1. 程式人生 > >HDU-5977 - Garden of Eden 點分治

HDU-5977 - Garden of Eden 點分治

HDU - 5977

題意:

  給定一顆樹,問樹上有多少節點對,節點對間包括了所有K種蘋果。

思路:

  點分治,對於每個節點記錄從根節點到這個節點包含的所有情況,類似狀壓,因為K《=10。然後處理每個重根連著的點的值:直接列舉每個點,然後找出這個點對應的每個子集,累計和子集互補的個數。

  列舉一個數的子集,例如1010,它的子集包括1010,1000,0010,0000.這裡有個技巧:

    for(int s = x; s; s = (s - 1) & x){
              res += 1ll*cnt[((1<<k)-1) ^ s];
       }
//#pragma GCC optimize(3)
//#pragma comment(linker, "/STACK:102400000,102400000")  //c++
// #pragma GCC diagnostic error "-std=c++11"
// #pragma comment(linker, "/stack:200000000")
// #pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")

#include <algorithm>
#include  <iterator>
#include  
<iostream> #include <cstring> #include <cstdlib> #include <iomanip> #include <bitset> #include <cctype> #include <cstdio> #include <string> #include <vector> #include <stack> #include <cmath> #include
<queue> #include <list> #include <map> #include <set> #include <cassert> using namespace std; #define lson (l , mid , rt << 1) #define rson (mid + 1 , r , rt << 1 | 1) #define debug(x) cerr << #x << " = " << x << "\n"; #define pb push_back #define pq priority_queue typedef long long ll; typedef unsigned long long ull; //typedef __int128 bll; typedef pair<ll ,ll > pll; typedef pair<int ,int > pii; typedef pair<int,pii> p3; //priority_queue<int> q;//這是一個大根堆q //priority_queue<int,vector<int>,greater<int> >q;//這是一個小根堆q #define fi first #define se second //#define endl '\n' #define OKC ios::sync_with_stdio(false);cin.tie(0) #define FT(A,B,C) for(int A=B;A <= C;++A) //用來壓行 #define REP(i , j , k) for(int i = j ; i < k ; ++i) #define max3(a,b,c) max(max(a,b), c); #define min3(a,b,c) min(min(a,b), c); //priority_queue<int ,vector<int>, greater<int> >que; const ll mos = 0x7FFFFFFF; //2147483647 const ll nmos = 0x80000000; //-2147483648 const int inf = 0x3f3f3f3f; const ll inff = 0x3f3f3f3f3f3f3f3f; //18 const int mod = 1e9+7; const double esp = 1e-8; const double PI=acos(-1.0); const double PHI=0.61803399; //黃金分割點 const double tPHI=0.38196601; template<typename T> inline T read(T&x){ x=0;int f=0;char ch=getchar(); while (ch<'0'||ch>'9') f|=(ch=='-'),ch=getchar(); while (ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar(); return x=f?-x:x; } /*-----------------------showtime----------------------*/ const int maxn = 50009; int a[maxn],g[maxn],dp[maxn],cnt[maxn]; vector<int>mp[maxn]; int n,k; ll ans = 0; void dfs(int u,int fa){ dp[u] = 1; for(int i=0; i<mp[u].size(); i++){ int v = mp[u][i]; if(g[v] || fa == v)continue; dfs(v, u); dp[u] += dp[v]; } } pii findg(int u,int fa, int sz){ int mx = 0; pii tmp = pii(inf, u); for(int i=0; i<mp[u].size(); i++){ int v = mp[u][i]; if(g[v] || fa == v)continue; tmp = min(tmp, findg(v,u,sz)); mx = max(mx, dp[v]); } mx = max(mx, sz - dp[u]); return min(tmp, pii(mx, u)); } void route(int u, int fa, vector<int>& ve, int sta){ sta = ((1<<a[u]) | sta); ve.pb(sta); for(int i=0; i<mp[u].size(); i++){ int v = mp[u][i]; if(v == fa || g[v])continue; route(v, u, ve, sta); } } ll cal(vector<int> &ve){ // memset(cnt, 0, sizeof(cnt)); for(int i=0; i<2000; i++) cnt[i] = 0; for(int i=0; i<ve.size(); i++){ cnt[ve[i]] ++; } /* Hash[it]-=1; ans+=Hash[(1<<m)-1]; for(int j=it;j;j=(j-1)&it){ ans+=Hash[((1<<m)-1)^j]; } Hash[it]+=1; */ ll res = 0; for(int i=0; i<ve.size(); i++){ int x = ve[i]; cnt[ve[i]]--; res += 1ll*cnt[(1<<k)-1]; for(int s = x; s; s = (s - 1) & x){ res += 1ll*cnt[((1<<k)-1) ^ s]; } cnt[ve[i]]++; } return res; } void divide(int u){ dfs(u,-1); int rt = findg(u, -1, dp[u]).se; g[rt] = 1; for(int i=0; i<mp[rt].size(); i++){ int v = mp[rt][i]; if(g[v])continue; divide(v); } vector<int>all; all.pb((1<<a[rt])); for(int i=0; i<mp[rt].size(); i++){ vector<int>ve; int v = mp[rt][i]; if(g[v])continue; route(v, -1, ve, (1<<a[rt])); ans -= 1ll*cal(ve); all.insert(all.end(),ve.begin(),ve.end()); } ans += 1ll*cal(all); g[rt] = 0; } int main(){ while(~scanf("%d%d", &n, &k)){ for(int i=1; i<=n; i++) scanf("%d", &a[i]), a[i]--; for(int i=1; i<=n; i++) mp[i].clear(); for(int i=1; i< n; i++) { int u,v; scanf("%d%d", &u, &v); mp[u].pb(v); mp[v].pb(u); } if(k == 1) { ans = 1ll*n*n; printf("%lld\n", ans); continue; } // memset(g,0,sizeof(g)); ans = 0; divide(1); printf("%lld\n", ans); } return 0 ; }
HDU-5977