1. 程式人生 > >[bzoj4504]K個串【可持久化線段樹】【堆】

[bzoj4504]K個串【可持久化線段樹】【堆】

【題目連結】
  
【題解】
  首先記下每個點向右所控制的區間,就是它到下一個與它相同的位置-1。
  我們考慮對於每個左端點維護一棵線段樹下標表示以該點為右端點的區間的答案。
  那麼左端點為1的區間可以O(N)暴力求出。
  對於兩個相鄰的左端點i,i+1,只有i所控制的區間會減去i的值。用可持久化線段樹+標記永久化即可。
  然後將每個點的對應最大值放入堆中,每次取出最大的並將該左端點的次大值放入。
  時間複雜度O((N+K)logN)
【程式碼】

# include <bits/stdc++.h>
# define    ll      long long
# define inf 0x3f3f3f3f # define N 100100 using namespace std; int read(){ int tmp = 0, fh = 1; char ch = getchar(); while (ch < '0' || ch > '9') {if (ch == '-') fh = -1; ch = getchar(); } while (ch >= '0' && ch <= '9'){tmp = tmp * 10 + ch - '0'; ch = getchar(); } return
tmp * fh; } const ll infll = 0x3f3f3f3f3f3f3f3fll; struct Tree{ int pl, pr, id; ll mx, tag; }T[N * 100]; struct Node{ ll sum; int belong, id; }; bool operator < (Node x, Node y){return x.sum < y.sum; } priority_queue <Node> hp; map <int, int> mp; int n, place, k, nxt[N], num[N], cnt[N], rt[N]; ll sum[N], ans; void
pushup(int p){ ll l = T[T[p].pl].mx + T[T[p].pl].tag, r = T[T[p].pr].mx + T[T[p].pr].tag; if (l > r) T[p].mx = l, T[p].id = T[T[p].pl].id; else T[p].mx = r, T[p].id = T[T[p].pr].id; } void build(int &p, int l, int r){ p = ++place; if (l != r){ int mid = (l + r) / 2; build(T[p].pl, l, mid); build(T[p].pr, mid + 1, r); pushup(p); } else { T[p].mx = sum[l]; T[p].id = l; } } void modify(int &p, int las, int ql, int qr, ll x, int l, int r){ p = ++place; T[p] = T[las]; if (ql == l && qr == r){ T[p].tag += x; return; } int mid = (l + r) / 2; if (mid >= qr) modify(T[p].pl, T[p].pl, ql, qr, x, l, mid); else if (mid < ql) modify(T[p].pr, T[p].pr, ql, qr, x, mid + 1, r); else modify(T[p].pl, T[p].pl, ql, mid, x, l, mid), modify(T[p].pr, T[p].pr, mid + 1, qr, x, mid + 1, r); pushup(p); } Node query(int p, int x){ return (Node){T[p].mx + T[p].tag, x, T[p].id}; } int main(){ //freopen("A.in", "r", stdin); //freopen("A.out", "w", stdout); n = read(), k = read(); for (int i = 1; i <= n; i++) num[i] = read(); for (int i = n; i >= 1; i--){ if (mp.find(num[i]) != mp.end()) nxt[i] = mp[num[i]]; else nxt[i] = n + 1; mp[num[i]] = i; } for (int i = 1; i <= n; i++) cnt[i] = n - i + 1; mp.clear(); for (int i = 1; i <= n; i++){ sum[i] = sum[i - 1]; if (mp.find(num[i]) == mp.end()) sum[i] += num[i]; mp[num[i]] = 1; } build(rt[1], 1, n); Node now = query(rt[1], 1); hp.push(now); for (int i = 2; i <= n; i++){ rt[i] = rt[i - 1]; if (i != nxt[i - 1]) modify(rt[i], rt[i], i, nxt[i - 1] - 1, -num[i - 1], 1, n); modify(rt[i], rt[i], i - 1, i - 1, -infll, 1, n); Node now = query(rt[i], i); hp.push(now); } for (int i = 1; i <= k; i++){ Node now = hp.top(); hp.pop(); ans = now.sum; cnt[now.belong]--; if (cnt[now.belong] != 0){ modify(rt[now.belong], rt[now.belong], now.id, now.id, -infll, 1, n); now = query(rt[now.belong], now.belong); hp.push(now); } //printf("%lld\n", ans); } printf("%lld\n", ans); return 0; }