KNN演算法,KD樹實現
阿新 • • 發佈:2019-01-05
自己實現的KD樹KNN演算法,和其他人不太一樣,歡迎批評指正
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
public class KdKnn {
private Node buildKDTree(List<Node> nodeList,int dimen){
if(nodeList.size()==0)
return null;
quicksort(nodeList,0,nodeList.size()-1 ,dimen);
int median=nodeList.size()/2;
Node root=nodeList.get(median);
List<Node> leftRange=new ArrayList<Node>();
List<Node> rightRange=new ArrayList<Node>();
for(Node node: nodeList){
if(node!=root){
if(node.getIndex(dimen)<root.getIndex(dimen)){
leftRange.add(node);
}else {
rightRange.add(node);
}
}
}
int newDimen=(++dimen)%2;
root.setLeft(buildKDTree(leftRange,newDimen));
root.setRight(buildKDTree(rightRange, newDimen));
return root;
}
private void quicksort(List<Node> nodeList,int left,int right,int dimen){
if(left<right){
int q=partition(nodeList, left, right, dimen);
quicksort(nodeList, left, q-1, dimen);
quicksort(nodeList, q+1, right, dimen);
}
}
private int partition(List<Node> nodeList, int left, int right,int dimen) {
double x=nodeList.get(right).getIndex(dimen);
int i=left-1;
for(int j=left;j<right;j++){
if(nodeList.get(j).getIndex(dimen)<x){
i++;
//交換i與j位置的節點
Collections.swap(nodeList, i, j);
}
}
Collections.swap(nodeList, i+1, right);
return i+1;
}
//基本思想:從根節點開始搜尋,搜尋過程中順便把搜尋路徑經過節點的“反向節點”加入先序佇列中。搜尋到達葉節點時候,這個葉節點暫時是
//距離target節點最近的節點,計算distance。隨後調整堆,按照distance的大小,小的在堆頂,只需要調整一次堆即可,即top1。
//與堆頂元素進行比較,如果葉節點distance大於堆頂節點,則最近的節點便是堆頂元素,否則亦然。
//總之:1、只計算葉節點 2、只檢查其中一些葉節點
//July說要把root節點放入先序佇列。。。為什麼要放進去?
private Node searchKNN(Node root,Node target,int dimen){
double Max_dist=0;
Node nearest=null;
List<Node> pirorList=new ArrayList<Node>();
Node Kd_point=root;
int max_steps=0;
while(max_steps<200){
int d=(dimen++)%2;
if(target.getIndex(d)<Kd_point.getIndex(d)){ //進入左子樹
if(Kd_point.getRight()!=null){ //將右子樹存入先序佇列
Kd_point.getRight().setDistance(distance(target,Kd_point.getRight()));
pirorList.add(Kd_point.getRight());
}
Kd_point=Kd_point.getLeft();
}else{
if(Kd_point.getLeft()!=null){ //將左子樹加入先序佇列
Kd_point.getLeft().setDistance(distance(target,Kd_point.getLeft()));
pirorList.add(Kd_point.getLeft());
}
Kd_point=Kd_point.getRight(); //進入右子樹
}
max_steps++;
if(Kd_point.getRight()==null&&Kd_point.getLeft()==null){ //掃描到了葉節點
Max_dist=distance(Kd_point,target);
Kd_point.setDistance(Max_dist);
nearest=Kd_point;
break;
}
}
maintainHeap(pirorList); //只調整一次堆就可以了
if(pirorList.get(0).getDistance()<Max_dist)
nearest=pirorList.get(0);
return nearest;
}
private void maintainHeap(List<Node> pirorList) {
for(int i=pirorList.size()/2-1;i>-1;i--){
fixHeap(pirorList,i);
}
}
private void fixHeap(List<Node> pirorList, int root) {
int left=2*root+1;
int right=2*root+2;
int min=root;
if(left<pirorList.size()&&pirorList.get(min).getDistance()>pirorList.get(left).getDistance())
min=left;
if(right<pirorList.size()&&pirorList.get(min).getDistance()>pirorList.get(right).getDistance())
min=right;
Collections.swap(pirorList, min, root);
if(root!=min){
fixHeap( pirorList, min);
}
}
private double distance(Node a,Node b){
double dist=0;
double [] A=a.getData();
double [] B=b.getData();
for(int i=0;i<A.length;i++)
dist+=Math.pow(A[i]-B[i], 2);
return Math.sqrt(dist);
}
public static void main(String[] args) {
List<Node> nodeList=new ArrayList<Node>();
nodeList.add(new Node(new double[]{2,3}));
nodeList.add(new Node(new double[]{5,4}));
nodeList.add(new Node(new double[]{9,6}));
nodeList.add(new Node(new double[]{4,7}));
nodeList.add(new Node(new double[]{8,1}));
nodeList.add(new Node(new double[]{7,2}));
KdKnn nd=new KdKnn();
Node root=nd.buildKDTree(nodeList, 0);
Node target=new Node(new double[]{2.1,3.1});
double [] nea=nd.searchKNN(root, target,0).getData();
for(int i=0;i<nea.length;i++)
System.out.println(nea[i]);
System.out.println(nd.searchKNN(root, target,0).getDistance());
}
}
public class Node {
private double data[];
private Node left;//左子樹
private Node right;//右子樹
private double distance;
public Node(double [] data){
this.data=data;
}
public double getIndex(int index){
return data[index];
}
public double[] getData() {
return data;
}
public void setData(double[] data) {
this.data = data;
}
public Node getLeft() {
return left;
}
public void setLeft(Node left) {
this.left = left;
}
public Node getRight() {
return right;
}
public void setRight(Node right) {
this.right = right;
}
public double getDistance() {
return distance;
}
public void setDistance(double distance) {
this.distance = distance;
}
}
參考: