洛谷3244 落憶楓音 (拓撲圖dp+式子)
阿新 • • 發佈:2018-12-04
題目大意就是 給你一個DAG
然後新增一條邊 ,詢問以1為根的生成樹的個數
QWQ
首先假設沒有新增的邊
答案就應該是
QWQ就相當於每個點選擇一個父親。
那麼加入一條邊,我們會有一些不合法的情況,那就是包含一條
路徑,剩下隨便選的方案數。假設全集是
,然後路徑上的點的集合是
,那我們實際上求的就是
其中
表示
集合中所有點的入度的乘積
然後對於這個東西,我們可以考慮拓撲圖上dp的方式
來解決
//假設我們添加了一條x->y的邊,要想不合法,就是求y->x的路徑條數
//所以我們要將令起點,也就是y的初值f[y]=ans
void addedge(int x,int y)
{
nxt[++cnt]=point[x];
to[cnt]=y;
in[y]++;
point[x]=cnt;
}
int qsm(int i,int j)
{
int ans=1;
while (j)
{
if (j&1) ans=ans*i%mod;
i=i*i%mod;
j>>=1;
}
return ans;
}
void tpsort()
{
//cout<<ans<<endl;
for (int i=1;i<=n;i++)
{
if (!in[i]) q.push(i);
}
while (!q.empty())
{
int now = q.front();
q.pop();
//cout<<now<<endl;
//int ymh=0;
//if (now==y) ymh=1;
f[now]=f[now]*qsm(d[now],mod-2)%mod;
//cout<<now<<" "<<f[now]<<endl;
for (int i=point[now];i;i=nxt[i])
{
int p =to[i];
in[p]--;
f[p]=(f[p]+f[now])%mod;
if (!in[p]) q.push(p);
}
}
}
下面是整個的程式碼
// luogu-judger-enable-o2
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<queue>
#include<map>
#include<set>
#define mk makr_pair
#define ll long long
#define int long long
using namespace std;
inline int read()
{
int x=0,f=1;char ch=getchar();
while (!isdigit(ch)) {if (ch=='-') f=-1;ch=getchar();}
while (isdigit(ch)) {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
return x*f;
}
const int maxn = 2e5+1e2;
const int maxm = 2*maxn;
const int mod = 1e9+7;
int point[maxn],nxt[maxm],to[maxm];
int n,m;
int cnt,in[maxn];
queue<int> q;
int ans;
int f[maxn];
int x,y;
int d[maxn];
//假設我們添加了一條x->y的邊,要想不合法,就是求y->x的路徑條數
//所以我們要將令起點,也就是y的初值f[y]=ans
void addedge(int x,int y)
{
nxt[++cnt]=point[x];
to[cnt]=y;
in[y]++;
point[x]=cnt;
}
int qsm(int i,int j)
{
int ans=1;
while (j)
{
if (j&1) ans=ans*i%mod;
i=i*i%mod;
j>>=1;
}
return ans;
}
void tpsort()
{
//cout<<ans<<endl;
for (int i=1;i<=n;i++)
{
if (!in[i]) q.push(i);
}
while (!q.empty())
{
int now = q.front();
q.pop();
//cout<<now<<endl;
//int ymh=0;
//if (now==y) ymh=1;
f[now]=f[now]*qsm(d[now],mod-2)%mod;
//cout<<now<<" "<<f[now]<<endl;
for (int i=point[now];i;i=nxt[i])
{
int p =to[i];
in[p]--;
f[p]=(f[p]+f[now])%mod;
if (!in[p]) q.push(p);
}
}
}
signed main()
{
n=read(),m=read(),x=read(),y=read();
for (int i=1;i<=m;i++)
{
int u=read(),v=read();
addedge(u,v);
}
ans=1;
for (int i=2;i<=n;i++)
{
if (i==y) ans=ans*(in[i]+1)%mod,d[i]=in[i]+1;
else ans=ans*in[i]%mod,d[i]=in[i];
}
f[y]=ans;
if (x==1)
{
cout<<ans<<"\n";
return 0;
}
tpsort();
cout<<(ans-f[x]+mod)%mod<<endl;
return 0;
}