1. 程式人生 > >k最近鄰(k-nn)

k最近鄰(k-nn)



源資料如下(raw_data)

1:7 7
1:7 4
0:3 4
0:1 4

程式碼如下(main.py)

#encoding=utf-8
#獲取兩個整型列表的平方差之和
def get_distance(list1,list2):
    length=len(list1)
    total=0
    for i in xrange(0,length):
        #print i,list1[i],list2[i]
        result1= (list1[i] - list2[i]) ** 2
        #print result1
        total = total + result1
    return total

predict_data=[3,7]#判斷它的類別
raw_data=open("raw_data","r+")#讀取源資料
raw_list=raw_data.readlines()
map_list=[]
record_index=0#
for ele in raw_list:#對於源資料的每一行 1:7 7
    ele=ele.strip()
    split_list=ele.split(":")
    tag=split_list[0]
    feature=split_list[1]
    feature_list=feature.split(" ")
    feature_list_int=[]
    for ele in feature_list:
        feature_list_int.append(int(ele.strip()))
    list1=[]
    list1.append(record_index)#每條記錄的id
    record_index = record_index + 1
    for ele in feature_list:#每條記錄的特性
        list1.append(int(ele.strip()))
    distance=get_distance(predict_data,feature_list_int)#每條記錄和預測特徵的歐式距離
    list1.append(distance)
    list1.append(int(tag.strip()))#每條記錄的標籤
    #print list1
    map_list.append(list1)
#list1=[ [0,7,7,16,1],[1,7,4,25,1],[2,3,4,9,0],[3,1,4,13,0] ]#未排序前的map_list:   id feature1 feature2 distance tag
map_list.sort(key=lambda x:x[3])
map_file=open("map_data","w+")
for ele in map_list:#將處理後的資料寫入檔案
    map_file.write(str(ele))
    map_file.write("\n")
#下面開始knn演算法
class1 = 0 #好
class2 = 1 #壞
class1_no = 0
class2_no = 0
k = 3
top_k_neighbor=map_list[0:k]#k個最近鄰
for ele in top_k_neighbor:#統計k個最近鄰中每個標籤的數量
    if ele[4] == class1:
        class1_no = class1_no + 1
        #print class1,ele[3]
    if ele[4] == class2:
        class2_no = class2_no + 1
        #print class2,ele[3]
if(class1_no > class2_no):#哪個標籤的數量最多,就是屬於哪一類
    print "good"
if(class1_no < class2_no):
    print "bad"
if(class1_no == class2_no):
    print "choose another k"


資料處理結果如下(map_data)

[2, 3, 4, 9, 0]
[3, 1, 4, 13, 0]
[0, 7, 7, 16, 1]
[1, 7, 4, 25, 1]