solution - 簡單題(K-Dimension Tree)
solution - 簡單題(K-Dimension Tree)
咕了這麼久,終於可以來講講KDT了。
說句實話,KDT的演算法是非常簡單的,但是很少有人能很快的寫對,總是會出現一些奇奇怪怪的BUG,我自己也寫了一個下午。主要是寫程式碼時注意結構的對稱性,以及演算法的模組性,一個function幹一件事就行。
說了這麼多,就開始講演算法吧。
#1 演算法描述
考慮一種二叉樹的結構,其中每一個節點有兩個功能:
- 儲存這個節點\(\textbf{p}=[x_0, x_1, \cdots,x_{k-1}]\)的值
- 儲存一個 \(k\) 維區域的記錄,並且通過這一個節點的左右兒子(\(\textbf p_l, \textbf p_r\)
考慮到演算法的簡便性,我們人為規定
那麼就十分簡單了,可以很清楚的實現這個演算法的插入,查詢,刪除。
但是考慮到演算法的單次複雜度還是 \(\textrm O(n)\),需要優化。
我們可以使用替罪羊樹的思路優化,即一但某個節點的某個左右兒子的重量大於這個節點的重量的 \(\alpha\) 倍,那麼就直接重構這個樹。其中 \(\alpha\) 基本在 \(0.75\) 附近最好。
下面是程式碼實現。
#2 程式碼實現細節
#2.1 節點定義
先是定義節點。
template <int D> struct KDT { KDT<D>* ls = nullptr, * rs = nullptr; int mx[D], mn[D], pos[D]; int val; // 這個結點的值 int sum; // 這個節點及其子節點的值的和 int tot; // 這個節點及其子節點的數量 const bool operator < (KDT t) const { return t.val < val; } // 為了pair所必須的 };
其中 mx, mn
為這個節點及其子樹的在所有維度的極大值與極小值。
pos
為這個節點的維度值。
其他的見註釋。
你可以注意到這裡使用了指標來定義。
#2.2 插入
這裡寫一下這個程式的虛擬碼:
$ \textbf {function}\ Insert (\text{這個節點的指標}nx, \text{插入的維度} \textbf d, \text{節點的值}val,\text{節點的深度depth(模過D)}) $
\(\ \ \ \ \textbf{if not exist } nx\ \textbf {then new }nx\)
$\ \ \ \ \textbf{else if } d_{val} \leq nx.pos_{val} $
$\ \ \ \ \ \ \ \ \textbf { then }nx \leftarrow insert(nx\rightarrow ls, d, val, depth + 1) $
\(\ \ \ \ \ \ \ \ \textbf { else } nx \leftarrow insert(n \rightarrow rs, d, val, depth + 1, rldep + 1);\)
$\ \ \ \ \textbf {if not } \text{balance} \textbf { then } rebuild(nx) $
\(\ \ \ \ update(nx)\)
$\ \ \ \ \textbf {return } nx $
\(\textbf {end function}\)
template <int D>
KDT<D>* insert(KDT<D>* nx, int d[D], int val, int depth, int rldep = 0) {
if (depth >= D) depth -= D;
if (nx == nullptr) {
nx = new KDT<D>;
for (int i = 0; i < D; i++) {
nx->pos[i] = nx->mn[i] = nx->mx[i] = d[i];
}
nx->val = nx->sum = val;
nx->tot = 1;
return nx;
}
else {
int flag = 1;
for (int i = 0; i < D; i++) {
if (d[i] != nx->pos[i])
flag = 0;
}
if (flag) {
nx->val += val;
update(nx);
return nx;
}
if (d[depth] <= nx->pos[depth]) {
nx->ls = insert(nx->ls, d, val, depth + 1, rldep + 1);
} else {
nx->rs = insert(nx->rs, d, val, depth + 1, rldep + 1);
}
update(nx);
int mx = 0;
if (nx->ls != nullptr) mx = max(mx, nx->ls->tot);
if (nx->rs != nullptr) mx = max(mx, nx->rs->tot);
if (mx > nx->tot * alpha) {
pair<int, KDT<D> >* arr =
new pair<int, KDT<D> >[nx->tot + 10];
pl = 0;
pia(nx, arr);
nx = rebuild(1, pl + 1, depth, arr);
delete[]arr;
}
return nx;
}
}
這裡請注意以下rebuild
塊的寫法。
我的寫法是新建一塊記憶體儲存被刪除的節點 pair<int, KDT<D> >* arr = new pair<int, KDT<D> >[nx->tot + 10];
然後將其節點刪除,並放至arr
中。
程式碼如下:
template <int D>
void pia(KDT<D>* ptr, pair<int, KDT<D> >* arr) {
if (ptr != nullptr) {
arr[++pl].second = *ptr;
pia(ptr->ls, arr);
pia(ptr->rs, arr);
delete ptr;
}
}
最後是rebuild
的一塊。
程式碼如下:
template <int D>
KDT<D>* rebuild(int L, int R, int dep, pair<int, KDT<D> >* arr) {
if (L >= R) return nullptr;
if (dep > D) dep -= D;
for (int i = L; i < R; i++) {
arr[i].first = arr[i].second.pos[dep];
}
int mid = (L + R) >> 1;
nth_element(arr + L, arr + mid, arr + R);
KDT<D>* ret = new KDT<D>;
*ret = arr[mid].second;
ret->ls = rebuild(L, mid, dep + 1, arr);
ret->rs = rebuild(mid + 1, R, dep + 1, arr);
update(ret);
return ret;
}
這裡為了偷懶才用了系統的nth_element
,否則可以不用pair
陣列。
#2.3查詢
這一塊比較簡單,不予贅述。其中allin
函式表示這個節點及其子節點全在所給範圍之中。allout
相反。 in
表示這個單獨的點是否在區域中。
程式碼如下:
template <int D>
int get_ans(KDT<D>* nx, int mx[D], int mn[D]) {
if (nx == nullptr) return 0;
if (allout(nx, mx, mn)) {
return 0;
}
if (allin(nx, mx, mn)) return nx->sum;
int ret = 0;
if (in(nx, mx, mn)) {
ret = nx->val;
}
ret += get_ans(nx->ls, mx, mn);
ret += get_ans(nx->rs, mx, mn);
return ret;
}
#3 程式碼呈現
#include<cstdio>
#include<algorithm>
const double alpha = 0.75;
const int maxn = 210000;
using namespace std;
template <int d>
struct kdt {
kdt<d>* ls = nullptr, * rs = nullptr;
int mx[d], mn[d], pos[d];
int val;
int sum;
int tot;
const bool operator < (kdt t) const {
return t.val < val;
}
};
//pair <int, kdt<t> > arr[maxn];
int pl = 0;
template <int d>
void update(kdt<d>* ret) {
ret->tot = 1;
ret->sum = ret->val;
for (int i = 0; i < d; i++) {
ret->mx[i] = ret->mn[i] = ret->pos[i];
}
if (ret->ls != nullptr) {
for (int i = 0; i < d; i++)
ret->mx[i] = max(ret->mx[i], ret->ls->mx[i]),
ret->mn[i] = min(ret->mn[i], ret->ls->mn[i]);
ret->sum += ret->ls->sum;
ret->tot += ret->ls->tot;
}
if (ret->rs != nullptr) {
for (int i = 0; i < d; i++)
ret->mx[i] = max(ret->mx[i], ret->rs->mx[i]),
ret->mn[i] = min(ret->mn[i], ret->rs->mn[i]);
ret->sum += ret->rs->sum;
ret->tot += ret->rs->tot;
}
}
template <int d>
void pia(kdt<d>* ptr, pair<int, kdt<d> >* arr) {
if (ptr != nullptr) {
arr[++pl].second = *ptr;
pia(ptr->ls, arr);
pia(ptr->rs, arr);
delete ptr;
}
}
template <int d>
kdt<d>* rebuild(int l, int r, int dep, pair<int, kdt<d> >* arr) {
if (l >= r) return nullptr;
if (dep > d) dep -= d;
for (int i = l; i < r; i++) {
arr[i].first = arr[i].second.pos[dep];
}
int mid = (l + r) >> 1;
nth_element(arr + l, arr + mid, arr + r);
kdt<d>* ret = new kdt<d>;
*ret = arr[mid].second;
ret->ls = rebuild(l, mid, dep + 1, arr);
ret->rs = rebuild(mid + 1, r, dep + 1, arr);
update(ret);
return ret;
}
template <int d>
kdt<d>* insert(kdt<d>* nx, int d[d], int val, int depth, int rldep = 0) {
if (depth >= d) depth -= d;
if (nx == nullptr) {
nx = new kdt<d>;
for (int i = 0; i < d; i++) {
nx->pos[i] = nx->mn[i] = nx->mx[i] = d[i];
}
nx->val = nx->sum = val;
nx->tot = 1;
return nx;
}
else {
int flag = 1;
for (int i = 0; i < d; i++) {
if (d[i] != nx->pos[i])
flag = 0;
}
if (flag) {
nx->val += val;
update(nx);
return nx;
}
if (d[depth] < nx->pos[depth]) {
nx->ls = insert(nx->ls, d, val, depth + 1, rldep + 1);
} else {
nx->rs = insert(nx->rs, d, val, depth + 1, rldep + 1);
}
update(nx);
int mx = 0;
if (nx->ls != nullptr) mx = max(mx, nx->ls->tot);
if (nx->rs != nullptr) mx = max(mx, nx->rs->tot);
if (mx > nx->tot * alpha) {
pair<int, kdt<d> >* arr = new pair<int, kdt<d> >[nx->tot + 10];
pl = 0;
pia(nx, arr);
nx = rebuild(1, pl + 1, depth, arr);
delete[]arr;
}
return nx;
}
}
template <int d>
int allin(kdt<d>* nx, int mx[d], int mn[d]) {
for (int i = 0; i < d; i++) {
if (nx->mx[i] > mx[i]) {
return 0;
}
if (nx->mn[i] < mn[i]) {
return 0;
}
}
return 1;
}
template <int d>
int allout(kdt<d>* nx, int mx[d], int mn[d]) {
for (int i = 0; i < d; i++) {
if (nx->mn[i] > mx[i]) {
return 1;
}
if (nx->mx[i] < mn[i]) {
return 1;
}
}
return 0;
}
template <int d>
int in(kdt<d>* nx, int mx[d], int mn[d]) {
for (int i = 0; i < d; i++) {
if (nx->pos[i] > mx[i]) {
return 0;
}
if (nx->pos[i] < mn[i]) {
return 0;
}
}
return 1;
}
template <int d>
int get_ans(kdt<d>* nx, int mx[d], int mn[d]) {
if (nx == nullptr) return 0;
if (allout(nx, mx, mn)) {
return 0;
}
if (allin(nx, mx, mn)) return nx->sum;
int ret = 0;
if (in(nx, mx, mn)) {
ret = nx->val;
}
ret += get_ans(nx->ls, mx, mn);
ret += get_ans(nx->rs, mx, mn);
return ret;
}
int main() {
kdt<2>* root = nullptr;
int n;scanf("%d", &n);
int lst_ans = 0;
while (1) {
int opt;
scanf("%d", &opt);
if (opt == 3) break;
if (opt == 1) {
int d[2] = { 0,0}, val;
scanf("%d%d%d", d, d + 1, &val);
d[0] ^= lst_ans, d[1] ^= lst_ans, val ^= lst_ans;
root = insert(root, d, val, 0);
}
if (opt == 2) {
int mx[2] = { 0,0 }, mn[2] = { 0,0 };
scanf("%d%d%d%d", mn, mn + 1, mx, mx + 1);
mx[0] ^= lst_ans, mx[1] ^= lst_ans;
mn[0] ^= lst_ans, mn[1] ^= lst_ans;
lst_ans = get_ans(root, mx, mn);
printf("%d\n", lst_ans);
}
}
}