CF Round #783 D - Optimal Partition
阿新 • • 發佈:2022-05-12
D - Optimal Partition
線段樹 + dp
設 s[i] 為字首和陣列,f[i] 為考慮前 i 個數的最大答案。當列舉到第 i 個數時,有狀態轉移方程如下
\[1. \;s[i] -s[j]>0,\;即\;s[j]<s[i]\;(j<i)\\f[i]=max(f[j]+i-j)\;即\;f[i]-i=max(f[j]-j)\\2. \;s[i] -s[j]==0,\;即\;s[j==s[i]\;(j<i)\\f[i]=max(f[j])\\\\3. \;s[i] -s[j]<0,\;即\;s[j]>s[i]\;(j<i)\\f[i]=max(f[j]-i+j)\;即\;f[i]+i=max(f[j]+j)\\ \]可用 s[i] 作為下標(離散化),建立三顆線段樹,分別維護 f[i] - i, f[i], f[i] + i 的最大值
對於第一種情況,設 s[i] 離散化後的下標為 idx, 找到 1 ~ idx 中 f[i] - i 的最大值後 + i 即使當前答案
第二、三種同理
注意狀態轉移的起點是 s[0] = 0, f[0] = 0, 離散化時要考慮 s[0], 並且一開始將 s[0] 的答案 f[0] 插入到線段樹中
線段樹包含 0 ~ n 共 n + 1 個點,要開 n + 1 的空間
#include <iostream> #include <cstring> #include <algorithm> #include <vector> using namespace std; typedef long long ll; const int N = 5e5 + 10; const ll INF = 4e18; int n; ll a[N], s[N], f[N]; vector<ll> alls; struct Node { int l, r; ll maxn; }tr[3][N<<2]; void pushup(int id, int u) { tr[id][u].maxn = max(tr[id][u<<1].maxn, tr[id][u<<1|1].maxn); } void build(int id, int u, int l, int r) { tr[id][u] = {l, r}; if (l == r) { tr[id][u].maxn = -INF; return; } int mid = l + r >> 1; build(id, u << 1, l, mid); build(id, u << 1 | 1, mid + 1, r); pushup(id, u); } void modify(int id, int u, int idx, ll k) { Node &root = tr[id][u]; if (root.l == idx && root.r == idx) { root.maxn = max(root.maxn, k); return; } int mid = root.l + root.r >> 1; if (idx <= mid) modify(id, u << 1, idx, k); else modify(id, u << 1 | 1, idx, k); pushup(id, u); } ll query(int id, int u, int l, int r) { Node &root = tr[id][u]; if (root.l >= l && root.r <= r) return root.maxn; int mid = root.l + root.r >> 1; ll v = -INF; if (l <= mid) v = query(id, u << 1, l, r); if (r > mid) v = max(v, query(id, u << 1 | 1, l, r)); return v; } int find(ll x) { return lower_bound(alls.begin(), alls.end(), x) - alls.begin() + 1; } int main() { ios::sync_with_stdio(false), cin.tie(0), cout.tie(0); int T; cin >> T; while(T--) { cin >> n; alls.clear(); alls.push_back(0); for (int i = 1; i <= n; i++) { cin >> a[i]; s[i] = s[i-1] + a[i]; alls.push_back(s[i]); } sort(alls.begin(), alls.end()); alls.erase(unique(alls.begin(), alls.end()), alls.end()); for (int i = 0; i < 3; i++) { build(i, 1, 1, n + 1); modify(i, 1, find(s[0]), 0); } for (int i = 1; i <= n; i++) { int idx = find(s[i]); ll t1 = query(0, 1, 1, idx - 1); ll t2 = query(1, 1, idx, idx); ll t3 = query(2, 1, idx + 1, n + 1); f[i] = max({t1 + i, t2, t3 - i}); modify(0, 1, idx, f[i] - i); modify(1, 1, idx, f[i]); modify(2, 1, idx, f[i] + i); } cout << f[n] << endl; } return 0; }