【題解】 「NOI2014」購票 dp+斜率優化+點分治 LOJ2249
Legend
Link \(\textrm{to LOJ}\)。
給定一棵內向樹,每個結點(除了根)有如下五個資訊:
- 父親結點 \(f_i\);
- 到父親的距離 \(s_i\);
- 起步價格 \(q_i\),表示在該節點乘坐交通工具的起步價格;
- 單位路程價格 \(p_i\),表示在該節點乘坐交通工具的單位路程價格;
- 最大路程限制 \(l_i\),超過這個距離的祖先結點不可以用該點的交通工具一次性到達;
其中每次乘車前都把必須指定終點並不能中途下車。
求所有結點到根的最短路。
資料範圍懶得寫了。
時空:\(\rm{3s/513MiB}\)。
Editorial
感覺是一道斜率優化強行上樹的題目。
事實就是這樣,我們可以很快推出來序列上的式子。
\(dp_i= \min\limits _{j=1}^{i-1} dp_j + (S_i - S_j)p_i + q_i\),其中 \(S_i\) 表示結點 \(i\) 到根的距離。
\(dp_j = p_i S_j + dp_i - S_ip_i - q_i\),轉移點位置 \((S_j,dp_j)\),斜率 \(p_i\),最小化截距。
哼哼,樹上斜率優化?不就是回溯的時候重置一下被修改的位置嗎?
哼哼,\(p_i\) 沒有單調性?就二分一下。
你激動地打完程式碼交上去發現只有 \(50\) 分,仔細一看,發現方程忘記了 \((S_i - S_j \le l_i)\) 的限制。
然後你就發現這個東西非常不好維護 >_<,窮途末路了嗎?
不……序列上這樣子的問題有個解決方法是 \(\rm{CDQ}\) 分治。
把序列分成前後兩部分,先計算左側的 \(dp\) 陣列,再統計跨越左右的,再遞迴右邊。
統計跨越左右的時候,要按照 \(l_i\) 從右到左排序,就可以照樣維護凸包了(只不過插入順序反過來了而已)。
來到樹上的話,就直接換成點分治就好了。
實現上的細節是當前的分治中心 \(x\) 並不會被計算到上半部分的子樹(對於序列就是左側),所以要進行暴力更新。
總複雜度就是 \(O(n \log^2 n)\)。
Code
注意到這題如果用叉積判凸包就會爆 \(\textrm{long long}\)。
於是我直接莽了一發用斜率,沒想到過了。
#include <bits/stdc++.h>
#define debug(...) ;//fprintf(stderr ,__VA_ARGS__)
#define LL long long
#define __FILE(x)\
freopen(#x".in" ,"r" ,stdin);\
freopen(#x".out" ,"w" ,stdout)
using namespace std;
const int MX = 2e5 + 233;
LL read(){
char k = getchar(); LL x = 0;
while(k < '0' || k > '9') k = getchar();
while(k >= '0' && k <= '9') x = x * 10 + k - '0' ,k = getchar();
return x;
}
int head[MX] ,tot;
struct edge{
int node ,next;
LL w;
}h[MX << 1];
void addedge(int u ,int v ,LL w ,int flg = 1){
h[++tot] = (edge){v ,head[u] ,w} ,head[u] = tot;
if(flg) addedge(v ,u ,w ,false);
}
int n ,t;
int fa[MX];
LL S[MX] ,p[MX] ,q[MX] ,lim[MX];
void getS(int x){
// debug("visiting %d\n" ,x);
for(int i = head[x] ,d ; i ; i = h[i].next){
if((d = h[i].node) == fa[x]) continue;
S[d] = S[x] + h[i].w;
getS(d);
}
}
int R ,sz[MX] ,mxsz[MX] ,subsize ,vis[MX];
void getGra(int x ,int f){
sz[x] = 1 ,mxsz[x] = 0;
for(int i = head[x] ,d ; i ; i = h[i].next){
if(vis[d = h[i].node] || d == f) continue;
getGra(d ,x);
sz[x] += sz[d];
mxsz[x] = max(mxsz[x] ,sz[d]);
}
mxsz[x] = max(mxsz[x] ,subsize - sz[x]);
if(mxsz[x] < mxsz[R]) R = x;
}
LL dp[MX];
void doit(int x ,int top);
void solve(int x){
debug("SOLVE %d\n" ,x);
vis[x] = 1;
int upper = fa[x];
while(!vis[upper]) upper = fa[upper];
for(int i = head[x] ,d ; i ; i = h[i].next){
if((d = h[i].node) == fa[x]){
if(!vis[d]){
mxsz[R = 0] = subsize = sz[d];
getGra(d ,x);
solve(R);
}
break;
}
}
for(int now = fa[x] ; now != upper && S[x] - S[now] <= lim[x] ; now = fa[now]){
dp[x] = min(dp[x] ,dp[now] + (S[x] - S[now]) * p[x] + q[x]);
}
// debug("solve %d\n" ,x);
doit(x ,upper);
for(int i = head[x] ,d ; i ; i = h[i].next){
if(vis[d = h[i].node] || d == fa[x]) continue;
mxsz[R = 0] = subsize = sz[d];
getGra(d ,x);
solve(R);
}
}
int down[MX] ,dcnt;
bool cmp(int a ,int b){return S[a] - lim[a] > S[b] - lim[b];}
void getDown(int x ,int f ,LL dist){
if(dist <= lim[x]) down[++dcnt] = x;
for(int i = head[x] ,d ; i ; i = h[i].next){
if(vis[d = h[i].node] || d == f) continue;
getDown(d ,x ,dist + h[i].w);
}
}
int que[MX] ,TAIL;
double slope(int j1 ,int j2){
return 1.0 * (dp[j1] - dp[j2]) / (S[j1] - S[j2]);
}
int search(int l ,int r ,int x){
++l;
int mid;
while(l <= r){
mid = (l + r) >> 1;
int j1 = que[mid - 1] ,j2 = que[mid];
if(slope(j1 ,j2) > p[x]){
l = mid + 1;
}
else{
r = mid - 1;
}
}
return l - 1;
}
void doit(int x ,int top){
dcnt = 0;
for(int i = head[x] ,d ; i ; i = h[i].next){
if((d = h[i].node) == fa[x]) continue;
getDown(d ,x ,h[i].w);
}
std::sort(down + 1 ,down + 1 + dcnt ,cmp);
// 越深的排序與越靠前
TAIL = 0;
que[++TAIL] = x;
int trcnt = fa[x];
for(int dd = 1 ; dd <= dcnt ; ++dd){
int now = down[dd];
while(trcnt != top && S[now] - S[trcnt] <= lim[now]){
while(1 < TAIL && slope(trcnt ,que[TAIL]) >= slope(que[TAIL] ,que[TAIL - 1])){
--TAIL;
}
que[++TAIL] = trcnt;
trcnt = fa[trcnt];
}
int tr = que[search(1 ,TAIL ,now)];
// printf("%d tr from %d\n" ,now ,tr);
dp[now] = min(dp[now] ,dp[tr] + (S[now] - S[tr]) * p[now] + q[now]);
}
}
int main(){
vis[0] = 1;
memset(dp ,0x3f ,sizeof dp);
n = read() ,t = read();
for(int i = 2 ; i <= n ; ++i){
fa[i] = read();
LL w = read();
p[i] = read();
q[i] = read();
lim[i] = read();
addedge(i ,fa[i] ,w);
}
getS(1);
getGra(1 ,0);
dp[1] = 0;
solve(1);
for(int i = 2 ; i <= n ; ++i)
printf("%lld\n" ,dp[i]);
return 0;
}