1. 程式人生 > 實用技巧 >P6292 區間本質不同子串個數 SAM+LCT+線段樹

P6292 區間本質不同子串個數 SAM+LCT+線段樹

題意:

戳這裡

分析:

  • 前置芝士:SAM(求本質不同的子串數目),LCT (在SAM上動態修改)線段樹

首先我們先考慮求區間內元素種類數 這類問題的常見做法,就是對於每一個元素只維護它最後一次出現的位置,然後區間查詢和值就可以了,但為了實現這個操作,我們必須找到一個方法求出本質相同的子串上一次出現的位置,對於一個串 \(S\) 我們記它右端點最後一次出現的位置在 \(lst\) 這樣左端點掃到 \([1,lst-|S|+1]\) 時貢獻 \(+1\)

對於這種本質不同子串的問題我們建出 \(SAM\),考慮右移一個位置帶來的影響,假設右移到 \(i\) 這個位置,所有以 \(i\) 結尾的子串,就是字首 \(i\)

對應的節點在 \(parent\) 樹上的祖先,所以我們可以暴力跳祖先將這些節點對應的子串最後一次出現的位置改為 \(i\) ,對於區間 \([1,lst-|s|+1]\) 整體加 \(1\), 但是這樣的複雜度不太對勁

我們繼續在 \(parent\) 瞎搞 思考,由於所有以 \(i\) 結尾的子串的長度是連續的,所以我們只需要將 \(parent\) 樹上某一點到根的路徑上所有的串的 \(lst\) 改為 \(i\) 就行了,由於我們查詢的時候是根據左端點掃描到的個數統計答案,所以我們需要將這些 \(lst\)\(i\) 的串的左端點改為一段連續的區間,並對這些連續的區間,每一個區間整體+\(1\)

,相當於我們需要加一個等差數列,對於區間加等差數列,單點查詢轉化為區間加差分數列,區間查詢這樣直接上線段樹就可以了

程式碼:

#include<bits/stdc++.h>
#define pii pair<int,int>
#define mk(x,y) make_pair(x,y)
#define lc rt<<1
#define rc rt<<1|1
#define pb push_back
#define fir first
#define sec second

using namespace std;

namespace zzc
{
	inline int read()
	{
		int x=0,f=1;char ch=getchar();
		while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
		while (isdigit(ch)){x=x*10+ch-48;ch=getchar();}
		return x*f;
	}
    
    const int maxn = 4e5+5;
    int n,qt;
    char ch[maxn];
    long long ans[maxn];

    struct suffix_automaton
    {
        int lst,cnt;
        int len[maxn],trans[maxn][26],link[maxn],pos[maxn];
        suffix_automaton(){cnt=lst=1;}
        int insert(int x)
        {
            int cur=++cnt,tmp=lst;lst=cnt;
            len[cur]=len[tmp]+1;
            for(;tmp&&!trans[tmp][x];tmp=link[tmp]) trans[tmp][x]=cur;
            if(!tmp)
            {
                link[cur]=1;
            }
            else
            {
                int q=trans[tmp][x];
                if(len[tmp]+1==len[q])
                {
                    link[cur]=q;
                }
                else
                {
                    int clone=++cnt;
                    len[clone]=len[tmp]+1;
                    link[clone]=link[q];
                    link[q]=link[cur]=clone;
                    for(int i=0;i<26;i++) trans[clone][i]=trans[q][i];
                    for(;tmp&&trans[tmp][x]==q;tmp=link[tmp]) trans[tmp][x]=clone;
                }
            }
            return cur;
        }
        void build()
        {
            for(int i=1;i<=n;i++) pos[i]=insert(ch[i]-'a');
        }
    }sam;
    
    struct segment_tree
    {
        int ch[maxn][2],tag[maxn];
        long long sum[maxn];
        void pushup(int rt){sum[rt]=sum[lc]+sum[rc];}
        void add(int rt,int l,int r,int k)
        {
            tag[rt]+=k;
            sum[rt]+=1ll*k*(r-l+1);
        }
        void pushdown(int rt,int l,int r)
        {
            if(tag[rt])
            {
                int mid=(l+r)>>1;
                add(lc,l,mid,tag[rt]);
                add(rc,mid+1,r,tag[rt]);
                tag[rt]=0;
            }
        }
        void build(int rt,int l,int r)
        {
            sum[rt]=tag[rt]=0;
            if(l==r) return ;
            int mid=(l+r)>>1;
            build(lc,l,mid);build(rc,mid+1,r);
            pushup(rt);
        }
        void modify(int rt,int l,int r,int ql,int qr,int k)
        {
            if(ql<=l&&r<=qr)
            {
                add(rt,l,r,k);
                return ;
            }
            pushdown(rt,l,r);
            int mid=(l+r)>>1;
            if(ql<=mid) modify(lc,l,mid,ql,qr,k);
            if(qr>mid) modify(rc,mid+1,r,ql,qr,k);
            pushup(rt);
        }
        long long query(int rt,int l,int r,int ql,int qr)
        {
            if(ql<=l&&r<=qr) return sum[rt];
            pushdown(rt,l,r);
            int mid=(l+r)>>1;
            long long res=0;
            if(ql<=mid) res+=query(lc,l,mid,ql,qr);
            if(qr>mid) res+=query(rc,mid+1,r,ql,qr);
            return res;
        }
    }seg;

    struct LCT
    {
        int fa[maxn],ch[maxn][2],tag[maxn],val[maxn];
        void build()
        {
            for(int i=2;i<=sam.cnt;i++)
            {
                fa[i]=sam.link[i];
                ch[i][0]=ch[i][1]=0;
                val[i]=tag[i]=0;
            }
        }
        bool isroot(int x) {return ch[fa[x]][0]!=x&&ch[fa[x]][1]!=x;}
        void assign(int x,int k) {val[x]=tag[x]=k;}
        void pushdown(int x) 
        {
            if(tag[x])
            {
                if(ch[x][0]) assign(ch[x][0],tag[x]);
                if(ch[x][1]) assign(ch[x][1],tag[x]);
                tag[x]=0;
            }
        }
        void pushall(int x)
        {
            if(!isroot(x)) pushall(fa[x]);
            pushdown(x);
        }
        void rotate(int x)
        {
            int y=fa[x],z=fa[y],l,r;
            if(ch[y][0]==x)l=0;else l=1;r=l^1;
            if(!isroot(y)){if(ch[z][0]==y)ch[z][0]=x;else ch[z][1]=x;}
            fa[x]=z;fa[y]=x;fa[ch[x][r]]=y;
            ch[y][l]=ch[x][r];ch[x][r]=y;
        }
        void splay(int x)
        {
            pushall(x);
            while(!isroot(x))
            {
                int y=fa[x],z=fa[y];
                if(!isroot(y)){if((ch[z][0]==y)^(ch[y][0]==x))rotate(x);else rotate(y);}
                rotate(x);
            }
        }
        void access(int x,int p)
        {
            int t=0;
            for(;x;t=x,x=fa[x])
            {
                splay(x);
                if(int k=val[x]) seg.modify(1,1,n,k-sam.len[x]+1,k-sam.len[fa[x]],-1);
                ch[x][1]=t;
            }
            assign(t,p);
            seg.modify(1,1,n,1,p,1);
        }
    }lct;

    struct que
    {
        int id,l,r;
        bool operator <(const que &b)const
        {
            return r==b.r?l<b.l:r<b.r;
        }
    }q[maxn];

	void work()
	{
	    scanf("%s",ch+1);n=strlen(ch+1);
        sam.build();
        lct.build();
        seg.build(1,1,n);
        qt=read();
        for(int i=1;i<=qt;i++)
        {
            q[i].l=read();q[i].r=read();
            q[i].id=i;
        }
        sort(q+1,q+qt+1);
        for(int i=1,j=1;i<=qt;i++)
        {
            while(j<=q[i].r) lct.access(sam.pos[j],j),j++;
            ans[q[i].id]=seg.query(1,1,n,q[i].l,q[i].r);
        }
        for(int i=1;i<=qt;i++) printf("%lld\n",ans[i]);
	}

}

int main()
{
	zzc::work();
	return 0;
}