1. 程式人生 > 實用技巧 >HDU-4747 Mex 線段樹應用 Mex性質

HDU-4747 Mex 線段樹應用 Mex性質

HDU-4747 Mex 線段樹應用 Mex性質

題意

給定長度為\(n\)的陣列\(a\),求

\[\sum \sum mex(i,j) \]

其中\(mex(i,j)\)表示區間\(mex(a_i...a_j)的值\)

\[1\leq n \leq 2\times 10^5\\ 1\leq a_i \leq 10^9 \]

分析

此題我認為還是不太好想到的

首先如果只求一維,由於單調性,求\(\sum mex(1,i)\)是可以在\(O(n)\)下完成的。

然後注意到第二維即\(\sum mex(2,i)\)該如何計算,這個時候\(1\)相當於沒有了,1在這一維上產生的影響就是當前下一個等於\(a[1]\)

的元素之前的一段。後面的顯然和原來的保持不變,這就讓我們想到了用區間維護。

那麼\(a[1]\)會如何影響\([2,next[a[1]] - 1]\)呢?

再次想到\(mex\)在“字首”意義上的單調性,我們只需要把其中大於\(a[1]\)的部分變為\(a[1]\)即可,其他部分的\(mex\)並不會受影響

最後由於遞推,不要忘記單點修改。

所以問題就轉化成了

  • 求出每個數的下一個等於它的數出現的位置
  • 求出第一個大於等於\(a[i]\)的位置
  • 修改某一段區間的值

這些都可以用線段樹實現,當然要注意一些細節,比如\(lazy\)標記應該設定\(-1\),否則\(mx\)會無法下傳,以及下一個位置陣列應該在最後加上\(n + 1\)

程式碼

struct Tree {
    int lazy;
    int sum;
    int mx;
    int l, r;
};

int n;
Tree node[maxn << 2];
int a[maxn];
int mex[maxn];
int nxt[maxn];

void push_up(int i) {
    node[i].sum = node[i << 1].sum + node[i << 1 | 1].sum;
    node[i].mx = max(node[i << 1].mx, node[i << 1 | 1].mx);
}

void build(int i, int l, int r) {
    node[i].l = l;
    node[i].r = r;
    if (l == r) {
        node[i].sum = mex[l];
        node[i].mx = mex[l];
        return;
    }
    int mid = l + r >> 1;
    build(i << 1, l, mid);
    build(i << 1 | 1, mid + 1, r);
    push_up(i);
}

void push_down(int i, int m) {
    if (node[i].lazy >= 0) {
        node[i << 1].lazy = node[i].lazy;
        node[i << 1 | 1].lazy = node[i].lazy;
        node[i << 1].sum = node[i].lazy * (m - (m >> 1));
        node[i << 1 | 1].sum = node[i].lazy * (m >> 1);
        node[i << 1].mx = node[i << 1 | 1].mx = node[i].lazy;
        node[i].lazy = -1;
    }
}

void update(int i, int l, int r, int val) {
    if (node[i].l > r || node[i].r < l) return;
    if (node[i].l >= l && node[i].r <= r) {
        //bug;
        node[i].lazy = val;
        node[i].sum = (node[i].r - node[i].l + 1) * val;
        node[i].mx = val;
        //cout << node[i].mx << ' ' << i << '\n';
        return;
    }
    push_down(i, node[i].r - node[i].l + 1);
    update(i << 1, l, r, val);
    update(i << 1 | 1, l, r, val);
    push_up(i);
}

int query(int i, int x) {
    if (node[i].l == node[i].r) {
        return node[i].l;
    }
    if (node[i << 1].mx > x) return query(i << 1, x);
    else return query(i << 1 | 1, x);
}


signed main() {
    while (scanf("%lld", &n)) {
        if (!n) break;
        for (int i = 1; i <= n; i++)
            a[i] = readint(), nxt[i] = n + 1;
        for (int i = 1; i < 4 * n; i++) {
            node[i].l = node[i].r = node[i].sum = node[i].lazy = -1, node[i].mx = 0;
        }
        unordered_map<int, int> mp;
        int cur = 0;
        for (int i = 1; i <= n; i++) {
            if (mp[a[i]]) nxt[mp[a[i]]] = i, mp[a[i]] = i;
            else mp[a[i]] = i;
            while (mp[cur]) cur++;
            mex[i] = cur;
        }
        build(1, 1, n);
        ll ans = 0;
        for (int i = 1; i <= n; i++) {
            ans += node[1].sum;
            if (node[1].mx > a[i]) {
                int l = query(1,a[i]);
                int r = nxt[i] - 1;
                if (l <= r) update(1, l, r, a[i]);
            }
            update(1, i, i, 0);
        }
        cout << ans << '\n';
    }
}