1. 程式人生 > >Coursera普林斯頓大學演算法Week5: Kd-Trees 線段樹

Coursera普林斯頓大學演算法Week5: Kd-Trees 線段樹

本任務的PointSET比較好實現,借用給的Point2D API比較容易

而Kdtree任務比較複雜。主要是針對邊界問題比較複雜,需要分清待插入節點的父節點是位於偶數層還是位於奇數層,根據不同的層數具有不同的點比較方案。

private int compare(Node pNode, Point2D thisPoint) 
	{
		if (pNode == null)
			throw new java.lang.IllegalArgumentException("the Node object is null");
		if (thisPoint == null) 
			throw new java.lang.IllegalArgumentException("the Point2D object is null");
		
		if (thisPoint.compareTo(pNode.point2d) == 0)
			return 0;
	
		if (pNode.depth % 2 != 0) // 父節點在奇數層,看放父節點的左右側
		{
			if (Double.compare(pNode.point2d.x(), thisPoint.x()) == 1) // 小於0右側
				return 1;
			else
				return -1;
		}
		else  // 父節點在偶數層,看放在父節點的上下側
		{
			if (Double.compare(pNode.point2d.y(), thisPoint.y()) == 1) // 小於0上側
				return 1;
			else
				return -1;
		}
	}

其中Node是定義的私有類,主要有幾個成員

private class Node {
		Point2D point2d;  // 分割矩形的點
		RectHV rectHV;   // 分割矩形
		Node leftNode;   // 左子樹節點
		Node rigthtNode;  // 右子樹節點
		int depth;   // 節點的層數
		
		public Node(Point2D point2d, RectHV rectHV, int depth) {
			this.point2d = point2d;
			this.rectHV = rectHV;
			this.depth = depth;
		}
	}

另外一個難點在於查詢給定點的最近點。其主要思路是,先查詢位於查詢點同一側的子節點,對於另一側的子節點,若查詢點據另一側點的矩形最近距離小於當前最近距離才有希望能在另一側找到最近點,才去查詢另一側的子節點,減少時間複雜度。

對於rang函式查詢給定矩形內的點集,也是當子節點的矩形與給定節點相交時才有希望在子節點上查詢到點落入給定矩形內。減少時間複雜度。

import edu.princeton.cs.algs4.Point2D;
import edu.princeton.cs.algs4.Queue;
import edu.princeton.cs.algs4.RectHV;
import edu.princeton.cs.algs4.SET;

public class PointSET {
	private SET<Point2D> points;
	
	public PointSET()    // 構造一個空點集
	{
		points = new SET<Point2D>();
	}
	public boolean isEmpty()  // 這個集合是空的嗎? 
	{
		return points.isEmpty();
	}
	public int size()  // 集合中的點數 
	{
		return points.size();
	}
	public void insert(Point2D p)   // 將該點新增到集合中(如果它尚未在集合中)
	{
		if (p == null) 
			throw new java.lang.IllegalArgumentException("this Point2D is null");
		
		points.add(p);
	}
	public boolean contains(Point2D p) // 集合是否包含點P?
	{
		if (p == null) 
			throw new java.lang.IllegalArgumentException("this Point2D is null");
		
		return points.contains(p);
	}
	public void draw()  // 把所有點畫成標準畫
	{
		for (Point2D point2d : points) {
			point2d.draw();
		}
	}
	public Iterable<Point2D> range(RectHV rect)  // 在矩形(或邊界)內的所有點
	{
		if (rect == null) 
			throw new java.lang.IllegalArgumentException("The RectHV is null");
		
		Queue<Point2D> queue = new Queue<>();  //  佇列用於儲存在矩形內(包含邊界)的點
		
		for (Point2D point2d : points) {
			if (rect.contains(point2d)) 
				queue.enqueue(point2d);  // 進佇列
		}
		
		return queue;
	}
	public Point2D nearest(Point2D p) // 集合為點p的最近鄰;如果集合為空,則為null。
	{
		if (points == null) // 集合為空
			return null;
		
		Point2D point2dNearest = null;
		double distanceMin = Double.POSITIVE_INFINITY;  // 兩點間歐式距離平方
		
		for (Point2D point2d : points) 
		{
			double distanceCurrent = point2d.distanceSquaredTo(p);
			if (distanceCurrent < distanceMin)  // 遍歷找到距離最小的點 
			{ 
				point2dNearest = point2d;
				distanceMin = distanceCurrent;
			}
		}
		
		return point2dNearest;
	}
	public static void main(String[] args)  // 單元測試的方法(可選) 
	{
		System.out.println(Double.compare(0.2, 0.3));
		PointSET pointSET = new PointSET();
		
		Point2D [] point2ds = new Point2D[8];
		for (int i = 0; i < point2ds.length; i++) {
			point2ds[i] = new Point2D(i/10.0, (i+1)/10.0);
			pointSET.insert(point2ds[i]);
		}
		
		System.out.println(pointSET.size());
		
		System.out.println(pointSET.contains(new Point2D(0.3, 0.3)));
		System.out.println(pointSET.nearest(new Point2D(0.3, 0.6)));
		RectHV rectHV = new RectHV(0.2, 0.2, 0.6, 0.9);
		Iterable<Point2D> pQueue = pointSET.range(rectHV); 
		
		for (Point2D point2d : pQueue) {
			System.out.println(point2d);
		}
		
	}
}

import edu.princeton.cs.algs4.Point2D;
import edu.princeton.cs.algs4.Queue;
import edu.princeton.cs.algs4.RectHV;
import edu.princeton.cs.algs4.StdDraw;

public class KdTree {
	
	private Node root;  // 根節點
	private int size; // 節點個數
	
	private class Node {
		Point2D point2d;  // 分割矩形的點
		RectHV rectHV;   // 分割矩形
		Node leftNode;   // 左子樹節點
		Node rigthtNode;  // 右子樹節點
		int depth;   // 節點的層數
		
		public Node(Point2D point2d, RectHV rectHV, int depth) {
			this.point2d = point2d;
			this.rectHV = rectHV;
			this.depth = depth;
		}
	}
	
	public KdTree()    // 構造一個空點集
	{
		root = null;
		size = 0;
	}
	public boolean isEmpty()  // 這個集合是空的嗎? 
	{
		return size == 0;
	}
	public int size()  // 集合中的點數 
	{
		return size;
	}
	
	private Node insert(Node insertPalceNode, Node perNode, Point2D thisPoint)
	{
		if (insertPalceNode == null)
		{
			if (size == 0)  // 原集合中無元素
				return new Node(thisPoint, new RectHV(0, 0, 1, 1), 1);
		
			else // 原集合中有元素,查詢其父節點
			{
				int cmp = compare(perNode, thisPoint);
				RectHV rectHV = null;  
				
				if (perNode.depth % 2 == 0)  // 父節點在偶數層,在上下側插入
				{
					if (cmp > 0) // 下方,同xmin,ymin,xmax;ymax = perNode.point.y
						rectHV = new RectHV(perNode.rectHV.xmin(), perNode.rectHV.ymin(),
								perNode.rectHV.xmax(), perNode.point2d.y());
					
					if (cmp < 0) // 上方,同xmax,ymax,xmin;ymin = perNode.point.y
						rectHV = new RectHV(perNode.rectHV.xmin(), perNode.point2d.y(), 
								perNode.rectHV.xmax(), perNode.rectHV.ymax()); 
				}
				else // 父節點在奇數層,在左右側插入 
				{
					if (cmp > 0)  // 左側, 同xmin,ymin,ymax;xmax = perNode.point.x
						rectHV = new RectHV(perNode.rectHV.xmin(), perNode.rectHV.ymin(),
								perNode.point2d.x(), perNode.rectHV.ymax());
					
					if (cmp < 0) // 右側,同xmax,ymax,ymin;xmin = perNode.point.x 
						rectHV = new RectHV(perNode.point2d.x(), perNode.rectHV.ymin(),
								perNode.rectHV.xmax(), perNode.rectHV.ymax());
				}
				return new Node(thisPoint, rectHV, perNode.depth + 1);
			}
		}

		else  // insertPalceNode != null 
		{
			int cmp = compare(insertPalceNode, thisPoint);
			
			if (cmp > 0) // 下方或左側,左子樹
				insertPalceNode.leftNode = insert(insertPalceNode.leftNode, insertPalceNode, thisPoint);
			if (cmp < 0)  // 上方或右側,右子樹
				insertPalceNode.rigthtNode = insert(insertPalceNode.rigthtNode, insertPalceNode, thisPoint);
			return insertPalceNode;
		}
	}
	
	private int compare(Node pNode, Point2D thisPoint) 
	{
		if (pNode == null)
			throw new java.lang.IllegalArgumentException("the Node object is null");
		if (thisPoint == null) 
			throw new java.lang.IllegalArgumentException("the Point2D object is null");
		
		if (thisPoint.compareTo(pNode.point2d) == 0)
			return 0;
	
		if (pNode.depth % 2 != 0) // 父節點在奇數層,看放父節點的左右側
		{
			if (Double.compare(pNode.point2d.x(), thisPoint.x()) == 1) // 小於0右側
				return 1;
			else
				return -1;
		}
		else  // 父節點在偶數層,看放在父節點的上下側
		{
			if (Double.compare(pNode.point2d.y(), thisPoint.y()) == 1) // 小於0上側
				return 1;
			else
				return -1;
		}
	}
	
	public void insert(Point2D p)   // 將該點新增到集合中(如果它尚未在集合中)
	{
		if (p == null) 
			throw new java.lang.IllegalArgumentException("the Point2D is null");
		
		if (contains(p)) 
			return;
		root = insert(root, null, p);
		size++;
	}
	
	private boolean containsP(Point2D p, Node cmpNoe) 
	{
		if (cmpNoe == null) 
			return false;
		int cmp = compare(cmpNoe, p);   
		if (cmp > 0)  // 左子樹
			return containsP(p, cmpNoe.leftNode);
		if (cmp < 0)  // 右子樹
			return containsP(p, cmpNoe.rigthtNode);
		
		return true;
		
			
	}
	public boolean contains(Point2D p) // 集合是否包含點P?
	{
		if (p == null) 
			throw new java.lang.IllegalArgumentException("the Point2D is null");
		return containsP(p, root);
	}
	
	
	public void draw() // 把所有點畫成標準畫
    {
        draw(root);
    }

    private void draw(Node x)
    {
        if (x == null) return; 
        draw(x.leftNode);
        draw(x.rigthtNode);
        StdDraw.setPenColor(StdDraw.BLACK);
        StdDraw.setPenRadius(0.01);
        x.point2d.draw();
        StdDraw.setPenRadius();
        // draw the splitting line segment
        if (x.depth % 2 == 0) 
        {
            StdDraw.setPenColor(StdDraw.RED);
            StdDraw.line(x.point2d.x(), x.rectHV.ymin(), x.point2d.x(), x.rectHV.ymax());   
        }
        else
        {
            StdDraw.setPenColor(StdDraw.BLUE);
            StdDraw.line(x.rectHV.xmin(), x.point2d.y(), x.rectHV.xmax(), x.point2d.y());   
        }
    } 

	private void range(Node currentNode, Queue<Point2D> queue, RectHV rectHV)
	{
		if (rectHV.contains(currentNode.point2d)) // 矩形中包含點
			queue.enqueue(currentNode.point2d);
	
		if (currentNode.leftNode != null && currentNode.leftNode.rectHV.intersects(rectHV)) {
			range(currentNode.leftNode, queue, rectHV);  //左子樹
		}
		
		if (currentNode.rigthtNode != null && currentNode.rigthtNode.rectHV.intersects(rectHV)) {
			range(currentNode.rigthtNode, queue, rectHV); // 右子樹
		}
		
	}
	
	public Iterable<Point2D> range(RectHV rect)  // 在矩形(或邊界)內的所有點
	{
		if (rect == null) 
			throw new java.lang.IllegalArgumentException("The RectHV is null");
		
		Queue<Point2D> queue = new Queue<>();  //  佇列用於儲存在矩形內(包含邊界)的點
		if (root == null) {
			return queue;
		}
		range(root, queue, rect);   //從根節點開始遍歷
		
		return queue;
	}
	
	
	private Node nearest(Node currentNode, Node nearestNode, Point2D p)
	{
		if (currentNode == null)    
			return nearestNode;
	
		double nearstDistance = Double.POSITIVE_INFINITY;   // 當前最短距離
		double currentDistance = p.distanceSquaredTo(currentNode.point2d); // 當前節點距離
		
		if (nearestNode != null) {  // 根據當前最近節點計算最短距離
			nearstDistance = p.distanceSquaredTo(nearestNode.point2d);
		}
		else {  // 無當前最近節點,當前節點即為最近節點
			nearstDistance = currentDistance;
			nearestNode = currentNode;
		}
		
		if (currentDistance < nearstDistance)  // 更改最近節點資訊
		{
			nearestNode = currentNode;
			nearstDistance = currentDistance;
		}
		
		int cmp = compare(currentNode, p);
		
		if (cmp > 0)  // 點位於當前節點的左子樹
		{
			nearestNode = nearest(currentNode.leftNode, nearestNode, p);
				
			// p點距離該節點水平線的垂直距離小於當前最短距離,在當前點的上側才有機會存在最近的點 
			if (currentNode.rigthtNode != null && 
					currentNode.rigthtNode.rectHV.distanceSquaredTo(p) < p.distanceSquaredTo(nearestNode.point2d))  
				nearestNode = nearest(currentNode.rigthtNode, nearestNode, p);
		}	
		else if (cmp < 0)  // 點位於當前節點的右子樹
		{
			nearestNode = nearest(currentNode.rigthtNode, nearestNode, p);
				
			// p點距離該節點水平線的垂直距離小於當前最短距離,在當前點的下側才有機會存在最近的點
			if (currentNode.leftNode != null && 
					currentNode.leftNode.rectHV.distanceSquaredTo(p) < p.distanceSquaredTo(nearestNode.point2d)) 
				nearestNode = nearest(currentNode.leftNode, nearestNode, p);
		}
		return nearestNode;
	}
	
	public Point2D nearest(Point2D p) // 集合為點p的最近鄰;如果集合為空,則為null。
	{
		if (root == null) 
			return null;
	
		return nearest(root, null, p).point2d;
	}
	
	
	public static void main(String[] args)  // 單元測試的方法(可選) 
	{
		KdTree kdTree = new KdTree();
		Point2D [] point2ds = new Point2D[5];
		point2ds[0] = new Point2D(0.7, 0.2);
		point2ds[1] = new Point2D(0.5, 0.4);
		point2ds[2] = new Point2D(0.2, 0.3);
		point2ds[3] = new Point2D(0.4, 0.7);
		point2ds[4] = new Point2D(0.9, 0.6);
		
		for (int i = 0; i < point2ds.length; i++) {
			System.out.println("*************i=" + i + "***************");
			kdTree.insert(point2ds[i]);
			System.out.println(kdTree.size());
		}
		
		System.out.println(kdTree.contains(point2ds[4]));
		
		Iterable<Point2D> iterable = kdTree.range(new RectHV(0, 0, 1, 1));
		
		for (Point2D point2d : iterable) {
			System.out.println(point2d.toString());
		}
		System.out.println();
		System.out.println(kdTree.root.point2d.toString());
		System.out.println(kdTree.root.leftNode.point2d.toString());
		System.out.println(kdTree.root.rigthtNode.point2d.toString());
		System.out.println(kdTree.root.leftNode.leftNode.point2d.toString());
		System.out.println(kdTree.root.leftNode.rigthtNode.point2d.toString());
		
		System.out.println(kdTree.size);
		System.out.println(kdTree.nearest(new Point2D(0.111, 0.494)));
	}
}