1. 程式人生 > >『LIS問題的優化:O(NlogN)效率的演算法』

『LIS問題的優化:O(NlogN)效率的演算法』


<更新提示>

<第一次更新>深究LIS問題


<正文>

LIS問題

求解長度為n的序列中最長上升子序列的長度。

分析

顯然可以令f[i]為以元素a[i]結尾的最長上升子序列的長度,則狀態轉移方程為: f [ i ]

= m a x { f [ j ] + 1 } ( f [
j ] < f [ i ] ) ,可以直接花 O ( n
2 2 )
的時間來暴力轉移。其程式碼實現如下:

#include<bits/stdc++.h>
using namespace std;
int n,a[1000080]={},f[100080]={};
int main()
{
    cin>>n;
    for(int i=1;i<=n;i++)cin>>a[i],f[i]=1;
    for(int i=1;i<=n;i++)
    {
        for(int j=1;j<i;j++)
        {
            if(a[j]<a[i])f[i]=max(f[i],f[j]+1);
        }
    }
    cout<<f[n]<<endl;
    return 0;
}

這是LIS問題的暴力解法,詳見『基礎DP專題:LIS,LCS和揹包九講(不包括泛化物品)及實現』。這裡不在深入討論。

那麼我們考慮一個升級版的LIS問題:

LIS模板題

題目描述
有N個整數,輸出這N個整數的最長上升序列、最長下降序列、最長不上升序列和最長不下降序列。
輸入格式
第一行,僅有一個數N。 N<=700000 第二行,有N個整數。 -10^9<=每個數<=10^9
輸出格式
第一行,輸出最長上升序列長度。 第二行,輸出最長下降序列長度。 第三行,輸出最長不上升序列長度。 第四行,輸出最長不下降序列長度。
樣例資料
input
10
1 3 0 8 6 2 3 1 4 2
output
4
4
4
4
時間限制:
1s
空間限制:
256MB

分析

70萬的資料量, O ( n 2 2 ) 的效率一遍都不夠,還要做四遍。顯然,我們必須更換策略。
O ( n ) 的演算法顯然難以實現,如果用高階資料結構優化,實現難度甚至不在noip提高組範圍內,永不考慮。
那我們考慮 O ( n l o g 2 n ) 的演算法, l o g 2 700000 = 19.417 ,所以 O ( n l o g 2 n ) 的效率做4遍近似於 O ( 700000 19.417 4 ) = O ( 13591900 4 ) O ( 5.4 10 7 ) 在考慮範圍內。我們嘗試尋找該效率的演算法。
怎麼實現呢?我們考慮一個數組d[i],代表長度為i的最長上升子序列的最小尾元素。仔細理解d[i]的含義。那麼我們如何完成這個陣列呢?花費 O ( n ) 時間線性遍歷序列。如果遇到一個數大於d陣列末尾的數,說明它可以成為最長上升子序列的一部分,將d陣列尾指標加一,存下這個數。如果這個數小於等於d陣列末尾的數,那麼我們在d陣列中找到第一個大於等於它的數所對應的位置,並將當前這個數替換掉該位置的數。為什麼要替換呢,首先是操作的合法性,既然d陣列的尾指標所指的數大於等於它,那麼它就可以插入到以前最長上升子序列的某個位置,那麼插入到第一個大於等於它的位置即可。這樣做是為了更好的響應以後的該操作及第一種情況,也就是說,這樣替換,它的潛力就更大了,就有可能組成更長的最長上升子序列。完成d陣列後,它的尾指標的大小就是最長上升子序列的長度。
那麼為什麼時間複雜度是 O ( n l o g 2 n ) 呢,顯然,線性掃描時間複雜度 O ( n ) ,那麼由第一種情況及定義可知,d陣列是單調上升的,那麼第二種情況查詢第一個大於等於它的數字就可以用二分查詢優化。所以總的時間複雜度就是 O ( n l o g 2 n )
再來考慮一個細節問題,為什麼要查詢第一個大於等於它是數而不是查詢第一個大於它的數呢?原因是我們在有相等的數時,要先替換掉相等的數,不然就成了最長不下降子序列的做法
關於二分查詢,我們可以手寫,當然,可以直接用lower_bound函式,剛好符合我們的需求。但是,避免返回地址出錯,我們從0開始存數就能避免。

最長上升子序列程式碼實現如下:

#include<bits/stdc++.h>
using namespace std;
int n,t=0,a[100080]={},d[100080]={};
int main()
{
    cin>>n;
    for(int i=0;i<n;i++)cin>>a[i];
    d[0]=a[0];
    for(int i=0;i<n;i++)
    {
        if(a[i]>d[t])d[++t]=a[i];
        else d[ lower_bound(d,d+t,a[i])-d ]=a[i];
    }
    cout<<t+1<<endl;
    return 0;
}

那麼最長不下降子序列的d[i]就是代表長度為i的最長不下降子序列的最小尾元素,我們將比較d陣列末尾元素與a[i]大小時將符號改為>=即可。由於最長不下降子序列中允許大小相同的元素存在,我們在查詢時如有相同不必替換,需要查詢第一個比它大的元素才能更優,即使用upper_bound函式即可。

最長不下降子序列程式碼實現如下:

#include<bits/stdc++.h>
using namespace std;
int n,t=0,a[100080]={},d[100080]={};
int main()
{
    cin>>n;
    for(int i=0;i<n;i++)cin>>a[i];
    d[0]=a[0];
    for(int i=0;i<n;i++)
    {
        if(a[i]>=d[t])d[++t]=a[i];
        else d[ upper_bound(d,d+t,a[i])-d ]=a[i];
    }
    cout<<t+1<<endl;
    return 0;
}

由此可以推出,最長下降子序列也是相同的。最長下降子序列的d[i]代表代表長度為i的最長下降子序列的最大尾元素。每一次掃描一個數,如果比d陣列尾元素小就加入d陣列,比d陣列尾元素大或相等就在單調下降的d陣列中找到第一個小於等於它的數,替換它,原理與最長上升子序列相同。不過,這個二分就要我們手寫了,難度不是很大,如果對二分的細節處理還有疑惑,可以看我的部落格『二分查詢和二分答案』
最長下降子序列程式碼實現如下:

#include<bits/stdc++.h>
using namespace std;
int n,t=0,a[100080]={},d[100080]={};
inline int find(int p[],int len,int num)
{
    int l=0,r=len;
    while(l+1<r)
    {
        int mid=(l+r)>>1;
        if(p[mid]<num)r=mid;
        else l=mid;
    }
    if(p[l]>num)return r;
    else return l;
}
int main()
{
    cin>>n;
    for(int i=0;i<n;i++)cin>>a[i];
    d[0]=a[0];
    for(int i=0;i<n;i++)
    {
        if(a[i]<d[t])d[++t]=a[i];
        else d[ find(d,t,a[i]) ]=a[i];
    }
    cout<<t+1<<endl;
    return 0;
}

最長不上升子序列的程式碼就與最長下降子序列相近了。最長不上升子序列的d[i]代表代表長度為i的最長不上升子序列的最大尾元素。在掃描時,如果a[i]小於等於d陣列尾元素,將其加入d陣列,否則找到第一個小於它的數,因為可以相等,所以找比它小的才能更優,與最長不下降子序列相同。而在程式碼上把最長下降子序列的比較的符號改為小於等於,在二分查詢時也把最後的比較改為大於等於即可。
最長不上升子序列程式碼實現如下:

#include<bits/stdc++.h>
using namespace std;
int n,t=0,a[100080]={},d[100080]={};
inline int find(int p[],int len,int num)
{
    int l=0,r=len;
    while(l+1<r)
    {
        int mid=(l+r)>>1;
        if(p[mid]<num)r=mid;
        else l=mid;
    }
    if(p[l]>=num)return r;
    else return l;
}
int main()
{
    cin>>n;
    for(int i=0;i<n;i++)cin>>a[i];
    d[0]=a[0];
    for(int i=0;i<n;i++)
    {
        if(a[i]<=d[t])d[++t]=a[i];
        else d[ find(d,t,a[i]) ]=a[i];
    }
    cout<<t+1<<endl;
    return 0;
}

那麼整合起來就是LIS模板題的標準答案:

#include<bits/stdc++.h>
using namespace std;
int n,a[700080]={},t1=0,t2=0,t3=0,t4=0;
int d1[700080]={},d2[700080]={},d3[700080]={},d4[700080]={};
inline int find1(int p[],int len,int num)
{
    int l=0,r=len;
    while(l+1<r)
    {
        int mid=(l+r)>>1;
        if(p[mid]<num)r=mid;
        else l=mid;
    }
    if(p[l]>num)return r;
    else return l;
}
inline int find2(int p[],int len,int num)
{
    int l=0,r=len;
    while(l+1<r)
    {
        int mid=(l+r)>>1;
        if(p[mid]<num)r=mid;
        else l=mid;
    }
    if(p[l]>=num)return r;
    else return l;
}
inline void a_()
{
    d1[0]=a[0];
    for(int i=1;i<n;i++)
    {
        if(a[i]>d1[t1])d1[++t1]=a[i];
        else d1[ lower_bound(d1,d1+t1,a[i])-d1 ]=a[i];
    }
    cout<<t1+1<<endl;
}
inline void b_()
{
    d2[0]=a[0];
    for(int i=1;i<n;i++)
    {
        if(a[i]<d2[t2])d2[++t2]=a[i];
        else d2[ find1(d2,t2,a[i]) ]=a[i];
    }
    cout<<t2+1<<endl;
} 
inline void c_()
{
    d3[0]=a[0];
    for(int i=1;i<n;i++)
    {
        if(a[i]<=d3[t3])d3[++t3]=a[i];
        else d3[ find2(d3,t3,a[i]) ]=a[i];
    }
    cout<<t3+1<<endl;
} 
inline void d_()
{
    d4[0]=a[0];
    for(int i=1;i<n;i++)
    {
        if(a[i]>=d4[t4])d4[++t4]=a[i];
        else d4[ upper_bound(d4,d4+t4,a[i])-d4 ]=a[i];
    }
    cout<<t4+1<<endl;
}
int main()
{
    cin>>n;
    for(int i=0;i<n;i++)cin>>a[i];
    a_();b_();c_();d_();
}

<後記>


<廢話>