[bzoj4504]K個串【可持久化線段樹】【堆】
阿新 • • 發佈:2019-01-06
【題目連結】
【題解】
首先記下每個點向右所控制的區間,就是它到下一個與它相同的位置-1。
我們考慮對於每個左端點維護一棵線段樹下標表示以該點為右端點的區間的答案。
那麼左端點為1的區間可以暴力求出。
對於兩個相鄰的左端點,只有所控制的區間會減去的值。用可持久化線段樹+標記永久化即可。
然後將每個點的對應最大值放入堆中,每次取出最大的並將該左端點的次大值放入。
時間複雜度
【程式碼】
# include <bits/stdc++.h>
# define ll long long
# define inf 0x3f3f3f3f
# define N 100100
using namespace std;
int read(){
int tmp = 0, fh = 1; char ch = getchar();
while (ch < '0' || ch > '9') {if (ch == '-') fh = -1; ch = getchar(); }
while (ch >= '0' && ch <= '9'){tmp = tmp * 10 + ch - '0'; ch = getchar(); }
return tmp * fh;
}
const ll infll = 0x3f3f3f3f3f3f3f3fll;
struct Tree{
int pl, pr, id;
ll mx, tag;
}T[N * 100];
struct Node{
ll sum; int belong, id;
};
bool operator < (Node x, Node y){return x.sum < y.sum; }
priority_queue <Node> hp;
map <int, int> mp;
int n, place, k, nxt[N], num[N], cnt[N], rt[N];
ll sum[N], ans;
void pushup(int p){
ll l = T[T[p].pl].mx + T[T[p].pl].tag, r = T[T[p].pr].mx + T[T[p].pr].tag;
if (l > r) T[p].mx = l, T[p].id = T[T[p].pl].id;
else T[p].mx = r, T[p].id = T[T[p].pr].id;
}
void build(int &p, int l, int r){
p = ++place;
if (l != r){
int mid = (l + r) / 2;
build(T[p].pl, l, mid);
build(T[p].pr, mid + 1, r);
pushup(p);
}
else {
T[p].mx = sum[l];
T[p].id = l;
}
}
void modify(int &p, int las, int ql, int qr, ll x, int l, int r){
p = ++place;
T[p] = T[las];
if (ql == l && qr == r){
T[p].tag += x;
return;
}
int mid = (l + r) / 2;
if (mid >= qr) modify(T[p].pl, T[p].pl, ql, qr, x, l, mid);
else if (mid < ql) modify(T[p].pr, T[p].pr, ql, qr, x, mid + 1, r);
else modify(T[p].pl, T[p].pl, ql, mid, x, l, mid),
modify(T[p].pr, T[p].pr, mid + 1, qr, x, mid + 1, r);
pushup(p);
}
Node query(int p, int x){
return (Node){T[p].mx + T[p].tag, x, T[p].id};
}
int main(){
//freopen("A.in", "r", stdin);
//freopen("A.out", "w", stdout);
n = read(), k = read();
for (int i = 1; i <= n; i++) num[i] = read();
for (int i = n; i >= 1; i--){
if (mp.find(num[i]) != mp.end()) nxt[i] = mp[num[i]];
else nxt[i] = n + 1;
mp[num[i]] = i;
}
for (int i = 1; i <= n; i++) cnt[i] = n - i + 1;
mp.clear();
for (int i = 1; i <= n; i++){
sum[i] = sum[i - 1];
if (mp.find(num[i]) == mp.end()) sum[i] += num[i];
mp[num[i]] = 1;
}
build(rt[1], 1, n);
Node now = query(rt[1], 1);
hp.push(now);
for (int i = 2; i <= n; i++){
rt[i] = rt[i - 1];
if (i != nxt[i - 1]) modify(rt[i], rt[i], i, nxt[i - 1] - 1, -num[i - 1], 1, n);
modify(rt[i], rt[i], i - 1, i - 1, -infll, 1, n);
Node now = query(rt[i], i);
hp.push(now);
}
for (int i = 1; i <= k; i++){
Node now = hp.top(); hp.pop();
ans = now.sum;
cnt[now.belong]--;
if (cnt[now.belong] != 0){
modify(rt[now.belong], rt[now.belong], now.id, now.id, -infll, 1, n);
now = query(rt[now.belong], now.belong);
hp.push(now);
}
//printf("%lld\n", ans);
}
printf("%lld\n", ans);
return 0;
}