1. 程式人生 > 實用技巧 >動態dp

動態dp

動態 dp

動態 \(dp\) 是由貓坤大佬在 WC 2018 提出來的黑科技。

他主要解決得是帶修改的 \(dp\) 問題,主要思路是由矩陣乘法來維護 \(dp\) 轉移

我們先來看一道模板題 動態dp

這道題,我們先來看不帶修改的情況,

轉移和狀態很容易就能列出來

\(f[i][0/1]\) 表示 以 \(i\) 為根的子樹中,且選 (不選) \(i\) 這個點的最大權值和

轉移方程可以寫成:

\(f[i][0] = \displaystyle \sum_{to \in son[i]} max (f[to][0],f[to][1])\)

\(f[i][1] = \displaystyle \sum_{to \in son[x]} f[to][0]\ + w[x]\)

當我們修改一個點的權值,那麼從他到根節點的路徑上的點的 \(dp\) 值都會受到影響。

所以,我們就可以只修改這一條鏈上的資訊(用樹鏈剖分來維護)。

動態 \(dp\) 的思想就是將 重兒子和輕兒子的貢獻分開來考慮。

\(g[i][0] = \displaystyle \sum _{to \in 輕兒子} max(f[to][0],f[to][1])\), $g[i][1] = \displaystyle \sum_{to \in 輕兒子} f[to][0]\ +w[i] $.

解釋一下 \(g[i][0]\) 表示 選 或不選 \(i\) 的輕兒子的最大貢獻和, \(g[i][1]\) 表示不選他輕兒子的價值和加上他本身的權值.

這兩個可以在求 \(f[x][0/1]\) 的時候順便維護出來

那麼,上面的方程就可以寫成。

\(f[i][0] = g[i][0] + max(f[son[i]][0],f[son][i][1])\)

\(f[i][1] = g[i][1] + f[son[i]][0]\) (\(son[x]\) 表示 \(x\) 的重兒子)

定義廣義矩陣乘法:

矩陣 \(c\) 為 矩陣 A 和矩陣 B 的乘積

那麼,\(c[i][j] = max( a[i][k] + b[k][j])\)

這個其實就是把普通的矩陣乘法的乘號改為加號,加號改為取 \(max\)

這種矩陣乘法也滿足普通矩陣乘法的性質,不相信的可以跑幾組資料試試

程式碼實現長這樣:

node operator *(node x,node y)
{
    node ans;
    for(int i = 0; i <= n; i++)
        for(int j = 0; j <= n; j++)
            for(int k = 0; k <= n; k++)
                ans.a[i][j] = max(ans.a[i][j],x.a[i][k] + y.a[k][j]);
    return ans;
}

然後,我們就可以根據這個廣義矩陣乘法構造一個矩陣,即,

\[\left[ \begin{matrix} f[to] [0] \\ f[to] [1] \\ \end{matrix} \right] \tag{2} \times \left[ \begin{matrix} \cdots \\ \cdots \\ \end{matrix} \right] = \left[ \begin{matrix} \ f[i][0] \\ \ f[i][1] \\ \end{matrix} \right] \]

其中 第二個矩陣是我們要 確定的轉移矩陣。

第二個矩陣構造出來長這樣:

\[\left[ \begin{matrix} f[to] [0] \\ f[to] [1] \\ \end{matrix} \right] \tag{2} \times \left[ \begin{matrix} \ g[i][0] & g[i][1] \\ \ g[i][0] & -\infin \\ \end{matrix} \right] = \left[ \begin{matrix} \ f[i][0] \\ \ f[i][1] \\ \end{matrix} \right] \]

但,我們發現一個問題,我們矩陣中沒有這個 \(f\) 值,這,我們一開始的矩陣就沒法轉移啊。

其實,對於每一個重鏈的末尾節點,他的轉移矩陣就是他的 \(f\) 值,

也就是說重鏈的末端節點儲存了 f的準確值 -by treaker

這樣,我們就可以由這個矩陣來推出這一條重鏈上每個點的資訊。

我們的暴力做法就是一直跳他的父親就是一直跳他的父親節點,再把矩陣乘起來。

但,這樣的複雜度還是不能接受。我們可以把這個矩陣放到線段樹上來維護。

線段樹的每個葉子節點存這個點的轉移矩陣,每個節點表示左右兒子矩陣的乘積,

這樣就可以用 O(log n) 的時間得到我們想要的資訊。

在結合樹剖套線段樹的做法,就能做到維護這棵樹的資訊。

注意:矩陣乘法要右乘,因為我們是從 dfn序大的地方跳到 dfn序小的地方,線上段樹上對應的是從右往左。

我們上面的轉移矩陣就需要重構一下,變成

\[\left[ \begin{matrix} \ g[i][0] & g[i][0] \\ \ - \infin & g[i][1] \\ \end{matrix} \right] \tag{2} \times \left[ \begin{matrix} f[son[x]][0]\\ f[son[x]][1]\\ \end{matrix} \right] = \left[ \begin{matrix} \ f[i][0] \\ \ f[i][1] \\ \end{matrix} \right] \]

修改操作,上文我們提到了 ‘當我們修改一個點的權值,那麼從他到根節點的路徑上的點的 \(dp\) 值都會受到影響’。

我們修改的就是從這個點到根節點的路徑的矩陣,我們把這條鏈琛出來,發現他是重鏈和輕邊交替在一起的。

我們仔細觀察一下我們構造的矩陣,發現裡面維護的是輕兒子的資訊,不會涉及到重兒子。

對於重鏈,我們不用修改,但對於重鏈鏈頂 $top[x] $和輕邊 $fa[top[x]] $交替的地方,我們需要單點修改。

因為此時鏈頂 \(top[x]\) 屬於 $fa[top[x]] $輕兒子,他的修改會對他父親的轉移矩陣造成影響。

我們關鍵要算出他改變之後對他父親轉移矩陣的影響。

他的 \(f\) 值我們可以由下面的推出來,然後可以根據 \(g\) 陣列的定義算出他改變之後對他父親的影響。

就這樣在更新,每次跳重鏈,直到跳到根節點,我們的修改操作就大工告成了。

注意一下,我們不能修改之後在統計他的 輕兒子的值,那樣可能不對,我們就只能通過增量法來修改。

記錄一下原來狀態的矩陣,在記錄修改之後的矩陣,兩者之差就是對他父親轉移矩陣的貢獻。

每次修改操作的時間複雜度為 O(\(log^2 n\))

具體程式碼長這樣

void modify(int x,int val)
{
    base[dfn[x]].a[1][0] += val - w[x];//增量法統計他修改的貢獻
    w[x] = val;
    while(x)
    {
        node Old = get_node(top[x]);//記錄一下鏈頂修改之前的矩陣
        chenge(1,1,n,dfn[x]);//修改當前這個節點的轉移矩陣
        node New = get_node(top[x]);//得到鏈頂修改之後的轉移矩陣
        int fx = dfn[fa[top[x]]];
        base[fx].a[0][0] += max(New.a[0][0],New.a[1][0]) - max(Old.a[0][0],Old.a[1][0]);//算他修改對他父親轉移矩陣的影響
        base[fx].a[0][1] += max(New.a[0][0],New.a[1][0]) - max(Old.a[0][0],Old.a[1][0]);
        base[fx].a[1][0] += New.a[0][0] - Old.a[0][0];
        x = fa[top[x]];//跳鏈修改
    }
}

總程式碼:

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
const int inf = 233333333;
const int N = 1e5+10;
int n,m,tot,x,y,num,val,u,v;
int head[N],top[N],dep[N],siz[N],fa[N],son[N],ord[N],end[N],f[N][2],g[N][2],dfn[N],w[N];
struct bian
{
    int to,net;
}e[N<<1];
struct node
{
    int a[2][2];
    node() {memset(a,-127/3,sizeof(a));}
}tr[N<<2],base[N];
inline int read()
{
    int s = 0,w = 1; char ch = getchar();
    while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
    while(ch >= '0' && ch <= '9'){s =s * 10+ch - '0'; ch = getchar();}
    return s * w;
}
node operator *(node x,node y)//廣義矩陣乘法
{
    node ans;
    for(int i = 0; i <= 1; i++)
    {
        for(int j = 0; j <= 1; j++)
        {
            for(int k = 0; k <= 1; k++)
            {
                ans.a[i][j] = max(ans.a[i][j],x.a[i][k] + y.a[k][j]);
            }
        }
    }
    return ans;
}
void add(int x,int y)
{
    e[++tot].to = y;
    e[tot].net = head[x];
    head[x] = tot;
}
void get_tree(int x)//樹剖預處理
{
    dep[x] = dep[fa[x]] + 1; siz[x] = 1;
    for(int i = head[x]; i; i = e[i].net)
    {
        int to = e[i].to;
        if(to == fa[x]) continue;
        fa[to] = x;
        get_tree(to);
        siz[x] += siz[to];
        if(siz[son[x]] < siz[to] || son[x] == -1) son[x] = to;
    }
}
void dfs(int x,int topp)
{
    top[x] = topp; dfn[x] = ++num; ord[num] = x;//END 記錄這條重鏈的鏈尾,ord記錄當前 當前這個節點的樹上編號
    if(son[x] == -1)
    {
        end[topp] = num;
        return;
    }
    dfs(son[x],topp);
    for(int i = head[x]; i; i = e[i].net)
    {
        int to = e[i].to;
        if(to == fa[x] || to == son[x]) continue;
        dfs(to,to);
    }
}
void dp(int x,int fa)//預處理為沒修改之前的 f 值與 g 值
{
    g[x][1] = f[x][1] = w[x];
    for(int i = head[x]; i; i = e[i].net)
    {
        int to = e[i].to;
        if(to == fa) continue;
        dp(to,x);
        f[x][0] += max(f[to][0], f[to][1]);
        f[x][1] += f[to][0];
        if(to != son[x])
        {
            g[x][0] += max(f[to][0], f[to][1]);
            g[x][1] += f[to][0];
        }
    }
}
void up(int o)
{
    tr[o] = tr[o<<1] * tr[o<<1|1];
}
void build(int o,int L,int R)//線段樹建樹
{
    if(L == R)
    {
    	int tmp = ord[L];
        tr[o].a[0][0] = tr[o].a[0][1] = g[tmp][0];//構造轉移矩陣
        tr[o].a[1][0] = g[tmp][1]; 
        base[L] = tr[o];
        return;
    }
    int mid = (L + R)>>1;
    build(o<<1,L,mid);
    build(o<<1|1,mid+1,R);
    up(o);
}
void chenge(int o,int l,int r,int x)//單點修改
{
    if(l == r)
    {
        tr[o] = base[l];
        return;
    }
    int mid = (l + r)>>1;
    if(x <= mid) chenge(o<<1,l,mid,x);
    if(x > mid) chenge(o<<1|1,mid+1,r,x);
    up(o);
}
node query(int o,int l,int r,int L,int R)//區間查詢
{
    if(L <= l && R >= r) return tr[o];
    int mid = (l + r)>>1;
    if(R <= mid) return query(o<<1,l,mid,L,R);
    if(L > mid) return query(o<<1|1,mid+1,r,L,R);
    return query(o<<1,l,mid,L,R) * query(o<<1|1,mid+1,r,L,R);
}
node get_node(int x)//得到鏈頂的 f 值
{
    return query(1,1,n,dfn[x],end[top[x]]);
}
void modify(int x,int val)
{
    base[dfn[x]].a[1][0] += val - w[x];//增量法統計他修改的貢獻
    w[x] = val;
    while(x)
    {
        node Old = get_node(top[x]);//記錄一下鏈頂修改之前的矩陣
        chenge(1,1,n,dfn[x]);//修改當前這個節點的轉移矩陣
        node New = get_node(top[x]);//得到鏈頂修改之後的轉移矩陣
        int fx = dfn[fa[top[x]]];
        base[fx].a[0][0] += max(New.a[0][0],New.a[1][0]) - max(Old.a[0][0],Old.a[1][0]);//算他修改對他父親轉移矩陣的影響
        base[fx].a[0][1] += max(New.a[0][0],New.a[1][0]) - max(Old.a[0][0],Old.a[1][0]);
        base[fx].a[1][0] += New.a[0][0] - Old.a[0][0];
        x = fa[top[x]];//跳鏈修改
    }
}
int main()
{
    n = read(); m = read();
    for(int i = 1; i <= n; i++)
    {
        w[i] = read();
        son[i] = -1;
    }
    for(int i = 1; i <= n-1; i++)
    {
    	u = read(); v = read();
    	add(u,v); add(v,u);
    }
    get_tree(1); dfs(1,1); dp(1,0); build(1,1,n);//預處理
    for(int i = 1; i <= m; i++)
    {
        x = read(); val = read();
        modify(x,val);//修改操作
        node ans = get_node(1);//得到新答案
        printf("%d\n",max(ans.a[0][0],ans.a[1][0]));
    }
    return 0;
}