7.30 NOI模擬賽 B Easy Sum 分塊 NTT
阿新 • • 發佈:2020-08-01
LINK:Easy Sum
考試的時候一臉懵逼 想不通這個\(n^2\)的還能怎麼優化.
事實上暴力\(n^2\)有30~40 而我臉黑 只有30 很氣...
把點抽象到楊輝三角上可以發現這是若干個行上的K個點求和 如果是對列上求和或者總體求和就好做的多.
另外一種\(n^2\)是 這n個點從\((a_i,b_i)\)這個位置走到\((0,k)\)
實際上在圖上進行dp求方案 很多神仙使用滾動陣列+迴圈展開\(n^2\)暴力過了這道題.
正解是這樣的:
將座標分塊 這樣做是方便後續的轉移.
考慮隔B分上一塊然後我們維護每一列的的dp值 從上一塊dp到下一塊.
每次由上一列dp到下一列需要做一個字首和的東西(實際上是字尾和.
其實就是乘以多項式\(\frac{1}{(1-x)}\)跳B列就乘以\(\frac{1}{(1-x)^B}\)
可能會有疑問為什麼不直接跳第一列而是一塊一塊跳.
在跳的過程中存在一個點要對下一列有貢獻了 所以我們此時暴力加上貢獻還是\(n^2\)的.
不妨考慮加到上一次要跳的那一列 這樣計算出自己需要res次字首和 那麼就是給上次的列的多項式加上一個\(x^y\cdot (1-x)^{B-res}\)
這個多項式最長只有B所以就可以接受了.
總複雜度為\(n\cdot B+\frac{n}{B}\cdot nlogn\) 顯然當B取\(\sqrt{nlogn}\)時最優.
程式碼寫的很容易理解.
這個思路是真的妙 建議一寫.
code
//#include<bits/stdc++.h> #include<iostream> #include<cstdio> #include<ctime> #include<cctype> #include<queue> #include<deque> #include<stack> #include<iostream> #include<iomanip> #include<cstdio> #include<cstring> #include<string> #include<ctime> #include<cmath> #include<cctype> #include<cstdlib> #include<queue> #include<deque> #include<stack> #include<vector> #include<algorithm> #include<utility> #include<bitset> #include<set> #include<map> #define ll long long #define db double #define INF 1000000000 #define inf 1000000000000000ll #define ldb long double #define pb push_back #define put_(x) printf("%d ",x); #define get(x) x=read() #define gt(x) scanf("%d",&x) #define gi(x) scanf("%lf",&x) #define put(x) printf("%d\n",x) #define putl(x) printf("%lld\n",x) #define rep(p,n,i) for(RE int i=p;i<=n;++i) #define go(x) for(int i=lin[x];i;i=nex[i]) #define fep(n,p,i) for(RE int i=n;i>=p;--i) #define vep(p,n,i) for(RE int i=p;i<n;++i) #define pii pair<int,int> #define mk make_pair #define RE register #define P 13331ll #define gf(x) scanf("%lf",&x) #define pf(x) ((x)*(x)) #define uint unsigned long long #define ui unsigned #define EPS 1e-5 #define sq sqrt #define S second #define F first #define mod 998244353 #define md 1000000007 #define max(x,y) ((x)<(y)?y:x) #define a(i) t[i].a #define b(i) t[i].b using namespace std; char buf[1<<15],*fs,*ft; inline char getc() { return (fs==ft&&(ft=(fs=buf)+fread(buf,1,1<<15,stdin),fs==ft))?0:*fs++; } inline int read() { RE int x=0,f=1;RE char ch=getc(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getc();} while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getc();} return x*f; } inline ll Read() { RE ll x=0,f=1;RE char ch=getc(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getc();} while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getc();} return x*f; } const int MAXN=350010,maxn=100010,N=2000,G=3; int lim=1,n,B,m,INV,IG; int fac[MAXN],inv[MAXN],h[N][N]; int f[MAXN],g[MAXN],rev[MAXN],O[MAXN]; struct wy{int a,b;}t[maxn]; vector<int>w[N]; inline int ksm(int b,int p) { int cnt=1; while(p) { if(p&1)cnt=(ll)cnt*b%mod; b=(ll)b*b%mod;p=p>>1; } return cnt; } inline int C(int a,int b){return (ll)fac[a]*inv[b]%mod*inv[a-b]%mod;} inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;} inline void NTT(int *a,int op) { vep(0,lim,i)if(i<rev[i])swap(a[i],a[rev[i]]); for(int len=2;len<=lim;len=len<<1) { int mid=len>>1; int wn=ksm(op==1?G:IG,(mod-1)/len); vep(1,mid,j)O[j]=(ll)O[j-1]*wn%mod; for(int j=0;j<lim;j+=len) { vep(0,mid,i) { int x=(ll)a[i+j+mid]*O[i]%mod; a[i+j+mid]=(a[i+j]-x+mod)%mod; a[i+j]=add(a[i+j],x); } } } if(op==-1)vep(0,lim,i)a[i]=(ll)a[i]*INV%mod; } inline bool cmp(wy x,wy y){return x.a<y.a;} int main() { //freopen("1.in","r",stdin); n=read();B=(int)sqrt(n*20*1.0); rep(1,n,i)get(a(i)),get(b(i)); lim=1;while(lim<=2*n)lim=lim<<1; vep(0,lim,i)rev[i]=rev[i>>1]>>1|((i&1)?lim>>1:0); m=n+B;INV=ksm(lim,mod-2);IG=ksm(G,mod-2);O[0]=1; fac[0]=1;rep(1,m,i)fac[i]=(ll)fac[i-1]*i%mod; inv[m]=ksm(fac[m],mod-2);fep(m-1,0,i)inv[i]=(ll)inv[i+1]*(i+1)%mod; rep(0,n,i)g[i]=C(i+B-1,B-1); NTT(g,1);//NTT(g,-1); rep(0,B,i) rep(0,i,j) h[i][j]=(ll)C(i,j)*((j&1)?mod-1:1)%mod; sort(t+1,t+1+n,cmp);int BB=n/B+1; rep(1,n,i)w[a(i)/B+1].pb(i); fep(BB,1,k) { vep(0,(int)w[k].size(),j) { int id=w[k][j]; int y=n-b(id); int res=a(id)%B+1; //(1-x)^{B-res}*x^y rep(0,B-res,i)if(y+i<=n)f[y+i]=add(f[y+i],h[B-res][i]); } NTT(f,1); vep(0,lim,i)f[i]=(ll)f[i]*g[i]%mod; NTT(f,-1); vep(n+1,lim,i)f[i]=0; } rep(0,n-1,i)printf("%d ",f[n-i]); return 0; }