【題解】CF1393E2 Twilight and Ancient Scroll (harder version)
阿新 • • 發佈:2021-10-04
給你 \(n\) 個字串,對每個字串,你可以刪除其任意一個字元或讓其保持原樣,求最後使得字串字典序不降得方案數。
不難想到一個 \(\mathcal{O}(n(\sum |S|)^2)\) ,我們定義狀態 \(f_{i,j}\) 表示前 \(i\) 個字串,結尾的字串是第 \(i\) 個字串刪除第 \(j\) 個字元的方案數,如果不刪除則 \(j = n + 1\) 。
手算一下不難發現 \(f_{i,j}\) 是 \(f_{i-1}\) 中刪除 \(k\) 後比當前串刪除 \(j\) 後小的位置的 \(f_{i-1,k}\) 之和。如果對於 \(j,k\) 按照刪除對應位置後字串的字典序排序,那麼轉移一定是一段字首和。
對這些位置排序可以做到線性,具體做法見這道題。
排序後,如果我們用雙指標掃一遍,我們一共需要比較 \(\mathcal{O}(\sum |S|)\) 次 \(S_{i,j}\) 和 \(S_{i-1,k}\),如果直接比較時間複雜度是 \(\mathcal{O}((\sum |S|)^2)\) 。
我們可以用雜湊加速這個過程,二分出兩個串第一個不同的位置即可。這樣時間複雜度是 \(\mathcal{O}(\sum|S|\log )\) 的。
細節非常多,這裡就不展開討論。
#include<bits/stdc++.h> #define rep(i,a,b) for(int i=a;i<=b;i++) #define pre(i,a,b) for(int i=a;i>=b;i--) const int N = 1000005, P = 1000000007, Q = 777777773, bas = 233;int pw[N]; void calc(char *s,int *u,int n){ int cur = 1, L = 1, R = n; rep(i, 1, n) if(i == n){ pre(j, n, R + 1)u[j + 1] = u[j]; while(cur)cur--, u[L++] = i - cur; u[L] = n + 1; } else if(s[i] == s[i + 1])cur++; else { if(s[i] < s[i + 1]){ while(cur)cur--, u[R--] = i - cur; } else{ while(cur)cur--, u[L++] = i - cur; } cur = 1; } } void init(char *s,int *u,int n){ u[0] = 0; rep(i, 1, n)u[i] = (1LL * u[i - 1] * bas + s[i]) % Q; } int get(int *u,int x,int ban){ if(x < ban)return u[x]; int res = x - ban + 1; return (1LL * u[ban - 1] * pw[res] % Q + u[x + 1] - 1LL * u[ban] * pw[res] % Q + Q)% Q; } bool cmp(char *s,char *t,int *u,int *v,int x,int y,int lenu,int lenv){ int ed = ~0,l = 1, r = std::min(lenu - (x <= lenu), lenv - (y <= lenv)); while(l <= r){ int mid = (l + r) >> 1; if(get(u, mid, x) == get(v, mid, y))l = mid + 1; else r = mid - 1, ed = mid; } if(~ed)return s[ed + (ed >= x)] < t[ed + (ed >= y)]; else return lenu - (x <= lenu) <= lenv - (y <= lenv); } int n,f[N],u[N],v[N],g[N],hs1[N],hs2[N];char s[N],t[N]; int main(){ pw[0] = 1;rep(i, 1, N - 5)pw[i] = 1LL * pw[i - 1] * bas % Q; scanf("%d",&n); scanf("%s",s + 1); int w = strlen(s + 1); calc(s, u, w);init(s, hs1, w); rep(i, 1, w + 1)f[i] = 1; rep(op, 2, n){ scanf("%s", t + 1); int m = strlen(t + 1), j = 1, sum = 0; calc(t, v, m);init(t, hs2, m); rep(i, 1, m + 1)g[i] = 0; rep(i, 1, m + 1){ while(j <= w + 1 && cmp(s, t, hs1, hs2, u[j], v[i], w, m))sum = (sum + f[j++]) % P; g[i] = sum; } w = m;rep(i, 1, m + 1)f[i] = g[i], s[i] = t[i], u[i] = v[i], hs1[i] = hs2[i]; } int ans = 0; rep(i, 1, w + 1)ans = (ans + f[i]) % P; printf("%d\n",ans); return 0; }