1. 程式人生 > 實用技巧 >[ICPC2020上海C] Sum of Log - 數位dp

[ICPC2020上海C] Sum of Log - 數位dp

[ICPC2020上海C] Sum of Log

Description

給定 \(x,y \le 10^9\)\(\sum_{i=0}^x \sum_{j=[i=0]}^y [i \& j=0][\log_2(i+j)+1]\)

Solution

顯然帶 \(\log\) 的那一項相當於是 \(i,j\) 的最大值的最高位,因此我們暴力列舉最高位是第幾位,由誰貢獻,後面的部分簡單數位 dp 即可

考場上的程式碼,寫得比較醜

(第一次正兒八經在考場上寫出數位dp,魔咒算是破了)

#include <bits/stdc++.h>
using namespace std;

#define int long long
const int N = 65;

const int mod = 1e+9+7;

int a[N], b[N], f[N][2][2];

int solve(int pos, int limi, int limj)
{
    if (pos == 0)
    {
        return 1;
    }
    if (~f[pos][limi][limj])
    {
        return f[pos][limi][limj];
    }
    int res = 0;
    if (limi == 0 || a[pos] == 1)
    {
        res += solve(pos - 1, limi, limj && b[pos] == 0);
    }
    if (limj == 0 || b[pos] == 1)
    {
        res += solve(pos - 1, limi && a[pos] == 0, limj);
    }
    res += solve(pos - 1, limi && a[pos] == 0, limj && b[pos] == 0);

    res %= mod;

    if (f[pos][limi][limj])
    {
        f[pos][limi][limj] = res;
    }

    return res;
}

int solve(int x, int y)
{
    memset(a, 0, sizeof a);
    memset(b, 0, sizeof b);

    for (int i = 0; i < N; i++)
    {
        f[i][0][1] = f[i][1][0] = f[i][1][1] = -1;
    }

    int n = 0, m = 0;
    while (x)
    {
        a[++n] = x % 2;
        x /= 2;
    }
    while (y)
    {
        b[++m] = y % 2;
        y /= 2;
    }

    int presumi = 0, presumj = 0;

    for (int i = m; i > n; i--)
        presumj += b[i];

    int ans = 0;

    for (int high = n; high >= 1; high--)
    {
        presumi += a[high];
        presumj += b[high];
        if (presumi >= 1)
        {
            // cout << "[";
            ans += solve(high - 1, presumi == 1 && a[high] == 1, presumj == 0) * high;
            ans %= mod;

            // cout << "]";
        }
    }

    return ans;
}

void solve()
{
    int x, y;
    scanf("%lld%lld", &x, &y);
    int ans = 0;
    ans += solve(x, y);
    ans += solve(y, x);
    printf("%lld\n", ans % mod);
}

signed main()
{


    ios::sync_with_stdio(false);

    memset(f, -1, sizeof f);
    int t;
    scanf("%lld", &t);
    while (t--)
    {
        solve();
    }

}