Luogu2178 NOI2015 品酒大會 SA、並查集
阿新 • • 發佈:2019-01-02
感覺題目講的很不清楚……
題目意思就是給出一個長度為\(n\)的字串,求對於\(r=0,1,...,n-1\),求出\(LCP(suffix_p,suffix_q) \geq r\)的無序數對\((p,q)\)的數目,並令一對無序數對的價值為\(val_p \times val_q\),則還要求對於每一個\(r\),所有滿足上述條件的無序數對中的最大價值
跟字尾\(LCP\)長度有關,直接上\(SA\)。求出\(sa\)陣列和\(height\)陣列,我們考慮如何實現對於每一個\(r\)的詢問快速求出答案。不妨將\(r\)從大到小求解,那麼對於某一個字尾\(sa_k\),滿足\(LCP(suffix_{sa_p} , suffix_{sa_k}) \geq r\)
然後我們考慮如何拓展區間。考慮對於\(height_k=q\),當\(r>q\)的時候\(k\)位置兩端的區間不會越過\(k-1\)與\(k\),而當\(r \leq q\)時這兩段區間就會合成一段區間。這個顯然是可以使用並查集維護的,並且可以比較輕鬆地在並查集上維護最大價值。
#include<bits/stdc++.h> #define mid ((l + r) >> 1) #define lch Tree[x].l #define rch Tree[x].r //This code is written by Itst using namespace std; inline int read(){ int a = 0; char c = getchar(); bool f = 0; while(!isdigit(c) && c != EOF){ if(c == '-') f = 1; c = getchar(); } if(c == EOF) exit(0); while(isdigit(c)){ a = (a << 3) + (a << 1) + (c ^ '0'); c = getchar(); } return f ? -a : a; } const int MAXN = 3e5 + 10; int fa[MAXN] , val[MAXN] , valMax[MAXN][2] , valMin[MAXN][2]; int sa[MAXN] , rk[MAXN] , pot[MAXN] , tp[MAXN << 1] , h[MAXN]; int ind[MAXN] , size[MAXN] , N , maxN = 26; char s[MAXN]; long long Max , cnt , ans[MAXN][2]; int find(int x){ return fa[x] == x ? x : (fa[x] = find(fa[x])); } void Debug(){ for(int i = 1 ; i <= N ; ++i) cout << sa[i] << ' '; cout << endl; for(int i = 1 ; i <= N ; ++i) cout << ind[i] << ' '; cout << endl << endl; } void input(){ N = read(); scanf("%s" , s + 1); for(int i = 1 ; i <= N ; ++i){ val[i] = read(); if(val[i] < 0) valMin[i][0] = val[i]; } } void sort(int p){ memset(pot , 0 , sizeof(pot)); for(int i = 1 ; i <= N ; ++i) ++pot[rk[i]]; for(int i = 1 ; i <= maxN ; ++i) pot[i] += pot[i - 1]; for(int i = 1 ; i <= N ; ++i) sa[++pot[rk[tp[i]] - 1]] = tp[i]; memcpy(tp , rk , sizeof(int) * (N + 1)); for(int i = 1 ; i <= N ; ++i) rk[sa[i]] = rk[sa[i - 1]] + (tp[sa[i]] != tp[sa[i - 1]] || tp[sa[i] + p] != tp[sa[i - 1] + p]); maxN = rk[sa[N]]; } bool cmp(int a , int b){ return h[a] < h[b]; } void init(){ memset(valMax , -0x3f , sizeof(valMax)); Max = -1ll * 0x3f3f3f3f * 0x3f3f3f3f; for(int i = 1 ; i <= N ; ++i) rk[tp[i] = i] = s[i] - 'a' + 1; sort(0); for(int i = 1 ; i <= N && maxN < N ; i <<= 1){ int cnt = 0; for(int j = 1 ; j <= i ; ++j) tp[++cnt] = N - i + j; for(int j = 1 ; j <= N ; ++j) if(sa[j] > i) tp[++cnt] = sa[j] - i; sort(i); } for(int i = 1 ; i <= N ; ++i){ if(rk[i] == 1) continue; int t = rk[i]; h[t] = max(0 , h[rk[i - 1]] - 1); while(s[sa[t] + h[t]] == s[sa[t - 1] + h[t]]) ++h[t]; ind[t] = t; } sort(ind + 2 , ind + N + 1 , cmp); for(int i = 1 ; i <= N ; ++i){ fa[i] = i; size[i] = 1; valMax[i][0] = val[i]; } } inline void merge(int x , int y){ fa[x] = y; int num[4] = {valMax[x][0] , valMax[x][1] , valMax[y][0] , valMax[y][1]}; sort(num , num + 4); valMax[y][0] = num[3]; valMax[y][1] = num[2]; Max = max(Max , 1ll * valMax[y][0] * valMax[y][1]); num[0] = valMin[x][0]; num[1] = valMin[x][1]; num[2] = valMin[y][0]; num[3] = valMin[y][1]; sort(num , num + 4); valMin[y][0] = num[0]; valMin[y][1] = num[1]; if(1ll * valMin[y][0] * valMin[y][1]) Max = max(Max , 1ll * valMin[y][0] * valMin[y][1]); cnt -= 1ll * size[x] * (size[x] - 1) / 2 + 1ll * size[y] * (size[y] - 1) / 2; size[y] += size[x]; cnt += 1ll * size[y] * (size[y] - 1) / 2; } void work(){ int p = N; for(int i = N - 1 ; i >= 0 ; --i){ while(p > 1 && h[ind[p]] == i){ merge(find(sa[ind[p]]) , find(sa[ind[p] - 1])); --p; } if(cnt){ ans[i][0] = cnt; ans[i][1] = Max; } } } void output(){ for(int i = 0 ; i <= N - 1 ; ++i) cout << ans[i][0] << ' ' << ans[i][1] << '\n'; } int main(){ input(); init(); work(); output(); return 0; }