KNN實現(資料集cifar10)
阿新 • • 發佈:2018-12-09
1. 讀取資料集
import pickle
with open('data_batch_2', 'rb') as f:
#x = pic.load(f, encoding='bytes')
x = pickle.load(f, encoding='latin1')
print(x['data'].shape)
#shape(10000, 3072)
- cifar資料集是用pickle序列化儲存,讀取方式python2和python3不同,此處採用的python3。encoding可以是bytes,也可以是latin1,目前還沒搞懂這是為什麼。
def cifarLoad(): file = 'data_batch_' train_data = [] train_label = [] val_data = [] val_label = [] for i in range(1, 6): filename = file + str(i) data_batch = unpickle(filename) train_data.extend(list(data_batch['data'])[0:9000]) list(data_batch['data']) train_label.extend(data_batch['labels'][0:9000]) val_data.extend(data_batch['data'][9000:, :]) val_label.extend(data_batch['labels'][9000:]) return np.array(train_data), np.array(train_label), np.array(val_data), np.array(val_label)
- 分成驗證集和訓練集(本次沒有采用交叉驗證,後面會採用交叉驗證再試一次)。
class NearestNeighbor(object): def __init__(self): self.X = None self.y = None self.dist = 0 def train(self, x, y): self.xtr = x self.ytr = y def predict(self, test_X, k, distance): num_test = test_X.shape[0] pre = [] for i in range(num_test): if distance == 'L1': self.dist = self.L1Distance(test_X[i]) if distance == 'L2': self.dist = self.L2Distance(test_X[i]) distArgSort = np.argsort(self.dist)[0:k] classSort = self.ytr[distArgSort] classCount = np.bincount(classSort) predict = np.argmax(classCount) #print(predict.dtype) pre.append(predict) return np.array(pre) def L1Distance(self, x): dist = np.sum(abs(self.xtr-x), axis=1) return dist def L2Distance(self, x): dist = np.sqrt(np.sum(np.square(self.xtr-x), axis=1)) return dist
- KNN程式碼
import pickle import numpy as np def unpickle(filename): with open(filename, 'rb') as f: cifar = pickle.load(f, encoding='latin1') return cifar def cifarLoad(): file = 'data_batch_' train_data = [] train_label = [] val_data = [] val_label = [] for i in range(1, 6): filename = file + str(i) data_batch = unpickle(filename) train_data.extend(list(data_batch['data'])[0:9000]) list(data_batch['data']) train_label.extend(data_batch['labels'][0:9000]) val_data.extend(data_batch['data'][9000:, :]) val_label.extend(data_batch['labels'][9000:]) return np.array(train_data), np.array(train_label), np.array(val_data), np.array(val_label) class NearestNeighbor(object): def __init__(self): self.X = None self.y = None self.dist = 0 def train(self, x, y): self.xtr = x self.ytr = y def predict(self, test_X, k, distance): num_test = test_X.shape[0] pre = [] for i in range(num_test): if distance == 'L1': self.dist = self.L1Distance(test_X[i]) if distance == 'L2': self.dist = self.L2Distance(test_X[i]) distArgSort = np.argsort(self.dist)[0:k] classSort = self.ytr[distArgSort] classCount = np.bincount(classSort) predict = np.argmax(classCount) #print(predict.dtype) pre.append(predict) return np.array(pre) def L1Distance(self, x): dist = np.sum(abs(self.xtr-x), axis=1) return dist def L2Distance(self, x): dist = np.sqrt(np.sum(np.square(self.xtr-x), axis=1)) return dist def CrossValidation(): pass if __name__ == '__main__': train_data, train_label, val_data, val_label = cifarLoad() clf = NearestNeighbor() train = clf.train(train_data, train_label) pre = clf.predict(val_data, k=20, distance='L2') arr = np.mean(pre == val_label) print(arr)