1. 程式人生 > 其它 >動態DP(動態動態規劃?) & 動態點分治

動態DP(動態動態規劃?) & 動態點分治

動態 DP

NOIP居然會考這種東西,所以不得不來學一下

結合著上面這道題,可以看出,動態DP就是一個動態規劃問題加上了修改操作,

如果每一次修改我們都去跑一遍動態規劃,時間複雜度直接起飛,所以這時候就要想辦法優化。

首先看一下這道題如果不帶修改操作該怎麼做

\(f_{i,0/1}\) 表示以 \(i\) 為子樹不選/選 \(i\) 的最大點權獨立集

那麼狀態轉移方程就為 \(\begin{cases}f_{u,0}=\sum\max(f_{v,0},f_{v,1})\\f_{u,1}=a_u+\sum f_{v,0}\end{cases}\)

一個簡單的樹上DP,可以看出,每個結點只會對他的父親結點造成影響.

這時候如果我們修改了其中一個結點,那麼他只會對在他這條鏈上的祖先結點造成影響,

如果這條樹的高度比較平均,那麼就只需要 \(log(n)\) 次,可惜如果這棵樹退化成鏈,那麼一次修改就需要 \(n\)次,

顯然是不行的.既然是跟樹有關,可以想到一個數據結構,樹鏈剖分.因為我們的DP是從葉子結點往上轉移,

而樹鏈剖分中每條重鏈的鏈尾都是葉子結點,這就可以讓我們很好的進行DP的轉移,

同時樹鏈剖分可以讓我們快速的進行修改操作,那麼我們該怎麼把DP與樹鏈剖分結合?

將DP的轉移式稍微變一下形,令 \(g_{i,0}\) 表示所有 \(i\) 的所有輕兒子可取可不取的最大值, \(g_{i,1}\) 表示 \(i\)

的所有輕兒子都不去並取 \(i\) 的最大值

\[\begin{cases}f_{u,0}=g_{u,0}+max(f_{v,1},f_{v,0})\\f_{u,1}=g_{u,1}+f_{v,0}\end{cases} \]

這個東西並不好直接快速在樹上轉移,這個形式可以考慮用矩陣加速.

重定義矩陣乘法 \(c_{i,j}=max(a_{i,k}+b_{k,j})\)

這個為什麼可以套用在上面的轉移式?

再將上面的轉移式改一下,\(\begin{cases}f_{u,0}=max(f_{v,1}+g_{u,0},f_{v,0}+g_{u,0})\\f_{u,1}=max(g_{u,1}+f_{v,0},-\inf)\end{cases}\)

這樣就可套用矩陣乘法了.可以直接構造出一個矩陣

\[\begin{vmatrix}f_{v,0}&f_{v,1}\end{vmatrix}\times\begin{vmatrix}g_{u,0}&g_{u,1}\\g_{u,0}&-\inf\end{vmatrix}=\begin{vmatrix}f_{u,0}&f_{v,0}\end{vmatrix} \]

所以我們只需要線上段樹中維護一個轉移矩陣,最後求一下所有轉移矩陣的積就行.

注意線段樹做乘法時的順序是從父親結點到葉子節點,所以我們要交換一下矩陣的順序

\[\begin{vmatrix}g_{u,0}&g_{u,1}\\g_{u,1}&-\inf\end{vmatrix}\times\begin{vmatrix}f_{v,0}\\f_{v,1}\end{vmatrix}=\begin{vmatrix}f_{u,0}\\f_{u,1}\end{vmatrix} \]

這樣就可以快速進行轉移了

程式碼

#include<iostream>
#include<cstring>
#include<cstdio>
using namespace std;
const int N = 1e5 + 5;

struct matrix
{
	int data[2][2];
	matrix() { memset(data, -0x3f, sizeof(data)); };
	matrix operator * (const matrix a) const
	{
		matrix c;
		for (int i = 0; i < 2; i++)
			for (int j = 0; j < 2; j++)
				for (int k = 0; k < 2; k++)
					c.data[i][j] = max(c.data[i][j], data[i][k] + a.data[k][j]);
		return c;
	}
} it[N];
struct tree
{
	int l, r;
	matrix mx;
} tr[4 * N];
int n, m, a[N], fa[N], siz[N], son[N];
int head[N], ver[2 * N], net[2 * N], idx, ed[N];
int top[N], tot, id[N], f[N][2], dfsn[N];

void add(int a, int b)
{
	net[++idx] = head[a], ver[idx] = b, head[a] = idx;
}

void dfs1(int u, int f)
{
	siz[u] = 1, fa[u] = f;
	for (int i = head[u]; i; i = net[i])
	{
		int v = ver[i];
		if (v == f)
			continue;
		dfs1(v, u);
		siz[u] += siz[v];
		if (siz[v] > siz[son[u]])
			son[u] = v;
	}
}

void dfs2(int u, int t)
{
	dfsn[u] = ++tot, id[tot] = u, top[u] = t;
	f[u][0] = 0, f[u][1] = a[u], ed[t] = max(ed[t], tot);
	it[u].data[0][0] = it[u].data[0][1] = 0;
	it[u].data[1][0] = a[u];
	if (!son[u])
		return;
	dfs2(son[u], t);
	f[u][0] += max(f[son[u]][0], f[son[u]][1]);
	f[u][1] += f[son[u]][0];
	for (int i = head[u]; i; i = net[i])
	{
		int v = ver[i];
		if (v == fa[u] || v == son[u])
			continue;
		dfs2(v, v);
		f[u][0] += max(f[v][0], f[v][1]);
		f[u][1] += f[v][0];
		it[u].data[0][0] += max(f[v][0], f[v][1]);
		it[u].data[0][1] = it[u].data[0][0];
		it[u].data[1][0] += f[v][0];
	}
}

void pushup(int p)
{
	tr[p].mx = tr[p << 1].mx * tr[p << 1 | 1].mx;
}

void build(int l, int r, int p)
{
	tr[p].l = l, tr[p].r = r;
	
	if (l == r)
	{
		tr[p].mx = it[id[l]];
		return;
	}
	int mid = (l + r) >> 1;
	build(l, mid, p << 1);
	build(mid + 1, r, p << 1 | 1);
	pushup(p);
}

void update_tree(int x, int p)
{
	if (tr[p].l == tr[p].r)
	{
		tr[p].mx = it[id[x]];
		return;
	}
	int mid = (tr[p].l + tr[p].r) >> 1;
	if (x <= mid)
		update_tree(x, p << 1);
	else
		update_tree(x, p << 1 | 1);
	pushup(p);
}

matrix query(int l, int r, int p)
{
	if (tr[p].l >= l && tr[p].r <= r)
		return tr[p].mx;
	int mid = (tr[p].l + tr[p].r) >> 1;
	matrix res;
	if (r <= mid)	
		return query(l, r, p << 1);
	else if (l > mid)
		return query(l, r, p << 1 | 1);
	else
		return query(l, r, p << 1) * query(l, r, p << 1 | 1);
	return res;
}

void update_path(int u, int w)
{
	it[u].data[1][0] += w - a[u], a[u] = w;
	matrix ta, tb;
	while (u)
	{
		ta = query(dfsn[top[u]], ed[top[u]], 1);
		update_tree(dfsn[u], 1);
		tb = query(dfsn[top[u]], ed[top[u]], 1);
		u = fa[top[u]];
		it[u].data[0][0] += max(tb.data[0][0], tb.data[1][0]) - max(ta.data[0][0], ta.data[1][0]);
		it[u].data[0][1] = it[u].data[0][0];
		it[u].data[1][0] += tb.data[0][0] - ta.data[0][0];
	}
}

int main()
{
	scanf("%d%d", &n, &m);
	for (int i = 1; i <= n; i++)
		scanf("%d", &a[i]);
	for (int i = 1; i < n; i++)
	{
		int u, v;
		scanf("%d%d", &u, &v);
		add(u, v), add(v, u);
	}
	dfs1(1, 0);
	dfs2(1, 1);
	build(1, n, 1);
	while (m--)
	{
		int x, y;
		scanf("%d%d", &x, &y);
		update_path(x, y);
		matrix ans = query(dfsn[1], ed[1], 1);
		printf("%d\n", max(ans.data[1][0], ans.data[0][0]));
	}
	return 0;
}