1. 程式人生 > 實用技巧 >【題解】CF645E. Intellectual Inquiry / sequence【20201020 CSP 模擬賽】【貪心 矩陣乘法】

【題解】CF645E. Intellectual Inquiry / sequence【20201020 CSP 模擬賽】【貪心 矩陣乘法】

題目連結

題目連結(略有不同)

題意

有一長為 \(n\) 的序列 \(a\)\(a_i\in [1,k]\cap \mathbf{N}\)。你要在後面添 \(m\) 個數(\(a_i\in [1,k]\cap \mathbf{N}\)),使得新序列的本質不同子序列個數最大。模 \(998244353\)\(n\leq 10^6\)\(k\leq 100\)\(m\leq 10^{18}\)

題解

首先考慮如何求一個序列的本質不同子序列個數:設 \(f_i\) 為考慮前 \(i\) 個數的本質不同子序列個數,如果沒有出現相同數字,\(f_{i+1}=2f_i\),如果在 \(t\) 處出現過相同數字,則 \(f_{i+1}=2f_i-f_{t-1}\)

(前 \(t-1\) 個數加上最後一個和前 \(i\) 個數加上最後一個會重)。

加新的數字時,肯定優先加要減的 \(f\) 小的數字(要是這個數放在之後加,前面的 \(\times 2\) 次數很多的地方用要減的 \(f\) 大的數顯然不優),也就是出現最近一次出現最早的數字。於是到了之後,有 \(f_{i}=2f_{i-1}-f_{i-k-1}\),於是矩陣乘法優化常係數齊次線性遞推。

#include<bits/stdc++.h>
using namespace std;
#define int long long
#define ll long long
ll getint(){
	ll ans=0,f=1;
	char c=getchar();
	while(c<'0'||c>'9'){
		if(c=='-')f=-1;
		c=getchar();
	}
	while(c>='0'&&c<='9'){
		ans=ans*10+c-'0';
		c=getchar();
	}
	return ans*f;
}
const int N=1e6+10,K=105,mod=1e9+7;
int a[N],n,k;
int lst[K];
bool v[K];
int b[K];
ll m;

struct mat{
	int a[K][K];
	mat(){memset(a,0,sizeof(a));}
	void init1(){
		for(int i=0;i<=k;i++)a[i][i]=1;
	}
	void init(){
		for(int i=0;i<k;i++)a[i][i+1]=1;
		a[k][0]=mod-1;a[k][k]=2;
	}
	int * operator[](int x){
		return a[x];
	}
};
mat operator* (mat &x,mat &y){
	mat z;
	for(int i=0;i<=k;i++){
		for(int j=0;j<=k;j++){
			for(int l=0;l<=k;l++){
				z[i][j]=(z[i][j]+x[i][l]*1ll*y[l][j])%mod;
			}
		}
	}
	return z;
}
int qpow(ll m){
	mat x,ans;
	ans.init1();
	x.init();
//	for(int i=0;i<=k;i++){
//		for(int j=0;j<=k;j++)cerr<<ans[i][j]<<" ";
//		cerr<<endl;
//	}cerr<<endl;
//	for(int i=0;i<=k;i++){
//		for(int j=0;j<=k;j++)cerr<<x[i][j]<<" ";
//		cerr<<endl;
//	}cerr<<endl;
	while(m){
		if(m&1)ans=ans*x;
		x=x*x;
		m>>=1;
	}
//	for(int i=0;i<=k;i++){
//		for(int j=0;j<=k;j++)cerr<<x[i][j]<<" ";
//		cerr<<endl;
//	}cerr<<endl;
	int res=0;
	for(int i=0;i<=k;i++)
		res=(res+b[i]*1ll*ans[k][i])%mod;
	return res;
}

signed main(){
	n=getint(),m=getint(),k=getint();
	for(int i=1;i<=n;i++)a[i]=getint();
	int f=1;
	for(int i=1;i<=n;i++){
		int t=lst[a[i]];
		lst[a[i]]=f;
		f=(f=f*2ll-t+mod)%mod;
	}
	int tmp=k;
	b[tmp]=f;
	for(int i=n;i>=1;--i){
		if(!v[a[i]])b[--tmp]=lst[a[i]];
		v[a[i]]=1;
	}
	if(m<=1e6){
		queue<int>q;
		for(int i=0;i<=k;i++)q.push(b[i]);
		for(int i=0;i<m;i++){
			q.push((q.back()*2ll-q.front()+mod)%mod);
			q.pop();
		}
		cout<<q.back()-1<<endl;
	}else{
		int ans=qpow(m);
		cout<<ans-1<<endl;
	}
}