【樹鏈剖分模板】bzoj1036 樹的統計
阿新 • • 發佈:2018-12-11
#include<cstdio> #include<algorithm> using namespace std; const int N=30000+5; int n,v[N]; int num,last[N],nxt[2*N],ver[2*N]; inline void add(int x,int y) {nxt[++num]=last[x]; last[x]=num; ver[num]=y; } int siz[N],son[N],fa[N],deep[N]; void build(int x) {siz[x]=1; son[x]=0; for(int i=last[x];i;i=nxt[i]) {int y=ver[i]; if(y!=fa[x]) {fa[y]=x; deep[y]=deep[x]+1; build(y); if(siz[y]>siz[son[x]]) son[x]=y; siz[x]+=siz[y]; } } } int id,a[N],top[N],ord[N]; void dfs(int x) {a[++id]=v[x]; ord[x]=id; if(x==son[fa[x]]) top[x]=top[fa[x]]; else top[x]=x; if(son[x]) dfs(son[x]); for(int i=last[x];i;i=nxt[i]) if(ver[i]!=fa[x] && ver[i]!=son[x]) dfs(ver[i]); } struct point{int l,r,sum,maxx;}t[4*N]; void build(int i,int p,int q) {t[i].l=p; t[i].r=q; if(p==q) {t[i].maxx=t[i].sum=a[p]; return;} int mid=p+q>>1; build(2*i,p,mid); build(2*i+1,mid+1,q); t[i].sum=t[2*i].sum+t[2*i+1].sum; t[i].maxx=max(t[2*i].maxx,t[2*i+1].maxx); } void change(int i,int p,int x) {if(t[i].l==t[i].r && t[i].l==p) {t[i].maxx=t[i].sum=x; return;} int mid=t[i].l+t[i].r>>1; if(p<=mid) change(2*i,p,x); else change(2*i+1,p,x); t[i].sum=t[2*i].sum+t[2*i+1].sum; t[i].maxx=max(t[2*i].maxx,t[2*i+1].maxx); } point ask(int i,int p,int q) {if(p<=t[i].l && t[i].r<=q) return t[i]; int mid=t[i].l+t[i].r>>1; if(q<=mid) return ask(2*i,p,q); else if(p>mid) return ask(2*i+1,p,q); else {point re,r1=ask(2*i,p,q),r2=ask(2*i+1,p,q); re.sum=r1.sum+r2.sum; re.maxx=max(r1.maxx,r2.maxx); return re; } } inline point query(int x,int y) {point re; re.sum=0; re.maxx=-30005; while(top[x]!=top[y]) {if(deep[top[x]]<deep[top[y]]) swap(x,y); point k=ask(1,ord[top[x]],ord[x]); re.sum+=k.sum; re.maxx=max(re.maxx,k.maxx); x=fa[top[x]]; } if(deep[x]<deep[y]) swap(x,y); point k=ask(1,ord[y],ord[x]); re.sum+=k.sum; re.maxx=max(re.maxx,k.maxx); return re; } int main() { scanf("%d",&n); int x,y,m; char op[15]; for(int i=1;i<n;i++) {scanf("%d%d",&x,&y); add(x,y); add(y,x); } for(int i=1;i<=n;i++) scanf("%d",&v[i]); deep[1]=1; build(1); dfs(1); build(1,1,id); scanf("%d",&m); while(m--) {scanf("%s%d%d",op,&x,&y); if(op[1]=='H') change(1,ord[x],y); else if(op[1]=='M') {point ans=query(x,y); printf("%d\n",ans.maxx);} else if(op[1]=='S') {point ans=query(x,y); printf("%d\n",ans.sum);} } return 0; }