BZOJ3689 - 異或之(trie)
阿新 • • 發佈:2020-08-14
題意
給定n個非負整數A[1], A[2], ……, A[n]。
對於每對(i, j)滿足1 <= i < j <= n,得到一個新的數A[i] xor A[j],這樣共有n*(n-1)/2個新的數。求這些數(不包含A[i])中前k小的數。
思路
一看到異或,就想到可能要用trie樹來處理。
層數越深,兩個數異或的結果越小。
所以可以從最底層往上處理,對trie樹每一層的每個結點左子樹和右子樹包含的數暴力一一異或,一直到得到超過k個數。由於每上一層,增加的計算最多為2倍,所以複雜度是線性的。
得到的這些數就一定是排前的。最後排序輸出k個數即可。
#include <bits/stdc++.h> #define endl '\n' #define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0) #define FILE freopen(".//data_generator//in.txt","r",stdin),freopen("res.txt","w",stdout) #define FI freopen(".//data_generator//in.txt","r",stdin) #define FO freopen("res.txt","w",stdout) #define pb push_back #define mp make_pair #define seteps(N) fixed << setprecision(N) typedef long long ll; using namespace std; /*-----------------------------------------------------------------*/ ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;} #define INF 0x3f3f3f3f const int N = 2e6 + 10; const double eps = 1e-5; int tr[N][2]; int l[N], r[N]; vector<int> points[50]; vector<int> arr; vector<int> ans; int head[N]; int cnt[N]; int k; int ct; void insert(int x) { int cur = 0; for(int i = 30; i >= 0; i--) { bool ntp = (x & (1 << i)); if(!tr[cur][ntp]) { tr[cur][ntp] = ++ct; cur = ct; tr[cur][0] = tr[cur][1] = 0; } else { cur = tr[cur][ntp]; } } cnt[cur]++; } void dfs(int cur, int dep, int val) { points[dep + 1].push_back(cur); if(dep < 0) { l[cur] = arr.size(); for(int i = 0 ;i < cnt[cur]; i++) { arr.push_back(val); } r[cur] = arr.size() - 1; return ; } if(tr[cur][0]) { dfs(tr[cur][0], dep - 1, val); l[cur] = l[tr[cur][0]]; r[cur] = r[tr[cur][0]]; } if(tr[cur][1]) { dfs(tr[cur][1], dep - 1, val + (1 << dep)); if(!tr[cur][0]) { l[cur] = l[tr[cur][1]]; } r[cur] = r[tr[cur][1]]; } } int main() { //FILE; IOS; int n; cin >> n >> k; for(int i = 1; i <= n; i++) { int v; cin >> v; insert(v); } dfs(0, 30, 0); int tot = 0; for(int p : points[0]) { int c = cnt[p]; tot += c * (c - 1) / 2; } if(tot >= k) { for(int i = 0; i < k; i++) { cout << 0 << endl; } } else { for(int i = 0; i < tot; i++) ans.push_back(0); for(int i = 1; i <= 31; i++) { for(int p : points[i]) { int ls = tr[p][0], rs = tr[p][1]; if(ls && rs) { for(int i = l[ls]; i <= r[ls]; i++) { for(int j = l[rs]; j <= r[rs]; j++) { ans.push_back(arr[i] ^ arr[j]); } } } } if(ans.size() > k) break; } sort(ans.begin(), ans.end()); for(int i = 0; i < k; i++) { cout << ans[i] << " "; } } }