1. 程式人生 > >HDU3045 Picnic Cows (斜率DP優化)(數形結合)

HDU3045 Picnic Cows (斜率DP優化)(數形結合)

滿足 -s 坐標軸 明顯 spl 信息 更新 cow 常數

轉自PomeCat:

“DP的斜率優化——對不必要的狀態量進行拋棄,對不優的狀態量進行擱置,使得在常數時間內找到最優解成為可能。斜率優化依靠的是數形結合的思想,通過將每個階段和狀態的答案反映在坐標系上尋找解答的單調性,來在一個單調的答案(下標)隊列中O(1)得到最優解。”

https://wenku.baidu.com/view/b97cd22d0066f5335a8121a3.html

“一些試題中繁雜的代數關系身後往往隱藏著豐富的幾何背景,而借助背景圖形的性質,可以使那些原本復雜的數量關系和抽象的概念,顯得直觀,從而找到設計算法的捷徑。”—— 周源《淺談數形結合思想在信息學競賽中的應用》

斜率優化的核心即為數形結合,具體來說,就是以DP方程為基礎,通過變形來使得原方程形如一個函數解析式,再通過建立坐標系的方式,將每一個DP方程代表的狀態表示在坐標系中,在確定“斜率”單調性的前提下,進行形如單調隊列操作的舍解和維護操作。

一個算法總是用於解決實際問題的,所以結合例題來說是最好的:

Picnic Cows(HDU3045)

題目大意:
給出一個有N (1<= N <=400000)個正數的序列,要求把序列分成若幹組(可以打亂順序),每組的元素個數不能小於T (1 < T <= N)。每一組的代價是每個元素與最小元素的差之和,總代價是每個組的代價之和,求總代價的最小值。

樣例輸入包含:
第一行 N
第二行 N個數,如題意

樣例輸出包含:
第一行 最小的總代價

分析:
首先,審題。可以打亂序列順序,又知道代價為組內每個元素與最小值差之和,故想到貪心,先將序列排序(用STL sort)。
先從最簡單的DP方程想起:
容易想到:

f[i] = min( f[j] + (a[j + 1 -> i] - Min k) ) (0 <= j < i)

– –> f[i] = min( f[j] + sum[i] - sum[j] - a[j + 1] * ( i - j ) )

Min k 代表序列 j + 1 -> i 內的最小值,排序後可以簡化為a[j + 1]。提取相似項合並成前綴和sum。這個方程的思路就是枚舉 j 不斷地計算狀態值更新答案。但是數據規模達到了 40000 ,這種以O(n ^ 2)為絕對上界方法明顯行不通。所以接下來我們要引入“斜率”來優化。

首先要對方程進行變形:
f[i] = f[j] + sum[i] - sum[j] - a[j + 1] * ( i - j )
– –> f[i] = (f[j] - sum[j] + a[j + 1] * j) - i * a[j + 1] + sum[i]
(此步將只由i決定的量與只由j決定的量分開)
由於 sum[i] 在當前枚舉到 i 的狀態下是一個不變量,所以在分析

時可以忽略(因為對決策優不優沒有影響)(當然寫的時候肯定不能忽略)

令 i = k
a[j + 1] = x
f[j] - sum[j] + a[j + 1] * j = y
f[i] = b
原方程變為
– –> b = y - k * x
移項
– –> y = k * x + b

是不是很眼熟? 沒錯,這就是直線的解析式。觀察這個式子,我們可以發現,當我們吧許許多多的答案放在坐標系上構成點集,且枚舉到 i 時,過每一個點的斜率是一樣的!! 很抽象? 看圖

技術分享

可以看出,我們要求的f[i]就是截距,明顯,延長最右邊的直線交於坐標軸可得最小值。難道只要每次提取最靠近 i 的狀態就行了嘛?現實沒有那麽美好。

技術分享

像這樣的情況,過2直線的截距明顯比過3直線的截距要小, 意味著更優(在找求解最小值問題時),這種情況下我們之前的猜想便行不通。

那怎麽辦呢?這時就要用到斜率優化的核心思想——維護凸包。
何為凸包?
不懂得同學還是戳這裏:http://baike.baidu.com/link?url=OWX7haiZHtuKD2hjbEBVouUGxKXIMvXDnKra0xDhxuz9zGttTg_JoRwmUcbrGD9Xp2BnbCJDF8BblaQffDBvblg0xNqgIOXCVZpQ7ZNBkWG

其實我們要維護的凸包與這個凸包並無關系,只是在圖中長得像罷了。
那為什麽要維護凸包呢?
還要看圖:
技術分享

這就是一個下凸包,由圖可見,最前面的那個點的截距最小,也詮釋了維護凸包的真正含義(想一想優先隊列,是不是隊首最優?)。那還是有人會提出疑問,為什麽非要維護這樣的凸包呢? 答案就是,f[i]明顯是遞增的(相比於f[j] 加上一個代價),所以會在圖中自然而然地顯現出 Y 隨著 X增加而增加的情況,呈現出凸包的模樣。

可能這個過程比較晦澀難懂,沒懂的同學可以多看幾遍。

現在我們回到對 的分析

現在假設我們正在枚舉 j 來更新答案,有一個數 k, j < k < i
再來假設 k 比 j 優(之所以要假設正是要推出具體情況方便舍解)

則有

f[k] + sum[i] - sum[k] - a[k + 1] * (i - k) <= 
f[j] + sum[i] - sum[j] - a[j + 1] * (i - j) (k > j)

移項消項得

f[k] - sum[k] + a[k+ 1] * k - (f[j] - sum[j] + a[j + 1] * j) <= i * (a[k + 1] - a[j+ 1])

將只與 i 有關的元素留下,剩下的除過去, 得到

f[k] - sum[k] + a[k+ 1] * k - (f[j] - sum[j] + a[j + 1] * j) / (a[k + 1] - a[j + 1])<= i 

(這裏註意判斷除數是否為負, 要變號,當然這裏排序過後對於 k > j a[k] 肯定大於 a[j])

這個式子什麽意思呢?我用人類的語言解釋一下。
設 Xi = a[i], Yi = f[i] - sum[i] + a[i + 1] * i
那麽原式即為如下形式:

(Yk - Yj) / (Xk - Xj) <= i

意思就是當有k 和 j 滿足 j < k 的前提下 滿足此不等式
則證明 j 沒有 k 優

而這個式子的左邊數學意義是斜率, 而右邊是一個遞增的變量, 所以遞增的 i 會淘汰隊列裏的元素, 而為了高效的淘汰, 我們會(在這道題裏)選用斜率遞增的單調隊列,也就是上凸包。(再看看前面的圖,是不是斜率在遞增)

我們還可以從另一個角度理解維護上凸包的理由

仔細觀察下面的圖:

一開始,1 號點的截距比2號點更優

技術分享

隨著斜率的變化,兩個點的截距變得一樣了

然後,斜率接著變化,1號點開始沒有2號點優了,所以要舍棄

技術分享

後面的過程,3號點會漸漸超過2號點,並淘汰掉2號點

技術分享

分析整個過程,最優點依次是 1 -> 2 -> 3,滿足單調的要求

但是如果是一個上凸包會怎樣呢?

這裏只給出最終圖(有興趣的同學可以自己推一推),可以預見的是,在1趕超2前,3先趕超了2就破壞了順序,因此不行

技術分享

思路大概是清晰了,現在只剩下代碼實現方面的問題了

下面就看看單調隊列的操作

先將推出的X, Y用函數表示方便計算:
X:

dnt X( int i, int j )
{
    return a[j + 1] - a[i + 1];
}

 
  • 1
  • 2
  • 3
  • 4

(dnt 是 typedef 的 long long)

Y:

dnt Y( int i, int j )
{
    return f[j] - sum[j] + j * a[j + 1] - (f[i] - sum[i] + i * a[i + 1]);
}

 
  • 1
  • 2
  • 3
  • 4

處理隊首:

for(; h + 1 < t && Y(Q[h + 1], Q[h + 2]) <= i * X(Q[h + 1], Q[h + 2]); h++);
 
  • 1

從隊尾維護單調性:
(這裏是一個下凸包所以兩點之間的斜率要遞增,即 斜率(1, 2) < 斜率(2, 3), 前一個斜率比後一個小)

for(; h + 1 < t && Y(Q[t - 1], Q[t]) * X(Q[t], cur) >= X(Q[t - 1], Q[t]) * Y(Q[t], cur); t--);

 
  • 1

(註意,要把除法寫成乘的形式,不然精度可能會出問題)

斜率優化部分已經完結(說起來挺復雜其實代碼很短),接下來就放出AC代碼:

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <iostream>
using namespace std;

typedef long long dnt;

int n, T, Q[405005];
dnt sum[405005], f[405005], a[405005];

dnt Y( int i, int j )
{
    return f[j] - sum[j] + j * a[j + 1] - (f[i] - sum[i] + i * a[i + 1]);
}

dnt X( int i, int j )
{
    return a[j + 1] - a[i + 1];
}

dnt DP( int i, int j )
{
    return f[j] + (sum[i] - sum[j]) - (i - j) * a[j + 1];
}

inline dnt R()
{
    static char ch;
    register dnt res, T = 1;
    while( ( ch = getchar() ) < 0  || ch > 9 )if( ch == - )T = -1; 
        res = ch - 48;
    while( ( ch = getchar() ) <= 9 && ch >= 0)
        res = res * 10 + ch - 48;
    return res*T;
}

int main()
{
    sum[0] = 0;
    while(~scanf( "%d%d", &n, &T ))
    {
        a[0] = 0, f[0] = 0;
        for(int i = 1; i <= n; i++)
            scanf( "%I64d", &a[i] );
        sort(a + 1, a + n + 1);
        for(int i = 1; i <= n; i++)
            sum[i] = sum[i - 1] + a[i];
        int h = 0, t = 0;
        Q[++t] = 0;
        for(int i = 1; i <= n; i++)
        {
            int cur = i - T + 1;
            for(; h + 1 < t && Y(Q[h + 1], Q[h + 2]) <= i * X(Q[h + 1], Q[h + 2]); h++);
            f[i] = DP(i, Q[h + 1]);
            if(cur < T) continue;
            for(; h + 1 < t && Y(Q[t - 1], Q[t]) * X(Q[t], cur) >= X(Q[t - 1], Q[t]) * Y(Q[t], cur); t--);
            Q[++t] = cur;
        }
        printf( "%I64d\n", f[n] );
    }   
    return 0;
}

註釋版本:

技術分享
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
const int maxn=800010;
long long  dp[maxn],q[maxn];
long long  a[maxn],sum[maxn];
long long getdp(long long i,long long j)
{
    return dp[j]+(sum[i]-sum[j])-a[j+1]*(i-j);
}
long long getdy(long long j,long long k)//得到 yj-yk  k<j
{
    return dp[j]-sum[j]+j*a[j+1]-(dp[k]-sum[k]+k*a[k+1]);
}
long long getdx(long long j,long long k)//得到 xj-xk  k<j
{
    return a[j+1]-a[k+1];
}
int main()
{
    long long i,j,n,k,head,tail,m;
    while(~scanf("%lld%lld",&n,&m)){
        head=tail=0;
        sum[0]=q[0]=dp[0]=q[1]=0;
        for(i=1;i<=n;i++) scanf("%lld",&a[i]);        
        sort(a+1,a+n+1);
        for(i=1;i<=n;i++) sum[i]=sum[i-1]+a[i];
        for(i=1;i<=n;i++){
                       //刪去隊首斜率小於目前斜率的點 
            while(head<tail&&(getdy(q[head+1],q[head])<=i*getdx(q[head+1],q[head]))) head++;
            dp[i]=getdp(i,q[head]);
            j=i-m+1;
            if(j<m) continue;
            //接下來是對j而不是i進行處理 ,保證了間隔大於m-1的要求 
            while(head<tail&&(getdy(j,q[tail])*getdx(j,q[tail-1])<=getdy(j,q[tail-1])*getdx(j,q[tail]))) tail--;
            q[++tail]=j;
        }
        printf("%lld\n",dp[n]);
    }
    return 0;
}
View Code

HDU3045 Picnic Cows (斜率DP優化)(數形結合)