1. 程式人生 > 其它 >樹上差分 學習筆記

樹上差分 學習筆記

前置知識:差分

例題:P2367 語文成績

序列維護區間加,最後詢問序列最小值。


線段樹

差分即可。

對於在原數列 \(a_u\)\(a_v\) 都加一個 \(x\),考慮在差分陣列 \(b\) 中,變化的只有 \(b_u\)\(b_{v+1}\)

因為在原數列 \(a_u\)\(a_v\) 都加一個 \(x\),對於 \(u\) 之前和 \(v+1\) 之後數的差不會有任何變化。\(u+1\)\(v\) 之前的數也不會有變化,實際上,只有 \(a_u\) 對於 \(a_{u-1}\) 的差相較於之前大了 \(x\)\(a_{v+1}\) 對於 \(a_v\) 的差相較於之前小了 \(x\)

。所以維護區間加只需要 \(b_u+x\)\(b_{v+1}-x\) 即可。

最後,對差分陣列 \(b\) 跑一遍字首和即可還原出原陣列。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;

long long n,m,a[5000010],b[5000010],s[5000010],ans=1000000000000;

int main(){
	long long i,j,u,v;
	scanf("%lld %lld",&n,&m);
	for(i=1;i<=n;i++) scanf("%lld",&a[i]);
	for(i=1;i<=n;i++){
		b[i]=a[i]-a[i-1];
	}
	while(m--){
		cin>>u>>v>>j;
		b[u]+=j;
		b[v+1]-=j;
	}
	for(i=1;i<=n;i++){
		s[i]=b[i]+s[i-1];
	}
	for(i=1;i<=n;i++){
		ans=min(ans,s[i]);
	}
	cout<<ans<<endl;
	return 0;
}

樹上差分

點差分

例題:P3128 [USACO15DEC]Max Flow P

在樹上給出多條路徑,問所有路徑經過最多的點經過了多少次。


典型的點差分。

考慮對於差分陣列 \(s\),一條從 \(u\)\(v\) 的路徑好像是隻需要 \(s_u+1,s_v+1,s_{f_{\text{LCA}(u,v)}}-1\) 就行。

但這是不對的。

考慮在樹上字首和恢復原值時,\(\text{LCA(u,v)}\) 的值會因為 \(s_u\)\(s_v\) 都加 \(1\) 而導致多加了一個 \(1\)。所以在點差分時 \(s_{\text{LCA}(u,v)}\) 也應該減 \(1\)

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;

const long long loglim=20;
long long n,m,h[50010],tot,f[50010][loglim+5],d[50010],s[50010],ans;
struct edge{
	long long v,nxt;
}e[100010];

void add(long long u,long long v){
	tot++;
	e[tot].v=v; e[tot].nxt=h[u];
	h[u]=tot;
}
void build(long long fr,long long u,long long dep){
	long long i,j,v;
	d[u]=dep;
	for(i=1;i<=loglim;i++){
		f[u][i]=f[f[u][i-1]][i-1];
	}
	for(i=h[u];i;i=e[i].nxt){
		v=e[i].v;
		if(v!=fr){
			f[v][0]=u;
			build(u,v,dep+1);
		}
	}
}
long long lca(long long u,long long v){
	long long i,j;
	if(d[u]<d[v]) swap(u,v);
	for(i=loglim;i>=0;i--){
		if(d[f[u][i]]>=d[v]) u=f[u][i];
	}
	if(u==v) return u;
	for(i=loglim;i>=0;i--){
		if(f[u][i]!=f[v][i]){
			u=f[u][i];
			v=f[v][i];
		}
	}
	return f[u][0];
}
void solve(long long fr,long long u){//樹上字首和
	long long i,j,v;
	for(i=h[u];i;i=e[i].nxt){
		v=e[i].v;
		if(v!=fr){
			solve(u,v);
			s[u]+=s[v];
		}
	}
}

int main(){
	long long i,j,u,v;
	cin>>n>>m;
	for(i=1;i<n;i++){
		cin>>u>>v;
		add(u,v);
		add(v,u);
	}
	build(0,1,1);
	while(m--){
		cin>>u>>v;
		s[u]++;		 //樹
		s[v]++;		 //上
		j=lca(u,v);	 //點
		s[j]--;		 //差
		s[f[j][0]]--;//分
	}
	solve(0,1);
	for(i=1;i<=n;i++){
		ans=max(ans,s[i]);
	}
	cout<<ans<<endl;
	return 0;
}

邊差分

例題:P2680 [NOIP2015 提高組] 運輸計劃

樹邊有非負權,給定多條路徑。現可將某邊權變為0,使得最長路徑最小。輸出該最小值。


gx:好難!

zkw:屑!

會樹剖的zkw把這題秒了,gx只能去寫他151行的樹上差分。


直接說正解了:

最大路徑最小暗示二分答案,顯然這題答案滿足單調性。二分答案的時候 \(check(mid)\) 判斷能否通過把某條邊變成 0 來使得答案小於等於 \(mid\)

我們把長度(可以預處理)大於 \(mid\) 的路徑叫做大路徑,反過來就是小路徑。我們需要找出一條邊是所有大路徑都經過的,看看能否通過把它變成0後使所有大路徑的長度小於等於 \(mid\)(判斷的時候只用判斷最長的大路徑(可以預處理)是否滿足即可)。

判斷一條邊是否被所有大路徑經過可以用邊差分。首先把樹邊經過的次數存到兒子節點處。然後每有一條 \(u\)\(v\) 的路徑,只需 \(s_u+1,s_v+1,s_{\text{LCA(u,v)}}-2\) 即可。最後樹上字首和,還原原陣列,即每條邊經過的次數。

總時間複雜度:\(O(m \log^2 n)\)

當然可以通過初始化 \(\text{LCA}\) 把時間複雜度降到 \(O(m \log n)\),但這樣也能過。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;

const int loglim=18;
int n,m,h[300010],tot,f[300010][loglim+5],g[300010][loglim+5],d[300010],s[300010],w[300010];
struct edge{
	int v,w,nxt;
}e[600010];
struct path{
	int u,v,op;
}p[300010];

inline int read(){
    char ch=getchar(); 
    int x=0,f=1;
    while(ch<'0' || ch>'9') {
        if(ch=='-') 
            f=-1;
        ch=getchar();
    } 
    while('0'<=ch && ch<='9') {
        x=x*10+ch-'0';
        ch=getchar();
    } 
    return x*f;
}
inline void add(int u,int v,int w){
	tot++;
	e[tot].v=v; e[tot].w=w; e[tot].nxt=h[u];
	h[u]=tot;
}
inline void build(int fr,int u,int dep){
	register int i,j,v;
	d[u]=dep;
	for(i=1;i<=loglim;i++){
		f[u][i]=f[f[u][i-1]][i-1];
		g[u][i]=g[u][i-1]+g[f[u][i-1]][i-1];
	}
	for(i=h[u];i;i=e[i].nxt){
		v=e[i].v;
		if(v!=fr){
			f[v][0]=u;
			g[v][0]=e[i].w;
			w[v]=e[i].w;
			build(u,v,dep+1);
		}
	}
}
inline int lca(int u,int v){
	register int i,j;
	if(d[u]<d[v]) swap(u,v);
	for(i=loglim;i>=0;i--){
		if(d[f[u][i]]>=d[v]){
			u=f[u][i];
		}
	}
	if(u==v) return u;
	for(i=loglim;i>=0;i--){
		if(f[u][i]!=f[v][i]){
			u=f[u][i];
			v=f[v][i];
		}
	}
	return f[u][0];
}
inline int G(int u,int v){
	register int i,j,now=0;
	if(d[u]<d[v]) swap(u,v);
	for(i=loglim;i>=0;i--){
		if(d[f[u][i]]>=d[v]){
			now+=g[u][i];
			u=f[u][i];
		}
	}
	if(u==v) return now;
	for(i=loglim;i>=0;i--){
		if(f[u][i]!=f[v][i]){
			now+=g[u][i];
			now+=g[v][i];
			u=f[u][i];
			v=f[v][i];
		}
	}
	now+=g[u][0];
	now+=g[v][0];
	return now;
}
inline void solve(int fr,int u){
	register int i,j,v;
	for(i=h[u];i;i=e[i].nxt){
		v=e[i].v;
		if(v!=fr){
			solve(u,v);
			s[u]+=s[v];
		}
	}
}
inline bool cmpp(path u,path v){
	return u.op>v.op;
}
inline bool check(int gx){
	register int i,j,u,v,sum=0;
	memset(s,0,sizeof(s));
	for(i=1;i<=m;i++){
		if(p[i].op<=gx) break;
		else sum++;
		u=p[i].u; v=p[i].v;
		s[u]++; s[v]++;
		s[lca(u,v)]-=2;
	}
	solve(0,1);
	if(sum==0) return 1;
	for(i=1;i<=n;i++){
		if(s[i]==sum){
			if(p[1].op-w[i]<=gx) return 1;
		}
	}
	return 0;
}

int main(){
	register int i,j,u,v,l,r,mid;
	cin>>n>>m;
	for(i=1;i<n;i++){
		u=read(); v=read(); j=read();
		add(u,v,j);
		add(v,u,j);
	}
	build(0,1,1);
	for(i=1;i<=m;i++){
		p[i].u=read(); p[i].v=read();
		u=p[i].u; v=p[i].v;
		p[i].op=G(u,v);
	}
	sort(p+1,p+m+1,cmpp);
	l=0;r=100000000000;
	while(l<r){
		mid=(l+r)>>1;
		if(check(mid)){
			r=mid;
		}
		else{
			l=mid+1;
		}
	}
	cout<<l<<endl;
	return 0;
}