1. 程式人生 > >神經網路與深度學習 自制MNIST測試資料供神經網路測試

神經網路與深度學習 自制MNIST測試資料供神經網路測試

一、利用windows自帶畫圖工具


畫布大小為28*28,用刷子工具,顏色為黑色,寫字並儲存。

二、python程式碼將其轉為灰度圖

from PIL import Image
import numpy as np

class Data2:
    def getTestPicArray(self,filename):
        im = Image.open(filename)
        x_s = 28
        y_s = 28
        out = im.resize((x_s, y_s), Image.ANTIALIAS)
        im_arr = np.array(out.convert('L'))
        num0 = 0
        num255 = 0
        threshold = 100

        for x in range(x_s):
            for y in range(y_s):
                if im_arr[x][y] > threshold:
                    num255 = num255 + 1
                else:
                    num0 = num0 + 1

        if (num255 > num0):
            #print("convert!")
            for x in range(x_s):
                for y in range(y_s):
                    im_arr[x][y] = 255 - im_arr[x][y]
                    if (im_arr[x][y] < threshold):  im_arr[x][y] = 0
                    # if(im_arr[x][y] > threshold) : im_arr[x][y] = 0
                    # else : im_arr[x][y] = 255
                    # if(im_arr[x][y] < threshold): im_arr[x][y] = im_arr[x][y] - im_arr[x][y] / 2

        out = Image.fromarray(np.uint8(im_arr))
        out.save("C:\\Users\\Administrator\\Desktop\\out\\" + filename)
        # print im_arr                               
        nm = im_arr.reshape((1, 784))

        nm = nm.astype(np.float32)
        nm = np.multiply(nm, 1.0 / 255.0)

       # print(nm.reshape((784,1)).shape)
        #print(nm.shape)
        return nm.reshape((784,1))
    #   return nm

    #getTestPicArray(r"C:\Users\Administrator\Desktop\2.png")
#Data2().getTestPicArray(r"2.png")

其中用到了第三方庫

三、在原神經網路程式碼中加入自定義測試資料(詳細見上一文章)
import mnist_loader
training_data, validation_data, test_data = mnist_loader.load_data_wrapper()
training_data = list(training_data)

import network
net = network.Network([784, 30, 10])
net.SGD(training_data, 30, 10, 2.0, test_data=test_data)

import image_data
import test4
for i in range(0,10):
    file = str(i)+".png"
    #data = image_data.ImageData().getArrayFromImage(file)
    data = test4.Data2().getTestPicArray(file)
    print("預測:",file,"的結果為",net.cjtest(data))
print("-----------------------------------------------------------------------")
for i in range(0,10):
    file = str(i)+".png"
    data = image_data.ImageData().getArrayFromImage(file)
    #data = test4.Data2().getTestPicArray(file)
    print("預測:",file,"的結果為",net.cjtest(data))


四、執行結果(準確性還有待提升)