[HAOI2016]找相同字元 字尾陣列+並差集
阿新 • • 發佈:2018-12-10
Description 給定兩個字串,求出在兩個字串中各取出一個子串使得這兩個子串相同的方案數。兩個方案不同當且僅當這兩 個子串中有一個位置不同。
Sample Input aabb bbaa
Sample Output 10
這題感覺跟這題跟[NOI2015]品酒大會很像。 套路題吧,不說了。。。
#include <cstdio> #include <cstring> #include <algorithm> using namespace std; int _min(int x, int y) {return x < y ? x : y;} typedef long long LL; typedef unsigned long long ULL; const ULL P = 131; struct node { int x, id; } height[410000]; char ss[210000]; int n, m, a[410000], Rank[410000], Rsort[410000], sa[410000], tt[410000], yy[410000]; int fa[410000], sum1[410000], sum2[410000]; ULL s[410000], o[410000]; int findfa(int x) { if(fa[x] != x) fa[x] = findfa(fa[x]); return fa[x]; } bool cmp(node a, node b) {return a.x > b.x;} void get_sa() { memcpy(Rank, a, sizeof(Rank)); memset(Rsort, 0, sizeof(Rsort)); for(int i = 1; i <= n; i++) Rsort[Rank[i]]++; for(int i = 1; i <= m; i++) Rsort[i] += Rsort[i - 1]; for(int i = n; i >= 1; i--) sa[Rsort[Rank[i]]--] = i; int ln = 1, p = 0; while(p < n) { int k = 0; for(int i = n - ln + 1; i <= n; i++) yy[++k] = i; for(int i = 1; i <= n; i++) if(sa[i] - ln > 0) yy[++k] = sa[i] - ln; memset(Rsort, 0, sizeof(Rsort)); for(int i = 1; i <= n; i++) Rsort[Rank[i]]++; for(int i = 1; i <= m; i++) Rsort[i] += Rsort[i - 1]; for(int i = n; i >= 1; i--) sa[Rsort[Rank[yy[i]]]--] = yy[i]; for(int i = 1; i <= n; i++) tt[i] = Rank[i]; p = 1; Rank[sa[1]] = 1; for(int i = 2; i <= n; i++) { if(tt[sa[i]] != tt[sa[i - 1]] || tt[sa[i] + ln] != tt[sa[i - 1] + ln]) p++; Rank[sa[i]] = p; } m = p, ln *= 2; } } int main() { int nn; scanf("%s", ss + 1); n = strlen(ss + 1); nn = n; for(int i = 1; i <= n; i++) a[i] = ss[i] - 'a' + 1; scanf("%s", ss + 1); m = strlen(ss + 1); a[n + 1] = 28; for(int i = 1; i <= m; i++) a[i + n + 1] = ss[i] - 'a' + 1; n += m + 1; o[0] = 1; for(int i = 1; i <= n; i++) o[i] = o[i - 1] * P, s[i] = s[i - 1] * P + a[i]; m = 30; get_sa(); for(int i = 1; i < n; i++) { int l = 0, r = _min(n - sa[i] + 1, n - sa[i + 1] + 1), ans = 0; while(l <= r) { int mid = (l + r) / 2; if(s[sa[i] + mid - 1] - s[sa[i] - 1] * o[mid] == s[sa[i + 1] + mid - 1] - s[sa[i + 1] - 1] * o[mid]) l = mid + 1, ans = mid; else r = mid - 1; } height[i].x = ans; height[i].id = i; } sort(height + 1, height + n, cmp); for(int i = 1; i <= nn; i++) fa[i] = i, sum1[i] = 1; for(int i = nn + 2; i <= n; i++) fa[i] = i, sum2[i] = 1; LL ans = 0, hh = 0, pos = n - 1; for(int i = 1; i < n; i++) { if(height[i].x == 0) {pos = i - 1; break;} if(height[i].x != height[i - 1].x && i != 1) ans += (LL)hh * (height[i - 1].x - height[i].x); int x = sa[height[i].id], y = sa[height[i].id + 1]; int fx = findfa(x), fy = findfa(y); if(fx != fy) { hh -= (LL)sum1[fx] * sum2[fx]; hh -= (LL)sum1[fy] * sum2[fy]; sum1[fy] += sum1[fx], sum2[fy] += sum2[fx]; fa[fx] = fy; hh += (LL)sum1[fy] * sum2[fy]; } } ans += (LL)height[pos].x * hh; printf("%lld\n", ans); return 0; }