1. 程式人生 > >cs231n nn分類

cs231n nn分類

#python3
import numpy as np
def unpickle(file):       #資料集的python3 例項
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

def load_CIFAR10(file):   
#get the training data  因為是一bytes編碼的,需要在標籤前面加b,提取資料
    dataTrain = []
    labelTrain = []
    for
i in range(1,6): dic = unpickle(file+"\\data_batch_"+str(i)) for item in dic[b"data"]: dataTrain.append(item) for item in dic[b"labels"]: labelTrain.append(item) #get test data dataTest = [] labelTest = [] dic = unpickle(file+"\\test_batch"
) for item in dic[b"data"]: dataTest.append(item) for item in dic[b"labels"]: labelTest.append(item) return dataTrain,labelTrain,dataTest,labelTest Xtr, Ytr, Xte, Yte = load_CIFAR10('tedata/cifar-10-batches-py') Xtr = np.asarray(Xtr) Xte = np.asarray(Xte) Ytr = np.asarray(Ytr) Yte = np.asarray(Yte) #Xtr_rows = Xtr.reshape(Xtr.shape[0], 32 * 32 * 3) # Xtr_rows becomes 50000 x 3072 #兩種方式選一種
#Xte_rows = Xte.reshape(Xte.shape[0], 32 * 32 * 3) # Xte_rows becomes 10000 x 3072 class NearestNeighbor(object): def __init__(self): pass def train(self, X,y): self.xtr = X self.ytr = y def predict(self, X): num_test = X.shape[0] # lets make sure that the output type matches the input type Ypred = np.zeros(num_test, dtype = self.ytr.dtype) # loop over all test rows for i in range(num_test): distances = np.sqrt(np.sum(np.square(self.xtr - X[i,:]), axis = 1)) min_index = np.argmin(distances) # get the index with smallest distance Ypred[i] = self.ytr[min_index] # predict the label of the nearest example return Ypred nn = NearestNeighbor() # create a Nearest Neighbor classifier class nn.train(Xtr, Ytr) # train the classifier on the training images and labels Yte_predict = nn.predict(Xte) # predict labels on the test images # and now print the classification accuracy, which is the average number # of examples that are correctly predicted (i.e. label matches) print ('accuracy: %f' % ( np.mean(Yte_predict == Yte) ))

這個程式碼跑了比較費時,我跑了一個小時才出來結果,主要原因是在predict過程中計算10000條資料費時。