1. 程式人生 > >KNN實現(資料集cifar10)

KNN實現(資料集cifar10)

1. 讀取資料集

import pickle

with open('data_batch_2', 'rb') as f:
    #x = pic.load(f, encoding='bytes')
    x = pickle.load(f, encoding='latin1')
    print(x['data'].shape)

#shape(10000, 3072)
  • cifar資料集是用pickle序列化儲存,讀取方式python2和python3不同,此處採用的python3。encoding可以是bytes,也可以是latin1,目前還沒搞懂這是為什麼。
def cifarLoad():
    file = 'data_batch_'
    train_data = []
    train_label = []
    val_data = []
    val_label = []
    for i in range(1, 6):
        filename = file + str(i)
        data_batch = unpickle(filename)
        train_data.extend(list(data_batch['data'])[0:9000])
        list(data_batch['data'])
        train_label.extend(data_batch['labels'][0:9000])
        val_data.extend(data_batch['data'][9000:, :])
        val_label.extend(data_batch['labels'][9000:])

    return np.array(train_data), np.array(train_label), np.array(val_data), np.array(val_label)
  • 分成驗證集和訓練集(本次沒有采用交叉驗證,後面會採用交叉驗證再試一次)。
class NearestNeighbor(object):
    def __init__(self):
        self.X = None
        self.y = None
        self.dist = 0

    def train(self, x, y):
        self.xtr = x
        self.ytr = y

    def predict(self, test_X, k, distance):
        num_test = test_X.shape[0]
        pre = []
        for i in range(num_test):
            if distance == 'L1':
                self.dist = self.L1Distance(test_X[i])
            if distance == 'L2':
                self.dist = self.L2Distance(test_X[i])
            distArgSort = np.argsort(self.dist)[0:k]
            classSort = self.ytr[distArgSort]
            classCount = np.bincount(classSort)
            predict = np.argmax(classCount)
            #print(predict.dtype)
            pre.append(predict)
        return np.array(pre)

    def L1Distance(self, x):
        dist = np.sum(abs(self.xtr-x), axis=1)
        return dist

    def L2Distance(self, x):
        dist = np.sqrt(np.sum(np.square(self.xtr-x), axis=1))
        return dist
  • KNN程式碼
import pickle
import numpy as np

def unpickle(filename):
    with open(filename, 'rb') as f:
        cifar = pickle.load(f, encoding='latin1')
    return cifar

def cifarLoad():
    file = 'data_batch_'
    train_data = []
    train_label = []
    val_data = []
    val_label = []
    for i in range(1, 6):
        filename = file + str(i)
        data_batch = unpickle(filename)
        train_data.extend(list(data_batch['data'])[0:9000])
        list(data_batch['data'])
        train_label.extend(data_batch['labels'][0:9000])
        val_data.extend(data_batch['data'][9000:, :])
        val_label.extend(data_batch['labels'][9000:])

    return np.array(train_data), np.array(train_label), np.array(val_data), np.array(val_label)

class NearestNeighbor(object):
    def __init__(self):
        self.X = None
        self.y = None
        self.dist = 0

    def train(self, x, y):
        self.xtr = x
        self.ytr = y

    def predict(self, test_X, k, distance):
        num_test = test_X.shape[0]
        pre = []
        for i in range(num_test):
            if distance == 'L1':
                self.dist = self.L1Distance(test_X[i])
            if distance == 'L2':
                self.dist = self.L2Distance(test_X[i])
            distArgSort = np.argsort(self.dist)[0:k]
            classSort = self.ytr[distArgSort]
            classCount = np.bincount(classSort)
            predict = np.argmax(classCount)
            #print(predict.dtype)
            pre.append(predict)
        return np.array(pre)

    def L1Distance(self, x):
        dist = np.sum(abs(self.xtr-x), axis=1)
        return dist

    def L2Distance(self, x):
        dist = np.sqrt(np.sum(np.square(self.xtr-x), axis=1))
        return dist

def CrossValidation():
    pass



if __name__ == '__main__':
    train_data, train_label, val_data, val_label = cifarLoad()
    clf = NearestNeighbor()
    train = clf.train(train_data, train_label)
    pre = clf.predict(val_data, k=20, distance='L2')
    arr = np.mean(pre == val_label)
    print(arr)