1. 程式人生 > 其它 >[SOS DP][容斥] Codeforces 1620G Subsequences Galore

[SOS DP][容斥] Codeforces 1620G Subsequences Galore

題目大意

給定一個字串序列 \([t_1,t_2,\cdots ,t_m]\) ,定義 \(f([t_1,t_2,\cdots,t_m])\) 為至少是其中一個字串 \(t_i\) 的子序列的字串個數,其中 \(f([])=0\)

給定一個字串序列 \([s_1,s_2,\cdots ,s_m]\),對每一個子集 \([s_{i_1}, s_{i_2}, \cdots, s_{i_k}]\) 求出 \(f({[s_{i_1}, s_{i_2}, \cdots, s_{i_k}]})\)\(998244353\) 取模後的值。

輸出 \(f({[s_{i_1}, s_{i_2}, \cdots, s_{i_k}]})\times k\times (i_1+i_2+\cdots+i_k)\)

的異或和(不取模)。

注意每個字串 \(s_i\)​ 中的字母都是排好序的。

題解

設字串 \(s\) 中字元 \(c\) 的個數為 \(\mathrm{cnt}(c)\),則該字串的子序列的個數為 \(\prod_{c='a'}^{'z'}(\mathrm{cnt}(c)+1)\)

對於同時是多個字串的子序列的字串的個數,只要對每個字元 \(c\)\(\mathrm{cnt}(c)\)\(\min\)\(1\) 再相乘即可。

接下來考慮容斥。設字串集為 \(S\),設同時是 \(S\) 中所有字串的子序列的字串的個數為 \(g(S)\),則有

\[f(S)=\sum_{T\subseteq S} (-1)^{|T|-1} g(T) \]

這個式子實際上是一個符號和奇偶校驗碼有關的子集和,直接用SOS DP或者說類似於FMT的方法直接算即可,時間複雜度 \(O(|\Sigma|n2^n)\)

\(|\Sigma|\) 是字符集大小,本題中是 \(26\)

\(x\) 的奇偶校驗碼可以用 __builtin_parity(x) 直接算,若 \(x\) 的二進位制中有偶數個 \(1\),則返回 \(0\),否則返回 \(1\)

Code

#include <bits/stdc++.h>
using namespace std;

#define LL long long
const LL MOD = 998244353;
char buf[20010];
int f[1 << 23], a[26];
vector<int> v[25];
int n;

int main() {
    scanf("%d", &n);
    for (int i = 0;i < n;++i) {
        scanf("%s", buf + 1);
        v[i].resize(26);
        for (int j = 1;buf[j];++j)
            ++v[i][buf[j] - 'a'];
    }
    for (int i = 1;i < (1 << n);++i) {
        memset(a, 0x3f, sizeof(a));
        for (int j = 0;j < n;++j) {
            if (!(i & (1 << j))) continue;
            for (int k = 0;k < 26;++k)
                a[k] = min(a[k], v[j][k]);
        }
        f[i] = 1;
        for (int k = 0;k < 26;++k)
            f[i] = 1LL * f[i] * (a[k] + 1) % MOD;
    }
    for (int i = 0;i < n;++i) {
        for (int j = 0;j < (1 << n);++j) {
            int x = __builtin_parity(j);
            if (j & (1 << i)) {
                f[j] = (f[j] + -1 * f[j ^ (1 << i)]) % MOD;
                if (f[j] < 0) f[j] += MOD;
            }
        }
    }
    LL ans = 0;
    for (int i = 0;i < (1 << n);++i) {
        if (!__builtin_parity(i)) f[i] = MOD - f[i];
        int k = 0, x = 0;
        for (int j = 0;j < n;++j)
            if (i & (1 << j)) { ++k; x += j + 1; }
        LL temp = 1LL * k * x * f[i];
        ans ^= temp;
    }
    printf("%I64d\n", ans);

    return 0;
}