1. 程式人生 > 其它 >Solution -「CF113D」Museum

Solution -「CF113D」Museum

\(S(u)\) 表示 \(u\) 結點的相鄰結點的集合。

又記 \(p(u)\) 表示走到了 \(u\) 且下一步繼續留在 \(u\) 結點的概率,那麼下一步離開 \(u\) 結點的概率即為 \(1 - p(u)\)

\(f(i, j)\) 表示 Petya 在 \(i\) 且 Vasya 在 \(j\) 的概率。

可知所有的形如 \(f(i, i)\) 的狀態都是不能用於轉移的,因為它們已經是末狀態了。

故有

\[f(i, j) = \sum_{k \in S(i), k \neq j} f(k, j) \times \frac{1 - p(k)} {|S(k)|} \times p(j) + \sum_{k \in S(j), k \neq i} f(i, k) \times \frac{1 - p(k)} {|S(k)|} \times p(i) + \sum_{u \in S(i), v \in S(j), u \neq v} f(u, v) \times \frac{1 - p(u)} {|S(u)|} \times \frac{1 - p(v)} {|S(v)|} \]

顯然這個轉移是有後效性的,無法用簡單的遞推做。故考慮高斯消元,將該式轉換為我們熟悉的方程形式進行求解即可。

共有 \(n ^ 2\) 只方程,時間複雜度 \(O(n ^ 6)\)。注意邊界條件,即初始狀態時,轉移式需要加 \(1\),感性理解一下。

#include <cstdio>
#include <vector>
using namespace std;

int Abs(int x) { return x < 0 ? -x : x; }
int Max(int x, int y) { return x > y ? x : y; }
int Min(int x, int y) { return x < y ? x : y; }

int read() {
    int x = 0, k = 1;
    char s = getchar();
    while (s < '0' || s > '9') {
        if (s == '-')
            k = -1;
        s = getchar();
    }
    while ('0' <= s && s <= '9') {
        x = (x << 3) + (x << 1) + (s ^ 48);
        s = getchar();
    }
    return x * k;
}

void write(int x) {
    if (x < 0) {
        putchar('-');
        x = -x;
    }
    if (x > 9)
        write(x / 10);
    putchar(x % 10 + '0');
}

void print(int x, char c) {
    write(x);
    putchar(c);
}

const int MAXN = 23 * 23 + 5;

struct Elimination {
    bool free[MAXN];
    int n, m, rk, opt;
    double a[MAXN][MAXN], eps;
    Elimination() { eps = 1e-15; }
    Elimination(int N, int M) {
        n = N;
        m = M;
    }
    double Abs(double x) { return x < eps ? -x : x; }
    void Swap(double &x, double &y) {
        double t = x;
        x = y;
        y = t;
    }

    void clear() {
        for (int i = 1; i <= n; i++)
            for (int j = 1; j <= m; j++) a[i][j] = 0;
    }

    void calc() {
        int r = 1, c = 1;
        for (; r <= n && c <= m; r++, c++) {
            int pos = r;
            for (int i = r + 1; i <= n; i++)
                if (Abs(a[i][c]) > Abs(a[pos][c]))
                    pos = i;
            if (Abs(a[pos][c]) < eps) {
                r--;
                continue;
            }
            if (pos != r)
                for (int i = c; i <= m; i++) Swap(a[r][i], a[pos][i]);
            double t;
            for (int i = 1; i <= n; i++)
                if (i != r && Abs(a[i][c]) > eps) {
                    t = a[i][c] / a[r][c];
                    for (int j = m; j >= c; j--) a[i][j] -= t * a[r][j];
                }
        }
        rk = r;
    }
};

int deg[MAXN];
double p[MAXN];
vector<int> mp[MAXN];

void Add_Edge(int u, int v) {
    mp[u].push_back(v);
    mp[v].push_back(u);
}

struct node {
    int x, y;
    node() {}
    node(int X, int Y) {
        x = X;
        y = Y;
    }
    int Get(int n) { return (x - 1) * n + y; }
};

int main() {
    int n = read(), m = read(), x = read(), y = read();
    for (int i = 1, u, v; i <= m; i++) {
        u = read(), v = read();
        Add_Edge(u, v);
        deg[u]++, deg[v]++;
    }
    for (int i = 1; i <= n; i++) scanf("%lf", &p[i]);
    Elimination q;
    q.n = n * n;
    q.m = q.n + 1;
    q.clear();
    q.a[node(x, y).Get(n)][q.m] = -1;
    for (int i = 1; i <= n; i++)
        for (int j = 1; j <= n; j++) {
            int pos = node(i, j).Get(n);
            q.a[pos][pos]--;
            if (i != j)
                q.a[pos][pos] += p[i] * p[j];
            for (size_t k1 = 0; k1 < mp[i].size(); k1++)
                for (size_t k2 = 0; k2 < mp[j].size(); k2++) {
                    int u = mp[i][k1], v = mp[j][k2];
                    if (u == v)
                        continue;
                    q.a[pos][node(u, v).Get(n)] += (1 - p[u]) * (1 - p[v]) / deg[u] / deg[v];
                }
            for (size_t k = 0; k < mp[i].size(); k++)
                if (mp[i][k] != j)
                    q.a[pos][node(mp[i][k], j).Get(n)] += (1 - p[mp[i][k]]) / deg[mp[i][k]] * p[j];
            for (size_t k = 0; k < mp[j].size(); k++)
                if (mp[j][k] != i)
                    q.a[pos][node(i, mp[j][k]).Get(n)] += (1 - p[mp[j][k]]) / deg[mp[j][k]] * p[i];
        }
    q.calc();
    for (int i = 1; i <= n; i++)
        printf("%.9f\n", q.a[node(i, i).Get(n)][q.m] / q.a[node(i, i).Get(n)][node(i, i).Get(n)]);
    return 0;
}