1. 程式人生 > >Luogu2178 NOI2015 品酒大會 SA、並查集

Luogu2178 NOI2015 品酒大會 SA、並查集

傳送門


感覺題目講的很不清楚……

題目意思就是給出一個長度為\(n\)的字串,求對於\(r=0,1,...,n-1\),求出\(LCP(suffix_p,suffix_q) \geq r\)的無序數對\((p,q)\)的數目,並令一對無序數對的價值為\(val_p \times val_q\),則還要求對於每一個\(r\),所有滿足上述條件的無序數對中的最大價值

跟字尾\(LCP\)長度有關,直接上\(SA\)。求出\(sa\)陣列和\(height\)陣列,我們考慮如何實現對於每一個\(r\)的詢問快速求出答案。不妨將\(r\)從大到小求解,那麼對於某一個字尾\(sa_k\),滿足\(LCP(suffix_{sa_p} , suffix_{sa_k}) \geq r\)

\(p\)一定是一段區間,而且這一段區間隨著\(r\)的縮小不斷增大。

然後我們考慮如何拓展區間。考慮對於\(height_k=q\),當\(r>q\)的時候\(k\)位置兩端的區間不會越過\(k-1\)\(k\),而當\(r \leq q\)時這兩段區間就會合成一段區間。這個顯然是可以使用並查集維護的,並且可以比較輕鬆地在並查集上維護最大價值。

#include<bits/stdc++.h>
#define mid ((l + r) >> 1)
#define lch Tree[x].l
#define rch Tree[x].r
//This code is written by Itst
using namespace std;

inline int read(){
    int a = 0;
    char c = getchar();
    bool f = 0;
    while(!isdigit(c) && c != EOF){
        if(c == '-')
            f = 1;
        c = getchar();
    }
    if(c == EOF)
        exit(0);
    while(isdigit(c)){
        a = (a << 3) + (a << 1) + (c ^ '0');
        c = getchar();
    }
    return f ? -a : a;
}

const int MAXN = 3e5 + 10;
int fa[MAXN] , val[MAXN] , valMax[MAXN][2] , valMin[MAXN][2];
int sa[MAXN] , rk[MAXN] , pot[MAXN] , tp[MAXN << 1] , h[MAXN];
int ind[MAXN] , size[MAXN] , N , maxN = 26;
char s[MAXN];
long long Max , cnt , ans[MAXN][2];

int find(int x){
    return fa[x] == x ? x : (fa[x] = find(fa[x]));
}

void Debug(){
    for(int i = 1 ; i <= N ; ++i)
        cout << sa[i] << ' ';
    cout << endl;
    for(int i = 1 ; i <= N ; ++i)
        cout << ind[i] << ' ';
    cout << endl << endl;
}

void input(){
    N = read();
    scanf("%s" , s + 1);
    for(int i = 1 ; i <= N ; ++i){
        val[i] = read();
        if(val[i] < 0)
            valMin[i][0] = val[i];
    }
}

void sort(int p){
    memset(pot , 0 , sizeof(pot));
    for(int i = 1 ; i <= N ; ++i)
        ++pot[rk[i]];
    for(int i = 1 ; i <= maxN ; ++i)
        pot[i] += pot[i - 1];
    for(int i = 1 ; i <= N ; ++i)
        sa[++pot[rk[tp[i]] - 1]] = tp[i];
    memcpy(tp , rk , sizeof(int) * (N + 1));
    for(int i = 1 ; i <= N ; ++i)
        rk[sa[i]] = rk[sa[i - 1]] + (tp[sa[i]] != tp[sa[i - 1]] || tp[sa[i] + p] != tp[sa[i - 1] + p]);
    maxN = rk[sa[N]];
}

bool cmp(int a , int b){
    return h[a] < h[b];
}

void init(){
    memset(valMax , -0x3f , sizeof(valMax));
    Max = -1ll * 0x3f3f3f3f * 0x3f3f3f3f;
    for(int i = 1 ; i <= N ; ++i)
        rk[tp[i] = i] = s[i] - 'a' + 1;
    sort(0);
    for(int i = 1 ; i <= N && maxN < N ; i <<= 1){
        int cnt = 0;
        for(int j = 1 ; j <= i ; ++j)
            tp[++cnt] = N - i + j;
        for(int j = 1 ; j <= N ; ++j)
            if(sa[j] > i)
                tp[++cnt] = sa[j] - i;
        sort(i);
    }
    for(int i = 1 ; i <= N ; ++i){
        if(rk[i] == 1)
            continue;
        int t = rk[i];
        h[t] = max(0 , h[rk[i - 1]] - 1);
        while(s[sa[t] + h[t]] == s[sa[t - 1] + h[t]])
            ++h[t];
        ind[t] = t;
    }
    sort(ind + 2 , ind + N + 1 , cmp);
    for(int i = 1 ; i <= N ; ++i){
        fa[i] = i;
        size[i] = 1;
        valMax[i][0] = val[i];
    }
}

inline void merge(int x , int y){
    fa[x] = y;
    int num[4] = {valMax[x][0] , valMax[x][1] , valMax[y][0] , valMax[y][1]};
    sort(num , num + 4);
    valMax[y][0] = num[3];
    valMax[y][1] = num[2]; 
    Max = max(Max , 1ll * valMax[y][0] * valMax[y][1]);
    num[0] = valMin[x][0];
    num[1] = valMin[x][1];
    num[2] = valMin[y][0];
    num[3] = valMin[y][1];
    sort(num , num + 4);
    valMin[y][0] = num[0];
    valMin[y][1] = num[1];
    if(1ll * valMin[y][0] * valMin[y][1])
        Max = max(Max , 1ll * valMin[y][0] * valMin[y][1]);
    cnt -= 1ll * size[x] * (size[x] - 1) / 2 + 1ll * size[y] * (size[y] - 1) / 2;
    size[y] += size[x];
    cnt += 1ll * size[y] * (size[y] - 1) / 2;
}

void work(){
    int p = N;
    for(int i = N - 1 ; i >= 0 ; --i){
        while(p > 1 && h[ind[p]] == i){
            merge(find(sa[ind[p]]) , find(sa[ind[p] - 1]));
            --p;
        }
        if(cnt){
            ans[i][0] = cnt;
            ans[i][1] = Max;
        }
    }
}

void output(){
    for(int i = 0 ; i <= N - 1 ; ++i)
        cout << ans[i][0] << ' ' << ans[i][1] << '\n';
}

int main(){
    input();
    init();
    work();
    output();
    return 0;
}