1. 程式人生 > >HDu 6153 a secret 擴充套件KMP模板

HDu 6153 a secret 擴充套件KMP模板

題意大致就是給你兩個串,然後讓你求第二個串的各個字尾在第一個串的出現次數,然後讓次數乘該字尾的長度累加輸出,最後結果是對1e9+7取模

思路:在講具體思路之前,如果對擴充套件kmp不瞭解或者沒聽說過的話,建議先看看這個部落格https://blog.csdn.net/dyx404514/article/details/41831947

迴歸正題:我們首先要對這個問題進行轉換,因為在我所學過的演算法中大都是對字首進行處理,這樣也符合串的輸入順序,所以首先需要對兩個串進行反轉,這樣求串2的字尾就變成了串2的字首,於是問題就轉換為求轉換過後的串2的各字首在串1出現的次數和長度的乘積累加和,進一步分析就需要解決這個乘積的問題,很明顯單純的先計算出現次數然後再乘再累加一定會超時的,所以就需要對這個問題再進行轉換,這裡使用的方法就是計算串1的連續字串在串2的相同字首數,如果求得串1在第1個位置開始的連續字串跟串2相同的字首是3,那麼就代表串2的前三個字元在串1出現,然後這個三就包括a,aa,aaa,(假設前三個字元是aaa),那麼先計算這個三個的乘積累加,很明顯是一個等差數列:1*1+1*2+1*3,如果後續又出現這幾個字串,那麼在加就好,反正乘積就是加法的累計。其餘的細節就是擴充套件kmp了,在這裡不在細講,上面的連線裡說的很詳細

程式碼:

#include<cstdio>
#include<cstring>
#include<string>
#include<iostream>
#include<cstdlib>
#include<algorithm>

using namespace std;

const int maxn = 1e6 + 10;
const int mod = 1e9 + 7;
int nxt[maxn];
int ex[maxn];
string s, s1;

inline long long add(long long n)
{
		long long ans = ((n%mod)*((n + 1) % mod) / 2) % mod;
		return ans;
}

void  getnxt()
{
		int len = s1.size();
		int j = 0, k = 1;
		nxt[0] = len;
		while (j + 1 < len&&s1[j + 1] == s1[j])	j++;
		nxt[1] = j;
		for (int i = 2; i < len; i++)
		{
				int p = nxt[k] + k - 1;
				int L = nxt[i - k];
				if (i + L < p + 1)
						nxt[i] = L;
				else
				{
						j = max(0, p - i + 1);
						while (i + j < len&&s1[i + j] == s1[j])
								j++;
						nxt[i] = j;
						k = i;
				}
		}
}

void exkmp()
{
		int len = s.size(), len2 = s1.size();
		getnxt();
		int j = 0, k = 0;
		while (j < len&&j < len2&&s[j] == s1[j]) j++;
		ex[0] = j;
		for (int i = 1; i < len; i++)
		{
				int p = ex[k] + k - 1;
				int L = nxt[i - k];
				if (i + L < p + 1)
						ex[i] = L;
				else
				{
						j = max(0, p - i + 1);
						while (i + j < len&&j < len2&&s[i + j] == s1[j]) 		j++;
						ex[i] = j;
						k = i;
				}
		}
}

void ini()
{
		memset(nxt, 0, sizeof(nxt));
		memset(ex, 0, sizeof(ex));
		s.clear();
		s1.clear();
}

int main()
{
		ios::sync_with_stdio(false);
		int t;
		cin >> t;
		while (t--)
		{
				ini();
				cin >> s >> s1;
				int len = s.size();
				reverse(s.begin(), s.end());
				reverse(s1.begin(), s1.end());
				exkmp();
				long long ans = 0;
				for (int i = 0; i < len; i++)
				{
						if (ex[i])
								ans = (ans + add(ex[i]) % mod) % mod;
				}
				cout << ans % mod << endl;
		}
		//system("pause");
		return 0;
}