bzoj 3730 震波 —— 點分治+樹狀陣列
阿新 • • 發佈:2018-12-27
題目:https://www.lydsy.com/JudgeOnline/problem.php?id=3730
建點分樹,每個點記兩個樹狀陣列,存它作為重心管轄的範圍內,所有點到它的距離情況和到它在點分樹上的父親的距離情況;
於是算的時候可以減去重複的,就是跳到父親之前把自己會被重複統計的部分減去;
注意跳點分樹父親時,查詢的距離都是原本詢問點到那個父親的距離,而不是上一層父親到那個父親的距離;
樹狀陣列的大小總共是 nlogn 的,因為每層有 n 個點,一共 logn 層;
於是一開始寫的是開一個長長的樹狀陣列,記錄每個點用的是它上面的哪一段;
然後感覺很艱難...段的長度到底應該是最大深度還是點數?第二個樹狀陣列的段長度是2倍?......
於是改成 vector 了,用 resize() 函式可以方便地開大小,雖然還是不知道具體應該開多大?
查詢距離可以把求 LCA 的過程用尤拉序+ST表變成 O(1) 的;
艱難寫好,對拍竟然沒問題!然而交上去卻一直RE...
後來發現前一個操作是修改,並不需要把 ans 賦成0!(對拍沒發現是因為不會生成資料,只好每次只做一次查詢)
然而還是RE...??
#include<cstdio> #include<cstring> #include<algorithm> #include<vector> using namespaceREstd; int const xn=1e5+5,xm=xn*20,xxn=(xn<<1)+5; int n,hd[xn],ct,to[xn<<1],nxt[xn<<1],v[xn]; int dep[xn],fa[xn],siz[xn],mx,rt,tmp,ans; //int t[xm],t2[xm<<1],wmx,wmx2,l[xn],r[xn],l2[xn],r2[xn]; int tim,in[xn],st[xxn][25],bit[xxn],bin[25]; bool vis[xn]; vector<int>tr[xn],tr2[xn];int rd() { int ret=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=0; ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return f?ret:-ret; } int Max(int x,int y){return x>y?x:y;} int Min(int x,int y){return x<y?x:y;} void add(int x,int y){to[++ct]=y; nxt[ct]=hd[x]; hd[x]=ct;} void dfsx(int x,int ff) { dep[x]=dep[ff]+1; in[x]=++tim; st[tim][0]=x; for(int i=hd[x],u;i;i=nxt[i]) if((u=to[i])!=ff)dfsx(u,x),st[++tim][0]=x; //out[x]=++tim; st[tim][0]=x;// } void init() { dfsx(1,0); bin[0]=1; for(int i=1;i<=20;i++)bin[i]=bin[i-1]*2; bit[1]=0; for(int i=2;i<=tim;i++)bit[i]=bit[i>>1]+1; for(int i=1;i<=20;i++) for(int j=1;j+bin[i-1]<=tim&&st[j+bin[i-1]][i-1];j++) st[j][i]=Min(st[j][i-1],st[j+bin[i-1]][i-1]); } int lca(int x,int y) { if(in[x]>in[y])swap(x,y); int t=bit[in[y]-in[x]+1]; return Min(st[in[x]][t],st[in[y]-bin[t]+1][t]); } int dist(int x,int y){return dep[x]+dep[y]-2*dep[lca(x,y)];} void getrt(int x,int ff,int sum,int dis) { int nmx=0; siz[x]=1; for(int i=hd[x],u;i;i=nxt[i]) { if((u=to[i])==ff||vis[u])continue; getrt(u,x,sum,dis+1); siz[x]+=siz[u]; nmx=Max(nmx,siz[u]); } nmx=Max(nmx,sum-siz[x]); if(nmx<mx)mx=nmx,rt=x,tmp=dis; } /* void ins(int nw,int x,int v){if(x==0){t[l[nw]]+=v; return;} for(int p=l[nw]+x;p<=r[nw];x+=(x&-x),p=l[nw]+x)t[p]+=v;} int query(int nw,int x){int ret=0; for(int p=l[nw]+x;x;x-=(x&-x),p=l[nw]+x)ret+=t[p]; return ret+t[l[nw]];} void ins2(int nw,int x,int v){ if(x==0){t2[l2[nw]]+=v; return;} printf("nw=%d x=%d v=%d\n",nw,x,v); for(int p=l2[nw]+x;p<=r2[nw];x+=(x&-x),p=l2[nw]+x)t2[p]+=v,printf("t2[%d]+=%d x=%d\n",p,v,x);} int query2(int nw,int x){int ret=0; for(int p=l2[nw]+x;x;x-=(x&-x),p=l2[nw]+x)ret+=t2[p],printf("t2[%d]=%d x=%d\n",p,t2[p],x); return ret+t2[l[nw]];} */ void ins(int nw,int x,int v){if(x==0){tr[nw][0]+=v; return;} for(;x<tr[nw].size();x+=(x&-x))tr[nw][x]+=v;} int query(int nw,int x){int ret=0; x=Min(x,tr[nw].size()-1); for(;x;x-=(x&-x))ret+=tr[nw][x]; return ret+tr[nw][0];} void ins2(int nw,int x,int v){if(x==0){tr2[nw][0]+=v; return;} for(;x<tr2[nw].size();x+=(x&-x))tr2[nw][x]+=v;} int query2(int nw,int x){int ret=0; x=Min(x,tr2[nw].size()-1); for(;x;x-=(x&-x))ret+=tr2[nw][x]; return ret+tr2[nw][0];} void dfs(int x,int ff,int dis) { ins(rt,dis,v[x]); if(fa[rt])ins2(rt,dist(x,fa[rt]),v[x]); for(int i=hd[x],u;i;i=nxt[i]) if((u=to[i])!=ff&&!vis[u])dfs(u,x,dis+1); } void work(int x,int sum) { vis[x]=1; ins(x,0,v[x]); if(fa[x])ins2(x,dist(x,fa[x]),v[x]);// for(int i=hd[x],u;i;i=nxt[i]) { if(vis[u=to[i]])continue; dfs(u,x,1); } for(int i=hd[x],u;i;i=nxt[i]) { if(vis[u=to[i]])continue; int ns=(siz[u]>siz[x]?sum-siz[x]:siz[u]); mx=xn; getrt(u,0,ns,1); fa[rt]=x; //len[rt]=tmp; /* l[rt]=wmx+1; wmx+=mx+1; r[rt]=wmx;//maxdep l2[rt]=wmx2+1; wmx2+=2*mx+2; r2[rt]=wmx2; */ tr[rt].resize(mx+1); tr2[rt].resize(2*mx+2); work(rt,ns); } } void ask(int p,int x,int k,int dis) { ans+=query(x,k); int len=dist(p,fa[x]); if(dis>=len&&fa[x]) { ans-=query2(x,dis-len); ask(p,fa[x],dis-len,dis); } } void change(int p,int x,int v1,int v2) { ins(x,dist(p,x),-v1); ins2(x,dist(p,fa[x]),-v1); ins(x,dist(p,x),v2); ins2(x,dist(p,fa[x]),v2); if(fa[x])change(p,fa[x],v1,v2); } int main() { n=rd(); int m=rd(); for(int i=1;i<=n;i++)v[i]=rd(); for(int i=1,x,y;i<n;i++)x=rd(),y=rd(),add(x,y),add(y,x); init(); mx=xn; getrt(1,0,n,0); /* l[rt]=wmx+1; wmx+=mx+1; r[rt]=wmx; l2[rt]=wmx2+1; wmx2+=2*mx+2; r2[rt]=wmx2; */ tr[rt].resize(mx+1); tr2[rt].resize(2*mx+2); work(rt,n); for(int i=1,op,x,y;i<=m;i++) { op=rd(); x=(rd()^ans); y=(rd()^ans); if(op==0)ans=0,ask(x,x,y,y),printf("%d\n",ans); else change(x,x,v[x],y),v[x]=y;//,ans=0; } return 0; }
然後借鑑了TJ:https://www.cnblogs.com/enigma-aw/p/6209545.html
預處理父親好方便!
程式碼如下:
#include<cstdio> #include<cstring> #include<algorithm> #include<vector> using namespace std; int const xn=1e5+5; int n,hd[xn],ct,to[xn<<1],nxt[xn<<1],v[xn]; int dep[xn],siz[xn],mx,rt,ans; int fa[xn][20],dis[xn][20]; bool vis[xn]; vector<int>tr[xn],tr2[xn]; int rd() { int ret=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=0; ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return f?ret:-ret; } int Max(int x,int y){return x>y?x:y;} int Min(int x,int y){return x<y?x:y;} void add(int x,int y){to[++ct]=y; nxt[ct]=hd[x]; hd[x]=ct;} void getrt(int x,int ff,int sum) { int nmx=0; siz[x]=1; for(int i=hd[x],u;i;i=nxt[i]) { if((u=to[i])==ff||vis[u])continue; getrt(u,x,sum); siz[x]+=siz[u]; nmx=Max(nmx,siz[u]); } nmx=Max(nmx,sum-siz[x]); if(nmx<mx)mx=nmx,rt=x; } void build(int x,int p,int ff,int d) { for(int i=hd[x],u;i;i=nxt[i]) { if((u=to[i])==ff||vis[u])continue; fa[u][++dep[u]]=p; dis[u][dep[u]]=d; build(u,p,x,d+1); } } void work(int x,int sum) { vis[x]=1; build(x,x,0,1); tr[x].resize(sum+1); tr2[x].resize(sum+1);//nlogn for(int i=hd[x],u;i;i=nxt[i]) { if(vis[u=to[i]])continue; int ns=(siz[u]>siz[x]?sum-siz[x]:siz[u]); mx=xn; getrt(u,0,ns); work(rt,ns); } } void ins(int nw,int x,int v){for(;x<tr[nw].size()&&x;x+=(x&-x))tr[nw][x]+=v;} int query(int nw,int x){if(x<0)return 0; int ret=0; x=Min(x,tr[nw].size()-1); for(;x;x-=(x&-x))ret+=tr[nw][x]; return ret+v[nw];}//v[nw] void ins2(int nw,int x,int v){for(;x<tr2[nw].size()&&x;x+=(x&-x))tr2[nw][x]+=v;}//x int query2(int nw,int x){if(x<0)return 0; int ret=0; x=Min(x,tr2[nw].size()-1); for(;x;x-=(x&-x))ret+=tr2[nw][x]; return ret;} int ask(int x,int k) { int ret=query(x,k); for(int i=dep[x];i;i--)// if(k>=dis[x][i]) ret+=query(fa[x][i],k-dis[x][i])-query2(fa[x][i+1],k-dis[x][i]); return ret; } void change(int x,int val) { int d=dis[x][dep[x]],ff; ins2(x,d,val);//x for(int i=dep[x];i;i--) { d=dis[x][i]; ff=fa[x][i]; ins(ff,d,val); d=dis[x][i-1]; ins2(ff,d,val);//fa[x][i] } } int main() { n=rd(); int m=rd(); for(int i=1;i<=n;i++)v[i]=rd(); for(int i=1,x,y;i<n;i++)x=rd(),y=rd(),add(x,y),add(y,x); mx=xn; getrt(1,0,n); work(rt,n); for(int i=1;i<=n;i++)fa[i][dep[i]+1]=i; for(int i=1;i<=n;i++)change(i,v[i]); for(int i=1,op,x,y;i<=m;i++) { op=rd(); x=(rd()^ans); y=(rd()^ans); if(op==0)ans=ask(x,y),printf("%d\n",ans); else change(x,y-v[x]),v[x]=y;//,ans=0; } return 0; }