1. 程式人生 > 其它 >Codeforces 1276F - Asterisk Substrings(SAM+線段樹合併+虛樹)

Codeforces 1276F - Asterisk Substrings(SAM+線段樹合併+虛樹)

SAM、線段樹合併、虛樹三位一體,hot tea

Codeforces 題面傳送門 & 洛谷題面傳送門

SAM hot tea %%%%%%%

首先我們顯然可以將所有能夠得到的字串分成六類:\(\varnothing,\text{*},s,\text{*}s,s\text{*},s\text{*}t\),其中 \(s,t\) 分別代表某個非空字串,\(\text{*}\) 則代表題目中的星號,顯然前兩種情況的貢獻都是 \(1\),算出後幾種情況的答案後直接加 \(2\) 即可,第三種情況也異常簡單,相當於求 \(s\) 中本質不同的子串個數,SAM 基操,相信來挑戰這道題的人都會。第四種情況可以看作在某個 \(s[1…n-1]\) 的子串後面添上一個 \(\text{*}\)

,因此第四種情況合法的子串數就是 \(s[1…n-1]\) 本質不同子串個數,這個一波 SAM 帶走即可,同理第五種情況合法的子串數就是 \(s[2…n]\) 本質不同子串數,同樣一波 SAM 帶走,比較棘手的是第六種情況。

由於我們要求本質不同子串個數,因此列舉星號代替了原字串中的哪個字元是不明智的,因為這樣會算重。因此我們考慮退一步,我們不是剛才對 \(s[1…n-1]\)​ 建了一個 SAM 嗎,那麼我們就考慮列舉星號左邊那部分對應的子串在 SAM 上哪個節點對應的字串集合中,假設為 \(x\)​,那麼根據 SAM 那一套理論,這個字串所有出現位置的右端點都在 \(x\)​ 對應的 \(\text{endpos}\)

​ 中,也就是 \(x\)​ 在 parent tree 上子樹內 \(s\)​ 所有字首對應的節點。換句話說,所有星號可以替換的位置都在集合 \(\{y+1|y\in\text{endpos}(x)\}\)​ 中,考慮星號右邊的字串有哪些取值,那麼顯然,星號右邊的字串必須是這些星號可以放的位置右邊的字元組成的字尾的某個字首,也就是 \(S=\{s[y+2…n]|y\in\text{endpos}(x)\}\)​ 中某個字串的某個字首,也就是說我們只需求 \(S=\{s[y+2…n]|y\in\text{endpos}(x)\}\)​ 有多少個本質不同的字首即可。直接列舉字首顯然不可取,不過注意到如果我們對反串也建一個 SAM,那麼反串的 SAM 上某個節點 \(z\)
​ 表示的字串符合要求,當且僅當存在一個 \(y\in\text{endpos}(x)\)​,滿足 \(s[y+2…n]\)​ 在反串的 SAM 上對應的節點在 \(z\) 的子樹內,也就是說如果我們找出這樣的 \(|\text{endpos}(x)|\)​ 個字尾在反串 SAM 上對應的節點,那麼所有符合條件的點組成的集合就是它們到根節點路徑上節點的並,而我們要求的,就是這個並中所有節點表示的字串集合的大小之和。然後到這裡我就不會做了,一直在想怎麼剖。事實上這東西長得一臉虛樹,根據虛樹那一套理論,這東西節點並的字串集合的大小,就是它們按照 DFS 序排序之後,DFS 序列相鄰兩個點+一頭一尾路徑上點的權值之和扣掉它們的 LCA 的權值的和除以 \(2\)(在本題中就是它們對應的等價類的大小,即 \(len_x-len_{lnk_x}\)​),外加它們的 LCA 到根節點這段路徑上權值之和。然後我們考慮線段樹合併維護這貨,具體來說就開一棵線段樹,以 DFS 序為下標,如果一個葉節點存在則表示其表示的區間的 DFS 序在該節點的子樹記憶體在,然後對於線段樹上每個節點,維護該區間中,在 SAM 上存在的 DFS 序的最小值、最大值,以及兩兩之間的距離和,每次轉移時將一個點為根的線段樹與其兒子節點表示的線段樹合併一波即可算出答案。初始狀態大概就對於所有 \(i\in[1,n-2]\),我們找出長度為 \(i\)​ 的字首在原串 SAM 上表示的節點,設其為 \(x\),然後假設 \(s[i+2…n]\) 在反串上表示的節點為 \(y\),那麼我們就將 \(y\) 在反串上的 DFS 序加入 \(x\) 表示的線段樹即可。使用 ST 表求 LCA 可以實現 \(\mathcal O(n\log n)\) 的複雜度。

const int MAXN=1e5;
const int MAXP=2e5;
const int MAX_ND=MAXP<<5;
const int LOG_N=19;
char str[MAXN+5];int n;
struct graph{
	int hd[MAXP+5],to[MAXP+5],nxt[MAXP+5],ec=0;
	void adde(int u,int v){to[++ec]=v;nxt[ec]=hd[u];hd[u]=ec;}
};
struct SAM{
	struct node{int ch[27],len,lnk;} s[MAXP+5];
	int cur,ncnt,ed[MAXN+5];graph t;
	SAM(){cur=ncnt=1;}
	void extend(char c,int ps){
		int id=c-'a',nw=++ncnt,p=cur;
		s[nw].len=s[cur].len+1;ed[ps]=nw;cur=nw;
		while(p&&!s[p].ch[id]) s[p].ch[id]=nw,p=s[p].lnk;
		if(!p) return s[nw].lnk=1,void();
		int q=s[p].ch[id];
		if(s[q].len==s[p].len+1) return s[nw].lnk=q,void();
		int cl=++ncnt;s[cl].len=s[p].len+1;
		s[cl].lnk=s[q].lnk;s[q].lnk=cl;s[nw].lnk=cl;
		for(int i=0;i<26;i++) s[cl].ch[i]=s[q].ch[i];
		while(p&&s[p].ch[id]==q) s[p].ch[id]=cl,p=s[p].lnk;
	}
	void build(){for(int i=2;i<=ncnt;i++) t.adde(s[i].lnk,i);}
	ll calc(){
		ll res=0;
		for(int i=1;i<=ncnt;i++) res+=s[i].len-s[s[i].lnk].len;
		return res;
	}
} s_tot,fw,bk;//forward backward
#define g1 fw.t
#define g2 bk.t
int dep[MAXP+5];pii st[MAXP*2+5][LOG_N+2];
int dfn[MAXP+5],rid[MAXP+5],tim=0,eul[MAXP+5],tim_eul=0;
void dfs0(int x,int f){
	rid[dfn[x]=++tim]=x;
	st[eul[x]=++tim_eul][0]=mp(dep[x],x);
	for(int e=g2.hd[x];e;e=g2.nxt[e]){
		int y=g2.to[e];if(y==f) continue;
		dep[y]=dep[x]+1;dfs0(y,x);
		st[eul[x]=++tim_eul][0]=mp(dep[x],x);
	}
}
int getlca(int x,int y){
	x=eul[x];y=eul[y];if(x>y) swap(x,y);
	int k=31-__builtin_clz(y-x+1);
	return min(st[x][k],st[y-(1<<k)+1][k]).se;
}
int getdis(int x,int y){
	int lc=getlca(x,y);
	return bk.s[x].len+bk.s[y].len-(bk.s[lc].len<<1);
}
struct data{
	int lc,lft,rit;ll sum;
	data(int _lc=0,int _lft=0,int _rit=0,ll _sum=0):
		lc(_lc),lft(_lft),rit(_rit),sum(_sum){}
	data operator +(const data &rhs){
		if(!lc) return rhs;
		if(!rhs.lc) return data(lc,lft,rit,sum);
		return data(getlca(lc,rhs.lc),lft,rhs.rit,
		sum+rhs.sum+getdis(rid[rit],rid[rhs.lft]));
	}
};
namespace segtree{
	struct node{int ch[2];data d;} s[MAX_ND+5];
	int ncnt=0;
	void pushup(int k){s[k].d=s[s[k].ch[0]].d+s[s[k].ch[1]].d;}
	void insert(int &k,int l,int r,int x){
		if(!k) k=++ncnt;if(l==r) return s[k].d=data(rid[x],x,x,0),void();
		int mid=l+r>>1;
		(x<=mid)?insert(s[k].ch[0],l,mid,x):insert(s[k].ch[1],mid+1,r,x);
		pushup(k);
	}
	int merge(int x,int y,int l,int r){
		if(!x||!y) return x+y;int z=++ncnt;
		if(l==r) return s[z].d=data(rid[l],l,l,0),z;
		int mid=l+r>>1;
		s[z].ch[0]=merge(s[x].ch[0],s[y].ch[0],l,mid);
		s[z].ch[1]=merge(s[x].ch[1],s[y].ch[1],mid+1,r);
		return pushup(z),z;
	}
}
using segtree::insert;
using segtree::merge;
int rt[MAXP+5];ll res=0;
void dfs(int x,int f){
	for(int e=g1.hd[x];e;e=g1.nxt[e]){
		int y=g1.to[e];if(y==f) continue;
		dfs(y,x);rt[x]=merge(rt[x],rt[y],1,bk.ncnt);
	} data d=segtree::s[rt[x]].d;
	if(d.lc) res+=1ll*(fw.s[x].len-fw.s[fw.s[x].lnk].len)*
	(d.sum+getdis(rid[d.lft],rid[d.rit])+(getdis(d.lc,1)<<1));
}
int main(){
	scanf("%s",str+1);n=strlen(str+1);
	for(int i=1;i<=n;i++) s_tot.extend(str[i],/*19260817*/i);
	for(int i=1;i<n;i++) fw.extend(str[i],i);
	for(int i=n;i>1;i--) bk.extend(str[i],i);
	fw.build();bk.build();
	dfs0(1,0);
	for(int i=1;i<=LOG_N;i++) for(int j=1;j+(1<<i)-1<=tim_eul;j++)
		st[j][i]=min(st[j][i-1],st[j+(1<<i-1)][i-1]);
	for(int i=1;i<n-1;i++) insert(rt[fw.ed[i]],1,bk.ncnt,dfn[bk.ed[i+2]]);
	dfs(1,0);res>>=1;res+=2+s_tot.calc()+fw.calc()+bk.calc();
	printf("%lld\n",res);
	return 0;
}