1. 程式人生 > 其它 >洛谷 P3781 - [SDOI2017]切樹遊戲(動態 DP+FWT)

洛谷 P3781 - [SDOI2017]切樹遊戲(動態 DP+FWT)

[DP&資料結構] 動態 dp+FWT+一些玄學優化技巧,毒瘤題

洛谷題面傳送門

SDOI 2017 R2 D1 T3,nb tea %%%

講個笑話,最近我在學動態 dp,wjz 在學 FWT,而我們剛好在同一天做到了這道題,而這道題剛好又是 FWT+動態 dp

首先考慮怎樣暴力計算答案,我們記 \(dp_{u,j}\) 表示以 \(u\) 為根的子樹中有多少個連通塊包含 \(u\) 且權值的異或和為 \(j\),初始 \(dp_{u,val_u}=1\),每次遍歷 \(u\) 的一個子樹 \(v\) 就對這個子樹就對這兩個子樹的 \(dp\) 做一個合併,即 \(dp_{u,x}\leftarrow dp_{u,x}+\sum\limits_{y=0}^{m-1}dp_{u,y}\times dp_{v,x\oplus y}\)

,最終答案即為 \(\sum\limits_{u}dp_{u,k}\)。正確性顯然,時間複雜度 \(\mathcal O(qnm^2)\),可以通過前四個測試點。

考慮優化,首先一個非常明顯的優化是,DP 轉移方程式長得一臉 xor 卷積的樣子,如果我們記 \(f*g\) 表示 \(f,g\) 兩個集合冪級數的 FWTxor,那麼上述式子可以改寫為 \(dp_u=dp_u+dp_u*dp_v=dp_u*(dp_v+1)\),因此考慮將所有 \(dp_u\) 都變為 \(\text{FWT}(dp_u)\),那麼 \(dp_u\) 的初始值就變為 \(dp_{u,i}=\text{FWT}(E_{val_u})_i\)

,其中 \(E_i\) 為滿足 \(f_i=1,f_j=0(j\ne i)\) 的集合冪級數 \(f\),這個可以通過預處理所有 \(\text{FWT}(E_i)\) 實現 \(\mathcal O(m)\) 初始化。轉移操作可以根據 FWT 那一套理論變成 \(dp_{u,i}\leftarrow dp_{u,i}\times(dp_{v,i}+1)\),這樣即可實現 \(\mathcal O(m)\) 轉移。然後每次操作完了之後再 IFWT 回來即可,時間複雜度降到了 \(\mathcal O(qnm+qm\log m)\),還是隻能通過前四個測試點(

注意到上述 \(dp\) 對於不帶修改的情況是 efficient enough 的,但是帶上修改就直接萎掉了,因此考慮擅長處理修改操作的動態 \(dp\)

來解決這個問題,按照動態 \(dp\) 的套路我們將樹剖成一條條重鏈,\(dp\) 分為輕兒子和重兒子處理,我們記 \(dpl_{u,i}=\sum\limits_{v\in\text{lightson}(u)}(dp_{v,i}+1)\),那麼記 \(w=wson_u\),則有 \(dp_{u,i}=dpl_{u,i}\times\text{FWT}(E_{val_u})_i\times (dp_{w,i}+1)\)

但是光記錄一個 \(dp\) 值是遠遠不夠的,因為最終我們要求的是整棵子樹中 \(dp_{u,k}\) 的值之和,所有我們不得不再額外記錄 \(sum_{u,i}\) 表示子樹中所有點的 \(dp_{u,i}\) 之和,那麼有 \(sum_{u,i}=\sum\limits_{v\in \text{son}(u)}sum_{v,i}+dp_{u,i}\),按照套路我們還是記 \(suml_{u,i}=\sum\limits_{v\in\text{lightson}(u)}sum_{v,i}\),那麼有 \(sum_{u,i}=sum_{w,i}+dp_{u,i}+suml_{u,i}=sum_{w,i}+dpl_{u,i}\times\text{FWT}(E_{val_u})_i\times(dp_{w,i}+1)+suml_{u,i}\)

考慮將這東西寫成矩陣的形式,那麼有:

\[\begin{bmatrix}dp_{u}&sum_{u}&1\end{bmatrix}=\begin{bmatrix}dp_{w}&sum_{w}&1\end{bmatrix}\times \begin{bmatrix} dpl_u\times\text{FWT}(E_{val_u})&dpl_u\times\text{FWT}(E_{val_u})&0\\ 0&1&0\\ dpl_u\times\text{FWT}(E_{val_u})&dpl_u\times\text{FWT}(E_{val_u})+suml_u&1 \end{bmatrix} \]

其中 \(f\times g\) 就對應項相乘好了,\(f+g\) 也同理。

\(A_u=\begin{bmatrix} dpl_u\times\text{FWT}(E_{val_u})&dpl_u\times\text{FWT}(E_{val_u})&0\\ 0&1&0\\ dpl_u\times\text{FWT}(E_{val_u})&dpl_u\times\text{FWT}(E_{val_u})+suml_u&1 \end{bmatrix}\),那麼對於一個點 \(u\) 而言,記它到重鏈底經過的節點依次是 \(u=v_1,v_2,\cdots,v_k\),那麼有

\[\begin{bmatrix}dp_{u}&sum_{u}&1\end{bmatrix}=\begin{bmatrix}0&0&1\end{bmatrix}\times\prod\limits_{i=k}^1A_{v_i} \]

這個可以樹鏈剖分+線段樹維護。

修改操作就按照動態 \(dp\) 的套路不斷跳重鏈並撤銷原來的 \(dp_{top_u}\)\(dpl_{fa[top_u]}\)\(suml_{fa[top_u]}\) 的影響並加入新的貢獻即可,時間複雜度 \(q\log^2nm+qm\log m\),LOJ 上可以通過,而洛谷上由於某兩位毒瘤提供的毒瘤卡樹剖的資料,只能獲得 \(80\) 分的好成績。

說起來輕巧,實現起來一堆細節需要注意:

  1. 直接矩陣乘法會多 \(27\) 的常數,導致 TLE,因此需要按照套路進行優化,注意到這個 \(3\times 3\) 的矩陣中只有四個位置是有用的,因此可以只維護這四個位置的值,即 \(\begin{bmatrix}a&b&0\\0&1&0\\c&d&1\end{bmatrix}\),那麼有 \(\begin{bmatrix}a_1&b_1&0\\0&1&0\\c_1&d_1&1\end{bmatrix}\times \begin{bmatrix}a_2&b_2&0\\0&1&0\\c_2&d_2&1\end{bmatrix}=\begin{bmatrix}a_1a_2&a_1b_2+b_1&0\\0&1&0\\a_2c_1+c_2&b_2c_1+d_1+d_2&1\end{bmatrix}\),這樣常數可以降到 \(4\)
  2. 注意線段樹 pushup 的順序。
  3. 在撤銷原來的貢獻時會出現除以 \(0\) 的情況,因此可以將 \(dpl_{u,i}\) 存成一個個結構體,每個結構體用 \(x\times 0^y\) 表示一個數,每次乘以 \(0\) 時令 \(y\) 加一,除以 \(0\) 則令 \(y\) 減一,這樣可以避免這個問題(u1s1 蒟蒻是第一次遇到這個套路呢,大佬不喜勿噴)
  4. 注意計算新加入的貢獻時是計算線段樹上 \([dfn[top[x]]],dfn[bot[top[x]]]\) 內矩陣的乘積,而不是 \([dfn[x]],dfn[bot[top[x]]]\),蒟蒻因為這個錯誤調了 1h,心態炸裂。

碼了 212 行……

const int MAXN=3e4;
const int MAXV=1<<7;
const int MOD=1e4+7;
const int INV2=5004;
int n,m,val[MAXN+5],inv[MOD+4];
void getinv(){
	for(int i=(inv[0]=inv[1]=1)+1;i<MOD;i++) inv[i]=inv[MOD%i]*(MOD-MOD/i)%MOD;
}
void FWTxor(int *a,int len,int type){
	for(int i=2;i<=len;i<<=1)
		for(int j=0;j<len;j+=i)
			for(int k=0;k<(i>>1);k++){
				int X=a[j+k],Y=a[(i>>1)+j+k];
				if(~type) a[j+k]=(X+Y)%MOD,a[(i>>1)+j+k]=(X-Y+MOD)%MOD;
				else a[j+k]=(X+Y)*INV2%MOD,a[(i>>1)+j+k]=(X-Y+MOD)*INV2%MOD;
			}
}
struct num0{//number expressed as x*0^y
	int x,y;
	num0(int v=1){(!v)?(y=x=1):(x=v,y=0);}
	num0 operator *(const int &rhs){
		(!rhs)?(++y):(x=x*rhs%MOD);
		return *this;
	}
	num0 operator /(const int &rhs){
		(!rhs)?(--y):(x=x*inv[rhs]%MOD);
		return *this;
	}
	int num(){return y?0:x;}
};
struct poly{
	int a[MAXV+5];
	poly(){memset(a,0,sizeof(a));}
	poly(int x){for(int i=0;i<m;i++) a[i]=x;}
	poly operator +(poly rhs) const{
		poly res;
		for(int i=0;i<m;i++) res.a[i]=(a[i]+rhs.a[i])%MOD;
		return res;
	}
	poly operator *(poly rhs) const{
		poly res(1);
		for(int i=0;i<m;i++) res.a[i]=a[i]*rhs.a[i]%MOD;
		return res;
	}
	void FWT(){FWTxor(a,m,1);}
	void IFWT(){FWTxor(a,m,-1);}
} e[MAXV+5];
int hd[MAXN+5],to[MAXN*2+5],nxt[MAXN*2+5],ec=0;
void adde(int u,int v){to[++ec]=v;nxt[ec]=hd[u];hd[u]=ec;}
int siz[MAXN+5],fa[MAXN+5],dep[MAXN+5],wson[MAXN+5];
int top[MAXN+5],dfn[MAXN+5],tim=0,rid[MAXN+5];
int bot[MAXN+5];
void dfs1(int x=1,int f=0){
	siz[x]=1;fa[x]=f;
	for(int e=hd[x];e;e=nxt[e]){
		int y=to[e];if(y==f) continue;
		dep[y]=dep[x]+1;dfs1(y,x);siz[x]+=siz[y];
		if(siz[y]>siz[wson[x]]) wson[x]=y;
	}
}
void dfs2(int x=1,int tp=1){
	top[x]=tp;rid[dfn[x]=++tim]=x;if(wson[x]) dfs2(wson[x],tp);
	for(int e=hd[x];e;e=nxt[e]){
		int y=to[e];if(y==wson[x]||y==fa[x]) continue;
		dfs2(y,y);
	}
}
int f[MAXN+5][MAXV+5],sum[MAXN+5][MAXV+5],suml[MAXN+5][MAXV+5];
num0 fl[MAXN+5][MAXV+5];
void dfs3(int x=1){
	for(int i=0;i<m;i++) f[x][i]=e[val[x]].a[i],fl[x][i]=1;
	for(int e=hd[x];e;e=nxt[e]){
		int y=to[e];if(y==fa[x]) continue;dfs3(y);
		for(int i=0;i<m;i++) f[x][i]=f[x][i]*(f[y][i]+1)%MOD;
		for(int i=0;i<m;i++) sum[x][i]=(sum[x][i]+sum[y][i])%MOD;
	} for(int i=0;i<m;i++) sum[x][i]=(sum[x][i]+f[x][i])%MOD;
}
void dfs4(int x=1){
	if(wson[x]) dfs4(wson[x]);
	for(int e=hd[x];e;e=nxt[e]){
		int y=to[e];if(y==fa[x]||y==wson[x]) continue;dfs4(y);
		for(int i=0;i<m;i++) fl[x][i]=fl[x][i]*((f[y][i]+1)%MOD);
		for(int i=0;i<m;i++) suml[x][i]=(suml[x][i]+sum[y][i])%MOD;
	}
}
struct mat{
	poly a,b,c,d;
	mat(){}
	mat operator *(const mat &rhs){
		mat res;res.a=a*rhs.a;res.b=b+a*rhs.b;
		res.c=rhs.a*c+rhs.c;res.d=rhs.b*c+d+rhs.d;
		return res;
	}
};
void print(mat x){
	for(int i=0;i<m;i++) printf("%d%c",x.a.a[i]," \n"[i==m-1]);
	for(int i=0;i<m;i++) printf("%d%c",x.b.a[i]," \n"[i==m-1]);
	for(int i=0;i<m;i++) printf("%d%c",x.c.a[i]," \n"[i==m-1]);
	for(int i=0;i<m;i++) printf("%d%c",x.d.a[i]," \n"[i==m-1]);
}
mat get(int x){
	mat res;
	for(int i=0;i<m;i++) res.a.a[i]=res.b.a[i]=res.c.a[i]=fl[x][i].num()*e[val[x]].a[i]%MOD;
	for(int i=0;i<m;i++) res.d.a[i]=(res.a.a[i]+suml[x][i])%MOD;
	return res;
}
struct node{int l,r;mat v;} s[MAXN*4+5];
void pushup(int k){s[k].v=s[k<<1|1].v*s[k<<1].v;}
void build(int k,int l,int r){
	s[k].l=l;s[k].r=r;if(l==r) return s[k].v=get(rid[l]),void();
	int mid=l+r>>1;build(k<<1,l,mid);build(k<<1|1,mid+1,r);pushup(k);
}
mat query(int k,int l,int r){
	if(l<=s[k].l&&s[k].r<=r) return s[k].v;
	int mid=s[k].l+s[k].r>>1;
	if(r<=mid) return query(k<<1,l,r);
	else if(l>mid) return query(k<<1|1,l,r);
	else return query(k<<1|1,mid+1,r)*query(k<<1,l,mid);
}
void modify(int k,int p){
	if(s[k].l==s[k].r) return s[k].v=get(rid[p]),void();
	int mid=s[k].l+s[k].r>>1;
	if(p<=mid) modify(k<<1,p);else modify(k<<1|1,p);
	pushup(k);
}
void change(int x){
	while(x){
		if(fa[top[x]]){
			mat res=query(1,dfn[top[x]],dfn[bot[top[x]]]);
//			printf("%d\n",fa[top[x]]);
//			for(int i=0;i<m;i++) printf("{%d,%d}%c",fl[fa[top[x]]][i].x,fl[fa[top[x]]][i].y," \n"[i==m-1]);
//			for(int i=0;i<m;i++) printf("%d%c",(res.c.a[i]+1)%MOD," \n"[i==m-1]);
//			print(get(fa[top[x]]));
			for(int i=0;i<m;i++) fl[fa[top[x]]][i]=fl[fa[top[x]]][i]/((res.c.a[i]+1)%MOD);
			for(int i=0;i<m;i++) suml[fa[top[x]]][i]=(suml[fa[top[x]]][i]-res.d.a[i]+MOD)%MOD;
		} modify(1,dfn[x]);
		if(fa[top[x]]){
			mat res=query(1,dfn[top[x]],dfn[bot[top[x]]]);
//			print(res);
			for(int i=0;i<m;i++) fl[fa[top[x]]][i]=fl[fa[top[x]]][i]*((res.c.a[i]+1)%MOD);
			for(int i=0;i<m;i++) suml[fa[top[x]]][i]=(suml[fa[top[x]]][i]+res.d.a[i])%MOD;
//			for(int i=0;i<m;i++) printf("{%d,%d}%c",fl[fa[top[x]]][i].x,fl[fa[top[x]]][i].y," \n"[i==m-1]);
//			for(int i=0;i<m;i++) printf("%d%c",(res.c.a[i]+1)%MOD," \n"[i==m-1]);
//			print(get(fa[top[x]]));
		} x=fa[top[x]];
	}
}
int main(){
	scanf("%d%d",&n,&m);getinv();
	for(int i=0;i<m;i++) e[i].a[i]=1,e[i].FWT();
//	for(int i=0;i<m;i++) for(int j=0;j<m;j++) printf("%d%c",e[i].a[j]," \n"[j==m-1]);
	for(int i=1;i<=n;i++) scanf("%d",&val[i]);
	for(int i=1,u,v;i<n;i++) scanf("%d%d",&u,&v),adde(u,v),adde(v,u);
	dfs1();dfs2();dfs3();dfs4();build(1,1,n);
//	for(int i=1;i<=n;i++) for(int j=0;j<m;j++) printf("%d%c",f[i][j]," \n"[j==m-1]);
//	for(int i=1;i<=n;i++) for(int j=0;j<m;j++) printf("%d%c",sum[i][j]," \n"[j==m-1]);
//	for(int i=1;i<=n;i++) for(int j=0;j<m;j++) printf("%d%c",fl[i][j].num()," \n"[j==m-1]);
//	for(int i=1;i<=n;i++) for(int j=0;j<m;j++) printf("%d%c",suml[i][j]," \n"[j==m-1]);
//	for(int i=0;i<m;i++) printf("%d%c",t.d.a[i]," \n"[i==m-1]);
	for(int i=1;i<=n;i++) if(top[i]==i){
		int cur=i;while(wson[cur]) cur=wson[cur];
		bot[i]=cur;
	} int qu;scanf("%d",&qu);
	while(qu--){
		char opt[9];scanf("%s",opt+1);
		if(opt[1]=='C'){
			int x,v;scanf("%d%d",&x,&v);
			val[x]=v;change(x);
		} else {
			int k;scanf("%d",&k);
			mat res=query(1,dfn[1],dfn[bot[1]]);
//			for(int i=0;i<m;i++) printf("%d%c",res.d.a[i]," \n"[i==m-1]);
			res.d.IFWT();
			printf("%d\n",res.d.a[k]);
		}
	}
	return 0;
}