吳恩達深度學習課後作業第一課第二週-邏輯迴歸的拓展,自己做資料來進行預測是否是貓
阿新 • • 發佈:2021-01-10
首先宣告:本文在參考[https://blog.csdn.net/u013733326/article/details/79639509](https://blog.csdn.net/u013733326/article/details/79639509)的部落格基礎上,增加了自己的圖片來進行預測。目的為了讓大家更能深入的理解怎麼製作資料集,以及如何使用自己的圖片來用一個跑好的模型。 相關原始碼,請到我github主頁下載需要的相關原始碼:[Confused-Pig](https://github.com/Confused-Pig)
本文如有不對之處,還請指出。
1.建立自己的資料集
我選取了幾張貓的圖片,把他們的畫素點提取,存進h5檔案中,以方便後面使用,具體程式碼如下:
注:需要下載opencv,numpy,h5py等庫,如果沒有,請自行下載,推薦在cmd視窗中pip。
import os
import numpy as np
import cv2
import h5py
def ResizePic(path,save_path,dim):
i = 1
file_name = os.listdir(path)
for img in file_name:
image = cv2.imread(path + img)
image_size = cv2.resize(image, dim, interpolation=cv2.INTER_AREA)
cv2.imwrite(save_path + str(i) + '.jpg', image_size)
i = i + 1
def save_image_to_h5py(path):
img_list = []
label_list = []
for child_dir in os.listdir(path):
img = cv2.imread(os.path.join (path,child_dir))
img_list.append(img)
label_list.append(1)
img_np = np.array(img_list)
label_np = np.array(label_list)
print('資料集標籤順序:\n',label_np)
#'a' ,如果已經有這個名字的h5檔案存在將不會開啟,目的為了防止誤刪資訊。
#‘w' ,如果有同名檔案也能開啟,但會覆蓋上次的內容。
with h5py.File('datasets/test_cat.h5','a') as f:
f.create_dataset('test_cat',data = img_np)
f.create_dataset('test_label',data = label_np)
f.close()
path = 'Image/pic/'
save_path = 'Image/test_pic/'
dim = (64,64)
d = ResizePic(path,save_path,dim)
b = save_image_to_h5py(save_path)
最終生成的效果如圖:
具體的對h5檔案的建立和讀取,請參考我的另外兩篇博文:
h5py檔案的建立和讀取,資料集的製作也不算很難;
如何用Python來建立一個深度學習的圖片集,改變畫素和自動排序
2.進行預測
我們先看看模型最後返回的引數是什麼:
def model(X_train,Y_train,X_test,Y_test,num_iterations=2000,learning_rate=0.5,print_cost=False):
w,b = initalize_parameters_zero(X_train.shape[0])
parameters,grads,costs = optimize(w,b,X_train,Y_train,num_iterations,learning_rate,print_cost)
w,b = parameters['w'],parameters['b']
Y_predict_test = predict(w,b,X_test)
Y_predict_train = predict(w,b,X_train)
print('訓練集準確度: ',format(100 - np.mean(np.abs(Y_predict_train - Y_train)) * 100), '%')
print('測試集準確度: ', format(100 - np.mean(np.abs(Y_predict_test - Y_test)) * 100), '%')
dic = {'costs':costs,
'Y_predict_train':Y_predict_train,
'Y_predict_test':Y_predict_test,
'w':w,
'b':b,
'learning_rate':learning_rate,
'num_iterations':num_iterations}
return dic
可以很清楚的看到,模型最後返回的是一個dic的字典,裡面儲存的是一系列引數。而我們在預測需要用的是w和b。
我們再來看看預測函式是哈樣的:
def predict(w,b,X):
m = X.shape[1]
Y_predict = np.zeros((1,m))
w = w.reshape(X.shape[0],1)
A = sigmoid(np.dot(w.T,X)+b)
for i in range(A.shape[1]):
Y_predict[0,i] = 1 if A[0,i] > 0.5 else 0 #判斷閾值是否大於0.5,大於則是1,否則是0
assert (Y_predict.shape == (1,m))
return Y_predict
預測函式裡需要三個引數,w、b和X,X就是我們要預測的圖片資訊。
接下來就是把我們的h5檔案中的圖片資訊拿出來,再結合dic字典裡的w,b進行預測(就是這麼簡單):
w = dic['w']
b = dic['b']
test = h5py.File('datasets/test_cat.h5','r')
test_cat_x_orig = np.array(test['test_cat'][:])
test_cat_x_flatten = test_cat_x_orig.reshape(test_cat_x_orig.shape[0],-1).T
test_cat = test_cat_x_flatten/255
pre = predict(w,b,test_cat)
print(pre)
執行一下看看結果:
輸出1則是貓,0不是貓。我所選的5張圖片全是貓,但結果卻只有三張是貓,所以精度不高,還需要改進。如何改進,請大家自己也動動大腦,和我一起想想方法。
結語
這有太多地方需要改進,精度遠遠不夠,奈何博主也是小新一枚,學力有限,希望大家能和我多多討論,在深度學習的道路上越走越長,謝謝大家。
原始碼我已經放到我的giuhub上了,希望大家去看看原始碼,在結合我講的,很容易懂。再次感謝大家閱讀。