1. 程式人生 > 實用技巧 >演算法淺談之樹上差分

演算法淺談之樹上差分

先放一道例題[USACO15DEC]Max Flow P

題目大意

給你一棵\(n\)個點的樹,有\(k\)條管道,每條管道有個起始點和終結點。從起始點到終結點的路徑上每個經過的點權值都要\(+1\)

現在問你這\(k\)條管道都處理完後權值最大的點的權值是多少

\(N\le50000\)

\(K\le100000\)

分析

乍一看有一點棘手啊。

如果是條鏈

我們考慮這棵樹是一條鏈。那麼就是一個正常的差分:起始點\(+1\),終結點後面的點\(-1\),最後哪一個變數從頭加到尾,記錄最大值。

其他情況

那就只能用樹上差分來做。

樹上差分。顧名思義,就是在樹上進行差分操作。具體是:起始點和終止點\(+1\),LCA與LCA的父親\(-1\)

。這樣就可以完成樹上差分

不會LCA的請看https://www.cnblogs.com/hulean/p/11144059.html

具體為什麼的話,畫圖就很明顯了。

dfs

差分處理完後,我們只需要一遍dfs來累加每個點的值,並且維護最大值。

這道題就做完了

#include <bits/stdc++.h>
using namespace std ;
const int MAXN = 50000 + 5 ;
struct Node {
	int next , to ;
} edge[ MAXN << 1 ] ;
int head[ MAXN ] , cnt ;
int n , k , w[ MAXN ] , ans ;
int deep[ MAXN ] , fa[ MAXN ][ 21 ] ;
inline int read () {
	int tot = 0 , f = 1 ; char c = getchar () ;
	while ( c < '0' || c > '9' ) { if ( c == '-' ) f = -1 ; c = getchar () ; }
	while ( c >= '0' && c <= '9' ) { tot = tot * 10 + c - '0' ; c = getchar () ; }
	return tot * f ;
}
inline void add ( int x , int y ) {
	edge[ ++ cnt ].next = head[ x ] ;
	edge[ cnt ].to = y ;
	head[ x ] = cnt ;
}
inline void dfs ( int u , int father ) {
	fa[ u ][ 0 ] = father ; deep[ u ] = deep[ father ] + 1 ;
	for ( int i = 1 ; i <= 20 ; i ++ )
		fa[ u ][ i ] = fa[ fa[ u ][ i - 1 ] ][ i - 1 ] ;
	for ( int i = head[ u ] ; i ; i = edge[ i ].next ) {
		int v = edge[ i ].to ;
		if ( v == father ) continue ;
		dfs ( v , u ) ;
	}
}
inline int Lca ( int x , int y ) {
	if ( x == y ) return x ;
	if ( deep[ x ] < deep[ y ] ) swap ( x , y ) ;
	for ( int i = 20 ; i >= 0 ; i -- ) {
		if ( deep[ fa[ x ][ i ] ] >= deep[ y ] ) x = fa[ x ][ i ] ;
	}
	if ( x == y ) return x ;
	for ( int i = 20 ; i >= 0 ; i -- ) {
		if ( fa[ x ][ i ] != fa[ y ][ i ] ) x = fa[ x ][ i ] , y = fa[ y ][ i ] ;
	}
	return fa[ x ][ 0 ] ;
}
inline void work ( int u , int father ) {
	for ( int i = head[ u ] ; i ; i = edge[ i ].next ) {
		int v = edge[ i ].to ;
		if ( v == father ) continue ;
		work ( v , u ) ;
		w[ u ] += w[ v ] ;
	}
	ans = max ( ans , w[ u ] ) ;
}
signed main () {
	n = read () ; k = read () ;
	for ( int i = 1 ; i < n ; i ++ ) {
		int x = read () , y = read () ;
		add ( x , y ) ; add ( y , x ) ;
	}
	dfs ( 1 , 0 ) ;
	for ( int i = 1 ; i <= k ; i ++ ) {
		int x = read () , y = read () ;
		w[ x ] ++ ; w[ y ] ++ ;
		int lca = Lca ( x , y ) ;
		w[ lca ] -- ; w[ fa[ lca ][ 0 ] ] -- ;
	}
	work ( 1 , 0 ) ;
	printf ( "%d\n" , ans ) ;
	return 0 ;
}