1. 程式人生 > 實用技巧 >HDU 1688 Sightseeing (次短路計數)

HDU 1688 Sightseeing (次短路計數)

題目傳送門

題目大意

\(s\)\(t\)的最短路條數加上比最短路大1的路徑條數\((1\leq n\leq 10^3,1\leq m\leq 10^4,1\leq L\leq 10^3,0\leq ans\leq10^9)(\)單向邊\()\)

Solution

最短路計數與次短路計數的模板題,只要注意判斷一下次短路是否正好為最短路\(+1\)即可。

每次取隊首元素來鬆弛最短路和次短路,注意只有鬆弛最短路失敗時才鬆弛次短路,具體見程式碼。

注意每次d陣列更新都要入隊

因為是多組資料,別忘了重置cnt以及其他陣列。

#include<bits/stdc++.h>
using namespace std;
#define maxn1 1005
#define maxn2 10005
#define INF 0x3f3f3f3f
int T;
template<typename T>void read(T& x){
	int f=0;x=0;char ch=getchar();
	while(ch<'0'||ch>'9'){f|=(ch=='-');ch=getchar();}
	while(ch>='0'&&ch<='9'){x=x*10+(ch^48);ch=getchar();}
	if(f)x=-x;
}
int head[maxn1],to[maxn2],nxt[maxn2],w[maxn2],cnt=0;
void add(int u,int v,int ww){
	nxt[++cnt]=head[u];
	to[cnt]=v;
	w[cnt]=ww;
	head[u]=cnt;
}
struct node{
	int d,x,flag;
	node(int d,int x,int flag):d(d),x(x),flag(flag){}
	bool operator < (const node& a)const{
		return d>a.d;
	}
};
int d[maxn1][2],vis[maxn1][2],num[maxn1][2];
int dij(int s,int t){
	memset(d,INF,sizeof(d));
	memset(num,0,sizeof(num));
	memset(vis,0,sizeof(vis));
	priority_queue<node>q;
	d[s][0]=0;
	num[s][0]=1;
	q.push(node(0,s,0));
	while(!q.empty()){
		int u=q.top().x,flag=q.top().flag;
		q.pop();
		if(vis[u][flag])continue;
		vis[u][flag]=1;
		for(int i=head[u];i!=-1;i=nxt[i]){
			int v=to[i];
			if(d[v][0]>d[u][flag]+w[i]){
				if(d[v][0]!=INF){//一個小優化 
				d[v][1]=d[v][0];
				num[v][1]=num[v][0];
			    q.push(node(d[v][1],v,1));//每次d陣列更新都要入隊 
			    }
				d[v][0]=d[u][flag]+w[i];
				num[v][0]=num[u][flag];
				q.push(node(d[v][0],v,0));
			}else if(d[v][0]==d[u][flag]+w[i]){
				num[v][0]+=num[u][flag];
			}else if(d[v][1]>d[u][flag]+w[i]){
				d[v][1]=d[u][flag]+w[i];
				num[v][1]=num[u][flag];
				q.push(node(d[v][1],v,1));
			}else if(d[v][1]==d[u][flag]+w[i]){
				num[v][1]+=num[u][flag];
			}
		}
	}
	int ans=num[t][0];
	if(d[t][1]==d[t][0]+1)
	    ans+=num[t][1];
	return ans;
}
int main(){
	read(T);
	int n,m,s,t,u,v,ww;
	while(T--){
		read(n),read(m);
		memset(head,-1,sizeof(head));
		cnt=0;//別忘了重置cnt
		for(int i=1;i<=m;++i){
			read(u),read(v),read(ww);
			add(u,v,ww);
		}
		read(s),read(t);
		printf("%d\n",dij(s,t));
	}
	return 0;
}