1. 程式人生 > 實用技巧 >POJ - 1741 (dsu on tree & 點分治)

POJ - 1741 (dsu on tree & 點分治)

題目連結:傳送門(POJ是真的煩)

題目思路:

對於dsu on tree 直接暴力統計深度--即u到根節點的距離(樹狀陣列維護桶,也可以用 排序雙指標--但單步容斥來得到合法答案),在子樹中查詢的查詢 k - (deep[u] - dis)+ dis +1 ,其中 dis 為子樹根的深度,+1是由於樹狀陣列的起點為1,而深度存在為0,故存入樹狀時整體平移一個單位,deep[u]-dis ,即u到子樹根的距離,+dis 是由於查詢的子樹中的 deep[v] = dis + v到子樹根的距離 ,要查詢的是v到子樹根的距離 但樹狀存的是deep[v] (deep[v] 和 v到子樹根的距離 是一一對應的),因此查詢的時候要加一個dis。

程式碼:

  1 #include<functional>
  2 #include<algorithm>
  3 #include<cmath>
  4 #include<cstdio>
  5 #include<cctype>
  6 #include<cstring>
  7 #include<vector>
  8 using namespace std;
  9 typedef long long LL;
 10 typedef unsigned long long uLL;
 11 typedef pair<int
,int> pii; 12 typedef pair<LL,LL> pLL; 13 typedef pair<double,double> pdd; 14 const int N=2e4+5; 15 const int M=1e7+5; 16 const int inf=0x3f3f3f3f; 17 const LL mod=998244353; 18 const double eps=1e-8; 19 const long double pi=acos(-1.0L); 20 #define ls (i<<1) 21 #define rs (i<<1|1) 22
#define fi first 23 #define se second 24 #define pb push_back 25 #define eb emplace_back 26 #define mk make_pair 27 #define mem(a,b) memset(a,b,sizeof(a)) 28 LL read() 29 { 30 LL x=0,t=1; 31 char ch; 32 while(!isdigit(ch=getchar())) if(ch=='-') t=-1; 33 while(isdigit(ch)){ x=10*x+ch-'0'; ch=getchar(); } 34 return x*t; 35 } 36 int son[N],sz[N],c[M],deep[N],ans,n,k,res; 37 int m=1e7; 38 vector<pii> e[N]; 39 inline int lowbit(int x) 40 { 41 return x&(-x); 42 } 43 void update(int x,int y) 44 { 45 for(int i=x;i<=m;i+=lowbit(i)) c[i]+=y; 46 } 47 int query(int x) 48 { 49 int tmp=0; 50 for(int i=x;i;i-=lowbit(i)) tmp+=c[i]; 51 return tmp; 52 } 53 void dfs(int u,int pre) 54 { 55 sz[u]=1; 56 son[u]=0; 57 for(int i=0;i<e[u].size();i++) 58 { 59 pii x=e[u][i]; 60 int v=x.fi; 61 if(v==pre) continue; 62 deep[v]=deep[u]+x.se; 63 dfs(v,u); 64 sz[u]+=sz[v]; 65 if(sz[son[u]]<sz[v]) son[u]=v; 66 } 67 } 68 void cal(int u,int pre,int dis) 69 { 70 if(k+1-deep[u]+2*dis<=0) return ; 71 res+=query(k+1-deep[u]+2*dis); 72 for(int i=0;i<e[u].size();i++) 73 { 74 pii x=e[u][i]; 75 int v=x.fi; 76 if(v==pre) continue; 77 cal(v,u,dis); 78 } 79 } 80 void doit(int u,int pre,int x) 81 { 82 if(deep[u]+1>m) return ; 83 update(deep[u]+1,x); 84 for(int i=0;i<e[u].size();i++) 85 { 86 pii t=e[u][i]; 87 int v=t.fi; 88 if(v==pre) continue; 89 doit(v,u,x); 90 } 91 } 92 void dfs2(int u,int pre,int flag) 93 { 94 for(int i=0;i<e[u].size();i++) 95 { 96 pii x=e[u][i]; 97 int v=x.fi; 98 if(v==pre||v==son[u]) continue; 99 dfs2(v,u,1); 100 } 101 if(son[u]) dfs2(son[u],u,0); 102 res=query(k+1+deep[u]); 103 update(deep[u]+1,1); 104 for(int i=0;i<e[u].size();i++) 105 { 106 pii x=e[u][i]; 107 int v=x.fi; 108 if(v==pre||v==son[u]) continue; 109 cal(v,u,deep[u]); 110 doit(v,u,1); 111 } 112 ans+=res; 113 if(flag) doit(u,pre,-1); 114 } 115 int main() 116 { 117 while(scanf("%d%d",&n,&k)==2&&(n||k)) 118 { 119 for(int i=1;i<=n;i++) e[i].clear(); 120 for(int i=1;i<n;i++) 121 { 122 int x=read(),y=read(),z=read(); 123 e[x].pb(mk(y,z)); 124 e[y].pb(mk(x,z)); 125 } 126 127 ans=res=0; 128 dfs(1,0); 129 dfs2(1,0,1); 130 printf("%d\n",ans); 131 } 132 return 0; 133 }
View Code

對於點分治,直接排序雙指標 統計合法答案即可(由於 dsu on tree 用了一次桶 ,這裡就用排序雙指標實現,這兩種實現方式效率是差不多的,不過排序雙指標更好理解和實現)

程式碼:

  1 #include<functional>
  2 #include<algorithm>
  3 #include<cmath>
  4 #include<cstdio>
  5 #include<cctype>
  6 #include<cstring>
  7 #include<vector>
  8 using namespace std;
  9 typedef long long LL;
 10 typedef unsigned long long uLL;
 11 typedef pair<int,int> pii;
 12 typedef pair<LL,LL> pLL;
 13 typedef pair<double,double> pdd;
 14 const int N=1e4+5;
 15 const int M=1e7+5;
 16 const int inf=0x3f3f3f3f;
 17 const LL mod=998244353;
 18 const double eps=1e-8;
 19 const long double pi=acos(-1.0L);
 20 #define ls (i<<1)
 21 #define rs (i<<1|1)
 22 #define fi first
 23 #define se second
 24 #define pb push_back
 25 //#define eb emplace_back
 26 #define mk make_pair
 27 #define mem(a,b) memset(a,b,sizeof(a))
 28 LL read()
 29 {
 30     LL x=0,t=1;
 31     char ch;
 32     while(!isdigit(ch=getchar())) if(ch=='-') t=-1;
 33     while(isdigit(ch)){ x=10*x+ch-'0'; ch=getchar(); }
 34     return x*t;
 35 
 36 }
 37 int n,k;
 38 vector<pii> e[N];
 39 int res,rt,ans,vis[N],a[N],cnt,sz[N];
 40 void dfs(int u,int pre,int tot)
 41 {
 42     int ma=0;
 43     sz[u]=1;
 44     for(int i=0;i<e[u].size();i++)
 45     {
 46         pii x=e[u][i];
 47         int v=x.fi,w=x.se;
 48         if(v==pre||vis[v]) continue;
 49         dfs(v,u,tot);
 50         sz[u]+=sz[v];
 51         ma=max(ma,sz[v]);
 52     }
 53     ma=max(ma,tot-sz[u]);
 54     if(ma<res) res=ma,rt=u;
 55 }
 56 void doit(int u,int pre,int dis)
 57 {
 58     a[++cnt]=dis;
 59     for(int i=0;i<e[u].size();i++)
 60     {
 61         pii x=e[u][i];
 62         if(!vis[x.fi]&&x.fi!=pre) doit(x.fi,u,dis+x.se);
 63     }
 64 }
 65 int cal(int u,int pre,int dis)
 66 {
 67     cnt=0;
 68     int tmp=0;
 69     doit(u,pre,dis);
 70     sort(a+1,a+cnt+1);
 71     //for(int i=1;i<=cnt;i++) printf("%d%c",a[i],i==cnt?'\n':' ');
 72     int l=1,r=cnt;
 73     for(;l<r;l++)
 74     {
 75         while(l<r&&a[l]+a[r]>k) r--;
 76         tmp+=r-l;
 77     }
 78     //printf("tmp = %d\n",tmp);
 79     return tmp;
 80 }
 81 void solve(int u)
 82 {
 83     //printf("u = %d\n",u);
 84     ans+=cal(u,0,0);
 85     for(int i=0;i<e[u].size();i++)
 86     {
 87         pii x=e[u][i];
 88         int v=x.fi,w=x.se;
 89         if(vis[v]) continue;
 90         ans-=cal(v,u,w);
 91     }
 92     vis[u]=1;
 93     for(int i=0;i<e[u].size();i++)
 94     {
 95         pii x=e[u][i];
 96         int v=x.fi,w=x.se;
 97         if(vis[v]) continue;
 98         res=inf;
 99         dfs(v,u,sz[v]);
100         solve(rt);
101     }
102 }
103 void init()
104 {
105     ans=0;
106     for(int i=1;i<=n;i++) vis[i]=0;
107     for(int i=1;i<=n;i++) e[i].clear();
108 }
109 int main()
110 {
111     while(~scanf("%d%d",&n,&k)&&(n||k))
112     {
113         init();
114         for(int i=1;i<n;i++)
115         {
116             int x=read(),y=read(),z=read();
117             e[x].pb(mk(y,z));
118             e[y].pb(mk(x,z));
119         }
120         res=inf;
121         dfs(1,0,n);
122         solve(rt);
123         printf("%d\n",ans);
124     }
125     return 0;
126 }
View Code