1. 程式人生 > 實用技巧 >2020牛客暑期多校訓練營(第二場) A All with Pairs

2020牛客暑期多校訓練營(第二場) A All with Pairs

思路:首先將所有後綴的hash值求出來,並對每個字尾出現的次數計數, 之後列舉每個串的字首, 假設串 a 的存在對應字尾的字首為 s1, s2, s3, |s1| < |s2| < |s3|,
假設 s3 對應串1,串2,串4 的字尾,首先 ans += cnt[s3], 然後看 s2, 若 s2 是 s3 的字尾, 則ans += cnt[s2] - cnt[s3], 因為 s3 是串1,串2,串4 的字尾,
s2 是 s3 的字尾, 那麼 s2 也是串1,串2,串4 的字尾,然後 串1,串2,串4 只和 串 a 的 s3 計算貢獻, 所以cnt[s2] 要減去 cnt[s3], 若 s2 不是 s3 的字尾,
則 ans += cnt[2], s1 同理。至於判斷 s2 是否是 s3 的字尾, KMP 即可。

#include <cstdio>
#include <algorithm>
#include <queue>
#include <stack>
#include <string>
#include <string.h>
#include <map>
#include <iostream>
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int, int> pii;
const int maxn = 1e6 + 50;
const ULL mod = 998244353;
const LL p = 233;
const int mod2 = 998244353;
int INF = 1e9;

#define fi first
#define se second
string s[maxn];
string T;
LL Hash[maxn];
LL pp[maxn];
map<ULL, LL> mmap;
LL cnt[maxn];
int Next[maxn];
int slen, tlen;
void getNext()
{
    int j, k;
    j = 0; k = -1; Next[0] = -1;
    while(j < tlen){
        if(k == -1 || T[j] == T[k]) Next[++j] = ++k;
        else k = Next[k];
    }
}

int main(int argc, char const *argv[])
{
    int n;
    scanf("%d", &n);
    pp[0] = 1;
    for(int i = 1; i < maxn; i++){
        pp[i] = 1LL * pp[i - 1] * p % mod;
    }
    for(int i = 1; i <= n; i++){
        cin >> s[i];
        int len = s[i].size();
        ULL base = 1;
        ULL hs = 0;
        for(int j = len - 1; j >= 0; --j){
            hs += base * s[i][j];
            mmap[hs]++;
            base *= p;
        }
    }

    LL ans = 0;
    for(int i = 1; i <= n; i++){
        T = s[i];
        tlen = T.size();
        ULL hs = 0;
        for(int j = 0; j < tlen; j++){
            cnt[j] = 0;
            hs = hs * p + s[i][j];
            if(!mmap.count(hs)) continue;
            cnt[j] += mmap[hs];
        }
        getNext();
        for(int j = 1; j <= tlen; j++){
            if(Next[j] == 0) continue;
            cnt[Next[j] - 1] -= cnt[j - 1];
        }
        for(int j = 0; j < tlen; j++){
            ans = (ans + 1LL * (j + 1) * (j + 1) % mod2 * cnt[j] % mod2) % mod2;
        }
    }

    cout << ans << endl;
    return 0;
}