1. 程式人生 > 實用技巧 >solution - 簡單題(K-Dimension Tree)

solution - 簡單題(K-Dimension Tree)

solution - 簡單題(K-Dimension Tree)

咕了這麼久,終於可以來講講KDT了。

說句實話,KDT的演算法是非常簡單的,但是很少有人能很快的寫對,總是會出現一些奇奇怪怪的BUG,我自己也寫了一個下午。主要是寫程式碼時注意結構的對稱性,以及演算法的模組性,一個function幹一件事就行。

說了這麼多,就開始講演算法吧。

#1 演算法描述

考慮一種二叉樹的結構,其中每一個節點有兩個功能:

  • 儲存這個節點\(\textbf{p}=[x_0, x_1, \cdots,x_{k-1}]\)的值
  • 儲存一個 \(k\) 維區域的記錄,並且通過這一個節點的左右兒子(\(\textbf p_l, \textbf p_r\)
    )將其在某一個維度將其分為兩半,形式化地,就是對於任意的節點\(\textbf a \in \{\textbf p_l \text{及其兒子}\}, \textbf b \in \{\textbf p_r \text{及其兒子}\}\),存在某一個維度向量\(\textbf T = [x_0,x_1,\cdots,x_{k-1}], \text{其中} x_t = 1, x_{p\not=t} = 0\),有$\textbf T \cdot \textbf a \leq \textbf T \cdot \textbf p < \textbf T \cdot \textbf b $

考慮到演算法的簡便性,我們人為規定

\(t\) 為這個節點的深度模 \(k\)

那麼就十分簡單了,可以很清楚的實現這個演算法的插入,查詢,刪除。

但是考慮到演算法的單次複雜度還是 \(\textrm O(n)\),需要優化。

我們可以使用替罪羊樹的思路優化,即一但某個節點的某個左右兒子的重量大於這個節點的重量的 \(\alpha\) 倍,那麼就直接重構這個樹。其中 \(\alpha\) 基本在 \(0.75\) 附近最好。

下面是程式碼實現。

#2 程式碼實現細節

#2.1 節點定義

先是定義節點。

template <int D>
struct KDT {
	KDT<D>* ls = nullptr, * rs = nullptr;
	int mx[D], mn[D], pos[D];
	int val; // 這個結點的值
	int sum; // 這個節點及其子節點的值的和
	int tot; // 這個節點及其子節點的數量
	const bool operator < (KDT t) const {
		return t.val < val;
	} // 為了pair所必須的
};

其中 mx, mn 為這個節點及其子樹的在所有維度的極大值與極小值。

pos 為這個節點的維度值。

其他的見註釋。

你可以注意到這裡使用了指標來定義。

#2.2 插入

這裡寫一下這個程式的虛擬碼:

$ \textbf {function}\ Insert (\text{這個節點的指標}nx, \text{插入的維度} \textbf d, \text{節點的值}val,\text{節點的深度depth(模過D)}) $

\(\ \ \ \ \textbf{if not exist } nx\ \textbf {then new }nx\)

$\ \ \ \ \textbf{else if } d_{val} \leq nx.pos_{val} $

$\ \ \ \ \ \ \ \ \textbf { then }nx \leftarrow insert(nx\rightarrow ls, d, val, depth + 1) $

\(\ \ \ \ \ \ \ \ \textbf { else } nx \leftarrow insert(n \rightarrow rs, d, val, depth + 1, rldep + 1);\)

$\ \ \ \ \textbf {if not } \text{balance} \textbf { then } rebuild(nx) $

\(\ \ \ \ update(nx)\)

$\ \ \ \ \textbf {return } nx $

\(\textbf {end function}\)

template <int D>
KDT<D>* insert(KDT<D>* nx, int d[D], int val, int depth, int rldep = 0) {
	if (depth >= D) depth -= D;
	if (nx == nullptr) {
		nx = new KDT<D>;
		for (int i = 0; i < D; i++) {
			nx->pos[i] = nx->mn[i] = nx->mx[i] = d[i];
		}
		nx->val = nx->sum = val;
		nx->tot = 1;
		return nx;
	}
	else {
		int flag = 1;
		for (int i = 0; i < D; i++) {
			if (d[i] != nx->pos[i])
				flag = 0;
		}
		if (flag) {
			nx->val += val;
			update(nx);
			return nx;
		}
		if (d[depth] <= nx->pos[depth]) {
			nx->ls = insert(nx->ls, d, val, depth + 1, rldep + 1);
		} else {
			nx->rs = insert(nx->rs, d, val, depth + 1, rldep + 1);
		}
		update(nx);
		int mx = 0;
		if (nx->ls != nullptr) mx = max(mx, nx->ls->tot);
		if (nx->rs != nullptr) mx = max(mx, nx->rs->tot);
		if (mx > nx->tot * alpha) {
			pair<int, KDT<D> >* arr = 
                new pair<int, KDT<D> >[nx->tot + 10];
			pl = 0;
			pia(nx, arr);
			nx = rebuild(1, pl + 1, depth, arr);
			delete[]arr;
		}
		return nx;
	}
}

這裡請注意以下rebuild塊的寫法。

我的寫法是新建一塊記憶體儲存被刪除的節點 pair<int, KDT<D> >* arr = new pair<int, KDT<D> >[nx->tot + 10];

然後將其節點刪除,並放至arr中。

程式碼如下:

template <int D>
void pia(KDT<D>* ptr, pair<int, KDT<D> >* arr) {
	if (ptr != nullptr) {
		arr[++pl].second = *ptr;
		pia(ptr->ls, arr);
		pia(ptr->rs, arr);
		delete ptr;
	}
}

最後是rebuild的一塊。

程式碼如下:

template <int D>
KDT<D>* rebuild(int L, int R, int dep, pair<int, KDT<D> >* arr) {
	if (L >= R) return nullptr;
	if (dep > D) dep -= D;
	for (int i = L; i < R; i++) {
		arr[i].first = arr[i].second.pos[dep];
	}
	int mid = (L + R) >> 1;
	nth_element(arr + L, arr + mid, arr + R);
	KDT<D>* ret = new KDT<D>;
	*ret = arr[mid].second;
	ret->ls = rebuild(L, mid, dep + 1, arr);
	ret->rs = rebuild(mid + 1, R, dep + 1, arr);
	update(ret);
	return ret;
}

這裡為了偷懶才用了系統的nth_element,否則可以不用pair陣列。

#2.3查詢

這一塊比較簡單,不予贅述。其中allin函式表示這個節點及其子節點全在所給範圍之中。allout 相反。 in表示這個單獨的點是否在區域中。

程式碼如下:

template <int D>
int get_ans(KDT<D>* nx, int mx[D], int mn[D]) {
	if (nx == nullptr) return 0;
	if (allout(nx, mx, mn)) {
		return 0;
	}
	if (allin(nx, mx, mn)) return nx->sum;
	int ret = 0;
	if (in(nx, mx, mn)) {
		ret = nx->val;
	}
	ret += get_ans(nx->ls, mx, mn);
	ret += get_ans(nx->rs, mx, mn);
	return ret;
}

#3 程式碼呈現

#include<cstdio>
#include<algorithm>

const double alpha = 0.75;
const int maxn = 210000;

using namespace std;

template <int d>
struct kdt {
	kdt<d>* ls = nullptr, * rs = nullptr;
	int mx[d], mn[d], pos[d];
	int val;
	int sum;
	int tot;
	const bool operator < (kdt t) const {
		return t.val < val;
	}
};

//pair <int, kdt<t> > arr[maxn];
int pl = 0;
template <int d>
void update(kdt<d>* ret) {
	ret->tot = 1;
	ret->sum = ret->val;
	for (int i = 0; i < d; i++) {
		ret->mx[i] = ret->mn[i] = ret->pos[i];
	}
	if (ret->ls != nullptr) {
		for (int i = 0; i < d; i++)
			ret->mx[i] = max(ret->mx[i], ret->ls->mx[i]),
			ret->mn[i] = min(ret->mn[i], ret->ls->mn[i]);
		ret->sum += ret->ls->sum;
		ret->tot += ret->ls->tot;
	}
	if (ret->rs != nullptr) {
		for (int i = 0; i < d; i++)
			ret->mx[i] = max(ret->mx[i], ret->rs->mx[i]),
			ret->mn[i] = min(ret->mn[i], ret->rs->mn[i]);
		ret->sum += ret->rs->sum;
		ret->tot += ret->rs->tot;
	}
}

template <int d>
void pia(kdt<d>* ptr, pair<int, kdt<d> >* arr) {
	if (ptr != nullptr) {
		arr[++pl].second = *ptr;
		pia(ptr->ls, arr);
		pia(ptr->rs, arr);
		delete ptr;
	}
}

template <int d>
kdt<d>* rebuild(int l, int r, int dep, pair<int, kdt<d> >* arr) {
	if (l >= r) return nullptr;
	if (dep > d) dep -= d;
	for (int i = l; i < r; i++) {
		arr[i].first = arr[i].second.pos[dep];
	}
	int mid = (l + r) >> 1;
	nth_element(arr + l, arr + mid, arr + r);
	kdt<d>* ret = new kdt<d>;
	*ret = arr[mid].second;
	ret->ls = rebuild(l, mid, dep + 1, arr);
	ret->rs = rebuild(mid + 1, r, dep + 1, arr);
	update(ret);
	return ret;
}

template <int d>
kdt<d>* insert(kdt<d>* nx, int d[d], int val, int depth, int rldep = 0) {
	if (depth >= d) depth -= d;
	if (nx == nullptr) {
		nx = new kdt<d>;
		for (int i = 0; i < d; i++) {
			nx->pos[i] = nx->mn[i] = nx->mx[i] = d[i];
		}
		nx->val = nx->sum = val;
		nx->tot = 1;
		return nx;
	}
	else {
		int flag = 1;
		for (int i = 0; i < d; i++) {
			if (d[i] != nx->pos[i])
				flag = 0;
		}
		if (flag) {
			nx->val += val;
			update(nx);
			return nx;
		}
		if (d[depth] < nx->pos[depth]) {
			nx->ls = insert(nx->ls, d, val, depth + 1, rldep + 1);
		} else {
			nx->rs = insert(nx->rs, d, val, depth + 1, rldep + 1);
		}
		update(nx);
		int mx = 0;
		if (nx->ls != nullptr) mx = max(mx, nx->ls->tot);
		if (nx->rs != nullptr) mx = max(mx, nx->rs->tot);
		if (mx > nx->tot * alpha) {
			pair<int, kdt<d> >* arr = new pair<int, kdt<d> >[nx->tot + 10];
			pl = 0;
			pia(nx, arr);
			nx = rebuild(1, pl + 1, depth, arr);
			delete[]arr;
		}
		return nx;
	}
}

template <int d>
int allin(kdt<d>* nx, int mx[d], int mn[d]) {
	for (int i = 0; i < d; i++) {
		if (nx->mx[i] > mx[i]) {
			return 0;
		}
		if (nx->mn[i] < mn[i]) {
			return 0;
		}
	}
	return 1;
}

template <int d>
int allout(kdt<d>* nx, int mx[d], int mn[d]) {
	for (int i = 0; i < d; i++) {
		if (nx->mn[i] > mx[i]) {
			return 1;
		}
		if (nx->mx[i] < mn[i]) {
			return 1;
		}
	}
	return 0;
}

template <int d>
int in(kdt<d>* nx, int mx[d], int mn[d]) {
	for (int i = 0; i < d; i++) {
		if (nx->pos[i] > mx[i]) {
			return 0;
		}
		if (nx->pos[i] < mn[i]) {
			return 0;
		}
	}
	return 1;
}

template <int d>
int get_ans(kdt<d>* nx, int mx[d], int mn[d]) {
	if (nx == nullptr) return 0;
	if (allout(nx, mx, mn)) {
		return 0;
	}
	if (allin(nx, mx, mn)) return nx->sum;
	int ret = 0;
	if (in(nx, mx, mn)) {
		ret = nx->val;
	}
	ret += get_ans(nx->ls, mx, mn);
	ret += get_ans(nx->rs, mx, mn);
	return ret;
}

int main() {
	kdt<2>* root = nullptr;
	int n;scanf("%d", &n);
	int lst_ans = 0;
	while (1) {
		int opt;
		scanf("%d", &opt);
		if (opt == 3) break;
		if (opt == 1) {
			int d[2] = { 0,0}, val;
			scanf("%d%d%d", d, d + 1, &val);
			d[0] ^= lst_ans, d[1] ^= lst_ans, val ^= lst_ans;
			root = insert(root, d, val, 0);
		}
		if (opt == 2) {
			int mx[2] = { 0,0 }, mn[2] = { 0,0 };
			scanf("%d%d%d%d", mn, mn + 1, mx, mx + 1);
			mx[0] ^= lst_ans, mx[1] ^= lst_ans;
			mn[0] ^= lst_ans, mn[1] ^= lst_ans;
			lst_ans = get_ans(root, mx, mn);
			printf("%d\n", lst_ans);
		}
	}
}