1. 程式人生 > 其它 >【2022 省選訓練賽 Contest 18 A】B(Splay)

【2022 省選訓練賽 Contest 18 A】B(Splay)

B

題目連結:2022 省選訓練賽 Contest 18 A

題目大意

給你 n 個點,每個點第 i 天的代價是 b[i]+(i-1)a[i]。
然後要你在前 m 天每天選一個點,然後最小化總代價。

思路

首先考慮怎麼 DP,觀察到如果你選好的一個點集,你可以很容易的得到最優的選的方案。
即按 \(a_i\) 從大到小排,然後每次依次選,這樣增長的量就會最小。

我們可以按 \(a_i\) 從大到小排序來 DP,設 \(f_{i,j}\) 為搞定前 \(i\) 個,選了 \(j\) 個的最小費用。
然後每次下一個選或者不選轉移,顯然超時。
考慮維護每個 \(f_i\) 陣列,亦或者說是看一個新的數會產生什麼。

然後有一些性質。
首先就是 \(f_{i,j}\) 的點集必定是 \(f_{i,j+1}\) 點集的子集,你可以用反證法來得到。
(反正意思就是你不選這個點集的肯定不如 \(f_{i,j}\) 的點集加上一個數更優)

然後接著就是加入一個數,它能修改的範圍一定是一段字尾。
這個也是可以通過反證法得到,根據前面的性質就有了。

那我們就可以通過二分出貢獻的字尾。
然後你要支援區間插入,區間加,區間二分,可以用平衡樹來實現。
(過程就是維護每個點當前代表的選的個數,以及新選這個需要的費用)

(那你貢獻就是一個區間加值,就是後面的一段都要加上你這個 \(a_i\),因為你每往後一個位置,就多了 \(a_i\)

的費用)

程式碼

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC target("avx")
#pragma GCC optimize("Ofast")
#pragma GCC optimize("inline")
#pragma GCC optimize("-fgcse")
#pragma GCC optimize("-fgcse-lm")
#pragma GCC optimize("-fipa-sra")
#pragma GCC optimize("-ftree-pre")
#pragma GCC optimize("-ftree-vrp")
#pragma GCC optimize("-fpeephole2")
#pragma GCC optimize("-ffast-math")
#pragma GCC optimize("-fsched-spec")
#pragma GCC optimize("unroll-loops")
#pragma GCC optimize("-falign-jumps")
#pragma GCC optimize("-falign-loops")
#pragma GCC optimize("-falign-labels")
#pragma GCC optimize("-fdevirtualize")
#pragma GCC optimize("-fcaller-saves")
#pragma GCC optimize("-fcrossjumping")
#pragma GCC optimize("-fthread-jumps")
#pragma GCC optimize("-funroll-loops")
#pragma GCC optimize("-fwhole-program")
#pragma GCC optimize("-freorder-blocks")
#pragma GCC optimize("-fschedule-insns")
#pragma GCC optimize("inline-functions")
#pragma GCC optimize("-ftree-tail-merge")
#pragma GCC optimize("-fschedule-insns2")
#pragma GCC optimize("-fstrict-aliasing")
#pragma GCC optimize("-fstrict-overflow")
#pragma GCC optimize("-falign-functions")
#pragma GCC optimize("-fcse-skip-blocks")
#pragma GCC optimize("-fcse-follow-jumps")
#pragma GCC optimize("-fsched-interblock")
#pragma GCC optimize("-fpartial-inlining")
#pragma GCC optimize("no-stack-protector")
#pragma GCC optimize("-freorder-functions")
#pragma GCC optimize("-findirect-inlining")
#pragma GCC optimize("-fhoist-adjacent-loads")
#pragma GCC optimize("-frerun-cse-after-loop")
#pragma GCC optimize("inline-small-functions")
#pragma GCC optimize("-finline-small-functions")
#pragma GCC optimize("-ftree-switch-conversion")
#pragma GCC optimize("-foptimize-sibling-calls")
#pragma GCC optimize("-fexpensive-optimizations")
#pragma GCC optimize("-funsafe-loop-optimizations")
#pragma GCC optimize("inline-functions-called-once")
#pragma GCC optimize("-fdelete-null-pointer-checks")

#include<cstdio>
#include<algorithm>
#define ll long long 

using namespace std;

const int N = 1e6 + 100;
struct node {
	ll a, b;
}a[N];
int n, k, rt;
ll ans;

ll re; char c;
ll read() {
	re = 0; c = getchar();
	while (c < '0' || c > '9') c = getchar();
	while (c >= '0' && c <= '9') {
		re = (re << 3) + (re << 1) + c - '0'; 
		c = getchar();
	}
	return re;
}

bool cmp(node x, node y) {
	return x.a > y.a;
}

ll clac(int x, ll k) {
	return (k - 1) * a[x].a + a[x].b;
}

struct SPLAY {
	int tot, pl[N], ls[N], rs[N], lzyp[N], fa[N];
	ll val[N], lzyv[N];
	
	bool lrs(int x) {return ls[fa[x]] == x;}
	
	int newpoint(int x, ll k) {
		int now = ++tot;
		val[now] = k; pl[now] = x;
		return now; 
	}
	
	void downv(int x, ll va) {
		val[x] += va; lzyv[x] += va;
	}
	
	void downp(int x, int p) {
		pl[x] += p; lzyp[x] += p;
	}
	
	void down(int now) {
		if (lzyv[now]) {
			if (ls[now]) downv(ls[now], lzyv[now]);
			if (rs[now]) downv(rs[now], lzyv[now]);
			lzyv[now] = 0;
		}
		if (lzyp[now]) {
			if (ls[now]) downp(ls[now], lzyp[now]);
			if (rs[now]) downp(rs[now], lzyp[now]);
			lzyp[now] = 0;
		}
	}
	
	void rotate(int x) {
		int y = fa[x], z = fa[y];
		int b = lrs(x) ? rs[x] : ls[x];
		if (z) (lrs(y) ? ls[z] : rs[z]) = x;
		if (lrs(x)) rs[x] = y, ls[y] = b;
			else ls[x] = y, rs[y] = b;
		fa[x] = z; fa[y] = x;
		if (b) fa[b] = y;
	}
	
	void Splay(int x) {
		while (fa[x]) {
			if (fa[fa[x]]) {
				if (lrs(x) == lrs(fa[x])) rotate(fa[x]);
					else rotate(x);
			}
			rotate(x);
		}
		rt = x;
	}
	
	int get(int p) {
		int now = rt, re = p;
		while (now) {
			down(now);
			if (val[now] < clac(p, pl[now])) now = rs[now];
				else re = pl[now], now = ls[now];
		}
		return re;
	}
	
	void insert(int x) {
		int now = rt;
		while (1) {
			down(now);
			int tmp = pl[now] < pl[x] ? rs[now] : ls[now];
			if (!tmp) {
				(pl[now] < pl[x] ? rs[now] : ls[now]) = x;
				fa[x] = now; break;
			}
			else {
				now = tmp;
			}
		}
		Splay(x);
	}
	
	void count(int now) {
		if (pl[now] <= k) ans += val[now];
		down(now);
		if (ls[now]) count(ls[now]);
		if (rs[now]) count(rs[now]);
	}
}T;

int main() {
	n = read(); k = read();
	for (int i = 1; i <= n; i++) a[i] = (node){read(), read()};
	sort(a + 1, a + n + 1, cmp);
	
	rt = T.newpoint(1, clac(1, 1));
	for (int i = 2; i <= n; i++) {
		int pl = T.get(i);
		T.insert(T.newpoint(pl, clac(i, pl)));
		if (T.rs[rt]) {
			T.downp(T.rs[rt], 1);
			T.downv(T.rs[rt], a[i].a);
		}
	}
	
	T.count(rt);
	printf("%lld", ans);
	
	return 0;
}