1. 程式人生 > 實用技巧 >【題解】 CF1404C Fixed Point Removal 線段樹+樹狀陣列+帶悔貪心

【題解】 CF1404C Fixed Point Removal 線段樹+樹狀陣列+帶悔貪心

Legend

Link \(\textrm{to Codeforces}\)

給定長度為 \(n\ (1 \le n \le 3\times 10^5)\) 的陣列 \(a_i\ (1 \le a_i \le n)\)

每次你可以選擇刪除一個位置上的數字當且僅當它的下標等於數字本身,即 \(a_i=i\)

刪除後陣列後面一段會平移過來,改變下標。

\(q\ (1 \le q \le 3 \times 10^5)\) 組詢問,給出 \(x,y\ (0 \le x,y\)\(x+y < n)\) 詢問:

前面 \(x\) 個強制不能刪除,後面 \(y\) 個強制不能刪除,最多可以刪掉多少個數字。

Editorial

Inspiration

考慮只詢問一次怎麼做?每次從右邊找起,能刪就刪除。但這樣是 \(O(qn^2)\) 的。

題目中提到的查詢是無視一個字首和一個字尾,無視字尾應該很好做,可能隨便減減就能解決。

無視字首則留下了一個字尾,這不禁讓人想到預處理每一個字尾的答案。

所以我們從右往左依次加入數字,這樣每次都考慮的是一個字尾。

假設單獨選出 \([l,n]\) 區間,最初位置在 \(x\) 上的數字可以被刪掉,那麼顯然,選出 \([l-1,n]\) 的時候也能被刪掉。

所以我們每加入一個新的數字就去檢查有哪些數字可以被刪去,這個怎麼快速找呢?

optimization

我們維護一個初始的 \(v_i=i-a_i\)

  • \(v_i < 0\) 的都一定不能被刪除,因為數字只能前移;
  • \(v_i =0\) 的是可以被刪除的;
  • \(v_i > 0\) 是潛在的可能被以後刪除的。

每次只要找到一個 \(v_x=0\) 的位置,刪除它即可。

刪除一個位於 \(x\) 的數字之後,我們就手動把 \(v_i\ (i \in [x,n])\) 全部 \(-1\),表示前移一位。

發現我們可以用線段樹維護,找到最右側的 \(v_i=0\) 的位置,這樣就保證不會把其他 \(v_j=0\) 的位置破壞掉。

那麼這樣直到最後我們對於每一個位置上的數字,都可以得到一個二元組 \((suf_i ,id_i)\)

\(suf_i\)

表示的是這個數字是在第幾個位置上的數字被加進來之後才刪掉的。

\(id_i\) 表示這個位置的最初下標。

考慮對於每一組 \((x,y)\) 的詢問,我們要求什麼,實際上是要求形如 \((i,j)\) 且同時滿足 \(i \ge x+1\)\(j \le n-y\) 的二元組數量。

這個離線之後就是樹狀陣列板子了。

Code

我在程式碼中,並不是維護的 \(v_i=0\) 最靠右的位置,而是選擇了一個 \(v_i<0\) 的位置,這樣也是可以通過的。

為什麼呢?可以類比帶悔貪心的思路呀,一定是可以通過改變刪除順序使得這個位置也能被刪除的。

#include <bits/stdc++.h>

#define LL long long

const int MX = 3e5 + 233;

using namespace std;

int read(){
	char k = getchar(); int x = 0;
	while(k < '0' || k > '9') k = getchar();
	while(k >= '0' && k <= '9') x = x * 10 + k - '0' ,k = getchar();
	return x;
}

int a[MX];

struct node{
	int l ,r ,mn ,mnfr ,add;
	node *lch ,*rch;
	node operator +(node B)const{
		node C;
		C.mn = min(this->mn ,B.mn);
		C.mnfr = (C.mn == this->mn) ? this->mnfr : B.mnfr;
		return C;
	}
}*root;

void pushup(node *x){
	x->mn = min(x->lch->mn ,x->rch->mn);
	x->mnfr = x->lch->mn == x->mn ? x->lch->mnfr : x->rch->mnfr;
}

void doadd(node *x ,int v){x->mn += v ,x->add += v;}
void pushdown(node *x){
	if(x->add){
		doadd(x->lch ,x->add);
		doadd(x->rch ,x->add);
		x->add = 0;
	}
}

node *build(int l ,int r ,int *__){
	node *x = new node; x->l = l ,x->r = r; x->add = 0;
	if(l == r) x->lch = x->rch = nullptr ,x->mn = __[l] ,x->mnfr = l;
	else{int mid = (l + r) >> 1;
		x->lch = build(l ,mid ,__);
		x->rch = build(mid + 1 ,r ,__);
		pushup(x);
	}return x;
}

void add(node *x ,int l ,int r ,int v){
	if(l <= x->l && x->r <= r) return doadd(x ,v);
	pushdown(x);
	if(l <= x->lch->r) add(x->lch ,l ,r ,v);
	if(r > x->lch->r) add(x->rch ,l ,r ,v);
	return pushup(x);
}

node query(node *x ,int l ,int r){
	if(l <= x->l && x->r <= r) return *x;
	pushdown(x);
	if(l <= x->lch->r && r > x->lch->r) return query(x->lch ,l ,r) + query(x->rch ,l ,r);
	if(l <= x->lch->r) return query(x->lch ,l ,r);
	return query(x->rch ,l ,r);
}

void change(node *x ,int p ,int v){
	if(p <= x->l && x->r <= p) return x->mn = v ,void();
	pushdown(x);
	if(p <= x->lch->r) change(x->lch ,p ,v);
	if(p > x->lch->r) change(x->rch ,p ,v);
	return pushup(x);
}

int pcnt;
struct Point{
	int x ,y ,type ,coef ,id;
	bool operator <(const Point &B)const{
		return x == B.x ? y == B.y ? type < B.type : y < B.y : x < B.x;
	}
}P[MX * 3];

class BIT{
	private:
		int data[MX];
	public:
		void add(int x ,int v){while(x < MX) data[x] += v ,x += x & -x;}
		int sum(int x){int s = 0; while(x > 0) s += data[x] ,x -= x & -x; return s;}
}C;

int Ans[MX];
int main(){
	int n = read() ,q = read();
	for(int i = 1 ; i <= n ; ++i){
		a[i] = read();
		a[i] = (a[i] > i ? INT_MAX : i - a[i]);
	}
	root = build(1 ,n ,a);
	for(int i = n ; i ; --i){
		while(true){
			node kksk = query(root ,i ,n);
			if(kksk.mn <= 0){
				P[++pcnt] = (Point){i ,kksk.mnfr ,0 ,0 ,0};
				// fprintf(stderr ,"(%d ,%d)\n" ,i ,kksk.mnfr);
				add(root ,kksk.mnfr ,n ,-1);
				change(root ,kksk.mnfr ,INT_MAX);
			}else break;
		}
	}
	for(int i = 1 ,x ,y ; i <= q ; ++i){
		x = read() ,y = read();
		P[++pcnt] = (Point){n ,n - y ,1 ,1 ,i};
		P[++pcnt] = (Point){x ,n - y ,1 ,-1 ,i};
	}
	sort(P + 1 ,P + 1 + pcnt);
	for(int i = 1 ; i <= pcnt ; ++i){
		if(P[i].type == 0){
			C.add(P[i].y ,1);
		}else{
			Ans[P[i].id] += P[i].coef * C.sum(P[i].y);
		}
	}
	for(int i = 1 ; i <= q ; ++i)
		printf("%d\n" ,Ans[i]);
	return 0;
}