1. 程式人生 > 實用技巧 >【題解】 [EZEC-4]求和

【題解】 [EZEC-4]求和

對於百分之十的資料:隨便過。

下面推式子:

\[\sum_{i=1}^n\sum_{j=1}^n\gcd(i,j)^{i+j} \]

\[=\sum_{d=1}^n\sum_{i=1}^n\sum_{j=1}^nd^{i+j}[\gcd(i,j)=d] \]

\[=\sum_{d=1}^n\sum_{i=1}^\frac{n}{d}\sum_{j=1}^\frac{n}{d}d^{d(i+j)}[\gcd(i,j)=1] \]

\[=\sum_{d=1}^n\sum_{i=1}^\frac{n}{d}\sum_{j=1}^\frac{n}{d}d^{d(i+j)}\sum_{k|\gcd(i,j)}\mu(k) \]

\[=\sum_{d=1}^n\sum_{k=1}^\frac{n}{d}\mu(k)\sum_{i=1}^\frac{n}{kd}\sum_{j=1}^\frac{n}{kd}d^{kd(i+j)} \]

\(T=kd:\)

\[=\sum_{T=1}^n\sum_{k|T}\mu(k)\sum_{i=1}^\frac{n}{T}\sum_{j=1}^\frac{n}{T}[(\frac{T}{k})^T]^{i+j} \]

現在的問題在於\(\sum_{i=1}^\frac{n}{T}\sum_{j=1}^\frac{n}{T}[(\frac{T}{k})^T]^{i+j}.\)

  • 線性遞推

以下是@SOSCHINA大佬的思路:

\(g(n)=\sum_{i=1}^n\sum_{j=1}^n k^s.\)

列舉\(s=i+j.\)

則有:

\[g(n)=\sum_{s=2}^{n+1}(s-1)k^s+\sum_{s=n+2}^{2n}(2n-s+1)k^s \]

\[g(n+1)=\sum_{s=2}^{n+1}(s-1)k^s+\sum_{s=n+2}^{2n+2}(sn+3-s)k^s \]

\[g(n+1)-g(n)=\sum_{s=n+2}^{2n}2k^s+2k^{2n+1}+k^{2n+2} \]

第三行就是兩行相減。

對第一行的解釋:\([2,n+1]\)這裡的數,每個數作為\(i+j\)

都出現了\(x-1\)次。因為\(i\)可以取遍\([1,x-1].\)後面的那一些,\([n+2,2n]\)會發現\(i\)最大隻能到\(n,\)不能再取遍\(x-1\)個值了。此時能取到的應該是\(2n-s+1\)種。

對於\(g(n+1):\)這裡是把第一個式子的最後一個值移動到了後面那個式子,方便做差。

這時我們可以在小模數的情況下做到\(O(n*mod\))的預處理。

  • 化簡形式

\(x=(\frac{T}{k})^T.\)

則原式為\(\sum_{i=1}^\frac{n}{T}\sum_{j=1}^\frac{n}{T} x^{i+j}.\)

像不像一個多項式。

它就等於\((x+x^2+...x^\frac{n}{T})^2.\)

於是我們可以等比數列求和解出。

剩下的,可以做到\(O(n\log n\log mod)\)處理出整個式子。

#include<bits/stdc++.h>
using namespace std;
const int MAXN=1500001;
int mod,TT;
bitset<MAXN+1>vis;
int p[MAXN+1],mu[MAXN+1],T[MAXN+1],cnt,n,Ans;
inline int Mod(long long x){
	if(x<0)return x+mod;
    if(x>=mod)return x%mod;
    return x;
}
inline int add(int x,int y) {return Mod(1ll*x+1ll*y+1ll*mod);}
inline int mul(int x,int y) {return Mod(1ll*x*y);}
inline int qpow(int a,int b) {
	if(!b)return 1;
	if(a<=1||b==1)return a;
	a %= mod; 
	int res=1;
	while(b) {
		if(b&1)res=mul(res,a);
		a=mul(a,a);
		b>>=1;
	}
	return res;
}
inline int calc(int x,int y){
	if(y==1)return x;
	if(x==1)return y;
	int ans=x;
	int inv=qpow((1-x+mod)%mod,mod-2);
	int fm=(1-qpow(x,y)+mod)%mod;
	ans=mul(ans,mul(fm,inv));
	return ans;
}
inline int Calc(int x,int y){int ans=calc(x,y);return mul(ans,ans);} 
int main() {
	scanf("%d",&TT);
	mu[1]=1;
	int N=MAXN;
	for(register int i=2; i<=N; ++i) {
		if(!vis[i])p[++cnt]=i,mu[i]=-1;
		for(register int j=1; j<=cnt&&i*p[j]<=N; ++j) {
			vis[i*p[j]]=1;
			if(i%p[j]==0)break;
			mu[i*p[j]]=-mu[i];
		}
	}
	while(TT--) {
		scanf("%d%d",&n,&mod);
		N=n;Ans=0;
		for(register int i=1; i<=N; ++i) {
			for(register int j=i,k,x; j<=N; j+=i) {
				k=i;if(!mu[k])continue;
				x=qpow(j/k,j);
				T[j]=add(T[j],mul(mu[k],Calc(x,n/j)));
			}
		}
		for(register int i=1; i<=n; ++i)Ans=add(Ans,T[i]),T[i]=0;
		printf("%d\n",Ans);
	}

	return 0;
}

由於這裡是\(5*10^5\)的資料,所以略微卡常,但筆者通過非常不精湛的卡常技術跑到了\(3s\)以內,所以這裡的時間限制我開了\(3.2s\).

對等比數列進行精細處理,可以做到\(O(n\log^2n)\)的複雜度。

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int MAXN=1500000;
int mod,TT;
bitset<MAXN<<1>vis;
int p[MAXN<<1],mu[MAXN<<1],T[MAXN<<1],cnt,n,Ans;
inline int Mod(long long x){
    if(x>=mod)return x%mod;
    return x;
}
inline int add(int x,int y) {
	return Mod(x+y+mod);
}
inline int mul(int x,int y) {
	return Mod(1ll*x*y);
}
inline int qpow(int a,int b) {
	if(!b)return 1;
	if(a<=1||b==1)return a;
	a %= mod; 
	int res=1;
	while(b) {
		if(b&1)res=mul(res,a);
		a=mul(a,a);
		b>>=1;
	}
	return res;
}
inline int calc(int x,int y){
	if(y==1)return x;
		int res=calc(x,y/2);
	res=add(res,mul(res,qpow(x,y/2)));
	if(y&1)res=add(res,mul(x,qpow(x,y-1)));
	return res;
}
inline int Calc(int x,int y){int ans=calc(x,y);return mul(ans,ans);} 
signed main() {
	scanf("%lld",&TT);
	mu[1]=1;
	int N=MAXN;
	for(int i=2; i<=N; ++i) {
		if(!vis[i])p[++cnt]=i,mu[i]=-1;
		for(int j=1; j<=cnt&&i*p[j]<=N; ++j) {
			vis[i*p[j]]=1;
			if(i%p[j]==0)break;
			mu[i*p[j]]=-mu[i];
		}
	}
	while(TT--) {
		scanf("%lld%lld",&n,&mod);
		N=n;
		Ans=0;
		for(int i=1; i<=N; ++i) {
			for(int j=i; j<=N; j+=i) {
				int k=i;
				int x=qpow(j/k,j);
				if(!mu[k])continue;
				T[j]=add(T[j],mul(mu[k],Calc(x,n/j)));
			}
		}
		for(int i=1; i<=n; ++i)Ans=add(Ans,T[i]),T[i]=0;
		cout<<Ans<<endl;
	}

	return 0;
}

由於常數等原因,這分程式碼可以拿到\(50\)分的好成績。但我們可以通過另一種做法將常數/複雜度降低。

另一種做法

觀察:

\[\sum_{d=1}^n\sum_{k=1}^\frac{n}{d}\mu(k)\sum_{i=1}^\frac{n}{kd}\sum_{j=1}^\frac{n}{kd}d^{kd(i+j)} \]

\[=\sum_{d=1}^n \sum_{k=1}^\frac{n}{d} \mu(k)(d^{kd}+d^{2kd}+...+d^{\frac{n}{kd}*kd=n})^2. \]

這裡同樣觀察式子發現可以直接算。前一部分是\(O(n\ln n)\)\(n\)倍調和級數的複雜度,後面帶上一個\(O(\log n)\)精細處理的等比數列求求和複雜度。

(程式碼中的優化即使不加也是可以過的)

#define __AVX__ 1
#define __AVX2__ 1
#define __SSE__ 1
#define __SSE2__ 1
#define __SSE2_MATH__ 1
#define __SSE3__ 1
#define __SSE4_1__ 1
#define __SSE4_2__ 1
#define __SSE_MATH__ 1
#define __SSSE3__ 1
#pragma GCC optimize("Ofast,no-stack-protector,unroll-loops,fast-math")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4.1,sse4.2,avx,avx2,popcnt,tune=native")
#include <immintrin.h>
#include <emmintrin.h>
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <string>
#include <bitset>
using namespace std;
const int MAXN=1.5e6+10;
int mod,T;
bitset<MAXN+1>vis;
int p[MAXN+1],cnt,mu[MAXN+1],N;
inline int Mod(long long a, int pp){
    return a>=pp ? a%pp : a>=0 ? a : a+pp;
}
inline int add(int x,int y){return Mod( (1ll+x+y+mod-1ll),mod);}
inline int mul(int x,int y){return Mod(1ll*x*y,mod);}
void pretreatment(){
	mu[1]=1;
	for(int i=2;i<=MAXN;++i){
		if(!vis[i])p[++cnt]=i,mu[i]=-1;
		for(int j=1;j<=cnt&&i*p[j]<=MAXN;++j){
			vis[i*p[j]]=1;
			if(Mod(i,p[j])==0)break;
			mu[i*p[j]]=-mu[i];
		}
	}
}
inline int qpow(int a,int b){
	if(!b)return 1;
	if(a<=1||b==1)return a;
	int res=1;
	while(b){
		if(b&1)res=mul(res,a);
		a=mul(a,a);b>>=1;
	}
	return res;
}
inline int calc(int x,int y){
	if(y==1)return x;
	int res=calc(x,y>>1);
	res=add(res,mul(res,qpow(x,y>>1)));
	if(y&1)res=add(res,mul(x,qpow(x,y-1)));
	return res;
}

inline int Calc(int x,int y){int ans=calc(x,y);return mul(ans,ans);} 
int ssolve(int n,int d){
	int res=0;
	for(register int l=1;l<=n;++l){
		if(!mu[l])continue;
		res=add(res,mul(mu[l],Calc(qpow(d,l),n/l)));
	}
	return res;
}
int solve(int n){
	int ans=0;
	for(register int l=1;l<=n;l++){
		ans=add(ans,ssolve(n/l,qpow(l,l)));
	}
	return ans;
}
signed main(){
	scanf("%lld",&T);
	pretreatment();
	for(;T;T--){
		scanf("%lld%lld",&N,&mod);
		printf("%lld\n",solve(N));
	}
	return 0;
}

可以用整除分塊減少迴圈中乘法的使用,對程式碼速度可能有一定的提升。

#include<bits/stdc++.h>
using namespace std;
const int MAXN=1.5e6+10;
int mod,T;
bitset<MAXN+1>vis;
int p[MAXN+1],cnt,mu[MAXN+1],N;
inline int Mod(long long a, int pp){return a>=pp ? a%pp : a>=0 ? a : a+pp;}
inline int add(int x,int y){return Mod( (1ll+x+y+mod-1ll),mod);}
inline int mul(int x,int y){return Mod(1ll*x*y,mod);}
inline int qpow(int a,int b){
	if(!b)return 1;
	if(a<=1||b==1)return a;
	a=Mod(a,mod);
	int res=1;
	while(b){
		if(b&1)res=mul(res,a);
		a=mul(a,a);b>>=1;
	}
	return res;
}
inline int calc(int x,int y){
	if(y==1)return x;
	int res=calc(x,y>>1);
	res=add(res,mul(res,qpow(x,y>>1)));
	if(y&1)res=add(res,mul(x,qpow(x,y-1)));
	return res;
}
inline int Calc(int x,int y){int ans=calc(x,y);return mul(ans,ans);} 
int ssolve(int n,int d){
	int res=0;
	for(register int l=1,r;l<=n;l=r+1){
		r=(n/(n/l));
		int D=n/l;
		for(int i=l;i<=r;++i){
			if(!mu[i])continue;
			res=add(res,mul(mu[i],Calc(qpow(d,i),D)));
		}
	}
	return res;
}
int solve(int n){
	int ans=0;
	for(register int l=1,r;l<=n;l=r+1){
		r=(n/(n/l));
		int D=n/l;
		for(int i=l;i<=r;++i)ans=add(ans,ssolve(D,qpow(i,i)));
	}
	return ans;
}
int main(){
	scanf("%d",&T);
	mu[1]=1;
	for(register int i=2;i<=MAXN;++i){
		if(!vis[i])p[++cnt]=i,mu[i]=-1;
		for(register int j=1;j<=cnt&&i*p[j]<=MAXN;++j){
			vis[i*p[j]]=1;
			if(Mod(i,p[j])==0)break;
			mu[i*p[j]]=-mu[i];
		}
	}
	for(;T;T--){
		scanf("%d%d",&N,&mod);
		printf("%d\n",solve(N));
	}
	return 0;
}

出這題的本意其實是想看看有沒有吊打\(\text{std}\)的做法的,筆者推了很久並沒有找到線性的做法。