1. 程式人生 > 其它 >類歐幾里得演算法 重學筆記

類歐幾里得演算法 重學筆記

Solution

以前學過,但是太爛,而且很有侷限性,今重學一遍。

考慮假設我們要解決的問題為求:

\[\sum_{x=0}^{n} x^{k1}\lfloor\frac{ax+b}{c}\rfloor^{k2} \]

可以發現可以分為幾種情況進行討論:

  1. \(a=0\) 或者 \(\lfloor\frac{an+b}{c}\rfloor=0\)

可以發現 \(\lfloor\frac{ax+b}{c}\rfloor\) 不變,直接 \(k1\) 次的字首和就好了。

  1. \(a\ge c\)

\(q=\lfloor\frac{a}{c}\rfloor,r=a\mod c\) ,那麼可以得到答案就是:

\[\sum_{x=0}^{n} x^{k1}(qx+\lfloor\frac{xr+b}{c}\rfloor)^{k2} \]\[=\sum_{i=0}^{k2} q^i\binom{k2}{i}\sum_{x=0}^{n} x^{k1+i}\lfloor\frac{xr+b}{c}\rfloor^{k2-i} \]

直接遞迴即可。

  1. \(b\ge c\)

\(q=\lfloor\frac{b}{c}\rfloor,r=b\mod c\),同理可以得到答案就是:

\[\sum_{i=0}^{k2} \binom{k2}{i}q^i\sum_{x=0}^{n} x^{k1}\lfloor\frac{ax+r}{c}\rfloor^{k2-i} \]
  1. \(\max(a,b)<c\)

可以把 \(\lfloor\frac{ax+b}{c}\rfloor^{k2}\) 拆開,變成:

\[\sum_{j=0}^{\lfloor\frac{ax+b}{c}\rfloor-1}((j+1)^{k2}-j^{k2}) \]

那麼答案就是:

\[\sum_{j=0} ((j+1)^{k2}-j^{k2})\sum_{x=0}^{n} x^{k1}[x>\lfloor\frac{cj+c-b-1}{a}\rfloor] \]\[\sum_{j=0} ((j+1)^{k2}-j^{k2})\sum_{i=0}^{n} i^{k1}-\sum_{j=0} ((j+1)^{k2}-j^{k2})\times \sum_{i=0}^{\lfloor\frac{cj+c-b-1}{a}\rfloor}i^{k1} \]

然後前面這部分可以算 \(k2\)

次的字首和,考慮如何算後面那一部分。你發現後面那一個是 \(\lfloor\frac{cj+c-b-1}{a}\rfloor+1\) 次的多項式,假設第 \(i\) 次係數為 \(B_i\),那麼就可以寫成:

\[\sum_{i=0}^{k2-1}\binom{k2}{i}\sum_{j=0}^{k1+1} B_j\sum_{x=0}^{n} x^i\lfloor\frac{cx+c-b-1}{a}\rfloor^j \]

也可以遞迴了。

Code

#include <bits/stdc++.h>
using namespace std;

#define Int register int
#define mod 1000000007
#define int long long
#define MAXN 15

template <typename T> inline void read (T &t){t = 0;char c = getchar();int f = 1;while (c < '0' || c > '9'){if (c != ' ' && c != '\n') f = -f;c = getchar();}while (c >= '0' && c <= '9'){t = (t << 3) + (t << 1) + c - '0';c = getchar();} t *= f;}
template <typename T,typename ... Args> inline void read (T &t,Args&... args){read (t);read (args...);}
template <typename T> inline void write (T x){if (x < 0){x = -x;putchar ('-');}if (x > 9) write (x / 10);putchar (x % 10 + '0');}
template <typename T> void chkmax (T &a,T b){a = max (a,b);}
template <typename T> void chkmin (T &a,T b){a = min (a,b);}

int mul (int a,int b){return 1ll * a * b % mod;}
int dec (int a,int b){return a >= b ? a - b : a + mod - b;}
int add (int a,int b){return a + b >= mod ? a + b - mod : a + b;}
int qkpow (int a,int b){
	int res = 1;for (;b;b >>= 1,a = mul (a,a)) if (b & 1) res = mul (res,a);
	return res;
}
int inv (int x){return qkpow (x,mod - 2);}
void Add (int &a,int b){a = add (a,b);}
void Sub (int &a,int b){a = dec (a,b);}

struct node{
	int t[MAXN][MAXN];
	node(){memset (t,0,sizeof (t));}
	int * operator [](const int key){return t[key];}
};

int C[MAXN][MAXN],mat[MAXN][MAXN];
struct Func{//處理每個F(i,x) \sum_{k=0}^{x} k^i 的i+1次函式
	int a[MAXN];
	int & operator [](const int key){return a[key];}
	void Gauss (int K){
		for (Int i = 0;i <= K;++ i){
			int tmp = i;
			for (Int j = i + 1;j <= K;++ j) if (mat[j][i]){tmp = j;break;}
			if (tmp ^ i) swap (mat[tmp],mat[i]);
			for (Int j = i + 1,iv = inv (mat[i][i]);j <= K;++ j){
				int del = mul (mat[j][i],iv);
				for (Int k = i;k <= K + 1;++ k) Sub (mat[j][k],mul (del,mat[i][k]));
			}
		}
		for (Int i = K;~i;-- i){
			for (Int j = i + 1;j <= K;++ j) Sub (mat[i][K + 1],mul (a[j],mat[i][j]));
			a[i] = mul (mat[i][K + 1],inv (mat[i][i]));
		}
	}
	void gen(int k){
		for (Int i = 0,res = 0;i <= k + 1;++ i) Add (res,qkpow (i,k)),mat[i][k + 2] = res;
		for (Int i = 0;i <= k + 1;++ i) for (Int j = 0,res = 1;j <= k + 1;++ j,res = mul (res,i)) mat[i][j] = res;
		Gauss (k + 1);
	}
	int getit (int k,int x){
		int res = 0;
		for (Int i = k + 1;i >= 0;-- i) res = add (a[i],mul (res,x));
		return res;
	}
}f[MAXN];

int relans (int n,int a,int b,int c,int k1,int k2){
	int res = 0;
	for (Int x = 0;x <= n;++ x)
		Add (res,mul (qkpow (x,k1),qkpow ((a * x + b) / c % mod,k2)));
	return res;
}

node getit (int n,int a,int b,int c){
	node ans;
	if (a == 0 || a * n + b < c){
		int t = (a * n + b) / c % mod;
		for (Int k1 = 0;k1 <= 10;++ k1)
			for (Int k2 = 0,res = 1;k1 + k2 <= 10;++ k2,res = mul (res,t))
				ans[k1][k2] = mul (res,f[k1].getit (k1,n));
	}
	else if (a >= c){
		int q = a / c,r = a % c;
		node lst = getit (n,r,b,c);
		for (Int k1 = 0;k1 <= 10;++ k1)
			for (Int k2 = 0;k1 + k2 <= 10;++ k2)
				for (Int i = 0,res = 1;i <= k2;++ i,res = mul (res,q))
					Add (ans[k1][k2],mul (mul (res,C[k2][i]),lst[k1 + i][k2 - i]));
	}
	else if (b >= c){
		int q = b / c,r = b % c;
		node lst = getit (n,a,r,c);
		for (Int k1 = 0;k1 <= 10;++ k1)
			for (Int k2 = 0;k1 + k2 <= 10;++ k2)
				for (Int i = 0,res = 1;i <= k2;++ i,res = mul (res,q))
					Add (ans[k1][k2],mul (mul (res,C[k2][i]),lst[k1][k2 - i]));
	}
	else{
		int M = (a * n + b) / c;
		node lst = getit (M - 1,c,c - b - 1,a);
		for (Int k1 = 0;k1 <= 10;++ k1)
			for (Int k2 = 0;k1 + k2 <= 10;++ k2){
				if (k2 == 0) ans[k1][k2] = f[k1].getit (k1,n);
				else{	
					ans[k1][k2] = mul (qkpow (M,k2),f[k1].getit (k1,n));
					for (Int i = 0;i <= k2 - 1;++ i)
						for (Int j = 0;j <= k1 + 1;++ j)	
							Sub (ans[k1][k2],mul (mul (C[k2][i],f[k1][j]),lst[i][j]));
				}
			}
	}
	return ans;
}

signed main(){
	for (Int i = 0;i <= 10;++ i) f[i].gen (i);
	for (Int i = 0;i <= 10;++ i){
		C[i][0] = 1;
		for (Int j = 1;j <= i;++ j) C[i][j] = add (C[i - 1][j],C[i - 1][j - 1]);
	}
	int T;read (T);
	while (T --> 0){
		int n,a,b,c,k1,k2;read (n,a,b,c,k1,k2);
		node ans = getit (n,a,b,c);write (ans[k1][k2]),putchar ('\n');
	}
	return 0;
}