1. 程式人生 > 其它 >opencv(python)使用svm演算法識別手寫數字

opencv(python)使用svm演算法識別手寫數字

技術標籤:OPenCVopencvpythonsvm手寫數字識別

svm演算法是一種使用超平面將資料進行分類的演算法。
關於mnist資料的解析,讀者可以自己從網上下載相應壓縮檔案,用python自己編寫解析程式碼,由於這裡主要研究knn演算法,為了圖簡單,直接使用Keras的mnist手寫數字解析模組。
本次程式碼執行環境為:
python 3.6.8
opencv-python 4.4.0.46
opencv-contrib-python 4.4.0.46
以下程式碼為使用svm演算法進行訓練模型:

import cv2
import numpy as np
from keras.datasets import
mnist from keras import utils if __name__=='__main__': #直接使用Keras載入的訓練資料(60000, 28, 28) (60000,) (train_images,train_labels),(test_images,test_labels)=mnist.load_data() #變換資料的形狀並歸一化 train_images=train_images.reshape(train_images.shape[0],-1)#(60000, 784) train_images=train_images.astype('float32'
)/255 test_images=test_images.reshape(test_images.shape[0],-1) test_images=test_images.astype('float32')/255 #將標籤資料轉為int32 並且形狀為(60000,1) train_labels=train_labels.astype(np.int32) test_labels=test_labels.astype(np.int32) train_labels=train_labels.reshape(-1,1) test_labels=
test_labels.reshape(-1,1) #建立svm模型 svm = cv2.ml.SVM_create() #設定型別為SVM_C_SVC代表分類 svm.setType(cv2.ml.SVM_C_SVC) #設定核函式 svm.setKernel(cv2.ml.SVM_POLY) #設定其它屬性 svm.setGamma(3) svm.setDegree(3) #設定迭代終止條件 svm.setTermCriteria((cv2.TermCriteria_MAX_ITER,300,1e-3)) #訓練 svm.train(train_images,cv2.ml.ROW_SAMPLE,train_labels) svm.save('mnist_svm.xml') #在測試資料上計算準確率 #進行模型準確率的測試 結果是一個元組 第一個值為資料1的結果 test_pre=svm.predict(test_images) test_ret=test_pre[1] #計算準確率 test_ret=test_ret.reshape(-1,) test_labels=test_labels.reshape(-1,) test_sum=(test_ret==test_labels) acc=test_sum.mean() print(acc)

訓練了300次的準確率為0.9687,如果增大這個迭代次數,準確率還會升高。
生成的模型檔案為十多兆,這個相比於knn演算法,模型檔案小很多。
接下來使用svm模型進行手寫數字的測試識別,程式碼如下:
在這裡插入圖片描述

import cv2
import numpy as np

if __name__=='__main__':
    #讀取圖片
    img=cv2.imread('shuzi.jpg',0)
    img_sw=img.copy()

    #將資料型別由uint8轉為float32
    img=img.astype(np.float32)
    #圖片形狀由(28,28)轉為(784,)
    img=img.reshape(-1,)
    #增加一個維度變為(1,784)
    img=img.reshape(1,-1)
    #圖片資料歸一化
    img=img/255

    #載入svm模型
    svm=cv2.ml.SVM_load('mnist_svm.xml')
    #進行預測
    img_pre=svm.predict(img)
    print(img_pre[1])

    cv2.imshow('test',img_sw)
    cv2.waitKey(0)

執行程式,結果如下,可見成功識別了該圖片。
在這裡插入圖片描述