P6242 【模板】線段樹 3 線段樹維護歷史最值+區間取min
阿新 • • 發佈:2020-09-17
P6242 【模板】線段樹 3
線段樹維護歷史最值+區間取min。
區間取min:
線段樹維護一個區間最大值\((MaxA)\)和嚴格次大值\((se)\),還要維護最大值個數\(cnt\),區間和\(sum\),然後分情況:(設當前與\(k\)取min)
當\(k >= t[o].MaxA\)時,直接返回;
當\(t[o].se < k < t[o].MaxA\)時,\(t[o].sum += t[o].cnt * (k - t[o].MaxA)\),\(t[o].MaxA = k\);
當\(k <= t[o].se\)時,繼續往下遞迴。
具體維護看程式碼:
void up(int o) { if(t[ls(o)].MaxA > t[rs(o)].MaxA) { t[o].MaxA = t[ls(o)].MaxA; t[o].cnt = t[ls(o)].cnt; t[o].se = max(t[ls(o)].se, t[rs(o)].MaxA); } else if(t[ls(o)].MaxA < t[rs(o)].MaxA) { t[o].MaxA = t[rs(o)].MaxA; t[o].cnt = t[rs(o)].cnt; t[o].se = max(t[rs(o)].se, t[ls(o)].MaxA); } else { t[o].MaxA = t[rs(o)].MaxA; t[o].cnt = t[ls(o)].cnt + t[rs(o)].cnt; t[o].se = max(t[ls(o)].se, t[rs(o)].se); } t[o].MaxB = max(t[ls(o)].MaxB, t[rs(o)].MaxB); t[o].sum = t[ls(o)].sum + t[rs(o)].sum; }
維護歷史最大值:
要維護4個標記:最大值加減標記\((add1)\),最大值歷史最大加減標記\((add1\)_ \()\),非最大值加減標記\((add2)\),非最大值歷史最大加減標記\((add2\)_\()\)。
void modify(int o, int l, int r, int add1, int add1_, int add2, int add2_) { t[o].sum += 1ll * add1 * t[o].cnt + 1ll * add2 * (r - l + 1 - t[o].cnt); t[o].MaxB = max(t[o].MaxB, t[o].MaxA + add1_); // MaxB代表歷史最大值,用add1_更新 t[o].add1_ = max(t[o].add1_, t[o].add1 + add1_); //標記也記得更新 t[o].MaxA += add1; t[o].add1 += add1; t[o].add2_ = max(t[o].add2_, t[o].add2 + add2_); if(t[o].se != -inf) t[o].se += add2; t[o].add2 += add2; }
完整程式碼:
#include <bits/stdc++.h>
#define ls(o) (o << 1)
#define rs(o) (o << 1 | 1)
#define mid ((l + r) >> 1)
using namespace std;
inline long long read() {
long long s = 0, f = 1; char ch;
while(!isdigit(ch = getchar())) (ch == '-') && (f = -f);
for(s = ch ^ 48;isdigit(ch = getchar()); s = (s << 1) + (s << 3) + (ch ^ 48));
return s * f;
}
const int N = 1e6 + 5, inf = 2e9;
int n, m;
long long x;
struct tree {
long long sum;
int MaxA, MaxB, cnt, se;
int add1, add1_, add2, add2_;
} t[N << 2];
void up(int o) {
if(t[ls(o)].MaxA > t[rs(o)].MaxA) {
t[o].MaxA = t[ls(o)].MaxA; t[o].cnt = t[ls(o)].cnt;
t[o].se = max(t[ls(o)].se, t[rs(o)].MaxA);
}
else if(t[ls(o)].MaxA < t[rs(o)].MaxA) {
t[o].MaxA = t[rs(o)].MaxA; t[o].cnt = t[rs(o)].cnt;
t[o].se = max(t[rs(o)].se, t[ls(o)].MaxA);
}
else {
t[o].MaxA = t[rs(o)].MaxA; t[o].cnt = t[ls(o)].cnt + t[rs(o)].cnt;
t[o].se = max(t[ls(o)].se, t[rs(o)].se);
}
t[o].MaxB = max(t[ls(o)].MaxB, t[rs(o)].MaxB);
t[o].sum = t[ls(o)].sum + t[rs(o)].sum;
}
void build(int o, int l, int r) {
if(l == r) {
t[o].MaxA = t[o].MaxB = t[o].sum = read();
t[o].se = -inf; t[o].cnt = 1;
return ;
}
build(ls(o), l, mid); build(rs(o), mid + 1, r);
up(o);
}
void modify(int o, int l, int r, int add1, int add1_, int add2, int add2_) {
t[o].sum += 1ll * add1 * t[o].cnt + 1ll * add2 * (r - l + 1 - t[o].cnt);
t[o].MaxB = max(t[o].MaxB, t[o].MaxA + add1_);
t[o].add1_ = max(t[o].add1_, t[o].add1 + add1_);
t[o].MaxA += add1; t[o].add1 += add1;
t[o].add2_ = max(t[o].add2_, t[o].add2 + add2_);
if(t[o].se != -inf) t[o].se += add2; t[o].add2 += add2;
}
void down(int o, int l, int r) {
int tmp = max(t[ls(o)].MaxA, t[rs(o)].MaxA);
if(t[ls(o)].MaxA == tmp)
modify(ls(o), l, mid, t[o].add1, t[o].add1_, t[o].add2, t[o].add2_);
else
modify(ls(o), l, mid, t[o].add2, t[o].add2_, t[o].add2, t[o].add2_); //add1,add_是維護區間最大值的標記,如果這個區間沒有父節點的最大值,那麼最大值標記不下傳
if(t[rs(o)].MaxA == tmp)
modify(rs(o), mid + 1, r, t[o].add1, t[o].add1_, t[o].add2, t[o].add2_);
else
modify(rs(o), mid + 1, r, t[o].add2, t[o].add2_, t[o].add2, t[o].add2_);
t[o].add1 = t[o].add1_ = t[o].add2 = t[o].add2_ = 0;
}
void change_add(int o, int l, int r, int x, int y, int k) {
if(x <= l && y >= r) { modify(o, l, r, k, k, k, k); return ; }
down(o, l, r);
if(x <= mid) change_add(ls(o), l, mid, x, y, k);
if(y > mid) change_add(rs(o), mid + 1, r, x, y, k);
up(o);
}
void change_min(int o, int l, int r, int x, int y, int k) {
if(t[o].MaxA <= k) return ;
if(x <= l && y >= r && t[o].MaxA > k && t[o].se < k) {
modify(o, l, r, k - t[o].MaxA, k - t[o].MaxA, 0, 0);
return ;
}
down(o, l, r);
if(x <= mid) change_min(ls(o), l, mid, x, y, k);
if(y > mid) change_min(rs(o), mid + 1, r, x, y, k);
up(o);
}
long long query_sum(int o, int l, int r, int x, int y) {
if(x <= l && y >= r) return t[o].sum;
down(o, l, r);
long long res = 0;
if(x <= mid) res += query_sum(ls(o), l, mid, x, y);
if(y > mid) res += query_sum(rs(o), mid + 1, r, x, y);
return res;
}
int query_A(int o, int l, int r, int x, int y) {
if(x <= l && y >= r) { return t[o].MaxA; }
down(o, l, r);
int res = -inf;
if(x <= mid) res = max(res, query_A(ls(o), l, mid, x, y));
if(y > mid) res = max(res, query_A(rs(o), mid + 1, r, x, y));
return res;
}
int query_B(int o, int l, int r, int x, int y) {
if(x <= l && y >= r) { return t[o].MaxB; }
down(o, l, r);
int res = -inf;
if(x <= mid) res = max(res, query_B(ls(o), l, mid, x, y));
if(y > mid) res = max(res, query_B(rs(o), mid + 1, r, x, y));
return res;
}
int main() {
n = read(); m = read();
build(1, 1, n);
for(int i = 1, opt, l, r;i <= m; i++) {
opt = read(); l = read(); r = read();
if(opt == 1) x = read(), change_add(1, 1, n, l, r, x);
if(opt == 2) x = read(), change_min(1, 1, n, l, r, x);
if(opt == 3) printf("%lld\n", query_sum(1, 1, n, l, r));
if(opt == 4) printf("%d\n", query_A(1, 1, n, l, r));
if(opt == 5) printf("%d\n", query_B(1, 1, n, l, r));
}
return 0;
}