1. 程式人生 > 其它 >程式碼源每日一題Div1 103 題解

程式碼源每日一題Div1 103 題解

題目連結

簡要題意

一個序列的最大差,定義為該序列最大值與最小值的差。

給定一個長度為 \(n\) 的數列 \(\{a_n\}\),求出該數列所有連續子串的最大差之和。

\(n\leq 5*10^5,0\leq a_i\leq 10^8\)

本題目原型為CF817D Imbalanced Array

題解

列舉子串肯定不行,雙指標也只是算最大/最小而不是求和,所以我們只能算每個數的貢獻。

我們記一個子序列的最大值為 \(f(l,r)\),最小值為 \(g(l,r)\),那麼我們要求的值即為:

\[\sum\limits_{i=1}^n\sum\limits_{j=i}^n(f(l,r)-g(l,r))=\sum\limits_{i=1}^n\sum\limits_{j=i}^nf(l,r)-\sum\limits_{i=1}^n\sum\limits_{j=i}^ng(l,r) \]

這個等式意味著一個數列的 f 和 g 可以分開來單獨計算,而不是合併起來一起計算。

分開計算後,我們分別使用單調棧來分別處理即可:對於某個元素 \(x\),直接找到左右邊界,然後乘法原理算出所有區間的數量,最後逐漸累加即可。

同時,注意到這種計算必須得讓所有元素各不相同,否則會導致一個區間被算了多次或者出現重複等情況,所以必須得進行一次離散化操作。

離散化操作+單調棧,總複雜度 \(O(n\log n)\)

#include <bits/stdc++.h>
using namespace std;
#define LL long long
const int N = 1000010;
int n, a[N];
int f[N], g[N];
namespace DC {
    struct Node {
        int val, id;
        bool operator < (const Node &rhs) {
            return val < rhs.val;
        }
    } arr[N];
    void solve() {
        for (int i = 1; i <= n; ++i)
            arr[i] = (Node){a[i], i};
        sort(arr + 1, arr + n + 1);
        for (int i = 1; i <= n; ++i)
            a[arr[i].id] = i;
    }
    int get(int id) { return arr[id].val; }
};
LL solve1() {
    memset(f, 0, sizeof(f));
    memset(g, 0, sizeof(g));
    stack<int> s;
    for (int i = 1; i <= n; ++i) {
        while (!s.empty() && a[s.top()] < a[i]) {
            int x = s.top(); s.pop();
            f[x] = i;
        }
        s.push(i);
    }
    for (int i = 1; i <= n; ++i)
        if (f[i] == 0) f[i] = n + 1;
    while (!s.empty()) s.pop();
    for (int i = n; i >= 1; --i) {
        while (!s.empty() && a[s.top()] < a[i]) {
            int x = s.top(); s.pop();
            g[x] = i;
        }
        s.push(i);
    }
    LL res = 0;
    for (int i = 1; i <= n; ++i)
        res += 1LL * DC::get(a[i]) * (i - g[i]) * (f[i] - i);
    return res;
}
LL solve2() {
    memset(f, 0, sizeof(f));
    memset(g, 0, sizeof(g));
    stack<int> s;
    for (int i = 1; i <= n; ++i) {
        while (!s.empty() && a[s.top()] > a[i]) {
            int x = s.top(); s.pop();
            f[x] = i;
        }
        s.push(i);
    }
    for (int i = 1; i <= n; ++i)
        if (f[i] == 0) f[i] = n + 1;
    while (!s.empty()) s.pop();
    for (int i = n; i >= 1; --i) {
        while (!s.empty() && a[s.top()] > a[i]) {
            int x = s.top(); s.pop();
            g[x] = i;
        }
        s.push(i);
    }
    LL res = 0;
    for (int i = 1; i <= n; ++i)
        res += 1LL * DC::get(a[i]) * (i - g[i]) * (f[i] - i);
    return res;
}
int main()
{
    cin >> n;
    for (int i = 1; i <= n; ++i)
        cin >> a[i];
    DC::solve();
    cout << solve1() - solve2() << endl;
    return 0;
}