BSGS演算法 學習筆記
技術標籤:OI技巧
BSGS演算法 學習筆記
文章目錄
簡介
BSGS(Baby-step Giant-step,大步小步,北上廣深,拔山蓋世)演算法,是用來求解高次同餘方程
a
x
≡
b
(
m
o
d
n
)
,
a
⊥
n
a^x\equiv b\pmod n,a\perp n
ax≡b(modn),a⊥n,或者換個寫法
x
≡
log
a
b
(
m
o
d
n
)
x\equiv\log_ab\pmod n
x≡logab(modn),即離散對數。
前置知識
尤拉定理:
a
φ
(
n
)
≡
1
(
m
o
d
n
)
,
a
⊥
n
a^{\varphi(n)}\equiv 1\pmod n,a\perp n
演算法思想
我們根據尤拉定理, a x m o d n a^x\bmod n axmodn的值每經過 φ ( n ) \varphi(n) φ(n) 就會有個迴圈,所以理論上我們只要看看 x ∈ [ 1 , φ ( n ) ] x\in[1,\varphi(n)] x∈[1,φ(n)] 內有沒有 x x x 滿足 a x ≡ b a^x\equiv b ax≡b。
其實也並不一定非得是
φ
(
n
)
\varphi(n)
φ(n) 吧,檢驗的範圍再大一點也不影響結果,為了省去這個求
φ
(
n
)
\varphi(n)
φ(n) 的過程~~(分解因數怪麻煩的)~~,我們直接檢驗
x
∈
[
1
,
n
−
1
]
x\in [1,n-1]
演算法採用一種分塊解決的思想,這也是 “大步” “小步” 的由來。下面是演算法的推導過程。
a x ≡ b ( m o d n ) a^x\equiv b\pmod n ax≡b(modn)
設定一個引數 k k k ,把 x x x 寫成一種稍有改變的“帶餘除法”的形式。
x
=
p
k
−
q
(
1
≤
p
≤
⌊
n
k
⌋
+
1
,
1
≤
q
≤
k
)
x=pk-q(1\le p\le\lfloor \dfrac{n}{k}\rfloor+1,1\le q \le k)
這樣 x x x 的值域可以包含 [ 0 , n − 1 ] [0,n-1] [0,n−1] 的整個區間。
於是
a p k − q ≡ b ( m o d n ) a^{pk-q}\equiv b\pmod n apk−q≡b(modn)
a p k ≡ b a q ( m o d n ) a^{pk}\equiv ba^q\pmod n apk≡baq(modn)
把這個式子分成左右兩部分,先列舉 q ∈ [ 1 , k ] q\in[1,k] q∈[1,k],計算處所有的 b a q m o d n ba^q\bmod n baqmodn,存入一個HASH表中。接下來,只要再列舉 p ∈ [ 1 , ⌊ n k ⌋ + 1 ] p\in[1,\lfloor \dfrac{n}{k}\rfloor+1] p∈[1,⌊kn⌋+1] ,計算 a p k a^{pk} apk,在HASH表中查詢是否存在這樣的 q q q 滿足上面那個等式即可。如果有多個 q q q,應選擇最大的那個,這樣可以使得 x = p k − q x=pk-q x=pk−q 最小。
時間複雜度為 O ( k + ⌊ n k ⌋ ) \mathcal O(k + \lfloor \dfrac{n}{k}\rfloor) O(k+⌊kn⌋),顯然當 k = n k=\sqrt n k=n 時有最小複雜度為 O ( n ) \mathcal O(\sqrt n) O(n )。
例題
P3846 [TJOI2007] 可愛的質數/【模板】BSGS
參考程式碼
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
typedef long long ll;
char In[1 << 20], *ss = In, *tt = In;
#define getchar() (ss == tt && (tt = (ss = In) + fread(In, 1, 1 << 20, stdin), ss == tt) ? EOF : *ss++)
ll read() {
ll x = 0, f = 1; char ch = getchar();
for(; ch < '0' || ch > '9'; ch = getchar()) if(ch == '-') f = -1;
for(; ch >= '0' && ch <= '9'; ch = getchar()) x = x * 10 + int(ch - '0');
return x * f;
}
#define ing long long
namespace HS {
const int MAXNODE = (1 << 16) + 10;
const int BS = 1491001;
int nxt[MAXNODE], num[MAXNODE], val[MAXNODE], cnt, head[BS];
int hs(int x) {
return x % BS;
}
void ins(int x, int i) {
int key = hs(x);
num[++cnt] = x; val[cnt] = i; nxt[cnt] = head[key]; head[key] = cnt;
}
int qry(int x) {
int key = hs(x);
int mx = -1;
for(int i = head[key]; i; i = nxt[i])
if(num[i] == x) mx = max(mx, val[i]);
return mx;
}
}
int qpow(int a, int n, int mod) {
int ret = 1;
for(; n; n >>= 1, a = 1ll * a * a % mod)
if(n & 1) ret = 1ll * ret * a % mod;
return ret;
}
int p, b, n, m;
signed main() {
p = read(), b = read(), n = read();
if(b % p == 0 && n) {
printf("no solution\n");
return 0;
}
m = sqrt(p);
for(int i = 1, k = 1ll * n * b % p; i <= m; i++, k = 1ll * k * b % p)
HS::ins(k, i);
for(int i = 1, bas = qpow(b, m, p), k = bas; i <= m+1; i++, k = 1ll * k * bas % p) {
int t = HS::qry(k);
if(t != -1) {
printf("%lld\n", 1ll * m * i - t);
return 0;
}
}
printf("no solution\n");
return 0;
}