[學習筆記]動態dp
其實就過了模板。
感覺就是帶修改的dp
【模板】動態dp
給定一棵n個點的樹,點帶點權。
有m次操作,每次操作給定x,y表示修改點x的權值為y。
你需要在每次操作之後求出這棵樹的最大權獨立集的權值大小。
n,m<=1e5
參考題解:shadowice1984
n^2 DP簡單又自然。
但是對於1e5次修改就不行了。
每一次修改會影響整個到根的鏈上的值。
採用樹剖。
ldp[i][0/1]表示i選不選,對於所有的輕兒子dp值。
dp[i][0/1]表示i選不選,對於總共的所有兒子的dp值。
ldp[i][0]=∑max(ldp[lightson][1],ldp[lightson][0])
ldp[i][1]=∑ldp[lightson][0]
dp[i][0]=ldp[i][0]+max(dp[heavyson][1],dp[heavyson][0])
dp[i][1]=ldp[i][1]+dp[heavyson][0]
可以先把這個dp都求出來。
然後怎麼維護?自然要用線段樹維護dfs序。
採用矩陣。
a*b定義為:
c[i][j]=max(a[i][k]+b[k][j])
有結合律。
線段樹維護區間矩陣乘積。(注意從右往左乘,自下而上)
只要在最前面乘上一個初始矩陣
第一行是0,第二行是-inf的矩陣。
就可以求出某個點的最終dp值了。
修改的時候,暴力修改這個 點的ldp0,ldp1
但是還會影響這個fa[top[x]]的ldp0,ldp1
所以要求出dp[top[x]],dp[top[y]]為了避免常數過大,
用一個數組記錄dp值,然後把前後兩次最大值的差值來修改fa[top[x]]的ldp0,ldp1
然後跳一條鏈,到fa[top[x]]
這樣單次修改log^2n
每次返回max(dp[1][0],dp[1][1])
普通線段樹版:(3000ms)
#include<bits/stdc++.h> #define reg register int #define il inline #definenumb (ch^'0') #define mid ((l+r)>>1) using namespace std; typedef long long ll; il void rd(int &x){ char ch;x=0;bool fl=false; while(!isdigit(ch=getchar()))(ch=='-')&&(fl=true); for(x=numb;isdigit(ch=getchar());x=x*10+numb); (fl==true)&&(x=-x); } namespace Miracle{ const int N=1e5+5; const int inf=0x3f3f3f3f; int n,m; struct node{ int nxt,to; }e[2*N]; int hd[N],cnt; void add(int x,int y){ e[++cnt].nxt=hd[x]; e[cnt].to=y; hd[x]=cnt; } struct tr{ int a[3][3]; void init(int x,int y){//x:ldp0 y:ldp1 a[1][1]=x,a[2][1]=x; a[1][2]=y,a[2][2]=-inf; } void pre(){ memset(a,-inf,sizeof a); } void st(){ a[1][1]=0,a[1][2]=-inf,a[2][1]=-inf,a[2][2]=0; } tr operator *(const tr& b){ tr c;c.pre(); for(reg i=1;i<=2;++i){ for(reg k=1;k<=2;++k){ for(reg j=1;j<=2;++j){ c.a[i][j]=max(c.a[i][j],a[i][k]+b.a[k][j]); } } }return c; } void op(){ cout<<left<<setw(10)<<a[1][1]<<" "<<left<<setw(10)<<a[1][2]<<endl; cout<<left<<setw(10)<<a[2][1]<<" "<<left<<setw(10)<<a[2][2]<<endl; cout<<endl; } }s[N],t[4*N],A; int dfn[N],top[N],dfn2[N],fdfn[N],sz[N],dep[N],son[N]; int nd[N];//tot;//num of heavy chain int fa[N]; int df; int ldp[N][2],dp[N][2]; int w[N]; void dfs1(int x,int d){ dep[x]=d; sz[x]=1; for(reg i=hd[x];i;i=e[i].nxt){ int y=e[i].to; if(y==fa[x]) continue; fa[y]=x; dfs1(y,d+1); if(sz[y]>sz[son[x]]){ son[x]=y; } } } void dfs2(int x){ dfn[x]=++df;fdfn[df]=x; if(!top[x]) { top[x]=x;nd[top[x]]=x; } if(son[x]) top[son[x]]=top[x],nd[top[x]]=son[x],dfs2(son[x]); dp[x][1]=w[x]; ldp[x][1]=w[x]; for(reg i=hd[x];i;i=e[i].nxt){ int y=e[i].to; if(y==son[x]||y==fa[x]) continue; dfs2(y); ldp[x][0]+=max(dp[y][0],dp[y][1]); ldp[x][1]+=dp[y][0]; } if(son[x]){ dp[x][1]=ldp[x][1]+dp[son[x]][0]; dp[x][0]=ldp[x][0]+max(dp[son[x]][1],dp[son[x]][0]); } s[x].init(ldp[x][0],ldp[x][1]); } void pushup(int x){ t[x]=t[x<<1|1]*t[x<<1]; } void build(int x,int l,int r){ if(l==r){ t[x]=s[fdfn[l]];return; } build(x<<1,l,mid);build(x<<1|1,mid+1,r); pushup(x); } tr query(int x,int l,int r,int L,int R){ if(L<=l&&r<=R){ return t[x]; } tr ret;ret.st(); if(mid<R) ret=ret*query(x<<1|1,mid+1,r,L,R); if(L<=mid) ret=ret*query(x<<1,l,mid,L,R); return ret; } void add(int x,int l,int r,int to,int p,int c){ if(l==r){ if(p) t[x].a[1][2]+=c; else t[x].a[1][1]+=c,t[x].a[2][1]+=c; return; } if(to<=mid) add(x<<1,l,mid,to,p,c); else if(mid<to) add(x<<1|1,mid+1,r,to,p,c); pushup(x); } int tmp[2]; int to[2]; int upda(int x,int y){ tmp[0]=tmp[1]=0; to[0]=to[1]=0; tmp[1]=y-w[x]; w[x]=y; while(x){ tr anc=A*query(1,1,n,dfn[top[x]],dfn[nd[top[x]]]); to[0]=anc.a[1][1],to[1]=anc.a[1][2]; add(1,1,n,dfn[x],0,tmp[0]); add(1,1,n,dfn[x],1,tmp[1]); anc=A*query(1,1,n,dfn[top[x]],dfn[nd[top[x]]]); tmp[0]=max(anc.a[1][1],anc.a[1][2])-max(to[0],to[1]); tmp[1]=anc.a[1][1]-to[0]; x=fa[top[x]]; } tr ans=A*query(1,1,n,dfn[top[1]],dfn[nd[top[1]]]); return max(ans.a[1][1],ans.a[1][2]); } int main(){ scanf("%d%d",&n,&m); for(reg i=1;i<=n;++i)rd(w[i]); int x,y; for(reg i=1;i<=n-1;++i){ rd(x);rd(y);add(x,y);add(y,x); } dfs1(1,1); dfs2(1); build(1,1,n); A.a[1][1]=0,A.a[1][2]=0; A.a[2][1]=-inf,A.a[2][2]=-inf; while(m--){ rd(x);rd(y); printf("%d\n",upda(x,y)); } return 0; } } int main(){ Miracle::main(); return 0; } /* Author: *Miracle* Date: 2018/11/12 16:29:49 */
zkw線段樹版:(1500ms)
#include<bits/stdc++.h> #define reg register int #define il inline #define numb (ch^'0') #define mid ((l+r)>>1) using namespace std; typedef long long ll; il void rd(int &x){ char ch;x=0;bool fl=false; while(!isdigit(ch=getchar()))(ch=='-')&&(fl=true); for(x=numb;isdigit(ch=getchar());x=x*10+numb); (fl==true)&&(x=-x); } namespace Miracle{ const int N=1e5+5; const int inf=0x3f3f3f3f; int n,m; struct node{ int nxt,to; }e[2*N]; int hd[N],cnt; il void add(int x,int y){ e[++cnt].nxt=hd[x]; e[cnt].to=y; hd[x]=cnt; } struct tr{ int a[3][3]; void init(int x,int y){//x:ldp0 y:ldp1 a[1][1]=x,a[2][1]=x; a[1][2]=y,a[2][2]=-inf; } void pre(){ memset(a,-inf,sizeof a); } void st(){ a[1][1]=0,a[1][2]=-inf,a[2][1]=-inf,a[2][2]=0; } tr operator *(const tr& b) const{ tr c;c.pre(); for(reg i=1;i<=2;++i){ for(reg k=1;k<=2;++k){ for(reg j=1;j<=2;++j){ c.a[i][j]=max(c.a[i][j],a[i][k]+b.a[k][j]); } } }return c; } void op(){ cout<<left<<setw(10)<<a[1][1]<<" "<<left<<setw(10)<<a[1][2]<<endl; cout<<left<<setw(10)<<a[2][1]<<" "<<left<<setw(10)<<a[2][2]<<endl; cout<<endl; } }s[N],t[4*N],A; int dfn[N],top[N],dfn2[N],fdfn[N],sz[N],dep[N],son[N]; int nd[N];//tot;//num of heavy chain int fa[N]; int df; int ldp[N][2],dp[N][2]; int w[N]; il void dfs1(int x,int d){ dep[x]=d; sz[x]=1; for(reg i=hd[x];i;i=e[i].nxt){ int y=e[i].to; if(y==fa[x]) continue; fa[y]=x; dfs1(y,d+1); if(sz[y]>sz[son[x]]){ son[x]=y; } } } il void dfs2(int x){ dfn[x]=++df;fdfn[df]=x; if(!top[x]) { top[x]=x;nd[top[x]]=x; } if(son[x]) top[son[x]]=top[x],nd[top[x]]=son[x],dfs2(son[x]); dp[x][1]=w[x]; ldp[x][1]=w[x]; for(reg i=hd[x];i;i=e[i].nxt){ int y=e[i].to; if(y==son[x]||y==fa[x]) continue; dfs2(y); ldp[x][0]+=max(dp[y][0],dp[y][1]); ldp[x][1]+=dp[y][0]; } if(son[x]){ dp[x][1]=ldp[x][1]+dp[son[x]][0]; dp[x][0]=ldp[x][0]+max(dp[son[x]][1],dp[son[x]][0]); } s[x].init(ldp[x][0],ldp[x][1]); } int up; il void build(){ up=1; for(;up<=n+1;up<<=1); for(reg i=up;i<=up+up-1;++i){ if(i>=up+1&&i<=up+n) t[i]=s[fdfn[i-up]]; else t[i]=A; } for(reg i=up-1;i;--i) t[i]=t[i<<1|1]*t[i<<1]; } il void chan(int to,int c0,int c1){ reg i=up+to; t[i].a[1][1]+=c0;t[i].a[2][1]+=c0; t[i].a[1][2]+=c1; for(i>>=1;i;i>>=1){ t[i]=t[i<<1|1]*t[i<<1]; } // cout<<" after chan "<<endl; } il tr query(int l,int r){ tr le,ri;le.st();ri.st(); for(reg s=up+l-1,e=up+r+1;s^e^1;s>>=1,e>>=1){ // cout<<s<<" "<<e<<endl; if(!(s&1)) le=t[s^1]*le; if(e&1) ri=ri*t[e^1]; } return ri*le; } int tmp[2]; int to[2]; il int upda(int x,int y){ tmp[0]=tmp[1]=0; to[0]=to[1]=0; tmp[1]=y-w[x]; w[x]=y; while(x){ //tr anc=A*query(1,1,n,dfn[top[x]],dfn[nd[top[x]]]); to[0]=dp[top[x]][0],to[1]=dp[top[x]][1]; chan(dfn[x],tmp[0],tmp[1]); tr anc=A*query(dfn[top[x]],dfn[nd[top[x]]]); tmp[0]=max(anc.a[1][1],anc.a[1][2])-max(to[0],to[1]); tmp[1]=anc.a[1][1]-to[0]; dp[top[x]][0]=anc.a[1][1],dp[top[x]][1]=anc.a[1][2]; x=fa[top[x]]; } return max(dp[1][0],dp[1][1]); } int main(){ scanf("%d%d",&n,&m); for(reg i=1;i<=n;++i)rd(w[i]); int x,y; for(reg i=1;i<=n-1;++i){ rd(x);rd(y);add(x,y);add(y,x); } dfs1(1,1); dfs2(1); A.a[1][1]=0,A.a[1][2]=0; A.a[2][1]=-inf,A.a[2][2]=-inf; build(); while(m--){ rd(x);rd(y); printf("%d\n",upda(x,y)); } return 0; } } int main(){ // freopen("data.in","r",stdin); // freopen("my.out","w",stdout); Miracle::main(); return 0; } /* Author: *Miracle* Date: 2018/11/12 16:29:49 */