【CF671D】Roads in Yusland
阿新 • • 發佈:2021-10-20
題目
題目連結:https://codeforces.com/problemset/problem/671/D
- 給定一棵 \(n\) 個點的以 \(1\) 為根的樹。
- 有 \(m\) 條路徑 \((x,y)\),保證 \(y\) 是 \(x\) 或 \(x\) 的祖先,每條路徑有一個權值。
- 你要在這些路徑中選擇若干條路徑,使它們能覆蓋每條邊,同時權值和最小。
- \(n,m \le 3 \times 10^5\)。
思路
設 \(f[x]\) 表示覆蓋點 \(x\) 子樹內所有邊以及 \(x\) 與其父親的邊的最小代價。
但是很明顯 \(f[x]\) 不能簡單轉移。因為有可能花更多代價,覆蓋 \(x\) 的祖先更多,這種情況是可能最優的。
所以可以對每一個點維護一個堆,存可能的最優解。
考慮點 \(y\)
也就是說,\(y\) 的所有方案只需要同時加上一個常數,然後扔到 \(x\) 的堆裡就好了。直接上左偏樹,然後需要搞一個子樹加的標記。
但是當某一個方案覆蓋不到 \(x\) 與其父親的邊的時候,這個方案就需要刪掉了。在每次合併完後不斷判斷堆頂是否需要刪掉即可。
新建一個虛根連向 \(1\),再加一條代價為 \(0\) 的路徑,最後輸出虛根的 \(f\) 即可。
時間複雜度 \(O(n\log m)\)
程式碼
#include <bits/stdc++.h> #define mp make_pair #define fi first #define se second using namespace std; typedef long long ll; const int N=300010; int n,m,tot,head[N],rt[N],dep[N]; ll f[N]; bool flag; vector<pair<int,int> > a[N]; struct edge { int next,to; }e[N*2]; void add(int from,int to) { e[++tot]=(edge){head[from],to}; head[from]=tot; } struct LeftistTree { int tot,dis[N],pos[N],lc[N],rc[N]; ll val[N],lazy[N]; int insert(pair<int,int> b) { tot++; val[tot]=b.se; pos[tot]=b.fi; return tot; } void pushdown(int x) { if (lazy[x]) { if (lc[x]) val[lc[x]]+=lazy[x],lazy[lc[x]]+=lazy[x]; if (rc[x]) val[rc[x]]+=lazy[x],lazy[rc[x]]+=lazy[x]; lazy[x]=0; } } int merge(int x,int y) { if (!x || !y) return x|y; pushdown(x); pushdown(y); if (val[x]>val[y] || (val[x]==val[y] && x>y)) swap(x,y); rc[x]=merge(rc[x],y); if (dis[rc[x]]>dis[lc[x]]) swap(lc[x],rc[x]); dis[x]=dis[rc[x]]+1; return x; } int pop(int x) { pushdown(x); return merge(lc[x],rc[x]); } }lit; void dfs(int x,int fa) { dep[x]=dep[fa]+1; for (int i=0;i<(int)a[x].size();i++) rt[x]=lit.merge(rt[x],lit.insert(a[x][i])); for (int i=head[x];~i;i=e[i].next) { int v=e[i].to; if (v!=fa) { dfs(v,x); f[x]+=f[v]; if (flag) return; lit.lazy[rt[v]]-=f[v]; lit.val[rt[v]]-=f[v]; rt[x]=lit.merge(rt[x],rt[v]); } } lit.lazy[rt[x]]+=f[x]; lit.val[rt[x]]+=f[x]; while (rt[x] && dep[lit.pos[rt[x]]]>=dep[x]) rt[x]=lit.pop(rt[x]); if (!rt[x]) { flag=1; return; } f[x]=lit.val[rt[x]]; } int main() { memset(head,-1,sizeof(head)); scanf("%d%d",&n,&m); for (int i=1,x,y;i<n;i++) { scanf("%d%d",&x,&y); add(x,y); add(y,x); } n++; add(n,1); for (int i=1,x,y,z;i<=m;i++) { scanf("%d%d%d",&x,&y,&z); a[x].push_back(mp(y,z)); } a[1].push_back(mp(0,0)); lit.dis[0]=-1; dep[0]=-1; dfs(n,0); if (flag) cout<<"-1"; else cout<<f[n]; return 0; }