1. 程式人生 > >BZOJ1906樹上的螞蟻&BZOJ3700發展城市——RMQ求LCA+樹鏈的交

BZOJ1906樹上的螞蟻&BZOJ3700發展城市——RMQ求LCA+樹鏈的交

題目描述

 眾所周知,Hzwer學長是一名高富帥,他打算投入巨資發展一些小城市。
 Hzwer打算在城市中開N個賓館,由於Hzwer非常壕,所以賓館必須建在空中,但是這樣就必須建立賓館之間的連線通道。機智的Hzwer在賓館中修建了N-1條隧道,也就是說,賓館和隧道形成了一個樹形結構。
 Hzwer有時候會花一天時間去視察某個城市,當來到一個城市之後,Hzwer會分析這些賓館的顧客情況。對於每個顧客,Hzwer用三個數值描述他:(S, T, V)表示該顧客這天想要從賓館S走到賓館T,他的速度是V。
 Hzwer需要做一些收集一些資料,這樣他就可以規劃他接下來的投資。
 其中有一項資料就是收集所有顧客可能的碰面次數。
 每天清晨,顧客同時從S出發以V的速度前往T(注意S可能等於T),當到達了賓館T的時候,顧客顯然要找個房間住下,那麼別的顧客再經過這裡就不會碰面了。特別的,兩個顧客同時到達一個賓館是可以碰面的。同樣,兩個顧客同時從某賓館出發也會碰面。

輸入

 第一行一個正整數T(1<=T<=20),表示Hzwer發展了T個城市,並且在這T個城市分別視察一次。
 對於每個T,第一行有一個正整數N(1<=N<=10^5)表示Hzwer在這個城市開了N個賓館。
 接下來N-1行,每行三個整數X,Y,Z表示賓館X和賓館Y之間有一條長度為Z的隧道
 再接下來一行M表示這天顧客的數量。
 緊跟著M行每行三個整數(S, T, V)表示該顧客會從賓館S走到賓館T,速度為v

輸出

 對於每個T,輸出一行,表示顧客的碰面次數。

樣例輸入

1
3
1 2 1
2 3 1
3
1 3 2
3 1 1
1 2 3

樣例輸出

2
0

提示

【資料規模】

 1<=T<=20   1<=N<=10^5   0<=M<=10^3   1<=V<=10^6   1<=Z<=10^3

 

這題細節好多啊,蒟蒻的我調了一下午。

考慮到m的範圍比較小,因此可以兩兩列舉判斷是否相遇。

對於兩個路徑,如果能夠相遇,相遇點一定在兩個路徑的交路徑上。

如何求樹上路徑交?

對於兩個路徑A(a.u,a.v)與B(b.u,b.v)求出lca(a.u,b.u),lca(a.v,b.v),lca(a.v,b.u),lca(a.u,b.v)

去掉這四個點中不在A或B路徑上的點,再去重後按dfs序排序,取後兩個(如果只有一個說明路徑只交於一點)就是交路徑的兩個端點

判斷出兩個路徑起點先到達的交路徑的端點是否是同一個,如果是就說明兩個顧客是同向運動,反之則是相向運動。

如果兩顧客是同向運動:只要先進入交路徑的顧客後走出交路徑就一定相遇。

如果兩顧客是相向運動:分別求出兩顧客進入和走出交路徑的時間,判斷只要兩時間段有交集就能相遇,因為除法較慢,所以轉成交叉相乘判斷。

在判斷和求路徑過程中多次求lca,用O(logn)的方法求顯然會TLE,在這裡採用RMQ求lca:

在dfs時求出尤拉遍歷序(就是遍歷到一個點存一次)及每個點第一次被遍歷的位置

對於x,y兩點的lca就是尤拉序上兩點第一次被遍歷位置之間深度最小的點,用ST表即可O(1)查詢

這道題有點卡常,注意涉及到乘速度時可能會爆longlong。

#include<cmath>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
inline char _read()
{
    static char buf[100000],*p1=buf,*p2=buf;
    return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline int read()
{
    int x=0,f=1;char ch=_read();
    while(ch>'9'||ch<'0'){if(ch=='-')f=-1;ch=_read();}
    while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+ch-'0';ch=_read();}
    return x*f;
}
int T,n,m;
int head[100010];
int s[100010];
int to[200010];
int next[200010];
int val[200010];
int d[100010];
int dep[100010];
int f[200010][18];
int g[200010][18];
int tot;
int num;
int x,y,z;
int ans;
int p[5];
int cnt;
int b[200010];
struct miku
{
    int u,v,w;
}a[1010];
inline void add(int x,int y,int z)
{
    tot++;
    next[tot]=head[x];
    head[x]=tot;
    to[tot]=y;
    val[tot]=z;
}
inline void dfs(int x,int fa)
{
    d[x]=d[fa]+1;
    s[x]=++num;
    f[num][0]=d[x];
    g[num][0]=x;
    for(int i=head[x];i;i=next[i])
    {
        if(to[i]!=fa)
        {
            dep[to[i]]=dep[x]+val[i];
            dfs(to[i],x);
            f[++num][0]=d[x];
            g[num][0]=x;
        }
    }
}
inline void ST()
{
    for(int j=1;j<=17;j++)
    {
        for(int i=1;i<=num;i++)
        {
            if(i+(1<<j)-1>num)
            {
                break;
            }
            if(f[i][j-1]<f[i+(1<<(j-1))][j-1])
            {
                f[i][j]=f[i][j-1];
                g[i][j]=g[i][j-1];
            }
            else
            {
                f[i][j]=f[i+(1<<(j-1))][j-1];
                g[i][j]=g[i+(1<<(j-1))][j-1];
            }
        }
    }
}
inline int lca(int x,int y)
{
    x=s[x];
    y=s[y];
    if(x>y)
    {
        swap(x,y);
    }
    int len=b[y-x+1];
    if(f[x][len]<f[y-(1<<len)+1][len])
    {
        return g[x][len];
    }
    else
    {
        return g[y-(1<<len)+1][len];
    }
}
inline bool find(int anc,int x,int y)
{
    int fx=lca(a[x].u,a[x].v);
    int fy=lca(a[y].u,a[y].v);
    if(lca(fx,anc)!=fx||lca(fy,anc)!=fy)
    {
        return false;
    }
    if(fx!=lca(fx,a[x].u)&&fx!=lca(fx,a[x].v))
    {
        return false;
    }
    if(fy!=lca(fy,a[y].u)&&fy!=lca(fy,a[y].v))
    {
        return false;
    }
    return true;
}
inline int dis(int x,int y)
{
    int anc=lca(x,y);
    return dep[x]+dep[y]-2*dep[anc];
}
inline bool cmp(int x,int y)
{
    return s[x]<s[y];
}
inline bool cpr(ll a,ll b,ll c)
{
    if(a<=b&&b<=c)
    {
        return 1;
    }
    else
    {
        return 0;
    }
}
inline int check(int x,int y)
{
    if(a[x].u==a[y].u)
    {
        return 1;
    }
    int res;
    cnt=0;
    res=lca(a[x].u,a[y].u);
    if(find(res,x,y)){p[++cnt]=res;}
    res=lca(a[x].v,a[y].v);
    if(find(res,x,y)){p[++cnt]=res;}
    res=lca(a[x].u,a[y].v);
    if(find(res,x,y)){p[++cnt]=res;}
    res=lca(a[y].u,a[x].v);
    if(find(res,x,y)){p[++cnt]=res;}
    if(cnt==0)
    {
        return 0;
    }
    sort(p+1,p+1+cnt,cmp);
    cnt=unique(p+1,p+1+cnt)-p-1;
    if(cnt==1)
    {
        if(1ll*dis(a[x].u,p[1])*a[y].w==1ll*dis(a[y].u,p[1])*a[x].w)
        {
            return 1;
        }
        else
        {
            return false;
        }
    }
    int st=p[cnt];
    int ed=p[cnt-1];
    int A1,A2,B1,B2;
    ll a1,a2,b1,b2;
    if(dis(a[x].u,st)<dis(a[x].u,ed))
    {
        A1=st;
        A2=ed;
    }
    else
    {
        A1=ed;
        A2=st;
    }
    if(dis(a[y].u,st)<dis(a[y].u,ed))
    {
        B1=st;
        B2=ed;
    }
    else
    {
        B1=ed;
        B2=st;
    }
    a1=1ll*dis(a[x].u,A1)*a[y].w;
    a2=1ll*dis(a[x].u,A2)*a[y].w;
    b1=1ll*dis(a[y].u,B1)*a[x].w;
    b2=1ll*dis(a[y].u,B2)*a[x].w;
    if(A1==B1)
    {
        if(a1==b1)
        {
            return 1;
        }
        if(a1<b1)
        {
            return b2<=a2;
        }
        else
        {
            return a2<=b2;
        }
    }
    else
    {
        if(cpr(a1,b1,a2))return 1;
        if(cpr(a1,b2,a2))return 1;
        if(cpr(b1,a1,b2))return 1;
        if(cpr(b1,a2,b2))return 1;
        return 0;
    }
}
int main()
{
    T=read();
    b[0]=-1;
    for(int i=1;i<=200010;i++)
    {
        b[i]=b[i>>1]+1;
    }
    while(T--)
    {
        memset(head,0,sizeof(head));
        num=0;
        tot=0;
        ans=0;
        n=read();
        for(int i=1;i<n;i++)
        {
            x=read();
            y=read();
            z=read();
            add(x,y,z);
            add(y,x,z);
        }
        dfs(1,0);
        ST();
        m=read();
        for(int i=1;i<=m;i++)
        {
            a[i].u=read();
            a[i].v=read();
            a[i].w=read();
        }
        for(int i=1;i<=m;i++)
        {
            for(int j=i+1;j<=m;j++)
            {
                ans+=check(i,j);
            }
        }
        printf("%d\n",ans);
    }
}