POJ - 1741 (dsu on tree & 點分治)
阿新 • • 發佈:2020-11-13
題目連結:傳送門(POJ是真的煩)
題目思路:
對於dsu on tree 直接暴力統計深度--即u到根節點的距離(樹狀陣列維護桶,也可以用 排序雙指標--但單步容斥來得到合法答案),在子樹中查詢的查詢 k - (deep[u] - dis)+ dis +1 ,其中 dis 為子樹根的深度,+1是由於樹狀陣列的起點為1,而深度存在為0,故存入樹狀時整體平移一個單位,deep[u]-dis ,即u到子樹根的距離,+dis 是由於查詢的子樹中的 deep[v] = dis + v到子樹根的距離 ,要查詢的是v到子樹根的距離 但樹狀存的是deep[v] (deep[v] 和 v到子樹根的距離 是一一對應的),因此查詢的時候要加一個dis。
程式碼:
1 #include<functional> 2 #include<algorithm> 3 #include<cmath> 4 #include<cstdio> 5 #include<cctype> 6 #include<cstring> 7 #include<vector> 8 using namespace std; 9 typedef long long LL; 10 typedef unsigned long long uLL; 11 typedef pair<intView Code,int> pii; 12 typedef pair<LL,LL> pLL; 13 typedef pair<double,double> pdd; 14 const int N=2e4+5; 15 const int M=1e7+5; 16 const int inf=0x3f3f3f3f; 17 const LL mod=998244353; 18 const double eps=1e-8; 19 const long double pi=acos(-1.0L); 20 #define ls (i<<1) 21 #define rs (i<<1|1) 22#define fi first 23 #define se second 24 #define pb push_back 25 #define eb emplace_back 26 #define mk make_pair 27 #define mem(a,b) memset(a,b,sizeof(a)) 28 LL read() 29 { 30 LL x=0,t=1; 31 char ch; 32 while(!isdigit(ch=getchar())) if(ch=='-') t=-1; 33 while(isdigit(ch)){ x=10*x+ch-'0'; ch=getchar(); } 34 return x*t; 35 } 36 int son[N],sz[N],c[M],deep[N],ans,n,k,res; 37 int m=1e7; 38 vector<pii> e[N]; 39 inline int lowbit(int x) 40 { 41 return x&(-x); 42 } 43 void update(int x,int y) 44 { 45 for(int i=x;i<=m;i+=lowbit(i)) c[i]+=y; 46 } 47 int query(int x) 48 { 49 int tmp=0; 50 for(int i=x;i;i-=lowbit(i)) tmp+=c[i]; 51 return tmp; 52 } 53 void dfs(int u,int pre) 54 { 55 sz[u]=1; 56 son[u]=0; 57 for(int i=0;i<e[u].size();i++) 58 { 59 pii x=e[u][i]; 60 int v=x.fi; 61 if(v==pre) continue; 62 deep[v]=deep[u]+x.se; 63 dfs(v,u); 64 sz[u]+=sz[v]; 65 if(sz[son[u]]<sz[v]) son[u]=v; 66 } 67 } 68 void cal(int u,int pre,int dis) 69 { 70 if(k+1-deep[u]+2*dis<=0) return ; 71 res+=query(k+1-deep[u]+2*dis); 72 for(int i=0;i<e[u].size();i++) 73 { 74 pii x=e[u][i]; 75 int v=x.fi; 76 if(v==pre) continue; 77 cal(v,u,dis); 78 } 79 } 80 void doit(int u,int pre,int x) 81 { 82 if(deep[u]+1>m) return ; 83 update(deep[u]+1,x); 84 for(int i=0;i<e[u].size();i++) 85 { 86 pii t=e[u][i]; 87 int v=t.fi; 88 if(v==pre) continue; 89 doit(v,u,x); 90 } 91 } 92 void dfs2(int u,int pre,int flag) 93 { 94 for(int i=0;i<e[u].size();i++) 95 { 96 pii x=e[u][i]; 97 int v=x.fi; 98 if(v==pre||v==son[u]) continue; 99 dfs2(v,u,1); 100 } 101 if(son[u]) dfs2(son[u],u,0); 102 res=query(k+1+deep[u]); 103 update(deep[u]+1,1); 104 for(int i=0;i<e[u].size();i++) 105 { 106 pii x=e[u][i]; 107 int v=x.fi; 108 if(v==pre||v==son[u]) continue; 109 cal(v,u,deep[u]); 110 doit(v,u,1); 111 } 112 ans+=res; 113 if(flag) doit(u,pre,-1); 114 } 115 int main() 116 { 117 while(scanf("%d%d",&n,&k)==2&&(n||k)) 118 { 119 for(int i=1;i<=n;i++) e[i].clear(); 120 for(int i=1;i<n;i++) 121 { 122 int x=read(),y=read(),z=read(); 123 e[x].pb(mk(y,z)); 124 e[y].pb(mk(x,z)); 125 } 126 127 ans=res=0; 128 dfs(1,0); 129 dfs2(1,0,1); 130 printf("%d\n",ans); 131 } 132 return 0; 133 }
對於點分治,直接排序雙指標 統計合法答案即可(由於 dsu on tree 用了一次桶 ,這裡就用排序雙指標實現,這兩種實現方式效率是差不多的,不過排序雙指標更好理解和實現)
程式碼:
1 #include<functional> 2 #include<algorithm> 3 #include<cmath> 4 #include<cstdio> 5 #include<cctype> 6 #include<cstring> 7 #include<vector> 8 using namespace std; 9 typedef long long LL; 10 typedef unsigned long long uLL; 11 typedef pair<int,int> pii; 12 typedef pair<LL,LL> pLL; 13 typedef pair<double,double> pdd; 14 const int N=1e4+5; 15 const int M=1e7+5; 16 const int inf=0x3f3f3f3f; 17 const LL mod=998244353; 18 const double eps=1e-8; 19 const long double pi=acos(-1.0L); 20 #define ls (i<<1) 21 #define rs (i<<1|1) 22 #define fi first 23 #define se second 24 #define pb push_back 25 //#define eb emplace_back 26 #define mk make_pair 27 #define mem(a,b) memset(a,b,sizeof(a)) 28 LL read() 29 { 30 LL x=0,t=1; 31 char ch; 32 while(!isdigit(ch=getchar())) if(ch=='-') t=-1; 33 while(isdigit(ch)){ x=10*x+ch-'0'; ch=getchar(); } 34 return x*t; 35 36 } 37 int n,k; 38 vector<pii> e[N]; 39 int res,rt,ans,vis[N],a[N],cnt,sz[N]; 40 void dfs(int u,int pre,int tot) 41 { 42 int ma=0; 43 sz[u]=1; 44 for(int i=0;i<e[u].size();i++) 45 { 46 pii x=e[u][i]; 47 int v=x.fi,w=x.se; 48 if(v==pre||vis[v]) continue; 49 dfs(v,u,tot); 50 sz[u]+=sz[v]; 51 ma=max(ma,sz[v]); 52 } 53 ma=max(ma,tot-sz[u]); 54 if(ma<res) res=ma,rt=u; 55 } 56 void doit(int u,int pre,int dis) 57 { 58 a[++cnt]=dis; 59 for(int i=0;i<e[u].size();i++) 60 { 61 pii x=e[u][i]; 62 if(!vis[x.fi]&&x.fi!=pre) doit(x.fi,u,dis+x.se); 63 } 64 } 65 int cal(int u,int pre,int dis) 66 { 67 cnt=0; 68 int tmp=0; 69 doit(u,pre,dis); 70 sort(a+1,a+cnt+1); 71 //for(int i=1;i<=cnt;i++) printf("%d%c",a[i],i==cnt?'\n':' '); 72 int l=1,r=cnt; 73 for(;l<r;l++) 74 { 75 while(l<r&&a[l]+a[r]>k) r--; 76 tmp+=r-l; 77 } 78 //printf("tmp = %d\n",tmp); 79 return tmp; 80 } 81 void solve(int u) 82 { 83 //printf("u = %d\n",u); 84 ans+=cal(u,0,0); 85 for(int i=0;i<e[u].size();i++) 86 { 87 pii x=e[u][i]; 88 int v=x.fi,w=x.se; 89 if(vis[v]) continue; 90 ans-=cal(v,u,w); 91 } 92 vis[u]=1; 93 for(int i=0;i<e[u].size();i++) 94 { 95 pii x=e[u][i]; 96 int v=x.fi,w=x.se; 97 if(vis[v]) continue; 98 res=inf; 99 dfs(v,u,sz[v]); 100 solve(rt); 101 } 102 } 103 void init() 104 { 105 ans=0; 106 for(int i=1;i<=n;i++) vis[i]=0; 107 for(int i=1;i<=n;i++) e[i].clear(); 108 } 109 int main() 110 { 111 while(~scanf("%d%d",&n,&k)&&(n||k)) 112 { 113 init(); 114 for(int i=1;i<n;i++) 115 { 116 int x=read(),y=read(),z=read(); 117 e[x].pb(mk(y,z)); 118 e[y].pb(mk(x,z)); 119 } 120 res=inf; 121 dfs(1,0,n); 122 solve(rt); 123 printf("%d\n",ans); 124 } 125 return 0; 126 }View Code