1. 程式人生 > >Prim和Kruskal求最小生成樹

Prim和Kruskal求最小生成樹

前兩天TCP/IP協議課上,老師談到圖論中的幾個簡單演算法,發現自己不是很熟練,所以馬上掛了一套MST的題目來練練手,題目很簡單,但由於課程比較多,所以還有三個題沒來得及刷,同時在此%一波典

假設給出了一個圖G<V,E>

首先是Prim演算法,基於貪心的思想,任選一個點形成一棵樹,然後不斷的從剩下的點中挑出距離這棵樹中任意一個點距離最小的點,不斷的重複這個操作,直到n個點都加入到樹中形成最小生成樹.時間複雜度為O(|V|^2)

關於Kruskal演算法,也是基於貪心的思想,同時利用了並查積,先選出一條最小的邊,這個邊形成一個邊的集合,然後再選出一條最小的邊,如果這條邊的兩個點的祖先不是同一個的話,把這條邊加入到邊的集合中,直到找到n-1條邊為止.時間複雜度為O(ElogE).

從時間複雜度來看,Prim適用於點比較少的圖,Kruskal適合邊較少的圖.

最小生成樹基礎題:

/*Prim求最小生成樹*/
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<iostream>

using namespace std;
#define mem(a,b) memset(a,b,sizeof(a))
#define inf 0x3f3f3f3f
const int maxn = 1e2+10;
int n,q;
int mp[maxn][maxn];
int lowcost[maxn];//表示以i為終點的邊的最小權值,當lowcost[i]=0說明以i為終點的邊的最小權值=0,也就是表示i點加入了MST
int mst[maxn];//表示對應lowcost[i]的起點,即說明邊<mst[i],i>是MST的一條邊,當mst[i]=0表示起點i加入MST

int Prim()
{
    //初始化lowcost和mst
    mst[1] = 0;
    lowcost[1] = -1;//先把第一個點加到生成樹裡面
    for( int i = 2; i<= n; i++)
    {
        lowcost[i] = mp[1][i];//初始化以i為終點的最小權值
        mst[i] = 1;//以哪個點為起點才能得到的最小權值
    }

    int mn,mnid;//當前邊的最小權值,加入的點的編號
    int sum = 0;//最小權值和
    for( int i = 2; i <= n; i++)
    {
        mn = inf;
        mnid = 0;
        for( int j = 2; j <= n; j++)
        {
            // printf("j=%d lowcost[j]=%d mn=%d\n",j,lowcost[j],mn);
            if( lowcost[j] < mn && lowcost[j] != -1)//當前權值小且沒有加入到樹中
            {
                mn = lowcost[j];
                mnid = j;
            }
        }
        sum += mn;
        // printf("sum=%d mnid=%d\n",sum,mnid);
        lowcost[mnid] = -1;
        for( int j = 2; j <= n; j++)
        {
            if( mp[mnid][j] < lowcost[j])
            {
                lowcost[j] = mp[mnid][j];
                mst[j] = mnid;
            }
        }

    }
    return sum;
}

int main()
{
    while( ~scanf("%d",&n))
    {
        for( int i = 1; i <= n; i++)
            for( int j = 1; j <= n; j++)
                scanf("%d",&mp[i][j]);
        scanf("%d",&q);
        while( q--)
        {
            int u,v;
            scanf("%d%d",&u,&v);
            mp[u][v] = mp[v][u] = 0;//這是別人已經修好了的路,要用的話就是零消費
        }

        int ans = 0;
        ans = Prim();
        printf("%d\n",ans);
    }

    return 0;
}


/*Prim求最小生成樹*/
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<iostream>

using namespace std;
#define mem(a,b) memset(a,b,sizeof(a))
#define inf 0x3f3f3f3f
const int maxn = 1e3+10;
const int maxm = 1e4+10;
int n,m;
int mp[maxn][maxn];
int lowcost[maxn];
int mst[maxn];

int Prim()
{
    //初始化
    lowcost[0] = -1;
    mst[0] = 0;
    for( int i = 1; i < n; i++)
    {
        lowcost[i] = mp[0][i];
        mst[i] = 0;
    }

    int mn,id;
    int sum = 0;
    for( int i = 1; i < n; i++)
    {
        mn = inf;
        id = 0;
        for( int j = 1; j < n; j++)
        {
            if( lowcost[j] < mn && lowcost[j] != -1)
            {
                mn = lowcost[j];
                id = j;
            }
        }

        sum += lowcost[id];
        lowcost[id] = -1;
        for( int j = 1; j < n; j++)
        {
            if( mp[id][j] < lowcost[j])
            {
                lowcost[j] = mp[id][j];
                mst[j] = id;
            }
        }
    }

    return sum;
}

int main()
{
    int flak = false;
    while( ~scanf("%d%d",&n,&m))
    {
        mem(mp,inf);
        int u,v,x;
        for( int i = 1; i <= m; i++)
        {
            scanf("%d%d%d",&u,&v,&x);
            if( u != v)//存在重邊的情況
            {
                mp[u][v] = min(mp[u][v],x);
                mp[v][u] = min(mp[v][u],x);
            }
        }

        int ans = Prim();
        int cnt = 0;
        for( int i = 0; i < n; i++)
            if( lowcost[i] == -1)
                cnt++;
        if( cnt != n)
            puts("impossible");
        else
            printf("%d\n",ans);
        puts("");
    }

    return 0;
}


/*Prim*/
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<iostream>

using namespace std;
#define mem(a,b) memset(a,b,sizeof(a))
#define inf 0x3f3f3f3f
const int maxn = 1e2+10;
int n;
int mp[maxn][maxn];
int lowcost[maxn];
int mst[maxn];

int Prim()
{
    lowcost[1] = -1, mst[1] = 0;
    for( int i = 2; i <= n; i++)
        lowcost[i] = mp[1][i], mst[i] = 1;

    int mn, id;
    int sum = 0;
    for( int i = 2; i <= n; i++)
    {
        mn = inf;
        id = 0;
        for( int j = 2; j <= n; j++)
        {
            if( lowcost[j] < mn && lowcost[j] != -1)
            {
                mn = lowcost[j];
                id = j;
            }
        }

        sum += lowcost[id];
        lowcost[id] = -1;
        for( int j = 2; j <= n; j++)
        {
            if( mp[id][j] < lowcost[j])
            {
                lowcost[j] = mp[id][j];
                mst[j] = id;
            }
        }
    }
    return sum;
}

int main()
{
    while( ~scanf("%d",&n) && n)
    {
        mem(mp,inf);
        int m = n*(n-1)/2;
        int u,v,x;
        for( int i = 1; i <= m; i++)
        {
            scanf("%d%d%d",&u,&v,&x);
            mp[u][v] = x;
            mp[v][u] = x;
        }

        int ans = Prim();
        printf("%d\n",ans);
    }

    return 0;
}


#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<iostream>

using namespace std;
#define mem(a,b) memset(a,b,sizeof(a))
#define inf 0x3f3f3f3f
const int maxn = 1e2+10;
int n,m;
int mp[maxn][maxn];
int lowcost[maxn];
int mst[maxn];

int Prim()
{
    lowcost[1] = -1;
    mst[1] = 0;
    for( int i = 2; i <= m; i++)
    {
        lowcost[i] = mp[1][i];
        mst[i] = 1;
    }

    int mn,id;
    int sum = 0;
    for( int i = 2; i <= m; i++)
    {
        mn = inf;
        id = 0;
        for( int j = 2; j <= m; j++)
        {
            if( lowcost[j] < mn && lowcost[j] != -1)
            {
                mn = lowcost[j];
                id = j;
            }
        }

        sum += lowcost[id];
        lowcost[id] = -1;
        for( int j = 2; j <= m; j++)
        {
            if( mp[id][j] < lowcost[j])
            {
                lowcost[j] = mp[id][j];
                mst[j] = id;
            }
        }
    }
    return sum;
}

int main()
{
    while( ~scanf("%d%d",&n,&m) && n)
    {
        mem(mp,inf);
        int u,v,x;
        for( int i = 0; i < n; i++)
        {
            scanf("%d%d%d",&u,&v,&x);
            mp[u][v] = min(mp[u][v],x);
            mp[v][u] = min(mp[v][u],x);
        }

        int ans = Prim();
        int cnt = 0;
        for( int i = 1; i <= m; i++)
        {
            if( lowcost[i] == -1)
                cnt++;
        }
        if( cnt != m)
            puts("?");
        else
            printf("%d\n",ans);
    }

    return 0;
}


#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<iostream>

using namespace std;
#define mem(a,b) memset(a,b,sizeof(a))
#define inf 1e20
#define eps 1e-8
const int maxn = 1e2+10;
int n;
double mp[maxn][maxn];
double lc[maxn];
int vis[maxn];
struct Point{
    int x,y;
    Point(){}
    Point( int _x, int _y)
    {
        x = _x, y = _y;
    }
};
Point p[maxn];

int sgn( double x)
{
    if( fabs(x) < eps)
        return 0;
    if( x < 0)
        return -1;
    return 1;
}

int sqr( int x)
{
    return x*x;
}

double dist( Point a, Point b)
{
    return sqrt(1.0*(sqr(a.x-b.x)+sqr(a.y-b.y)));
}

double Prim()
{
    mem(vis,0);
    vis[1] = 1;
    for( int i = 2; i <= n; i++)
        lc[i] = mp[1][i];

    double mn;
    int id;
    double sum = 0;
    for( int i = 2; i <= n; i++)
    {
        mn = inf;
        id = 0;
        for( int j = 2; j <= n; j++)
        {
            if( !vis[j] && lc[j] < mn)
            {
                mn = lc[j];
                id = j;
            }
        }
        sum += mn;
        vis[id] = 1;
        for( int j = 2; j <= n; j++)
        {
            if( !vis[j] && mp[id][j] < lc[j])
                lc[j] = mp[id][j];
        }
    }
    return sum;
}

int main()
{
    int T;
    scanf("%d",&T);
    while( T--)
    {
        scanf("%d",&n);
        for( int i = 1; i <= n; i++)
            scanf("%d%d",&p[i].x,&p[i].y);

        for( int i = 1; i <= n; i++)
        {
            for( int j = 1; j <= n; j++)
            {
                mp[i][j] = inf;
                double dis = dist(p[i],p[j]);
                if( sgn(dis-10.0) >= 0 && sgn(dis-1000.0) <= 0)
                    mp[i][j] = dis;
            }
        }
        double ans = Prim();
        if( ans >= inf)
            puts("oh!");
        else
            printf("%.1f\n",ans*100);
    }

    return 0;
}


#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<iostream>

using namespace std;
#define mem(a,b) memset(a,b,sizeof(a))
#define inf 0x3f3f3f3f
const int maxn = 5e2+10;
int n,m,k;
int mp[maxn][maxn];
int con[maxn];
int lc[maxn];
bool vis[maxn];
int ans;

void Prim()
{
    mem(vis,0);
    vis[1] = 1;
    for( int i = 2; i <= n; i++)
        lc[i] = mp[1][i];

    int pos = 1;
    ans = 0;
    for( int i = 2; i <= n; i++)
    {
        int mn = inf;
        for( int j = 2; j <= n; j++)
        {
            if( !vis[j] &&  lc[j] < mn)
            {
                mn = lc[j];
                pos = j;
            }
        }

        ans += mn;
        if( mn == inf)
            return;
        vis[pos] = 1;

        for( int j = 2; j <= n; j++)
        {
            if( !vis[j] && mp[pos][j] < lc[j])
                lc[j] = mp[pos][j];
        }
    }
}

int main()
{
    int T;
    scanf("%d",&T);
    while( T--)
    {
        mem(mp,inf);
        scanf("%d%d%d",&n,&m,&k);
        int u,v,x;
        for( int i = 0; i < m; i++)
        {
            scanf("%d%d%d",&u,&v,&x);
            mp[u][v] = min(mp[u][v],x);
            mp[v][u] = min(mp[v][u],x);
        }

        while( k--)
        {
            int t;
            scanf("%d",&t);
            for( int i = 0; i < t; i++)
                scanf("%d",&con[i]);
            for( int i = 0; i < t; i++)
                for( int j = i+1; j < t; j++)
                    mp[con[i]][con[j]] = mp[con[j]][con[i]] = 0;
        }

        Prim();
        if( ans > inf)
            puts("-1");
        else
            printf("%d\n",ans);
    }


    return 0;
}


#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<iostream>

using namespace std;
#define mem(a,b) memset(a,b,sizeof(a))
#define inf 0x3f3f3f3f
const int maxn = 1e2+10;
int n;
int mp[maxn][maxn];
int lc[maxn];
bool vis[maxn];
int ans;

void Prim()
{
    mem(vis,0);
    vis[1] = 1;
    for( int i = 2; i <= n; i++)
        lc[i] = mp[1][i];

    int pos = 1;
    ans = 0;
    for( int i = 2; i <= n; i++)
    {
        int mn = inf;
        for( int j = 2; j <= n; j++)
        {
            if( !vis[j] &&  lc[j] < mn)
            {
                mn = lc[j];
                pos = j;
            }
        }

        ans += mn;
        if( mn == inf)
            return;
        vis[pos] = 1;

        for( int j = 2; j <= n; j++)
        {
            if( !vis[j] && mp[pos][j] < lc[j])
                lc[j] = mp[pos][j];
        }
    }
}

int main()
{
    while( ~scanf("%d",&n) && n)
    {
        mem(mp,inf);
        for( int i = 1; i < n; i++)
        {
            getchar();
            char ch;
            int u,v,x;
            scanf("%c",&ch);
            u = ch-'A'+1;
            int t;
            scanf("%d",&t);
            for( int i = 0; i < t; i++)
            {
                getchar();
                scanf("%c%d",&ch,&x);
                v = ch-'A'+1;
                mp[u][v] = min(mp[u][v],x);
                mp[v][u] = min(mp[v][u],x);
            }
        }
        // for( int i = 1; i <= n; i++)
        // {
        //     for( int j = 1; j <= n; j++)
        //     {
        //         if( mp[i][j] == inf)
        //             printf(" -1");
        //         else
        //             printf("%3d",mp[i][j]);
        //     }
        //     puts("");
        // }
        Prim();
        printf("%d\n",ans);
    }

    return 0;
}


/*Kruskal*/
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<iostream>

using namespace std;
#define mem(a,b) memset(a,b,sizeof(a))
#define inf 1e20;
const int maxn = 1e2+10;
const int maxm = 1e4+10;
int n;
struct Point{
    double x,y;
};
Point p[maxn];
struct Edge{
    int a,b;
    double dis;
    Edge(){}
    Edge( int _a, int _b, double _dis)
    {
        a = _a, b = _b, dis = _dis;
    }
};
Edge e[maxm];
int sz;//邊的條數
double ans;
int pre[maxn];

void Init()
{
    for( int i = 0; i < n; i++)
        pre[i] = i;
}

double sqr( double x)
{
    return x*x;
}

double dist( Point p0, Point p1)
{
    return sqrt(sqr(p0.x-p1.x)+sqr(p0.y-p1.y));
}

bool cmp( Edge e0, Edge e1)
{
    return e0.dis < e1.dis;
}

int Find( int x)
{
    if( x == pre[x])
        return x;
    return pre[x] = Find(pre[x]);
}

void Union( int x, int y)
{
    int px = Find(x);
    int py = Find(y);
    if( px != py)
        pre[px] = py;
}

void Kruskal()
{
    sort(e,e+sz,cmp);
    int edge = 0;
    for( int i = 0; i < sz && edge != n-1; i++)
    {
        if( Find(e[i].a) != Find(e[i].b))
        {
            Union(e[i].a,e[i].b);
            ans += e[i].dis;
            edge++;
        }
    }
}

int main()
{
    while( ~scanf("%d",&n))
    {
        for( int i = 0; i < n; i++)
            scanf("%lf%lf",&p[i].x,&p[i].y);
        sz = 0;
        for( int i = 0; i < n; i++)
        {
            for( int j = i+1; j < n; j++)
            {
                double dis = dist(p[i],p[j]);
                e[sz++] = Edge(i,j,dis);
            }
        }

        Init();
        ans = 0;
        Kruskal();
        printf("%.2f\n",ans);
    }

    return 0;
}