H - Set【18南京網路賽】
阿新 • • 發佈:2020-12-02
H - Set
題意:
有 \(n\) 個集合,初始時第 \(i\) 個集合中的數只有 \(a_i\)
支援三種操作
1 u v
若第 \(u\) 個數和第 \(v\) 個數在不同的集合中,則將這兩個集合合併
2 u
把第 \(u\) 個數所在的集合中所有的數都+1
3 u k x
詢問操作,你需要輸出第 \(u\) 個數所在的集合中滿足 \(t = x \ \ (mod\ \ 2^k)\) 的 \(t\) 的個數
思路:
其實就是用 Tire
樹來儲存集合裡的數,因為只需要合併,直接合並就好。
對於操作二,只需要交換左子樹和右子樹,這樣左子樹就相當於全部加上了 1
,然後遞迴處理左子樹,實在是妙
在插入與合併的時候維護節點的子樹大小資訊,在查詢操作的時候直接輸出子樹大小就可以
這個題當時寫的時候 debug
滴了一下午
void merge(int rt1,int rt2){ // 暴力合併好像就可以 // 把 rt1 合併到 rt2 tr[rt2].size += tr[rt1].size; if(tr[rt1].s[0]){ if (!tr[rt2].s[0]) tr[rt2].s[0] = tr[rt1].s[0]; else merge(tr[rt1].s[0], tr[rt2].s[0]); } if(tr[rt1].s[1]){ if (!tr[rt2].s[1])tr[rt2].s[1] = tr[rt1].s[1]; else merge(tr[rt1].s[1], tr[rt2].s[1]); } }
這個是正確的版本,我當時寫的艾斯比版本如下
void merge(int rt1,int rt2){ // 暴力合併好像就可以 // 把 rt1 合併到 rt2 tr[rt2].size += tr[rt1].size; if(tr[rt1].s[0]){ if (!tr[rt2].s[0]) tr[rt2].s[0] = getnode(); merge(tr[rt1].s[0], tr[rt2].s[0]); } if(tr[rt1].s[1]){ if (!tr[rt2].s[1])tr[rt2].s[1] = getnode(); merge(tr[rt1].s[1], tr[rt2].s[1]); } }
一直 段錯誤, 在合併的時候加了一個判斷大小交換開大資料就過了 3/5
個點,我才意識到 段錯誤是因為 getnode
呼叫太多了
我一直奇怪我這樣的寫法本應更節省記憶體才對
最後只需要 200Mb
記憶體就可以
#include<bits/stdc++.h>
using namespace std;
const int N = 6e5 + 10;
struct node{
int s[2], size;
}tr[N*30];
int tot, root[N];
int getnode(){
tot++;
tr[tot].s[0] = tr[tot].s[1] = 0;
tr[tot].size = 0;
return tot;
}
int fa[N];
int find(int a){
return a == fa[a] ? a : fa[a] = find(fa[a]);
}
void insert(int rt,int val){
int cur = rt;
bitset<30>s(val);
tr[cur].size++;
for (int i = 0;i < 30;i++) {
int v = s[i];
if (!tr[cur].s[v])tr[cur].s[v] = getnode();
cur = tr[cur].s[v];
tr[cur].size++;
}
}
void merge(int rt1,int rt2){
// 暴力合併好像就可以
// 把 rt1 合併到 rt2
tr[rt2].size += tr[rt1].size;
if(tr[rt1].s[0]){
if (!tr[rt2].s[0]) tr[rt2].s[0] = tr[rt1].s[0];
else merge(tr[rt1].s[0], tr[rt2].s[0]);
}
if(tr[rt1].s[1]){
if (!tr[rt2].s[1])tr[rt2].s[1] = tr[rt1].s[1];
else merge(tr[rt1].s[1], tr[rt2].s[1]);
}
}
void add(int rt){
if (!rt)return;
swap(tr[rt].s[0], tr[rt].s[1]);
add(tr[rt].s[0]);
}
int query(int rt,int k,int x){
bitset<30>s(x);
int cur = rt;
for(int i = 0;i < k;i++){
if (!tr[cur].s[s[i]])return 0;
cur = tr[cur].s[s[i]];
}
return tr[cur].size;
}
int n, m;
int main(){
scanf("%d%d", &n, &m);
for(int i = 1;i <= n;i++){
fa[i] = i;root[i] = getnode();
int x;scanf("%d", &x);
insert(root[i], x);
}
while(m--){
int op, u, v, k, x;
scanf("%d", &op);
if(op == 1){
scanf("%d%d", &u, &v);
u = find(u);
v = find(v);
if (u == v)continue;
//if (tr[root[u]].size > tr[root[v]].size)swap(u, v);
merge(root[u], root[v]);
fa[u] = v;
}
else if(op == 2){
scanf("%d", &u);
u = find(u);
add(root[u]);
}
else{
scanf("%d%d%d", &u, &k, &x);
u = find(u);
printf("%d\n", query(root[u], k, x));
}
}
}