1. 程式人生 > 其它 >[題解][LG-P3791]普通數學題

[題解][LG-P3791]普通數學題

可以想到將\(d\)函式的字首和轉化一下:

\[S_d(i)=\sum_{i=1}^n d(i)=\sum_{i=1}^n \left\lfloor\frac{n}{i}\right\rfloor \]

那麼就可以求出\(O(\sqrt{n})\)求出每一個字首和了(整除分塊),加上一個unordered_map去除不必要的計算。

將問題修改一下:\(i \leq n, j \leq m \longrightarrow i < n + 1, j < m + 1\)

接下來解決\(\mathrm{xor}\)的問題,首先可以發現,我們可以轉化一下問題,求出每一個\(d(i)\)的貢獻。可以(比較難)想到使用數位\(dp\)

的思想,設定前面幾位是相同的,然後固定某一位比原數小,然後後面的所有位隨便放。

這裡設定\(len_a, len_b, a, b\),其中\(a\)的後\(len_a\)是隨便放的,強制令\(a\)是第\(0\)\(len_a-1\)都是\(0\), 第\(len_a\)位比\(n\)小(只能是\(0\)了),然後\(len_a\)(含)之後的位都和\(n\)相同,\(len_b, b\)同理。

\(mx=\max\{len_a, len_b\}, mn=\min\{len_a, len_b\}\),一個數\(num\)的二進位制第\(p\)位是\(num_p\)

那麼可以發現,對於\(mx\)

(含)之後的位,都是可以確定的(廢話),使用題目的公式計算出來就好了,我們設計算出的結果是\(pre\)。那麼公式可以是:

\[pre=a \land b \land x \land \lnot2^{mx} \]

但是對於第\(1\)\(mx-1\)的位,可以不管題目提供的\(x\),因為任何一種可能都能夠配出來。設一種可能為\(val\),考慮第\(p\)位,若\(p>mn\),那麼只有一種可能,就是在\(len\)的那個數(現在假設為\(len_a\)較大)中配一個\(v~\mathrm{xor}~x_p~\mathrm{xor}~b_p\)。假設\(p\leq mn\),那麼就有兩種可能:\(v~\mathrm{xor}~x_p=a_p~\mathrm{xor}~b_p\)

。那麼對於任何一種可能,都有\(2^{mn}\)\(a,b\)的配合情況。所有的可能就是\([pre,pre+2^{mx}-1]\)中的整數了。那麼一組\(len_a, len_b, a, b\)的答案就是:

\[ans(len_a, len_b, a, b)=\left(S_d\left(pre+2^{mx}-1\right)-S_d(pre-1)\right)\times 2^{mn} \]

那麼最終答案就是所有合法的\(len_a, len_b, a,b\)答案之和了。

PS:記得想要得到一個\(2^p(p>31)\),必須要寫1ll<<p,而不是1<<p

#include <bits/stdc++.h>
#define LL long long

using namespace std;

const LL MOD = 998244353;
unordered_map<LL, LL> mp;

LL n, m, vl;

LL calc_sd(LL n) {
	if (n < 0) return 0;
	if (mp.count(n)) return mp[n];
	LL ans = 0;
	for (LL l = 1, r; l <= n; l = r + 1) {
		r = n / (n / l);
		ans = (ans + (r - l + 1) * (n / l) % MOD) % MOD;
	}
	return mp[n] = ans;
}

LL calc_ans(LL x, LL y, LL lx, LL ly) {
	if (lx > ly) swap(x, y), swap(lx, ly);
	LL pre = (x ^ y ^ vl) & (~((1ll << ly) - 1));
	LL val1 = calc_sd(pre + (1ll << ly) - 1), val2 = calc_sd(pre - 1);
	return (val1 - val2 + MOD) % MOD * (1ll << lx) % MOD;
}
int main() {
	scanf("%lld%lld%lld", &n, &m, &vl), n++, m++;
	LL ans = 0;
	for (int i = 0; i <= 50; i++) if (n & (1ll << i)) 
		for (int j = 0; j <= 50; j++) if (m & (1ll << j)) 
			ans = (ans + calc_ans(n ^ (1ll << i), m ^ (1ll << j), i, j)) % MOD;
	printf("%lld\n", ans);
	return 0;
}