CF739E Gosha is hunting(費用流/凸優化dp)
紀念合格考爆炸。
其實這個題之前就寫過部落格了,qwq但是不小心弄丟了,所以今天來補一下。
首先,一看到球的個數的限制,不難相當用網路流的流量來限制每個球使用的數量。
由於涉及到最大化期望,所以要使用最大費用最大流。
我們新建兩個點\(ss,tt\),分別表示兩種球。
那麼我們現在考慮應該怎麼計算期望呢。
首先,如果假設如果對於一個怪物用一個球,那麼連邊也就比較容易了
對於一個怪物\(x\)
我們\(ss -> x\),費用為\(p[i]\),流量為1。表示一個球在一個怪物上只能用一次。
\(tt\)也是同理。
然後對於每一個\(x->t\),費用是\(0\),流量是\(1\)
但是,要是每次不要求只能用一個球應該怎麼做呢。
我們考慮,這條邊的費用應該是多少。
兩個球都用的期望應該是\(1-(1-p_i)(1-q_i)\)
經過展開,我們發現應該是\(p_i+q_i-p_i\times q_i\)
那麼由於我們發現,由於用了兩個球,所以已經獲得了二者之和的收益,那麼在這一側,只需要在上述建圖的基礎上\(x->t\),費用是\(-p_i\times q_i\)即可。
最後跑一發最大費用最大流就能通過這個題qwq時間複雜度玄學。
#include<bits/stdc++.h> #define pb push_back #define mk make_pair #define ll long long #define db double using namespace std; inline int read() { int x=0,f=1;char ch=getchar(); while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();} while (isdigit(ch)){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();} return x*f; } const int maxn = 4010; const int maxm = 3e6+1e2; const double eps = 1e-10; int point[maxn],nxt[maxm],to[maxm],pre[maxm],from[maxn]; double dis[maxn]; int vis[maxn]; double cost[maxm]; int flow[maxm]; double ans; int n,m,cnt=1; int s,t; void addedge(int x,int y,db w,int f) { nxt[++cnt]=point[x]; pre[cnt]=x; to[cnt]=y; cost[cnt]=w; flow[cnt]=f; point[x]=cnt; } void insert(int x,int y,db w,int f) { addedge(x,y,w,f); addedge(y,x,-w,0); } queue<int> q; bool spfa(int s) { for (int i=1;i<=maxn-3;i++) dis[i]=-1e9; memset(vis,0,sizeof(vis)); q.push(s); dis[s]=0; while (!q.empty()) { int x = q.front(); q.pop(); vis[x]=0; for (int i=point[x];i;i=nxt[i]) { int p = to[i]; if (dis[p]-(dis[x]+cost[i])<-eps && flow[i]>0) { from[p]=i; dis[p]=dis[x]+cost[i]; if (!vis[p]) { q.push(p); vis[p]=1; } } } } if (dis[t]==-1e9) return false; return true; } void mcf() { double x = 1e9; for (int i=from[t];i;i=from[pre[i]]) x=min(x,1.0*flow[i]); for (int i=from[t];i;i=from[pre[i]]) { flow[i]-=x; flow[i^1]+=x; ans+=x*cost[i]; } } void solve() { while (spfa(s)) mcf(); } db a[maxn],b[maxn]; int ss,sss; int aa,bb; int main() { n=read(),aa=read(),bb=read(); s=maxn-10; ss=s+1; t=s+3; sss=ss+1; insert(s,ss,0,aa); insert(s,sss,0,bb); for (int i=1;i<=n;i++) scanf("%lf",&a[i]); for (int i=1;i<=n;i++) scanf("%lf",&b[i]); for (int i=1;i<=n;i++) { insert(ss,i,a[i],1); insert(sss,i,b[i],1); insert(i,t,0,1); insert(i,t,-a[i]*b[i],1); } solve(); printf("%.4lf\n",ans); return 0; }
但是其實這個題的正解是凸優化\(dp\)
首先,先做一個最\(naive\)的想法。
我們令\(dp[i][j][k]\)表示前\(i\)個怪物,用了\(j\)一號球,用了\(k\)個二號球
那麼轉移也是顯然的。
每次只需要討論一下對於當前的怪物是用幾個球,用哪個即可。
但是這樣的複雜度是\(O(n^3)\)的。
顯然沒有辦法通過。
考慮怎麼優化。
由於題目中涉及到的正好用幾個球,並且通過打表發現函式是凸的,那麼我們就可以直接用凸優化來優化掉一維。
(其實是可以直接優化兩個的,但是我太懶,所以沒寫。)
我們對於當前二分的值,表示每選一個二號球,就可以多得到\(mid\)的期望。不限制選的個數。
那麼不難得到下面的這個轉移式子。
dp[i][j]=dp[i-1][j];
dp[i][j]=max(dp[i][j],dp[i-1][j]+(ymh){bb[i],1});
if (j)
{
dp[i][j]=max(dp[i][j],dp[i-1][j-1]+(ymh){a[i],0});
dp[i][j]=max(dp[i][j],dp[i-1][j-1]+(ymh){both[i],1});
}
然後通過調整\(mid\),通過正好選到\(k\)個二號球。
最後求一個\(dp\)陣列,然後記得把貢獻減去就行。
時間複雜度\(n^2log\),非常優秀。
(其實是如果精度太小會\(WA\),精度太大會\(TLE\))
但是完全可以做到\(nlog^2\)的。
給程式碼。
#include<bits/stdc++.h>
#define pb push_back
#define mk make_pair
#define ll long long
#define db double
using namespace std;
inline int read()
{
int x=0,f=1;char ch=getchar();
while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
while (isdigit(ch)){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
return x*f;
}
const int maxn = 2010;
const db eps = 1e-6;
struct ymh{
db val;
int num;
ymh operator + (const ymh &b) const
{
return (ymh){val+b.val,num+b.num};
}
};
ymh dp[maxn][maxn];
db a[maxn],b[maxn];
int n;
db l=-4,r=4;
inline int dcmp(double x,double y)
{
return x-y<-eps ? -1 : (x-y>eps ? 1 : 0);
}
inline ymh max(ymh a,ymh b)
{
int now = dcmp(a.val,b.val);
if (now==0)
{
if (a.num<b.num) return a;
else return b;
}
else
{
if(now==-1) return b;
else return a;
}
}
int numa,numb;
db aa[maxn];
db bb[maxn];
db both[maxn];
bool check(db lim)
{
for (int i=1;i<=n;i++) aa[i]=a[i];
for (int i=1;i<=n;i++) bb[i]=b[i]+lim;
for (int i=1;i<=n;i++) both[i]=1.0-(1.0-a[i])*(1.0-b[i])+lim;
for (register int i=1;i<=n;++i)
{
for (register int j=0;j<=numa;++j)
{
dp[i][j]=dp[i-1][j];
dp[i][j]=max(dp[i][j],dp[i-1][j]+(ymh){bb[i],1});
if (j)
{
dp[i][j]=max(dp[i][j],dp[i-1][j-1]+(ymh){a[i],0});
dp[i][j]=max(dp[i][j],dp[i-1][j-1]+(ymh){both[i],1});
}
// if (dp[i][j].num>numb) return false;
}
}
return dp[n][numa].num<=numb;
}
int main()
{
n=read(),numa=read(),numb=read();
for (int i=1;i<=n;i++) scanf("%lf",&a[i]);
for (int i=1;i<=n;i++) scanf("%lf",&b[i]);
double ans=0;
while (r-l>=eps)
{
db mid = (l+r)/2;
// memset(dp,0,sizeof(dp));
if (check(mid)) l=mid,ans=mid;
else r=mid;
//printf("%.4lf %d\n",mid,dp[n][numa].num);
}
//cout<<1<<endl;
//printf("%.4lf\n",ans);
//memset(dp,0,sizeof(dp));
check(ans);
//printf("")
//printf("%.4lf %d\n",dp[n][numa].val,dp[n][numa].num);
printf("%.4lf",dp[n][numa].val-1.0*numb*ans);
return 0;
}