1. 程式人生 > 其它 >educational round #5 cf616E Sum of Remainders

educational round #5 cf616E Sum of Remainders

詢問 \(n\mod 1+n\mod 2+n\mod 3+\cdots +n\mod m\)\(10^9+7\) 取模的值.

\(1\leq n,m\leq 10^{13}\)

考慮 \(\mathrm O(m)\) 地計算肯定是不可行的 .

是否可以和分解因數一樣,只考慮 \(\mathrm O(\sqrt n)\) 之內的內容呢 ?

是可以的,觀察可得對於 \(1\leq x\leq \sqrt n\) ,對於 \(\lfloor\frac{n}{y}\rfloor =x\)\(y\) 的範圍是 $\lfloor \frac{n}{x+1}\rfloor+1 \leq y\leq \lfloor \frac{n}{x}\rfloor $ .

那麼,對於這些 \(y\) 的模數是 \(n-x\lfloor \frac{n}{x}\rfloor\)\(n-x(\lfloor \frac{n}{x+1}\rfloor+1)\) ,依次遞增 \(x\) . 這個直接等差序列求和 .

這是,可以發現,對於 \(m=x\)\(m\in [\lfloor \frac{n}{x+1}\rfloor+1 ,\lfloor \frac{n}{x}\rfloor ]\) 是可以 \(\mathrm O(1)\) 地求和的 .

隨著 \(x\) 的增加,可以求和的區間形如兩邊向中間衍生,知道相遇合並 .

所以,發現,只要列舉 \(x\in [1,\sqrt n]\)

,就可以得到所有 \([1,m]\) 的所有模數.

可能要特殊處理一下最後 \(x=\sqrt n\) 的情況 .

此題要非常注意取模 .

時間複雜度 : \(\mathrm O(\sqrt n)\)

空間複雜度 : \(\mathrm O(1)\)

code

#include<bits/stdc++.h>
using namespace std;
inline long long read(){
	char ch=getchar();
	while(ch<'0'||ch>'9')ch=getchar();
	long long res=0;
	while(ch>='0'&&ch<='9'){
		res=(res<<3)+(res<<1)+ch-'0';
		ch=getchar();
	}
	return res;
}
inline void print(int res){
	if(res==0){
		putchar('0');
		return;
	}
	int a[10],len=0;
	while(res>0){
		a[len++]=res%10;
		res/=10;
	}
	for(int i=len-1;i>=0;i--)
		putchar(a[i]+'0');
}
const int mod=1e9+7;
long long n,m;
long long ans=0;
int main(){
	int inv2=(mod+1)/2;
	n=read();m=read();
	long long s=0;
	while(s*s<=n)s++;
	if(s*s>n)s--;
	if(m<=s){
		for(int i=1;i<=m;i++){
			ans=(n%i+ans)%mod;
		}
		print(ans);
		putchar('\n');
		return 0;
	}
	int ans=0;
	for(long long i=1;(i+1)*(i+1)<=n;i++){
		ans=(ans+n%i)%mod;
		long long r=n/i,l=n/(i+1)+1;
		if(l>min(n-1,m))continue;
		r=min(r,min(n-1,m)); 
		long long R=n-r*i,L=n-l*i;
		long long tmp=1ll*(R+L)%mod*((r-l+1)%mod)%mod*inv2%mod; 
		ans=(ans+tmp)%mod;
	}
	if(1ll*s*s==n){
		if(s<=m){
			ans=(ans+n%s)%mod;	
		}
	}else{
		long long r=n/s;
		r=min(r,min(n-1,m));
		long long l=s+1; 
		if(l>r)l=r;
		long long R=n-r*s,L=n-l*s;
		long long tmp=1ll*(R+L)%mod*((r-l+1)%mod)%mod*inv2%mod;
		ans=(ans+tmp)%mod;
		if(l!=s)ans=(ans+n%s)%mod;
	}
	if(m>n)ans=(1ll*(m-n)%mod*(n%mod)%mod+ans)%mod;
	print(ans);
	putchar('\n');
	return 0;
}
/*inline? ll or int? size? min max?*/