1. 程式人生 > >tree - dp - 長鏈剖分

tree - dp - 長鏈剖分

題目大意:
給你一顆樹,點有點權,對所有三元組(x,y,z),滿足dis(x,y)=dis(y,z)=dis(x,z),統計a(x)a(y)+a(x)a(z)+a(y)a(z)的和。n<=100000。
題解:
條件等價於存在一箇中心點。
列舉三個點的LCA,然後劈成兩半,一半是鏈,一半是Y倒過來寫,發現二者能合併當且僅當鏈長等於Y倒過來寫的下面長度減去上面長度,而這二者不超過子樹深度,因此長鏈剖分即可。

#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define lint long long
#define mod 998244353 #define ull unsigned lint #define db long double #define pb push_back #define mp make_pair #define fir first #define sec second #define gc getchar() #define debug(x) cerr<<#x<<"="<<x #define sp <<" " #define ln <<endl using namespace std; typedef pair<
int,int> pii; typedef set<int>::iterator sit; const int N=100010; struct edges{ int to,pre; }e[N<<1];int h[N],etop,a[N],l[N],son[N];lint ans=0; inline int add_edge(int u,int v) { return e[++etop].to=v,e[etop].pre=h[u],h[u]=etop; } inline int inn() { int x,ch;while((ch=gc)<'0'||ch>
'9'); x=ch^'0';while((ch=gc)>='0'&&ch<='9') x=(x<<1)+(x<<3)+(ch^'0');return x; } int getl(int x,int fa=0) { l[x]=1,son[x]=0; for(int i=h[x],y;i;i=e[i].pre) if((y=e[i].to)^fa) { l[x]=max(l[x],getl(y,x)+1); if(l[y]>l[son[x]]) son[x]=y; } return l[x]; } inline int *arr(int n) { int *p=new int[n];return memset(p,0,sizeof(int)*n),p; } #define P(x) (x>=mod?x-=mod:0) int dfs(int x,int fa,int *Ax,int *Bx,int *Cx,int *Dx) { if(son[x]) dfs(son[x],x,Ax-1,Bx-1,Cx+1,Dx+1),ans+=(Ax[0]+(lint)a[x]*Bx[0])%mod,P(ans); Cx[0]=a[x],Dx[0]=1; for(int i=h[x],y;i;i=e[i].pre) if((y=e[i].to)!=fa&&e[i].to!=son[x]) { int *Ay=arr(l[y]*2+1)+l[y],*By=arr(l[y]*2+1)+l[y], *Cy=arr(l[y]+1),*Dy=arr(l[y]+1); dfs(y,x,Ay,By,Cy,Dy); rep(d,0,l[y]) { if(d) ans+=((lint)Dx[d-1]*Ay[d]+(lint)Cx[d-1]*By[d])%mod,P(ans); ans+=((lint)Ax[d+1]*Dy[d]+(lint)Bx[d+1]*Cy[d])%mod,P(ans); } rep(d,0,l[y]-1) Ax[d]+=Ay[d+1],Bx[d]+=By[d+1],P(Ax[d]),P(Bx[d]); rep(d,1,l[y]) Ax[d]+=(lint)Cx[d]*Cy[d-1]%mod,P(Ax[d]), Bx[d]+=((lint)Cx[d]*Dy[d-1]+(lint)Dx[d]*Cy[d-1])%mod,P(Bx[d]); rep(d,0,l[y]) Cx[d+1]+=Cy[d],Dx[d+1]+=Dy[d],P(Cx[d+1]),P(Dx[d+1]); } return 0; } int main() { int n=inn(),u,v;n=inn(); rep(i,1,n-1) u=inn(),v=inn(),add_edge(u,v),add_edge(v,u); rep(i,1,n) a[i]=inn(),(a[i]>=mod?a[i]%=mod:0); getl(1); int *A=arr(l[1]*2+1)+l[1],*B=arr(l[1]*2+1)+l[1], *C=arr(l[1]+1),*D=arr(l[1]+1); return dfs(1,0,A,B,C,D),!printf("%d\n",int(ans)); }