1. 程式人生 > 其它 >AtCoder Beginner Contest 245 Ex - Product Modulo 2

AtCoder Beginner Contest 245 Ex - Product Modulo 2

這個題解是基於官方題解,但是官方題解我看了好久才看懂,所以在官方題解的基礎上增加一些解釋。

建議也看看B站裡dls的講解視訊。

拆分

藉助CRT可以將題目拆分成多個子問題,最後再合併,就可以得到原問題的答案。

CRT的式子\(x \equiv a_i \mod m_i\),且\(m_i\)兩兩互素。

\(x\)看成\(N\)\(m_i = {p_{i}}^{e_i}\),這樣的話CRT的條件還是滿足的。這裡\(M = \prod_{i} {p_i}^{e_i}\)且對於每一個\(i\)都是一個子問題。

然後可以通過一個類似exLucas的過程,將子問題的答案拼成原問題的答案。

由於各個子問題之間是相互獨立的,所以將所有子問題的答案乘起來就是原問題的答案。

解決子問題

記子問題為\(f(p, e, K, N)\)

現在的問題就是要解決子問題,可以用生成函式結合快速冪來做。

首先,就是說\(N_1\)\(N_2\)的質因子分解中,\(p_i\)的指數相同,那麼\(f(p, e, K, N_1) = f(p, e, K, N_2)\),這個可以用歸納法證明。所以可以將指數相同的數合併成一類,這樣方便計算,最後要計算答案的時候再除以類的大小,就能得到原本的答案。

特別的,可以認為\(N = 0\)包含無窮大個\(p\)

注意到\(p^c \mod p^e\)\(c > e\)的時候都為\(0\),所以可以將\(c > e\)的歸為一類。

容易證明至多有\(O(\log N)\)個類的問題。

然後,包含\(K\)個數的答案可以由包含\(K - 1\)個數的答案推導得到,就是說如果\(k = i + j\),那麼包含\(k\)\(p\)的答案可以由包含\(i\)\(p\)的和包含\(j\)\(p\)的答案合併得到。

但是一個一個推太慢了,所以使用快速冪來加速,因為是線性組合,所以成立。

其實就是說構造一個生成函式\(f(x) = \sum_{i = 0}^{e} a_i x^i\),表示\(a_i\)表示包含\(i\)\(p\)的方案數,特別的\(a_e\)表示包含大於等於\(e\)\(p\)的方案數,則\(f^{K}(x)\)

就是\(K\)個數相乘之後結果的方案數。

只有一個數的時候,可以方便的計算出\(a_i\)的值,就是\(a_{e} = 0, a_{e - 1} = p - 1, a_{i} = a_{i + 1} \times p\)

注意

官方題解中也有提到,就是說可能出現除\(0\)的情況,但是因為\(M \le 10^{12}\),所以這個時候指數\(e = 1\),這個可以特判一下。

\(N = 0\),那麼就是\(K\)個數中任意一個為零即可,用所有方案減去全不為零的方案就是答案,即\(p^{K} - (p-1)^{K}\)

\(N \ne 0\),這個時候,前\(K - 1\)個數可以是任意非零元素,而最後一個元素是唯一確定的,即\(a_K = N \times (\prod_{i = 1}^{K - 1} a_i)^{-1}\)。所以方案數是\((p-1)^{K-1}\)

AC程式碼

// Problem: Ex - Product Modulo 2
// Contest: AtCoder - AtCoder Beginner Contest 245
// URL: https://atcoder.jp/contests/abc245/tasks/abc245_h
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>

#define CPPIO \
  std::ios::sync_with_stdio(false), std::cin.tie(0), std::cout.tie(0);
#ifdef BACKLIGHT
#include "debug.h"
#else
#define logd(...) ;
#endif

using i64 = int64_t;
using u64 = uint64_t;

void solve_case(int Case);

int main() {
  CPPIO;
  int T = 1;
  // std::cin >> T;
  for (int t = 1; t <= T; ++t) {
    solve_case(t);
  }
  return 0;
}

template <typename ValueType, ValueType mod_, typename SupperType = int64_t>
class Modular {
 private:
  ValueType value_;

  ValueType normalize(SupperType value) const {
    if (value >= 0 && value < mod_)
      return value;
    value %= mod_;
    if (value < 0)
      value += mod_;
    return value;
  }

  ValueType power(ValueType value, size_t exponent) const {
    ValueType result = 1;
    ValueType base = value;
    while (exponent) {
      if (exponent & 1)
        result = SupperType(result) * base % mod_;
      base = SupperType(base) * base % mod_;
      exponent >>= 1;
    }
    return result;
  }

 public:
  Modular() : value_(0) {}

  Modular(const SupperType& value) : value_(normalize(value)) {}

  ValueType value() const { return value_; }

  Modular inv() const { return Modular(power(value_, mod_ - 2)); }

  Modular power(size_t exponent) { return Modular(power(value_, exponent)); }

  friend Modular operator+(const Modular& lhs, const Modular& rhs) {
    ValueType result = lhs.value() + rhs.value() >= mod_
                           ? lhs.value() + rhs.value() - mod_
                           : lhs.value() + rhs.value();
    return Modular(result);
  }

  friend Modular operator-(const Modular& lhs, const Modular& rhs) {
    ValueType result = lhs.value() - rhs.value() < 0
                           ? lhs.value() - rhs.value() + mod_
                           : lhs.value() - rhs.value();
    return Modular(result);
  }

  friend Modular operator*(const Modular& lhs, const Modular& rhs) {
    ValueType result = SupperType(1) * lhs.value() * rhs.value() % mod_;
    return Modular(result);
  }

  friend Modular operator/(const Modular& lhs, const Modular& rhs) {
    ValueType result = SupperType(1) * lhs.value() * rhs.inv().value() % mod_;
    return Modular(result);
  }
};
template <typename StreamType,
          typename ValueType,
          ValueType mod,
          typename SupperType = int64_t>
StreamType& operator<<(StreamType& out,
                       const Modular<ValueType, mod, SupperType>& modular) {
  return out << modular.value();
}
template <typename StreamType,
          typename ValueType,
          ValueType mod,
          typename SupperType = int64_t>
StreamType& operator>>(StreamType& in,
                       Modular<ValueType, mod, SupperType>& modular) {
  ValueType value;
  in >> value;
  modular = Modular<ValueType, mod, SupperType>(value);
  return in;
}
// using mint = Modular<int, 1'000'000'007>;
using mint = Modular<int, 998'244'353>;
std::string to_string(mint v) {
  return to_string(v.value());
}

std::vector<std::pair<i64, i64>> factor(i64 n) {
  std::vector<std::pair<i64, i64>> pe;

  for (i64 i = 2; i * i <= n; ++i) {
    if (n % i == 0) {
      i64 e = 0;
      while (n % i == 0) {
        ++e;
        n = n / i;
      }
      pe.emplace_back(i, e);
    }
  }
  if (n > 1)
    pe.emplace_back(n, 1);
  return pe;
}

void solve_case(int Case) {
  i64 k, n, m;
  std::cin >> k >> n >> m;

  auto f = [](i64 p, i64 e, i64 k, i64 n) -> mint {
    auto cp = [](i64 n, i64 p) -> i64 {
      if (n == 0)
        return 1e9;
      i64 e = 0;
      while (n % p == 0) {
        ++e;
        n = n / p;
      }
      return e;
    };

    auto mul = [&](const std::vector<mint>& a,
                   const std::vector<mint>& b) -> std::vector<mint> {
      std::vector<mint> c(e + 1);
      for (i64 i = 0; i <= e; ++i) {
        for (i64 j = 0; j <= e; ++j) {
          c[std::min(e, i + j)] = c[std::min(e, i + j)] + a[i] * b[j];
        }
      }
      return c;
    };

    if (e == 1) {
      i64 pe = 1;
      for (int i = 1; i <= e; ++i)
        pe = pe * p;
      n = n % pe;
      if (n == 0)
        return mint(p).power(k) - mint(p - 1).power(k);
      else
        return mint(p - 1).power(k - 1);
    }

    i64 c = cp(n, p);
    if (c > e)
      c = e;

    std::vector<mint> r(e + 1), x(e + 1);
    r[0] = 1;
    x[e] = 1;
    x[e - 1] = p - 1;
    for (i64 i = e - 2; i >= 0; --i)
      x[i] = x[i + 1] * p;
    auto d = x;
    while (k) {
      if (k & 1)
        r = mul(r, x);
      x = mul(x, x);
      k >>= 1;
    }
    logd(r);

    return r[c] / d[c];
  };

  mint ans(1);
  auto pe = factor(m);
  logd(pe);
  for (auto [p, e] : pe) {
    ans = ans * f(p, e, k, n);
  }

  std::cout << ans.value() << "\n";
}