1. 程式人生 > 其它 >數論預習筆記-二次剩餘

數論預習筆記-二次剩餘

模板題面

模板題

預習筆記

學的是 cipolla ,因為看不懂 Tonelli-Shanks 。
找前數競同學學了複數乘法之後對證明過程恍然大悟。。。

首先建立一個複數域,雖然不知道為什麼要建一個複數域。。。
然後所有的數都可以用複數表示,即 \(A + Bi\) 的形式。

用隨機數找到一個 \(n\) 的非二次剩餘 \(a\),根據密度,期望次數為 2 次。
定義 \(i^2 = a ^ 2 - n\)

其中用等號代替同餘符號。

  • 引理1:
    \((a + b) ^ p = a ^ p + b ^ p \ (\bmod m)\)
\[(a + b) ^ p = \sum_{i = 0} ^ {p} C_{p} ^ {i} a ^ {p - i}b ^ i = \sum_{i = 0} ^ {p} \frac{p!}{(p - i)!i!}a ^ {p - i}b ^ i \]

根據魔法的分配律,那麼中間含有 \(p\)

的項就會被消去。

  • 引理2:
    \(i ^ p = -i \ (\bmod m)\)
\[i ^ p = i ^ {p - 1} * i = (i ^ 2) ^ {\frac{p - 1}{2}} * i = (a ^ 2 - n) ^ {\frac{p - 1}{2}} * i=-i \]
  • 引理3:
    \(a^p = a \ (\bmod m)\)
    即費馬小定理。

所以 \(x = (a + i) ^ {\frac{p +1}{2}} = ((a + i) ^ {p + 1}) ^ \frac{1}{2} = ((a + i) ^ p * (a + i)) ^ \frac{1}{2} = ((a ^ p + i ^ p) * (a + i)) ^ \frac{1}{2} = ((a - i)(a + i)) ^ \frac{1}{2} = (a ^ 2 - i ^ 2) ^ \frac{1}{2} = n ^ \frac{1}{2}\)

程式碼

#include<cstdio>
#include<cctype>
#include<ctime>
#include<algorithm>
 
#define LL long long

LL I_pow, t;

struct complex {
	LL x, y;
	inline complex mul(complex a, complex b, LL p) {
		complex ans = {0, 0};
		ans.x = (a.x * b.x % p + a.y * b.y % p * I_pow % p) % p;
		ans.y = (a.x * b.y % p + b.x * a.y % p) % p;
		return ans;
	}
};

inline LL power(LL a, LL b, LL p) {
	LL res = 1;
	for(; b; b >>= 1, a = a * a % p)
		if(b & 1) res = res * a % p;
	return res;
}

inline LL complex_pow(complex a, LL b, LL p) {
	complex res = {1, 0};
	for(; b; b >>= 1, a = a.mul(a, a, p))
		if(b & 1) res = res.mul(res, a, p);
	return res.x % p; 
}

LL cipolla(LL n, LL p) {
	n %= p;
	if(p == 2) return n;
	if(power(n, (p - 1) / 2, p) == p - 1) return -1;
	LL a;
	while(true) {
		a = rand() % p;
		I_pow = ((a * a % p - n) % p + p) % p;
		if(power(I_pow, (p - 1) / 2, p) == p - 1) break;
	}
	return complex_pow((complex){a, 1}, (p + 1) / 2, p);
}

inline int read() {
	int x = 0, f = 1, c = getchar();
	for(; !isdigit(c); c = getchar())
		if(c == '-')
			f = -1;
	for(; isdigit(c); c = getchar())
		x = x * 10 + c - 48;
	return x * f;
}

int main() {
	srand(time(0));
	t = read();
	while(t--) {
		LL n = read(), mod = read();
		if(!n) {
			puts("0"); continue;
		}
		LL ans1 = cipolla(n, mod);
		if(ans1 == -1) {
			puts("Hola!"); continue;
		}
		LL ans2 = mod - ans1;
		if(ans1 > ans2) std::swap(ans1, ans2);
		if(ans1 == ans2) printf("%lld\n", ans1);
		else printf("%lld %lld\n", ans1, ans2);
	}
	return 0;
}