BZOJ4543/BZOJ3522 [POI2014]Hotel加強版(長鏈剖分)
阿新 • • 發佈:2018-11-12
題目好神仙……這個叫長鏈剖分的玩意兒更神仙……
考慮dp,設\(f[i][j]\)表示以\(i\)為根的子樹中到\(i\)的距離為\(j\)的點的個數,\(g[i][j]\)表示\(i\)的子樹中有\(g[i][j]\)對點深度相同,他們到LCA的距離為\(d\),且他們的LCA到\(i\)的距離為\(d-j\)。或者換句話來說就是以\(i\)為根的子樹中有這麼多個點對,而且沒有第三個點去和這些點對匹配,第三個點不在\(i\)的子樹中且到\(i\)的距離為\(j\),\(g[i][j]\)表示這些點對的個數
設\(u\)為當前點,\(v\)為某一子樹,那麼轉移方程如下
\[f[u][i]+=f[v][i+1]\]
\[g[u][i-1]+=g[v][i]\]
\[g[u][i+1]+=f[u][i+1]*f[v][i]\]
\[ans+=f[u][i-1]*g[v][i]+g[u][i+1]*f[v][i]\]
如果是原題的\(n\leq 5000\)已經足夠了,然而當\(n\leq 100000\)的時候很明顯gg了
發現狀態陣列的第二維實際上跟這個節點的深度有關,於是考慮用長鏈剖分優化。簡單來說記每一個節點深度最大的兒子為它的重兒子。因為第一次轉移的時候有\(f[u][i]=f[v][i-1],g[u][i]=g[v][i+1]\),於是可以類似於dsu on tree的思想,對於每個重兒子的資訊直接繼承,輕兒子暴力跑一遍。重兒子的資訊可以直接用指標來達到\(O(1)\)
這個時間複雜度大概是\(O(n)\)的,對於每個點轉移的複雜度為\(\sum dep[v]-dep[son[u]]=\sum dep[v]-dep[u]+1\),然後所有點的加起來除了葉子結點都互相抵消,於是總的複雜度為\(O(n)\)
空間複雜度也是\(O(n)\),因為非葉節點的空間都是由它所在重鏈的兒子轉移來的,所以對每個葉節點開正比於此重鏈長度的空間即可
//minamoto #include<bits/stdc++.h> #define ll long long using namespace std; #define getc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++) char buf[1<<21],*p1=buf,*p2=buf; int read(){ int res,f=1;char ch; while((ch=getc())>'9'||ch<'0')(ch=='-')&&(f=-1); for(res=ch-'0';(ch=getc())>='0'&&ch<='9';res=res*10+ch-'0'); return res*f; } const int N=1e5+5,M=1005; int head[N],Next[N<<1],ver[N<<1],tot; inline void add(int u,int v){ver[++tot]=v,Next[tot]=head[u],head[u]=tot;} ll memp[N*5],*f[N],*g[N],*to=memp+5,ans; int n,dep[N],mx[N]; void dfs(int u,int fa){ mx[u]=u; for(int i=head[u];i;i=Next[i]){ int v=ver[i]; if(v!=fa){ dep[v]=dep[u]+1,dfs(v,u); if(dep[mx[v]]>dep[mx[u]])mx[u]=mx[v]; } } for(int i=head[u];i;i=Next[i]){ int v=ver[i]; if(v!=fa&&(mx[v]!=mx[u]||u==1)){ v=mx[v],to+=dep[v]-dep[u]+1; f[v]=to,g[v]=(to+=1),to+=(dep[v]-dep[u])*2+1; } } } void dp(int u,int fa){ for(int i=head[u];i;i=Next[i]){ int v=ver[i];if(v==fa)continue;dp(v,u); if(mx[v]==mx[u])f[u]=f[v]-1,g[u]=g[v]+1; } ans+=g[u][0],f[u][0]=1; for(int i=head[u];i;i=Next[i]){ int v=ver[i];if(v==fa||mx[v]==mx[u])continue; for(int j=0;j<=dep[mx[v]]-dep[u];++j) ans+=f[u][j-1]*g[v][j]+g[u][j+1]*f[v][j]; for(int j=0;j<=dep[mx[v]]-dep[u];++j){ g[u][j-1]+=g[v][j]; g[u][j+1]+=f[u][j+1]*f[v][j]; f[u][j+1]+=f[v][j]; } } } int main(){ // freopen("testdata.in","r",stdin); n=read(); for(int i=1,u,v;i<n;++i)u=read(),v=read(),add(u,v),add(v,u); while(to!=memp)*to=0,--to;*to=0,++to; dep[1]=1;dfs(1,0),dp(1,0); printf("%lld\n",ans);return 0; }