1. 程式人生 > 實用技巧 >「筆記」數位DP

「筆記」數位DP

目錄

寫在前面

19 年前聽 zlq 講課的時候學的東西,當時只會抄板子,現在來重學一波= =

一個板子水一天題(不事

引入

「SCOI2009」Windy 數

給定引數 \(l,r\),求 \([l,r]\) 中不含前導零且相鄰兩個數字之差至少為 \(2\) 的正整數的個數。
\(1\le l\le r\le 2\times 10^9\)
1S,512MB。

這是一個經典的數位 DP 的例子。其模型一般是給定一些對於數的限制條件,求在給定範圍內滿足限制的數的貢獻。
通過數位 DP 一般可以在 \(O(m\log_{10}{(n)})\)

的時間內解決此問題,其中 \(m\) 是數碼種類數,\(n\) 是取值的最大值。

求解

首先將詢問 \([l,r]\) 內合法的數的個數拆成詢問 \([0\sim l-1]\)\([0, r]\) 內合法的數的個數,之後考慮數位 DP。
數位 DP 有遞推 和 記憶化搜尋兩種寫法,由於記憶化搜尋更容易理解與實現,我們一般採用記憶化搜尋解決此類問題。以下也僅介紹記憶化搜尋的解法。

先考慮爆搜。考慮列舉所有範圍內的數,搜尋的同時檢查是否滿足給定的限制條件。注意考慮前導零與是否達到列舉的上界,其程式碼如下所示:

int numlth, num[kN]; //儲存給定值的從高位到低位的十進位制拆分。
//now_:當前填到第幾位; last_:now_ - 1 位填的數;
//zero_:前 now_ - 1 位是否均為 0; lim_:前 now_ - 1 位是否達到列舉的上界(與 num 相同)
int Dfs(int now_, int last_, bool zero_, bool lim_) {
  if (now_ > numlth) return 1; //當前列舉的數合法
  int ret = 0; 
  //列舉第 now_ 位填的數,up 為該位填數的上界
  for (int i = 0, up = lim_ ? num[now_] : 9; i <= up; ++ i) { 
    if (abs(i - last_) < 2) continue ;
    if (zero_ && !i) ret += Dfs(now_ + 1, 11, true, lim_ && i == up); //前 now_ 位均為 0
    else ret += Dfs(now_ + 1, i, false, lim_ &&i == up);
  }
  return ret;
}
//ans[0, x] = Dfs(1, 11, true, true);

發現當列舉的數字首的性質相同,即 dfs 的四個引數相同時,dfs 的返回值相同。
比如當列舉到 \(020\underline{?}??\)\(010\underline{?}??\) 時,dfs 的引數均為 (4, 0, false, false)。表示它們字首的性質相同,列舉之後位數得到的答案顯然也相同。
簡單記憶化即可避免重複列舉過程。

//f[i][j][0/1][0/1] 表示 dfs(i, j, 0/1, 0/1) 的答案。
int numlth, num[kN], f[kN][kN][2][2];
int Dfs(int now_, int last_, bool zero_, bool lim_) {
  if (now_ > numlth) return 1;
  if (f[now_][last_][zero_][lim_] != -1) return f[now_][last_][zero_][lim_];
  int ret = 0;
  for (int i = 0, up = lim_ ? num[now_] : 9; i <= up; ++ i) {
    if (abs(i - last_) < 2) continue ;
    if (zero_ && !i) ret += Dfs(now_ + 1, 11, true, lim_ && i == up);
    else ret += Dfs(now_ + 1, i, false, lim_ && i == up);
  }
  return f[now_][last_][zero_][lim_] = ret;
}
//ans[0, x] = Dfs(1, 11, true, true);

特判優化

發現上述 dfs 的過程中,\(\operatorname{lim} = 1\)\(\operatorname{zero} = 1\) 的狀態只會被列舉到 1 次,即只會重複呼叫 dfs(now_, last_, 0, 0)。對這兩維的記憶化對減少列舉次數是做負功的。
於是可以通過特判去除這兩維,如下所示:

//f[i][j] 表示 dfs(i, j, 0, 0) 的答案。
int Dfs(int now_, int last_, bool zero_, bool lim_) {
  if (now_ > numlth) return 1;
  if (!lim_ && f[now_][last_] != -1) return f[now_][last_];
  int ret = 0;
  for (int i = 0, up = lim_ ? num[now_] : 9; i <= up; ++ i) {
    if (abs(i - last_) < 2) continue ;
    if (zero_ && !i) ret += Dfs(now_ + 1, 11, true, lim_ && i == up);
    else ret += Dfs(now_ + 1, i, false, lim_ && i == up);
  }
  if (!lim_) f[now_][last_] = ret;
  return ret;
}

可以感性理解特判的實際意義。若 dfs 的引數 \(\operatorname{lim} = 0\) 時,表示字首比上界小,後面的位數可以隨意填。因此字首性質相同的所有子問題是完全等價的,因此可以記憶化。
\(\operatorname{zero} = 1\)\(\operatorname{lim} = 0\) 一定是配套出現的,因此也可以特判掉。

這樣時空複雜度均變為了原來的 \(\frac{1}{4}\)。在其他題目中也可以套用此模板,將 0/1 維特判掉,減小時空複雜度。
可能有___出題人卡直接記憶化的寫法,比如這題:

程式碼

引入問題的完整程式碼。

//知識點:數位 DP
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#include <vector>
#define LL long long
const int kN = 15;
//=============================================================
int numlth, f[kN][kN];
std::vector <int> num;
//=============================================================
inline int read() {
  int f = 1, w = 0;
  char ch = getchar();
  for (; !isdigit(ch); ch = getchar())
    if (ch == '-') f = -1;
  for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
  return f * w;
}
void Chkmax(int &fir, int sec) {
  if (sec > fir) fir = sec;
}
void Chkmin(int &fir, int sec) {
  if (sec < fir) fir = sec;
}
int Dfs(int now_, int last_, bool zero_, bool lim_) {
  if (now_ > numlth) return 1;
  if (!lim_ && f[now_][last_] != -1) return f[now_][last_];
  int ret = 0;
  for (int i = 0, up = lim_ ? num[now_] : 9; i <= up; ++ i) {
    if (abs(i - last_) < 2) continue ;
    if (zero_ && !i) ret += Dfs(now_ + 1, 11, true, lim_ && i == up);
    else ret += Dfs(now_ + 1, i, false, lim_ && i == up);
  }
  if (!lim_) f[now_][last_] = ret;
  return ret;
}
int Calc(int val_) {
  num.clear();
  num.push_back(0);
  for (int tmp = val_; tmp; tmp /= 10) num.push_back(tmp % 10);
  for (int i = 1, j = num.size() - 1; i < j; ++ i, -- j) {
    std::swap(num[i], num[j]);
  }
  numlth = num.size() - 1;
  memset(f, -1, sizeof (f));
  return Dfs(1, 11, true, true);
}
//=============================================================
int main() {
  int a = read(), b = read();
  printf("%d\n", Calc(b) - Calc(a - 1));
  return 0; 
}

例題

「ZJOI2010」數字計數

給定兩個正整數 \(a\)\(b\),求在 \([a,b]\) 中的所有整數中,每個數碼各出現了多少次。
\(1\le a\le b\le 10^{12}\)
1S,512MB。

與引入問題不同的是,這題要求的是數碼的數量,限制了每個數的貢獻,求貢獻和。
套路類似,考慮對每個數碼分開求解,dfs 時記錄已列舉字首的貢獻量。
Dfs(int now_, LL sum_, bool zero_, bool lim_, int digit_) 表示前 \(\operatorname{now} - 1\) 位含有數碼 \(\operatorname{digit}\) 的數量為 \(\operatorname{sum}\)、字首是否全為前導零、字首是否達到上界,滿足上述條件的所有數中數碼 \(\operatorname{digit}\) 的數量。
邊界是搜尋到第 \(\operatorname{length}+1\) 位,此時返回 \(\operatorname{sum}\) 的值。
與套路類似地,發現一些 \(\operatorname{now}\)\(\operatorname{sum}\) 相等的搜尋狀態會被重複訪問,簡單記憶化即可。
總複雜度 \(O(10^2\log_{10}(n))\) 級別。

//知識點:數位 DP
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#include <vector>
#define LL long long
const int kN = 20;
//=============================================================
LL numlth, f[kN][kN];
std::vector <int> num;
//=============================================================
inline LL read() {
  LL f = 1, w = 0;
  char ch = getchar();
  for (; !isdigit(ch); ch = getchar())
    if (ch == '-') f = -1;
  for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
  return f * w;
}
void Chkmax(int &fir, int sec) {
  if (sec > fir) fir = sec;
}
void Chkmin(int &fir, int sec) {
  if (sec < fir) fir = sec;
}
LL Dfs(int now_, LL sum_, bool zero_, bool lim_, int digit_) {
  if (now_ > numlth) return sum_;
  if (!lim_ && f[now_][sum_] != -1) return f[now_][sum_];
  LL ret = 0;
  for (int i = 0, up = lim_ ? num[now_] : 9; i <= up; ++ i) {
    if (zero_ && !i) ret += Dfs(now_ + 1, sum_, true, lim_ && i == up, digit_);
    else ret += Dfs(now_ + 1, sum_ + (i == digit_), false, lim_ && i == up, digit_);
  }
  if (!lim_) f[now_][sum_] = ret;
  return ret;
}
LL Calc(LL val_, int digit_) {
  num.clear();
  num.push_back(0);
  for (LL tmp = val_; tmp; tmp /= 10) num.push_back(tmp % 10);
  for (int i = 1, j = num.size() - 1; i < j; ++ i, -- j) std::swap(num[i], num[j]);
  numlth = num.size() - 1;
  memset(f, -1, sizeof (f));
  return Dfs(1, 0, true, true, digit_);
}
//=============================================================
int main() {
  LL a = read(), b = read();
  for (int i = 0; i <= 9; ++ i) printf("%lld ", Calc(b, i) - Calc(a - 1, i));
  return 0; 
}

還有一種考慮每個位置填入指定數碼後對應的數的個數的無腦寫法,看程式碼就能看懂。

//知識點:暴力
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#define LL long long
const int kN = 13;
//=============================================================
LL f[kN];
//=============================================================
inline LL read() {
  LL f = 1, w = 0;
  char ch = getchar();
  for (; !isdigit(ch); ch = getchar())
    if (ch == '-') f = -1;
  for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
  return f * w;
}
void Chkmax(int &fir, int sec) {
  if (sec > fir) fir = sec;
}
void Chkmin(int &fir, int sec) {
  if (sec < fir) fir = sec;
}
LL Calc(LL val_, LL digit_) {
  LL ret = (!digit_);
  for (LL tmp = val_, pow10 = 1; tmp; tmp /= 10ll, pow10 *= 10ll) {
    LL pre = tmp / 10ll + 1;
    if (! digit_) {
      if (pre == 1) continue;
      if (0 < tmp % 10) ret += (pre - 1ll) * pow10;
      if (0 == tmp % 10) ret += (pre - 2ll) * pow10 + val_ % pow10 + 1;
      continue;
    }
    if (digit_ > tmp % 10) ret += (pre - 1ll) * pow10;
    if (digit_ == tmp % 10) ret += (pre - 1ll) * pow10 + val_ % pow10 + 1;
    if (digit_ < tmp % 10) ret += pre * pow10;
  }
  return ret;
}
//=============================================================
int main() { 
  LL a = read(), b = read();
  for (int i = 0; i <= 9; ++ i) printf("%lld ", Calc(b, i) - Calc(a - 1, i));
  return 0; 
}

「AHOI2009」同類分佈

給定兩個正整數 \(a\)\(b\),求在 \([a,b]\) 中的所有整數中,各位數之和能整除原數的數的個數。
\(1\le a\le b\le 10^{18}\)
3S,512MB。

考慮到各位數之和與原數在 dfs 中都是變數,不易檢驗合法性。但發現各位數之和不大於 \(9\times 12\),考慮先列舉各位數之和,再在 dfs 時維護字首的餘數,以檢查是否合法。
同樣設 Dfs(int now_, int sum_, int p_, bool zero_, bool lim_, int val_),其中 \(\operatorname{sum}\) 為字首的各數位之和,\(p\) 為原數模 \(\operatorname{val}\) 的餘數。
邊界是搜尋到第 \(\operatorname{length}+1\) 位,此時返回 \([\operatorname{sum}=\operatorname{val} \land \, p = 0]\)
對數位和和餘數簡單記憶化即可,總複雜度 \(O(2\cdot10^2\log_{10}^3(n))\) 級別。

//知識點:數位 DP
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#include <vector>
#define LL long long
const int kN = 20;
//=============================================================
int numlth;
LL f[kN][9 * kN][9 * kN];
std::vector <int> num;
//=============================================================
inline LL read() {
  LL f = 1, w = 0;
  char ch = getchar();
  for (; !isdigit(ch); ch = getchar())
    if (ch == '-') f = -1;
  for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
  return f * w;
}
void Chkmax(int &fir, int sec) {
  if (sec > fir) fir = sec;
}
void Chkmin(int &fir, int sec) {
  if (sec < fir) fir = sec;
}
LL Dfs(int now_, int sum_, int p_, bool zero_, bool lim_, int val_) {
  if (now_ > numlth) return (sum_ == val_ && !p_);
  if (!lim_ && f[now_][sum_][p_] != -1) return f[now_][sum_][p_];
  LL ret = 0;
  for (int i = 0, up = lim_ ? num[now_] : 9; i <= up; ++ i) {
    if (zero_ && !i) ret += Dfs(now_ + 1, sum_, 10 * p_ % val_, true, lim_ && i == up, val_);
    else ret += Dfs(now_ + 1, sum_ + i, (10 * p_ + i) % val_, false, lim_ && i == up, val_);
  }
  if (!zero_ && !lim_) f[now_][sum_][p_] = ret;
  return ret;
}
LL Calc(LL val_) {
  num.clear();
  num.push_back(0);
  for (LL tmp = val_; tmp; tmp /= 10) num.push_back(tmp % 10);
  for (int i = 1, j = numlth = num.size() - 1; i < j; ++ i, -- j) {
    std::swap(num[i], num[j]);
  }
  LL ret = 0;
  for (int i = 1; i <= 9 * numlth; ++ i) {
    memset(f, -1, sizeof (f));
    ret += Dfs(1, 0, 0, true, true, i);
  }
  // printf("%lld %lld\n", val_, ret);
  return ret;
}
//=============================================================
int main() {
  LL a = read(), b = read();
  printf("%lld\n", Calc(b) - Calc(a - 1));
  return 0; 
}

套路題們

P3413 SAC#1 - 萌數

給定兩個正整數 \(a\)\(b\),求在 \([a,b]\) 中的所有整數中,存在長度至少為2的迴文子串的數的個數。
\(1\le a< b\le 10^{1000}\)
1S,128MB。

存在長度至少為2的迴文子串等價於沒有連續相等的三位,dfs 時記錄前兩位即可。程式碼 Link


「CQOI2016」手機號碼

給定兩個正整數 \(a\)\(b\),求在 \([a,b]\) 中的所有整數中,至少有三個相鄰的相同數字,且 8 和 4 不同時存在的數的個數。
\(10^{10}\le a\le b\le 10^{11}\)
1S,256MB。

狀態多設幾維即可,記錄前兩位,字首中是否有有三個相鄰的相同數字,字首中是否有 8,字首中是否有 4。程式碼 Link


P4317 花神的數論題

給定一正整數 \(a\),求在 \([1,a]\) 中的所有整數的二進位制拆分中 1 的個數的乘積。
\(1\le a \le 10^{15}\)
1S,128MB。

二進位制拆分 \(a\),同「AHOI2009」同類分佈,列舉二進位制中 1 的個數 dfs 即可。
注意不要亂取模。程式碼 Link

「SDOI2014」數數

給定一個整數 \(n\),一大小為 \(m\) 的數字串集合 \(s\)
求不以 \(s\) 中任意一個數字串作為子串的,不大於 \(n\) 的數字的個數。
\(1\le n\le 10^{1201}\)\(1\le m\le 100\)\(1\le \sum |s_i|\le 1500\)\(n\) 沒有前導零,\(s_i\) 可能存在前導零。
1S,128MB。

題目要求不以 \(s\) 中任意一個數字串作為子串,想到這題:「JSOI2007」文字生成器。首先套路地對給定集合的串構建 ACAM,並在 ACAM 上標記所有包含集合內的子串的狀態。
之後考慮在 ACAM 上模擬串匹配的過程做數位 DP。發現字首所在狀態儲存了字首的所有資訊,可以將其作為 dfs 的引數。
Dfs(int now_, int pos_, bool zero_, bool lim_) { 表示字首匹配到的 ACAM 的狀態為 \(\operatorname{pos}\) 時,合法的數字的數量。轉移時沿 ACAM 上的轉移函式轉移,避免轉移到被標記的狀態。
存在 \(\operatorname{trans}(0, 0) = 0\),這樣直接 dfs 也能順便處理不同長度的數字串。
總複雜度 \(O(\log_{10}(n)\sum |s_i|)\) 級別。

//知識點:ACAM,數位 DP
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#include <queue>
#define LL long long
const int kN = 1500 + 10;
const int mod = 1e9 + 7;
//=============================================================
int n, m, ans;
char num[kN], s[kN];
//=============================================================
inline int read() {
  int f = 1, w = 0;
  char ch = getchar();
  for (; !isdigit(ch); ch = getchar())
    if (ch == '-') f = -1;
  for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
  return f * w;
}
void Chkmax(int &fir, int sec) {
  if (sec > fir) fir = sec;
}
void Chkmin(int &fir, int sec) {
  if (sec < fir) fir = sec;
}
namespace ACAM {
  const int kSigma = 10;
  int node_num, tr[kN][kSigma], last[kN], fail[kN];
  int f[kN][kN];
  bool tag[kN];
  void Insert(char *s_) {
    int u_ = 0, lth = strlen(s_ + 1);
    for (int i = 1; i <= lth; ++ i) {
      if (! tr[u_][s_[i] - '0']) tr[u_][s_[i] - '0'] = ++ node_num;
      u_ = tr[u_][s_[i] - '0'];
      last[u_] = s_[i] - '0';
    }
    tag[u_] = true;
  }
  void Build() {
    std:: queue <int> q;
    for (int i = 0; i < kSigma; ++ i) {
      if (tr[0][i]) q.push(tr[0][i]);
    }
    while (!q.empty()) {
      int u_ = q.front(); q.pop();
      tag[u_] |= tag[fail[u_]];
      for (int i = 0; i < kSigma; ++ i) {
        int v_ = tr[u_][i];
        if (v_) {
          fail[v_] = tr[fail[u_]][i];
          q.push(v_);
        } else {
          tr[u_][i] = tr[fail[u_]][i];
        }
      }
    }
  }
  int Dfs(int now_, int pos_, bool zero_, bool lim_) {
    if (now_ > n) return 1;
    if (!zero_ && !lim_ && f[now_][pos_] != -1) return f[now_][pos_];
    int ret = 0;
    for (int i = 0, up = lim_ ? num[now_] - '0': 9; i <= up; ++ i) {
      int v_ = tr[pos_][i];
      if (tag[v_]) continue;
      if (zero_ && !i) ret += Dfs(now_ + 1, 0, true, lim_ && i == num[now_] - '0');
      else ret += Dfs(now_ + 1, v_, false, lim_ && i == num[now_] - '0');
      ret %= mod;
    }
    if (!zero_ && !lim_) f[now_][pos_] = ret;
    return ret;
  }
  int DP() {
    memset(f, -1, sizeof (f));
    return Dfs(1, 0, true, true);
  }
}
//=============================================================
int main() {
  scanf("%s", num + 1);
  n = strlen(num + 1);
  m = read();
  for (int i = 1; i <= m; ++ i) {
    scanf("%s", s + 1);
    ACAM::Insert(s);
  }
  ACAM::Build();
  printf("%d\n", ACAM::DP());
  return 0; 
}

寫在最後

鳴謝:

數位dp 筆記 - Flandre-Zhu