用Java泛化寫的伸展樹
阿新 • • 發佈:2018-12-25
此程式對原實現進行了泛化,並對findMax 和findMin兩個方法做了修改,使得使用起來方便了許多,大家可參見原作者的文章點選開啟連結
附上原始碼:
import java.util.ArrayList; import java.util.Comparator; import java.util.LinkedList; import java.util.List; import java.util.Random; /** * Splay tree and its operations * * @author clj 2014-03-31 * * 1. findXXX() methods are mutable operation 2. All operations are not * thread-safe because the tree can be changed by any operation. * */ public class SplayTreeTemplate<K, V> { static class SplayNode<K, V> { K key; V val; SplayNode<K, V> left; SplayNode<K, V> right; public SplayNode(K key, V val) { this.key = key; this.val = val; } public String toString() { return String.valueOf(val); } } private Comparator<? super K> comparator; private SplayNode<K, V> root; private int count = 0; public SplayTreeTemplate(SplayNode<K, V> root, Comparator<? super K> c) { this.root = root; comparator = c; count = countOfNodes(root); } public SplayTreeTemplate(Comparator<? super K> c) { this.comparator = c; } private int countOfNodes(SplayNode<K, V> root) { if (root == null) return 0; else { return countOfNodes(root.left) + countOfNodes(root.right) + 1; } } public SplayNode<K, V> getRoot() { return root; } // find a node in the tree // return true if found; false otherwise public boolean find(K key) throws Exception { if (root == null) throw new Exception("tree is empty."); splay(key); if (comparator.compare(root.key, key) == 0) { return true; } else { return false; } } private SplayNode<K, V> findMax(SplayNode<K, V> rootNode) throws Exception { if (rootNode == null) throw new Exception("tree or subtree is empty."); SplayNode<K, V> pseudoNode = new SplayNode<K, V>(null, null); // right tree root (no right child) SplayNode<K, V> leftMax = pseudoNode; SplayNode<K, V> t = rootNode; while (true) { SplayNode<K, V> parent = t.right; // if parent == null ,t is maximum if (parent == null) break; // Note: the variable parent is target's parent, the variable t is // target's grandparent if (parent.right == null) { // zag t.right = null; leftMax.right = t; leftMax = t; t = parent; break; } else { // zag-zag SplayNode<K, V> tmp = parent.right; // after rotate parent.right = null rotateRightChild(t, parent); // update left tree and its max node leftMax.right = parent; leftMax = parent; // update the middle tree's root t = tmp; } } leftMax.right = t.left; t.left = pseudoNode.right; // pseudoNode.right is the root of left tree return t; } public V findMax() throws Exception { root = findMax(this.root); return root.val; } private SplayNode<K, V> findMin(SplayNode<K, V> rootNode) throws Exception { if (rootNode == null) throw new Exception("tree or subtree is empty."); SplayNode<K, V> pseudoNode = new SplayNode<K, V>(null, null); // right tree root (no right child) SplayNode<K, V> rightMin = pseudoNode; SplayNode<K, V> t = rootNode; while (true) { SplayNode<K, V> parent = t.left; // if parent == null ,t is minimum if (parent == null) break; // Note: the variable parent is target's parent, the variable t is // target's grandparent if (parent.left == null) { // zig t.left = null; rightMin.left = t; rightMin = t; t = parent; break; } else { // zig-zig SplayNode<K, V> tmp = parent.left; // after rotate parent.left = null rotateLeftChild(t, parent); // update right tree and its min node rightMin.left = parent; rightMin = parent; // update the middle tree's root t = tmp; } } rightMin.left = t.right; t.right = pseudoNode.left; // pseudoNode.left is the root of right tree return t; } public V findMin() throws Exception { root = findMin(this.root); return root.val; } public V deleteMax() throws Exception { V max = findMax(); this.root = root.left; count--; return max; } public V deleteMin() throws Exception { V min = findMin(); this.root = root.right; count--; return min; } public void insert(K key, V val) throws Exception { if (root == null) { // set the new node as root this.root = new SplayNode<K, V>(key, val); count++; } else { splay(key); if (comparator.compare(root.key, key) == 0) { root.val = val; // throw new Exception("duplicate value!"); } else if (comparator.compare(key, root.key) < 0) { // split the splayed tree with right subtree including root, and // set the new node as root // x is between root and root.left SplayNode<K, V> tmp = new SplayNode<K, V>(key, val); tmp.left = this.root.left; tmp.right = this.root; root.left = null; this.root = tmp; count++; } else {// ie. x>root.val // split the splayed tree with left subtree including root, // and set the new Node<K,V> as root // x is between root and root.right SplayNode<K, V> tmp = new SplayNode<K, V>(key, val); tmp.left = this.root; tmp.right = this.root.right; root.right = null; this.root = tmp; count++; } } } public V remove(K key) throws Exception { if (root == null) throw new Exception("tree is empty."); splay(key); if (comparator.compare(root.key, key) != 0) { throw new Exception("value not found."); } SplayNode<K, V> temp = root; if (root.left == null) { // root(root.val==x) is the min node root = root.right; } else { // find the max value from left subtree, and // then remove root and join the right subtree with the left splayed // subtree SplayNode<K, V> leftSubTreeRoot = this.findMax(this.root.left); leftSubTreeRoot.right = this.root.right; root = leftSubTreeRoot; } count--; return temp.val; } private void rotateLeftChild(SplayNode<K, V> grandparent, SplayNode<K, V> parent) { grandparent.left = parent.right; parent.right = grandparent; // split the parent with middle tree parent.left = null; } private void rotateRightChild(SplayNode<K, V> grandparent, SplayNode<K, V> parent) { grandparent.right = parent.left; parent.left = grandparent; // split the parent with middle tree parent.right = null; } // x: the target value to be found for splaying public void splay(K key) { this.root = splay(this.root, key); } // x: the target value to be found for splaying // rootNode: the root node of the tree to be splayed // return the new root of the splayed tree or subtree private SplayNode<K, V> splay(SplayNode<K, V> rootNode, K key) { if (rootNode == null) return null; SplayNode<K, V> pseudoNode = new SplayNode<K, V>(null, null); // left tree root (no left child) SplayNode<K, V> leftMax = pseudoNode; // right tree root (no right child) SplayNode<K, V> rightMin = pseudoNode; SplayNode<K, V> t = rootNode; while (true) { int comp = comparator.compare(key, t.key); if (comp == 0) { // key == t.key break; } else if (comp < 0) { // key <t.key // Note: the variable parent is target's parent, the variable t // is target's grandparent SplayNode<K, V> parent = t.left; if (parent == null) { break; } else { if (comparator.compare(key, parent.key) < 0) { if (parent.left == null) { // zig t.left = null; rightMin.left = t; rightMin = t; t = parent; break; } else { // zig-zig SplayNode<K, V> tmp = parent.left; // after rotate parent.left = null rotateLeftChild(t, parent); // update right tree and its min node rightMin.left = parent; rightMin = parent; // update the middle tree's root t = tmp; } } else { // ie. key >= parent.key // zig or zig-zag(simplified to zig) t.left = null; rightMin.left = t; rightMin = t; t = parent; } } } else { // ie. key > t.key SplayNode<K, V> parent = t.right; if (parent == null) { break; } else { if (comparator.compare(key, parent.key) > 0) { if (parent.right == null) { // zag t.right = null; leftMax.right = t; leftMax = t; t = parent; break; } else { // zag-zag SplayNode<K, V> tmp = parent.right; // after rotate parent.right = null rotateRightChild(t, parent); // update left tree and its max node leftMax.right = parent; leftMax = parent; // update the middle tree's root t = tmp; } } else { // ie. key <= parent.key // zag or zag-zig (simplified to zag) t.right = null; leftMax.right = t; leftMax = t; t = parent; } } } } // re-assemble (note: even if the above while is not executed, the // following code works as expected.) leftMax.right = t.left; rightMin.left = t.right; t.left = pseudoNode.right; // pseudoNode.right is the root of left tree t.right = pseudoNode.left; // pseudoNode.left is the root of right tree return t; } public int getSize() { return count; } public boolean isEmpty() { return count == 0; } // utility method for test purpose public static void recursiveInOrderTraverse(SplayNode root) { if (root == null) return; recursiveInOrderTraverse(root.left); System.out.format(" %s", root.val); recursiveInOrderTraverse(root.right); } // utility method for test purpose // n: the nodes number of the tree public static void displayBinaryTree(SplayNode root, int n) { if (root == null) return; LinkedList<SplayNode> queue = new LinkedList<SplayNode>(); // all nodes in each level List<List<SplayNode>> nodesList = new ArrayList<List<SplayNode>>(); // the positions in a displayable tree for each level's nodes List<List<Integer>> nextPosList = new ArrayList<List<Integer>>(); queue.add(root); // int level=0; int levelNodes = 1; int nextLevelNodes = 0; List<SplayNode> levelNodesList = new ArrayList<SplayNode>(); List<Integer> nextLevelNodesPosList = new ArrayList<Integer>(); int pos = 0; // the position of the current node List<Integer> levelNodesPosList = new ArrayList<Integer>(); levelNodesPosList.add(0); // root position nextPosList.add(levelNodesPosList); int levelNodesTotal = 1; while (!queue.isEmpty()) { SplayNode node = queue.remove(); if (levelNodes == 0) { nodesList.add(levelNodesList); nextPosList.add(nextLevelNodesPosList); levelNodesPosList = nextLevelNodesPosList; levelNodesList = new ArrayList<SplayNode>(); nextLevelNodesPosList = new ArrayList<Integer>(); // level++; levelNodes = nextLevelNodes; levelNodesTotal = nextLevelNodes; nextLevelNodes = 0; } levelNodesList.add(node); pos = levelNodesPosList.get(levelNodesTotal - levelNodes); if (node.left != null) { queue.add(node.left); nextLevelNodes++; nextLevelNodesPosList.add(2 * pos); } if (node.right != null) { queue.add(node.right); nextLevelNodes++; nextLevelNodesPosList.add(2 * pos + 1); } levelNodes--; } // save the last level's nodes list nodesList.add(levelNodesList); int maxLevel = nodesList.size() - 1; // ==level // use both nodesList and nextPosList to set the positions for each node // Note: expected max columns: 2^(level+1) - 1 int cols = 1; for (int i = 0; i <= maxLevel; i++) { cols <<= 1; } cols--; SplayNode[][] tree = new SplayNode[maxLevel + 1][cols]; // load the tree into an array for later display for (int currLevel = 0; currLevel <= maxLevel; currLevel++) { levelNodesList = nodesList.get(currLevel); levelNodesPosList = nextPosList.get(currLevel); // Note: the column for this level's j-th element: // 2^(maxLevel-level)*(2*j+1) - 1 int tmp = maxLevel - currLevel; int coeff = 1; for (int i = 0; i < tmp; i++) { coeff <<= 1; } for (int k = 0; k < levelNodesList.size(); k++) { int j = levelNodesPosList.get(k); int col = coeff * (2 * j + 1) - 1; tree[currLevel][col] = levelNodesList.get(k); } } // display the binary search tree System.out.format("%n"); for (int i = 0; i <= maxLevel; i++) { for (int j = 0; j < cols; j++) { SplayNode node = tree[i][j]; if (node == null) System.out.format(" "); else System.out.format("%2d", node.key); } System.out.format("%n"); } } public static void printAfterSplayed(SplayTreeTemplate splayTree) { SplayNode root = splayTree.getRoot(); System.out.format("%nAfter being splayed, in-order BST:%n"); SplayTreeTemplate.recursiveInOrderTraverse(root); System.out.format("%n%n%nAfter being splayed, the tree is:"); SplayTreeTemplate.displayBinaryTree(root, splayTree.getSize()); } public static void main(String[] args) throws Exception { Comparator<Integer> comparator = new Comparator<Integer>() { @Override public int compare(Integer o1, Integer o2) { return o1 < o2 ? -1 : (o1 == o2 ? 0 : 1); } }; // test1(comparator); // test2(comparator); // test3(comparator); test4(comparator); } private static void test4(Comparator<Integer> comparator) throws Exception { SplayTreeTemplate<Integer, String> splayTree; System.out.format("************************************"); System.out.format("%nTest case 4 - priority queue:%n"); splayTree = new SplayTreeTemplate<Integer, String>(comparator); long current = System.currentTimeMillis(); for (int i = 0; i < 1000000; i++) { Integer key = i + new Random().nextInt(10); // Integer key = i; String val = "v_" + key; splayTree.insert(key, val); } long duration = System.currentTimeMillis() - current; System.out.println("完成插入: " + duration); current = System.currentTimeMillis(); for (int i = 0; i < 1000000; i++) { Integer key = i + new Random().nextInt(10); // Integer key = i; splayTree.find(key); } duration = System.currentTimeMillis() - current; System.out.println("完成查詢:" + duration); current = System.currentTimeMillis(); while (!splayTree.isEmpty()) { splayTree.deleteMax(); } duration = System.currentTimeMillis() - current; System.out.println("完成刪除:" + duration); } private static void test3(Comparator<Integer> comparator) throws Exception { SplayNode<Integer, String> root; SplayTreeTemplate<Integer, String> splayTree; String max; Integer newKey; String newVal; System.out.format("************************************"); System.out.format("%nTest case 3 - priority queue:%n"); SplayNode<Integer, String> m1 = new SplayNode<Integer, String>(1, "1"); SplayNode<Integer, String> m4 = new SplayNode<Integer, String>(4, "4"); SplayNode<Integer, String> m7 = new SplayNode<Integer, String>(7, "7"); SplayNode<Integer, String> m9 = new SplayNode<Integer, String>(9, "9"); SplayNode<Integer, String> m20 = new SplayNode<Integer, String>(20, "20"); SplayNode<Integer, String> m22 = new SplayNode<Integer, String>(22, "22"); SplayNode<Integer, String> m26 = new SplayNode<Integer, String>(26, "26"); SplayNode<Integer, String> m29 = new SplayNode<Integer, String>(29, "29"); SplayNode<Integer, String> m30 = new SplayNode<Integer, String>(30, "30"); SplayNode<Integer, String> m36 = new SplayNode<Integer, String>(36, "36"); m1.right = m4; m4.right = m7; m7.right = m9; m9.right = m20; m20.right = m22; m22.right = m26; m26.right = m29; m29.right = m30; m30.right = m36; root = m1; System.out.format("%nBefore being splayed, in-order BST:%n"); SplayTreeTemplate.recursiveInOrderTraverse(root); splayTree = new SplayTreeTemplate(root, comparator); max = splayTree.deleteMax(); System.out.format("%n*****deleted max value: %s*****%n", max); printAfterSplayed(splayTree); max = splayTree.deleteMax(); System.out.format("%n*****deleted max value: %s*****%n", max); printAfterSplayed(splayTree); max = splayTree.deleteMax(); System.out.format("%n*****deleted max value: %s*****%n", max); printAfterSplayed(splayTree); max = splayTree.deleteMax(); System.out.format("%n*****deleted max value: %s*****%n", max); printAfterSplayed(splayTree); newKey = 16; newVal = "16"; splayTree.insert(newKey, newVal); System.out.format("%n*****insert new value %d*****%n", newKey); printAfterSplayed(splayTree); max = splayTree.deleteMax(); System.out.format("%n*****deleted max value: %s*****%n", max); printAfterSplayed(splayTree); max = splayTree.deleteMax(); System.out.format("%n*****deleted max value: %s*****%n", max); printAfterSplayed(splayTree); newKey = 12; newVal = "12"; splayTree.insert(newKey, newVal); System.out.format("%n*****insert new value %d*****%n", newKey); printAfterSplayed(splayTree); max = splayTree.deleteMax(); System.out.format("%n*****deleted max value: %s*****%n", max); printAfterSplayed(splayTree); } private static void test2(Comparator<Integer> comparator) throws Exception { SplayNode<Integer, String> root; SplayTreeTemplate<Integer, String> splayTree; System.out.format("************************************"); System.out.format("%nTest case 2 - splaytree's operations:%n"); /* * 13 10 25 12 20 35 29 */ SplayNode<Integer, String> n13 = new SplayNode<Integer, String>(13, "13"); SplayNode<Integer, String> n10 = new SplayNode<Integer, String>(10, "10"); SplayNode<Integer, String> n25 = new SplayNode<Integer, String>(25, "25"); SplayNode<Integer, String> n12 = new SplayNode<Integer, String>(12, "12"); SplayNode<Integer, String> n20 = new SplayNode<Integer, String>(20, "20"); SplayNode<Integer, String> n35 = new SplayNode<Integer, String>(35, "35"); SplayNode<Integer, String> n29 = new SplayNode<Integer, String>(29, "29"); n13.left = n10; n13.right = n25; n10.right = n12; n25.left = n20; n25.right = n35; n35.left = n29; root = n13; System.out.format("%nBefore being splayed, in-order BST:%n"); SplayTreeTemplate.recursiveInOrderTraverse(root); splayTree = new SplayTreeTemplate<Integer, String>(root, comparator); int val = 25; boolean found = splayTree.find(val); System.out.format("%n*****%d is in the tree? [%s]*****%n", val, found); printAfterSplayed(splayTree); String max = splayTree.findMax(); System.out.format("%n*****max value=%s*****%n", max); printAfterSplayed(splayTree); String min = splayTree.findMin(); System.out.format("%n*****min value=%s*****%n", min); printAfterSplayed(splayTree); max = splayTree.deleteMax(); System.out.format("%n*****deleted max value: %s*****%n", max); printAfterSplayed(splayTree); min = splayTree.deleteMin(); System.out.format("%n*****deleted min value: %s*****%n", min); printAfterSplayed(splayTree); Integer newKey = 24; String newVal = "24"; splayTree.insert(newKey, newVal); System.out.format("%n*****insert new value %d*****%n", newKey); printAfterSplayed(splayTree); Integer removeVal = 12; splayTree.remove(removeVal); System.out.format("%n*****remove value %d*****%n", removeVal); printAfterSplayed(splayTree); } private static void test1(Comparator<Integer> comparator) { System.out.format("%nTest case 1 - splay opeartion:%n"); SplayNode<Integer, String> nn12 = new SplayNode<Integer, String>(12, "12"); SplayNode<Integer, String> nn5 = new SplayNode<Integer, String>(5, "5"); SplayNode<Integer, String> nn25 = new SplayNode<Integer, String>(25, "25"); SplayNode<Integer, String> nn20 = new SplayNode<Integer, String>(20, "20"); SplayNode<Integer, String> nn30 = new SplayNode<Integer, String>(30, "30"); SplayNode<Integer, String> nn15 = new SplayNode<Integer, String>(15, "15"); SplayNode<Integer, String> nn24 = new SplayNode<Integer, String>(24, "24"); SplayNode<Integer, String> nn13 = new SplayNode<Integer, String>(13, "13"); SplayNode<Integer, String> nn18 = new SplayNode<Integer, String>(18, "18"); SplayNode<Integer, String> nn16 = new SplayNode<Integer, String>(16, "16"); nn12.left = nn5; nn12.right = nn25; nn25.left = nn20; nn25.right = nn30; nn20.left = nn15; nn20.right = nn24; nn15.left = nn13; nn15.right = nn18; nn18.left = nn16; SplayNode<Integer, String> root = nn12; System.out.format("%nBefore being splayed, in-order BST:%n"); SplayTreeTemplate.recursiveInOrderTraverse(root); SplayTreeTemplate<Integer, String> splayTree = new SplayTreeTemplate( root, comparator); splayTree.splay(19); System.out.format("%n*****splay the node with value=19*****%n"); printAfterSplayed(splayTree); } }