1. 程式人生 > 實用技巧 >動態 DP

動態 DP

一道入門 DP + 修改 = 動態 DP。
模板題為例,多次詢問樹的最大獨立集,帶修改。
先有 naive 的 DP,記 \(f_{u,0/1}\) 表示 \(u\) 點不選/選時以 \(u\) 為根的子樹的最大獨立集權值,有

\[\begin{cases}f_{u,0}=\sum\limits_{v\in \operatorname{son}u}\max\{f_{v,0},f_{v,1}\} \\ f_{u,1}=w_u+\sum\limits_{v\in\operatorname{son}u}f_{v,0}\end{cases} \]

修改一個點的權值,只有這個點到根結點的 DP 值發生了變化,於是我們考慮重鏈剖分來快速轉移。
說到快速轉移,又想到了用矩陣。所以我們先來把轉移方程改成可用矩陣轉移的形式。再記兩個 DP 值

\[\begin{cases}g_{u,0}=\sum\limits_{v\in\operatorname{lson}u}\max\{f_{v,0},f_{v,1}\} \\ g_{u,1}=w_u+\sum\limits_{v\in\operatorname{lson}u}f_{v,0}\end{cases} \]

\(\operatorname{lson}u\) 表示 \(u\) 的所有輕兒子。這樣做的目的是將輕重兒子的轉移分開,更好維護。
現在重寫轉移方程

\[\begin{cases}f_{u,0}=g_{u,0}+\max\{f_{\operatorname{hson}u,0},f_{\operatorname{hson}u,1}\} \\ f_{u,1}=g_{u,1}+f_{\operatorname{hson}u,0}\end{cases} \]

\(\operatorname{son}u\) 表示 \(u\) 的重兒子。這樣我們就丟掉了煩人的 \(\sum\) 了。
上面的轉移形式很像矩陣 \(A_{i,j}=\max\limits_{k}\{B_{i,k}+C_{k,j}\}\) 的轉移形式,所以我們用這個新定義的矩陣乘法來重寫剛才的式子(記 \(v=\operatorname{hson}u\)):

\[\begin{bmatrix}f_{u,0} \\ f_{u,1}\end{bmatrix}=\begin{bmatrix}g_{u,0} & g_{u,0} \\ g_{u,1} & -\infty\end{bmatrix}\begin{bmatrix}f_{v,0} \\ f_{v,1}\end{bmatrix} \]

這樣寫後,一個點的 DP 值就等於這個點所在重鏈的葉子結點的矩陣乘這條鏈上的所有的轉移矩陣,這個可以用線段樹來維護。修改時沿著鏈向上跳,線段樹上單點修改即可。
更多細節詳見程式碼(註釋自我感覺跟詳細)

#include <bits/stdc++.h>
using namespace std;

const int N=1e5+5;
struct matrix
{
    int a[2][2];
    matrix() {memset(a,0xcf,sizeof(a));}
    int* const operator[](const int i) {return a[i];}
    const int* const operator[](const int i) const {return a[i];}
    matrix operator*(matrix b)
    {
        matrix c;
        for(int i=0;i<2;++i)
            for(int j=0;j<2;++j)
                for(int k=0;k<2;++k)
                    c[i][j]=max(c[i][j],a[i][k]+b[k][j]);
        return c;
    }
}g[N],t[N<<2];//g[]是轉移矩陣,t[]是線段樹
int n,Q,son[N],siz[N],dfn[N],id[N],fa[N],ids,top[N],nd[N];
//id[]是這個點的dfs序,dfn[]是id[]的逆對映,nd[]只有在鏈頂有定義,代表這條鏈的底的dfs序
int f[N][2],w[N];//f[]是 DP 值
vector<int> G[N];

void dfs1(int u,int faz)
{
    siz[u]=1,fa[u]=faz;
    for(int v:G[u]) if(v!=faz)
    {
        dfs1(v,u),siz[u]+=siz[v];
        if(siz[v]>siz[son[u]]) son[u]=v;
    }
}

void dfs2(int u,int tp)
{
    top[u]=tp,id[u]=++ids,dfn[ids]=u;
    //按照定義初始化矩陣和 DP 值
    g[u][0][0]=g[u][0][1]=0;
    g[u][1][0]=f[u][1]=w[u],nd[tp]=max(nd[tp],ids);
    if(son[u])
    {
        dfs2(son[u],tp);
        f[u][0]+=max(f[son[u]][0],f[son[u]][1]);
        f[u][1]+=f[son[u]][0];
    }
    for(int v:G[u]) if(v!=fa[u]&&v!=son[u])
    {
        dfs2(v,v);
        f[u][0]+=max(f[v][0],f[v][1]);
        g[u][0][0]+=max(f[v][0],f[v][1]);
        f[u][1]+=f[v][0],g[u][0][1]=g[u][0][0];
        g[u][1][0]+=f[v][0];
        //按照定義轉移即可
    }
}

//下面的線段樹維護的是鏈上的轉移矩陣之積
void build(int rt,int l,int r)
{
    if(l==r) {t[rt]=g[dfn[l]]; return;}
    int mid=l+r>>1;
    build(rt<<1,l,mid),build(rt<<1|1,mid+1,r);
    t[rt]=t[rt<<1]*t[rt<<1|1];
}

void upd(int rt,int lc,int rc,int p)
{
    if(lc==rc) {t[rt]=g[dfn[lc]]; return;}
    int mid=lc+rc>>1;
    if(p<=mid) upd(rt<<1,lc,mid,p);
    else upd(rt<<1|1,mid+1,rc,p);
    t[rt]=t[rt<<1]*t[rt<<1|1];
}

matrix query(int rt,int lc,int rc,int l,int r)
{
    if(l<=lc&&r>=rc) return t[rt];
    int mid=lc+rc>>1;
    if(r<=mid) return query(rt<<1,lc,mid,l,r);
    else if(l>mid) return query(rt<<1|1,mid+1,rc,l,r);
    else return query(rt<<1,lc,mid,l,r)*query(rt<<1|1,mid+1,rc,l,r);
}

void upd_val(int x,int y)
{
    g[x][1][0]+=y-w[x],w[x]=y; //先將自己的值修改了
    matrix bef,aft; //需要分別記修改前後的轉移矩陣,靠這個差值更新
    while(x)
    {
        bef=query(1,1,n,id[top[x]],nd[top[x]]);
        upd(1,1,n,id[x]);
        aft=query(1,1,n,id[top[x]],nd[top[x]]);
        x=fa[top[x]]; //這個點是鏈頂的父親,所以這條鏈對於父親來說是一條輕鏈,所以父親的矩陣必須修改
        g[x][0][0]+=max(aft[0][0],aft[1][0])-max(bef[0][0],bef[1][0]);
        g[x][0][1]=g[x][0][0],g[x][1][0]+=aft[0][0]-bef[0][0];
    }
}

int main()
{
    scanf("%d%d",&n,&Q);
    for(int i=1;i<=n;++i) scanf("%d",w+i);
    for(int i=1,a,b;i<n;++i)
    {
        scanf("%d%d",&a,&b);
        G[a].push_back(b),G[b].push_back(a);
    }
    dfs1(1,0),dfs2(1,1),build(1,1,n);
    for(int i=1,x,y;i<=Q;++i)
    {
        scanf("%d%d",&x,&y);
        upd_val(x,y);
        matrix ans=query(1,1,n,id[1],nd[1]);
        //注意到一個點的 DP 值就是這個點所在鏈的底到這個點的轉移矩陣之積
        //似乎我們在 dfs 後就一直沒有用 f[] 了?因為每個葉子結點
        //的轉移矩陣剛好與其 f[] 值相等,所以我們不需要了
        printf("%d\n",max(ans[0][0],ans[1][0]));
    }
    return 0;
}