cs231n nn分類
阿新 • • 發佈:2018-11-01
#python3
import numpy as np
def unpickle(file): #資料集的python3 例項
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
def load_CIFAR10(file):
#get the training data 因為是一bytes編碼的,需要在標籤前面加b,提取資料
dataTrain = []
labelTrain = []
for i in range(1,6):
dic = unpickle(file+"\\data_batch_"+str(i))
for item in dic[b"data"]:
dataTrain.append(item)
for item in dic[b"labels"]:
labelTrain.append(item)
#get test data
dataTest = []
labelTest = []
dic = unpickle(file+"\\test_batch" )
for item in dic[b"data"]:
dataTest.append(item)
for item in dic[b"labels"]:
labelTest.append(item)
return dataTrain,labelTrain,dataTest,labelTest
Xtr, Ytr, Xte, Yte = load_CIFAR10('tedata/cifar-10-batches-py')
Xtr = np.asarray(Xtr)
Xte = np.asarray(Xte)
Ytr = np.asarray(Ytr)
Yte = np.asarray(Yte)
#Xtr_rows = Xtr.reshape(Xtr.shape[0], 32 * 32 * 3) # Xtr_rows becomes 50000 x 3072 #兩種方式選一種
#Xte_rows = Xte.reshape(Xte.shape[0], 32 * 32 * 3) # Xte_rows becomes 10000 x 3072
class NearestNeighbor(object):
def __init__(self):
pass
def train(self, X,y):
self.xtr = X
self.ytr = y
def predict(self, X):
num_test = X.shape[0]
# lets make sure that the output type matches the input type
Ypred = np.zeros(num_test, dtype = self.ytr.dtype)
# loop over all test rows
for i in range(num_test):
distances = np.sqrt(np.sum(np.square(self.xtr - X[i,:]), axis = 1))
min_index = np.argmin(distances) # get the index with smallest distance
Ypred[i] = self.ytr[min_index] # predict the label of the nearest example
return Ypred
nn = NearestNeighbor() # create a Nearest Neighbor classifier class
nn.train(Xtr, Ytr) # train the classifier on the training images and labels
Yte_predict = nn.predict(Xte) # predict labels on the test images
# and now print the classification accuracy, which is the average number
# of examples that are correctly predicted (i.e. label matches)
print ('accuracy: %f' % ( np.mean(Yte_predict == Yte) ))
這個程式碼跑了比較費時,我跑了一個小時才出來結果,主要原因是在predict過程中計算10000條資料費時。