洛谷P3435 [POI2006]OKR-Periods of Words
題目
這題意不是一般人能讀懂的,為了讀懂題目,我還特意去翻了題解[手動笑哭]
題目大意:
給定一個字串s
對於s的每一個字首子串s1,規定一個字串Q,Q滿足:Q是s1的字首子串且Q不等於s1且s1是字串Q+Q的字首.設siz為所有滿足條件的Q中Q的最大長度(注意這裡僅僅針對s1而不是s,即一個siz的值對應一個s1)
求出所有siz的和
不要被這句話誤導了:
求給定字串所有字首的最大週期長度之和
正確斷句:求給定字串 所有/字首的最大週期長度/之和
我就想了半天:既然是"最大週期長度",那不是唯一的嗎?為什麼還要求和呢?
思路
其實這題要AC並不難(看通過率就知道)
看圖
要滿足Q是s1的字首,則Q的1~5位和s1的1~5位是一樣的,又因為s1是Q+Q的字首,所以又要滿足s1的6~8位和Q+Q的6~8位一樣,即s1的6~8位和Q的1~3位相等,回到s1,標藍色的兩個位置相等.
回顧下KMP中next陣列的定義:next[i]
表示對於某個字串a,"a中長度為next[i]
的字首子串"與"a中以第i為結尾,長度為next[i]
的非字首子串"相等,且next[i]
取最大值
是不是悟到了什麼,是不是感覺這題和next陣列冥冥之中有某種相似之處?
但是,這僅僅只是開始
按照題目的意思,我們要讓Q的長度最大,也就是圖中藍色部分長度最小,但是next中存的是藍色部分的最大值,顯然,兩者相違背,難道我們要改造next陣列嗎?明顯不行,若next儲存的改為最小值,則原來求next的方法行不通.考慮換一種思路(一定要對KMP中next的求法理解透徹,不然下面看不懂,不行的next[i],next[next[i-1]],next[next[next[i]]]...
都能滿足"字首等於以i結尾的子串"這個條件,且越往後,值越小,所以,我們的目標就定在上面序列中從後往前第一個不為0的next值
極端條件下,暴力跑可以去到O(n^2),理論上會超時(我沒試過)
兩種優化:
- 記憶化,時間效率應該是O(n)這裡不詳細講,可以去到洛谷題解檢視
- 倍增(我第一時間想到並AC的做法):
我們將j=next[j]
這一語句稱作"j跳了一次"(感覺怪怪的),將next拓展為2維,next[i][k]
表示結尾為i,j跳了2^k的字首字元長度(也就是next[i][0]
等價於原來的next[i]
藉助倍增LCA的思想(沒學沒關係,現學現用),這裡不做贅述,上程式碼
int tmp = i;
for(rr int j = siz[i] ; j >= 0 ; --j)//siz[i]是next[i][j]中第一個為0的小標j,注意倒序列舉
if(next[tmp][j] != 0)//如果不為0則跳
tmp = next[tmp][j];
倍增方法在字串長度去到10^6時是非常危險的,帶個log理論是2*10^7左右,常數再大那麼一丟丟就TLE了,還好資料比較水,但是作為倍增和KMP的練習做一下也是不錯的
最後,記得開longlong(不然我就一次AC了)
完整程式碼
#include <iostream>
#include <cmath>
#include <cstdio>
#define nn 1000010
#define rr register
#define ll long long
using namespace std;
int next[nn][30] ;
int siz[nn];
char s[nn];
int n;
int main() {
// freopen("P3435_3.in" , "r" , stdin);
cin >> n;
do
s[1] = getchar();
while(s[1] < 'a' || s[1] > 'z');
for(rr int i = 2 ; i <= n ; i++)
s[i] = getchar();
next[1][0] = 0;
for(rr int i = 2 , j = 0 ; i <= n ; i++) {
while(j != 0 && s[i] != s[j + 1])
j = next[j][0];
if(s[j + 1] == s[i])
++j;
next[i][0] = j;
}
rr int k = log(n) / log(2) + 1;
for(rr int j = 1 ; j <= k ; j++)
for(rr int i = 1 ; i <= n ; i++) {
next[i][j] = next[next[i][j - 1]][j - 1];
if(next[i][j] == 0)
siz[i] = j;
}
ll ans = 0;
for(rr int i = 1 ; i <= n ; ++i) {
int tmp = i;
for(rr int j = siz[i] ; j >= 0 ; --j)
if(next[tmp][j] != 0)
tmp = next[tmp][j];
if(2 * (i - tmp) >= i && tmp != i)
ans += (ll)i - tmp;
}
cout << ans;
return 0;
}