tree - dp - 長鏈剖分
阿新 • • 發佈:2018-11-26
題目大意:
給你一顆樹,點有點權,對所有三元組(x,y,z),滿足dis(x,y)=dis(y,z)=dis(x,z),統計a(x)a(y)+a(x)a(z)+a(y)a(z)的和。n<=100000。
題解:
條件等價於存在一箇中心點。
列舉三個點的LCA,然後劈成兩半,一半是鏈,一半是Y倒過來寫,發現二者能合併當且僅當鏈長等於Y倒過來寫的下面長度減去上面長度,而這二者不超過子樹深度,因此長鏈剖分即可。
#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define lint long long
#define mod 998244353
#define ull unsigned lint
#define db long double
#define pb push_back
#define mp make_pair
#define fir first
#define sec second
#define gc getchar()
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
typedef pair< int,int> pii;
typedef set<int>::iterator sit;
const int N=100010;
struct edges{
int to,pre;
}e[N<<1];int h[N],etop,a[N],l[N],son[N];lint ans=0;
inline int add_edge(int u,int v) { return e[++etop].to=v,e[etop].pre=h[u],h[u]=etop; }
inline int inn()
{
int x,ch;while((ch=gc)<'0'||ch> '9');
x=ch^'0';while((ch=gc)>='0'&&ch<='9')
x=(x<<1)+(x<<3)+(ch^'0');return x;
}
int getl(int x,int fa=0)
{
l[x]=1,son[x]=0;
for(int i=h[x],y;i;i=e[i].pre)
if((y=e[i].to)^fa)
{
l[x]=max(l[x],getl(y,x)+1);
if(l[y]>l[son[x]]) son[x]=y;
}
return l[x];
}
inline int *arr(int n) { int *p=new int[n];return memset(p,0,sizeof(int)*n),p; }
#define P(x) (x>=mod?x-=mod:0)
int dfs(int x,int fa,int *Ax,int *Bx,int *Cx,int *Dx)
{
if(son[x]) dfs(son[x],x,Ax-1,Bx-1,Cx+1,Dx+1),ans+=(Ax[0]+(lint)a[x]*Bx[0])%mod,P(ans);
Cx[0]=a[x],Dx[0]=1;
for(int i=h[x],y;i;i=e[i].pre)
if((y=e[i].to)!=fa&&e[i].to!=son[x])
{
int *Ay=arr(l[y]*2+1)+l[y],*By=arr(l[y]*2+1)+l[y],
*Cy=arr(l[y]+1),*Dy=arr(l[y]+1);
dfs(y,x,Ay,By,Cy,Dy);
rep(d,0,l[y])
{
if(d) ans+=((lint)Dx[d-1]*Ay[d]+(lint)Cx[d-1]*By[d])%mod,P(ans);
ans+=((lint)Ax[d+1]*Dy[d]+(lint)Bx[d+1]*Cy[d])%mod,P(ans);
}
rep(d,0,l[y]-1) Ax[d]+=Ay[d+1],Bx[d]+=By[d+1],P(Ax[d]),P(Bx[d]);
rep(d,1,l[y]) Ax[d]+=(lint)Cx[d]*Cy[d-1]%mod,P(Ax[d]),
Bx[d]+=((lint)Cx[d]*Dy[d-1]+(lint)Dx[d]*Cy[d-1])%mod,P(Bx[d]);
rep(d,0,l[y]) Cx[d+1]+=Cy[d],Dx[d+1]+=Dy[d],P(Cx[d+1]),P(Dx[d+1]);
}
return 0;
}
int main()
{
int n=inn(),u,v;n=inn();
rep(i,1,n-1) u=inn(),v=inn(),add_edge(u,v),add_edge(v,u);
rep(i,1,n) a[i]=inn(),(a[i]>=mod?a[i]%=mod:0);
getl(1);
int *A=arr(l[1]*2+1)+l[1],*B=arr(l[1]*2+1)+l[1],
*C=arr(l[1]+1),*D=arr(l[1]+1);
return dfs(1,0,A,B,C,D),!printf("%d\n",int(ans));
}