1. 程式人生 > 其它 >吳恩達深度學習課後作業第一課第二週-邏輯迴歸的拓展,自己做資料來進行預測是否是貓

吳恩達深度學習課後作業第一課第二週-邏輯迴歸的拓展,自己做資料來進行預測是否是貓

技術標籤:吳恩達深度學習課後作業python深度學習


首先宣告:本文在參考[https://blog.csdn.net/u013733326/article/details/79639509](https://blog.csdn.net/u013733326/article/details/79639509)的部落格基礎上,增加了自己的圖片來進行預測。目的為了讓大家更能深入的理解怎麼製作資料集,以及如何使用自己的圖片來用一個跑好的模型。 相關原始碼,請到我github主頁下載需要的相關原始碼:[Confused-Pig](https://github.com/Confused-Pig)
本文如有不對之處,還請指出。

1.建立自己的資料集

我選取了幾張貓的圖片,把他們的畫素點提取,存進h5檔案中,以方便後面使用,具體程式碼如下:
注:需要下載opencv,numpy,h5py等庫,如果沒有,請自行下載,推薦在cmd視窗中pip。

import os
import numpy as np
import cv2
import h5py

def ResizePic(path,save_path,dim):
    i = 1
    file_name = os.listdir(path)
    for img in file_name:
        image = cv2.imread(path + img)
image_size = cv2.resize(image, dim, interpolation=cv2.INTER_AREA) cv2.imwrite(save_path + str(i) + '.jpg', image_size) i = i + 1 def save_image_to_h5py(path): img_list = [] label_list = [] for child_dir in os.listdir(path): img = cv2.imread(os.path.join
(path,child_dir)) img_list.append(img) label_list.append(1) img_np = np.array(img_list) label_np = np.array(label_list) print('資料集標籤順序:\n',label_np) #'a' ,如果已經有這個名字的h5檔案存在將不會開啟,目的為了防止誤刪資訊。 #‘w' ,如果有同名檔案也能開啟,但會覆蓋上次的內容。 with h5py.File('datasets/test_cat.h5','a') as f: f.create_dataset('test_cat',data = img_np) f.create_dataset('test_label',data = label_np) f.close() path = 'Image/pic/' save_path = 'Image/test_pic/' dim = (64,64) d = ResizePic(path,save_path,dim) b = save_image_to_h5py(save_path)

最終生成的效果如圖:
在這裡插入圖片描述
具體的對h5檔案的建立和讀取,請參考我的另外兩篇博文:
h5py檔案的建立和讀取,資料集的製作也不算很難;
如何用Python來建立一個深度學習的圖片集,改變畫素和自動排序

2.進行預測

我們先看看模型最後返回的引數是什麼:

def model(X_train,Y_train,X_test,Y_test,num_iterations=2000,learning_rate=0.5,print_cost=False):
 
    w,b = initalize_parameters_zero(X_train.shape[0])
    parameters,grads,costs = optimize(w,b,X_train,Y_train,num_iterations,learning_rate,print_cost)
    w,b = parameters['w'],parameters['b']

    Y_predict_test = predict(w,b,X_test)
    Y_predict_train = predict(w,b,X_train)

    print('訓練集準確度: ',format(100 - np.mean(np.abs(Y_predict_train - Y_train)) * 100), '%')
    print('測試集準確度: ', format(100 - np.mean(np.abs(Y_predict_test - Y_test)) * 100), '%')

    dic = {'costs':costs,
           'Y_predict_train':Y_predict_train,
           'Y_predict_test':Y_predict_test,
           'w':w,
           'b':b,
           'learning_rate':learning_rate,
           'num_iterations':num_iterations}

    return dic

可以很清楚的看到,模型最後返回的是一個dic的字典,裡面儲存的是一系列引數。而我們在預測需要用的是w和b。


我們再來看看預測函式是哈樣的:

def predict(w,b,X):
 
    m = X.shape[1]
    Y_predict = np.zeros((1,m))
    w = w.reshape(X.shape[0],1)

    A = sigmoid(np.dot(w.T,X)+b)

    for i in range(A.shape[1]):
        Y_predict[0,i] = 1 if A[0,i] > 0.5 else 0      #判斷閾值是否大於0.5,大於則是1,否則是0

    assert (Y_predict.shape == (1,m))

    return Y_predict

預測函式裡需要三個引數,w、b和X,X就是我們要預測的圖片資訊。


接下來就是把我們的h5檔案中的圖片資訊拿出來,再結合dic字典裡的w,b進行預測(就是這麼簡單):
w = dic['w']
b = dic['b']

test = h5py.File('datasets/test_cat.h5','r')
test_cat_x_orig = np.array(test['test_cat'][:])
test_cat_x_flatten = test_cat_x_orig.reshape(test_cat_x_orig.shape[0],-1).T
test_cat = test_cat_x_flatten/255

pre = predict(w,b,test_cat)
print(pre)

執行一下看看結果:
在這裡插入圖片描述
輸出1則是貓,0不是貓。我所選的5張圖片全是貓,但結果卻只有三張是貓,所以精度不高,還需要改進。如何改進,請大家自己也動動大腦,和我一起想想方法。


結語

這有太多地方需要改進,精度遠遠不夠,奈何博主也是小新一枚,學力有限,希望大家能和我多多討論,在深度學習的道路上越走越長,謝謝大家。
原始碼我已經放到我的giuhub上了,希望大家去看看原始碼,在結合我講的,很容易懂。再次感謝大家閱讀。

點選這裡,原始碼奉上,感謝使用