1. 程式人生 > 其它 >【題解】[ZJOI2019]語言

【題解】[ZJOI2019]語言

Problem

\(\text{Solution:}\)

菜雞太菜了想了好久沒有思路……只知道要求樹上鍊的並集但不知道咋整……雖然題目演算法線段樹合併和樹上差分看到題就能想到……但是怎麼做還是有思維難度……(對筆者來說)

看了好久的題解都沒有看懂 這次寫的詳細一點。

  • 所謂矩陣面積並

某些題解中說了一句 “可以用樹剖套掃描線” “就是一個矩陣面積並” 的做法,複雜度 \(O(n\log^3 n).\)

實際上,我們對樹進行樹剖後,每一條路徑 \((s,t)\) 都可以被劃分為 \(\log n\) 個線段。我們可以把每一段被劃分的重鏈區間看做矩形,那我們對每個點求的鏈並集,實際上就是一個抽象的 “矩形面積並” 。

複雜度是樹剖的 \(\log\) 和掃面線本身的兩個 \(\log\). 我沒有實現這種做法。

  • 線段樹合併、樹上差分與樹剖

常規的 \(\log^2 n\) 做法是這個。

由於有序對不好算,我們可以算無序對的個數,除以二即可。

\(S_u\) 表示 \(u\) 可以到達的點的集合(不包括它自己)。那麼 \(ans=\frac{\sum S_i}{2}.\)

如何對每一個點計算出它的 \(S\) 呢?

首先明確一點:我們對每一個點建立的線段樹是一棵 以 dfs序 為基礎的線段樹,它維護了整棵樹的資訊。

那麼,對每一個點都有這樣一棵線段樹,對於每一個語言傳遞過程,我們在路上的每一個點上都進行一次區間覆蓋操作,最終每一個點上線段樹上被覆蓋的長度就是它的 \(S\)

.

顯然這玩意複雜度不對。

首先,對樹上路徑修改,我們需要用到樹剖。

那麼對於一條路徑上的點,我們可以自然想到用樹上差分的思路來優化。

並不知道一棵支援區間修改的樹咋合併) 上文所述,我們需要維護一棵線段樹上被覆蓋的長度,並且是 全域性詢問

那這個東西就長得很像掃描線了。 維護區間被覆蓋的次數來更新即可。

那麼合併呢?

合併要注意一下細節:每一次合併到的點都要記錄資訊,因為這種寫法的區間覆蓋次數沒有什麼下傳操作,不要只在葉子上維護 pushup 操作,這樣維護的資訊是錯誤而不全面的。

其他寫法都一樣,複雜度是一個 \(\log\).

那麼總結就是:考慮暴力,每一個點都記錄一下鏈上資訊,又因為需要樹上修改路徑需要樹剖操作,而對於一條路徑上的點我們可以考慮樹上差分優化,進而利用線段樹合併來解決這題。

時間複雜度:\(O(n\log^2 n).\)

空間分析

之前從神魚那裡見到,空間複雜度是 \(2n\log n\) 的。

計算一下:\(Num=2\cdot 10^5 \cdot \log 10^5=3.32\cdot 10^6\) 級別。然而程式碼中,進行修改的操作達到了 \(O(m\log n)\) 級別的個數,也就是多了個 \(\log\) ,約為\(5.5\cdot 10^7\)級別。空間沒有卡滿,程式碼開到 \(2\cdot 10^7\) 可以過去。

#include<bits/stdc++.h>
using namespace std;
const int MAXN=2e7+10;
int ls[MAXN],rs[MAXN],topp,rub[MAXN],node;
int cnt[MAXN],len[MAXN],id[MAXN],rk[MAXN];
int rt[MAXN],pa[MAXN],siz[MAXN],son[MAXN];
int dep[MAXN],head[MAXN],tot,n,m,top[MAXN];
long long ans;
struct E {
	int nxt,to;
} e[MAXN];

inline int Max(int x,int y){return x>y?x:y;}

inline void add(int x,int y) {
	e[++tot] = ( E ) {
		head [ x ] , y
	} ;
	head[x]=tot;
}

void dfs1(int x,int fa) {
	pa [ x ] = fa ;
	siz [ x ] = 1 ;
	dep [ x ] = dep [ fa ] + 1 ;
	for ( int i=head[x]; i; i=e[i].nxt) {
		int j=e[i].to;
		if(j==fa)continue;
		dfs1(j,x);
		siz[x]+=siz[j];
		if(siz[j]>siz[son[x]])son[x]=j;
	}
}

int dfstime;
void dfs2(int u,int t) {
	top [ u ] = t ;
	id [ u ] = ++ dfstime ;
	if ( ! son [ u ] )
		return ;
	dfs2 ( son [ u ] , t ) ;
	for ( int i = head [ u ] ; i ; i = e [ i ] .nxt ) {
		int j = e [ i ] .to ;
		if ( j == son [ u ] || j == pa [ u ] )
			continue ;
		dfs2 ( j , j ) ;
	}
}

inline void del(int x) {
	rub[++topp]=x;
	len[x]=cnt[x]=ls[x]=rs[x]=0;
}

inline int New() {
	if(topp)return rub[topp--];
	return ++node;
}

inline void pushup(int x,int l,int r) {
	if(cnt[x]>0)
		len[x]=r-l+1;
	else
		len[x]=len[ls[x]]+len[rs[x]];
}

void change(int &x,int L,int R,int l,int r,int v) {
	if(!x)x=New();
	if ( l <= L && R <= r ) {
		cnt[x]+=v;
		pushup(x,L,R);
		return;
	}
	int mid=(L+R)>>1;
	if(l<=mid)change(ls[x],L,mid,l,r,v);
	if(mid<r)change(rs[x],mid+1,R,l,r,v);
	pushup(x,L,R);
}

int merge(int x,int y,int l,int r) {
	if(!x||!y)return x+y;
	cnt[x]+=cnt[y];
	if(l==r) {
		pushup(x,l,r);
		del(y);
		return x;
	}
	int mid=(l+r)>>1;
	ls[x]=merge(ls[x],ls[y],l,mid);
	rs[x]=merge(rs[x],rs[y],mid+1,r);
	pushup(x,l,r);
	del(y);
	return x;
}

int query(int x,int L,int R,int l,int r) {
	if(L>=l&&R<=r)return len[x];
	int mid=(L+R)>>1,val=0;
	if(l<=mid)val+=query(ls[x],L,mid,l,r);
	if(mid<r)val+=query(rs[x],mid+1,R,l,r);
	return val;
}

void changes(int root,int x,int y,int v) {
	while(top[x]!=top[y]) {
		if(dep[top[x]]>=dep[top[y]]) {
			change(rt[root],1,n,id[top[x]],id[x],v);
			x=pa[top[x]];
		} else {
			change(rt[root],1,n,id[top[y]],id[y],v);
			y=pa[top[y]];
		}
	}
	if(id[x]<=id[y])change(rt[root],1,n,id[x],id[y],v);
	else change(rt[root],1,n,id[y],id[x],v);
}

inline int LCA(int x,int y) {
	while ( top [ x ] != top [ y ] ) {
		if(dep[top[x]]>=dep[top[y]])x=pa[top[x]];
		else y=pa[top[y]];
	}
	return dep [ x ] < dep [ y ] ? x : y ;
}

void dfs3 ( int x ) {
	for ( int i = head [ x ] ; i ; i = e [ i ] .nxt ) {
		int j = e [ i ] .to ;
		if ( j == pa [ x ] )
			continue ;
		dfs3 ( j ) ;
		rt [ x ] = merge ( rt [ x ] , rt [ j ] , 1 , n ) ;
	}
	ans += Max(query ( rt [ x ] , 1 , n , 1 , n ) -1,0) ;
}

signed main() {
	freopen("1.in","r",stdin);
	freopen("my.out","w",stdout);
	scanf("%d%d",&n,&m);
	for(int i=1; i<n; ++i) {
		int x,y;
		scanf("%d%d",&x,&y);
		add(x,y);
		add(y,x);
	}
	dfs1(1,1);
	dfs2(1,1);
	for(; m; m--) {
		int s,t;
		scanf("%d%d",&s,&t);
		int lca=LCA(s,t);
		changes(s,s,t,1);
		changes(t,s,t,1);
		changes(lca,s,t,-1);
		if(lca==1)continue;
		changes(pa[lca],s,t,-1);
	}
	dfs3(1);
	ans>>=1;
	printf("%lld\n",ans);
	return 0;
}