1. 程式人生 > 其它 >CF1649E Tyler and Strings題解

CF1649E Tyler and Strings題解

題意

給你一個長度為\(n\)的序列\(s\)和一個長度為\(m\)的序列\(t\),現在你可以任意排列\(s\)中的元素.現在問你能組合出多少個本質不同的序列,使得字典序小於\(t\)

分析

發現最終的要求是字典序小於,所以我們可以從\(1\)號位置到\(\min(n,m)\)的位置迭代.假設我們現在迭代到的位置為\(j\),那麼位置\(j\)要麼放一個字典序小於\(t_j\)的字元,放了之後從\(j+1\)\(\min(n,m)\)的位置都可以隨便放.要麼放一個字典序等於\(t_j\)的字元,然後因為我們可以確定放的是哪一個字元,就可以減去這個字元的影響然後繼續從\(j+1\)開始統計答案.
以上是主要思路,但是我們發現如果我們暴力迭代並統計的話,我們設權值的最大值為\(c\)

,那麼我們需要列舉用哪個字典序小於\(t_j\)的元素,然後統計答案,複雜度很顯然接受不了.
考慮優化,首先考慮怎麼優化統計答案的部分,我們設用\(v_1\)\(1\),\(v_2\)\(2\)...\(v_c\)\(c\)來組成不同的序列,那麼顯然能組成的不同的序列個數為\(\frac{(v_1+v_2+...+v_c)!}{v_1!v_2!...v_c!}\),這玩意顯然可以通過預處理處理出來.然後我們考慮迭代的時候,我們設\(s_k<t_j\),所以用\(s_k\)這個元素,那麼它後邊的就可以隨便選,答案就是\(\frac{(v_1+v_2+...+v_c-1)!}{v_1!v_2!...(v_k-1)!...v_c!}\)
,而且由於對於每個字典序小於\(t_j\)的元素,我們都要統計一遍這個答案,所以我們可以設\(A_i=\frac{(v_1+v_2+...+v_c-1)!}{v_1!v_2!...(v_i-1)!...v_c!}\),那麼對於某次迭代,只需要對這個陣列求字首和即可.
但是這樣我們會發現一個問題,就是當我們選擇了一個字典序等於\(t_j\)的元素時,這個元素會被去掉,我們之前統計的所有的\(A_i\)就需要重新被計算一遍,我們顯然接受不了.考慮怎麼快速解決這個問題.我們發現:假設我們用的元素的值為\(k\),那麼對於所有的\(i\) \(\not=\) \(k,A_i=A_i\times\frac{v_i}{v_1+v_2+...+v_c-1}\)
,而對於\(k\),則有\(A_k=A_k\times\frac{v_k-1}{v_1+v_2+...+v_c-1}\).
分析到這裡,我們發現以上兩種操作就是對整個\(A\)陣列進行區間乘法以及區間求和操作,可以用基本的線段樹來實現這種操作然後統計答案即可.
有一點需要注意,當\(n<m\)時,如果前幾個的字典序都選取了與\(t\)相等的,這樣組成的序列字典序也是小於\(t\)的,所以需要特判。

程式碼

#include<bits/stdc++.h>
using namespace std;
const int N=2e5+10;
const int P=998244353;
int n,m,c;
int s[N],t[N];
int inv[N],fac[N],ninv[N];
int buck[N],add[N],tmp[N];
struct Tree{
	int mul,val;
	int size;
}tree[N<<2];

inline int ksm(int x,int y){
	int res=1;
	while(y){
		if(y&1)
			res=1ll*res*x%P;
		x=1ll*x*x%P;
		y>>=1;
	}
	return res%P;
}

#define LC (root<<1)
#define RC (root<<1|1)
void Build(int root,int start,int end){
	tree[root].size=end-start+1;
	tree[root].mul=1;
	if(start==end){
		tree[root].val=add[start];
		return;
	}
	int mid=(start+end)>>1;
	Build(LC,start,mid);
	Build(RC,mid+1,end);
	tree[root].val=(tree[LC].val+tree[RC].val)%P;
	return;
}
void pushdown(int root){
	if(tree[root].mul>1){
		tree[LC].mul=1ll*tree[root].mul*tree[LC].mul%P;
		tree[RC].mul=1ll*tree[root].mul*tree[RC].mul%P;
		tree[LC].val=1ll*tree[LC].val*tree[root].mul%P;
		tree[RC].val=1ll*tree[RC].val*tree[root].mul%P;
	}
	tree[root].mul=1;
	return;
}
void modify_group(int root,int qstart,int qend,int nstart,int nend,int off){
	if(qend<qstart)
		return;
	if(qstart>nend||qend<nstart)
		return;
	if(qstart<=nstart&&qend>=nend){
		tree[root].val=1ll*tree[root].val*off%P;
		tree[root].mul=1ll*tree[root].mul*off%P;
		return;
	}
	int mid=(nstart+nend)>>1;
	pushdown(root);
	modify_group(LC,qstart,qend,nstart,mid,off);
	modify_group(RC,qstart,qend,mid+1,nend,off);
	tree[root].val=(tree[LC].val+tree[RC].val)%P;
	return;
}
int query(int root,int qstart,int qend,int nstart,int nend){
	if(qend<qstart)
		return 0;
	if(qstart>nend||qend<nstart)
		return 0;
	if(qstart<=nstart&&qend>=nend)
		return tree[root].val%P;
	int mid=(nstart+nend)>>1;
	pushdown(root);
	return (query(LC,qstart,qend,nstart,mid)+query(RC,qstart,qend,mid+1,nend))%P;
}

int main(void){
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++)
		scanf("%d",&s[i]);
	for(int i=1;i<=n;i++)
		c=max(c,s[i]);
	for(int i=1;i<=m;i++)
		scanf("%d",&t[i]);
	fac[1]=1;
	for(int i=2;i<=n;i++)
		fac[i]=1ll*fac[i-1]*i%P;
	inv[0]=ninv[0]=1;
	for(int i=1;i<=n;i++){
		inv[i]=ksm(fac[i],P-2)%P;
		ninv[i]=ksm(i,P-2)%P;
	}
	for(int i=1;i<=n;i++)
		buck[s[i]]++;
	int inih=0,inil=1;
	for(int i=1;i<=c;i++){
		inih=(inih+buck[i])%P;
		inil=1ll*inil*inv[buck[i]]%P;
	}
//	printf("%lld\n",1ll*fac[inih]*inil%P);
	for(int i=1;i<=c;i++){
		if(buck[i]<=0)
			continue;
		int bottom=1ll*inil*fac[buck[i]]%P*inv[buck[i]-1]%P;
		add[i]=1ll*fac[inih-1]*bottom%P;
//		printf("%d ",add[i]);
	}
	
	Build(1,1,c);
	int sum=inih;
	long long ans=0;
	
	if(m>n){
		for(int i=1;i<=c;i++)
			tmp[i]=buck[i];
		bool flag=1;
		for(int i=1;i<=n&&flag;i++){
			if(tmp[t[i]]>=1){
				tmp[t[i]]--;
			}
			else 
				flag=0;
		}
		if(flag)
			ans++;
	}
	
	for(int i=1;i<=min(n,m);i++){
		ans=(ans+query(1,1,t[i]-1,1,c))%P;
		if(buck[t[i]]==0)
			break;
		int dest=1ll*buck[t[i]]*ninv[sum-1]%P;
		int d2=1ll*(buck[t[i]]-1)*ninv[buck[t[i]]]%P;
		modify_group(1,1,c,1,c,dest);
		modify_group(1,t[i],t[i],1,c,d2);
		sum--;
		buck[t[i]]--;
	}
	printf("%lld\n",ans); 
	return 0;
}