1. 程式人生 > 實用技巧 >【題解】 「NOI2014」購票 dp+斜率優化+點分治 LOJ2249

【題解】 「NOI2014」購票 dp+斜率優化+點分治 LOJ2249

Legend

Link \(\textrm{to LOJ}\)

給定一棵內向樹,每個結點(除了根)有如下五個資訊:

  • 父親結點 \(f_i\)
  • 到父親的距離 \(s_i\)
  • 起步價格 \(q_i\),表示在該節點乘坐交通工具的起步價格;
  • 單位路程價格 \(p_i\),表示在該節點乘坐交通工具的單位路程價格;
  • 最大路程限制 \(l_i\),超過這個距離的祖先結點不可以用該點的交通工具一次性到達;

其中每次乘車前都把必須指定終點並不能中途下車。

求所有結點到根的最短路。

資料範圍懶得寫了。

時空:\(\rm{3s/513MiB}\)

Editorial

感覺是一道斜率優化強行上樹的題目。

事實就是這樣,我們可以很快推出來序列上的式子。

\(dp_i= \min\limits _{j=1}^{i-1} dp_j + (S_i - S_j)p_i + q_i\),其中 \(S_i\) 表示結點 \(i\) 到根的距離。

\(dp_j = p_i S_j + dp_i - S_ip_i - q_i\),轉移點位置 \((S_j,dp_j)\),斜率 \(p_i\),最小化截距。

哼哼,樹上斜率優化?不就是回溯的時候重置一下被修改的位置嗎?

哼哼,\(p_i\) 沒有單調性?就二分一下。

你激動地打完程式碼交上去發現只有 \(50\) 分,仔細一看,發現方程忘記了 \((S_i - S_j \le l_i)\) 的限制。

然後你就發現這個東西非常不好維護 >_<,窮途末路了嗎?

不……序列上這樣子的問題有個解決方法是 \(\rm{CDQ}\) 分治。

把序列分成前後兩部分,先計算左側的 \(dp\) 陣列,再統計跨越左右的,再遞迴右邊。

統計跨越左右的時候,要按照 \(l_i\) 從右到左排序,就可以照樣維護凸包了(只不過插入順序反過來了而已)。

來到樹上的話,就直接換成點分治就好了。

實現上的細節是當前的分治中心 \(x\) 並不會被計算到上半部分的子樹(對於序列就是左側),所以要進行暴力更新。

總複雜度就是 \(O(n \log^2 n)\)

Code

注意到這題如果用叉積判凸包就會爆 \(\textrm{long long}\)

於是我直接莽了一發用斜率,沒想到過了。

#include <bits/stdc++.h>

#define debug(...) ;//fprintf(stderr ,__VA_ARGS__)
#define LL long long
#define __FILE(x)\
	freopen(#x".in" ,"r" ,stdin);\
	freopen(#x".out" ,"w" ,stdout)

using namespace std;

const int MX = 2e5 + 233;

LL read(){
	char k = getchar(); LL x = 0;
	while(k < '0' || k > '9') k = getchar();
	while(k >= '0' && k <= '9') x = x * 10 + k - '0' ,k = getchar();
	return x;
}

int head[MX] ,tot;
struct edge{
	int node ,next;
	LL w;
}h[MX << 1];
void addedge(int u ,int v ,LL w ,int flg = 1){
	h[++tot] = (edge){v ,head[u] ,w} ,head[u] = tot;
	if(flg) addedge(v ,u ,w ,false);
}

int n ,t;
int fa[MX];
LL S[MX] ,p[MX] ,q[MX] ,lim[MX];

void getS(int x){
	// debug("visiting %d\n" ,x);
	for(int i = head[x] ,d ; i ; i = h[i].next){
		if((d = h[i].node) == fa[x]) continue;
		S[d] = S[x] + h[i].w;
		getS(d);
	}
}

int R ,sz[MX] ,mxsz[MX] ,subsize ,vis[MX];
void getGra(int x ,int f){
	sz[x] = 1 ,mxsz[x] = 0;
	for(int i = head[x] ,d ; i ; i = h[i].next){
		if(vis[d = h[i].node] || d == f) continue;
		getGra(d ,x);
		sz[x] += sz[d];
		mxsz[x] = max(mxsz[x] ,sz[d]);
	}
	mxsz[x] = max(mxsz[x] ,subsize - sz[x]);
	if(mxsz[x] < mxsz[R]) R = x;
}

LL dp[MX];
void doit(int x ,int top);
void solve(int x){
	debug("SOLVE %d\n" ,x);
	vis[x] = 1;
	int upper = fa[x];
	while(!vis[upper]) upper = fa[upper];
	for(int i = head[x] ,d ; i ; i = h[i].next){
		if((d = h[i].node) == fa[x]){
			if(!vis[d]){
				mxsz[R = 0] = subsize = sz[d];
				getGra(d ,x);
				solve(R);
			}
			break;
		}
	}
	for(int now = fa[x] ; now != upper && S[x] - S[now] <= lim[x] ; now = fa[now]){
		dp[x] = min(dp[x] ,dp[now] + (S[x] - S[now]) * p[x] + q[x]);
	}
	// debug("solve %d\n" ,x);
	doit(x ,upper);
	for(int i = head[x] ,d ; i ; i = h[i].next){
		if(vis[d = h[i].node] || d == fa[x]) continue;
		mxsz[R = 0] = subsize = sz[d];
		getGra(d ,x);
		solve(R);
	}
}

int down[MX] ,dcnt;

bool cmp(int a ,int b){return S[a] - lim[a] > S[b] - lim[b];}

void getDown(int x ,int f ,LL dist){	
	if(dist <= lim[x]) down[++dcnt] = x;
	for(int i = head[x] ,d ; i ; i = h[i].next){
		if(vis[d = h[i].node] || d == f) continue;
		getDown(d ,x ,dist + h[i].w);
	}
}

int que[MX] ,TAIL;

double slope(int j1 ,int j2){
	return 1.0 * (dp[j1] - dp[j2]) / (S[j1] - S[j2]);
}

int search(int l ,int r ,int x){
	++l;
	int mid;
	while(l <= r){
		mid = (l + r) >> 1;
		int j1 = que[mid - 1] ,j2 = que[mid];
		if(slope(j1 ,j2) > p[x]){
			l = mid + 1;
		}
		else{
			r = mid - 1;
		}
	}
	return l - 1;
}

void doit(int x ,int top){

	dcnt = 0;
	for(int i = head[x] ,d ; i ; i = h[i].next){
		if((d = h[i].node) == fa[x]) continue;
		getDown(d ,x ,h[i].w);
	}

	std::sort(down + 1 ,down + 1 + dcnt ,cmp);
	// 越深的排序與越靠前

	TAIL = 0;
	que[++TAIL] = x;
	int trcnt = fa[x];

	for(int dd = 1 ; dd <= dcnt ; ++dd){
		int now = down[dd];
		while(trcnt != top && S[now] - S[trcnt] <= lim[now]){
			while(1 < TAIL && slope(trcnt ,que[TAIL]) >= slope(que[TAIL] ,que[TAIL - 1])){
				--TAIL;
			}
			que[++TAIL] = trcnt;
			trcnt = fa[trcnt];
		}
		int tr = que[search(1 ,TAIL ,now)];
		// printf("%d tr from %d\n" ,now ,tr);
		dp[now] = min(dp[now] ,dp[tr] + (S[now] - S[tr]) * p[now] + q[now]);
	}
}

int main(){
	vis[0] = 1;
	memset(dp ,0x3f ,sizeof dp);
	n = read() ,t = read();
	for(int i = 2 ; i <= n ; ++i){
		fa[i] = read();
		LL w = read();
		p[i] = read();
		q[i] = read();
		lim[i] = read();
		addedge(i ,fa[i] ,w);
	}
	getS(1);

	getGra(1 ,0);

	dp[1] = 0;
	solve(1);
	for(int i = 2 ; i <= n ; ++i)
		printf("%lld\n" ,dp[i]);
	return 0;
}