1. 程式人生 > >NOI2015.品酒大會(字尾陣列)

NOI2015.品酒大會(字尾陣列)

給出一個 長度為 n 的字串,每一位有一個權值 val。定義兩個位字元為 r 相似,是指分別從這兩個字元開始,到後面的 r 個字元都相等。兩個 r 相似的字元還有一個權值為這兩個字元權值的乘積。問對於 r = 0, 1, 2, … , n - 1,統計出有多少種方法可以選出 2 個“相似”的字元,並回答選擇 2 個”r 相似”的字元可以得到的權值的最大值。 

首先說一個暴力的做法,可以得到 50 分:

先預處理出 1 - n 累加的結果。對於這個字串求出 Height 陣列,然後列舉 r ,按 r Height 分組,然後統計方案數:假設組內有 x 個字尾,那麼貢獻的方案數就是 1 + 2 + ... + (x - 1)

,統計最大值:維護最大值和次大值,最小值和次小值,更新答案。

這種方法的複雜度主要看資料,如果資料很亂,導致 Height陣列的最大值很小,那麼這種方法就可以過,否則就會 TLE

#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;

const int MAX_N = 300005;
const long long INF = 9223372036854775807;

typedef long long LL;

int n, a[MAX_N], sa[MAX_N], r[MAX_N], h[MAX_N], id1, id2;
LL val[MAX_N], Mx, Mxx, Mn, Mnn;
int ws[MAX_N], wv[MAX_N], wa[MAX_N], wb[MAX_N];
LL add[MAX_N], ans, ret;

void da(int *a, int *sa, int n, int m) {
	int *x = wa, *y = wb;
	for (int i = 0; i < m; i ++) ws[i] = 0;
	for (int i = 0; i < n; i ++) ws[x[i] = a[i]] ++;
	for (int i = 1; i < m; i ++) ws[i] += ws[i - 1];
	for (int i = n - 1; i >= 0; i --) sa[-- ws[x[i]]] = i;
	for (int k = 1; k <= n; k <<= 1) {
		int p = 0;
		for (int i = n - k; i < n; i ++) y[p ++] = i;
		for (int i = 0; i < n; i ++) if (sa[i] >= k) y[p ++] = sa[i] - k;
		for (int i = 0; i < n; i ++) wv[i] = x[y[i]];
		for (int i = 0; i < m; i ++) ws[i] = 0;
		for (int i = 0; i < n; i ++) ws[wv[i]] ++;
		for (int i = 1; i < m; i ++) ws[i] += ws[i - 1];
		for (int i = n - 1; i >= 0; i --) sa[-- ws[wv[i]]] = y[i];
		swap(x, y); p = 1; x[sa[0]] = 0;
		for (int i = 1; i < n; i ++) x[sa[i]] = (y[sa[i - 1]] == y[sa[i]]) && (y[sa[i - 1] + k] == y[sa[i] + k]) ? p - 1 : p ++;
		if (p >= n) break; m = p;
	}
}
void calc() {
	for (int i = 1; i <= n; i ++) r[sa[i]] = i;
	int k = 0, j;
	for (int i = 0; i < n; h[r[i ++]] = k)
		for (k ? k -- : 0, j = sa[r[i] - 1]; a[i + k] == a[j + k]; k ++);
}
void work(int x) {
	int sum = 0, j;
	for (int i = 2; i <= n; i = j + 1) {
		for (; h[i] < x && i <= n; i ++);
		for (j = i; h[j] >= x; j ++);
		if (i == j) continue;
		sum = 0; Mx = -INF; Mn = Mnn = INF;
		ans += add[j - i];
		for (int k = i - 1; k < j; k ++) {
			if (val[sa[k]] > Mx) Mx = val[sa[k]], id1 = k;
			if (val[sa[k]] < Mn) Mn = val[sa[k]], id2 = k;
		}
		Mxx = -INF, Mnn = INF;
		for (int k = i - 1; k < j; k ++) {
			if (val[sa[k]] > Mxx && id1 != k) Mxx = val[sa[k]];
			if (val[sa[k]] < Mnn && id2 != k) Mnn = val[sa[k]];
		}
		ret = max(ret, max(Mxx * Mx, Mn * Mnn));
	}
}
void init() {
	scanf("%d", &n); getchar();
	for (int i = 0; i < n; i ++) {
		char c; scanf("%c", &c);
		a[i] = (int)c;
	}
	a[n] = 0; 
	Mx = Mxx = -INF; Mn = Mnn = INF;
	for (int i = 0; i < n; i ++) {
		scanf("%lld", &val[i]);
		if (Mx < val[i]) Mx = val[i], id1 = i;
		if (Mn > val[i]) Mn = val[i], id2 = i;
	}
	da(a, sa, n + 1, 128); calc();
	for (int i = 1; i <= n; i ++) printf("%d ", sa[i]); printf("\n");
	for (int i = 1; i <= n; i ++) printf("%d ", h[i]); printf("\n");
}
void doit() {
	int mxh = -1;
	for (int i = 1; i <= n; i ++) mxh = max(mxh, h[i]);
	for (int i = 1; i <= n; i ++) add[i] = add[i - 1] + i;
	for (int i = 0; i < n; i ++) {
		if (Mxx < val[i] && i != id1) Mxx = val[i];
		if (Mnn > val[i] && i != id2) Mnn = val[i];
	}
	printf("%lld %lld\n", add[n - 1], max(Mx * Mxx, Mn * Mnn));
	for (int i = 1; i < n; i ++) {
		if (i <= mxh) {
			ans = 0; ret = -INF;
			Mx = Mxx = -INF; Mn = Mnn = INF;
			work(i);
			printf("%lld %lld\n", ans, ret == -INF ? 0 : ret);
		} else printf("0 0\n");
	}
}
int main() {
	init();
	doit();
	return 0;
}
滿分演算法:

Height 陣列,然後按從大到小的順序排序,因為可以發現 Height 中的大值不會影響小值對答案的貢獻。每次更新答案,將當前兩個字串合併,即用並查集維護一下,他們對於方案數的貢獻就是這兩個字尾所在的集合個數的乘積,同時維護一下最大最小值就行了。時間複雜度 O( nlogn )

#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;

const int MAX_N = 300005;

typedef long long LL;

int n, a[MAX_N], sa[MAX_N], r[MAX_N], h[MAX_N];
int ws[MAX_N], wv[MAX_N], wa[MAX_N], wb[MAX_N];
int f[MAX_N], sz[MAX_N];
LL ans[MAX_N], sum[MAX_N], mx[MAX_N], mn[MAX_N], val[MAX_N];
char c;
struct node {
	int h, x, y;
}g[MAX_N];

inline bool cmp(node a, node b) { return a.h > b.h; }
int find(int x) { return x == f[x] ? x : (f[x] = find(f[x])); }
void uniont(int x, int y) { 
	f[y] = x; sz[x] += sz[y];
	mx[x] = max(mx[x], mx[y]);
	mn[x] = min(mn[x], mn[y]);
}
void da(int *a, int *sa, int n, int m) {
	int *x = wa, *y = wb;
	for (int i = 0; i < m; i ++) ws[i] = 0;
	for (int i = 0; i < n; i ++) ws[x[i] = a[i]] ++;
	for (int i = 1; i < m; i ++) ws[i] += ws[i - 1];
	for (int i = n - 1; i >= 0; i --) sa[-- ws[x[i]]] = i;
	for (int k = 1; k <= n; k <<= 1) {
		int p = 0;
		for (int i = n - k; i < n; i ++) y[p ++] = i;
		for (int i = 0; i < n; i ++) if (sa[i] >= k) y[p ++] = sa[i] - k;
		for (int i = 0; i < n; i ++) wv[i] = x[y[i]];
		for (int i = 0; i < m; i ++) ws[i] = 0;
		for (int i = 0; i < n; i ++) ws[wv[i]] ++;
		for (int i = 1; i < m; i ++) ws[i] += ws[i - 1];
		for (int i = n - 1; i >= 0; i --) sa[-- ws[wv[i]]] = y[i];
		swap(x, y); p = 1; x[sa[0]] = 0;
		for (int i = 1; i < n; i ++) x[sa[i]] = (y[sa[i - 1]] == y[sa[i]]) && (y[sa[i - 1] + k] == y[sa[i] + k]) ? p - 1 : p ++;
		if (p >= n) break; m = p;
	}
}
void calc() {
	for (int i = 1; i <= n; i ++) r[sa[i]] = i;
	int k = 0, j;
	for (int i = 0; i < n; h[r[i ++]] = k)
		for (k ? k -- : 0, j = sa[r[i] - 1]; a[i + k] == a[j + k]; k ++);
}
void init() {
	scanf("%d", &n); getchar();
	for (int i = 0; i < n; i ++) {
		char c; scanf("%c", &c);
		a[i] = (int)c;
	}
	a[n] = 0; 
	for (int i = 1; i <= n; i ++) scanf("%lld", &val[i]);
	da(a, sa, n + 1, 128); calc();
	for (int i = 1; i <= n; i ++) 
		f[i] = i, mx[r[i - 1]] = val[i], mn[r[i - 1]] = val[i], sz[i] = 1;
}
void doit() {
	for (int i = 2; i <= n; i ++) g[i - 1].h = h[i], g[i - 1].x = i, g[i - 1].y = i - 1;
	sort(g + 1, g + n, cmp);
	memset(sum, 128, sizeof(sum));
	for (int i = g[1].h, j = 1; i >= 0; i --) {
		ans[i] = ans[i + 1], sum[i] = sum[i + 1];
		for (; j < n && g[j].h == i; j ++) {
			int x = find(g[j].x), y = find(g[j].y);
			sum[i] = max(sum[i], 1ll * mx[x] * mx[y]);
			sum[i] = max(sum[i], 1ll * mn[x] * mn[y]);
			ans[i] += 1ll * sz[x] * sz[y];
			uniont(x,y);
		}
	}
	for (int i = 0; i < n; i ++) 
		printf("%lld %lld\n", ans[i], ans[i] == 0 ? 0 : sum[i]);
}
int main() {
	init();
	doit();
	return 0;
}