1. 程式人生 > 實用技巧 >BZOJ3598 - 方伯伯的商場之旅(數位dp)

BZOJ3598 - 方伯伯的商場之旅(數位dp)

題意

方伯伯有一天去參加一個商場舉辦的遊戲。商場派了一些工作人員排成一行。每個人面前有幾堆石子。說來也巧,位置在 i 的人面前的第 j 堆的石子的數量,剛好是 i 寫成 K 進位制後的第 j 位。
現在方伯伯要玩一個遊戲,商場會給方伯伯兩個整數 L,R。方伯伯要把位置在 [L, R] 中的每個人的石子都合併成一堆石子。每次操作,他可以選擇一個人面前的兩堆石子,將其中的一堆中的某些石子移動到另一堆,代價是移動的石子數量 * 移動的距離。商場承諾,方伯伯只要完成任務,就給他一些椰子,代價越小,給他的椰子越多。所以方伯伯很著急,想請你告訴他最少的代價是多少。
例如:10 進位制下的位置在 12312 的人,合併石子的最少代價為:
1 * 2 + 2 * 1 + 3 * 0 + 1 * 1 + 2 * 2 = 9
即把所有的石子都合併在第三堆

題解

先思考一下如何求最少代價。由於最後要合併為一堆,所以就是求移到哪一堆代價最小。假設pre和suf陣列代表字首和和字尾和。
假設當前的目標為i,每當目標右移一格(i+1),對代價的貢獻為pre[i] - suf[i + 1]。這個貢獻是單調遞增的,所以最小代價的位置為貢獻剛好由負轉正的位置。

這下就有dp的目標了。列舉目標位置,求出最小代價為目標位置的數和代價和,最後累加答案即可。狀態為當前位置p,貢獻sdif,目標位置的數字d(用於判斷是否是臨界位置,防止重複計數)。

細節詳見程式碼。

#include <bits/stdc++.h>

#define endl '\n'
#define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define FILE freopen(".//data_generator//in.txt","r",stdin),freopen("res.txt","w",stdout)
#define FI freopen(".//data_generator//in.txt","r",stdin)
#define FO freopen("res.txt","w",stdout)
#define pb push_back
#define mp make_pair
#define seteps(N) fixed << setprecision(N) 
typedef long long ll;

using namespace std;
/*-----------------------------------------------------------------*/

ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
#define INF 0x3f3f3f3f

const int N = 3e5 + 10;
const double eps = 1e-5;
typedef pair<ll, ll> PII;

ll f[40][8000][40];
ll cnt[40][8000][40];
int vis[40][8000][40];
int di[20];
int tar, tag, k;

PII dfs(int p, int sdif, int d, int lmt) {
    if(!p) {
        
        if(sdif <= 0 && 2 * d + sdif > 0) return mp(1, 0);
        return mp(0, 0);
    }
    if(!lmt && vis[p][sdif + 4000][d] == tag) 
        return mp(cnt[p][sdif + 4000][d], f[p][sdif + 4000][d]);
    ll res = 0;
    ll ct = 0;
    int maxx = lmt ? di[p] : (k - 1);
    for(int i = 0; i <= maxx; i++) {
        if(p > tar) {
            auto v = dfs(p - 1, sdif + i, i, i == maxx && lmt);
            ct += v.first;
            res += v.second + v.first * abs(tar - p) * i;
        } else {
            if(p == tar) {
                auto v = dfs(p - 1, sdif - i, i, i == maxx && lmt);
                ct += v.first;
                res += v.second + v.first * abs(tar - p) * i;
            } else {
                auto v = dfs(p - 1, sdif - i, d, i == maxx && lmt);
                ct += v.first;
                res += v.second + v.first * abs(tar - p) * i;
            }
        }
    }
    if(!lmt) {
        vis[p][sdif + 4000][d] = tag;
        f[p][sdif + 4000][d] = res;
        cnt[p][sdif + 4000][d] = ct;
    }
    return mp(ct, res);
}


ll solve(ll x) {
    int tot = 0;
    ll ans = 0;
    while(x) {
        di[++tot] = x % k;
        x /= k;
    }
    for(tar = 1; tar <= tot; tar++) {
        tag++;
        ans += dfs(tot, 0, 0, 1).second;
    }
    return ans;
}

int main() {
    IOS;
    ll l, r;
    cin >> l >> r >> k;
    cout << solve(r) - solve(l - 1) << endl;
}