1. 程式人生 > 實用技巧 >P6242 【模板】線段樹 3 線段樹維護歷史最值+區間取min

P6242 【模板】線段樹 3 線段樹維護歷史最值+區間取min

P6242 【模板】線段樹 3

​ 線段樹維護歷史最值+區間取min。

​ 區間取min:

​ 線段樹維護一個區間最大值\((MaxA)\)和嚴格次大值\((se)\),還要維護最大值個數\(cnt\),區間和\(sum\),然後分情況:(設當前與\(k\)取min)

​ 當\(k >= t[o].MaxA\)時,直接返回;

​ 當\(t[o].se < k < t[o].MaxA\)時,\(t[o].sum += t[o].cnt * (k - t[o].MaxA)\)\(t[o].MaxA = k\)

​ 當\(k <= t[o].se\)時,繼續往下遞迴。

​ 具體維護看程式碼:

void up(int o) {
    if(t[ls(o)].MaxA > t[rs(o)].MaxA) {
        t[o].MaxA = t[ls(o)].MaxA; t[o].cnt = t[ls(o)].cnt;
        t[o].se = max(t[ls(o)].se, t[rs(o)].MaxA);
    }
    else if(t[ls(o)].MaxA < t[rs(o)].MaxA) {
        t[o].MaxA = t[rs(o)].MaxA; t[o].cnt = t[rs(o)].cnt;
        t[o].se = max(t[rs(o)].se, t[ls(o)].MaxA);
    }
    else {
        t[o].MaxA = t[rs(o)].MaxA; t[o].cnt = t[ls(o)].cnt + t[rs(o)].cnt;
        t[o].se = max(t[ls(o)].se, t[rs(o)].se);
    }
    t[o].MaxB = max(t[ls(o)].MaxB, t[rs(o)].MaxB);
    t[o].sum = t[ls(o)].sum + t[rs(o)].sum;
}

​ 維護歷史最大值:

​ 要維護4個標記:最大值加減標記\((add1)\),最大值歷史最大加減標記\((add1\)_ \()\),非最大值加減標記\((add2)\),非最大值歷史最大加減標記\((add2\)_\()\)

void modify(int o, int l, int r, int add1, int add1_, int add2, int add2_) {
    t[o].sum += 1ll * add1 * t[o].cnt + 1ll * add2 * (r - l + 1 - t[o].cnt);

    t[o].MaxB = max(t[o].MaxB, t[o].MaxA + add1_); // MaxB代表歷史最大值,用add1_更新
    t[o].add1_ = max(t[o].add1_, t[o].add1 + add1_); //標記也記得更新
    t[o].MaxA += add1; t[o].add1 += add1; 

    t[o].add2_ = max(t[o].add2_, t[o].add2 + add2_);  
    if(t[o].se != -inf) t[o].se += add2; t[o].add2 += add2;
}

完整程式碼:

#include <bits/stdc++.h>

#define ls(o) (o << 1)
#define rs(o) (o << 1 | 1)
#define mid ((l + r) >> 1)

using namespace std;

inline long long read() {
    long long s = 0, f = 1; char ch;
    while(!isdigit(ch = getchar())) (ch == '-') && (f = -f);
    for(s = ch ^ 48;isdigit(ch = getchar()); s = (s << 1) + (s << 3) + (ch ^ 48));
    return s * f;
}

const int N = 1e6 + 5, inf = 2e9;
int n, m;
long long x;
struct tree { 
    long long sum; 
    int MaxA, MaxB, cnt, se;
    int add1, add1_, add2, add2_; 
} t[N << 2];

void up(int o) {
    if(t[ls(o)].MaxA > t[rs(o)].MaxA) {
        t[o].MaxA = t[ls(o)].MaxA; t[o].cnt = t[ls(o)].cnt;
        t[o].se = max(t[ls(o)].se, t[rs(o)].MaxA);
    }
    else if(t[ls(o)].MaxA < t[rs(o)].MaxA) {
        t[o].MaxA = t[rs(o)].MaxA; t[o].cnt = t[rs(o)].cnt;
        t[o].se = max(t[rs(o)].se, t[ls(o)].MaxA);
    }
    else {
        t[o].MaxA = t[rs(o)].MaxA; t[o].cnt = t[ls(o)].cnt + t[rs(o)].cnt;
        t[o].se = max(t[ls(o)].se, t[rs(o)].se);
    }
    t[o].MaxB = max(t[ls(o)].MaxB, t[rs(o)].MaxB);
    t[o].sum = t[ls(o)].sum + t[rs(o)].sum;
}

void build(int o, int l, int r) {
    if(l == r) { 
        t[o].MaxA = t[o].MaxB = t[o].sum = read(); 
        t[o].se = -inf; t[o].cnt = 1; 
        return ; 
    }
    build(ls(o), l, mid); build(rs(o), mid + 1, r);
    up(o);
} 

void modify(int o, int l, int r, int add1, int add1_, int add2, int add2_) {
    t[o].sum += 1ll * add1 * t[o].cnt + 1ll * add2 * (r - l + 1 - t[o].cnt);

    t[o].MaxB = max(t[o].MaxB, t[o].MaxA + add1_);
    t[o].add1_ = max(t[o].add1_, t[o].add1 + add1_);
    t[o].MaxA += add1; t[o].add1 += add1;

    t[o].add2_ = max(t[o].add2_, t[o].add2 + add2_);  
    if(t[o].se != -inf) t[o].se += add2; t[o].add2 += add2;
}

void down(int o, int l, int r) {
    int tmp = max(t[ls(o)].MaxA, t[rs(o)].MaxA);
    if(t[ls(o)].MaxA == tmp) 
        modify(ls(o), l, mid, t[o].add1, t[o].add1_, t[o].add2, t[o].add2_);
    else 
        modify(ls(o), l, mid, t[o].add2, t[o].add2_, t[o].add2, t[o].add2_); //add1,add_是維護區間最大值的標記,如果這個區間沒有父節點的最大值,那麼最大值標記不下傳

    if(t[rs(o)].MaxA == tmp) 
        modify(rs(o), mid + 1, r, t[o].add1, t[o].add1_, t[o].add2, t[o].add2_);
    else 
        modify(rs(o), mid + 1, r, t[o].add2, t[o].add2_, t[o].add2, t[o].add2_);
    t[o].add1 = t[o].add1_ = t[o].add2 = t[o].add2_ = 0;
}

void change_add(int o, int l, int r, int x, int y, int k) {
    if(x <= l && y >= r) { modify(o, l, r, k, k, k, k); return ; }
    down(o, l, r);
    if(x <= mid) change_add(ls(o), l, mid, x, y, k);
    if(y > mid) change_add(rs(o), mid + 1, r, x, y, k);
    up(o);
}

void change_min(int o, int l, int r, int x, int y, int k) {
    if(t[o].MaxA <= k) return ;
    if(x <= l && y >= r && t[o].MaxA > k && t[o].se < k)  {
        modify(o, l, r, k - t[o].MaxA, k - t[o].MaxA, 0, 0);
        return ;
    }
    down(o, l, r);
    if(x <= mid) change_min(ls(o), l, mid, x, y, k);
    if(y > mid) change_min(rs(o), mid + 1, r, x, y, k);
    up(o); 
}

long long query_sum(int o, int l, int r, int x, int y) {
    if(x <= l && y >= r) return t[o].sum;
    down(o, l, r);
    long long res = 0;
    if(x <= mid) res += query_sum(ls(o), l, mid, x, y);
    if(y > mid) res += query_sum(rs(o), mid + 1, r, x, y);
    return res;
}

int query_A(int o, int l, int r, int x, int y) {
    if(x <= l && y >= r) { return t[o].MaxA; }
    down(o, l, r);
    int res = -inf;
    if(x <= mid) res = max(res, query_A(ls(o), l, mid, x, y));
    if(y > mid) res = max(res, query_A(rs(o), mid + 1, r, x, y));
    return res;
}

int query_B(int o, int l, int r, int x, int y) {
    if(x <= l && y >= r) { return t[o].MaxB; }
    down(o, l, r);
    int res = -inf;
    if(x <= mid) res = max(res, query_B(ls(o), l, mid, x, y));
    if(y > mid) res = max(res, query_B(rs(o), mid + 1, r, x, y));
    return res;
}

int main() {

    n = read(); m = read();
    build(1, 1, n);
    for(int i = 1, opt, l, r;i <= m; i++) {
        opt = read(); l = read(); r = read();
        if(opt == 1) x = read(), change_add(1, 1, n, l, r, x);
        if(opt == 2) x = read(), change_min(1, 1, n, l, r, x);
        if(opt == 3) printf("%lld\n", query_sum(1, 1, n, l, r)); 
        if(opt == 4) printf("%d\n", query_A(1, 1, n, l, r));
        if(opt == 5) printf("%d\n", query_B(1, 1, n, l, r));
    }

    return 0;
}