題解 「THUPC 2017」小 L 的計算題 / Sum
阿新 • • 發佈:2020-08-19
題目大意
給出 \(a_{1,2,...,n}\),對於 \(\forall k\in [1,n]\) ,求出:
\[\sum_{i=1}^{n}a_i^k \]
\(n\le 2\times 10^5\),答案對 \(998244353\) 取模 。
思路
我們考慮對答案構造生成函式:
\[F(x)=\sum_{k=0}^{\infty} \sum_{i=1}^{n}a_i^kx^k \]
\[=\sum_{i=1}^{n}\frac{1}{1-a_ix} \]
\[=\sum_{i=1}^{n}(1+\frac{a_ix}{1-a_ix}) \]
\[=n-x(\sum_{i=1}^{n}\frac{-a_i}{1-a_ix}) \]
然後我們發現存在:
\[(\ln(1-a_ix))^{'}=\frac{-a_i}{1-a_ix} \]
於是:
\[F(x)=n-x(\sum_{i=1}^{n}\ln(1-a_ix))^{'} \]
\[=n-x(\ln \prod_{i=1}^{n}(1-a_ix))^{'} \]
然後這個東西就可以直接分治解決了,時間複雜度 \(\Theta(n\log^2 n)\) 。
\(\texttt{Code}\)
#pragma GCC optimize("Ofast") #pragma GCC optimize("inline", "no-stack-protector", "unroll-loops") #pragma GCC diagnostic error "-fwhole-program" #pragma GCC diagnostic error "-fcse-skip-blocks" #pragma GCC diagnostic error "-funsafe-loop-optimizations" #include <bits/stdc++.h> using namespace std; #define SZ(x) ((int)x.size()) #define Int register int #define mod 998244353 #define MAXN 1000005 int mul (int a,int b){return 1ll * a * b % mod;} int dec (int a,int b){return a >= b ? a - b : a + mod - b;} int add (int a,int b){return a + b >= mod ? a + b - mod : a + b;} int qkpow (int a,int k){ int res = 1;for (;k;k >>= 1,a = 1ll * a * a % mod) if (k & 1) res = 1ll * res * a % mod; return res; } int inv (int x){return qkpow (x,mod - 2);} typedef vector <int> poly; int w[MAXN],rev[MAXN]; void init_ntt (){ int lim = 1 << 19; for (Int i = 0;i < lim;++ i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << 18); int Wn = qkpow (3,(mod - 1) / lim);w[lim >> 1] = 1; for (Int i = lim / 2 + 1;i < lim;++ i) w[i] = mul (w[i - 1],Wn); for (Int i = lim / 2 - 1;i;-- i) w[i] = w[i << 1]; } void ntt (poly &a,int lim,int type){ #define G 3 #define Gi 332748118 static unsigned long long d[MAXN]; for (Int i = 0,z = 19 - __builtin_ctz(lim);i < lim;++ i) d[rev[i] >> z] = a[i]; for (Int i = 1;i < lim;i <<= 1) for (Int j = 0;j < lim;j += i << 1) for (Int k = 0;k < i;++ k){ int x = mul (w[i + k],d[i + j + k]); d[i + j + k] = dec (d[j + k],x),d[j + k] = add (d[j + k],x); } for (Int i = 0;i < lim;++ i) a[i] = d[i]; if (type == -1){ reverse (a.begin() + 1,a.begin() + lim); for (Int i = 0,Inv = inv (lim);i < lim;++ i) a[i] = mul (a[i],Inv); } #undef G #undef Gi } poly operator * (poly a,poly b){ int d = SZ (a) + SZ (b) - 1,lim = 1;while (lim < d) lim <<= 1; a.resize (lim),b.resize (lim); ntt (a,lim,1),ntt (b,lim,1); for (Int i = 0;i < lim;++ i) a[i] = mul (a[i],b[i]); ntt (a,lim,-1),a.resize (d); return a; } poly inv (poly a,int n){ poly b(1,inv (a[0])),c; for (Int l = 4;(l >> 2) < n;l <<= 1){ c.resize (l >> 1); for (Int i = 0;i < (l >> 1);++ i) c[i] = i < n ? a[i] : 0; c.resize (l),b.resize (l); ntt (c,l,1),ntt (b,l,1); for (Int i = 0;i < l;++ i) b[i] = mul (b[i],dec (2,mul (b[i],c[i]))); ntt (b,l,-1),b.resize (l >> 1); } b.resize (n); return b; } poly inv (poly a){return inv (a,SZ (a));} poly der (poly a){ for (Int i = 0;i < SZ (a) - 1;++ i) a[i] = mul (a[i + 1],i + 1); a.pop_back ();return a; } poly ine (poly a){ a.push_back (0); for (Int i = SZ (a) - 1;i;-- i) a[i] = mul (a[i - 1],inv (i)); a[0] = 0;return a; } poly ln (poly a,int n){ a = ine (der (a) * inv (a)); a.resize (n); return a; } poly ln (poly a){return ln (a,SZ (a));} template <typename T> inline void read (T &t){t = 0;char c = getchar();int f = 1;while (c < '0' || c > '9'){if (c == '-') f = -f;c = getchar();}while (c >= '0' && c <= '9'){t = (t << 3) + (t << 1) + c - '0';c = getchar();} t *= f;} template <typename T,typename ... Args> inline void read (T &t,Args&... args){read (t);read (args...);} template <typename T> inline void write (T x){if (x < 0){x = -x;putchar ('-');}if (x > 9) write (x / 10);putchar (x % 10 + '0');} int n,a[MAXN];poly F; poly divide (int l,int r){ if (l == r){ poly tmp;tmp.resize(2); tmp[0] = 1,tmp[1] = mod - a[l]; return tmp; } int mid = (l + r) >> 1; return divide (l,mid) * divide (mid + 1,r); } signed main(){ init_ntt (); int T;read (T); while (T --> 0){ read (n); for (Int i = 1;i <= n;++ i) read (a[i]); F = der (ln (divide (1,n))),F.resize (n + 1); for (Int i = n;i;-- i) F[i] = mod - F[i - 1];F[0] = n; int ans = 0;for (Int i = 1;i <= n;++ i) ans ^= (F[i] % mod + mod) % mod; write (ans),putchar ('\n'); } return 0; }