Codeforces 1092 F Tree with Maximum Cost (換根 + dfs)
阿新 • • 發佈:2018-12-22
題意:
給你一棵無根樹,每個節點有個權值$a_i$,指定一個點u,定義$\displaystyle value = \sum^v a_i*dist(u,v)$,求value的最大值
n,ai<=2e5
思路:
其實就是找一個節點作為根滿足上述最大的value
直接列舉是$O(n^2)$的,肯定不行,我們要用到換根法
換根適用於這種無根樹找根,兩個跟直接產生的結果又有聯絡,可以相互轉換的情況
對於這一題,我們讓sum[u] = 以u為根的子樹的$\sum a_i$
這樣,從父親節點u向兒子節點v轉移的時候,
假設此時的value(整棵樹以u為根)為res,我們要將res的值轉化為以v為根的value
大前提:此時u是整棵樹的根! //沒有這個大前提也可以,你要預處理一下每個節點祖先的$\sum a_i$,然後在下面的操作中搞一下,但是我們完全可以通過只改變sum[u],sum[v]的值來決定到底誰才是整棵樹的根,因為無論u,v誰是根,其他節點的sum[]都是不變的!嘻嘻
首先$value_v$相比$value_u$,根(v或u)與以v為根的子樹中的每一個節點的距離都小了1
在value上表現為 res -= sum[v]
其次在以v為根的子樹之外的節點,跟到那些節點的距離都大了1
所以sum[u] -= sum[v], res += sum[u]
此時因為v要成為整個樹的根,所以sum[v]+=sum[u]
程式碼:
#include<iostream> #include<cstdio> #include<algorithm> #include<cmath> #include<cstring> #include<string> #include<stack> #include<queue> #include<deque> #include<set> #include<vector> #include<map> #include<functional> #define fst first #define sc second #define pb push_back #define mem(a,b) memset(a,b,sizeof(a)) #define lson l,mid,root<<1 #define rson mid+1,r,root<<1|1 #define lc root<<1 #define rc root<<1|1 #define lowbit(x) ((x)&(-x)) using namespace std; typedef double db; typedef long double ldb; typedef long long ll; typedef unsigned long long ull; typedef pair<int,int> PI; typedef pair<ll,ll> PLL; const db eps = 1e-6; const int mod = 1e9+7; const int maxn = 2e6+100; const int maxm = 2e6+100; const int inf = 0x3f3f3f3f; const db pi = acos(-1.0); vector<int>g[maxn]; int a[maxn]; ll res, ans; ll sum[maxn]; void dfs(int x, int fa, int h){ int sz = g[x].size(); res += 1ll*h*a[x]; sum[x] = a[x]; for(int i = 0; i < sz; i++){ if(g[x][i] == fa)continue; dfs(g[x][i], x, h+1); sum[x] += sum[g[x][i]]; } return; } void dfs2(int x, int fa){ ans = max(res, ans); int sz = g[x].size(); for(int i = 0; i < sz; i++){ int y = g[x][i]; if(y == fa) continue; res -= sum[y]; sum[x] -= sum[y]; res += sum[x]; sum[y] += sum[x]; dfs2(y, x); sum[y] -= sum[x]; res -= sum[x]; sum[x] += sum[y]; res += sum[y]; } return; } int main(){ int n; scanf("%d", &n); mem(sum, 0); for(int i = 1; i <= n; i++){ scanf("%d", &a[i]); } for(int i = 1; i < n; i++){ int x, y; scanf("%d %d",&x,&y); g[x].pb(y); g[y].pb(x); } res = 0; ans = 0; dfs(1,-1,0); dfs2(1,-1); printf("%lld", ans); return 0; } /* */
明天(今天)還得磨錘子,趕緊睡覺了