機器學習 — 再認識資料集
做了一些簡單機器學習任務後,發現必須要對資料集有足夠的瞭解才能動手做一些事,這是無法避免的,否則可能連在幹嘛都不知道,而一些官方例程並不會對資料集做過多解釋,你甚至連它長什麼樣都不知道。。。
以sklearn的手寫數字識別為例,例子中,一句
digits = datasets.load_digits()
就拿到資料了,然後又幾句
images_and_labels = list(zip(digits.images, digits.target)) for index, (image, label) in enumerate(images_and_labels[:4]): plt.subplot(2, 4, index + 1) plt.axis('off') plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest') plt.title('Training: %i' % label) # To apply a classifier on this data, we need to flatten the image, to # turn the data in a (samples, feature) matrix: n_samples = len(digits.images) data = digits.images.reshape((n_samples, -1))
就把資料集劃分好了,對初學者來說,可能都不知道幹了些啥。。。當然更重要的是,跑一邊程式看到效果不錯,想要用訓練好的模型玩玩自己的資料集,卻無從下手。。。於是,下面就以這個例子來說一下,如何基本的瞭解資料集,以及如何構造資料集,或許還會談談為什麼要這樣構造。。。
1.認識資料集。
看程式碼,我們發現,該資料集主要由兩個部分組成:
1).images
2).target
target 的劃分看起來不復雜,所以可以直接看看其中的部分內容:
>>> print(digits.images.shape) # (1797,) >>> print(digits.target[:10]) # [0 1 2 3 4 5 6 7 8 9] >>> print(digits.target[-10:]) # [5 4 8 8 4 9 0 8 9 8]
含義是:target是一個形狀為長度為1797的行向量,共有1797個(0~9)數字。
images 還需要做一些處理才能使用fit介面,但我們也先看看原本長什麼樣:
>>> print(digits.image.shape) # (1797, 8, 8) >>> print(digits.images[0].shape) # (8, 8) >>> print(digits.images[0]) ''' [[ 0. 0. 5. 13. 9. 1. 0. 0.] [ 0. 0. 13. 15. 10. 15. 5. 0.] [ 0. 3. 15. 2. 0. 11. 8. 0.] [ 0. 4. 12. 0. 0. 8. 8. 0.] [ 0. 5. 8. 0. 0. 9. 8. 0.] [ 0. 4. 11. 0. 1. 12. 7. 0.] [ 0. 2. 14. 5. 10. 12. 0. 0.] [ 0. 0. 6. 13. 10. 0. 0. 0.]] '''
再畫出來看看:
>>> import matplotlib.pyplot as plt >>> plt.axis('off') >>> plt.title('label: %i' % digits.target[0]) >>> plt.imshow(digits.images[0], cmap='gray_r') >>> plt.show()
含義是:images是由1797張尺寸為8*8的單通道圖片組成,而圖片內容對應每一張標籤的數字的手寫數字。
於是,這下我們瞭解了資料集了,但別急,圖片集還要做點處理才能使用:
>>> data = digits.images.reshape((n_samples, -1)) >>> print(data.shape) # (1797, 64) >>> print(data[0]) ''' [ 0. 0. 5. 13. 9. 1. 0. 0. 0. 0. 13. 15. 10. 15. 5. 0. 0. 3. 15. 2. 0. 11. 8. 0. 0. 4. 12. 0. 0. 8. 8. 0. 0. 5. 8. 0. 0. 9. 8. 0. 0. 4. 11. 0. 1. 12. 7. 0. 0. 2. 14. 5. 10. 12. 0. 0. 0. 0. 6. 13. 10. 0. 0. 0.] '''
把原圖片集形狀 (numbers, w, h) 變成了 (numbers, w * h),也就是把2維陣列變為一維陣列來儲存,我個人認為是為了效率...處理一維陣列的效率比二維陣列高很多。(使用深度學習,我們可以利用神經網路自己構造輸入形狀和輸出形狀,便利許多。)
現在我們很清楚模型要輸入什麼樣的資料才能進行訓練了。
2.訓練模型。
該例子使用svm,不同問題的選擇不一,而是根據對演算法的理解、經驗和觀察最終訓練效果選擇合適的演算法。
from sklearn import datasets, svm digits = datasets.load_digits() n_samples = len(digits.images) train_x = digits.images.reshape((n_samples, -1)) train_y = digits.target model = svm.SVC(gamma=0.001) model.fit(train_x, train_y,)
3.評估模型的效果。
from sklearn import metrics y_real = dateset.target ... y_pred = model.predict(test_x) print(metrics.accuracy_score(y_real, y_pred))
4.儲存和載入模型。
儲存模型很簡單,sklearn有專門提供便利的方法來儲存和載入模型:
from sklearn.externals import joblib joblib.dump(model, 'mnist.m')
載入模型:
model = joblib.load('mnist.m') y_pred = model.predict(test_x)
5.最後,部署模型。
上面看到,圖片的形狀必須為8*8畫素大小的單通道圖片,假如我們有一批50*50的手寫數字圖片集,想用該模型測試一下效果怎麼辦,我們只需要改變一下圖片解析度,把形狀變為8*8即可。這樣,我們才能用自己的資料集來進行測試,或者部署該模型以提供給別人使用。
關於如何部署到web,可以參考前一篇隨筆。
下面是一個例子,使用了一點opencv來把RGB圖片轉為灰度圖、修改圖片尺寸以及一些簡單的額外處理:
from sklearn import datasets, svm, metrics import matplotlib.pyplot as plt from sklearn.externals import joblib import numpy as np import cv2 as cv digits = datasets.load_digits() n_samples = len(digits.images) train_x = digits.images.reshape((n_samples, -1)) train_y = digits.target model = svm.SVC(gamma=0.001) model.fit(train_x, train_y,) joblib.dump(model, 'mnist.m') w, h = 8, 8 labels = [0, 1, 4, 5] lenght = len(labels) images = np.zeros((lenght, h, w), np.uint8) imgs = [] for i, name in enumerate(labels): img = cv.imread('digits/{}.png'.format(name), cv.IMREAD_GRAYSCALE) img = cv.resize(img, (h, w), interpolation=cv.INTER_CUBIC) for r in range(img.shape[0]): for c in range(img.shape[1]): if np.all(img[r, c] <= [251, 251, 251]): img[r, c] = (0, 0, 0) imgs.append(img) # gray = cv.cvtColor(img, cv.COLOR_BGR2GRAY) # lap = cv.Laplacian(gray, cv.CV_64F) images[i] = img images = images.reshape((lenght, -1)) # model = joblib.load('mnist.m') pred = model.predict(images) print(metrics.accuracy_score(labels, pred)) for index, (image, label) in enumerate(list(zip(imgs, pred))): plt.subplot(1, lenght, index + 1) plt.axis('off') plt.imshow(image, cmap='gray_r', interpolation='nearest') plt.title('pred: %i' % label) plt.show()
嘛,雖然最後結果很糟。。4張圖片識別率只有25%,唯一一張識別成功的,還是因為,資料全部被識別為1,也不知道為啥。。。
自己斷斷續續玩了也有一段時間了,能懂如何生成資料,如何構造模型,如何部署模型等,嘛,很糟糕的說,也算有點成長吧。。。