樹上的最遠點對 51Nod - 1766
阿新 • • 發佈:2018-12-12
https://www.51nod.com/Challenge/Problem.html#!#problemId=1766
在點集S1 S2中各找出兩個最遠的點 將S1 S2合併後最遠的兩個點一定在這四個點之間 滿足區間可加性 線段樹優化
還有就是求LCA 發現之前一直寫的是假演算法 每次查詢都是log的複雜度。。完全可以利用倍增陣列結合位運算做到O1查詢 就和普通rmq一樣
#include <bits/stdc++.h> using namespace std; const int maxn=1e5+10; struct node1 { int v,w,next; }; struct node2 { int u,v,val; }; node1 edge[2*maxn]; node2 tree[4*maxn]; int dp[2*maxn][20]; int first[maxn],deep[maxn],dis[maxn],mp1[maxn],mp2[2*maxn],logval[2*maxn]; int n,q,num; template <class T> inline void _cin(T &ret) { char c; ret = 0; while((c = getchar()) < '0' || c > '9'); while(c >= '0' && c <= '9') { ret = ret * 10 + (c - '0'); c = getchar(); } } void addedge(int u,int v,int w) { edge[num].v=v; edge[num].w=w; edge[num].next=first[u]; first[u]=num++; } void dfs(int cur,int fa) { int i,v,w; num++; mp1[cur]=num,mp2[num]=cur; for(i=first[cur];i!=-1;i=edge[i].next){ v=edge[i].v,w=edge[i].w; if(v!=fa){ dp[v][0]=cur; deep[v]=deep[cur]+1,dis[v]=dis[cur]+w; dfs(v,cur); num++; mp2[num]=cur; } } } void solve() { int i,j; dp[1][0]=0; deep[1]=1,dis[1]=0; num=0; dfs(1,0); logval[0]=-1; for(i=1;i<=num;i++) logval[i]=logval[i/2]+1; for(i=1;i<=num;i++) dp[i][0]=mp2[i]; for(j=1;j<=20;j++){ for(i=1;i+(1<<j)<=num;i++){ if(deep[dp[i][j-1]]<deep[dp[i+(1<<(j-1))][j-1]]) dp[i][j]=dp[i][j-1]; else dp[i][j]=dp[i+(1<<(j-1))][j-1]; } } } int getlca(int x,int y) { int res,k; if(mp1[x]>mp1[y]) swap(x,y); k=logval[mp1[y]-mp1[x]+1]; if(deep[dp[mp1[x]][k]]<deep[dp[mp1[y]-(1<<k)+1][k]]) res=dp[mp1[x]][k]; else res=dp[mp1[y]-(1<<k)+1][k]; return res; } void pushup(int cur) { int u,v,val,tu,tv,lca; val=-1; if(val<tree[2*cur].val){ u=tree[2*cur].u,v=tree[2*cur].v,val=tree[2*cur].val; } if(val<tree[2*cur+1].val){ u=tree[2*cur+1].u,v=tree[2*cur+1].v,val=tree[2*cur+1].val; } tu=tree[2*cur].u,tv=tree[2*cur+1].u,lca=getlca(tu,tv); if(val<dis[tu]+dis[tv]-2*dis[lca]){ u=tu,v=tv,val=dis[tu]+dis[tv]-2*dis[lca]; } tu=tree[2*cur].u,tv=tree[2*cur+1].v,lca=getlca(tu,tv); if(val<dis[tu]+dis[tv]-2*dis[lca]){ u=tu,v=tv,val=dis[tu]+dis[tv]-2*dis[lca]; } tu=tree[2*cur].v,tv=tree[2*cur+1].u,lca=getlca(tu,tv); if(val<dis[tu]+dis[tv]-2*dis[lca]){ u=tu,v=tv,val=dis[tu]+dis[tv]-2*dis[lca]; } tu=tree[2*cur].v,tv=tree[2*cur+1].v,lca=getlca(tu,tv); if(val<dis[tu]+dis[tv]-2*dis[lca]){ u=tu,v=tv,val=dis[tu]+dis[tv]-2*dis[lca]; } tree[cur].u=u,tree[cur].v=v,tree[cur].val=val; } void build(int l,int r,int cur) { int m; if(l==r){ tree[cur].u=l,tree[cur].v=l,tree[cur].val=0; return; } m=(l+r)/2; build(l,m,2*cur); build(m+1,r,2*cur+1); pushup(cur); } void query(int &uu,int &vv,int &vall,int pl,int pr,int l,int r,int cur) { int m; int u1,v1,val1,u2,v2,val2; int u,v,val,tu,tv,lca; if(pl<=l&&r<=pr){ uu=tree[cur].u,vv=tree[cur].v,vall=tree[cur].val; return; } m=(l+r)/2,u1=-1,u2=-1; if(pl<=m) query(u1,v1,val1,pl,pr,l,m,2*cur); if(pr>m) query(u2,v2,val2,pl,pr,m+1,r,2*cur+1); if(u1==-1) uu=u2,vv=v2,vall=val2; else if(u2==-1) uu=u1,vv=v1,vall=val1; else{ val=-1; if(val<val1){ u=u1,v=v1,val=val1; } if(val<val2){ u=u2,v=v2,val=val2; } tu=u1,tv=u2,lca=getlca(tu,tv); if(val<dis[tu]+dis[tv]-2*dis[lca]){ u=tu,v=tv,val=dis[tu]+dis[tv]-2*dis[lca]; } tu=u1,tv=v2,lca=getlca(tu,tv); if(val<dis[tu]+dis[tv]-2*dis[lca]){ u=tu,v=tv,val=dis[tu]+dis[tv]-2*dis[lca]; } tu=v1,tv=u2,lca=getlca(tu,tv); if(val<dis[tu]+dis[tv]-2*dis[lca]){ u=tu,v=tv,val=dis[tu]+dis[tv]-2*dis[lca]; } tu=v1,tv=v2,lca=getlca(tu,tv); if(val<dis[tu]+dis[tv]-2*dis[lca]){ u=tu,v=tv,val=dis[tu]+dis[tv]-2*dis[lca]; } uu=u,vv=v,vall=val; } } int main() { int i,u,v,w; int val,tu,tv,lca; int a,b,c,d,u1,v1,val1,u2,v2,val2; //scanf("%d",&n); _cin(n); memset(first,-1,sizeof(first)); num=0; for(i=1;i<=n-1;i++){ //scanf("%d%d%d",&u,&v,&w); _cin(u),_cin(v),_cin(w); addedge(u,v,w); addedge(v,u,w); } solve(); build(1,n,1); //scanf("%d",&q); _cin(q); while(q--){ //scanf("%d%d%d%d",&a,&b,&c,&d); _cin(a),_cin(b),_cin(c),_cin(d); query(u1,v1,val1,a,b,1,n,1); query(u2,v2,val2,c,d,1,n,1); val=-1; tu=u1,tv=u2,lca=getlca(tu,tv); val=max(val,dis[tu]+dis[tv]-2*dis[lca]); tu=u1,tv=v2,lca=getlca(tu,tv); val=max(val,dis[tu]+dis[tv]-2*dis[lca]); tu=v1,tv=u2,lca=getlca(tu,tv); val=max(val,dis[tu]+dis[tv]-2*dis[lca]); tu=v1,tv=v2,lca=getlca(tu,tv); val=max(val,dis[tu]+dis[tv]-2*dis[lca]); printf("%d\n",val); } return 0; } /* 20 9 10 1 9 4 1 10 11 1 10 12 1 10 6 1 10 1 1 12 7 1 12 2 1 1 13 1 1 14 1 1 5 1 7 3 1 7 8 1 14 15 1 14 16 1 14 17 1 3 18 1 17 19 1 17 20 1 1 8 20 4 4 */