1. 程式人生 > 其它 >機器學習筆記(二十三)——Tensorflow 2(視覺化)

機器學習筆記(二十三)——Tensorflow 2(視覺化)

本部落格僅用於個人學習,不用於傳播教學,主要是記自己能夠看得懂的筆記(

學習知識來自:【吳恩達團隊Tensorflow2.0實踐系列課程第一課】TensorFlow2.0中基於TensorFlow2.0的人工智慧、機器學習和深度學習簡介及基礎程式設計_嗶哩嗶哩_bilibili

上次鑑別了一下人與馬,這次換了一個數據集,鑑別貓與狗。方法與上次一毛一樣,不過這次後面要加一個視覺化操作,來看看我們的圖片經過卷積和池化之後的有什麼變化,有什麼突出的地方。

這次為了方便,用的是jupyter notebook編輯的(之前使用VScode),資料集下載地址:https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip

視覺化的話,就在我上次寫的程式碼後面加上下面這些,就可以了。另外,程式碼中的plt.show()在jupyter notebook中可以刪除。

import random
from tensorflow.keras.preprocessing.image import img_to_array,load_img
import matplotlib.pyplot as plt

s_outputs=[layer.output for layer in model.layers[1:]] #儲存每一層的輸出
v_model=tf.keras.models.Model(inputs=model.input,outputs=s_outputs) #
建立新的模型 for root,dirs,catsnam in os.walk(filepath+'/tmp/train/cats'): used_up_variable=0 for root,dirs,dogsnam in os.walk(filepath+'/tmp/train/dogs'): used_up_variable=0 catsnam=[filepath+'/tmp/train/cats/'+nam for nam in catsnam] dogsnam=[filepath+'/tmp/train/dogs/'+nam for nam in dogsnam] #獲取所有檔案的絕對路徑 img_path
=random.choice(catsnam+dogsnam) #隨機取一個圖片 img=load_img(img_path,target_size=(150,150)) #以150*150載入圖片 plt.imshow(img) plt.show() x=img_to_array(img) x=x.reshape((1,)+x.shape) #變為(1,150,150,3) x/=255.0 #歸一化 maps=v_model.predict(x) #生成結果 ans=model.predict(x,batch_size=10) #預測結果 print(ans[0]) if ans[0]<0.5: print('This is a cat.') else: print('This is a dog.') layernams=[layer.name for layer in model.layers] #獲取每一層的名字 for layernam,map in zip(layernams,maps): if len(map.shape)==4: #輸出Flatten之前的卷積層和池化層的影象 tunnel=map.shape[-1] #獲取特徵數 size=map.shape[1] #獲取輸出影象的邊長 d_grid=np.zeros((size,size*tunnel)) #建立0矩陣,之後將輸出影象放置在其中,有tunnel張圖 for i in range(tunnel): #以下為影象美化處理,我也不知道什麼原理 x=map[0,:,:,i] x-=x.mean() x=x/x.std() x*=64 x+=128 x=np.clip(x,0,255).astype('uint8') d_grid[:,i*size:(i+1)*size]=x #併入到矩陣中 scale=20.0/tunnel #總長:20 plt.figure(figsize=(scale*tunnel,scale)) #輸出大小:20*something plt.title(layernam) plt.grid(False) plt.gray() plt.imshow(d_grid,aspect='auto',cmap='viridis') #見參考部落格 plt.show()

得到結果:

<matplotlib.image.AxesImage at 0x1ae073faf10>

[8.4550436e-07]
This is a cat.
<ipython-input-21-db08195e54f6>:22: RuntimeWarning: invalid value encountered in true_divide
  x=x/x.std()

參考部落格:

Python zip() 函式 | 菜鳥教程 (runoob.com)

jupyter notebook 設定縮排為tab製表符_ksx_120999的部落格-CSDN部落格

Python Matplotlib.pyplot.gray()用法及程式碼示例 - 純淨天空 (vimsky.com)

matplotlib中cmap_matplotlib基礎繪圖命令之imshow的使用_weixin_39812577的部落格-CSDN部落格