[POI2007]ZAP-Queries
阿新 • • 發佈:2018-12-12
嘟嘟嘟
挺好的題
\[\begin{align*} ans &= \sum_{i = 1} ^ {a} \sum_{j = 1} ^ {b} [gcd(i, j) = d] \\ &= \sum_{i = 1} ^ {\lfloor \frac{a}{d} \rfloor} \sum_{j = 1} ^ {\lfloor \frac{b}{d} \rfloor} [gcd(i, j) = 1] \\ \end{align*}\]
令\(n = \lfloor \frac{a}{d} \rfloor\),\(m = \lfloor \frac{b}{d} \rfloor\),根據莫比烏斯函式:\(\sum_{d | n} \mu(d) = [n = 1]\)
\[\begin{align*} ans &= \sum_{i = 1} ^ {n} \sum_{j = 1} ^ {m} \sum_{d' | gcd(i, j)} \mu(d') \\ &= \sum_{d'} \sum_{i = 1, d' | i} ^ {n} \sum_{j = 1, d' | i} ^ {m} \mu(d') \\ &= \sum_{d'} \lfloor \frac{n}{d'} \rfloor \lfloor \frac{m}{d'} \rfloor \mu(d') \end{align*}\]
化簡到這裡,就可以用數論分塊的思想,列舉\(d'\)
所以要先預處理\(\mu(i)\)的字首和。
#include<cstdio> #include<iostream> #include<cmath> #include<algorithm> #include<cstring> #include<cstdlib> #include<cctype> #include<vector> #include<stack> #include<queue> using namespace std; #define enter puts("") #define space putchar(' ') #define Mem(a, x) memset(a, x, sizeof(a)) #define rg register typedef long long ll; typedef double db; const int INF = 0x3f3f3f3f; const db eps = 1e-8; const int maxn = 5e4 + 5; inline ll read() { ll ans = 0; char ch = getchar(), last = ' '; while(!isdigit(ch)) last = ch, ch = getchar(); while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar(); if(last == '-') ans = -ans; return ans; } inline void write(ll x) { if(x < 0) x = -x, putchar('-'); if(x >= 10) write(x / 10); putchar(x % 10 + '0'); } int n; int prime[maxn], v[maxn], phi[maxn], mu[maxn]; ll sum[maxn]; void init() { phi[1] = mu[1] = 1; for(int i = 2; i < maxn; ++i) { if(!v[i]) v[i] = i, prime[++prime[0]] = i, phi[i] = i - 1, mu[i] = -1; for(int j = 1; i * prime[j] < maxn && j <= prime[0]; ++j) { int k = i * prime[j]; v[k] = prime[j]; if(i % prime[j] == 0) { phi[k] = prime[j] * phi[i]; mu[k] = 0; break; } else phi[k] = (prime[j] - 1) * phi[i], mu[k] = -mu[i]; } } for(int i = 1; i < maxn; ++i) sum[i] = sum[i - 1] + mu[i]; } ll solve(int n, int m) { ll ret = 0; ll Min = min(n, m); for(int l = 1, r; l <= Min; l = r + 1) { r = min(n / (n / l), m / (m / l)); ret += (sum[r] - sum[l - 1]) * (n / l) * (m / l); } return ret; } int main() { init(); int T = read(); while(T--) { int n = read(), m = read(), d = read(); write(solve(n / d, m / d)), enter; } return 0; }