1. 程式人生 > 實用技巧 >【牛客7872 J】樹上啟發式合併

【牛客7872 J】樹上啟發式合併

【牛客7872 J】樹上啟發式合併

題意

樹上啟發式合併,求有多少點對滿足,這兩個點x和y相互之間不是祖先和後代的關係
同時滿足\(val[x]+val[y]=2 * val[ lca(x,y) ]\)

題解

根據兩個點不能互為祖先的要求可知:

比較可行的方式是列舉這個作為lca的結點,對於一個作為lca的結點
什麼樣的結點會以它為lca呢,當然是以它的不同的兒子為根結點的子樹中的結點
因此,統計答案的方式也比較巧妙,對於一個作為lca的結點u

  • 首先遍歷它的第一個兒子v1的那棵子樹,用一個mp陣列記錄當前已經遍歷過的結點中每個數出現的次數
    遍歷第1個兒子那棵子樹時把mp維護好。
  • 然後從第2個兒子開始,先對每一個結點v,獲取到當前mp[2*val[u]-val[v]]的大小
    這表示能和結點v一起組成符合條件的點對
    有多少。
  • 這樣查詢完第2個兒子上所有節點後,再把第2個兒子子樹上的所有結點的mp值維護好,依次迴圈這樣一個過程

由於在做這個過程的時候必須保證mp值的準確,所以每次一個lca判斷完後要清空該棵子樹對mp值造成的影響。
那麼考慮什麼樣的結點不用清空呢,那就是該結點作為父親結點的最後一個兒子維護答案時不用清空。
那麼我們怎樣能使時間複雜度儘可能降低呢?那就是把所有兒子中最重的(子樹大小最大的兒子)放在最後一個訪問,這樣就可以節省下清空它的時間複雜度,這就是啟發式合併,運用最後一個兒子不需要清空的性質來降低時間複雜度。

Code

/****************************
* Author : W.A.R            *
* Date : 2020-10-31-20:44   *
****************************/
/*
*/
#include<stdio.h>
#include<string.h>
#include<math.h>
#include<algorithm>
#include<queue>
#include<map>
#include<unordered_map>
#include<stack>
#include<string>
#include<set>
#define mem(a,x) memset(a,x,sizeof(a))
using namespace std;
typedef long long ll;
const int maxn=1e6+10;
const ll mod=1e9+7;

namespace Fast_IO{
    const int MAXL((1 << 18) + 1);int iof, iotp;
    char ioif[MAXL], *ioiS, *ioiT, ioof[MAXL],*iooS=ioof,*iooT=ioof+MAXL-1,ioc,iost[55];
    char Getchar(){
        if (ioiS == ioiT){
            ioiS=ioif;ioiT=ioiS+fread(ioif,1,MAXL,stdin);return (ioiS == ioiT ? EOF : *ioiS++);
        }else return (*ioiS++);
    }
    void Write(){fwrite(ioof,1,iooS-ioof,stdout);iooS=ioof;}
    void Putchar(char x){*iooS++ = x;if (iooS == iooT)Write();}
    inline int read(){
        int x=0;for(iof=1,ioc=Getchar();(ioc<'0'||ioc>'9')&&ioc!=EOF;)iof=ioc=='-'?-1:1,ioc=Getchar();
		if(ioc==EOF)exit(0);
        for(x=0;ioc<='9'&&ioc>='0';ioc=Getchar())x=(x<<3)+(x<<1)+(ioc^48);return x*iof;
    }
    inline long long read_ll(){
        long long x=0;for(iof=1,ioc=Getchar();(ioc<'0'||ioc>'9')&&ioc!=EOF;)iof=ioc=='-'?-1:1,ioc=Getchar();
		if(ioc==EOF)exit(0);
        for(x=0;ioc<='9'&&ioc>='0';ioc=Getchar())x=(x<<3)+(x<<1)+(ioc^48);return x*iof;
    }
    template <class Int>void Print(Int x, char ch = '\0'){
        if(!x)Putchar('0');if(x<0)Putchar('-'),x=-x;while(x)iost[++iotp]=x%10+'0',x/=10;
        while(iotp)Putchar(iost[iotp--]);if (ch)Putchar(ch);
    }
    void Getstr(char *s, int &l){
        for(ioc=Getchar();ioc==' '||ioc=='\n'||ioc=='\t';)ioc=Getchar();
		if(ioc==EOF)exit(0);
        for(l=0;!(ioc==' '||ioc=='\n'||ioc=='\t'||ioc==EOF);ioc=Getchar())s[l++]=ioc;s[l] = 0;
    }
    void Putstr(const char *s){for(int i=0,n=strlen(s);i<n;++i)Putchar(s[i]);}
}
using namespace Fast_IO;
struct node{int to,nxt;}e[maxn];
int son[maxn],siz[maxn],cnt[maxn],head[maxn],val[maxn],ct;
ll ans;
unordered_map<int,int>mp;
void addE(int u,int v){e[++ct].to=v;e[ct].nxt=head[u];head[u]=ct;}
void dfs(int u,int fa){
	siz[u]=1;
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;if(v==fa)continue;
		dfs(v,u);siz[u]+=siz[v];
		if(siz[v]>siz[son[u]])son[u]=v;
	}
}
void add(int u,int fa,int value){
	mp[val[u]]+=value;
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;
		if(v==fa)continue;
		add(v,u,value);
	}
}
void calc(int u,int fa,int lca){
	ans+=mp[2*val[lca]-val[u]];
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;
		if(v==fa)continue;
		calc(v,u,lca);
	}
}
void getAns(int u,int fa,bool heavy){
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;
		if(v==fa||v==son[u])continue;
		getAns(v,u,0);
	}
	if(son[u])getAns(son[u],u,1);
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;
		if(v==fa||v==son[u])continue;
		calc(v,u,u);
		add(v,u,1);
	}
	mp[val[u]]++;
	if(!heavy)add(u,fa,-1);
}
int main(){
	int n=read();
	for(int i=1;i<=n;i++)val[i]=read();
	for(int i=1;i<n;i++){int u=read(),v=read();addE(u,v);addE(v,u);}
	dfs(1,0);
	getAns(1,0,0);
	printf("%lld\n",ans<<1);
	return 0;
}