斜率優化動態規劃
前言
斜率優化通常使用單調佇列輔助進行實現,用於優化 D P DP DP 的時間複雜度,比較抽象,需要讀者有較高的數學素養。
適用範圍
使用單調佇列優化
D
P
DP
DP ,通常可以解決型如:
d
p
[
i
]
=
m
i
n
(
f
(
j
)
)
+
g
(
i
)
dp[i]=min(f(j))+g(i)
dp[i]=min(f(j))+g(i) 的狀態轉移方程。其中
f
(
i
)
f(i)
f(i) 是隻關於
i
i
i 的函式,
g
(
j
)
g(j)
g(j) 是隻關於
j
j
j 的函式。樸素的解決方法是在第二層迴圈中列舉
j
j
而斜率優化利用上述方法進行改進,實現對於型如: d p [ i ] = m i n ( f ( i , j ) ) + g ( i ) dp[i]=min(f(i,j))+g(i) dp[i]=min(f(i,j))+g(i) 的狀態轉移方程。對比第一種情況,可以發現函式 f f f 函式與兩個值 i , j i,j i,j 都有關,簡單地使用單調佇列是無法優化的。這時候就開始引入主題斜率優化了。
下面結合一道例題來具體詳解。題目來自於
H
N
O
I
2008
HNOI2008
題目大意
有 n n n 個數字 C C C,把它分為若干組,給出另一個數 L L L ,每組的花費為 ( i − j + ∑ k = i j C k − L ) 2 (i-j+\sum_{k=i}^jC_k-L)^2 (i−j+∑k=ijCk−L)2,總花費為所有組的花費之和。求最小總花費。
思路
先考慮樸素的 d p dp dp 做法。
設 d p [ i ] dp[i] dp[i] 為將前 i i i 個數字分組後的最小花費。求和可以考慮使用字首和來優化,設字首和陣列為 p r e pre pre 。則狀態轉移方程可以寫為:
d
p
[
i
]
=
M
i
n
(
d
p
[
j
]
+
(
s
u
m
[
i
]
−
s
u
m
[
j
]
)
+
(
i
−
(
j
+
1
)
)
−
L
)
2
,
0
≤
j
<
i
)
dp[i]=Min(dp[j]+(sum[i]-sum[j])+(i-(j+1))-L)^2,0≤j<i)
即是:
d p [ i ] = M i n ( d p [ j ] + ( s u m [ i ] − s u m [ j ] + i − j − L − 1 ) 2 , 0 ≤ j < i ) dp[i]=Min(dp[j]+(sum[i]-sum[j]+i-j-L-1)^2,0≤j<i) dp[i]=Min(dp[j]+(sum[i]−sum[j]+i−j−L−1)2,0≤j<i)
那麼 s u m sum sum 陣列可以初始化為:
for(int i = 1; i <= n; i++) {
Quick_Read(val[i]);
sum[i] = sum[i - 1] + val[i];
}
設 p r e [ i ] = s u m [ i ] + i pre[i]=sum[i]+i pre[i]=sum[i]+i ,再進一步設 l = L + 1 l=L+1 l=L+1 那麼狀態轉移方程可以寫為:
d p [ i ] = M i n ( d p [ j ] + ( p r e [ i ] − p r e [ j ] − l ) 2 , 0 ≤ j < i ) dp[i]=Min(dp[j]+(pre[i]-pre[j]-l)^2,0≤j<i) dp[i]=Min(dp[j]+(pre[i]−pre[j]−l)2,0≤j<i)
狀態轉移
int Get_Dp(int i, int j) {
return dp[j] + (pre[i] - pre[j] - l) * (pre[i] - pre[j] - l);
}
若列舉 j j j ,則時間複雜度為 O ( n ) 2 O(n)^2 O(n)2 ,時間複雜度不優。使用斜率優化可以對其進行優化。
假設當前列舉到 i i i ,需要得到 i i i 的狀態。假設有兩個決策點 j j j , k k k ,滿足決策點 j j j 優於決策點 k k k 。用符號語言可以表達為:
d p [ j ] + ( p r e [ i ] − p r e [ j ] − l ) 2 < d p [ k ] + ( p r e [ i ] − p r e [ k ] − l ) 2 dp[j]+(pre[i]-pre[j]-l)^2<dp[k]+(pre[i]-pre[k]-l)^2 dp[j]+(pre[i]−pre[j]−l)2<dp[k]+(pre[i]−pre[k]−l)2
展開得:
d p [ j ] + p r e [ i ] 2 + p r e [ j ] 2 + l 2 − 2 × p r e [ i ] × p r e [ j ] − 2 × l × p r e [ i ] + 2 × l × p r e [ j ] < d p [ k ] + p r e [ i ] 2 + p r e [ k ] 2 + l 2 − 2 × p r e [ i ] × p r e [ k ] − 2 × l × p r e [ i ] + 2 × l × p r e [ k ] dp[j]+pre[i]^2+pre[j]^2+l^2-2\times pre[i]\times pre[j]-2\times l\times pre[i]+2\times l\times pre[j]<dp[k]+pre[i]^2+pre[k]^2+l^2-2\times pre[i]\times pre[k]-2\times l\times pre[i]+2\times l\times pre[k] dp[j]+pre[i]2+pre[j]2+l2−2×pre[i]×pre[j]−2×l×pre[i]+2×l×pre[j]<dp[k]+pre[i]2+pre[k]2+l2−2×pre[i]×pre[k]−2×l×pre[i]+2×l×pre[k]
進一步整理得 :
d p [ j ] + p r e [ j ] 2 − d p [ k ] − p r e [ k ] 2 < ( p r e [ i ] − l ) × 2 × ( p r e [ j ] − p r e [ k ] ) dp[j]+pre[j]^2-dp[k]-pre[k]^2<(pre[i]-l)\times 2\times (pre[j] - pre[k]) dp[j]+pre[j]2−dp[k]−pre[k]2<(pre[i]−l)×2×(pre[j]−pre[k])
觀察可得:左邊的式子只與 j j j 和 k k k 有關,但右邊的式子還與 i i i 有關。也可以發現若滿足上述式子,則會有 j j j 優於 k k k 。再分類討論:
- j > k j>k j>k ,則 p r e [ j ] > p r e [ k ] pre[j]>pre[k] pre[j]>pre[k],移項得 d p [ j ] + p r e [ j ] 2 − ( d p [ k ] + p r e [ k ] 2 ) p r e [ j ] − p r e [ k ] < p r e [ i ] − l \frac{dp[j]+pre[j]^2-(dp[k]+pre[k]^2)}{pre[j] - pre[k]}<pre[i]-l pre[j]−pre[k]dp[j]+pre[j]2−(dp[k]+pre[k]2)<pre[i]−l , p r e [ i ] − l pre[i]-l pre[i]−l 可以 看為一個常數。那麼意味著點 j ( d p [ j ] + p r e [ j ] 2 , p r e [ j ] ) j(dp[j]+pre[j]^2,pre[j]) j(dp[j]+pre[j]2,pre[j]) 與點 k ( d p [ k ] + p r e [ k ] 2 , p r e [ k ] ) k(dp[k]+pre[k]^2,pre[k]) k(dp[k]+pre[k]2,pre[k]) 所構成的直線的斜率小於 p r e [ i ] − l pre[i]-l pre[i]−l 這個常數。
- j < k j<k j<k ,則 p r e [ j ] < p r e [ k ] pre[j]<pre[k] pre[j]<pre[k],移項得 d p [ j ] + p r e [ j ] 2 − ( d p [ k ] + p r e [ k ] 2 ) p r e [ j ] − p r e [ k ] > p r e [ i ] − l \frac{dp[j]+pre[j]^2-(dp[k]+pre[k]^2)}{pre[j] - pre[k]}>pre[i]-l pre[j]−pre[k]dp[j]+pre[j]2−(dp[k]+pre[k]2)>pre[i]−l , p r e [ i ] − l pre[i]-l pre[i]−l 可以 看為一個常數。那麼意味著點 j ( d p [ j ] + p r e [ j ] 2 , p r e [ j ] ) j(dp[j]+pre[j]^2,pre[j]) j(dp[j]+pre[j]2,pre[j]) 與點 k ( d p [ k ] + p r e [ k ] 2 , p r e [ k ] ) k(dp[k]+pre[k]^2,pre[k]) k(dp[k]+pre[k]2,pre[k]) 所構成的直線的斜率大於 p r e [ i ] − l pre[i]-l pre[i]−l 這個常數。
獲得分子的函式:
int Get_Up(int j, int k) {
return dp[j] + pre[j] * pre[j] - dp[k] - pre[k] * pre[k];
}
獲得分母的函式:
int Get_Down(int j, int k) {
return pre[j] - pre[k];
}
有了上述的一級結論,可以進一步推匯出二級結論:
設
x
,
y
x,y
x,y 的斜率表示為
k
(
x
,
y
)
k(x,y)
k(x,y) 。若存在三點
a
,
b
,
c
a,b,c
a,b,c ,有
k
(
a
,
b
)
>
k
(
b
,
c
)
k(a,b)>k(b,c)
k(a,b)>k(b,c) ,即是影象形成上凸的形狀時,那麼點
b
b
b 絕對不是最優的。
分類討論:
- k ( a , b ) > k ( b , c ) > p r e [ i ] − l k(a,b)>k(b,c)>pre[i]-l k(a,b)>k(b,c)>pre[i]−l ,則對於上述結論可以得出 a a a 比 b b b 更優,捨去 b b b 。
- p r e [ i ] − l > k ( a , b ) > k ( b , c ) pre[i]-l>k(a,b)>k(b,c) pre[i]−l>k(a,b)>k(b,c) ,則對於上述結論可以得出 c c c 比 b b b 更優,捨去 b b b 。
- p r e [ i ] − l < k ( a , b ) pre[i]-l<k(a,b) pre[i]−l<k(a,b) 且 p r e [ i ] − l > k ( b , c ) pre[i]-l>k(b,c) pre[i]−l>k(b,c) ,則對於上述結論可以得出 a a a 和 c c c 都比 b b b 更優,捨去 b b b 。
那麼就可以得出答案的點必須滿足
k
(
a
1
,
a
2
)
<
k
(
a
2
,
a
3
)
<
.
.
.
<
k
(
a
m
−
1
,
a
m
)
k(a_1,a_2)<k(a_2,a_3)<...<k(a_{m-1},a_m)
k(a1,a2)<k(a2,a3)<...<k(am−1,am) 。全部呈現出下凸狀態,如下圖。
這樣下標遞增,斜率遞增的點集可以使用單調佇列來維護。
找出當前最優的點為 q u e [ h e a d ] que[head] que[head] ,即隊頭元素。
while(Get_Up(que[head + 1], que[head]) <= 2 * (pre[i] - l) * Get_Down(que[head + 1], que[head]) && head < tail)
head++;
用當前點 i i i 來更新佇列,使得該佇列呈下凸之勢。
while(Get_Up(que[tail], que[tail - 1]) * Get_Down(i, que[tail]) >= Get_Up(i, que[tail]) * Get_Down(que[tail], que[tail - 1]) && head < tail)
tail--;
按照上述方法進行狀態轉移,得到的 d p [ n ] dp[n] dp[n] 就是當前的最優解。
C++程式碼
程式碼比較短,一氣呵成。(注意要開 l o n g long long l o n g long long)
#include <cstdio>
#define int long long
void Quick_Read(int &N) {
N = 0;
int op = 1;
char c = getchar();
while(c < '0' || c > '9') {
if(c == '-')
op = -1;
c = getchar();
}
while(c >= '0' && c <= '9') {
N = (N << 1) + (N << 3) + (c ^ 48);
c = getchar();
}
N *= op;
}
void Quick_Write(int N) {
if(N < 0) {
putchar('-');
N = -N;
}
if(N >= 10)
Quick_Write(N / 10);
putchar(N % 10 + 48);
}
const int MAXN = 5e5 + 5;
int dp[MAXN];
int pre[MAXN], val[MAXN];
int n, l;
int que[MAXN];
int head, tail;
int Get_Dp(int i, int j) {
return dp[j] + (pre[i] - pre[j] - l) * (pre[i] - pre[j] - l);
}
int Get_Up(int j, int k) {
return dp[j] + pre[j] * pre[j] - dp[k] - pre[k] * pre[k];
}
int Get_Down(int j, int k) {
return pre[j] - pre[k];
}
void Line_Dp() {
head = 1;
tail = 1;
for(int i = 1; i <= n; i++) {
while(Get_Up(que[head + 1], que[head]) <= 2 * (pre[i] - l) * Get_Down(que[head + 1], que[head]) && head < tail)
head++;
dp[i] = Get_Dp(i, que[head]);
while(Get_Up(que[tail], que[tail - 1]) * Get_Down(i, que[tail]) >= Get_Up(i, que[tail]) * Get_Down(que[tail], que[tail - 1]) && head < tail)
tail--;
que[++tail] = i;
}
Quick_Write(dp[n]);
}
void Read() {
Quick_Read(n);
Quick_Read(l);
l++;
for(int i = 1; i <= n; i++) {
Quick_Read(val[i]);
pre[i] = pre[i - 1] + val[i] + 1;
}
}
signed main() {
Read();
Line_Dp();
return 0;
}