1. 程式人生 > >用Java泛化寫的伸展樹

用Java泛化寫的伸展樹

此程式對原實現進行了泛化,並對findMax 和findMin兩個方法做了修改,使得使用起來方便了許多,大家可參見原作者的文章點選開啟連結

附上原始碼:

import java.util.ArrayList;
import java.util.Comparator;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;

/**
 * Splay tree and its operations
 * 
 * @author clj 2014-03-31
 * 
 *         1. findXXX() methods are mutable operation 2. All operations are not
 *         thread-safe because the tree can be changed by any operation.
 * 
 */
public class SplayTreeTemplate<K, V> {
	static class SplayNode<K, V> {
		K key;
		V val;
		SplayNode<K, V> left;
		SplayNode<K, V> right;

		public SplayNode(K key, V val) {
			this.key = key;
			this.val = val;
		}

		public String toString() {
			return String.valueOf(val);
		}
	}

	private Comparator<? super K> comparator;
	private SplayNode<K, V> root;
	private int count = 0;

	public SplayTreeTemplate(SplayNode<K, V> root, Comparator<? super K> c) {
		this.root = root;
		comparator = c;
		count = countOfNodes(root);
	}

	public SplayTreeTemplate(Comparator<? super K> c) {
		this.comparator = c;
	}

	private int countOfNodes(SplayNode<K, V> root) {
		if (root == null)
			return 0;
		else {
			return countOfNodes(root.left) + countOfNodes(root.right) + 1;
		}
	}

	public SplayNode<K, V> getRoot() {
		return root;
	}

	// find a node in the tree
	// return true if found; false otherwise
	public boolean find(K key) throws Exception {
		if (root == null)
			throw new Exception("tree is empty.");
		splay(key);
		if (comparator.compare(root.key, key) == 0) {
			return true;
		} else {
			return false;
		}
	}

	private SplayNode<K, V> findMax(SplayNode<K, V> rootNode) throws Exception {
		if (rootNode == null)
			throw new Exception("tree or subtree is empty.");
		SplayNode<K, V> pseudoNode = new SplayNode<K, V>(null, null);
		// right tree root (no right child)
		SplayNode<K, V> leftMax = pseudoNode;

		SplayNode<K, V> t = rootNode;
		while (true) {
			SplayNode<K, V> parent = t.right;
			// if parent == null ,t is maximum
			if (parent == null)
				break;
			// Note: the variable parent is target's parent, the variable t is
			// target's grandparent
			if (parent.right == null) {
				// zag
				t.right = null;
				leftMax.right = t;
				leftMax = t;
				t = parent;
				break;
			} else {
				// zag-zag
				SplayNode<K, V> tmp = parent.right;

				// after rotate parent.right = null
				rotateRightChild(t, parent);

				// update left tree and its max node
				leftMax.right = parent;
				leftMax = parent;

				// update the middle tree's root
				t = tmp;
			}
		}
		leftMax.right = t.left;
		t.left = pseudoNode.right; // pseudoNode.right is the root of left tree
		return t;
	}

	public V findMax() throws Exception {
		root = findMax(this.root);
		return root.val;
	}

	private SplayNode<K, V> findMin(SplayNode<K, V> rootNode) throws Exception {
		if (rootNode == null)
			throw new Exception("tree or subtree is empty.");

		SplayNode<K, V> pseudoNode = new SplayNode<K, V>(null, null);
		// right tree root (no right child)
		SplayNode<K, V> rightMin = pseudoNode;

		SplayNode<K, V> t = rootNode;
		while (true) {
			SplayNode<K, V> parent = t.left;
			// if parent == null ,t is minimum
			if (parent == null)
				break;
			// Note: the variable parent is target's parent, the variable t is
			// target's grandparent
			if (parent.left == null) {
				// zig
				t.left = null;
				rightMin.left = t;
				rightMin = t;
				t = parent;
				break;
			} else {
				// zig-zig
				SplayNode<K, V> tmp = parent.left;

				// after rotate parent.left = null
				rotateLeftChild(t, parent);

				// update right tree and its min node
				rightMin.left = parent;
				rightMin = parent;

				// update the middle tree's root
				t = tmp;
			}
		}
		rightMin.left = t.right;
		t.right = pseudoNode.left; // pseudoNode.left is the root of right tree

		return t;
	}

	public V findMin() throws Exception {
		root = findMin(this.root);
		return root.val;
	}

	public V deleteMax() throws Exception {
		V max = findMax();
		this.root = root.left;
		count--;
		return max;
	}

	public V deleteMin() throws Exception {
		V min = findMin();
		this.root = root.right;
		count--;
		return min;
	}

	public void insert(K key, V val) throws Exception {
		if (root == null) {
			// set the new node as root
			this.root = new SplayNode<K, V>(key, val);
			count++;
		} else {
			splay(key);
			if (comparator.compare(root.key, key) == 0) {
				root.val = val;
				// throw new Exception("duplicate value!");
			} else if (comparator.compare(key, root.key) < 0) {
				// split the splayed tree with right subtree including root, and
				// set the new node as root
				// x is between root and root.left
				SplayNode<K, V> tmp = new SplayNode<K, V>(key, val);
				tmp.left = this.root.left;
				tmp.right = this.root;
				root.left = null;
				this.root = tmp;
				count++;
			} else {// ie. x>root.val
					// split the splayed tree with left subtree including root,
					// and set the new Node<K,V> as root
					// x is between root and root.right
				SplayNode<K, V> tmp = new SplayNode<K, V>(key, val);
				tmp.left = this.root;
				tmp.right = this.root.right;
				root.right = null;
				this.root = tmp;
				count++;
			}
		}
	}

	public V remove(K key) throws Exception {
		if (root == null)
			throw new Exception("tree is empty.");

		splay(key);
		if (comparator.compare(root.key, key) != 0) {
			throw new Exception("value not found.");
		}
		SplayNode<K, V> temp = root;
		if (root.left == null) {
			// root(root.val==x) is the min node
			root = root.right;
		} else {
			// find the max value from left subtree, and
			// then remove root and join the right subtree with the left splayed
			// subtree
			SplayNode<K, V> leftSubTreeRoot = this.findMax(this.root.left);
			leftSubTreeRoot.right = this.root.right;
			root = leftSubTreeRoot;
		}
		count--;
		return temp.val;
	}

	private void rotateLeftChild(SplayNode<K, V> grandparent,
			SplayNode<K, V> parent) {
		grandparent.left = parent.right;

		parent.right = grandparent;
		// split the parent with middle tree
		parent.left = null;
	}

	private void rotateRightChild(SplayNode<K, V> grandparent,
			SplayNode<K, V> parent) {
		grandparent.right = parent.left;

		parent.left = grandparent;
		// split the parent with middle tree
		parent.right = null;
	}

	// x: the target value to be found for splaying
	public void splay(K key) {
		this.root = splay(this.root, key);
	}

	// x: the target value to be found for splaying
	// rootNode: the root node of the tree to be splayed
	// return the new root of the splayed tree or subtree
	private SplayNode<K, V> splay(SplayNode<K, V> rootNode, K key) {
		if (rootNode == null)
			return null;

		SplayNode<K, V> pseudoNode = new SplayNode<K, V>(null, null);
		// left tree root (no left child)
		SplayNode<K, V> leftMax = pseudoNode;
		// right tree root (no right child)
		SplayNode<K, V> rightMin = pseudoNode;

		SplayNode<K, V> t = rootNode;
		while (true) {
			int comp = comparator.compare(key, t.key);
			if (comp == 0) { // key == t.key
				break;
			} else if (comp < 0) { // key <t.key
				// Note: the variable parent is target's parent, the variable t
				// is target's grandparent
				SplayNode<K, V> parent = t.left;
				if (parent == null) {
					break;
				} else {
					if (comparator.compare(key, parent.key) < 0) {
						if (parent.left == null) {
							// zig
							t.left = null;
							rightMin.left = t;
							rightMin = t;
							t = parent;
							break;
						} else {
							// zig-zig
							SplayNode<K, V> tmp = parent.left;

							// after rotate parent.left = null
							rotateLeftChild(t, parent);

							// update right tree and its min node
							rightMin.left = parent;
							rightMin = parent;

							// update the middle tree's root
							t = tmp;
						}
					} else { // ie. key >= parent.key
								// zig or zig-zag(simplified to zig)
						t.left = null;
						rightMin.left = t;
						rightMin = t;
						t = parent;
					}
				}
			} else { // ie. key > t.key
				SplayNode<K, V> parent = t.right;
				if (parent == null) {
					break;
				} else {
					if (comparator.compare(key, parent.key) > 0) {
						if (parent.right == null) {
							// zag
							t.right = null;
							leftMax.right = t;
							leftMax = t;
							t = parent;
							break;
						} else {
							// zag-zag
							SplayNode<K, V> tmp = parent.right;
							// after rotate parent.right = null
							rotateRightChild(t, parent);

							// update left tree and its max node
							leftMax.right = parent;
							leftMax = parent;

							// update the middle tree's root
							t = tmp;
						}
					} else { // ie. key <= parent.key
								// zag or zag-zig (simplified to zag)
						t.right = null;
						leftMax.right = t;
						leftMax = t;
						t = parent;
					}
				}
			}
		}
		// re-assemble (note: even if the above while is not executed, the
		// following code works as expected.)
		leftMax.right = t.left;
		rightMin.left = t.right;

		t.left = pseudoNode.right; // pseudoNode.right is the root of left tree
		t.right = pseudoNode.left; // pseudoNode.left is the root of right tree

		return t;
	}

	public int getSize() {
		return count;
	}

	public boolean isEmpty() {
		return count == 0;
	}

	// utility method for test purpose
	public static void recursiveInOrderTraverse(SplayNode root) {
		if (root == null)
			return;
		recursiveInOrderTraverse(root.left);
		System.out.format(" %s", root.val);
		recursiveInOrderTraverse(root.right);
	}

	// utility method for test purpose
	// n: the nodes number of the tree
	public static void displayBinaryTree(SplayNode root, int n) {
		if (root == null)
			return;

		LinkedList<SplayNode> queue = new LinkedList<SplayNode>();

		// all nodes in each level
		List<List<SplayNode>> nodesList = new ArrayList<List<SplayNode>>();

		// the positions in a displayable tree for each level's nodes
		List<List<Integer>> nextPosList = new ArrayList<List<Integer>>();

		queue.add(root);
		// int level=0;
		int levelNodes = 1;

		int nextLevelNodes = 0;
		List<SplayNode> levelNodesList = new ArrayList<SplayNode>();
		List<Integer> nextLevelNodesPosList = new ArrayList<Integer>();

		int pos = 0; // the position of the current node
		List<Integer> levelNodesPosList = new ArrayList<Integer>();
		levelNodesPosList.add(0); // root position
		nextPosList.add(levelNodesPosList);
		int levelNodesTotal = 1;
		while (!queue.isEmpty()) {
			SplayNode node = queue.remove();

			if (levelNodes == 0) {
				nodesList.add(levelNodesList);
				nextPosList.add(nextLevelNodesPosList);
				levelNodesPosList = nextLevelNodesPosList;

				levelNodesList = new ArrayList<SplayNode>();
				nextLevelNodesPosList = new ArrayList<Integer>();

				// level++;
				levelNodes = nextLevelNodes;
				levelNodesTotal = nextLevelNodes;

				nextLevelNodes = 0;
			}
			levelNodesList.add(node);

			pos = levelNodesPosList.get(levelNodesTotal - levelNodes);
			if (node.left != null) {
				queue.add(node.left);
				nextLevelNodes++;
				nextLevelNodesPosList.add(2 * pos);
			}

			if (node.right != null) {
				queue.add(node.right);
				nextLevelNodes++;

				nextLevelNodesPosList.add(2 * pos + 1);
			}

			levelNodes--;
		}
		// save the last level's nodes list
		nodesList.add(levelNodesList);

		int maxLevel = nodesList.size() - 1; // ==level

		// use both nodesList and nextPosList to set the positions for each node

		// Note: expected max columns: 2^(level+1) - 1
		int cols = 1;
		for (int i = 0; i <= maxLevel; i++) {
			cols <<= 1;
		}
		cols--;
		SplayNode[][] tree = new SplayNode[maxLevel + 1][cols];

		// load the tree into an array for later display
		for (int currLevel = 0; currLevel <= maxLevel; currLevel++) {
			levelNodesList = nodesList.get(currLevel);
			levelNodesPosList = nextPosList.get(currLevel);
			// Note: the column for this level's j-th element:
			// 2^(maxLevel-level)*(2*j+1) - 1
			int tmp = maxLevel - currLevel;
			int coeff = 1;
			for (int i = 0; i < tmp; i++) {
				coeff <<= 1;
			}
			for (int k = 0; k < levelNodesList.size(); k++) {
				int j = levelNodesPosList.get(k);
				int col = coeff * (2 * j + 1) - 1;
				tree[currLevel][col] = levelNodesList.get(k);
			}
		}

		// display the binary search tree
		System.out.format("%n");
		for (int i = 0; i <= maxLevel; i++) {
			for (int j = 0; j < cols; j++) {
				SplayNode node = tree[i][j];
				if (node == null)
					System.out.format("  ");
				else
					System.out.format("%2d", node.key);
			}
			System.out.format("%n");
		}
	}

	public static void printAfterSplayed(SplayTreeTemplate splayTree) {
		SplayNode root = splayTree.getRoot();
		System.out.format("%nAfter being splayed, in-order BST:%n");
		SplayTreeTemplate.recursiveInOrderTraverse(root);

		System.out.format("%n%n%nAfter being splayed, the tree is:");
		SplayTreeTemplate.displayBinaryTree(root, splayTree.getSize());
	}

	public static void main(String[] args) throws Exception {
		Comparator<Integer> comparator = new Comparator<Integer>() {
			@Override
			public int compare(Integer o1, Integer o2) {
				return o1 < o2 ? -1 : (o1 == o2 ? 0 : 1);
			}
		};

		// test1(comparator);

		// test2(comparator);

		// test3(comparator);

		test4(comparator);
	}

	private static void test4(Comparator<Integer> comparator) throws Exception {
		SplayTreeTemplate<Integer, String> splayTree;
		System.out.format("************************************");
		System.out.format("%nTest case 4 - priority queue:%n");

		splayTree = new SplayTreeTemplate<Integer, String>(comparator);
		long current = System.currentTimeMillis();
		for (int i = 0; i < 1000000; i++) {
			Integer key = i + new Random().nextInt(10);
//			Integer key = i;
			String val = "v_" + key;
			splayTree.insert(key, val);
		}
		long duration = System.currentTimeMillis() - current;
		System.out.println("完成插入: " + duration);
		current = System.currentTimeMillis();
		for (int i = 0; i < 1000000; i++) {
			Integer key = i + new Random().nextInt(10);
//			Integer key = i;
			splayTree.find(key);
		}
		duration = System.currentTimeMillis() - current;
		System.out.println("完成查詢:" + duration);
		current = System.currentTimeMillis();
		while (!splayTree.isEmpty()) {
			splayTree.deleteMax();
		}
		duration = System.currentTimeMillis() - current;
		System.out.println("完成刪除:" + duration);
	}

	private static void test3(Comparator<Integer> comparator) throws Exception {
		SplayNode<Integer, String> root;
		SplayTreeTemplate<Integer, String> splayTree;
		String max;
		Integer newKey;
		String newVal;
		System.out.format("************************************");
		System.out.format("%nTest case 3 - priority queue:%n");

		SplayNode<Integer, String> m1 = new SplayNode<Integer, String>(1, "1");
		SplayNode<Integer, String> m4 = new SplayNode<Integer, String>(4, "4");
		SplayNode<Integer, String> m7 = new SplayNode<Integer, String>(7, "7");
		SplayNode<Integer, String> m9 = new SplayNode<Integer, String>(9, "9");
		SplayNode<Integer, String> m20 = new SplayNode<Integer, String>(20,
				"20");
		SplayNode<Integer, String> m22 = new SplayNode<Integer, String>(22,
				"22");
		SplayNode<Integer, String> m26 = new SplayNode<Integer, String>(26,
				"26");
		SplayNode<Integer, String> m29 = new SplayNode<Integer, String>(29,
				"29");
		SplayNode<Integer, String> m30 = new SplayNode<Integer, String>(30,
				"30");
		SplayNode<Integer, String> m36 = new SplayNode<Integer, String>(36,
				"36");

		m1.right = m4;
		m4.right = m7;
		m7.right = m9;
		m9.right = m20;
		m20.right = m22;
		m22.right = m26;
		m26.right = m29;
		m29.right = m30;
		m30.right = m36;

		root = m1;
		System.out.format("%nBefore being splayed, in-order BST:%n");
		SplayTreeTemplate.recursiveInOrderTraverse(root);

		splayTree = new SplayTreeTemplate(root, comparator);

		max = splayTree.deleteMax();
		System.out.format("%n*****deleted max value: %s*****%n", max);
		printAfterSplayed(splayTree);

		max = splayTree.deleteMax();
		System.out.format("%n*****deleted max value: %s*****%n", max);
		printAfterSplayed(splayTree);

		max = splayTree.deleteMax();
		System.out.format("%n*****deleted max value: %s*****%n", max);
		printAfterSplayed(splayTree);

		max = splayTree.deleteMax();
		System.out.format("%n*****deleted max value: %s*****%n", max);
		printAfterSplayed(splayTree);

		newKey = 16;
		newVal = "16";
		splayTree.insert(newKey, newVal);
		System.out.format("%n*****insert new value %d*****%n", newKey);
		printAfterSplayed(splayTree);

		max = splayTree.deleteMax();
		System.out.format("%n*****deleted max value: %s*****%n", max);
		printAfterSplayed(splayTree);

		max = splayTree.deleteMax();
		System.out.format("%n*****deleted max value: %s*****%n", max);
		printAfterSplayed(splayTree);

		newKey = 12;
		newVal = "12";
		splayTree.insert(newKey, newVal);
		System.out.format("%n*****insert new value %d*****%n", newKey);
		printAfterSplayed(splayTree);

		max = splayTree.deleteMax();
		System.out.format("%n*****deleted max value: %s*****%n", max);
		printAfterSplayed(splayTree);
	}

	private static void test2(Comparator<Integer> comparator) throws Exception {
		SplayNode<Integer, String> root;
		SplayTreeTemplate<Integer, String> splayTree;
		System.out.format("************************************");
		System.out.format("%nTest case 2 - splaytree's operations:%n");
		/*
		 * 13 10 25 12 20 35 29
		 */

		SplayNode<Integer, String> n13 = new SplayNode<Integer, String>(13,
				"13");
		SplayNode<Integer, String> n10 = new SplayNode<Integer, String>(10,
				"10");
		SplayNode<Integer, String> n25 = new SplayNode<Integer, String>(25,
				"25");
		SplayNode<Integer, String> n12 = new SplayNode<Integer, String>(12,
				"12");
		SplayNode<Integer, String> n20 = new SplayNode<Integer, String>(20,
				"20");
		SplayNode<Integer, String> n35 = new SplayNode<Integer, String>(35,
				"35");
		SplayNode<Integer, String> n29 = new SplayNode<Integer, String>(29,
				"29");

		n13.left = n10;
		n13.right = n25;
		n10.right = n12;
		n25.left = n20;
		n25.right = n35;
		n35.left = n29;

		root = n13;
		System.out.format("%nBefore being splayed, in-order BST:%n");
		SplayTreeTemplate.recursiveInOrderTraverse(root);

		splayTree = new SplayTreeTemplate<Integer, String>(root, comparator);
		int val = 25;
		boolean found = splayTree.find(val);
		System.out.format("%n*****%d is in the tree? [%s]*****%n", val, found);
		printAfterSplayed(splayTree);

		String max = splayTree.findMax();
		System.out.format("%n*****max value=%s*****%n", max);
		printAfterSplayed(splayTree);

		String min = splayTree.findMin();
		System.out.format("%n*****min value=%s*****%n", min);
		printAfterSplayed(splayTree);

		max = splayTree.deleteMax();
		System.out.format("%n*****deleted max value: %s*****%n", max);
		printAfterSplayed(splayTree);

		min = splayTree.deleteMin();
		System.out.format("%n*****deleted min value: %s*****%n", min);
		printAfterSplayed(splayTree);

		Integer newKey = 24;
		String newVal = "24";
		splayTree.insert(newKey, newVal);
		System.out.format("%n*****insert new value %d*****%n", newKey);
		printAfterSplayed(splayTree);

		Integer removeVal = 12;
		splayTree.remove(removeVal);
		System.out.format("%n*****remove value %d*****%n", removeVal);
		printAfterSplayed(splayTree);
	}

	private static void test1(Comparator<Integer> comparator) {
		System.out.format("%nTest case 1 - splay opeartion:%n");

		SplayNode<Integer, String> nn12 = new SplayNode<Integer, String>(12, "12");
		SplayNode<Integer, String> nn5 = new SplayNode<Integer, String>(5, "5");
		SplayNode<Integer, String> nn25 = new SplayNode<Integer, String>(25, "25");
		SplayNode<Integer, String> nn20 = new SplayNode<Integer, String>(20, "20");
		SplayNode<Integer, String> nn30 = new SplayNode<Integer, String>(30, "30");
		SplayNode<Integer, String> nn15 = new SplayNode<Integer, String>(15, "15");
		SplayNode<Integer, String> nn24 = new SplayNode<Integer, String>(24, "24");
		SplayNode<Integer, String> nn13 = new SplayNode<Integer, String>(13, "13");
		SplayNode<Integer, String> nn18 = new SplayNode<Integer, String>(18, "18");
		SplayNode<Integer, String> nn16 = new SplayNode<Integer, String>(16, "16");

		nn12.left = nn5;
		nn12.right = nn25;
		nn25.left = nn20;
		nn25.right = nn30;
		nn20.left = nn15;
		nn20.right = nn24;
		nn15.left = nn13;
		nn15.right = nn18;
		nn18.left = nn16;

		SplayNode<Integer, String> root = nn12;
		System.out.format("%nBefore being splayed, in-order BST:%n");
		SplayTreeTemplate.recursiveInOrderTraverse(root);

		SplayTreeTemplate<Integer, String> splayTree = new SplayTreeTemplate(
				root, comparator);
		splayTree.splay(19);
		System.out.format("%n*****splay the node with value=19*****%n");
		printAfterSplayed(splayTree);
	}
}