1. 程式人生 > 其它 >【題解】CF1393E2 Twilight and Ancient Scroll (harder version)

【題解】CF1393E2 Twilight and Ancient Scroll (harder version)

給你 \(n\) 個字串,對每個字串,你可以刪除其任意一個字元或讓其保持原樣,求最後使得字串字典序不降得方案數。

不難想到一個 \(\mathcal{O}(n(\sum |S|)^2)\) ,我們定義狀態 \(f_{i,j}\) 表示前 \(i\) 個字串,結尾的字串是第 \(i\) 個字串刪除第 \(j\) 個字元的方案數,如果不刪除則 \(j = n + 1\)

手算一下不難發現 \(f_{i,j}\)\(f_{i-1}\) 中刪除 \(k\) 後比當前串刪除 \(j\) 後小的位置的 \(f_{i-1,k}\) 之和。如果對於 \(j,k\) 按照刪除對應位置後字串的字典序排序,那麼轉移一定是一段字首和。

對這些位置排序可以做到線性,具體做法見這道題

排序後,如果我們用雙指標掃一遍,我們一共需要比較 \(\mathcal{O}(\sum |S|)\)\(S_{i,j}\)\(S_{i-1,k}\),如果直接比較時間複雜度是 \(\mathcal{O}((\sum |S|)^2)\)

我們可以用雜湊加速這個過程,二分出兩個串第一個不同的位置即可。這樣時間複雜度是 \(\mathcal{O}(\sum|S|\log )\) 的。

細節非常多,這裡就不展開討論。

#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define pre(i,a,b) for(int i=a;i>=b;i--)
const int N = 1000005, P = 1000000007, Q = 777777773, bas = 233;int pw[N];
void calc(char *s,int *u,int n){
	int cur = 1, L = 1, R = n;
	rep(i, 1, n)
		if(i == n){
			pre(j, n, R + 1)u[j + 1] = u[j];
			while(cur)cur--, u[L++] = i - cur;
			u[L] = n + 1;
		}
		else if(s[i] == s[i + 1])cur++;
		else {
			if(s[i] < s[i + 1]){
				while(cur)cur--, u[R--] = i - cur;
			}
			else{
				while(cur)cur--, u[L++] = i - cur;
			}
			cur = 1;
		}
}
void init(char *s,int *u,int n){
	u[0] = 0;
	rep(i, 1, n)u[i] = (1LL * u[i - 1] * bas + s[i]) % Q;
}
int get(int *u,int x,int ban){
	if(x < ban)return u[x];
	int res = x - ban + 1;
	return (1LL * u[ban - 1] * pw[res] % Q + u[x + 1] - 1LL * u[ban] * pw[res] % Q + Q)% Q;
}
bool cmp(char *s,char *t,int *u,int *v,int x,int y,int lenu,int lenv){
	int ed = ~0,l = 1, r = std::min(lenu - (x <= lenu), lenv - (y <= lenv));
	while(l <= r){
		int mid = (l + r) >> 1;
		if(get(u, mid, x) == get(v, mid, y))l = mid + 1;
		else r = mid - 1, ed = mid;
	}
	if(~ed)return s[ed + (ed >= x)] < t[ed + (ed >= y)];
	else return lenu - (x <= lenu) <= lenv - (y <= lenv);
}
int n,f[N],u[N],v[N],g[N],hs1[N],hs2[N];char s[N],t[N];
int main(){
	pw[0] = 1;rep(i, 1, N - 5)pw[i] = 1LL * pw[i - 1] * bas % Q;
	scanf("%d",&n);
	scanf("%s",s + 1);
	int w = strlen(s + 1);
	calc(s, u, w);init(s, hs1, w);
	rep(i, 1, w + 1)f[i] = 1;
	rep(op, 2, n){
		scanf("%s", t + 1);
		int m = strlen(t + 1), j = 1, sum = 0;
		calc(t, v, m);init(t, hs2, m);
		rep(i, 1, m + 1)g[i] = 0;
		rep(i, 1, m + 1){
			while(j <= w + 1 && cmp(s, t, hs1, hs2, u[j], v[i], w, m))sum = (sum + f[j++]) % P;
			g[i] = sum;
		}
		w = m;rep(i, 1, m + 1)f[i] = g[i], s[i] = t[i], u[i] = v[i], hs1[i] = hs2[i];
	}
	int ans = 0;
	rep(i, 1, w + 1)ans = (ans + f[i]) % P;
	printf("%d\n",ans);
	return 0;
}