[筆記] [題解] 狀壓$DP$&洛谷P1433
[筆記] [題解] 狀壓\(DP\)&洛谷P1433
狀壓\(DP\)
狀壓\(dp\)是動態規劃的一種,通過將狀態壓縮為整數來達到優化轉移的目的 -- \(OI wiki\)
狀態壓縮的思想是用二進位制來表示狀態.
狀壓\(dp\)的時間複雜度是\(O(n^22^n)\)的,通常只能用於\(n \leq 21\)的資料範圍
題目思路
我們定義陣列\(dp[i][j]\)表示的是老鼠走到第\(i\)個乳酪,且之前走過的狀態為\(j\)時所用的最短的距離.舉個栗子:\(dp[9][166]\)代表的就是老鼠走到了第\(9\)塊乳酪,之前已經走過第\(2,3,6,8\)塊乳酪,因為\(166\)
對於\(dp\)陣列的初始化:首先先賦值為最大值,可以使用memset(dp,127,sizeof(dp))
,接著如果只經過第\(i\)塊乳酪,那麼對應的\(dp\)陣列的值就是原點到這個乳酪的座標的直線距離.用程式表示就是dp[i][1 << (i - 1)] = dis[0][i]
.為什麼是左移\(i-1\)位呢?我們可以手推一下,假設現在我們只到了第\(1\)塊乳酪,那麼對應的狀態(二進位制)因該是\(0001\),轉換為十進位制就是\(1\),\(1=2^0\),所以我們其實是左移\(0\)
接下來就是轉移方程了,\(dp[i][k] = min(dp[i][k],dp[j][k-2^{i-1}] + dis[i][j])\),\(dp[j][k-2^{i-1}]\)表示的是現在老鼠在\(j\)點,並且沒有走過\(i\)點的最短距離,\(dis[i][j]\)就是\(i,j\)之間的距離,為什麼可以表示沒有走過\(i\)的狀態呢?因為\(dp[i][k]\)要求的就是走過了\(i\)的最佳答案,所以\(k\)
關於統計答案,就是找出\(min(dp[i][2^n-1])\)就可以了,因為已經明確了最終的狀態是全部乳酪都要吃,所以對於每一個\(i\)只要保證描述狀態的那一維也就是\(2^n-1\)也就保證了每一個乳酪都被吃了,所以每一個\(i\)都是成立的合法解,所以取最小值就行.再理解一下,為什麼\(2^n-1\)就可以保證代表每一個乳酪都被取到了呢?根據上文的描述我們知道乳酪\(i\)被取了就意味著在\(dp\)陣列第二維描述狀態的數的二進位制位上第\(i\)位是\(1\),那麼其實我們只要在統計答案的時候保證描述狀態的每一位都是\(1\)即可,那麼我們先假設\(n=3\),那麼對應的最終狀態的二進位制表示因該是\(111\),十進位制下就是\(7\),而\(7=8-1,8=2^3,8\)的二進位制表示是\(1000\),那麼再在\(8\)的二進位制位上\(-1\),就變成\(111\),也就是答案了.
這就是狀壓\(dp\)的基本思路和解法,其實應用過的範圍不是很廣,因為可以支援的資料範圍不大,但是這種用二進位制來優化狀態表示和轉移的方法還是很有學習價值的.
解題程式碼
下面就是洛谷上例題的程式碼了:
在實現的時候還是要注意精度,因為要不然的話太小的小數比如\(0.001\)就會被計算成\(0\),導致\(WA\),所以儘量所有含有實數運算的函式和引數都要定義成\(double\)型別
#include <bits/stdc++.h>
using namespace std;
double dis[20][20];
struct point{
double x,y;
}pos[20];
double dp[20][34000];
int n;
double pow(double x){
return x * x;
}
double calc_dis(int x,int y){
return sqrt(pow(pos[x].x - pos[y].x) + pow(pos[x].y - pos[y].y));
}
int main(){
double ans;
memset(dp,127,sizeof(dp));
ans = dp[0][0];
scanf("%d",&n);
for(int i = 1;i <= n;i++){
scanf("%lf%lf",&pos[i].x,&pos[i].y);
}
pos[0].x = pos[0].y = 0;
for(int i = 0;i <= n;i++){
for(int j = i + 1;j <= n;j++){
dis[i][j] = dis[j][i] = calc_dis(i,j);
}
}
for(int i = 1;i <= n;i++)
dp[i][1 << (i - 1)] = dis[0][i];
for(int k = 1;k < (1 << n);k++){
for(int i = 1;i <= n;i++){
if((k & (1 << (i - 1))) == 0)continue;
for(int j = 1;j <= n;j++){
if(i == j)continue;
if((k & (1 << (j - 1))) == 0)continue;
dp[i][k] = min(dp[i][k],dp[j][k - (1 << (i - 1))] + dis[i][j]);
}
}
}
for(int i = 1;i <= n;i++)ans = min(ans,dp[i][(1 << n) - 1]);
printf("%.2lf\n",ans);
return 0;
}