MNIST手寫數字體分類--KNN matlab實現
這裡直接給出KNN matlab的實現
其中訓練資料60000條,測試資料10000條trainImages = loadMNISTImages('train-images.idx3-ubyte'); trainLabels = loadMNISTLabels('train-labels.idx1-ubyte'); N = 784; K = 100;% can be any other value testImages = loadMNISTImages('t10k-images.idx3-ubyte'); testLabels = loadMNISTLabels('t10k-labels.idx1-ubyte'); trainLength = length(trainImages); testLength = length(testImages); testResults = linspace(0,0,length(testImages)); compLabel = linspace(0,0,K); tic; for i=1:testLength curImage = repmat(testImages(:,i),1,trainLength); curImage = abs(trainImages-curImage); comp=sum(curImage); [sortedComp,ind] = sort(comp); for j = 1:K compLabel(j) = trainLabels(ind(j)); end table = tabulate(compLabel); [maxCount,idx] = max(table(:,2)); testResults(i) = table(idx); disp(testResults(i)); disp(testLabels(i)); end % Compute the error on the test set error=0; for i=1:testLength if (testResults(i) ~= testLabels(i)) error=error+1; end end %Print out the classification error on the test set error/testLength toc; disp(toc-tic);
function images = loadMNISTImages(filename) %loadMNISTImages returns a 28x28x[number of MNIST images] matrix containing %the raw MNIST images fp = fopen(filename, 'rb'); assert(fp ~= -1, ['Could not open ', filename, '']); magic = fread(fp, 1, 'int32', 0, 'ieee-be'); assert(magic == 2051, ['Bad magic number in ', filename, '']); numImages = fread(fp, 1, 'int32', 0, 'ieee-be'); numRows = fread(fp, 1, 'int32', 0, 'ieee-be'); numCols = fread(fp, 1, 'int32', 0, 'ieee-be'); images = fread(fp, inf, 'unsigned char'); images = reshape(images, numCols, numRows, numImages); images = permute(images,[2 1 3]); fclose(fp); % Reshape to #pixels x #examples images = reshape(images, size(images, 1) * size(images, 2), size(images, 3)); % Convert to double and rescale to [0,1] images = double(images) / 255; end
function labels = loadMNISTLabels(filename) %loadMNISTLabels returns a [number of MNIST images]x1 matrix containing %the labels for the MNIST images fp = fopen(filename, 'rb'); assert(fp ~= -1, ['Could not open ', filename, '']); magic = fread(fp, 1, 'int32', 0, 'ieee-be'); assert(magic == 2049, ['Bad magic number in ', filename, '']); numLabels = fread(fp, 1, 'int32', 0, 'ieee-be'); labels = fread(fp, inf, 'unsigned char'); assert(size(labels,1) == numLabels, 'Mismatch in label count'); fclose(fp); end
