[ SDOI 2016 ] 模式字串
阿新 • • 發佈:2021-07-14
題目
思路
程式碼
#include <iostream> #include <cstring> #include <algorithm> #include <cstdio> using namespace std; typedef unsigned long long ULL; const int N = 1000010, P = 31; int T, n, m, pre[N], suf[N]; char w[N], tmp[N]; ULL B[N] = { 1 }, prehash[N], sufhash[N]; int h[N], t[N << 1], p[N << 1], idx; void add(int a, int b) { t[idx] = b, p[idx] = h[a], h[a] = idx++; } int sz[N], vis[N], root; void DFS_rt(int u, int fa, int tot) { sz[u] = 1; int mx = 0; for (int i = h[u], v; v = t[i], i != -1; i = p[i]) { if (vis[v] || v == fa) continue; DFS_rt(v, u, tot), sz[u] += sz[v], mx = max(mx, sz[v]); } mx = max(mx, tot - sz[u]); if (mx * 2 <= tot) root = u; } void DFS_sz(int u, int fa) { sz[u] = 1; for (int i = h[u], v; v = t[i], i != -1; i = p[i]) if (!vis[v] && v != fa) DFS_sz(v, u), sz[u] += sz[v]; } int sumpre[N], sumsuf[N], Tpre[N], Tsuf[N]; long long res; int DFS(int u, int fa, int dep, int &mx, int mid, ULL hash) { int res = 0; mx = max(mx, dep); hash += w[u] * B[dep - 1]; if (hash == prehash[dep]) { Tpre[dep % m]++; if (mid == pre[dep % m + 1]) res += sumsuf[m - dep % m - 1]; } if (hash == sufhash[dep]) { Tsuf[dep % m]++; if (mid == suf[dep % m + 1]) res += sumpre[m - dep % m - 1]; } for (int i = h[u], v; v = t[i], i != -1; i = p[i]) if (!vis[v] && v != fa) res += DFS(v, u, dep + 1, mx, mid, hash); return res; } int Solve(int u, int tot) { if (tot < m) return 0; DFS_rt(u, -1, tot), u = root; DFS_sz(u, -1), vis[u] = 1; int tag = 0, res = 0; if (w[u] == pre[1]) sumpre[0]++; if (w[u] == suf[1]) sumsuf[0]++; for (int i = h[u], v; v = t[i], i != -1; i = p[i]) { if (vis[v]) continue; int mx = 0; res += DFS(v, u, 1, mx, w[u], 0); tag = max(mx, tag); for (int i = 0; i <= mx; i++) sumpre[i] += Tpre[i], Tpre[i] = 0, sumsuf[i] += Tsuf[i], Tsuf[i] = 0; } for (int i = 0; i <= tag; i++) sumpre[i] = sumsuf[i] = 0; for (int i = h[u], v; v = t[i], i != -1; i = p[i]) if (!vis[v]) res += Solve(v, sz[v]); return res; } void Process(int s[], ULL hash[]) { for (int i = m + 1; i <= n; i++) s[i] = s[i - m]; for (int i = 1; i <= n; i++) hash[i] = hash[i - 1] * P + s[i]; } int main() { for (int i = 1; i <= 1e6; i++) B[i] = B[i - 1] * P; scanf("%d", &T); while (T--) { memset(vis, 0, sizeof vis), res = 0; memset(h, -1, sizeof h), idx = 0; scanf("%d%d%s", &n, &m, tmp + 1); for (int i = 1; i <= n; i++) w[i] = tmp[i] - 'A' + 1; for (int i = 1, a, b; i < n; i++) scanf("%d%d", &a, &b), add(a, b), add(b, a); scanf("%s", tmp + 1); for (int i = 1; i <= m; i++) pre[i] = tmp[i] - 'A' + 1; for (int i = 1; i <= m; i++) suf[i] = pre[m - i + 1]; Process(pre, prehash), Process(suf, sufhash); cout << Solve(1, n) << endl; } return 0; }