1. 程式人生 > 實用技巧 >【BZOJ3640】JC的小蘋果(高斯消元)

【BZOJ3640】JC的小蘋果(高斯消元)

點此看題面

  • 一張\(n\)個點\(m\)條邊的無向圖,要從\(1\)號點走到\(n\)號點,初始體力為\(hp\)
  • 每當你走到編號為\(i\)的點時,體力都會失去\(a_i\),然後等概率選擇當前點的一條邊走出去。
  • 當體力值小於等於\(0\)的時候就失敗了,求走到\(n\)號點的概率。
  • \(n\le150,m\le5\times10^3,hp\le10^4,0\le a_i\le hp\)

成環的概率\(DP\)

\(f_{k,i}\)表示在體力值為\(k\)時走到\(i\)的概率,顯然有轉移方程:

\[f_{k,i}=\sum_{j=1}^{n-1}f_{k+a_i,j}\times \frac{w_{i,j}}{deg_j} \]

其中\(w_{i,j}\)表示\(i,j\)之間的邊數,注意此題有重邊有自環。

看起來這樣就完事了,但\(a_i\)可能等於\(0\),也就是說我們不能簡單地分層\(DP\),因為在\(k\)相同的狀態之間也可能存在轉移。

這種時候就要套路地想到高斯消元,把這個轉移式看作一個方程。

然而,如果直接暴力這麼去做顯然會\(T\)飛,因此要考慮優化。

高斯消元的優化

考慮我們總共要做\(hp\)次高斯消元,但實際上每次高斯消元的係數都是相同的,區別只在於等號右邊的值。

因此我們可以在一開始先做一遍高斯消元,預處理出\(p_{i,j}\)表示第\(j\)個式子等號右邊的值對第\(i\)個式子等號右邊的值的貢獻係數。

那麼接下來每次求解就變成\(O(n^2)\)了。

程式碼:\(O(n^3+n^2hp)\)

#include<bits/stdc++.h>
#define Tp template<typename Ty>
#define Ts template<typename Ty,typename... Ar>
#define Reg register
#define RI Reg int
#define Con const
#define CI Con int&
#define I inline
#define W while
#define N 150
#define M 5000
#define HP 10000
#define DB double
#define eps 1e-12
#define add(x,y) (e[++ee].nxt=lnk[x],e[lnk[x]=ee].to=y)
using namespace std;
int n,m,hp,a[N+5],d[N+5],w[N+5][N+5];DB f[HP+5][N+5];
namespace Gauss//高斯消元
{
	DB a[N+5][N+5],p[N+5][N+5],v[N+5],res[N+5];
	I void Add(CI i,CI j)//用第i行去消第j行
	{
		DB t=-a[j][i]/a[i][i];for(RI k=1;k<=n;++k) a[j][k]+=t*a[i][k],p[j][k]+=t*p[i][k];
	}
	I void Init()//初始化
	{
		RI i,j,k;DB t;for(i=1;i<=n;++i) for(p[i][i]=1,j=i+1;j<=n;++j) Add(i,j);//從上往下消成三角形
		for(i=n;i;--i) {for(j=1;j<=n;++j) p[i][j]/=a[i][i];for(a[i][i]=1,j=i-1;j;--j) Add(i,j);}//從下往上消得只剩對角線
	}
	I void Solve()//快速求解
	{
		RI i,j;for(i=1;i<=n;++i) for(j=1;j<=n;++j) res[i]+=p[i][j]*v[j];//根據預處理出的係數計算
	}
}
int main()
{
	RI i,j,k,x,y;for(scanf("%d%d%d",&n,&m,&hp),i=1;i<=n;++i) scanf("%d",a+i);
	for(i=1;i<=m;++i) scanf("%d%d",&x,&y),++w[x][y],++d[x],x^y&&(++w[y][x],++d[y]);//注意自環
	for(i=1;i<=n;++Gauss::a[i][i],++i) if(!a[i]) for(j=1;j^n;++j) Gauss::a[i][j]=-1.0*w[i][j]/d[j];//求出係數矩陣
	DB t=0;for(Gauss::Init(),Gauss::v[1]=1,k=hp;k;--k)
	{
		for(i=1;i<=n;++i) if(a[i]&&k+a[i]<=hp) for(j=1;j^n;++j) Gauss::v[i]+=f[k+a[i]][j]*w[i][j]/d[j];//求出等號右邊的值
		for(Gauss::Solve(),i=1;i<=n;++i) f[k][i]=Gauss::res[i],Gauss::v[i]=Gauss::res[i]=0;t+=f[k][n];//把值移到DP數組裡
	}return printf("%.8lf\n",t),0;//輸出答案
}