1. 程式人生 > >BZOJ3451Normal——點分治+FFT

BZOJ3451Normal——點分治+FFT

題目描述
某天WJMZBMR學習了一個神奇的演算法:樹的點分治!
這個演算法的核心是這樣的:
消耗時間=0
Solve(樹 a)
消耗時間 += a 的 大小
如果 a 中 只有 1 個點
退出
否則在a中選一個點x,在a中刪除點x
那麼a變成了幾個小一點的樹,對每個小樹遞迴呼叫Solve
我們注意到的這個演算法的時間複雜度跟選擇的點x是密切相關的。
如果x是樹的重心,那麼時間複雜度就是O(nlogn)
但是由於WJMZBMR比較傻逼,他決定隨機在a中選擇一個點作為x!
Sevenkplus告訴他這樣做的最壞複雜度是O(n^2)
但是WJMZBMR就是不信>_<。。。
於是Sevenkplus花了幾分鐘寫了一個程式證明了這一點。。。你也試試看吧^_^
現在給你一顆樹,你能告訴WJMZBMR他的傻逼演算法需要的期望消耗時間嗎?(消耗時間按在Solve裡面的那個為標準)

輸入格式
第一行一個整數n,表示樹的大小
接下來n-1行每行兩個數a,b,表示a和b之間有一條邊
注意點是從0開始標號的

輸出格式
一行一個浮點數表示答案
四捨五入到小數點後4位
如果害怕精度跪建議用long double或者extended

樣例輸入
3
0 1
1 2
樣例輸出
5.6667
提示
n<=30000

答案要求的是點分樹上所有點的子樹size和(所有的點都是隨機的)的期望,根據期望的線性性質,答案等於每個點的期望被算的次數之和。我們考慮點x,他怎樣才會對y產生1點貢獻呢?只有先選了x,才能保證x對y做出一點貢獻。所以做出貢獻的概率為1dis(i,j)+1(因為每個點選到的概率均等,而得先選一個點)。所以最終的答案為i=1nj=1n1dis(i,j)+1
那麼怎麼求這個答案呢?
我們需要知道樹上路徑長度為x的路徑有多少條,我們都知道x為定值時,直接一個簡單的點分治即可。但是現在x是所有值,所以我們考慮依舊進行點分治,然後統計每個點到當前根的距離,那麼sum(x)=i=0xnum(i)num(xi)num(i)表示到根距離為i的點的數量。這是一個十分顯然的卷積的形式,我們直接卷積求出sum即可。
#include<bits/stdc++.h>
#define db double
#define MAXN 131072
#define MD 998244353
#define ll long long
using namespace std;
int read(){
    char c;int x;while(c=getchar(),c<'0'||c>'9');x=c-'0';
    while(c=getchar(),c>='0'&&c<='9') x=x*10+c-'0';return x;
}
const db pi=acos(-1.0);
struct comple{
    double x,y;
    comple (double xx=0,double yy=0){x=xx,y=yy;}
    comple operator+(const comple a){return comple(x+a.x,y+a.y);}
    comple operator-(const comple a){return comple(x-a.x,y-a.y);}
    comple operator*(const comple a){return comple(x*a.x-y*a.y,y*a.x+x*a.y);}
}a[MAXN],b[MAXN],w[2][MAXN];
int n,cnt,root,sum,limit=1,l,m,head[MAXN<<1],nxt[MAXN<<1],go[MAXN<<1],buck[MAXN<<1],f[MAXN],siz[MAXN],vis[MAXN],d[MAXN],dep[MAXN],r[MAXN<<1];
ll ans;
int pows(ll a,int b){
    ll base=1;
    while(b){
        if(b&1) base=base*a%MD;
        a=a*a%MD;b/=2;
    }
    return base;
}
void add(int x,int y){
    go[cnt]=y;nxt[cnt]=head[x];head[x]=cnt;cnt++;
    go[cnt]=x;nxt[cnt]=head[y];head[y]=cnt;cnt++;
}
void pre(){
    comple Wn(cos(2*pi/limit),sin(2*pi/limit));
    w[0][0]=w[1][0]=comple(1,0);
    for(int i=1;i<limit;i++) w[1][i]=w[1][i-1]*Wn;
    for(int i=1;i<limit;i++) w[0][i]=w[1][limit-i];
}
void FFT(comple *A,int type){
    for(int i=0;i<limit;i++) if(i<r[i]) swap(A[i],A[r[i]]);
    for(int mid=1;mid<limit;mid<<=1){
        for(int R=mid<<1,j=0;j<limit;j+=R){
            for(int k=0;k<mid;k++){
                comple x=A[j+k],y=w[type==1][limit/(mid<<1)*k]*A[j+k+mid];
                A[j+k]=x+y;A[j+k+mid]=x-y;
            }
        }
    }
}
void getroot(int x,int fa){
    f[x]=0;siz[x]=1;
    for(int i=head[x];i!=-1;i=nxt[i]){
        int to=go[i];
        if(to==fa||vis[to]) continue;
        getroot(to,x);f[x]=max(f[x],siz[to]);
        siz[x]+=siz[to]; 
    }
    f[x]=max(f[x],sum-siz[x]);
    if(f[x]<f[root]) root=x;
}
void getdep(int x,int fa){
    dep[++dep[0]]=d[x];a[d[x]].x++;m=max(m,d[x]);
    for(int i=head[x];i!=-1;i=nxt[i]){
        int to=go[i];
        if(vis[to]||to==fa) continue;
        d[to]=d[x]+1;getdep(to,x);
    }
}
void calc(int x,int w,int type){
    d[x]=w;dep[0]=0;m=0;limit=1;l=0;
    getdep(x,0);m<<=1;
    while(limit<=m) limit<<=1,l++;
    for(int i=1;i<limit;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
    pre();FFT(a,1);
    for(int i=0;i<limit;i++) a[i]=a[i]*a[i];
    FFT(a,-1);
    for(int i=0;i<=m;i++) buck[i]+=type*(int)(a[i].x/limit+0.5);
    for(int i=0;i<limit;i++) a[i].x=a[i].y=0;
}
void solve(int x){
    vis[x]=1;calc(x,0,1);
    for(int i=head[x];i!=-1;i=nxt[i]){
        int to=go[i];
        if(vis[to]) continue;
        calc(to,1,-1);
        sum=siz[to];root=0;
        getroot(to,0);
        solve(root);
    }
}
int main()
{
    n=read();memset(head,-1,sizeof(head));f[0]=2e9;sum=n;
    for(int i=1;i<n;i++){
        int x=read(),y=read();
        add(x,y);
    }
    getroot(1,0);
    solve(root);
    for(int i=0;i<n;i++) ans=(ans+1ll*pows(i+1,MD-2)*buck[i])%MD;
    printf("%d\n",ans);
    return 0;
}
//這裡我沒有寫小數,我用對998244353求逆元的形式。