1. 程式人生 > >HDU6058 Kanade's sum

HDU6058 Kanade's sum

Kanade’s sum

題目大意就是給你一個1~n的一個排列,求滿足題意的所有區間[l,r](區間長度至少為k且1 <= l <= r <= n)中第k大的數之和。

舉個例子:
比如說給你 n = 4, k = 2,其中數列是a[4] = {3,1,4,2}。因為k = 2, 所以區間長度至少為2.
那麼滿足提議的區間有下列這些:

[1,2]:{3,1},第2(k = 2)大的數是1
[1,3]:{3,1,4},第2(k = 2)大的數是3
[1,4]:{3,1,4,2},第2(k = 2)大的數是3
[2,3]:{1,4},第2(k = 2)大的數是1
[2,4]:{1,4,2},第2(k = 2)大的數是2
[3,4]:{4,2},第2(k = 2)大的數是2
sum = 1 + 3 + 3 + 1 + 2 + 2 = 12;

這道題的做法我知道的有兩種:

第一種,用一個for迴圈遍歷a[],對於每一個a[i],去找它前面第k大和它後面第k大的位置(為什麼是第k大而不是第k-1大呢?因為在第k大和第k-1大之間的數也是可以取的,要算上。先往下看),通過這些位置來計算出以a[i]為第k大的數的區間個數有幾個,則sum += a[i] * cnt即可
在這個for迴圈裡用兩個陣列left[]和right[]維護這些位置。left[1]表示a[i]左邊比a[i]大的最近數的位置,以此類推。
上述例子中對於a[2] = 1,left[1] = 1, right[1] = 3。記left[]的長度為lcnt,right[]的長度為rcnt

對於cnt的計算類似於乘法原理,比如我們要算a[i]為第k大的區間個數,就從1到lcnt遍歷left[]陣列,當我們以left[s]為a[i]左邊第s大的位置為最遠端(近似)時,右邊就只能以right[k - 1 - s]為最遠端(近似),因為在我們要算的這個區間裡比a[i]大的個數是k - 1個。那麼就這[left[s],right[k - 1 - s]]一個區間滿足嗎?當然不是,正確答案是[l,r],其中left[s + 1]< l <= left[s],right[k - 1 - s] <= r < right[k - s]。那麼根據乘法原理,左邊能取left[s] - left[s + 1]個,右邊能取right[k -s] - right[k - 1 -s]個,則cnt += (left[s] - left[s + 1])* (right[k -s] - right[k - 1 -s])。
這是正常情況,還有些特殊情況需要處理。這個就在程式裡自己特判了。
最後說下時間複雜度,這種方法其實不是很好,因為是o(n²)的,但是由於題目要求2s,n²勉強能過,大概跑了1.7s左右。

第二種:先用pos[a[i]]=i記錄下每個a[i]的位置,因為是1~n的排列,也就是說我們如果從1~n遍歷for迴圈,每次計算i為第k大的區間個數,用陣列模擬連結串列維護比i大的數,然後刪除那個結點。複雜度o(nk),0.46s。
兩種程式碼分別如下:

#include<cstdio>
#include<cstring>
using namespace std;
typedef long long ll;
const int maxn = 5e5 + 5;
int a[maxn];
int left[maxn];
int right[maxn];
int main()
{
    int cas;
    scanf("%d", &cas);
    while(cas--)
    {
        int n, k;
        scanf("%d%d", &n, &k);
        for(int i = 1; i <= n; i++) scanf("%d", &a[i]);
        ll sum  = 0;
        for(int i = 1; i <= n; i++)
        {
            int lcnt = 0, rcnt = 0, j = 0;
            for(j = i - 1; j >= 1; j--)
            {
                if(a[j] > a[i])
                {
                    lcnt ++;
                    left[lcnt] = i - j;
                }
                if(lcnt >= k) break;
            }
            if(j <= 1) left[++lcnt] = i;
            for(j = i + 1; j <= n; j++)
            {
                if(a[j] > a[i])
                {
                    rcnt ++;
                    right[rcnt] = j - i;
                }
                if(rcnt >= k) break;
            }
            if(j >= n) right[++rcnt] = n - i + 1;
            for(j = 1; j <= lcnt; j++)
            {
                if(k - j >= rcnt) continue;
                sum += (ll) a[i] *  (left[j] - left[j - 1])  * (right[k - j + 1] - right[k - j]);
            }
        }
        printf("%lld\n", sum);
    }
}
#include<cstdio>
using namespace std;
typedef long long ll;
const int maxn = 5e5 + 10;
int a[maxn], pos[maxn], pre[maxn], nxt[maxn];
ll left[100], right[100];
int n,k;
void del(int x)
{
    pre[nxt[x]] = pre[x];
    nxt[pre[x]] = nxt[x];
}
ll cal(int x)
{
    int  lcnt = 0, rcnt = 0;
    for(int i = x; i > 0; i = pre[i])
    {
        lcnt ++;
        left[lcnt] = i - pre[i];
        if(lcnt == k) break;
    }
    for(int i = x; i <= n; i = nxt[i])
    {
        rcnt ++;
        right[rcnt] = nxt[i] - i;
        if(rcnt == k) break;
    }
    ll res = 0;
    for(int i = 1; i <= lcnt; i ++)
    {
        if(k - i + 1 <= rcnt)
        {
            res += left[i] * right[k - i + 1];
        }
    }
    return res;
}
int main()
{
    int lcnts = 0;
    int t;
    scanf("%d", &t);
    while(t--)
    {
        scanf("%d%d", &n, &k);
        for(int i = 1; i <= n; i++)
        {
            scanf("%d", &a[i]);
            pos[a[i]] = i;
            pre[i] = i - 1;
            nxt[i] = i + 1;
        }
        pre[0] = 0;
        nxt[n + 1] = n + 1;
        ll sum = 0;
        for(int i = 1; i <= n; i++)
        {
            int x = pos[i];
            sum += cal(x) * i;
            del(x);
        }
        printf("%I64d\n", sum);
    }
}