AtCoder Beginner Contest 245 Ex - Product Modulo 2
這個題解是基於官方題解,但是官方題解我看了好久才看懂,所以在官方題解的基礎上增加一些解釋。
建議也看看B站裡dls的講解視訊。
拆分
藉助CRT可以將題目拆分成多個子問題,最後再合併,就可以得到原問題的答案。
CRT的式子\(x \equiv a_i \mod m_i\),且\(m_i\)兩兩互素。
把\(x\)看成\(N\),\(m_i = {p_{i}}^{e_i}\),這樣的話CRT的條件還是滿足的。這裡\(M = \prod_{i} {p_i}^{e_i}\)且對於每一個\(i\)都是一個子問題。
然後可以通過一個類似exLucas的過程,將子問題的答案拼成原問題的答案。
由於各個子問題之間是相互獨立的,所以將所有子問題的答案乘起來就是原問題的答案。
解決子問題
記子問題為\(f(p, e, K, N)\)。
現在的問題就是要解決子問題,可以用生成函式結合快速冪來做。
首先,就是說\(N_1\)和\(N_2\)的質因子分解中,\(p_i\)的指數相同,那麼\(f(p, e, K, N_1) = f(p, e, K, N_2)\),這個可以用歸納法證明。所以可以將指數相同的數合併成一類,這樣方便計算,最後要計算答案的時候再除以類的大小,就能得到原本的答案。
特別的,可以認為\(N = 0\)包含無窮大個\(p\)。
注意到\(p^c \mod p^e\)在\(c > e\)的時候都為\(0\),所以可以將\(c > e\)的歸為一類。
容易證明至多有\(O(\log N)\)個類的問題。
然後,包含\(K\)個數的答案可以由包含\(K - 1\)個數的答案推導得到,就是說如果\(k = i + j\),那麼包含\(k\)個\(p\)的答案可以由包含\(i\)個\(p\)的和包含\(j\)個\(p\)的答案合併得到。
但是一個一個推太慢了,所以使用快速冪來加速,因為是線性組合,所以成立。
其實就是說構造一個生成函式\(f(x) = \sum_{i = 0}^{e} a_i x^i\),表示\(a_i\)表示包含\(i\)個\(p\)的方案數,特別的\(a_e\)表示包含大於等於\(e\)個\(p\)的方案數,則\(f^{K}(x)\)
只有一個數的時候,可以方便的計算出\(a_i\)的值,就是\(a_{e} = 0, a_{e - 1} = p - 1, a_{i} = a_{i + 1} \times p\)。
注意
官方題解中也有提到,就是說可能出現除\(0\)的情況,但是因為\(M \le 10^{12}\),所以這個時候指數\(e = 1\),這個可以特判一下。
\(N = 0\),那麼就是\(K\)個數中任意一個為零即可,用所有方案減去全不為零的方案就是答案,即\(p^{K} - (p-1)^{K}\)。
\(N \ne 0\),這個時候,前\(K - 1\)個數可以是任意非零元素,而最後一個元素是唯一確定的,即\(a_K = N \times (\prod_{i = 1}^{K - 1} a_i)^{-1}\)。所以方案數是\((p-1)^{K-1}\)
AC程式碼
// Problem: Ex - Product Modulo 2
// Contest: AtCoder - AtCoder Beginner Contest 245
// URL: https://atcoder.jp/contests/abc245/tasks/abc245_h
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define CPPIO \
std::ios::sync_with_stdio(false), std::cin.tie(0), std::cout.tie(0);
#ifdef BACKLIGHT
#include "debug.h"
#else
#define logd(...) ;
#endif
using i64 = int64_t;
using u64 = uint64_t;
void solve_case(int Case);
int main() {
CPPIO;
int T = 1;
// std::cin >> T;
for (int t = 1; t <= T; ++t) {
solve_case(t);
}
return 0;
}
template <typename ValueType, ValueType mod_, typename SupperType = int64_t>
class Modular {
private:
ValueType value_;
ValueType normalize(SupperType value) const {
if (value >= 0 && value < mod_)
return value;
value %= mod_;
if (value < 0)
value += mod_;
return value;
}
ValueType power(ValueType value, size_t exponent) const {
ValueType result = 1;
ValueType base = value;
while (exponent) {
if (exponent & 1)
result = SupperType(result) * base % mod_;
base = SupperType(base) * base % mod_;
exponent >>= 1;
}
return result;
}
public:
Modular() : value_(0) {}
Modular(const SupperType& value) : value_(normalize(value)) {}
ValueType value() const { return value_; }
Modular inv() const { return Modular(power(value_, mod_ - 2)); }
Modular power(size_t exponent) { return Modular(power(value_, exponent)); }
friend Modular operator+(const Modular& lhs, const Modular& rhs) {
ValueType result = lhs.value() + rhs.value() >= mod_
? lhs.value() + rhs.value() - mod_
: lhs.value() + rhs.value();
return Modular(result);
}
friend Modular operator-(const Modular& lhs, const Modular& rhs) {
ValueType result = lhs.value() - rhs.value() < 0
? lhs.value() - rhs.value() + mod_
: lhs.value() - rhs.value();
return Modular(result);
}
friend Modular operator*(const Modular& lhs, const Modular& rhs) {
ValueType result = SupperType(1) * lhs.value() * rhs.value() % mod_;
return Modular(result);
}
friend Modular operator/(const Modular& lhs, const Modular& rhs) {
ValueType result = SupperType(1) * lhs.value() * rhs.inv().value() % mod_;
return Modular(result);
}
};
template <typename StreamType,
typename ValueType,
ValueType mod,
typename SupperType = int64_t>
StreamType& operator<<(StreamType& out,
const Modular<ValueType, mod, SupperType>& modular) {
return out << modular.value();
}
template <typename StreamType,
typename ValueType,
ValueType mod,
typename SupperType = int64_t>
StreamType& operator>>(StreamType& in,
Modular<ValueType, mod, SupperType>& modular) {
ValueType value;
in >> value;
modular = Modular<ValueType, mod, SupperType>(value);
return in;
}
// using mint = Modular<int, 1'000'000'007>;
using mint = Modular<int, 998'244'353>;
std::string to_string(mint v) {
return to_string(v.value());
}
std::vector<std::pair<i64, i64>> factor(i64 n) {
std::vector<std::pair<i64, i64>> pe;
for (i64 i = 2; i * i <= n; ++i) {
if (n % i == 0) {
i64 e = 0;
while (n % i == 0) {
++e;
n = n / i;
}
pe.emplace_back(i, e);
}
}
if (n > 1)
pe.emplace_back(n, 1);
return pe;
}
void solve_case(int Case) {
i64 k, n, m;
std::cin >> k >> n >> m;
auto f = [](i64 p, i64 e, i64 k, i64 n) -> mint {
auto cp = [](i64 n, i64 p) -> i64 {
if (n == 0)
return 1e9;
i64 e = 0;
while (n % p == 0) {
++e;
n = n / p;
}
return e;
};
auto mul = [&](const std::vector<mint>& a,
const std::vector<mint>& b) -> std::vector<mint> {
std::vector<mint> c(e + 1);
for (i64 i = 0; i <= e; ++i) {
for (i64 j = 0; j <= e; ++j) {
c[std::min(e, i + j)] = c[std::min(e, i + j)] + a[i] * b[j];
}
}
return c;
};
if (e == 1) {
i64 pe = 1;
for (int i = 1; i <= e; ++i)
pe = pe * p;
n = n % pe;
if (n == 0)
return mint(p).power(k) - mint(p - 1).power(k);
else
return mint(p - 1).power(k - 1);
}
i64 c = cp(n, p);
if (c > e)
c = e;
std::vector<mint> r(e + 1), x(e + 1);
r[0] = 1;
x[e] = 1;
x[e - 1] = p - 1;
for (i64 i = e - 2; i >= 0; --i)
x[i] = x[i + 1] * p;
auto d = x;
while (k) {
if (k & 1)
r = mul(r, x);
x = mul(x, x);
k >>= 1;
}
logd(r);
return r[c] / d[c];
};
mint ans(1);
auto pe = factor(m);
logd(pe);
for (auto [p, e] : pe) {
ans = ans * f(p, e, k, n);
}
std::cout << ans.value() << "\n";
}