1. 程式人生 > 其它 >POJ1077 Eight(A* + 康託展開)

POJ1077 Eight(A* + 康託展開)

前置知識:康託展開 和 康託逆展開

解決什麼問題?

能構造一個 \(1\sim N\) 的全排列 和 \(0\sim N!-1\) 之間的雙射,在解決全排列的雜湊問題上有奇效。
康託展開即是將全排列對映到自然數,而康託逆展開則是將自然數對映到一個全排列。

怎麼解決?

對於一個 \(N\) 項全排列 \(\{p\}\),定義其康託展開就是其在所有 \(1\sim N\) 排列中的字典序的排名,再定義 \(rank[i]\) 表示除去 \(p_1 \sim p_{i-1}\),有多少個正整數小於 \(p_i\)。那麼顯然,有計算式:

\[Cantor(p) = \sum_{i = 1}^N rank[i] \times (N-i)! \]

稍微解釋一下,就是逐步確定總的排名,依次計算每一位對排名的貢獻。對於第 \(i\)

\(p_i\),比它小的可能性有 \(rank[i]\) 種,而每一種貢獻了 \((N-i)!\) 個排名。
那麼逆展開怎麼做?也只要逆向逐步確定即可,並且根據康託展開的定義,我們能且僅能確定一個全排列。
具體實現如下:

int cantor(vector<int> p) {
	int ret = 0;
	for (int i = 0; i < 9; i++) {
		int rnk = p[i]; for (int j = 0; j < i; j++) if (p[j] < p[i]) rnk--;
		ret += (rnk - 1) * fac[8 - i];
	}
	return ret; // rank in [0, 9! - 1]
}

vector<int> cantorRev(int num) {
	vector<int> ret, p; ret.resize(9); p.resize(9);
	for (int i = 0; i < 9; i++) p[i] = i + 1;
	for (int i = 8; i >= 0; i--) {
		int rnk = num / fac[i] + 1;
		for (int j = 0; j < 9; j++) if (p[j]) if (--rnk == 0) { ret[8 - i] = p[j]; p[j] = 0; break; }
		num %= fac[i];
	}
	return ret;
}

補充一下,我的實現中康託展開的複雜度是 \(O(N^2)\) 的,而我們顯然能夠使用樹狀陣列快速統計出 \(rank[i]\),時間複雜度就能降為 \(O(N\log N)\),但是本題中 \(N=9\),優化意義不大,因此沒有寫樹狀陣列。

這麼做有什麼好處?

相比線性雜湊等做法,程式碼實現方便、美觀,且這樣節省空間,因為構造的是 \(0 \sim N!-1\) 的對映,在時間複雜度上其實並沒有太大的優勢。

道理我都懂,這題怎麼做?

本題做法很多,有樸素 \(bfs\)、雙向 \(bfs\)\(A*\) 等等。
雙向 \(bfs\) 是個很不錯的做法,從 \(S\)\(T\) 同時開始搜尋,直到路徑第一次有交即可。
本文主要使用啟發式搜尋演算法 \(A*\)

實現,定義的啟發式估價函式是所有對應點對的曼哈頓距離之和,具體內容已經在《演算法競賽進階指南》中講述過,我在此不再贅述。
通過康託展開技巧,我們可以省去 \(map\) 等雜湊方法,程式碼也更加清晰。
剛開始我使用的是 \(vector\) 來操作,結果 \(T\) 飛了

展開檢視 vector 版程式碼
#include <iostream>
#include <cstdio>
#include <vector>
#include <cstring>
#include <cmath>
#include <queue>
#include <algorithm>
using namespace std;
#define mp make_pair
typedef pair<int, int> pii;
const int maxn = 500005;
const int inf = 0x3f3f3f3f;
int vis[maxn], dis[maxn], f[maxn];
int fac[10];
char dir[4] = {'u', 'd', 'l', 'r'};
int head[maxn], nxt[maxn << 2], tail[maxn << 2], type[maxn << 2], ecnt;

void init() {
	fac[0] = 1;
	for (int i = 1; i <= 9; i++) fac[i] = fac[i - 1] * i;
	memset(head, 0, sizeof(head));
}

int cantor(vector<int> p) {
	int ret = 0;
	for (int i = 0; i < 9; i++) {
		int rnk = p[i]; for (int j = 0; j < i; j++) if (p[j] < p[i]) rnk--;
		ret += (rnk - 1) * fac[8 - i];
	}
	return ret; // rank in [0, 9! - 1]
}

vector<int> cantorRev(int num) {
	vector<int> ret, p; ret.resize(9); p.resize(9);
	for (int i = 0; i < 9; i++) p[i] = i + 1;
	for (int i = 8; i >= 0; i--) {
		int rnk = num / fac[i] + 1;
		for (int j = 0; j < 9; j++) if (p[j]) if (--rnk == 0) { ret[8 - i] = p[j]; p[j] = 0; break; }
		num %= fac[i];
	}
	return ret;
}

int mmp[5][5], mmpT[5][5];
int manhattan(int num) {
	vector<int> p = cantorRev(num); int ret = 0;
	for (int i = 0; i < 9; i++) mmp[(i / 3) + 1][(i % 3) + 1] = p[i];
	for (int i = 0; i < 9; i++) mmpT[(i / 3) + 1][(i % 3) + 1] = i;
	for (int i = 1; i <= 3; i++) for (int j = 1; j <= 3; j++)
		for (int x = 1; x <= 3; x++) for (int y = 1; y <= 3; y++)
			if (mmp[i][j] == mmpT[x][y]) ret += abs(i - x) + abs(j - y);
	return ret;
}

void addedge(int u, int v, int t) {
	nxt[++ecnt] = head[u];
	head[u] = ecnt;
	tail[ecnt] = v;
	type[ecnt] = t;
}

void print(int num) {
	vector<int> tmp = cantorRev(num);
	for (int i = 0; i < 9; i++) cout << tmp[i] << " "; cout << endl;
}

pii last[maxn];
void Astar(int S, int T) {
	memset(dis, inf, sizeof(dis));
	memset(vis, 0, sizeof(vis));
	priority_queue<pii> pq;
	dis[S] = 0; pq.push(mp(-(0 + f[S]), S));
	while (!pq.empty()) {
		pii cur = pq.top(); pq.pop();
		int u = cur.second;
		if (vis[u]) continue; vis[u] = 1;
		if (u == T) break;
		for (int e = head[u]; e; e = nxt[e]) {
			int v = tail[e];
			if (dis[v] > dis[u] + 1) {
				dis[v] = dis[u] + 1;
				last[v] = mp(u, type[e]);
				pq.push(mp(-(dis[v] + f[v]), v));
			}
		}
	}
}

int main() {
	init(); vector<int> p; p.resize(9);
	for (int i = 0; i < 9; i++) {
		char c[5]; scanf("%s", c);
		if (c[0] == 'x') p[i] = 9;
		else p[i] = c[0] - '0';
	}
	int S = cantor(p), T = 0;
	for (int i = 0; i < fac[9]; i++) {
		p = cantorRev(i); int pos = 0;
		for (int j = 0; j < 9; j++) if (p[j] == 9) { pos = j; break; }
		if (pos - 3 >= 0) { // U
			swap(p[pos - 3], p[pos]);
			addedge(i, cantor(p), 0);
			swap(p[pos - 3], p[pos]);
		}
		if (pos + 3 < 9) { // D
			swap(p[pos + 3], p[pos]);
			addedge(i, cantor(p), 1);
			swap(p[pos + 3], p[pos]);
		}
		if (pos > 0 && pos / 3 == (pos - 1) / 3) { // L
			swap(p[pos - 1], p[pos]);
			addedge(i, cantor(p), 2);
			swap(p[pos - 1], p[pos]);
		}
		if (pos < 8 && pos / 3 == (pos + 1) / 3) { // R
			swap(p[pos + 1], p[pos]);
			addedge(i, cantor(p), 3);
			swap(p[pos + 1], p[pos]);
		}
	}
	for (int i = 0; i < fac[9]; i++) f[i] = manhattan(i);
	Astar(S, T);
	if (!vis[T]) puts("unsolvable");
	else {
		string ans;
		for (int u = T; u != S; u = last[u].first) ans.push_back(dir[last[u].second]);
		reverse(ans.begin(), ans.end());
		printf("%s\n", ans.c_str());
	}
	return 0;
}
改成陣列就能過了
展開檢視 AC 程式碼
#include <iostream>
#include <cstdio>
#include <vector>
#include <cstring>
#include <cmath>
#include <queue>
#include <algorithm>
using namespace std;
#define mp make_pair
typedef pair<int, int> pii;
const int maxn = 500005;
const int inf = 0x3f3f3f3f;
int vis[maxn], dis[maxn], f[maxn];
int fac[10];
char dir[4] = {'u', 'd', 'l', 'r'};
int head[maxn], nxt[maxn << 2], tail[maxn << 2], type[maxn << 2], ecnt;

void init() {
	fac[0] = 1;
	for (int i = 1; i <= 9; i++) fac[i] = fac[i - 1] * i;
	memset(head, 0, sizeof(head));
}

int cantor(int *p) {
	int ret = 0;
	for (int i = 0; i < 9; i++) {
		int rnk = p[i]; for (int j = 0; j < i; j++) if (p[j] < p[i]) rnk--;
		ret += (rnk - 1) * fac[8 - i];
	}
	return ret; // rank in [0, 9! - 1]
}

int mmp[5][5], mmpT[5][5];
int manhattan(int *p) {
	int ret = 0;
	for (int i = 0; i < 9; i++) mmp[(i / 3) + 1][(i % 3) + 1] = p[i];
	for (int i = 0; i < 9; i++) mmpT[(i / 3) + 1][(i % 3) + 1] = i;
	for (int i = 1; i <= 3; i++) for (int j = 1; j <= 3; j++)
		for (int x = 1; x <= 3; x++) for (int y = 1; y <= 3; y++)
			if (mmp[i][j] == mmpT[x][y]) ret += abs(i - x) + abs(j - y);
	return ret;
}

void addedge(int u, int v, int t) {
	nxt[++ecnt] = head[u];
	head[u] = ecnt;
	tail[ecnt] = v;
	type[ecnt] = t;
}

pii last[maxn];
void Astar(int S, int T) {
	memset(dis, inf, sizeof(dis));
	memset(vis, 0, sizeof(vis));
	priority_queue<pii> pq;
	dis[S] = 0; pq.push(mp(-(0 + f[S]), S));
	while (!pq.empty()) {
		pii cur = pq.top(); pq.pop();
		int u = cur.second;
		if (vis[u]) continue; vis[u] = 1;
		if (u == T) break;
		for (int e = head[u]; e; e = nxt[e]) {
			int v = tail[e];
			if (dis[v] > dis[u] + 1) {
				dis[v] = dis[u] + 1;
				last[v] = mp(u, type[e]);
				pq.push(mp(-(dis[v] + f[v]), v));
			}
		}
	}
}

int main() {
	init(); int p[9];
	for (int i = 0; i < 9; i++) {
		char c[5]; scanf("%s", c);
		if (c[0] == 'x') p[i] = 9;
		else p[i] = c[0] - '0';
	}
	int S = cantor(p), T = 0, id = 0;
	for (int i = 0; i < 9; i++) p[i] = i + 1;
	do {
		int pos = 0;
		for (int j = 0; j < 9; j++) if (p[j] == 9) { pos = j; break; }
		if (pos - 3 >= 0) { // U
			swap(p[pos - 3], p[pos]);
			addedge(id, cantor(p), 0);
			swap(p[pos - 3], p[pos]);
		}
		if (pos + 3 < 9) { // D
			swap(p[pos + 3], p[pos]);
			addedge(id, cantor(p), 1);
			swap(p[pos + 3], p[pos]);
		}
		if (pos > 0 && pos / 3 == (pos - 1) / 3) { // L
			swap(p[pos - 1], p[pos]);
			addedge(id, cantor(p), 2);
			swap(p[pos - 1], p[pos]);
		}
		if (pos < 8 && pos / 3 == (pos + 1) / 3) { // R
			swap(p[pos + 1], p[pos]);
			addedge(id, cantor(p), 3);
			swap(p[pos + 1], p[pos]);
		}
		f[id] = manhattan(p);
		id++;
	} while (next_permutation(p, p + 9));
	Astar(S, T);
	if (!vis[T]) puts("unsolvable");
	else {
		string ans;
		for (int u = T; u != S; u = last[u].first) ans.push_back(dir[last[u].second]);
		reverse(ans.begin(), ans.end());
		printf("%s\n", ans.c_str());
	}
	return 0;
}
時間複雜度 $O(9!\log 9!)$

反思與總結

我是上來就將整個圖建好的,雖然改成了陣列,但是執行速度還是堪憂。
看了網上的其它解法,好像只要邊跑 \(A*\) 邊建圖即可,因為圖中大部分點都是訪問不到的。
這題判無解有個基於逆序對的高階思想,我並沒有使用,因為 \(A*\) 完全跑的完,只要最後判一下終點有沒有被訪問過即可(其實是因為我不會證明這個結論,就不放上來獻醜了)。