1. 程式人生 > 其它 >[紀中D1A組T1/GDOI2016]最長公共子串 #字串# #暴力#

[紀中D1A組T1/GDOI2016]最長公共子串 #字串# #暴力#

目錄

題目

思路

好想不好寫系列
不難發現,只要兩個區間相交,被覆蓋的元素是可以所以互換的,也就是說,相交區間可以合併為一個大區間,用排序實現:

	for(int i = 1 ; i <= k ; i++)
		sec[i].l = read() + 1 , sec[i].r = read() + 1;
		
	sort(sec + 1 , sec + k + 1 , cmp);//區間左端點從小到大排序
	for(int i = 1 ; i <= k ; i++)
		if(sec[i].l == sec[i].r)//只包含一個元素的區間相當於沒用
			sec[i].l = sec[i].r = -1;//銷燬
	for(int i = 1 , j = 2 ; i <= k ; i++)
		if(sec[i].l != -1) {
			if(j <= i)	j = i + 1;
			while(j <= k && sec[j].l <= sec[i].r) {
				sec[i].r = max(sec[i].r , sec[j].r);//合併
				sec[j].l = sec[j].r = -1;
				++j;
			}
		}
	
	int k2 = 0;
	for(int i = 1 ; i <= k ; i++) {//整理
		if(sec[i].l != -1) {
			++k2;
			l[k2] = sec[i].l , r[k2] = sec[i].r;
		}
	}
	k = k2;

下面很容易想到暴力:列舉\(i,j,len\)表示從\(T\)\(i\)位,\(S\)\(j\)位開始匹配,匹配了長度為\(len\)的子串,時間複雜度是\(O(n^3)\)顯然不可過.
這裡提供一種比較好想的優化方法:
顯然,若\(T\)\(i\)位,\(S\)\(j\)位開始的可匹配的連續子串長度最大為\(len\),那麼\(T\)\(i+x\)位開始,\(S\)\(j+x\)位開始的可匹配的連續子串長度最大至少為\(len-x\),另外,若\(S\)\(j+x\)開始,長度為\(len-x\)的子串不被任何一個區間包含,則直接取等號,又因為題目問的是最大值,這樣列舉到\(i'=i+x,j'=j+x\)

時可以直接跳過.
因此,我們用\(mem_{i,j}\)記錄列舉\(i,j\)\(len\)的初始長度,得到當前\(len\)後,對\(mem_{i+x,j+x}(x\in (0,len))\) 進行更新.
然後發現還是會超時
問題出在更新:如果我們對\(x\)的所有值都更新是會被卡的(若果存在一個非常大的區間,\(len\)又很大,那麼一次更新的時間接近\(O(n)\)).考慮 列舉到後面,更新的對於相等的\(i,j\), \(mem_{i,j}\)是不下降的,那我們不如直接讓後面來更新.
學過分塊的人就有一種隱隱的感覺:每次列舉\(x\)的值時不超過\(\sqrt n\)沒錯,這樣就保證了兩邊複雜度的均衡,可以AC此題!

程式碼

#include <iostream>
#include <cstdio>
#include <algorithm>
#define DEBUG 0//除錯開關
#define recode 1//優化開關
#define reg register
using namespace std;
const int N = 2010;
const int K = 100010;
int read() {
	int re = 0;
	char c = getchar();
	bool sig = false;
	while(c < '0' || c > '9')
		sig |= (c == '-') , c = getchar();
	while(c >= '0' && c <= '9')
		re = (re << 1) + (re << 3) + c - '0' , c = getchar();
	return sig ? -re : re;
}
int sread(int *a , int &n) {
	char c = getchar();
	while(c < 'a' || c > 'z')	c = getchar();
	while(c >= 'a' && c <= 'z')	a[++n] = c - 'a' , c = getchar();
}
int t[N] , s[N] , n , m;
int sumt[N][30];

struct node{
	int l , r;
}sec[K];
bool cmp(node a , node b) {return a.l < b.l;}

int k , l[K] , r[K];

int a[K][30];
int ans;

int mem[N][N];
int main() {
	freopen("lcs.in" , "r" , stdin);
	freopen("lcs.out" , "w" , stdout);

	sread(t , m);
	sread(s , n);
	k = read();
	for(int i = 1 ; i <= m ; i++) 
		for(int j = 0 ; j < 26 ; j++) 
			sumt[i][j] = sumt[i - 1][j] + (int)(t[i] == j);
//section union-begin
#if recode
	for(int i = 1 ; i <= k ; i++)
		sec[i].l = read() + 1 , sec[i].r = read() + 1;
		
	sort(sec + 1 , sec + k + 1 , cmp);
	for(int i = 1 ; i <= k ; i++)
		if(sec[i].l == sec[i].r)
			sec[i].l = sec[i].r = -1;
	for(int i = 1 , j = 2 ; i <= k ; i++)
		if(sec[i].l != -1) {
			if(j <= i)	j = i + 1;
			while(j <= k && sec[j].l <= sec[i].r) {
				sec[i].r = max(sec[i].r , sec[j].r);
				sec[j].l = sec[j].r = -1;
				++j;
			}
		}
	
	int k2 = 0;
	for(int i = 1 ; i <= k ; i++) {
		if(sec[i].l != -1) {
			++k2;
			l[k2] = sec[i].l , r[k2] = sec[i].r;
		}
	}
	k = k2;
#else
	for(int i = 1 ; i <= k ; i++)
		l[i] = read() + 1 , r[i] = read() + 1;
	
	for(int i = 1 ; i <= k ; i++)
		for(int j = 1 ; j <= k ; j++)
			if(i != j) {
				if(l[i] <= l[j] && l[j] <= r[i]) {
					l[i] = min(l[i] , l[j]);
					r[i] = max(r[i] , r[j]);
					l[j] = r[j] = -1;
				}
			}
	int k2 = 0;
	for(int i = 1 ; i <= k ; i++)
		if(l[i] != -1 && l[i] != r[i]) {
			++k2;
			l[k2] = l[i];
			r[k2] = r[i];
		}
	k = k2;
#endif

	for(reg int i = 1 ; i <= k ; ++i)
		for(reg int j = l[i] ; j <= r[i] ; ++j)
			++a[i][s[j]] , s[j] = -i;
//section union end


//debug 
#if DEBUG
	cout << k << '\n';
	for(int i = 1 ; i <= k ; i++) {
		cout << l[i] << '\t' << r[i] << '\n';
		for(int j = 0 ; j < 26 ; j++)
			cout << a[i][j] << ' ';
		putchar('\n');
	}
#endif


//matching begin
	
	int upb = m > n ? m : n;
	for(reg int i = 1 ; i <= m ; ++i)
		for(reg int j = 1 ; j <= n ; ++j) {
			if(i > m || j > n)	continue;
			
			reg int las = i - 1 , len = mem[i][j];
			if(len == -1)	continue;
			while(len + i <= m && len + j <= n) {
				if(s[j + len] >= 0) {
					if(t[i + len] != s[j + len])	break;
					las = len + i;
				} else {
					if(s[j + len] != s[j + len - 1])
						las = i + len - 1;
					if(sumt[i + len][t[i + len]] - sumt[las][t[i + len]] > a[-s[j + len]][t[i + len]])	break;
				}
				++len;
			}
			if(len + i > m || len + j > n)	len = min(m - i + 1 , n - j + 1);

			if(len > ans)
				ans = len;
			
			int I = i , J = j;
			int cnt = 0;
			while(len >= 0 && s[J] != s[J + len])
				++I , ++J , --len , mem[I][J] = -1 , ++cnt;
			
			while(len > 0 && cnt * cnt < 1000) {
				++I , ++J , --len , ++cnt;
				if(mem[I][J] > len)
					break;
				else
					mem[I][J] = len;
			}
		}
//matching end

	cout << ans;
	return 0;
}
/*
xabcdafg
aafbcd
1
0 2

abcdafg
aafbcd
2
0 2
2 5

*/