1. 程式人生 > >BZOJ4543/BZOJ3522 [POI2014]Hotel加強版(長鏈剖分)

BZOJ4543/BZOJ3522 [POI2014]Hotel加強版(長鏈剖分)

題目好神仙……這個叫長鏈剖分的玩意兒更神仙……

考慮dp,設\(f[i][j]\)表示以\(i\)為根的子樹中到\(i\)的距離為\(j\)的點的個數,\(g[i][j]\)表示\(i\)的子樹中有\(g[i][j]\)對點深度相同,他們到LCA的距離為\(d\),且他們的LCA到\(i\)的距離為\(d-j\)。或者換句話來說就是以\(i\)為根的子樹中有這麼多個點對,而且沒有第三個點去和這些點對匹配,第三個點不在\(i\)的子樹中且到\(i\)的距離為\(j\)\(g[i][j]\)表示這些點對的個數

\(u\)為當前點,\(v\)為某一子樹,那麼轉移方程如下
\[f[u][i]+=f[v][i+1]\]


\[g[u][i-1]+=g[v][i]\]
\[g[u][i+1]+=f[u][i+1]*f[v][i]\]
\[ans+=f[u][i-1]*g[v][i]+g[u][i+1]*f[v][i]\]

如果是原題的\(n\leq 5000\)已經足夠了,然而當\(n\leq 100000\)的時候很明顯gg了

發現狀態陣列的第二維實際上跟這個節點的深度有關,於是考慮用長鏈剖分優化。簡單來說記每一個節點深度最大的兒子為它的重兒子。因為第一次轉移的時候有\(f[u][i]=f[v][i-1],g[u][i]=g[v][i+1]\),於是可以類似於dsu on tree的思想,對於每個重兒子的資訊直接繼承,輕兒子暴力跑一遍。重兒子的資訊可以直接用指標來達到\(O(1)\)

的轉移

這個時間複雜度大概是\(O(n)\)的,對於每個點轉移的複雜度為\(\sum dep[v]-dep[son[u]]=\sum dep[v]-dep[u]+1\),然後所有點的加起來除了葉子結點都互相抵消,於是總的複雜度為\(O(n)\)

空間複雜度也是\(O(n)\),因為非葉節點的空間都是由它所在重鏈的兒子轉移來的,所以對每個葉節點開正比於此重鏈長度的空間即可

//minamoto
#include<bits/stdc++.h>
#define ll long long
using namespace std;
#define getc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
char buf[1<<21],*p1=buf,*p2=buf;
int read(){
    int res,f=1;char ch;
    while((ch=getc())>'9'||ch<'0')(ch=='-')&&(f=-1);
    for(res=ch-'0';(ch=getc())>='0'&&ch<='9';res=res*10+ch-'0');
    return res*f;
}
const int N=1e5+5,M=1005;
int head[N],Next[N<<1],ver[N<<1],tot;
inline void add(int u,int v){ver[++tot]=v,Next[tot]=head[u],head[u]=tot;}
ll memp[N*5],*f[N],*g[N],*to=memp+5,ans;
int n,dep[N],mx[N];
void dfs(int u,int fa){
    mx[u]=u;
    for(int i=head[u];i;i=Next[i]){
        int v=ver[i];
        if(v!=fa){
            dep[v]=dep[u]+1,dfs(v,u);
            if(dep[mx[v]]>dep[mx[u]])mx[u]=mx[v];
        }
    }
    for(int i=head[u];i;i=Next[i]){
        int v=ver[i];
        if(v!=fa&&(mx[v]!=mx[u]||u==1)){
            v=mx[v],to+=dep[v]-dep[u]+1;
            f[v]=to,g[v]=(to+=1),to+=(dep[v]-dep[u])*2+1;
        }
    }
}
void dp(int u,int fa){
    for(int i=head[u];i;i=Next[i]){
        int v=ver[i];if(v==fa)continue;dp(v,u);
        if(mx[v]==mx[u])f[u]=f[v]-1,g[u]=g[v]+1;
    }
    ans+=g[u][0],f[u][0]=1;
    for(int i=head[u];i;i=Next[i]){
        int v=ver[i];if(v==fa||mx[v]==mx[u])continue;
        for(int j=0;j<=dep[mx[v]]-dep[u];++j)
        ans+=f[u][j-1]*g[v][j]+g[u][j+1]*f[v][j];
        for(int j=0;j<=dep[mx[v]]-dep[u];++j){
            g[u][j-1]+=g[v][j];
            g[u][j+1]+=f[u][j+1]*f[v][j];
            f[u][j+1]+=f[v][j];
        }
    }
}
int main(){
//  freopen("testdata.in","r",stdin);
    n=read();
    for(int i=1,u,v;i<n;++i)u=read(),v=read(),add(u,v),add(v,u);
    while(to!=memp)*to=0,--to;*to=0,++to;
    dep[1]=1;dfs(1,0),dp(1,0);
    printf("%lld\n",ans);return 0;
}