聯考20200722 T1 集合劃分
阿新 • • 發佈:2020-07-22
分析:
首先是一個\(O(n^2)\)的DP,設\(f_{i,j,0/1}\)表示做了前\(i\)個,用了\(j\)個\(A\),最後一個是\(A/B\)的方案數
然後我們不看最後一位,發現\(f_{i,j}\)兩個狀態可以用\(2*2\)的轉移矩陣DP
發現轉移矩陣與\(j\)沒有關係,把\(j\)去掉,維護\(f_i=\sum_{j=0}a_jx^j\)的生成函式,\(x^j\)項係數就是\(f_{i,j}\)
如果加一位\(A\)相當於乘一個\(x\),否則乘一個\(1\)
分治維護矩陣上的多項式
複雜度\(O(nlog^2n)\),我的常數巨大2333
#include<cstdio> #include<cmath> #include<cstring> #include<iostream> #include<algorithm> #include<queue> #include<set> #include<map> #include<vector> #include<string> #define maxn 200005 #define INF 0x3f3f3f3f #define MOD 998244353 #define Poly vector<int> using namespace std; inline long long getint() { long long num=0,flag=1;char c; while((c=getchar())<'0'||c>'9')if(c=='-')flag=-1; while(c>='0'&&c<='9')num=num*10+c-48,c=getchar(); return num*flag; } int n; int A[maxn],B[maxn]; struct node{ Poly a[2][2]; }P[maxn]; int rev[maxn]; inline int upd(int x){return x<MOD?x:x-MOD;} inline int ksm(int num,int k) { int ret=1; for(;k;k>>=1,num=1ll*num*num%MOD)if(k&1)ret=1ll*ret*num%MOD; return ret; } inline Poly add(Poly x,Poly y) { int mx=max(x.size(),y.size()); x.resize(mx),y.resize(mx); for(int i=0;i<mx;i++)x[i]=upd(x[i]+y[i]); return x; } inline void NTT(Poly &a,int N,int opt) { for(int i=0;i<N;i++)if(i<rev[i])swap(a[i],a[rev[i]]); for(int i=1;i<N;i<<=1) { int wn=ksm(3,(MOD-1)/(i<<1)); if(!~opt)wn=ksm(wn,MOD-2); for(int j=0;j<N;j+=i<<1)for(int k=0,w=1;k<i;k++,w=1ll*w*wn%MOD) { int x=a[j+k],y=1ll*a[i+j+k]*w%MOD; a[j+k]=upd(x+y),a[i+j+k]=upd(x-y+MOD); } } if(!~opt)for(int i=0,Inv=ksm(N,MOD-2);i<N;i++)a[i]=1ll*a[i]*Inv%MOD; } inline node mul(node y,node x) { int N=x.a[0][0].size(),M=y.a[0][0].size(),len=1; while(len<N+M)len<<=1; for(int i=0;i<len;i++)rev[i]=(rev[i>>1]>>1)|(i&1?len>>1:0); for(int i=0;i<2;i++)for(int j=0;j<2;j++) { x.a[i][j].resize(len),y.a[i][j].resize(len); NTT(x.a[i][j],len,1),NTT(y.a[i][j],len,1); } node z; for(int i=0;i<2;i++)for(int j=0;j<2;j++) { z.a[i][j].resize(len); for(int k=0;k<2;k++)for(int l=0;l<len;l++)z.a[i][j][l]=(z.a[i][j][l]+1ll*x.a[i][k][l]*y.a[k][j][l])%MOD; } for(int i=0;i<2;i++)for(int j=0;j<2;j++)NTT(z.a[i][j],len,-1),z.a[i][j].resize(N+M-1); return z; } inline node solve(int l,int r) { if(l==r)return P[l]; int mid=(l+r)>>1; return mul(solve(l,mid),solve(mid+1,r)); } int main() { n=getint(),getint(); for(int i=1;i<=2*n;i++)A[i]=getint(); for(int i=1;i<=2*n;i++)B[i]=getint(); for(int i=1;i<=2*n;i++) { for(int j=0;j<2;j++)for(int k=0;k<2;k++)P[i].a[j][k].resize(2); if(A[i-1]<=A[i])P[i].a[0][0][1]=1; if(B[i-1]<=A[i])P[i].a[0][1][1]=1; if(A[i-1]<=B[i])P[i].a[1][0][0]=1; if(B[i-1]<=B[i])P[i].a[1][1][0]=1; } node Ans=solve(1,2*n); printf("%d\n",upd(Ans.a[0][0][n]+Ans.a[1][0][n])); }