1. 程式人生 > 其它 >[USACO12NOV]Balanced Trees G

[USACO12NOV]Balanced Trees G

tag:點分治


怎麼兩篇題解都沒了,我來補一篇

這種樹上路徑問題很容易想到是點分治,考慮如何計算兩條路徑拼起來的答案。

首先一定是(...(...(...(...這樣一條路徑和...)...)...)這樣一條路徑拼起來,然後因為是求 \(\max\),所以求出兩邊的最大深度再取 \(\max\) 就行了。這裡以左邊為例。


假設當前是一條合法路徑,那麼當前的答案就是處理括號序列時用的那個棧的歷史最大值,所以在dfs的時候拿個mx變數記錄一下。

如果當前是一個“(”,而且棧為空,就要把mx變數加一,然後可以忽略掉這個“(”【因為拼起來的路徑是合法的,這裡就可以預設右邊有一個“)”把它匹配掉了】。

比如說當前是“(”,父親到重心是“()(())”,mx\(2\),由於右邊路徑一定有一個“)”與當前的“(”匹配,所以兩段路徑拼起來是“( ()(()) ... )”。

相當於中間那部分括號的深度整體+1,所以讓mx+1就行。

對於右邊路徑的處理也是類似的,可以結合程式碼理解一下:

/*
up為括號序列棧頂
mxup為歷史最大值
cntl為多餘出來的,需要用右邊的")"去匹配的"("
*/
if(a[x]=='(') val[x].up++;
else val[x].up--;
val[x].mxup = max(val[x].mxup,-val[x].up);
if(val[x].up>0)
    val[x].cntl++,
    val[x].mxup++,
    val[x].up = 0;

然後可以用一個桶去記錄最大值,以 cntl/cntr 為下標。


注意一些小細節

  • 在dfs的時候最好令一邊包含重心,另一邊不包含重心

  • 不要漏了到重心的本身就合法的鏈

  • 在拼路徑的時候要判斷是否存在對應的路徑


#include<bits/stdc++.h>
using namespace std;

template<typename T>
inline void Read(T &n){
	char ch; bool flag=0;
	while(!isdigit(ch=getchar()))if(ch=='-')flag=1;
	for(n=ch^48;isdigit(ch=getchar());n=((n<<1)+(n<<3)+(ch^48)));
	if(flag)n=-n;
}

enum{
    MAXN = 40005
};

int n;

struct _{
    int nxt, to;
    _(int nxt=0, int to=0):nxt(nxt),to(to){}
}edge[MAXN<<1];
int fst[MAXN], tot;

inline void Add_Edge(int f, int t){
    edge[++tot] = _(fst[f], t); fst[f] = tot;
    edge[++tot] = _(fst[t], f); fst[t] = tot;
}

char a[MAXN];

inline void upd(int &x, int y){x = max(x,y);}

int sz[MAXN], Size, Weight, Center;
char vis[MAXN], Vis[MAXN];
void Get_Center(int x){
    vis[x] = true;
    sz[x] = 1;
    int w=0;
    for(register int u=fst[x]; u; u=edge[u].nxt){
        int v=edge[u].to;
        if(vis[v] or Vis[v]) continue;
        Get_Center(v);
        sz[x] += sz[v];
        upd(w,sz[v]);
    }   
    upd(w,Size-sz[x]);
    if(w < Weight)
        Weight = w, Center = x;
    vis[x] = false;
}

struct ele{int mxup, mxdown, up, down, cntl, cntr;}val[MAXN];

int q[MAXN], top;
int mxl[MAXN], mxr[MAXN];
void dfs(int x){
    vis[x] = true; q[++top] = x; sz[x] = 1;

    if(a[x]=='(') val[x].up++; else val[x].up--;
    upd(val[x].mxup,-val[x].up);
    if(val[x].up>0)
        val[x].cntl++,
        val[x].mxup++,
        val[x].up = 0;

    if(a[x]=='(') val[x].down++; else val[x].down--;
    upd(val[x].mxdown,val[x].down);
    if(val[x].down<0)
        val[x].cntr++,
        val[x].mxdown++,
        val[x].down = 0;

    for(register int u=fst[x]; u; u=edge[u].nxt){
        int v=edge[u].to;
        if(vis[v] or Vis[v]) continue;
        val[v] = val[x]; dfs(v);
        sz[x] += sz[v];
    }
    vis[x] = false;
}

int ans, dep;
void solve(int x){
    Weight = Size; Get_Center(x); x = Center; Vis[x] = true;

    ele base = (ele){0,0,0,0,0,0};
    if(a[x]=='(') base.cntl = 1; else base.up = -1; base.mxup = 1;
    if(!base.up) mxl[base.cntl] = base.mxup;
    mxr[0] = 0;

    int prv; top = 0;
    int ml=0, mr=0;
    for(register int u=fst[x]; u; u=edge[u].nxt){
        int v=edge[u].to;
        if(Vis[v]) continue;
        val[v] = base; prv = top+1;
        dfs(v);
        for(register int i=prv; i<=top; i++){
            ele cur = val[q[i]];
            upd(ml,cur.cntl); upd(mr,cur.cntr);
            if(!cur.up and ~mxr[cur.cntl]) upd(ans,max(mxr[cur.cntl],cur.mxup));
            if(!cur.down and ~mxl[cur.cntr]) upd(ans,max(mxl[cur.cntr],cur.mxdown));
        }
        for(register int i=prv; i<=top; i++){
            ele cur = val[q[i]];
            if(!cur.up) upd(mxl[cur.cntl],cur.mxup);
            if(!cur.down) upd(mxr[cur.cntr],cur.mxdown);
        }
    }
    fill(mxl,mxl+ml+1,-1); fill(mxr,mxr+mr+1,-1);

    for(register int u=fst[x]; u; u=edge[u].nxt){
        int v=edge[u].to;
        if(Vis[v]) continue;
        Size = sz[v]; solve(v);
    }
}

char tmp[10];

int main(){
    freopen("1.in","r",stdin);
	freopen("2.out","w",stdout);
    // double tt=clock();
    Read(n);
    for(register int i=2; i<=n; i++){
        int fa; scanf("%d",&fa);
        Add_Edge(i,fa);
    }
    memset(mxl,-1,sizeof mxl);
    memset(mxr,-1,sizeof mxr);
    for(register int i=1; i<=n; i++) scanf("%s",tmp), a[i] = tmp[0];
    Size = n; solve(1);
    cout<<ans<<endl;
    // printf("%.6lf\n",(clock()-tt)/CLOCKS_PER_SEC);
    return 0;
}