1. 程式人生 > 其它 >2021“MINIEYE杯”中國大學生演算法設計超級聯賽(2)1004 - I love counting trie操作好題

2021“MINIEYE杯”中國大學生演算法設計超級聯賽(2)1004 - I love counting trie操作好題

題意:
一個長度為\(n\)的序列,每個位置\(i\)有一個權重\(w_i\),然後有\(Q\)個詢問,每次詢問包含\(l,r,a,b\)四個引數,其詢問含義為區間\([l,r]\)有多少種權值\(w_i\)使得,\(w_i⊕a \le b\)

思路:
這個題其實一看到的話找出符合特定大小關係的異或值,就會往\(trie\)樹上考慮,又看到詢問讓我們處理的是區間內不同種類的數,也就是相同的數我們不能重複計算,這裡之前也做過類似套路的題,記錄一下字首和字尾,然後查詢其實就變成了詢問區間\([l,r]\)內,後繼節點位置\(> r\)的數有哪些,我們只統計這些數對於答案的貢獻即可。
但是比賽時,一直不會處理這個區間問題,想到上可持久化\(trie\)

但是苦於不會在字首關係中維護這個貢獻,故只能學習\(std\)的做法。

首先題目沒要求我們帶修,那麼就是可以離線處理的了。
再看每個詢問,看似是一個區間問題,但是考慮到\(trie\)的構造過程,每個數都會在其二進位制位下被分成\(log\)份,所以每個詢問的\(a,b\)其實就可以對應\(trie\)\(log\)個節點。然後我們再把每個權值\(w_i\)對應到\(trie\)\(log\)個節點中。這樣每個節點就會儲存當前節點對應的數的位置和這一二進位制位上的\(0/1\)值,總個數是一定不會超過\(nlog_2^{w_i}\)的。所以我們詢問\(trie\)樹的所有節點,對每個節點我們暴力的更新這個節點對應的數和詢問,詢問完之後,再暴力的消除影響,這個區間統計和單點更新的過程我們就可以用樹狀陣列解決,總複雜度\(O(n*log_2^{w_i}*log_2^n)\)

下面詳述一下,把一個詢問掛在\(trie\)某個節點的具體過程,對於\(a,b\)的當前二進位制位\(v_a,v_b\)來說。

  • \(v_b= 1\),那麼如果我們去找\(v_a\)對應的相同值,它和\(v_a\)異或在一起很明顯是\(=0\)的,那麼不用繼續向下搜了,這個位置對應的數一定是滿足條件的,因為對於這一位來說\(0 < 1\)成立了,此時我們就把當前詢問的資訊即\(l,r,id\)存入這個節點表示,這個節點內對於這個詢問來說是有合法貢獻的。然後去看向另一個節點走即可。
  • \(v_b = 0\),那麼我們要想小於等於它,那就一定要向和\(v_a\)相同的方向,這條分支就不涉及存入節點資訊這個操作了,繼續向下一位二級制走就可以了。

這樣最後我們就會把一個詢問最多存入\(trie\)\(log\)個節點中,表示這些節點中的數會對這個詢問產生部分的貢獻,只需要最後對每個節點做一次統計,那麼最後得出的就是每個詢問的完整貢獻。

再說區間內的不同種類數是如何保證的,這樣想,對於每個\(trie\)的節點,我們存的資訊既有對應權值節點\(w_i\)的資訊,也有對應詢問\(q_i\)的資訊。我們通過排序保證詢問前的插入節點均為合法即可,哪些是合法的呢,上面已經提到過,維護一個後繼節點的位置,位置\(>r\)的即為合法權值。把不合法的插入值放入對應詢問的後面即可,這樣就保證了再離線處理每個詢問的時候的正確性。

ps:這道題涉及到的處理方法和小技巧還是很巧妙的,思路也很妙,把一個詢問拆成幾個部分,一個個部分去做,最後加和貢獻,特別的存入\(trie\)節點的操作,有助於幫助進一步理解\(trie\)

#include <bits/stdc++.h>

using namespace std;

#define pb push_back
#define eb emplace_back
#define MP make_pair
#define pii pair<int,int>
#define pll pair<ll,ll>
#define lson rt<<1
#define rson rt<<1|1
#define CLOSE std::ios::sync_with_stdio(false)
#define sz(x) (int)(x).size()
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-6;
const int N = 1e5 + 10;
int n,c[N],m,bit[N],ans[N];
struct node {
	int l,r,np,id;//np代表當前節點的後繼節點的位置
	bool operator < (const node &u) const {
		if(np != u.np) return np > u.np;//只統計後繼節點不在當前統計區間內的點
		else if(id != u.id) return id > u.id;//相等時把詢問放在前面
		else return l < u.l;//只出現一次的節點 直接按大小排即可
	}
};

std::vector<node>root[N*20];
int lowbit(int x) { return x & (-x); }
void add(int p,int v) {
	while(p <= n) { bit[p] += v; p += lowbit(p); }
}
int query(int p) {
	int ans = 0;
	while(p) { ans += bit[p]; p -= lowbit(p); }
	return ans;
}
int trie[N * 20][2],tot,last[N],nxt[N];
void insert(int x,int id) {
	int p = 0;
	for(int i = 20;i >= 0;i --) {
		int v = (x >> i) & 1;
		if(!trie[p][v]) trie[p][v] = ++tot;
		p = trie[p][v];
		root[p].pb(node{id,0,nxt[id],0});
	}
}

void query(int l,int r,int x,int y,int id) {
	int p = 0;
	for(int i = 20;i >= 0;i --) {
		int v1 = (x >> i) & 1,v2 = (y >> i) & 1;
		if(v2 == 1) {
			if(trie[p][v1]) root[trie[p][v1]].pb(node{l,r,r,id});//走這邊一定小對應 0 < 1 故可以提前算入答案
			p = trie[p][v1^1];
			if(!p) break;
		}
		else {
			p = trie[p][v1];
			if(!p) break;
		}
	}
	// cout << p << '\n';
	root[p].pb(node{l,r,r,id});//在結尾處新增詢問
}

void solve() {
	scanf("%d",&n); 
	for(int i = 1;i <= n;i ++) {
		scanf("%d",&c[i]);
	}
	//求區間內只出現一次的數的處理時 是可以相當於前驅和後繼節點的判斷來處理的
	for(int i = n;i >= 1;i --) {
		if(!last[c[i]]) nxt[i] = n + 1;
		else nxt[i] = last[c[i]];
		last[c[i]] = i;
	}
	for(int i = 1;i <= n;i ++) insert(c[i],i);
	scanf("%d",&m);
	int l,r,a,b;
	for(int i = 1;i <= m;i ++) {
		scanf("%d%d%d%d",&l,&r,&a,&b);
		query(l,r,a,b,i);
	}
	// cout << tot << '\n';
	for(int i = 1;i <= tot;i ++) {
		sort(root[i].begin(),root[i].end());
		for(int j = 0;j < sz(root[i]);j ++) {
			// cout << i << " : " << root[i][j].l << ' ' << root[i][j].r << ' ' << root[i][j].np << ' ' << root[i][j].id << '\n';
			if(root[i][j].id == 0) {//代表是插入 先做插入
				add(root[i][j].l,1);
			}
			else {
				// cout << query(root[i][j].r) - query(root[i][j].l-1)
				ans[root[i][j].id] += query(root[i][j].r) - query(root[i][j].l-1);
			}
		}
		for(int j = 0;j < sz(root[i]);j ++) if(root[i][j].id == 0) add(root[i][j].l,-1);
	}
	for(int i = 1;i <= m;i ++) printf("%d\n",ans[i]);
	return ;
}

int main() {
	int T = 1;
	while(T--) {
		solve();
	}
	return 0;
}