1. 程式人生 > 實用技巧 >CF833B-線段樹優化DP

CF833B-線段樹優化DP

CF833B-線段樹優化DP

題意

將一個長為\(n\)的序列分成\(k\)段,每段貢獻為其中不同數字的個數,求最大貢獻和。

思路

此處感謝@gxy001 聚銠的精彩講解

先考慮暴力DP,可以想到一個時空複雜度\(O(n^2k)\)的方法,即記錄前i個數字分成了j段。我們現在來思考幾個問題來優化這個操作:

  1. 對於一個數字,它對那些地方實際有貢獻?
  2. 每次分割出一個區間段對後續操作有影響的位置在哪?
  3. 每次轉移都從哪些地方繼承?

下來一一解答這些問題。

  1. 對於一個數字,它能產生貢獻的區間其實就是該數字上一次出現的位置的後一位到它本身的位置。
  2. 對於每次劃分,它以前的位置的貢獻已經被考慮,所以我們只能考慮後面的位置。
  3. 相應的,每次轉移會繼承前面所有DP值的最大值

那麼我們可以將k提出來,每次迴圈繼承上一次所有的dp值。因為只考慮從前面轉移dp值,所以不會對之前的決策產生影響,所以是正確的。

看看1、3問題的答案,是不是想到了RMQ和區間賦值?

於是我們可以通過線段樹來實現DP優化。

具體來講,迭代k次,每次線段樹更新為上一次序列的dp值,然後從前往後掃,每個數會對其上述區間產生1的貢獻,轉移繼承前面所有dp值的最大值即可。

時間複雜度將一維優化為log。

程式碼

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cctype>
#include<algorithm>
#include<cmath>
using namespace std;
inline int read(){
	int w=0,x=0;char c=getchar();
	while(!isdigit(c))w|=c=='-',c=getchar();
	while(isdigit(c))x=(x<<3)+(x<<1)+(c^48),c=getchar();
	return w?-x:x;
}
namespace star
{
	const int maxn=35005;
	int n,k,cur[maxn],pre[maxn],f[maxn][60];
	struct SegmentTree{
		#define ls (ro<<1)
		#define rs (ro<<1|1)
		struct tree{
			int l,r,tag,v;
		}e[maxn<<2];
		inline void pushup(int ro){
			e[ro].v=max(e[ls].v,e[rs].v);
		}
		inline void pushdown(int ro){
			e[ls].tag+=e[ro].tag;e[rs].tag+=e[ro].tag;
			e[ls].v+=e[ro].tag;e[rs].v+=e[ro].tag;
			e[ro].tag=0;
		}
		void build(int ro,int l,int r){
			e[ro].l=l,e[ro].r=r;
			if(l==r)return;
			int mid=l+r>>1;
			build(ls,l,mid);
			build(rs,mid+1,r);
		}
		void rebuild(int tim,int ro){
			int l=e[ro].l,r=e[ro].r;
			e[ro].tag=0;
			if(l==r){
				e[ro].v=f[l][tim];return;
			}
			rebuild(tim,ls);rebuild(tim,rs);
			pushup(ro);
		}
		void update(int ro,int x,int y){
			int l=e[ro].l,r=e[ro].r;
			if(l>=x and r<=y){
				e[ro].v+=1;
				e[ro].tag+=1;return;
			}
			pushdown(ro);
			int mid=l+r>>1;
			if(mid>=x)update(ls,x,y);
			if(mid<y)update(rs,x,y);
			pushup(ro);
		}
		int query(int ro,int x,int y){
			int l=e[ro].l,r=e[ro].r;
			if(l==x and r==y)return e[ro].v;
			pushdown(ro);
			int mid=l+r>>1;
			if(mid<x)return query(rs,x,y);
			else if(mid>=y)return query(ls,x,y);
			else return max(query(ls,x,mid),query(rs,mid+1,y));
		}
		#undef ls
		#undef rs 
	}T;
	inline void work(){
		n=read(),k=read();
		for(int x,i=1;i<=n;i++)x=read(),pre[i]=cur[x],cur[x]=i;
		T.build(1,0,n);
		for(int i=1;i<=k;i++){
			T.rebuild(i-1,1);
			for(int x=1;x<=n;x++) T.update(1,pre[x],x-1),f[x][i]=T.query(1,0,x-1);
		}
		printf("%d",f[n][k]);
	}
}
signed main(){
	star::work();
	return 0;
}