1. 程式人生 > 實用技巧 >CF::Gym 100960G - Youngling Tournament

CF::Gym 100960G - Youngling Tournament

CF::Gym題目頁面傳送門

給定一個數列\(a\),支援\(q\)次操作:令\(a_x=y\),並設\(a'\)\(a\)從大到小排序的結果,查詢滿足該位置上的數大於等於其後所有數的和的位置數量。

\(n\in\left[1,10^5\right],q\in\left[1,5\times10^4\right],a_i\in\left[1,10^{12}\right]\)

考慮設\(suM_i\)\(a'\)在位置\(i\)處的字尾和,那麼位置\(i\)滿足條件顯然當且僅當\(a'_i-suM_{i+1}\geq0\)。不難想到維護\(a'_i-suM_{i+1}\)

肯定是要按\(a_i\)

從大往小排列的。注意到每次修改,可能會讓\(a_x\)\(a'\)中移個位置,而其他元素的相對位置不變。設\(a_x\)本來在\(a'\)中位置為\(p\),修改完跑到了\(p'\)。那麼分\(p<p'\)\(p\geq p'\)兩種情況。這裡以前者為例,後者類似。

顯然整個\(a'\)序列分成三段:

  1. \(1\sim p\),這一段的\(a'_i-suM_{i+1}\)值顯然都要加上\(a_x-y\)
  2. \(p\sim p'\),這一段的\(a'_i-suM_{i+1}\)值顯然都要加上\(-y\)
  3. \(p'\sim n\),這一段的\(a'_i-suM_{i+1}\)值顯然不變。

看到區間增加,不難想到線段樹配合懶標記。那麼問題來了,維護啥呢?咋查詢呢?線段樹套平衡樹肯定是不行的,因為是區間修改。線段樹直接維護也維護不動。考慮讓線段樹起到剪枝的作用:每個節點維護當前區間的\(a'_i-suM_{i+1}\)最大值(這個顯然是懶標記可做的)。查詢的時候從根往下走,對於每個兒子,如果它的最大值\(\geq0\)則往下走,否則不往下走(即裡面不可能有符合要求的位置)。

這樣複雜度是多少呢?注意到一個非常重要的性質:答案是\(\mathrm O(\log v)\)級別的,其中\(v\)\(a\)的值域大小。證明非常簡單,大概就是每有一個答案,字尾和就要增一倍,並且當前位置的數要大於等於字尾和。這樣一來,每個符合要求的數就會有一條對應的線段樹上的根到葉子的鏈,多條鏈的並集是被經過的節點集合,\(\mathrm O(\log n\log v)\)

考慮到還要插入與刪除,想用線段樹的話要離線預留好位置,比較煩我懶得寫了。其他的方法有:動態開點線段樹,\(\mathrm\!\left(\log^2v\right)\);平衡樹,複雜度不變,分析差不多。我寫了後者,使用fhq-Treap。插入刪除直接轉化成修改,就不需要垃圾桶了。

時間複雜度\(\mathrm O(n\log n+q\log n\log v)\)

程式碼:

#include<bits/stdc++.h>
using namespace std;
#define int long long
#define mp make_pair
#define X first
#define Y second
const int inf=0x3f3f3f3f3f3f3f3f;
mt19937 rng(20060617);
const int N=100000;
int n,qu;
int a[N+1];
int b[N+1],suM[N+2];
struct fhq_treap{ 
	int sz,root;
	struct node{unsigned key;int lson,rson,sz,v,dif,sum,mx,lz;}nd[N+1];
	#define key(p) nd[p].key
	#define lson(p) nd[p].lson
	#define rson(p) nd[p].rson
	#define sz(p) nd[p].sz
	#define v(p) nd[p].v
	#define dif(p) nd[p].dif
	#define sum(p) nd[p].sum
	#define mx(p) nd[p].mx
	#define lz(p) nd[p].lz
	int bld(int l=1,int r=n){
		int mid=l+r>>1,p=nwnd(b[mid],b[mid]-suM[mid+1]);
		if(l<mid)lson(p)=bld(l,mid-1);
		if(r>mid)rson(p)=bld(mid+1,r);
		return sprup(p),p;
	}
	void init(){
		sz=0;
		nd[0]=node({0,0,0,0,0,-inf,0,-inf,0});
		root=bld();
	}
	void sprup(int p){
		sum(p)=sum(lson(p))+v(p)+sum(rson(p));
		mx(p)=max(mx(lson(p)),max(dif(p),mx(rson(p))));
		sz(p)=sz(lson(p))+1+sz(rson(p));
	}
	void sprdwn(int p){
		if(lz(p)){
			tag(lson(p),lz(p));tag(rson(p),lz(p));
			lz(p)=0;
		}
	}
	void tag(int p,int v){
		if(p)dif(p)+=v,mx(p)+=v,lz(p)+=v;
	}
	pair<int,int> split(int x,int p=-1){~p||(p=root);
		if(!x)return mp(0,p);
		sprdwn(p);
		pair<int,int> sp;
		if(x<=sz(lson(p)))return sp=split(x,lson(p)),lson(p)=sp.Y,sprup(p),mp(sp.X,p);
		return sp=split(x-1-sz(lson(p)),rson(p)),rson(p)=sp.X,sprup(p),mp(p,sp.Y);
	}
	int mrg(int p,int q){
		if(!p||!q)return p|q;
		sprdwn(p);sprdwn(q);
		if(key(p)<key(q))return rson(p)=mrg(rson(p),q),sprup(p),p;
		return lson(q)=mrg(p,lson(q)),sprup(q),q;
	}
	int grt(int v,int p=-1){~p||(p=root);
		if(!p)return 0;
		sprdwn(p);
		if(v(p)>v)return sz(lson(p))+1+grt(v,rson(p));
		return grt(v,lson(p));
	}
	int nwnd(int v,int dif){
		return nd[++sz]=node({rng(),0,0,1,v,dif,v,dif,0}),sz;
	}
	void mv_rit(int v1,int v2){
		pair<int,int> sp=split(grt(v1)),sp0=split(1,sp.Y),sp1=split(grt(v2,sp0.Y),sp0.Y);
		//sp.X,del(sp0.X),sp1.X,insert(v2),sp1.Y
		v(sp0.X)=sum(sp0.X)=v2,dif(sp0.X)=mx(sp0.X)=v2-sum(sp1.Y);
		tag(sp.X,v1-v2);tag(sp1.X,-v2);
		root=mrg(sp.X,mrg(sp1.X,mrg(sp0.X,sp1.Y)));
	}
	void mv_lft(int v1,int v2){
		pair<int,int> sp=split(grt(v2)),sp0=split(grt(v1,sp.Y),sp.Y),sp1=split(1,sp0.Y);
		//sp.X,insert(v2),sp0.X,del(sp1.X),sp1.Y
		v(sp1.X)=sum(sp1.X)=v2,dif(sp1.X)=mx(sp1.X)=v2-sum(sp0.X)-sum(sp1.Y);
		tag(sp.X,v1-v2);tag(sp0.X,v1);
		root=mrg(sp.X,mrg(sp1.X,mrg(sp0.X,sp1.Y)));
	}
	int cnt(int p=-1){~p||(p=root);
		if(!p)return 0;
		sprdwn(p);
		int res=0;
		if(dif(p)>=0)res++;
		if(mx(lson(p))>=0)res+=cnt(lson(p));
		if(mx(rson(p))>=0)res+=cnt(rson(p));
		return res;
	}
	void dfs(int p=-1){~p||(p=root);
		if(!p)return;
		sprdwn(p);
		dfs(lson(p));
		cout<<v(p)<<" "<<dif(p)<<"!\n";
		dfs(rson(p));
	}
}trp;
signed main(){
	cin>>n;
	for(int i=1;i<=n;i++)scanf("%lld",a+i),b[i]=a[i];
	sort(b+1,b+n+1,greater<int>());
	for(int i=n;i;i--)suM[i]=suM[i+1]+b[i];
	trp.init();
//	trp.dfs();
	cout<<trp.cnt()<<"\n";
	cin>>qu;
	while(qu--){
		int x,y;
		scanf("%lld%lld",&x,&y);
//		cout<<a[x]<<" "<<y<<"!!\n";
		if(a[x]>y)trp.mv_rit(a[x],y);
		else trp.mv_lft(a[x],y);
		a[x]=y;
//		trp.dfs();
		printf("%lld\n",trp.cnt());
	}
	return 0;
}

後來yxh告訴我了另一個神仙方法,是用二進位制搞的,程式碼特別短。當時乍一看是1log的,woc這麼強的嗎???還準備學習一下。現在才發現也是2log的,複雜度跟上述做法一樣,就懶得研究了。

總體來說是比較簡單的一題。