用KNN分類器進行貓狗分類
1. KNN簡單介紹
KNN名字是K-nearest neighbors。Nearest neighbors是最鄰近的,K是指數量。其思想大概是,在空間中先放置好所有用於訓練的樣品,把測試樣品置於該空間中。用距離公式計算出離測試樣品最近的K個樣品,假如K個樣品中屬於A類的最多,那測試樣品也算輸入A類。
下圖中,白色框圖片是已經正確識別的。紅色框內的圖片是需要進行分類的。這裡取K=4,與目標圖片最接近的有兩張熊貓和兩張貓。
用於計算距離的公式有兩個,第一個是尤拉公式,也稱為L2-距離。
第二個是the Manhattan/city block,也稱為L1-距離。
公式計算完成,就是根據這K個樣品做判斷了。在這裡例子中,結果會比較爭議,因為該樣品和兩個種類集合的距離一樣近……
2. KNN貓狗分類試驗
2.1 環境搭建
我的電腦是桌上型電腦 系統:Ubuntu18.04 64bit CPU: i3-6100 無外接顯示卡 RAM:8G
需要使用OpenCV+Python OpenCV因為事前已經安裝好了,是從原始碼編譯安裝的。這裡不提供步驟介紹了。
但需要安裝python的必備模組。假如電腦未安裝(如筆者的電腦……)scipy、numpy、sklearn和imutils,請執行:
pip install scipy
# 自動下載了numpy
pip install sklearn
pip install imutils
不過也沒大礙。這幾個模組假如沒安裝,在執行本文程式碼時候,系統會提醒缺少這些包的。
2.1 下載貓狗資料集
因為下載速度快,就選擇了這個微軟的下載包。必應一下子就搜到了這個包,而且不用註冊即可下載。感覺快於Kaggle。Kaggle Cats and Dogs Dataset
不過使用起來其實這是個坑。我下載完成解壓後,執行程式碼,用命令列除錯才發現有部分圖片打不開。可以手動瀏覽檔案,看圖片有無損壞,另外也可以在python讀取圖片時,把無法識別的圖片刪除。下文會提到。
2.2 原始碼
本試驗在knnClassifier資料夾中建立了三個py檔案 |--- knnClassifier | |--- simpledatasetloader.py | |--- simplepreprocessor.py | |--- knn.py
2.2.1 simplepreprocessor.py
本檔案提供了SimplePreprocessor這個class供外部檔案使用。功能是根據尺寸,把輸入影象壓縮。因為KNN要把多個樣品讀取到RAM中,以便對測試樣品進行分類。因此,先把所有讀取的影象進行壓縮,便於讀取到有限的RAM中,同時也能減少判斷演算法的執行時間。
import cv2
class SimplePreprocessor:
def __init__(self, width, height, inter=cv2.INTER_AREA):
self.width = width
self.height = height
self.inter = inter
def preprocess(self, image):
return cv2.resize(image, (self.width, self.height), interpolation=self.inter)
# print(image.size)
if __name__ == '__main__':
s = SimplePreprocessor(32, 32)
img = cv2.imread('/home/xxjian/DeepLearningMaterial/kagglecatsanddogs_3367a/PetImages/Cat/9759.jpg')
# print(img)
cv2.imshow('src', img)
cv2.imshow("resize", s.preprocess(img))
#print(img.size)
cv2.waitKey(0)
# cv2.destroyallWindows()
檔案中帶有main函式,就是當使用python simplepreprocessor.py指令執行本檔案時,會執行main函式內的程式碼。可以這樣先測試本檔案的函式功能。
2.2.2 simpledatasetloader.py
import numpy as np
import cv2
import os
class SimplePreprocessor:
def __init__(self, width, height, inter=cv2.INTER_AREA):
self.width = width
self.height = height
self.inter = inter
def preprocess(self, image):
return cv2.resize(image, (self.width, self.height), interpolation=self.inter)
class SimpleDatasetLoader:
def __init__(self, preprocessors=None):
self.preprocessors = preprocessors
if self.preprocessors is None:
self.preprocessors = []
def load(self, imagePaths, verbose=-1):
data = []
labels = []
# print(imagePaths)
for (i, imagePath) in enumerate(imagePaths):
image = cv2.imread(imagePath)
label = imagePath.split(os.path.sep)[-2]
if self.preprocessors is not None:
for p in self.preprocessors:
if(image is None):
print(i)
os.remove(imagePaths[i])
print('file: ')
print(imagePaths[i])
print('is removed.')
continue
image = p.preprocess(image)
data.append(image)
labels.append(label)
if verbose > 0 and i > 0 and (i + 1)%verbose == 0:
print('[INFO] processed {}/{}'.format(i+1, len(imagePaths)))
return (np.array(data), np.array(labels))
if __name__ == '__main__':
imagePaths = '/home/xxjian/DeepLearningMaterial/pet_sample/'
sp = SimplePreprocessor(32, 32)
sdl = SimpleDatasetLoader(preprocessors=[sp])
#(data, labels) = sdl.load(imagePaths, verbose=10)
#data = data.reshape((data.shape[0], 3072))
上面的這幾行程式碼:
if(image is None): print(i) os.remove(imagePaths[i]) print('file: ') print(imagePaths[i]) print('is removed.') continue
用於刪除資料集中無法識別的檔案,可能是由於網站提供的包有問題,也可能是我下載時用下載程式在資料傳輸或者重構檔案時遇到了錯誤,總之裡面有損壞的格式檔案。如果不加這幾行,在讀取到錯誤檔案時會有NoneType錯誤(本程式中cv2.read()函式讀取損壞的檔案會輸出NoneType變數)。
2.2.3 knn.py
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from simplepreprocessor import SimplePreprocessor
from simpledatasetloader import SimpleDatasetLoader
from imutils import paths
import argparse
if __name__ == '__main__':
ap = argparse.ArgumentParser()
ap.add_argument("-d", "--dataset", required=True, help="path to input dataset")
ap.add_argument("-k", "--neighbors", type=int, default=1, help="of nearest neighbors for classification")
ap.add_argument("-j", "--jobs", type=int, help="of jobs for K-NN distance (-1 uses all variables cores)")
args = vars(ap.parse_args())
print("[INFO] loading images...")
imagePaths = list(paths.list_images(args["dataset"]))
sp = SimplePreprocessor(32, 32)
sdl = SimpleDatasetLoader(preprocessors=[sp])
(data, labels) = sdl.load(imagePaths, verbose=100)
data = data.reshape((data.shape[0], 3072))
print("[INFO] features matrix:{:.1f}MB".format(data.nbytes / (1024*1000.0)))
le = LabelEncoder()
labels = le.fit_transform(labels)
(trainX, testX, trainY, testY) = train_test_split(data, labels, test_size=0.25, random_state=42)
print("[INFO] evaluating K-NN classifier...")
#model = KNeighborsClassifier(n_neighbors=args["neighbors"], n_jobs=args["jobs"])
model=KNeighborsClassifier(n_neighbors=3)
model.fit(trainX, trainY)
print(classification_report(testY, model.predict(testX), target_names=le.classes_))
#knn=KNeighborsClassifier(n_neighbors=3)
#knn.fit(trainX,trainY)
#prediction=model.predict(testX)
#print(classification_report(testY, prediction, target_names=le.classes_))
我把原始碼註釋掉了,本試驗中的程式碼源自pyimagesearch作者出的一本書。但是我的環境中不能直接執行,故做了小量修改。
3 試驗結果
在knnClassifier資料夾內執行下面的命令:
python3 knn.py --dataset /home/xxjian/DeepLearningMaterial/kagglecatsanddogs_3367a/PetImages/
會輸出以下的資訊:
[INFO] loading images...
[INFO] processed 100/24944
[INFO] processed 200/24944
[INFO] processed 300/24944
[INFO] processed 400/24944
[INFO] processed 500/24944
[INFO] processed 600/24944
[INFO] processed 700/24944
[INFO] processed 800/24944
[INFO] processed 900/24944
[INFO] processed 1000/24944
...
[INFO] processed 24600/24944
[INFO] processed 24700/24944
[INFO] processed 24800/24944
[INFO] processed 24900/24944
[INFO] features matrix:74.8MB
[INFO] evaluating K-NN classifier...
執行到evaluating K-NN classifier...這裡會比價耗時,可以開啟gnone-system-monitor可以看到,CPU4完全被python3這程式佔用了。耐心等候,最後輸出了以下結果:
precision recall f1-score support
Cat 0.56 0.65 0.60 3124
Dog 0.58 0.49 0.53 3112
avg / total 0.57 0.57 0.57 6236
這裡比較有用的資訊是precision精度。support是樣品數目。
程式碼設定了樣品中的0.25作為測試樣。24944的1/4就是6236。精度是正確識別的樣品除以測試樣品。
這是KNN的效能了。對本資料集,識別貓狗、平均精度0.57。一半對一半錯。
參考:
2. Deep.Learning.for.Computer.Vision.with.Python.Starter.Bundle.2017.9.pdf的第七章