線段樹(區間樹)之區間染色和4n推導過程
阿新 • • 發佈:2018-12-18
前言
線段樹(區間樹)是什麼呢?有了二叉樹、二分搜尋樹,線段樹又是幹什麼的呢?最經典的線段樹問題:區間染色;正如它的名字而言,主要解決區間的問題
一、線段樹說明
1、什麼是線段樹?
線段樹首先是二叉樹,並且是平衡二叉樹(它是一 棵空樹或它的左右兩個子樹的高度差的絕對值不超過1,並且左右兩個子樹都是一棵平衡二叉樹),並且具有二分性質。
如下圖,就是一顆線段樹:
假如,用陣列表示線段樹,如果區間有n個元素,陣列表示需要有多少節點?
2、4n節點推導過程
要進行一下,如果對推導過程不感興趣的,可以直接記住結論,需要4n個節點,推導過程如下圖: PS:依舊是全部落格園最醜圖,當感覺有進步啊!是不是推薦一下,鼓勵一下啊
說明:感覺用盡了洪荒之力,才推匯出來了。感覺高考之後再也不會用到等比公式了,但又用到了,還是緣分未盡啊,哈哈哈!最後,都放棄了,一直推導不出來,忘卻了最後一層的null,假設是滿二叉樹,按最大值進行估算,所以4n是完全夠大的!
二、為什麼要使用線段樹
線段樹主要解決一些區間問題的,如下:
1、區間染色
有一面牆,長度為n,每次選擇一段牆進行染色,m次操作之後,我們可以看見多少種顏色?
2、區間查詢
查詢區間[i,j]的最大值、最小值,或者區間數字和;實質:基於區間的統計查詢。
例如:2018年註冊使用者中消費最高的使用者?消費最低的使用者?學習最長時間的使用者?
三、程式碼實現
1、建立線段樹
二叉樹具有天然遞迴性質,所以用遞迴相對簡單,用迭代也是可以的,我才用遞迴實現,程式碼如下:
template<class T> class SegmentTree { private: T *tree; T *data; int size; std::function<T(T, T)> function; int leftChild(int index) { //左孩子下標;例如用陣列儲存,根節點是下標0,則左孩子為1,右孩子為2 returnindex * 2 + 1; } int rightChild(int index) { //右孩子下標 return index * 2 + 2; } void buildSegmentTree(int treeIndex, int l, int r) { if (l == r) { tree[treeIndex] = data[l]; return; } int leftTreeIndex = leftChild(treeIndex); int rightTreeIndex = rightChild(treeIndex); int mid = l + (r - l) / 2; //中間值求法,防止整型溢位 buildSegmentTree(leftTreeIndex, l, mid); //構建左子樹 buildSegmentTree(rightTreeIndex, mid + 1, r); //構建右子樹 tree[treeIndex] = function(tree[leftTreeIndex], tree[rightTreeIndex]); } public: SegmentTree(T arr[], int n, std::function<T(T, T)> function) { //建構函式,構建一棵樹 this->function = function; data = new T[n]; for (int i = 0; i < n; ++i) { data[i] = arr[i]; } tree = new T[n * 4]; //分配4n節點 size = n; buildSegmentTree(0, 0, size - 1); } };
2、線段樹查詢
線段樹具有二分查詢性質,所以二分查詢那種思路就可以了,程式碼如下:
T query(int treeIndex, int l, int r, int queryL, int queryR) { if (l == queryL && r == queryR) { return tree[treeIndex]; } int mid = l + (r - l) / 2; int leftTreeIndex = leftChild(treeIndex); int rightTreeIndex = rightChild(treeIndex); if (queryL >= mid + 1) { return query(rightTreeIndex, mid + 1, r, queryL, queryR); } else if (queryR <= mid) { return query(leftTreeIndex, l, mid, queryL, queryR); } T leftResult = query(leftTreeIndex, l, mid, queryL, mid); T rightResult = query(rightTreeIndex, mid + 1, r, mid + 1, queryR); return function(leftResult, rightResult); } T query(int queryL, int queryR) { assert(queryL >= 0 && queryL < size && queryR >= 0 && queryR < size && queryL <= queryR); return query(0, 0, size - 1, queryL, queryR); }
3、整體程式碼
SegmentTree.h如下:
#ifndef SEGMENT_TREE_SEGMENTTREE_H #define SEGMENT_TREE_SEGMENTTREE_H #include <cassert> #include <functional> template<class T> class SegmentTree { private: T *tree; T *data; int size; std::function<T(T, T)> function; int leftChild(int index) { return index * 2 + 1; } int rightChild(int index) { return index * 2 + 2; } void buildSegmentTree(int treeIndex, int l, int r) { if (l == r) { tree[treeIndex] = data[l]; return; } int leftTreeIndex = leftChild(treeIndex); int rightTreeIndex = rightChild(treeIndex); int mid = l + (r - l) / 2; buildSegmentTree(leftTreeIndex, l, mid); buildSegmentTree(rightTreeIndex, mid + 1, r); tree[treeIndex] = function(tree[leftTreeIndex], tree[rightTreeIndex]); } T query(int treeIndex, int l, int r, int queryL, int queryR) { if (l == queryL && r == queryR) { return tree[treeIndex]; } int mid = l + (r - l) / 2; int leftTreeIndex = leftChild(treeIndex); int rightTreeIndex = rightChild(treeIndex); if (queryL >= mid + 1) { return query(rightTreeIndex, mid + 1, r, queryL, queryR); } else if (queryR <= mid) { return query(leftTreeIndex, l, mid, queryL, queryR); } T leftResult = query(leftTreeIndex, l, mid, queryL, mid); T rightResult = query(rightTreeIndex, mid + 1, r, mid + 1, queryR); return function(leftResult, rightResult); } public: SegmentTree(T arr[], int n, std::function<T(T, T)> function) { this->function = function; data = new T[n]; for (int i = 0; i < n; ++i) { data[i] = arr[i]; } tree = new T[n * 4]; size = n; buildSegmentTree(0, 0, size - 1); } int getSize() { return size; } T get(int index) { assert(index >= 0 && index < size); return data[index]; } T query(int queryL, int queryR) { assert(queryL >= 0 && queryL < size && queryR >= 0 && queryR < size && queryL <= queryR); return query(0, 0, size - 1, queryL, queryR); } void print() { std::cout << "["; for (int i = 0; i < size * 4; ++i) { if (tree[i] != NULL) { std::cout << tree[i]; } else { std::cout << "0"; } if (i != size * 4 - 1) { std::cout << ", "; } } std::cout << "]" << std::endl; } }; #endif //SEGMENT_TREE_SEGMENTTREE_HView Code
main.cpp如下:
#include <iostream> #include "SegmentTree.h" int main() { int nums[] = {-2, 0, 3, -5, 2, -1}; SegmentTree<int> *segmentTree = new SegmentTree<int>(nums, sizeof(nums) / sizeof(int), [](int a, int b) -> int { return a + b; }); std::cout << segmentTree->query(2,5) << std::endl; segmentTree->print(); return 0; }
4、演示
執行結果,如下:
5、時間複雜度分析
更新 O(logn)
查詢 O(logn)