1. 程式人生 > 實用技巧 >題解 [HNOI2019]序列

題解 [HNOI2019]序列

題目傳送門

題目大意

給出一個\(n\)個數的數列\(A_{1,2,...,n}\),求出一個單調不減的數列\(B_{1,2,...,n}\),使得\(\sum_{i=1}^{n}(A_i-B_i)^2\)最小。

\(m\)次查詢,每次將某個\(A_x\)更改為\(y\),求出修改後的答案。查詢之間互相獨立。

\(n,m\le 10^5\)

思路

其實這道題正解是用保序迴歸,但是找規律也能找出來。

我們通過觀察發現,對於一段相同的\(B_i\)\(B_i\)是該段的平均值。於是我們大膽猜測,我們最優方案就是把\(A\)序列劃分成一些區間,每一段的\(B\)都是\(A\)的平均值,並且\(B\)

單調不降。

我們發現這順便還可以發現一個事情:我們肯定應該多分割槽間,否則的話我們肯定只分一段就完事了。

於是,我們現在考慮一個區間對答案的貢獻:

\[\sum_{i=l}^{r} (A_i-d)^2 \]

其中\(d\)是這段區間的平均值。

\[=\sum_{i=l}^{r} (A_i^2-2A_id+d^2) \]

\[=\sum_{i=l}^{r} A_i^2-\dfrac{(\sum_{i=l}^{r}A_i)^2}{r-l+1} \]

於是我們發現我們只需要維護區間平方和、區間和、區間長度。於是,我們就順利地拿到了\(50\)分。

考慮\(100\)分。一個很顯然的事情是,我們會改變的決策區間一定是\([L,R]\)

,至於\([1,L)\)\((R,n]\)都不會被影響,於是我們可以預處理一下。

問題就是如何找到\(L,R\)。一個不是很顯然的事情就是,我們選的端點一定都是一開始分的區間的某些端點。因為我們一個區間如果從中間分開,前一段的平均值一定比後一段的平均值大,因為如果比它小的話肯定分開更優(分得越多越優)。於是,如果我們不是選一開始的端點的話,被分開的那一段前一段一定比後一段的平均值大,與平均值單調不減矛盾,得證。

同時,我們還發現答案是具有單調性的。於是我們可以考慮先二分\(R\),然後再二分\(L\)。二分\(R\)直接二分就好了,二分\(L\)的話,我們發現並不需要考慮前面的是否滿足條件(這個自己想一下就明白了),只需要考慮分開的後面的,於是這個我們可以在主席樹上查詢。具體見程式碼。

時間複雜度\(\Theta(n\log^2 n)\),空間複雜度\(\Theta(n\log n)\)

\(\texttt{Code}\)

#include <bits/stdc++.h>
using namespace std;

#define Int register int
#define mod 998244353
#define ll long long
#define MAXN 100005

template <typename T> inline void read (T &t){t = 0;char c = getchar();int f = 1;while (c < '0' || c > '9'){if (c == '-') f = -f;c = getchar();}while (c >= '0' && c <= '9'){t = (t << 3) + (t << 1) + c - '0';c = getchar();} t *= f;}
template <typename T,typename ... Args> inline void read (T &t,Args&... args){read (t);read (args...);}
template <typename T> inline void write (T x){if (x < 0){x = -x;putchar ('-');}if (x > 9) write (x / 10);putchar (x % 10 + '0');}

int n,m,a[MAXN],inv[MAXN];
int mul (int a,int b){return 1ll * a * b % mod;}
int dec (int a,int b){return a >= b ? a - b : a + mod - b;}
int add (int a,int b){return a + b >= mod ? a + b - mod : a + b;}

struct node{
	int rt,ans,top;//對應主席樹的頂點、答案、棧頂編號 
}pre[MAXN],suf[MAXN];

struct Data{
	int len,sqr;ll sum;//因為要比較平均值大小,所以區間和不能取模
	Data(){}
	Data(int len1,ll sum1,int sqr1){len = len1,sum = sum1,sqr = sqr1;}
	Data operator + (Data p){return Data(len + p.len,sum + p.sum,add (sqr,p.sqr));}
	Data operator - (Data p){return Data(len - p.len,sum - p.sum,dec (sqr,p.sqr));} 
	bool operator < (Data p)const{return p.len ? (len ? 1.0 * sum / len < 1.0 * p.sum / p.len: 1) : 0;}
	bool operator <= (Data &p)const{return p.len ? (len ? 1.0 * sum / len <= 1.0 * p.sum / p.len: 1) : 0;} 
	int calc (){return dec (sqr,mul (mul (sum % mod,sum % mod),inv[len]));}
}sum[MAXN],sta[MAXN];
Data calc (int l,int r){return sum[r] - sum[l - 1];}

struct Segment{
	int cnt;
	struct Node{
		int son[2],l,r,k;//k是指l->r這寫區間第一段區間的右端點 
	}tree[MAXN * 60];
	void Pushup (int x){
		tree[x].l = tree[tree[x].son[0]].l,tree[x].r = tree[tree[x].son[tree[x].son[1] > 0]].r;
		tree[x].k = tree[tree[x].son[0]].k;
	}
	void modify (int &x,int l,int r,int pos,int L,int R){
		tree[++ cnt] = tree[x],x = cnt;
		if (l == r) return tree[x].l = L,tree[x].r = tree[x].k = R,void ();
		int mid = (l + r) >> 1;
		if (pos <= mid) modify (tree[x].son[0],l,mid,pos,L,R);
		else modify (tree[x].son[1],mid + 1,r,pos,L,R);
		Pushup (x);
	}
	int queryr (int x,int l,int r,int pos){//查詢第pos段區間的右端點 
		if (l == r) return tree[x].r;
		int mid = (l + r) >> 1;
		if (pos <= mid) return queryr (tree[x].son[0],l,mid,pos);
		else return queryr (tree[x].son[1],mid + 1,r,pos);  
	}
	int queryl (int x,int l,int r,int pos,Data &tmp){//找到最靠右的滿足的L,並對答案進行合併 
		if (r <= pos){
			Data Lget = calc (tree[x].l,tree[x].k),Rget = calc (tree[x].k + 1,tree[x].r);
			if (Rget + tmp <= Lget) return tmp = tmp + calc (tree[x].l,tree[x].r),0;
			if (l == r) return tree[x].r;
		}
		int res,mid = (l + r) >> 1;
		if (pos > mid && (res = queryl (tree[x].son[1],mid + 1,r,pos,tmp))) return res;
		else return queryl (tree[x].son[0],l,mid,pos,tmp);
	}
}Tree;

void Init(){
	for (Int i = 1,top = 0;i <= n;++ i){
		sta[++ top] = Data (1,a[i],mul (a[i],a[i]));
		while (top > 1 && sta[top] <= sta[top - 1]) sta[top - 1] = sta[top - 1] + sta[top],-- top;
		pre[i].ans = add (pre[i - sta[top].len].ans,sta[top].calc());//計算答案
		pre[i].rt = pre[i - 1].rt,pre[i].top = top;
		Tree.modify (pre[i].rt,1,n,top,i - sta[top].len + 1,i);
	}
	for (Int i = n,top = 0;i;-- i){
		sta[++ top] = Data (1,a[i],mul (a[i],a[i]));
		while (top > 1 && sta[top - 1] <= sta[top]) sta[top - 1] = sta[top - 1] + sta[top],-- top;
		suf[i].ans = add (suf[i + sta[top].len].ans,sta[top].calc());
		suf[i].rt = suf[i + 1].rt,suf[i].top = top;
		Tree.modify (suf[i].rt,1,n,top,i,i + sta[top].len - 1);
	}
}

signed main(){
	read (n,m),inv[1] = 1;
	for (Int i = 2;i <= n;++ i) inv[i] = mul (mod - (mod / i),inv[mod % i]);
	for (Int i = 1;i <= n;++ i) read (a[i]),sum[i] = sum[i - 1] + Data (1,a[i],mul (a[i],a[i]));Init ();
	write (pre[n].ans),putchar ('\n');
	for (Int i = 1,x,y;i <= m;++ i){
		read (x,y);
		int l = 0,r = suf[x + 1].top - 1;
		while (l <= r){
			int mid = (l + r) >> 1,Rpos = mid ? Tree.queryr (suf[x + 1].rt,1,n,suf[x + 1].top - mid + 1) : x;//Rpos就是選出來的R 
			Data tmp = Data (1,y,mul (y,y)) + calc (x + 1,Rpos);int Lpos = x > 1 ? Tree.queryl (pre[x - 1].rt,1,n,pre[x - 1].top,tmp) : x; 
			if (tmp < calc (Rpos + 1,Tree.queryr (suf[x + 1].rt,1,n,suf[x + 1].top - mid))) r = mid - 1;
			else l = mid + 1;
		}
		int mid = r + 1,Rpos = mid ? Tree.queryr (suf[x + 1].rt,1,n,suf[x + 1].top - mid + 1) : x;
		Data tmp = Data (1,y,mul (y,y)) + calc (x + 1,Rpos);int Lpos = x > 1 ? Tree.queryl (pre[x - 1].rt,1,n,pre[x - 1].top,tmp) : x; 
		write (add (tmp.calc(),add (pre[Lpos].ans,suf[Rpos + 1].ans))),putchar ('\n');
	}
	return 0;
}