CF1649E Tyler and Strings題解
阿新 • • 發佈:2022-03-09
題意
給你一個長度為\(n\)的序列\(s\)和一個長度為\(m\)的序列\(t\),現在你可以任意排列\(s\)中的元素.現在問你能組合出多少個本質不同的序列,使得字典序小於\(t\)
分析
發現最終的要求是字典序小於,所以我們可以從\(1\)號位置到\(\min(n,m)\)的位置迭代.假設我們現在迭代到的位置為\(j\),那麼位置\(j\)要麼放一個字典序小於\(t_j\)的字元,放了之後從\(j+1\)到\(\min(n,m)\)的位置都可以隨便放.要麼放一個字典序等於\(t_j\)的字元,然後因為我們可以確定放的是哪一個字元,就可以減去這個字元的影響然後繼續從\(j+1\)開始統計答案.
以上是主要思路,但是我們發現如果我們暴力迭代並統計的話,我們設權值的最大值為\(c\)
考慮優化,首先考慮怎麼優化統計答案的部分,我們設用\(v_1\)個\(1\),\(v_2\)個\(2\)...\(v_c\)個\(c\)來組成不同的序列,那麼顯然能組成的不同的序列個數為\(\frac{(v_1+v_2+...+v_c)!}{v_1!v_2!...v_c!}\),這玩意顯然可以通過預處理處理出來.然後我們考慮迭代的時候,我們設\(s_k<t_j\),所以用\(s_k\)這個元素,那麼它後邊的就可以隨便選,答案就是\(\frac{(v_1+v_2+...+v_c-1)!}{v_1!v_2!...(v_k-1)!...v_c!}\)
但是這樣我們會發現一個問題,就是當我們選擇了一個字典序等於\(t_j\)的元素時,這個元素會被去掉,我們之前統計的所有的\(A_i\)就需要重新被計算一遍,我們顯然接受不了.考慮怎麼快速解決這個問題.我們發現:假設我們用的元素的值為\(k\),那麼對於所有的\(i\) \(\not=\) \(k,A_i=A_i\times\frac{v_i}{v_1+v_2+...+v_c-1}\)
分析到這裡,我們發現以上兩種操作就是對整個\(A\)陣列進行區間乘法以及區間求和操作,可以用基本的線段樹來實現這種操作然後統計答案即可.
有一點需要注意,當\(n<m\)時,如果前幾個的字典序都選取了與\(t\)相等的,這樣組成的序列字典序也是小於\(t\)的,所以需要特判。
程式碼
#include<bits/stdc++.h>
using namespace std;
const int N=2e5+10;
const int P=998244353;
int n,m,c;
int s[N],t[N];
int inv[N],fac[N],ninv[N];
int buck[N],add[N],tmp[N];
struct Tree{
int mul,val;
int size;
}tree[N<<2];
inline int ksm(int x,int y){
int res=1;
while(y){
if(y&1)
res=1ll*res*x%P;
x=1ll*x*x%P;
y>>=1;
}
return res%P;
}
#define LC (root<<1)
#define RC (root<<1|1)
void Build(int root,int start,int end){
tree[root].size=end-start+1;
tree[root].mul=1;
if(start==end){
tree[root].val=add[start];
return;
}
int mid=(start+end)>>1;
Build(LC,start,mid);
Build(RC,mid+1,end);
tree[root].val=(tree[LC].val+tree[RC].val)%P;
return;
}
void pushdown(int root){
if(tree[root].mul>1){
tree[LC].mul=1ll*tree[root].mul*tree[LC].mul%P;
tree[RC].mul=1ll*tree[root].mul*tree[RC].mul%P;
tree[LC].val=1ll*tree[LC].val*tree[root].mul%P;
tree[RC].val=1ll*tree[RC].val*tree[root].mul%P;
}
tree[root].mul=1;
return;
}
void modify_group(int root,int qstart,int qend,int nstart,int nend,int off){
if(qend<qstart)
return;
if(qstart>nend||qend<nstart)
return;
if(qstart<=nstart&&qend>=nend){
tree[root].val=1ll*tree[root].val*off%P;
tree[root].mul=1ll*tree[root].mul*off%P;
return;
}
int mid=(nstart+nend)>>1;
pushdown(root);
modify_group(LC,qstart,qend,nstart,mid,off);
modify_group(RC,qstart,qend,mid+1,nend,off);
tree[root].val=(tree[LC].val+tree[RC].val)%P;
return;
}
int query(int root,int qstart,int qend,int nstart,int nend){
if(qend<qstart)
return 0;
if(qstart>nend||qend<nstart)
return 0;
if(qstart<=nstart&&qend>=nend)
return tree[root].val%P;
int mid=(nstart+nend)>>1;
pushdown(root);
return (query(LC,qstart,qend,nstart,mid)+query(RC,qstart,qend,mid+1,nend))%P;
}
int main(void){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%d",&s[i]);
for(int i=1;i<=n;i++)
c=max(c,s[i]);
for(int i=1;i<=m;i++)
scanf("%d",&t[i]);
fac[1]=1;
for(int i=2;i<=n;i++)
fac[i]=1ll*fac[i-1]*i%P;
inv[0]=ninv[0]=1;
for(int i=1;i<=n;i++){
inv[i]=ksm(fac[i],P-2)%P;
ninv[i]=ksm(i,P-2)%P;
}
for(int i=1;i<=n;i++)
buck[s[i]]++;
int inih=0,inil=1;
for(int i=1;i<=c;i++){
inih=(inih+buck[i])%P;
inil=1ll*inil*inv[buck[i]]%P;
}
// printf("%lld\n",1ll*fac[inih]*inil%P);
for(int i=1;i<=c;i++){
if(buck[i]<=0)
continue;
int bottom=1ll*inil*fac[buck[i]]%P*inv[buck[i]-1]%P;
add[i]=1ll*fac[inih-1]*bottom%P;
// printf("%d ",add[i]);
}
Build(1,1,c);
int sum=inih;
long long ans=0;
if(m>n){
for(int i=1;i<=c;i++)
tmp[i]=buck[i];
bool flag=1;
for(int i=1;i<=n&&flag;i++){
if(tmp[t[i]]>=1){
tmp[t[i]]--;
}
else
flag=0;
}
if(flag)
ans++;
}
for(int i=1;i<=min(n,m);i++){
ans=(ans+query(1,1,t[i]-1,1,c))%P;
if(buck[t[i]]==0)
break;
int dest=1ll*buck[t[i]]*ninv[sum-1]%P;
int d2=1ll*(buck[t[i]]-1)*ninv[buck[t[i]]]%P;
modify_group(1,1,c,1,c,dest);
modify_group(1,t[i],t[i],1,c,d2);
sum--;
buck[t[i]]--;
}
printf("%lld\n",ans);
return 0;
}