1. 程式人生 > >機器學習 — 再認識資料集

機器學習 — 再認識資料集

  做了一些簡單機器學習任務後,發現必須要對資料集有足夠的瞭解才能動手做一些事,這是無法避免的,否則可能連在幹嘛都不知道,而一些官方例程並不會對資料集做過多解釋,你甚至連它長什麼樣都不知道。。。

  以sklearn的手寫數字識別為例,例子中,一句

digits = datasets.load_digits()

  就拿到資料了,然後又幾句

images_and_labels = list(zip(digits.images, digits.target))
for index, (image, label) in enumerate(images_and_labels[:4]):
    plt.subplot(2, 4, index + 1)
    plt.axis('off')
    plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title('Training: %i' % label)

# To apply a classifier on this data, we need to flatten the image, to
# turn the data in a (samples, feature) matrix:
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))

  就把資料集劃分好了,對初學者來說,可能都不知道幹了些啥。。。當然更重要的是,跑一邊程式看到效果不錯,想要用訓練好的模型玩玩自己的資料集,卻無從下手。。。於是,下面就以這個例子來說一下,如何基本的瞭解資料集,以及如何構造資料集,或許還會談談為什麼要這樣構造。。。

  1.認識資料集。

  看程式碼,我們發現,該資料集主要由兩個部分組成:

    1).images

    2).target

  target 的劃分看起來不復雜,所以可以直接看看其中的部分內容:

>>> print(digits.images.shape)
# (1797,)
>>> print(digits.target[:10])
# [0 1 2 3 4 5 6 7 8 9]
>>> print(digits.target[-10:])
# [5 4 8 8 4 9 0 8 9 8]

  含義是:target是一個形狀為長度為1797的行向量,共有1797個(0~9)數字。

  images 還需要做一些處理才能使用fit介面,但我們也先看看原本長什麼樣:

>>> print(digits.image.shape)
# (1797, 8, 8)
>>> print(digits.images[0].shape)
# (8, 8)
>>> print(digits.images[0])
'''
[[ 0.  0.  5. 13.  9.  1.  0.  0.]
 [ 0.  0. 13. 15. 10. 15.  5.  0.]
 [ 0.  3. 15.  2.  0. 11.  8.  0.]
 [ 0.  4. 12.  0.  0.  8.  8.  0.]
 [ 0.  5.  8.  0.  0.  9.  8.  0.]
 [ 0.  4. 11.  0.  1. 12.  7.  0.]
 [ 0.  2. 14.  5. 10. 12.  0.  0.]
 [ 0.  0.  6. 13. 10.  0.  0.  0.]]
'''

  再畫出來看看:

>>> import matplotlib.pyplot as plt
>>> plt.axis('off')
>>> plt.title('label: %i' % digits.target[0])
>>> plt.imshow(digits.images[0], cmap='gray_r')
>>> plt.show()

  

  含義是:images是由1797張尺寸為8*8的單通道圖片組成,而圖片內容對應每一張標籤的數字的手寫數字。

  於是,這下我們瞭解了資料集了,但別急,圖片集還要做點處理才能使用:

>>> data = digits.images.reshape((n_samples, -1))
>>> print(data.shape)
# (1797, 64)
>>> print(data[0])
'''
[ 0.  0.  5. 13.  9.  1.  0.  0.  0.  0. 13. 15. 10. 15.  5.  0.  0.  3.
 15.  2.  0. 11.  8.  0.  0.  4. 12.  0.  0.  8.  8.  0.  0.  5.  8.  0.
  0.  9.  8.  0.  0.  4. 11.  0.  1. 12.  7.  0.  0.  2. 14.  5. 10. 12.
  0.  0.  0.  0.  6. 13. 10.  0.  0.  0.]
'''

  把原圖片集形狀 (numbers, w, h) 變成了 (numbers, w * h),也就是把2維陣列變為一維陣列來儲存,我個人認為是為了效率...處理一維陣列的效率比二維陣列高很多。(使用深度學習,我們可以利用神經網路自己構造輸入形狀和輸出形狀,便利許多。)

  現在我們很清楚模型要輸入什麼樣的資料才能進行訓練了。

  2.訓練模型。

  該例子使用svm,不同問題的選擇不一,而是根據對演算法的理解、經驗和觀察最終訓練效果選擇合適的演算法。

from sklearn import datasets, svm


digits = datasets.load_digits()
n_samples = len(digits.images)
train_x = digits.images.reshape((n_samples, -1))
train_y = digits.target

model = svm.SVC(gamma=0.001)
model.fit(train_x, train_y,)

  3.評估模型的效果。

from sklearn import metrics

y_real = dateset.target
...
y_pred = model.predict(test_x)
print(metrics.accuracy_score(y_real, y_pred))

  4.儲存和載入模型。

  儲存模型很簡單,sklearn有專門提供便利的方法來儲存和載入模型:

from sklearn.externals import joblib

joblib.dump(model, 'mnist.m')

  載入模型:

model = joblib.load('mnist.m')
y_pred = model.predict(test_x)

  5.最後,部署模型。

  上面看到,圖片的形狀必須為8*8畫素大小的單通道圖片,假如我們有一批50*50的手寫數字圖片集,想用該模型測試一下效果怎麼辦,我們只需要改變一下圖片解析度,把形狀變為8*8即可。這樣,我們才能用自己的資料集來進行測試,或者部署該模型以提供給別人使用。

  關於如何部署到web,可以參考前一篇隨筆

  下面是一個例子,使用了一點opencv來把RGB圖片轉為灰度圖、修改圖片尺寸以及一些簡單的額外處理:

from sklearn import datasets, svm, metrics
import matplotlib.pyplot as plt
from sklearn.externals import joblib
import numpy as np
import cv2 as cv


digits = datasets.load_digits()
n_samples = len(digits.images)
train_x = digits.images.reshape((n_samples, -1))
train_y = digits.target

model = svm.SVC(gamma=0.001)
model.fit(train_x, train_y,)
joblib.dump(model, 'mnist.m')

w, h = 8, 8
labels = [0, 1, 4, 5]
lenght = len(labels)
images = np.zeros((lenght, h, w), np.uint8)
imgs = []
for i, name in enumerate(labels):
    img = cv.imread('digits/{}.png'.format(name), cv.IMREAD_GRAYSCALE)
    img = cv.resize(img, (h, w), interpolation=cv.INTER_CUBIC)
    for r in range(img.shape[0]):
        for c in range(img.shape[1]):
            if np.all(img[r, c] <= [251, 251, 251]):
                img[r, c] = (0, 0, 0)
    imgs.append(img)
    # gray = cv.cvtColor(img, cv.COLOR_BGR2GRAY)
    # lap = cv.Laplacian(gray, cv.CV_64F)
    images[i] = img
images = images.reshape((lenght, -1))
# model = joblib.load('mnist.m')
pred = model.predict(images)
print(metrics.accuracy_score(labels, pred))
for index, (image, label) in enumerate(list(zip(imgs, pred))):
    plt.subplot(1, lenght, index + 1)
    plt.axis('off')
    plt.imshow(image, cmap='gray_r', interpolation='nearest')
    plt.title('pred: %i' % label)
plt.show()

  嘛,雖然最後結果很糟。。4張圖片識別率只有25%,唯一一張識別成功的,還是因為,資料全部被識別為1,也不知道為啥。。。

  自己斷斷續續玩了也有一段時間了,能懂如何生成資料,如何構造模型,如何部署模型等,嘛,很糟糕的說,也算有點成長吧。。。