Codeforces gym101955 A【樹形dp】
阿新 • • 發佈:2019-01-04
有n個大號和m個小號
然後需要對這些號進行匹配,一個大號最多匹配2個小號
匹配條件是大號和小號構成了字首關係
字串長度不超過10
問方案數
思路
因為要構成字首關係
所以就考慮在trie樹上dp
\(f_{i,j,k}\)表示i的子樹中,還需要來自祖先的j個小號,並且有需要匹配但是沒有匹配的小號k個
然後如果當前是一個大號節點
可以從子樹中選一個小號
\(f_{u,j,k - 1}<=f_{v,j,k} * k\)
可以從子樹中選兩個小號
\(f_{u,j,k - 2}<=f_{v,j,k} * (\frac{k *(k - 1)}{2})\)
可以從祖先中選一個小號
\(f_{u,j+1, k}<=f_{u,j,k}\)
可以從祖先中選兩個小號(因為在祖先中需要選擇兩次,避免重複計算這裡除以2)
\(f_{u,j+2,k}<=f_{u,j,k}*\frac{1}{2}\)
可以從祖先選一個子樹選一個
\(f_{u,j+1,k-1}<=f_{u,j,k}*k\)
這裡我們考慮等價選擇的多種方案的時候只在深度淺的地方算
然後實際上如果是小號節點,同理就好了
#include<bits/stdc++.h> using namespace std; const int N = 1e5 + 10; const int Mod = 1e9 + 7; const int CHARSET_SIZE = 26; int add(int a, int b) { return (a += b) >= Mod ? a - Mod : a; } int mul(int a, int b) { return 1ll * a * b % Mod; } struct Node { int ch[CHARSET_SIZE], typ; void init() { typ = 0; memset(ch, 0, sizeof(ch)); } } p[N]; int tot = 0, n, m; char c[N]; int f[N][12][22], g[N][12][22]; void init() { tot = 1; p[1].init(); } void insert(char *s, int typ) { int len = strlen(s + 1), u = 1; for (int i = 1; i <= len; i++) { int cur = s[i] - 'a'; if (!p[u].ch[cur]) p[p[u].ch[cur] = ++tot].init(); u = p[u].ch[cur]; } p[u].typ = typ; } void dfs(int u) { for (int i = 0; i <= 10; i++) for (int j = 0; j <= 20; j++) f[u][i][j] = g[u][i][j] = 0; f[u][0][0] = 1; for (int i = 0; i < CHARSET_SIZE; i++) { int v = p[u].ch[i]; if (!v) continue; dfs(v); for (int j = 10; j >= 0; j--) for (int k = 20; k >= 0; k--) if (f[u][j][k]) for (int l = 0; l <= 10 - j; l++) for (int t = 0; t <= 20 - k; t++) g[u][j + l][k + t] = add(g[u][j + l][k + t], mul(f[u][j][k], f[v][l][t])); for (int j = 0; j <= 10; j++) for (int k = 0; k <= 20; k++) { f[u][j][k] = g[u][j][k]; g[u][j][k] = 0; } } if (!p[u].typ) return; for (int i = 0; i <= 10; i++) { for (int j = 0; j <= 20; j++) if (f[u][i][j]) { if (p[u].typ == 1) { if (i + 1 <= 10) g[u][i + 1][j] = add(g[u][i + 1][j], f[u][i][j]); if (j - 1 >= 0) g[u][i][j - 1] = add(g[u][i][j - 1], mul(j, f[u][i][j])); if (i + 2 <= 10) g[u][i + 2][j] = add(g[u][i + 2][j], mul((Mod + 1) >> 1, f[u][i][j])); if (j - 2 >= 0) g[u][i][j - 2] = add(g[u][i][j - 2], mul((j * (j - 1)) >> 1, f[u][i][j])); if (i + 1 <= 10 && j - 1 >= 0) g[u][i + 1][j - 1] = add(g[u][i + 1][j - 1], mul(j, f[u][i][j])); } else { if (i - 1 >= 0) g[u][i - 1][j] = add(g[u][i - 1][j], mul(i, f[u][i][j])); if (j + 1 <= 20) g[u][i][j + 1] = add(g[u][i][j + 1], f[u][i][j]); } } } for (int i = 0; i <= 10; i++) for (int j = 0; j <= 20; j++) f[u][i][j] = add(f[u][i][j], g[u][i][j]); } void solve(int cas) { init(); scanf("%d %d", &n, &m); for (int i = 1; i <= n; i++) { scanf("%s", c + 1); insert(c, 1); } for (int i = 1; i <= m; i++) { scanf("%s", c + 1); insert(c, 2); } dfs(1); printf("Case #%d: %d\n", cas, f[1][0][0]); } int main() { int T; scanf("%d", &T); for (int i = 1; i <= T; i++) solve(i); return 0; }