牛客練習賽31 F 瑟班守護者莎利雅與護教軍(推導 + 線段樹)
一道算是挺難想的線段樹。
需要維護兩個陣列,一個d和一個f。f的定義參見體題面,前後最大值中最小的,再與自己的d取一個大的。然後有三種操作,1是求f的和;2是求區間[l,r]中,有多少個f的數值大於給定的數字v;3是修改某一個位置的d的值,使其增加v。在這裡,f隨著d的數值而變化,但是這個變化的關係沒有什麼直觀的聯絡。
下面我們仔細考慮一下,這兩個到底有什麼關係。
首先我們得出幾個顯然的推論:
推論一:如果某個位置x的d增加了v之後,存在一個y>x,使得d[y]>d[x],那麼有max(d[k])>max(d[j]),k和j的定義見題面。
推論二:推論一反過來,如果存在y<x,使得d[y]>d[x],那麼有max(d[j])>max(d[k])
然後開始推導。
考慮當一個位置x增加了v,不妨假設增加了之後 d[x]>d[i] (1<=i<x) 並且滿足推論一的條件,那麼根據推論一,在區間[x+1,y-1]中,y為滿足條件的最小值,f的數值只受max(d[j])影響,而此時max(d[j])=d[x],所以這一區間的f增大為d[x]。
當滿足前一條件,但是不滿足推論一的條件的時候,也即在x的右邊找不到比當前d[x]大的d[y],那麼我去找區間[x+1,n]中最左邊的一個最大值所在的位置y。我們會發現,對於一個在區間[x+1,y-1]中的點,當更新之前,滿足max(d[j])>max(d[k])時,更新之後這個關係還是不變,f還是可以用d[k]更新;但當更新之前是max(d[j])<max(d[k])時,f取值是d[j],但此時已經是max(d[j])>max(d[k])了,所以f的取值應該是與d[k]關聯。綜上,兩種情況用d[k]去更新f都不會出錯。所以此時對於區間[x+1,y-1]我們用d[k]去更新。
同理,如果d[x]>d[i] (x<i<=n)並且滿足推論二的條件,那麼根據推論二,在區間[y+1,x-1]中,y為滿足條件的最大值,f的數值只會受到max(d[k])的影響,而此時max(d[k])=d[x],所以這一區間的f增大為d[x]。
當不滿足推論二的條件的時候也是類似的道理。找到區間[1,x-1]中最右邊的一個最大值所在的位置y,對於區間[y+1,x-1]的f用d[x]去更新即可。
這樣,動態維護f的操作我們就完成了。接下來考慮這個區間內大於某一個數字的f的個數怎麼求。
經過觀察,我們可以發現,那些f的數值大於某些數字的區域一定是連續的一個區間,而且這個區間的端點l和r,滿足l左邊的所有d[i]都小於對應的數字,r右邊的所有d[i]都小於對應的數字。這個很容易去證明,如果存在兩個位置l和r,d[l]和d[r]都大於對應的數字,那麼顯然他們之間的所有位置的f都要大於對應的數字。因為d[j]和d[k]都大於等於對應的數字,而f是二者最小值與d的最大值,因此也一定大於對應的數字。所以只需要找到d中最左邊和最右邊的大於對應數字的位置即可。那麼第二個操作也可以解決了。
最後總結一下。我們建立兩棵線段樹,一個維護d,另一個維護f。對於一個修改操作,我們需要在d中查詢左右第一個大於更新之後d[x]的位置,如果不存在則對應找左邊做靠右的最大值和右邊最靠左的最大值,這樣我們就可以知道需要修改的區間。然後在f的對應區間進行更新。對於查詢區間內大於某一個數字的個數,只需要在d中找大於對應數字的最左邊和最右邊的數字的位置,兩個之間的數字都滿足條件。然後查詢f的和直接輸出f那棵樹的和即可。具體操作見程式碼:
#include<bits/stdc++.h>
#define INF 0x3f3f3f3f
#define pi 3.141592653589793
#define mod 998244353
#define LL long long
#define pb push_back
#define lb lower_bound
#define ub upper_bound
#define sf(x) scanf("%lld",&x)
#define sc(x,y,z) scanf("%lld%lld%lld",&x,&y,&z)
using namespace std;
const int N = 300010;
LL f[N],d[N],mx1[N],mx2[N],n,m;
typedef pair<LL,int> P;
struct ST1
{
#define ls i<<1
#define rs i<<1|1
struct node
{
int lpos,rpos,l,r;
LL max;
} T[N<<2];
inline void push_up(int i)
{
T[i].max=max(T[ls].max,T[rs].max);
T[i].lpos=T[T[ls].max>=T[rs].max?ls:rs].lpos;
T[i].rpos=T[T[rs].max>=T[ls].max?rs:ls].rpos;
}
void build(int i,int l,int r)
{
T[i]=node{0,0,l,r,0};
if (l==r)
{
T[i].lpos=T[i].rpos=l;
T[i].max=d[l]; return;
}
int mid=(l+r)>>1;
build(ls,l,mid);
build(rs,mid+1,r);
push_up(i);
}
void update(int i,int pos,LL x)
{
if (T[i].l==T[i].r)
{
T[i].max+=x;
return;
}
int mid=(T[i].l+T[i].r)>>1;
if (mid>=pos) update(ls,pos,x);
else if (mid<pos) update(rs,pos,x);
push_up(i);
}
int query1(int i,int l,int r,LL x)
{
if (l>r||T[i].max<x) return 0;
if (T[i].l==T[i].r) return T[i].max>=x?T[i].l:0;
int mid=(T[i].l+T[i].r)>>1,res=0;
if (l<=mid) res=query1(ls,l,r,x);
if (r>mid&&!res) res=query1(rs,l,r,x);
return res;
}
int query2(int i,int l,int r,LL x)
{
if (l>r||T[i].max<x) return 0;
if (T[i].l==T[i].r) return T[i].max>=x?T[i].l:0;
int mid=(T[i].l+T[i].r)>>1,res=0;
if (r>mid&&!res) res=query2(rs,l,r,x);
if (l<=mid&&!res) res=query2(ls,l,r,x);
return res;
}
P getmax1(int i,int l,int r)
{
if (l<=T[i].l&&T[i].r<=r) return P(T[i].max,T[i].lpos);
int mid=(T[i].l+T[i].r)>>1; P res={0,0};
if (l<=mid) res=getmax1(ls,l,r);
if (r>mid)
{
P tmp=getmax1(rs,l,r);
if (tmp.first>res.first) res=tmp;
}
return res;
}
P getmax2(int i,int l,int r)
{
if (l<=T[i].l&&T[i].r<=r) return P(T[i].max,T[i].rpos);
int mid=(T[i].l+T[i].r)>>1; P res={0,0};
if (r>mid) res=getmax2(rs,l,r);
if (l<=mid)
{
P tmp=getmax2(ls,l,r);
if (tmp.first>res.first) res=tmp;
}
return res;
}
} seg1;
struct ST2
{
#define ls i<<1
#define rs i<<1|1
struct node
{
LL lazy,sum;
int l,r;
} T[N<<2];
void build(int i,int l,int r)
{
T[i]=node{0,0,l,r};
if (l==r)
{
T[i].sum=f[l];
return;
}
int mid=(l+r)>>1;
build(ls,l,mid);
build(rs,mid+1,r);
T[i].sum=T[ls].sum+T[rs].sum;
}
inline void push_down(int i)
{
LL lazy=T[i].lazy;
T[ls].lazy=lazy; T[rs].lazy=lazy;
T[ls].sum=(T[ls].r-T[ls].l+1)*lazy;
T[rs].sum=(T[rs].r-T[rs].l+1)*lazy;
T[i].lazy=0;
}
void update(int i,int l,int r,LL x)
{
if (l>r) return;
if (T[i].l==l&&T[i].r==r)
{
T[i].sum=(T[i].r-T[i].l+1)*x;
T[i].lazy=x; return;
}
if (T[i].lazy) push_down(i);
int mid=(T[i].l+T[i].r)>>1;
if (mid>=r) update(ls,l,r,x);
else if (mid<l) update(rs,l,r,x);
else
{
update(ls,l,mid,x);
update(rs,mid+1,r,x);
}
T[i].sum=T[ls].sum+T[rs].sum;
}
} seg2;
int main()
{
sf(n); sf(m);
for(int i=1;i<=n;i++) sf(d[i]);
seg1.build(1,1,n);
for(int i=1;i<=n;i++)
mx1[i]=max(mx1[i-1],d[i]);
for(int i=n;i>=1;i--)
mx2[i]=max(mx2[i+1],d[i]);
for(int i=1;i<=n;i++)
f[i]=max(d[i],min(mx1[i-1],mx2[i+1]));
seg2.build(1,1,n);
while(m--)
{
LL op,l,r,v;
sf(op);
if (op==1) printf("%lld\n",seg2.T[1].sum);
else if (op==2)
{
sc(l,r,v);
int ll=seg1.query1(1,1,l-1,v);
int rr=seg1.query2(1,r+1,n,v);
if (ll) ll=l; else ll=seg1.query1(1,l,r,v);
if (rr) rr=r; else rr=seg1.query2(1,l,r,v);
if (ll*rr==0) puts("0"); else printf("%d\n",rr-ll+1);
} else
{
sf(l); sf(v);
seg1.update(1,l,v); d[l]+=v;
int ll=seg1.query2(1,1,l-1,d[l]);
int rr=seg1.query1(1,l+1,n,d[l]);
if (ll&&rr) continue;
seg2.update(1,l,l,d[l]);
if (ll) seg2.update(1,ll+1,l-1,d[l]);
else if (l!=1)
{
P tmp=seg1.getmax2(1,1,l-1);
seg2.update(1,tmp.second,l-1,tmp.first);
}
if (rr) seg2.update(1,l+1,rr-1,d[l]);
else if (l!=1)
{
P tmp=seg1.getmax1(1,l+1,n);
seg2.update(1,l+1,tmp.second,tmp.first);
}
}
}
return 0;
}