1. 程式人生 > 實用技巧 >[HDU5421]Victor and String(PAM)

[HDU5421]Victor and String(PAM)

題面

http://acm.hdu.edu.cn/showproblem.php?pid=5421

題解

前置知識

這道題很好地證明了PAM是可以雙向新增的。

如果只有操作2、3、4,那就是PAM模板而已。現在考慮1怎麼做。

發現操作1和2、以及3,4的維護並不衝突。可以仿效結尾加點的操作,具體如下:

  • 維護lastl,lastr表示當前的字串的最長字首迴文子串和最長字尾迴文子串。(可以為全串)
  • 更新fail,next的操作毫不影響,因為本來轉移時迴文串就是要左右兩邊各加一個字元
  • lastl、lastr有可能會相互影響,就是當lastl或者lastr其中之一求出來為全串的時候,那麼也要把另一個也賦為全串。

更詳細操作可見程式碼。

程式碼

#include<iostream>
#include<cstring>

using namespace std;

#define rg register
#define In inline
#define ll long long

const int N = 1e5;

In ll read(){
	ll s = 0,ww = 1;
	char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-')ww = -1;ch = getchar();}
	while('0' <= ch && ch <= '9'){s = 10 * s + ch - '0';ch = getchar();}
	return s * ww;
}

In void write(ll x){
	if(x < 0)putchar('-'),x = -x;
	if(x > 9)write(x / 10);
	putchar('0' + x % 10);
}

char s[2*N+5];
int l,r;

struct PAM{
	int nx[N+5][26],fail[N+5],len[N+5];
	ll dep[N+5];
	int cnt,lastl,lastr;
	ll sum;
	void clear(){
		cnt = 1;
		fail[0] = fail[1] = 1;
		memset(nx[0],0,sizeof(nx[0]));
		memset(nx[1],0,sizeof(nx[1]));
		len[0] = 0,len[1] = -1;
		lastl = lastr = 0;
		sum = 0;
	}
	void create(){
		cnt++;
		memset(nx[cnt],0,sizeof(nx[cnt]));
	}
	void extendl(char c,int n){
		int id = c - 'a';
		int p = lastl;
		while(s[l+len[p]+1] != s[l])p = fail[p];
		if(!nx[p][id]){
			create();
			len[cnt] = len[p] + 2;
			int q = fail[p];
			while(s[l+len[q]+1] != s[l])q = fail[q];
			fail[cnt] = nx[q][id];
			dep[cnt] = dep[fail[cnt]] + 1;
			nx[p][id] = cnt;
		}
		lastl = nx[p][id];
		sum += dep[lastl];
		if(len[lastl] == r - l + 1)lastr = lastl;
	}
	void extendr(char c,int n){
		int id = c - 'a';
		int p = lastr;
		while(s[r-len[p]-1] != s[r])p = fail[p];
		if(!nx[p][id]){
			create();
			len[cnt] = len[p] + 2;
			int q = fail[p];
			while(s[r-len[q]-1] != s[r])q = fail[q];
			fail[cnt] = nx[q][id];
			dep[cnt] = dep[fail[cnt]] + 1;
			nx[p][id] = cnt;
		}
		lastr = nx[p][id];
		sum += dep[lastr];
		if(len[lastr] == r - l + 1)lastl = lastr;
	}
}P;

int n;

int main(){
	while(~scanf("%d",&n)){
		P.clear();
		memset(s,0,sizeof(s));
		l = N,r = N - 1;
		while(n--){
			int opt = read();
			if(opt <= 2){
				char c = getchar();
				while(c < 'a' || c > 'z')c = getchar();
				if(opt == 1)s[--l] = c,P.extendl(c,l);
				else s[++r] = c,P.extendr(c,r);
			}
			else{
				if(opt == 3)write(P.cnt - 1),putchar('\n');
				else write(P.sum),putchar('\n');
			}
		}
	}
	return 0;
}