1. 程式人生 > >BZOJ 2160 拉拉隊排練 Manacher + 字首和

BZOJ 2160 拉拉隊排練 Manacher + 字首和

題目大意:

就是現在給出一個長度為 n 的字串(1 <= n <= 10^6)和一個正整數K(1 <= K <= 10^12)

對於給出的長度為n的字串如果其迴文串的數量比K少則輸出-1, 否則輸出所有迴文串中長度為奇數的最長的前K個迴文串的長度的乘積, 結果對於19930726取模輸出

大致思路:

首先可以用manacher演算法確定每個位置的迴文半徑, 由於這裡只需要長度是奇數, 所以不需要再原來的字串的相鄰兩個字元之間插入未出現的字元, 直接在首尾新增好不同字元之後盤一遍manacher演算法, 對於位置i為中心的迴文半徑R[i], 用dp[i]來表示相鄰長度的迴文串的數量差分(其實就是一個常用的字首和技巧, 因為這裡每次更新[1, R[i]]這個區間 + 1, 而只在所有更新完畢之後才查詢所以沒有必要使用樹狀陣列, 直接根據每次更新的時候dp[1]++, dp[R[i] + 1]--, 最後後dp[1~i]的和就是最終ans[i]的值, 即長度為2*i - 1的迴文串的數量

然後用快速冪就可以了, 沒有什麼難度

細節就看程式碼吧

(吐槽一下第一次在BZOJ上交題沒在F.A.Q裡看到用%I64d還是%lld, 然後我PE了看了好久不知道為什麼....最後還是用%lld過了 = =)

話說用%I64d就算是Wrong Answer也判了Presentation Error..

程式碼如下:

Result  :  Accepted     Memory  :  21780 KB     Time  :  520 ms

/**************************************************************
    Problem: 2160
    User: Gatevin
    Language: C++
    Result: Accepted
    Time:520 ms
    Memory:21780 kb
****************************************************************/
 
/*
 * Author: Gatevin
 * Created Time:  2015/3/20 11:20:33
 * File Name: Chitoge_Kirisaki.cpp
 */
#include<iostream>
#include<sstream>
#include<fstream>
#include<vector>
#include<list>
#include<deque>
#include<queue>
#include<stack>
#include<map>
#include<set>
#include<bitset>
#include<algorithm>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cctype>
#include<cmath>
#include<ctime>
#include<iomanip>
using namespace std;
const double eps(1e-8);
typedef long long lint;
 
#define maxn 1000100
 
const lint mod = 19930726;
 
int n;
lint K;
char in[maxn];
int R[maxn];
lint dp[maxn];
lint ans[maxn];
/*
 * 這裡沒有在輸入的字串中間插入特殊字元
 * 得到的半徑R都是奇數長度的迴文串的半徑
 */
void Manacher(char *s, int *R, int n)
{
    int p = 0, mx = 0;
    R[0] = 1;
    for(int i = 1; i <= n; i++)
    {
        if(mx > i)
            R[i] = min(R[2*p - i], mx - i);
        else R[i] = 1;
        while(s[i - R[i]] == s[i + R[i]])
            R[i]++;
        if(i + R[i] > mx)
            p = i, mx = i + R[i];
    }
    return;
}
 
lint quick_pow(lint base, lint pow)
{
    lint ret = 1;
    while(pow)
    {
        if(pow & 1) ret = (ret*base) % mod;
        pow >>= 1;
        base = base*base % mod;
    }
    return ret;
}
 
int main()
{
    scanf("%d %lld", &n, &K);
    {
        scanf("%s", in + 1);
        in[0] = '@'; in[n + 1] = '$'; in[n + 2] = '\0';
        Manacher(in, R, n);
        memset(dp, 0, sizeof(dp));
        for(int i = 1; i <= n; i++)//所有以i為中心的奇數長度的迴文串
        {
            dp[1]++;
            dp[R[i] + 1]--;
        }
        memset(ans, 0, sizeof(ans));
        int maxlen = 0;
        for(int i = 1; i <= (n + 1) >> 1; i++)//ans[i]表示長度為2*i - 1的串的個數
        {
            ans[i] = ans[i - 1] + dp[i];
            if(ans[i] > 0) maxlen = i;
        }
        lint result = 1;
        int r = maxlen;
        while(r && K > 0)
            result = (result*quick_pow((lint)(2*r - 1), min(ans[r], K))) % mod, K -= ans[r], r--;
        if(K > 0) printf("-1\n");
        else printf("%lld\n", result);
    }
    return 0;
}