1. 程式人生 > >KNN演算法,KD樹實現

KNN演算法,KD樹實現

自己實現的KD樹KNN演算法,和其他人不太一樣,歡迎批評指正

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;



public class KdKnn {

    private Node buildKDTree(List<Node> nodeList,int dimen){

        if(nodeList.size()==0)
            return null;
        quicksort(nodeList,0,nodeList.size()-1
,dimen); int median=nodeList.size()/2; Node root=nodeList.get(median); List<Node> leftRange=new ArrayList<Node>(); List<Node> rightRange=new ArrayList<Node>(); for(Node node: nodeList){ if(node!=root){ if(node.getIndex(dimen)<root.getIndex(dimen)){ leftRange.add(node); }else
{ rightRange.add(node); } } } int newDimen=(++dimen)%2; root.setLeft(buildKDTree(leftRange,newDimen)); root.setRight(buildKDTree(rightRange, newDimen)); return root; } private void quicksort(List<Node> nodeList,int
left,int right,int dimen){ if(left<right){ int q=partition(nodeList, left, right, dimen); quicksort(nodeList, left, q-1, dimen); quicksort(nodeList, q+1, right, dimen); } } private int partition(List<Node> nodeList, int left, int right,int dimen) { double x=nodeList.get(right).getIndex(dimen); int i=left-1; for(int j=left;j<right;j++){ if(nodeList.get(j).getIndex(dimen)<x){ i++; //交換i與j位置的節點 Collections.swap(nodeList, i, j); } } Collections.swap(nodeList, i+1, right); return i+1; } //基本思想:從根節點開始搜尋,搜尋過程中順便把搜尋路徑經過節點的“反向節點”加入先序佇列中。搜尋到達葉節點時候,這個葉節點暫時是 //距離target節點最近的節點,計算distance。隨後調整堆,按照distance的大小,小的在堆頂,只需要調整一次堆即可,即top1。 //與堆頂元素進行比較,如果葉節點distance大於堆頂節點,則最近的節點便是堆頂元素,否則亦然。 //總之:1、只計算葉節點 2、只檢查其中一些葉節點 //July說要把root節點放入先序佇列。。。為什麼要放進去? private Node searchKNN(Node root,Node target,int dimen){ double Max_dist=0; Node nearest=null; List<Node> pirorList=new ArrayList<Node>(); Node Kd_point=root; int max_steps=0; while(max_steps<200){ int d=(dimen++)%2; if(target.getIndex(d)<Kd_point.getIndex(d)){ //進入左子樹 if(Kd_point.getRight()!=null){ //將右子樹存入先序佇列 Kd_point.getRight().setDistance(distance(target,Kd_point.getRight())); pirorList.add(Kd_point.getRight()); } Kd_point=Kd_point.getLeft(); }else{ if(Kd_point.getLeft()!=null){ //將左子樹加入先序佇列 Kd_point.getLeft().setDistance(distance(target,Kd_point.getLeft())); pirorList.add(Kd_point.getLeft()); } Kd_point=Kd_point.getRight(); //進入右子樹 } max_steps++; if(Kd_point.getRight()==null&&Kd_point.getLeft()==null){ //掃描到了葉節點 Max_dist=distance(Kd_point,target); Kd_point.setDistance(Max_dist); nearest=Kd_point; break; } } maintainHeap(pirorList); //只調整一次堆就可以了 if(pirorList.get(0).getDistance()<Max_dist) nearest=pirorList.get(0); return nearest; } private void maintainHeap(List<Node> pirorList) { for(int i=pirorList.size()/2-1;i>-1;i--){ fixHeap(pirorList,i); } } private void fixHeap(List<Node> pirorList, int root) { int left=2*root+1; int right=2*root+2; int min=root; if(left<pirorList.size()&&pirorList.get(min).getDistance()>pirorList.get(left).getDistance()) min=left; if(right<pirorList.size()&&pirorList.get(min).getDistance()>pirorList.get(right).getDistance()) min=right; Collections.swap(pirorList, min, root); if(root!=min){ fixHeap( pirorList, min); } } private double distance(Node a,Node b){ double dist=0; double [] A=a.getData(); double [] B=b.getData(); for(int i=0;i<A.length;i++) dist+=Math.pow(A[i]-B[i], 2); return Math.sqrt(dist); } public static void main(String[] args) { List<Node> nodeList=new ArrayList<Node>(); nodeList.add(new Node(new double[]{2,3})); nodeList.add(new Node(new double[]{5,4})); nodeList.add(new Node(new double[]{9,6})); nodeList.add(new Node(new double[]{4,7})); nodeList.add(new Node(new double[]{8,1})); nodeList.add(new Node(new double[]{7,2})); KdKnn nd=new KdKnn(); Node root=nd.buildKDTree(nodeList, 0); Node target=new Node(new double[]{2.1,3.1}); double [] nea=nd.searchKNN(root, target,0).getData(); for(int i=0;i<nea.length;i++) System.out.println(nea[i]); System.out.println(nd.searchKNN(root, target,0).getDistance()); } } public class Node { private double data[]; private Node left;//左子樹 private Node right;//右子樹 private double distance; public Node(double [] data){ this.data=data; } public double getIndex(int index){ return data[index]; } public double[] getData() { return data; } public void setData(double[] data) { this.data = data; } public Node getLeft() { return left; } public void setLeft(Node left) { this.left = left; } public Node getRight() { return right; } public void setRight(Node right) { this.right = right; } public double getDistance() { return distance; } public void setDistance(double distance) { this.distance = distance; } }

參考: