BZOJ1906樹上的螞蟻&BZOJ3700發展城市——RMQ求LCA+樹鏈的交
題目描述
眾所周知,Hzwer學長是一名高富帥,他打算投入巨資發展一些小城市。
Hzwer打算在城市中開N個賓館,由於Hzwer非常壕,所以賓館必須建在空中,但是這樣就必須建立賓館之間的連線通道。機智的Hzwer在賓館中修建了N-1條隧道,也就是說,賓館和隧道形成了一個樹形結構。
Hzwer有時候會花一天時間去視察某個城市,當來到一個城市之後,Hzwer會分析這些賓館的顧客情況。對於每個顧客,Hzwer用三個數值描述他:(S, T, V)表示該顧客這天想要從賓館S走到賓館T,他的速度是V。
Hzwer需要做一些收集一些資料,這樣他就可以規劃他接下來的投資。
其中有一項資料就是收集所有顧客可能的碰面次數。
每天清晨,顧客同時從S出發以V的速度前往T(注意S可能等於T),當到達了賓館T的時候,顧客顯然要找個房間住下,那麼別的顧客再經過這裡就不會碰面了。特別的,兩個顧客同時到達一個賓館是可以碰面的。同樣,兩個顧客同時從某賓館出發也會碰面。
輸入
第一行一個正整數T(1<=T<=20),表示Hzwer發展了T個城市,並且在這T個城市分別視察一次。
對於每個T,第一行有一個正整數N(1<=N<=10^5)表示Hzwer在這個城市開了N個賓館。
接下來N-1行,每行三個整數X,Y,Z表示賓館X和賓館Y之間有一條長度為Z的隧道
再接下來一行M表示這天顧客的數量。
緊跟著M行每行三個整數(S, T, V)表示該顧客會從賓館S走到賓館T,速度為v
輸出
對於每個T,輸出一行,表示顧客的碰面次數。
樣例輸入
13
1 2 1
2 3 1
3
1 3 2
3 1 1
1 2 3
樣例輸出
20
提示
【資料規模】
1<=T<=20 1<=N<=10^5 0<=M<=10^3 1<=V<=10^6 1<=Z<=10^3
這題細節好多啊,蒟蒻的我調了一下午。
考慮到m的範圍比較小,因此可以兩兩列舉判斷是否相遇。
對於兩個路徑,如果能夠相遇,相遇點一定在兩個路徑的交路徑上。
如何求樹上路徑交?
對於兩個路徑A(a.u,a.v)與B(b.u,b.v)求出lca(a.u,b.u),lca(a.v,b.v),lca(a.v,b.u),lca(a.u,b.v)
去掉這四個點中不在A或B路徑上的點,再去重後按dfs序排序,取後兩個(如果只有一個說明路徑只交於一點)就是交路徑的兩個端點
判斷出兩個路徑起點先到達的交路徑的端點是否是同一個,如果是就說明兩個顧客是同向運動,反之則是相向運動。
如果兩顧客是同向運動:只要先進入交路徑的顧客後走出交路徑就一定相遇。
如果兩顧客是相向運動:分別求出兩顧客進入和走出交路徑的時間,判斷只要兩時間段有交集就能相遇,因為除法較慢,所以轉成交叉相乘判斷。
在判斷和求路徑過程中多次求lca,用O(logn)的方法求顯然會TLE,在這裡採用RMQ求lca:
在dfs時求出尤拉遍歷序(就是遍歷到一個點存一次)及每個點第一次被遍歷的位置
對於x,y兩點的lca就是尤拉序上兩點第一次被遍歷位置之間深度最小的點,用ST表即可O(1)查詢
這道題有點卡常,注意涉及到乘速度時可能會爆longlong。
#include<cmath> #include<cstdio> #include<cstring> #include<iostream> #include<algorithm> #define ll long long using namespace std; inline char _read() { static char buf[100000],*p1=buf,*p2=buf; return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++; } inline int read() { int x=0,f=1;char ch=_read(); while(ch>'9'||ch<'0'){if(ch=='-')f=-1;ch=_read();} while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+ch-'0';ch=_read();} return x*f; } int T,n,m; int head[100010]; int s[100010]; int to[200010]; int next[200010]; int val[200010]; int d[100010]; int dep[100010]; int f[200010][18]; int g[200010][18]; int tot; int num; int x,y,z; int ans; int p[5]; int cnt; int b[200010]; struct miku { int u,v,w; }a[1010]; inline void add(int x,int y,int z) { tot++; next[tot]=head[x]; head[x]=tot; to[tot]=y; val[tot]=z; } inline void dfs(int x,int fa) { d[x]=d[fa]+1; s[x]=++num; f[num][0]=d[x]; g[num][0]=x; for(int i=head[x];i;i=next[i]) { if(to[i]!=fa) { dep[to[i]]=dep[x]+val[i]; dfs(to[i],x); f[++num][0]=d[x]; g[num][0]=x; } } } inline void ST() { for(int j=1;j<=17;j++) { for(int i=1;i<=num;i++) { if(i+(1<<j)-1>num) { break; } if(f[i][j-1]<f[i+(1<<(j-1))][j-1]) { f[i][j]=f[i][j-1]; g[i][j]=g[i][j-1]; } else { f[i][j]=f[i+(1<<(j-1))][j-1]; g[i][j]=g[i+(1<<(j-1))][j-1]; } } } } inline int lca(int x,int y) { x=s[x]; y=s[y]; if(x>y) { swap(x,y); } int len=b[y-x+1]; if(f[x][len]<f[y-(1<<len)+1][len]) { return g[x][len]; } else { return g[y-(1<<len)+1][len]; } } inline bool find(int anc,int x,int y) { int fx=lca(a[x].u,a[x].v); int fy=lca(a[y].u,a[y].v); if(lca(fx,anc)!=fx||lca(fy,anc)!=fy) { return false; } if(fx!=lca(fx,a[x].u)&&fx!=lca(fx,a[x].v)) { return false; } if(fy!=lca(fy,a[y].u)&&fy!=lca(fy,a[y].v)) { return false; } return true; } inline int dis(int x,int y) { int anc=lca(x,y); return dep[x]+dep[y]-2*dep[anc]; } inline bool cmp(int x,int y) { return s[x]<s[y]; } inline bool cpr(ll a,ll b,ll c) { if(a<=b&&b<=c) { return 1; } else { return 0; } } inline int check(int x,int y) { if(a[x].u==a[y].u) { return 1; } int res; cnt=0; res=lca(a[x].u,a[y].u); if(find(res,x,y)){p[++cnt]=res;} res=lca(a[x].v,a[y].v); if(find(res,x,y)){p[++cnt]=res;} res=lca(a[x].u,a[y].v); if(find(res,x,y)){p[++cnt]=res;} res=lca(a[y].u,a[x].v); if(find(res,x,y)){p[++cnt]=res;} if(cnt==0) { return 0; } sort(p+1,p+1+cnt,cmp); cnt=unique(p+1,p+1+cnt)-p-1; if(cnt==1) { if(1ll*dis(a[x].u,p[1])*a[y].w==1ll*dis(a[y].u,p[1])*a[x].w) { return 1; } else { return false; } } int st=p[cnt]; int ed=p[cnt-1]; int A1,A2,B1,B2; ll a1,a2,b1,b2; if(dis(a[x].u,st)<dis(a[x].u,ed)) { A1=st; A2=ed; } else { A1=ed; A2=st; } if(dis(a[y].u,st)<dis(a[y].u,ed)) { B1=st; B2=ed; } else { B1=ed; B2=st; } a1=1ll*dis(a[x].u,A1)*a[y].w; a2=1ll*dis(a[x].u,A2)*a[y].w; b1=1ll*dis(a[y].u,B1)*a[x].w; b2=1ll*dis(a[y].u,B2)*a[x].w; if(A1==B1) { if(a1==b1) { return 1; } if(a1<b1) { return b2<=a2; } else { return a2<=b2; } } else { if(cpr(a1,b1,a2))return 1; if(cpr(a1,b2,a2))return 1; if(cpr(b1,a1,b2))return 1; if(cpr(b1,a2,b2))return 1; return 0; } } int main() { T=read(); b[0]=-1; for(int i=1;i<=200010;i++) { b[i]=b[i>>1]+1; } while(T--) { memset(head,0,sizeof(head)); num=0; tot=0; ans=0; n=read(); for(int i=1;i<n;i++) { x=read(); y=read(); z=read(); add(x,y,z); add(y,x,z); } dfs(1,0); ST(); m=read(); for(int i=1;i<=m;i++) { a[i].u=read(); a[i].v=read(); a[i].w=read(); } for(int i=1;i<=m;i++) { for(int j=i+1;j<=m;j++) { ans+=check(i,j); } } printf("%d\n",ans); } }