WQS二分題集
WQS二分,一種優化一類特殊DP的方法。
很多最優化問題都是形如“一堆物品,取與不取之間有限制。現在規定只取k個,最大/小化總收益”。
這類問題最自然的想法是:設f[i][j]表示前i個取j個的最大收益,轉移即可。復雜度O(n^2)。
那麽,如果在某些情況下,可以通過將問題稍作轉化,變成一個不強制選k個的DP,而最後DP出來的最優解一定正好選了k個,那麽問題就會簡化很多。
WQS二分就是基於這個思想。
首先考慮建一個二維坐標系,x軸是選的數的個數,y軸是最大收益,如果這個x-y圖像有凸性,那麽就可能通過給每個被選的數一個偏差值,將復雜度中的一個n變成log。因此,WQS二分又叫作凸優化/帶權二分。
來看一個題:[BZOJ2654]Tree
按照上面所說建立坐標系,發現x-y圖像的斜率單調遞增。是一個下凸函數。
我們考慮給每一條白邊減去某個值(一些地方是加上某個值,本質是一樣的)cost,那麽如果最終解選了x條邊,則得到的值為實際值-cost*x。考慮這個式子的幾何意義,就相當於將凸包通過斜率為cost的直線投影到y軸上。
可以發現,如果合適的選取cost值,可以使凸包上橫坐標為k的這個投影後的縱坐標最大,這時就可以直接得出這個點的值了。
我們二分cost,於是問題轉化為,求一棵每條白邊都減去cost的圖中的最小生成樹,直接求MST即可。
每次根據哪個點投影後的縱坐標最大調整二分邊界,這個類似於用一條直線去切這個凸包,根據切點橫坐標調整。
這裏需要註意一個問題,可能會存在k-1,k,k+1三點共線的情況,這時如果當前二分的直線正好與這三點平行。這是我們要保證它返回的切點一定在我們當前枚舉的二分區間之內。具體到這道題就是通過給等長的邊按顏色排序控制最終收益相同的方案中白邊的個數。
1 #include<cstdio> 2 #include<algorithm> 3 #define rep(i,l,r) for (int i=l; i<=r; i++) 4 typedef long long ll; 5 using namespace std; 6 7 constView Codeint N=100100; 8 int n,m,cnt,tot,k,ans,u[N],v[N],w[N],c[N],fa[N]; 9 struct E{ int u,v,w,c; }e[N]; 10 11 bool operator<(E a,E b){ return a.w==b.w ? a.c>b.c : a.w<b.w; } 12 int find(int x){ return x==fa[x] ? x : fa[x]=find(fa[x]); } 13 14 bool check(int x){ 15 tot=cnt=0; 16 rep(i,1,n) fa[i]=i; 17 rep(i,1,m){ 18 e[i].u=u[i]; e[i].v=v[i]; e[i].w=w[i]; e[i].c=c[i]; 19 if (!c[i]) e[i].w-=x; 20 } 21 sort(e+1,e+m+1); 22 rep(i,1,m){ 23 int p=find(e[i].u),q=find(e[i].v); 24 if (p!=q){ 25 fa[p]=q; tot+=e[i].w; 26 if (!e[i].c) cnt++; 27 } 28 } 29 return cnt<=k; 30 } 31 32 int main(){ 33 freopen("bzoj2654.in","r",stdin); 34 freopen("bzoj2654.out","w",stdout); 35 scanf("%d%d%d",&n,&m,&k); 36 rep(i,1,m) scanf("%d%d%d%d",&u[i],&v[i],&w[i],&c[i]),u[i]++,v[i]++; 37 int L=-105,R=105; 38 while(L<=R){ 39 int mid=(L+R)>>1; 40 if (check(mid)) L=mid+1,ans=tot+k*mid; else R=mid-1; 41 } 42 printf("%d\n",ans); 43 return 0; 44 }
再看一題:[BZOJ1150][CTSC2007]數據備份
這題的經典做法是可撤銷貪心,但也可以用WQS做。
首先同樣建出坐標系,發現是一個斜率單增的上凸包。先二分斜率去掉只選K個的限制,問題簡化成普通DP。
f[i][0/1]表示前i個數,第i個數選了/沒選,的最小代價。
1 #include<cstdio> 2 #include<cstring> 3 #include<algorithm> 4 #define rep(i,l,r) for (int i=(l); i<=(r); i++) 5 typedef long long ll; 6 using namespace std; 7 8 const int N=100010; 9 const ll inf=1e15; 10 int n,k,s[N]; 11 ll L,R,ans; 12 struct P{ ll v,x; }f[N][2]; 13 P min(P a,P b){ if (a.v<b.v || (a.v==b.v && a.x<b.x)) return a; else return b; } 14 15 bool jud(ll cost){ 16 memset(f,0x7f,sizeof(f)); 17 f[1][0]=(P){0,0}; 18 rep(i,2,n){ 19 f[i][0]=min(f[i-1][0],f[i-1][1]); 20 f[i][1]=(P){f[i-1][0].v+s[i]-s[i-1]-cost,f[i-1][0].x+1}; 21 } 22 f[n][0]=min(f[n][0],f[n][1]); 23 if (f[n][0].x<=k) { ans=f[n][0].v+k*cost; return 1; } else return 0; 24 } 25 26 int main(){ 27 freopen("bzoj1150.in","r",stdin); 28 freopen("bzoj1150.out","w",stdout); 29 scanf("%d%d",&n,&k); 30 rep(i,1,n) scanf("%d",&s[i]),R+=s[i]; 31 while (L<=R){ 32 ll mid=(L+R)>>1; 33 if (jud(mid)) L=mid+1; else R=mid-1; 34 } 35 printf("%lld\n",ans); 36 return 0; 37 }View Code
[BZOJ2151]種樹
同上題,設f[i][0/1][0/1]表示前i個數,第一個數選了/沒選,第i個數選了/沒選,的最大收益。
1 #include<cstdio> 2 #include<algorithm> 3 #include<cstring> 4 #define rep(i,l,r) for (int i=(l); i<=(r); i++) 5 using namespace std; 6 7 const int N=200010,inf=1e9; 8 int n,m,ans,a[N],L,R; 9 struct P{ int v,x; }f[N][2][2]; 10 P max(P a,P b){ if (a.v>b.v || (a.v==b.v && a.x<b.x)) return a; else return b; } 11 P add(P s,int b){ return (P){s.v+b,s.x+1}; } 12 13 bool jud(int cost){ 14 memset(f,0,sizeof(f)); 15 f[1][1][1]=(P){a[1]-cost,1}; f[1][0][1]=f[1][1][0]=(P){-inf,0}; 16 rep(i,2,n){ 17 f[i][0][0]=max(f[i-1][0][0],f[i-1][0][1]); 18 f[i][1][0]=max(f[i-1][1][0],f[i-1][1][1]); 19 f[i][0][1]=add(f[i-1][0][0],a[i]-cost); 20 f[i][1][1]=add(f[i-1][1][0],a[i]-cost); 21 } 22 P s=max(max(f[n][0][0],f[n][0][1]),f[n][1][0]); 23 if (s.x<=m) { ans=s.v+m*cost; return 1; } else return 0; 24 } 25 26 int main(){ 27 freopen("bzoj2151.in","r",stdin); 28 freopen("bzoj2151.out","w",stdout); 29 scanf("%d%d",&n,&m); 30 if (m>n/2) { puts("Error!"); return 0; } 31 rep(i,1,n) scanf("%d",&a[i]); 32 L=-1001; R=1001; 33 while (L<=R){ 34 int mid=(L+R)>>1; 35 if (jud(mid)) R=mid-1; else L=mid+1; 36 } 37 printf("%d\n",ans); 38 return 0; 39 }View Code
[BZOJ5311]貞魚
同樣先二分斜率去掉K的限制,問題變為求最小沖突。
f[i]表示前i個人的最小沖突,s[i][j]表示沖突表的二維前綴和,則有f[i]=max{f[j]+(s[i][i]-s[i][j]-s[j][i]+2*s[j][j])/2}。
同時這個DP是有決策單調性的,於是問題就由O(n^2*k)優化到了O(nlognlogk)。
1 #include<cstdio> 2 #include<algorithm> 3 #include<cstring> 4 #define rep(i,l,r) for (int i=(l); i<=(r); i++) 5 using namespace std; 6 7 const int N=4010; 8 int n,k,st,ed,s[N][N],f[N],g[N]; 9 struct P{ int x,l,r; }q[N]; 10 11 int rd(){ 12 int x=0; char ch=getchar(); 13 while (ch<‘0‘ || ch>‘9‘) ch=getchar(); 14 while (ch>=‘0‘ && ch<=‘9‘) x=(x<<3)+(x<<1)+(ch^48),ch=getchar(); 15 return x; 16 } 17 18 int cal(int j,int i){ return f[j]+((s[i][i]-s[i][j]-s[j][i]+s[j][j])>>1); } 19 20 bool chk(int i,int j,int k){ 21 int x=cal(i,k),y=cal(j,k); 22 return (x<y) || (x==y && g[i]<g[j]); 23 } 24 25 int find(int i,int j){ 26 int l=q[ed].l,r=n,res=0; 27 while (l<=r){ 28 int mid=(l+r)>>1; 29 if (chk(i,j,mid)) res=mid,r=mid-1; else l=mid+1; 30 } 31 return res; 32 } 33 34 void solve(int c){ 35 st=ed=1; q[1]=(P){0,0,n}; 36 rep(i,1,n){ 37 ++q[st].l; if (q[st].l>q[st].r) st++; 38 f[i]=cal(q[st].x,i)-c; g[i]=g[q[st].x]+1; 39 if (st>ed || chk(i,q[ed].x,n)){ 40 while (st<=ed && chk(i,q[ed].x,q[ed].l)) ed--; 41 if (st>ed) q[++ed]=(P){i,i,n}; 42 else{ 43 int x=find(i,q[ed].x); 44 q[ed].r=x-1; q[++ed]=(P){i,x,n}; 45 } 46 } 47 } 48 } 49 50 int main(){ 51 freopen("bzoj5311.in","r",stdin); 52 freopen("bzoj5311.out","w",stdout); 53 scanf("%d%d",&n,&k); 54 rep(i,1,n) rep(j,1,n) s[i][j]=s[i-1][j]+s[i][j-1]-s[i-1][j-1]+rd(); 55 int l=-s[n][n],r=0,res=0; 56 while (l<=r){ 57 int mid=(l+r)>>1; solve(mid); 58 if (g[n]<=k) res=mid,l=mid+1; else r=mid-1; 59 } 60 solve(res); printf("%d\n",f[n]+k*res); 61 return 0; 62 }View Code
[BZOJ5252]林克卡特樹
https://www.cnblogs.com/HocRiser/p/9055203.html
1 #include<cstdio> 2 #include<cstring> 3 #include<algorithm> 4 #define rep(i,l,r) for (int i=l; i<=r; i++) 5 typedef long long ll; 6 using namespace std; 7 8 const int N=300010; 9 int n,k,u,v,w,cnt,to[N<<1],nxt[N<<1],val[N<<1],h[N]; 10 ll mid,tot; 11 void add(int u,int v,int w){ to[++cnt]=v; val[cnt]=w; nxt[cnt]=h[u]; h[u]=cnt; } 12 struct P{ 13 ll x,y; 14 bool operator < (const P &b) const {return x==b.x? y>b.y : x<b.x;} 15 P operator + (const P &b) const {return (P){x+b.x,y+b.y};} 16 P operator + (int b) {return (P){x+b,y};} 17 }dp[3][N]; 18 P upd(P a){ return (P){a.x-mid,a.y+1}; } 19 20 void dfs(int u,int fa){ 21 dp[2][u]=max(dp[2][u],(P){-mid,1}); 22 for (int i=h[u],v; i; i=nxt[i]) 23 if ((v=to[i])!=fa){ 24 dfs(v,u); 25 dp[2][u]=max(dp[2][u]+dp[0][v],upd(dp[1][u]+dp[1][v]+val[i])); 26 dp[1][u]=max(dp[1][u]+dp[0][v],dp[0][u]+dp[1][v]+val[i]); 27 dp[0][u]=dp[0][u]+dp[0][v]; 28 } 29 dp[0][u]=max(dp[0][u],max(upd(dp[1][u]),dp[2][u])); 30 } 31 32 int main(){ 33 freopen("lct.in","r",stdin); 34 freopen("lct.out","w",stdout); 35 scanf("%d%d",&n,&k); k++; 36 rep(i,2,n) scanf("%d%d%d",&u,&v,&w),tot+=abs(w),add(u,v,w),add(v,u,w); 37 ll L=-tot,R=tot; 38 while (L<=R){ 39 mid=(L+R)>>1; memset(dp,0,sizeof(dp)); dfs(1,0); 40 if (dp[0][1].y<=k) R=mid-1; else L=mid+1; 41 } 42 memset(dp,0,sizeof(dp)); mid=L; dfs(1,0); printf("%lld\n",L*k+dp[0][1].x); 43 return 0; 44 }View Code
WQS二分的另外兩個題:CF958E2,CF739E
WQS二分題集