1. 程式人生 > >「雅禮集訓 2017 Day7」事情的相似度

「雅禮集訓 2017 Day7」事情的相似度

「雅禮集訓 2017 Day7」事情的相似度

題目連結

我們先將字串建字尾自動機。然後對於兩個字首\([1,i]\)\([1,j]\),他們的最長公共字尾長度就是他們在\(fail\)樹上對應節點的\(lca\)\(maxlen\)

所以現在問題就變成了一個樹上問題:給定一棵樹,每個點有一個權值\((mxlen)\),詢問編號在一段區間內的點兩兩之間\(lca\)權值的最大值。

方法很多,這裡用的\(dsu\ on\ tree\)。對於每個點\(v\),我們計算其作為\(lca\)的貢獻。顯然貢獻的情況是一個點對,他們在\(v\)的不同子樹中(\(v\)自己也算一個子樹)。但是這樣點對的數量可能達到\(O(n^2)\)

不過我們仔細思考一下就會發現,其實這樣的點對不多。對於一個\(lca\),一個子節點\(v\),我們要與一個在之前已經加入的節點,我們發現,根據貪心,只需要與\(v\)的前驅和後繼組合就可以了。

程式碼:

#include<bits/stdc++.h>
#define ll long long
#define N 200005

using namespace std;
inline int Get() {int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}while('0'<=ch&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}return x*f;}

int n,m;
char s[N];
int ch[N<<1][2],fail[N<<1],mxlen[N<<1];
int id[N<<1];
int cnt=1,last=1;

void Insert(int f,int i) {
    static int now,p;
    now=++cnt;
    p=last,last=now;
    id[now]=i;
    mxlen[now]=mxlen[p]+1;
    while(p&&!ch[p][f]) ch[p][f]=now,p=fail[p];
    if(!p) return fail[now]=1,void();
    
    int q=ch[p][f];
    if(mxlen[q]==mxlen[p]+1) return fail[now]=q,void();
    
    int New=++cnt;
    memcpy(ch[New],ch[q],sizeof(ch[q]));
    fail[New]=fail[q];
    fail[q]=fail[now]=New;
    mxlen[New]=mxlen[p]+1;
    while(p&&ch[p][f]==q) ch[p][f]=New,p=fail[p];
}

struct load {int to,next;}e[N<<2];
int h[N<<1],edge=1;
void add(int i,int j) {e[++edge]=(load) {j,h[i]};h[i]=edge;}
int val[N<<1];

int size[N<<1],son[N<<1];
void dfs(int v) {
    size[v]=1;
    for(int i=h[v];i;i=e[i].next) {
        int to=e[i].to;
        dfs(to);
        size[v]+=size[to];
        if(size[son[v]]<size[to]) son[v]=to;
    }
}

set<int>pos;
set<int>::iterator it;
void statis(int v,int flag) {
    if(id[v]) {
        if(flag) pos.insert(id[v]);
        else pos.erase(id[v]);
    }
    for(int i=h[v];i;i=e[i].next) {
        int to=e[i].to;
        statis(to,flag);
    }
}

struct node {
    int l,r,mx;
    bool operator <(const node &a)const {return r<a.r;}
}st[N*50];
int sum;
struct query {
    int l,r,id;
    bool operator <(const query &a)const {return r<a.r;}
}q[N];
int ans[N];

void cal(int v,int mx) {
    if(id[v]) {
        it=pos.lower_bound(id[v]);
        if(it!=pos.end()) st[++sum]=(node) {id[v],*it,mx};
        if(it!=pos.begin()) st[++sum]=(node) {*(--it),id[v],mx};
    }
    for(int i=h[v];i;i=e[i].next) {
        int to=e[i].to;
        cal(to,mx);
    }
}

void solve(int v,int flag) {
    for(int i=h[v];i;i=e[i].next) {
        int to=e[i].to;
        if(to==son[v]) continue ;
        solve(to,0);
    }
    if(son[v]) solve(son[v],1);
    if(id[v]) {
        it=pos.lower_bound(id[v]);
        if(it!=pos.end()) st[++sum]=(node) {id[v],*it,val[v]};
        if(it!=pos.begin()) st[++sum]=(node) {*(--it),id[v],val[v]};
        pos.insert(id[v]);
    }
    for(int i=h[v];i;i=e[i].next) {
        int to=e[i].to;
        if(to==son[v]) continue ;
        cal(to,val[v]);
        statis(to,1);
    }
    if(!flag) pos.clear();
}

void solve2(int v) {
    if(id[v]) pos.insert(id[v]);
    for(int i=h[v];i;i=e[i].next) {
        int to=e[i].to;
        solve2(to);
    }
    for(int i=h[v];i;i=e[i].next) {
        int to=e[i].to;
        cal(to,val[v]);
        statis(to,1);
    }
    pos.clear();
}
struct Bit {
    int tem[N];
    int low(int i) {return i&(-i);}
    void add(int v,int f) {for(int i=v;i<=n;i+=low(i)) tem[i]=max(tem[i],f);}
    int query(int v) {
        int ans=0;
        for(int i=v;i;i-=low(i)) ans=max(ans,tem[i]);
        return ans;
    }
}bit;

int main() {
    n=Get(),m=Get();
    scanf("%s",s+1);
    for(int i=1;i<=n;i++) Insert(s[i]-'0',i);
    for(int i=2;i<=cnt;i++) {
        val[i]=mxlen[i];
        add(fail[i],i);
    }
    dfs(1);
    solve(1,1);
    sort(st+1,st+1+sum);
    
    for(int i=1;i<=m;i++) q[i].l=Get(),q[i].r=Get(),q[i].id=i;
    sort(q+1,q+1+m);
    
    int tag=1;
    for(int i=1;i<=m;i++) {
        while(tag<=sum&&st[tag].r<=q[i].r) {
            bit.add(n-st[tag].l+1,st[tag].mx);
            tag++;
        }
        ans[q[i].id]=bit.query(n-q[i].l+1);
    }
    
    for(int i=1;i<=m;i++) cout<<ans[i]<<"\n";
    return 0;
}