1. 程式人生 > 其它 >【luogu P3781】【LOJ 2269】切樹遊戲(FWT)(DDP)

【luogu P3781】【LOJ 2269】切樹遊戲(FWT)(DDP)

切樹遊戲

題目連結:luogu P3781 / LOJ 2269

題目大意

給你一棵樹,會有單點修改,要你在其中求有多少棵子樹的權值異或和是一個詢問的 k。

思路

首先考慮沒有修改的 DP。

先是最暴力的:\(f_{i,j}\)\(i\) 的子樹,當前選的子樹有 \(i\),異或和為 \(j\)
詢問答案 \(k\) 就是 \(\sum\limits_{i=1}^nf_{k,j}\)
然後轉移是列舉子樹 \(son\)\(f_{i,j}=\sum\limits_{x\oplus y=j}f_{i,x}f_{son,y}+f_{i,j}\)
然後這個顯然是可以用 FWT 優化的:\(nf_{i}=FWT[a_i]\)


(其實卷積起來就是 \(f_{i,k}=nf_{i,k}\prod\limits_{son}(f_{son,k}+1)\)

然後至於答案我們可以搞 \(h_{i,k}=f_{i,k}+\sum\limits_{son}h_{son,k}\)

那我們這個就可以搞,然後搞出來最後 IFWT 轉回去。
那一次的複雜度是 \(O(nm\log m)\)


考慮加上動態改點,發現上面的 DP 其實還算簡單,我們考慮 DDP。
那還是輕重鏈剖分,然後重兒子是 \(hson_x\)
那我們設 \(lf_{x,k}=nf_{x,k}\prod\limits_{son\wedge son\neq hson_{x}}(f_{son,k}+1)\)


\(lh_{x,k}=\sum\limits_{son\wedge son\neq hson_{x}}lh_{son,k}\)
然後你會發現這個 \(k\) 一直只跟自己有關聯,然後一直掛在這裡很煩,那我們考慮把上面轉移那些去掉 \(k\) 的一維,然後把它們當做陣列轉移。
(因為你都跟自己有關,所以加減乘除都是 \(O(m)\) 的,就不用卷積)

然後看看怎麼轉移,準備上矩陣。
\(f_{x}=lf_{x}(f_{hson_x}+1)\)
\(h_x=f_x+lh_x+h_{hson_x}=lf_{x}(f_{hson_x}+1)+lh_x+h_{hson_x}\)

然後你就可以嘗試建矩陣啦!
先搞一個 \(\begin{vmatrix}f_x\\ h_x\\ 1\end{vmatrix}\)


然後要通過乘一個矩陣把 \(\begin{vmatrix}f_{hson_x}\\ h_{hson_x}\\ 1\end{vmatrix}\) 變成 \(\begin{vmatrix}f_x\\ h_x\\ 1\end{vmatrix}\)
然後可以構造出:
\(\begin{vmatrix}lf_x&0& lf_x\\ lf_x&1&lf_x+lh_x\\ 0&0&1\end{vmatrix}*\begin{vmatrix}f_{hson_x}\\ h_{hson_x}\\ 1\end{vmatrix}=\begin{vmatrix}f_x\\ h_x\\ 1\end{vmatrix}\)

然後就可以用這個矩陣搞,複雜度為 \(O(qlog^{(2)}nm*27)\),好像有點大,就算用全域性平衡二叉樹也過不去。

然後我們看到這個矩陣的樣子比較特別,考慮手玩一下:
\(\begin{vmatrix}a_1&0& b_1\\ c_1&1&d_1\\ 0&0&1\end{vmatrix}*\begin{vmatrix}a_2&0& b_2\\ c_2&1&d_2\\ 0&0&1\end{vmatrix}=\begin{vmatrix}a_1a_2&0& a_1b_2+b_1\\ a_2c_1+c_2&1&c_1b_2+d_2+d_1\\ 0&0&1\end{vmatrix}\)
然後你會發現你只用維護四個值,而且它們怎麼維護是固定的,所以常數就變成了 \(4\),用平衡二叉樹就可以過啦!
(好像說 luogu 樹鏈剖分被卡了的說)

然後至於實現的話。。。多用結構體。
具體一下就是你矩陣裡面每個值是陣列用結構體,輕邊的維護不能直接轉移,因為你修改的時候要除,但是裡面可能是 \(0\),所以你要再弄一個結構體專門來輕邊的轉移,就是一個正常的陣列加上一個表示陣列每一位乘了 \(0\) 的個數。
然後你乘 \(0\) 就不乘而是加 \(0\) 的個數,除就是減,然後弄一個函式把它轉回乘整除的陣列,就是 \(O(m)\) 列舉一次把有 \(0\) 的變成 \(0\)
然後建議全域性平衡二叉樹也可以封裝一下。

(麻了程式碼是真的長)

程式碼

#include<cstdio>
#include<algorithm>

using namespace std;

const int mo = 1e4 + 7;
const int N = 3e4 + 10;
const int M = 128;
struct node {
	int to, nxt;
}e[N << 1];
int n, m, a[N], inv[N], le[N], KK;
int sz[N], son[N], ans[M];
char c;

int jia(int x, int y) {return x + y >= mo ? x + y - mo : x + y;}
int jian(int x, int y) {return x < y ? x - y + mo : x - y;}
int cheng(int x, int y) {return x * y % mo;}

void FWT(int *f, int limit, int op) {//FWT
	for (int mid = 1; mid < limit; mid <<= 1) {
		for (int R = mid << 1, j = 0; j < limit; j += R)
			for (int k = 0; k < mid; k++) {
				int x = f[j | k], y = f[j | mid | k];
				f[j | k] = jia(x, y); f[j | mid | k] = jian(x, y);
				if (op == -1) f[j | k] = cheng(f[j | k], inv[2]), f[j | mid | k] = cheng(f[j | mid | k], inv[2]);
			}
	}
}

void add(int x, int y) {e[++KK] = (node){y, le[x]}; le[x] = KK;}

void dfs(int now, int father) {//重鏈剖分
	sz[now] = 1;
	for (int i = le[now]; i; i = e[i].nxt)
		if (e[i].to != father) {
			dfs(e[i].to, now); sz[now] += sz[e[i].to];
			if (sz[e[i].to] > sz[son[now]]) son[now] = e[i].to;
		}
}

struct poly {//記得你矩陣裡面每個值都是一個數組,那加減我們得維護
	int f[M];
	
	int& operator [](int& x) {return f[x];}
	poly operator +(poly y) {poly re; for (int i = 0; i < m; i++) re[i] = jia(f[i], y[i]); return re;}
	poly operator -(poly y) {poly re; for (int i = 0; i < m; i++) re[i] = jian(f[i], y[i]); return re;}
	poly operator *(poly y) {poly re; for (int i = 0; i < m; i++) re[i] = cheng(f[i], y[i]); return re;}
	poly operator /(poly y) {poly re; for (int i = 0; i < m; i++) re[i] = cheng(f[i], inv[y[i]]); return re;}
	void operator +=(poly y) {for (int i = 0; i < m; i++) f[i] = jia(f[i], y[i]);}
	void operator -=(poly y) {for (int i = 0; i < m; i++) f[i] = jian(f[i], y[i]);}
}one, ee[N];

struct matrix {
	poly a[2][2];
	
	poly* operator [](const int& x) {return a[x];}
	matrix operator *(matrix b) {//優化了的矩陣轉移
		matrix re;
		re[0][0] = a[0][0] * b[0][0];
		re[0][1] = a[0][0] * b[0][1] + a[0][1];
		re[1][0] = b[0][0] * a[1][0] + b[1][0];
		re[1][1] = a[1][0] * b[0][1] + b[1][1] + a[1][1];
		return re;
	}
};

struct Light {//對於輕邊上的轉移我們可以單獨開一個結構體,因為要記錄 0 個數
	int num0[M], val[M];
	
	void change(int pl, int va) {
		if (!va) num0[pl] = val[pl] = 1;
			else num0[pl] = 0, val[pl] = va;
	}
	
	void operator *=(poly f) {
		for (int i = 0; i < m; i++)
			if (!f.f[i]) num0[i]++;
				else val[i] = cheng(val[i], f.f[i]);
	}
	
	void operator /=(poly f) {
		for (int i = 0; i < m; i++)
			if (!f.f[i]) num0[i]--;
				else val[i] = cheng(val[i], inv[f.f[i]]);
	}
};

poly get_poly(Light &b) {//把輕邊的結構體轉回給陣列
	poly x; for (int i = 0; i < m; i++) x[i] = b.num0[i] ? 0 : b.val[i]; return x;
}

struct BST {//全域性平衡二叉樹
	Light lf[N]; poly lh[N];
	int fa[N], ls[N], rs[N], root, sta[N], ssz[N];
	matrix val[N], sum[N];
	
	bool nrt(int x) {
		return ls[fa[x]] == x || rs[fa[x]] == x;
	}
	
	void Make_val(int now, int to) {//求 lf,lh (輕邊的值轉移) 
		lf[now] *= (sum[to][1][0] + one);
		lh[now] += sum[to][1][1];
	}
	
	void Clean_val(int now, int to) {
		lf[now] /= (sum[to][1][0] + one);
		lh[now] -= sum[to][1][1];
	}
	
	void Make_Val(int now) {//建矩陣 
		val[now][0][0] = val[now][0][1] = val[now][1][0] = val[now][1][1] = get_poly(lf[now]);
		val[now][1][1] += lh[now];
	}
	
	void up(int now) {
		sum[now] = sum[ls[now]] * val[now] * sum[rs[now]];
	}
	
	int buildT(int l, int r) {
		if (l > r) return 0;
		int tot = 0; for (int i = l; i <= r; i++) tot += ssz[sta[i]];
		for (int i = l, now = ssz[sta[i]]; i <= r; i++, now += ssz[sta[i]])
			if (now * 2 >= tot) {
				ls[sta[i]] = buildT(l, i - 1); rs[sta[i]] = buildT(i + 1, r);
				fa[ls[sta[i]]] = fa[rs[sta[i]]] = sta[i]; up(sta[i]); return sta[i];
			}
	}
	
	int build(int now, int fr) {
		for (int i = now; i; fr = i, i = son[i]) {
			for (int j = le[i]; j; j = e[j].nxt)
				if (e[j].to != fr && e[j].to != son[i]) {
					int x = build(e[j].to, i); fa[x] = i;
					Make_val(i, x);
				}
			Make_Val(i);
		}
		sta[0] = 0;
		for (int i = now; i; i = son[i]) sta[++sta[0]] = i, ssz[i] = sz[i] - sz[son[i]];
		reverse(sta + 1, sta + sta[0] + 1);//反轉,因為也是從下到上DP的 
		return buildT(1, sta[0]);
	}
	
	void Init() {
		for (int i = 1; i <= n; i++)
			for (int j = 0; j < m; j++)
				lf[i].change(j, ee[i][j]);
		val[0][0][0] = sum[0][0][0] = one;
		root = build(1, 0);
	}
	
	void change(int x, int y) {
		lf[x] /= ee[x];
		for (int i = 0; i < m; i++) ee[x][i] = 0; a[x] = y; ee[x][a[x]] = 1;
		FWT(ee[x].f, m, 1); lf[x] *= ee[x]; Make_Val(x);
		for (; x; x = fa[x]) {
			if (nrt(x)) up(x);//重鏈上直接上傳
				else {//輕鏈要修改好父親的值,上傳
					Clean_val(fa[x], x); up(x); Make_val(fa[x], x); Make_Val(fa[x]);
				}
		}
	}
	
	void update_ans() {
		for (int i = 0; i < m; i++) ans[i] = sum[root][1][1][i];
		FWT(ans, m, -1);
	}
}T;

void Init() {
	for (int i = 0; i < m; i++) one[i] = 1;
	inv[0] = inv[1] = 1; for (int i = 2; i < mo; i++) inv[i] = cheng(inv[mo % i], mo - mo / i);
}

int main() {
	scanf("%d %d", &n, &m); Init();
	for (int i = 1; i <= n; i++) scanf("%d", &a[i]), ee[i][a[i]] = 1, FWT(ee[i].f, m, 1);
	for (int i = 1; i < n; i++) {
		int x, y; scanf("%d %d", &x, &y); add(x, y); add(y, x);
	}
	
	dfs(1, 0);
	T.Init();
	T.update_ans(); 
	
	int q; scanf("%d", &q);
	while (q--) {
		c = getchar(); while (c != 'C' && c != 'Q') c = getchar();
		if (c == 'C') {
			while (c != ' ') c = getchar();
			int x, y; scanf("%d %d", &x, &y);
			T.change(x, y); T.update_ans();
		}
		else {
			while (c != ' ') c = getchar();
			int x; scanf("%d", &x);
			printf("%d\n", ans[x]);
		}
	}
	
	return 0;
}