1. 程式人生 > 其它 >LOJ3626 「2021 集訓隊互測」愚蠢的線上法官

LOJ3626 「2021 集訓隊互測」愚蠢的線上法官

我們不妨設 \(A\)\([1,n]\) ,且為 \(\text{dfs}\) 序。那麼顯然該矩陣的第一列的列向量是完全相同的,我們再找到根節點的兒子,那麼對於這些兒子的子樹,子樹內是該兒子的值,子樹外是根節點的值。如此遞迴,我們發現每一次遞迴的時候,子樹外的節點是不變的,子樹內的節點是變成子樹的根節點的權值。

我們發現行向量也是類似的,考慮進行一個類似於高斯消元的過程,每一個節點減去其父親,我們就可以將每一個節點不屬於該子樹的部分消去。那麼此時這個矩陣上的點就只有 \(O(n^2)\) 個了,但是你經過一定的位置交換可以發現,你只有右上方的位置有值,所以直接乘一波對角線就行了。

此時我們再來考慮一波如果 \(A\)

並不是 \(n\) 的排列,首先如果存在重複的部分,那麼這一行一列直接可以被我們消成 \(0\) ,所以行列式就是 \(0\) ,這樣只需要考慮如果有些位置不存在怎麼辦。我們考慮不存在的位置對於權值的影響,其實我們可以直接利用給一行或者一列加來使得位置消去,只不過需要注意考慮對於行列式的影響。

給一行或一列加好像不太好做?考慮對於一個節點 \(u\) ,如果這個節點不存在子樹節點在 \(A\) 中或者只有一個存在,那麼不會有位置的值與 \(u\) 相關;考慮如果存在兩個,但是這兩個互相包含,也不會存在值與 \(u\) 相關;考慮有多個互不包含的點,他會形成一個類似於如圖的結構。

\[\begin{bmatrix} a_1 & x & x & z & \cdots & z\\ x & a_2 & x & z & \cdots & z\\ x & x & a_3 & z & \cdots & z\\ z & z & z & a_4 & \cdots & y\\ \cdots & \cdots & \cdots & \cdots & \cdots & \cdots\\ z & z & z & y & \cdots & a_n \end{bmatrix} \]

我們考慮利用遞迴處理這個結構(這個結構實際上就是一個樹形結構),考慮在當前節點,我們先將最外層的點權(在上面的例子中就是 \(z\)

)給減掉,這樣我們選擇遞迴求解給當前矩陣加上一個數 \(x\) 之後的行列式,可以證明是一個一次函式 \(ax+b\) (證明過程就是求解過程),我們此時考慮我們已經得到了兩個子矩陣的一次函式,其中一個是上述的,另一個是 \(cx+d\) ,合併之後的結果可以證明是 \((ad+bc)x+bd\)

我們不妨先將這個函式的貢獻拆開,不帶 \(x\) 的部分便是本身的行列式,帶 \(x\) 的部分便是加上的值的貢獻,\(x\) 的幾次項我們邊可以看成是選擇多少個位置加 \(x\) 。首先考慮如果是直接相乘這兩個一次函式的話,便是不考慮 \(z\) 部分可以加上一個 \(x\) 的矩陣的行列式。考慮二次項相當於是選擇兩個位置加上 \(x\)

,且是一個在上方一個在下方,我們考慮必然存在唯一對應的一個排列,使得選擇兩個位置在 \(z\) 的與之一一對應且貢獻相反,也就是說二次項的貢獻是不存在的。

於是我們直接樹上遞迴即可,甚至前面的消元都不需要了,膜拜李神。

#include<bits/stdc++.h>
using namespace std;
const int N=5e5+5;
const int MOD=998244353;
int ADD(int x,int y){return x+y>=MOD?x+y-MOD:x+y;}
int SUB(int x,int y){return x-y<0?x-y+MOD:x-y;}
int TIME(int x,int y){return (int)(1ll*x*y%MOD);}
int n,m,v[N],cnt[N];
struct Edge{int nxt,to;}e[N<<1];int fir[N];
void add(int u,int v,int i){e[i]=(Edge){fir[u],v},fir[u]=i;}
pair<int,int> dp(int u,int fa,int tmp){
	tmp=SUB(v[u],v[fa]);pair<int,int> f;
	if(cnt[u]) f=make_pair(1,0);else f=make_pair(0,1);
	for(int i=fir[u];i;i=e[i].nxt){
		int v=e[i].to;if(v==fa) continue;
		pair<int,int> g=dp(v,u,tmp);
		f.first=ADD(TIME(f.first,g.second),TIME(f.second,g.first));
		f.second=TIME(f.second,g.second);
	}
	return f.second=ADD(f.second,TIME(f.first,tmp)),f;
}
int main(){
	cin>>n>>m;
	for(int i=1;i<=n;++i) scanf("%d",&v[i]);
	for(int i=1,x;i<=m;++i) scanf("%d",&x),cnt[x]++;
	for(int i=1,u,v;i<n;++i) scanf("%d%d",&u,&v),add(u,v,i<<1),add(v,u,i<<1|1);
	for(int i=1;i<=n;++i) if(cnt[i]>1) return printf("0\n"),0;
	return printf("%d\n",dp(1,0,0).second),0;
}