1. 程式人生 > 其它 >洛谷 P6833 [Cnoi2020]雷雨(set優化dijkstra,set、pair、struct的結合)

洛谷 P6833 [Cnoi2020]雷雨(set優化dijkstra,set、pair、struct的結合)

傳送門


解題思路

分別從a、b、c三個點求單源最短路。

然後列舉兩條道路相交的節點(i,j),因為是點權,所以答案為 \(dis[0][i][j]+dis[1][i][j]+dis[2][i][j]+e[i][j]\)

注意用set進行的堆優化,要防止set丟失元素,所以要對pair的第二維(存點的座標)的結構體進行合理的過載運算子。

雖然沒有固定的順序,但是因為pair和set的原因,需要過載運算子。

例如過載'<',如果直接return a.x<b.x,會導致吞掉dis和橫座標相等的點(因為假設a,b橫座標相等,則a<b返回false,b<a也返回false)。
而當我們改為return a.x<=b.x,則不會出現問題。依舊假設a,b橫座標相等,則a<b返回true,b<a也返回true。

就因為這個東西調了一上午。。。

有大佬也對此問題做出了回覆

AC程式碼

#include<cstdio>
#include<iostream>
#include<cstring>
#include<iomanip>
#include<cmath>
#include<algorithm>
#include<set>
using namespace std;
const int maxn=1005;
int n,m,a,b,c;
long long dis[3][maxn][maxn],e[maxn][maxn],ans=1e18;
struct node{
	int x,y;
	node(int x,int y):x(x),y(y){}
	friend bool operator <(node a,node b){
		return a.x<=b.x;
	}
};
node s(0,0);
void dij(int id){
	set<pair<long long,node> > q;
	dis[id][s.x][s.y]=0;
	q.insert(make_pair(0,s));
	while(!q.empty()){
		node u=q.begin()->second;
		q.erase(q.begin());
		if(u.x>1&&dis[id][u.x-1][u.y]>dis[id][u.x][u.y]+e[u.x][u.y]){
			node to(u.x-1,u.y);
			q.erase(make_pair(dis[id][u.x-1][u.y],to));
			dis[id][u.x-1][u.y]=dis[id][u.x][u.y]+e[u.x][u.y];
			q.insert(make_pair(dis[id][u.x-1][u.y],to));
		}
		if(u.y>1&&dis[id][u.x][u.y-1]>dis[id][u.x][u.y]+e[u.x][u.y]){
			node to(u.x,u.y-1);
			q.erase(make_pair(dis[id][u.x][u.y-1],to));
			dis[id][u.x][u.y-1]=dis[id][u.x][u.y]+e[u.x][u.y];
			q.insert(make_pair(dis[id][u.x][u.y-1],to));
		}
		if(u.x<n&&dis[id][u.x+1][u.y]>dis[id][u.x][u.y]+e[u.x][u.y]){
			node to(u.x+1,u.y);
			q.erase(make_pair(dis[id][u.x+1][u.y],to));
			dis[id][u.x+1][u.y]=dis[id][u.x][u.y]+e[u.x][u.y];
			q.insert(make_pair(dis[id][u.x+1][u.y],to));
		}
		if(u.y<m&&dis[id][u.x][u.y+1]>dis[id][u.x][u.y]+e[u.x][u.y]){
			node to(u.x,u.y+1);
			q.erase(make_pair(dis[id][u.x][u.y+1],to));
			dis[id][u.x][u.y+1]=dis[id][u.x][u.y]+e[u.x][u.y];
			q.insert(make_pair(dis[id][u.x][u.y+1],to));
		}
	}
}
int main(){
	ios::sync_with_stdio(false);
	memset(dis,0x3f,sizeof(dis));
	cin>>n>>m>>a>>b>>c;
	for(int i=1;i<=n;i++){
		for(int j=1;j<=m;j++){
			cin>>e[i][j];
		}
	}
	s.x=1;
	s.y=a;
	dij(0);
	s.x=n;
	s.y=b;
	dij(1);
	s.y=c;
	dij(2);
	for(int i=1;i<=n;i++){
		for(int j=1;j<=m;j++){
			ans=min(ans,dis[0][i][j]+dis[1][i][j]+dis[2][i][j]+e[i][j]);
		}
	}
	cout<<ans;
    return 0;
}