tensorflow2caffe(2) : 如何在tensorflow中取出模型引數
經過一段時間,筆者又回來了!
從2017年7月開始,筆者結束了研一繁重的考試月任務(因為平常不常上課,突擊考試大家懂的),然後就進入了研二生涯。在七八月筆者就開始了自己的課題實驗(由於太忙累壞了身體,先是眼睛不太好,然後又犯了胃病被迫去做胃鏡,在此深切提醒各位,工作的時候不能太拼,身體很重要)。研二生涯筆者在實驗室又要做學術又要兼顧專案,因此相當繁忙。但是筆者向大家做下承諾,一定不放棄部落格的更新,筆者會在第一時間將自己的實驗經驗與小技巧科普給大家。
廢話少說,下面開啟正題!
還記得我們的上一篇部落格嗎?講述了caffemodel到底記錄了什麼,那麼,筆者帶領大家回顧一下我們目前做的是什麼工程:目前是將一個模型在tensorflow下面訓練,然後轉換到caffe框架上執行,這樣,模型就可以在嵌入式平臺上面運作了(nvidia TX2,nvidia TK1等)。我們在tensorflow2caffe(1)中,已經講述瞭如何搞清楚caffemodel裡面記錄了什麼,也就弄明白了我們的資料轉換目標。那麼,待轉換的資料從哪裡來呢?從tensorflow中來對吧。那麼,針對tensoeflow訓練所得的結果,我們來看一下我們得到了什麼:
從上圖中,我們看到,tensorflow框架訓練得到的結果裡面,首先有一個checkpoint,這是一個訓練結果的索引,從這個裡面,我們可以看到我們儲存了哪些模型?(對應儲存模型的程式在tensorflow的訓練程式中,筆者就比較愛使用tf.Train.Saver.save介面儲存模型引數),ckeckpoint檔案可以被開啟,我們可以看看裡面記錄了什麼。
大家可以看到,裡面記錄的是這個我們在訓練不同階段儲存的模型。首先大家要記住,這是一個很有用的檔案。
然後,在圖1中,我們發現了還有其他型別的檔案,首先有一個.meta檔案,瞭解tensorflow的同學知道,這是一個記錄Graph的檔案,在tensorflow中,Graph記錄了所有資料的流向,規定了整個模型的結構。然後有一種data-00000-of-00001結構的檔案,這個檔案比較大(往往幾十上百M),這個檔案是記錄了我們訓練得到的資料
那麼,到這裡,我們就明白瞭如何去提取出tensorflow的訓練引數。
(1)載入資料流圖
(2)通過checkpoint找到目前最新的訓練儲存結果
(3)提取訓練得到的所有引數
具體程式碼如下:
#!/usr/bin/python import tensorflow as tf import numpy as np with tf.Session() as sess: new_saver = tf.train.import_meta_graph('model.ckpt-189200.meta') #load graph for var in tf.trainable_variables(): #get the param names print var.name #print parameters' names new_saver.restore(sess, tf.train.latest_checkpoint('./')) #find the newest training result all_vars = tf.trainable_variables() for v in all_vars: v_4d = np.array(sess.run(v)) #get the real parameters
在上面的程式碼中,tf.train.import_meta_graph匯入了訓練資料流圖,tf,train.latest_checkpoint找到了最新的訓練資料,還列印了引數名字,並且最後使用sess.run(v)得到了各個引數。
我們執行一下程式碼,大家可以看到,引數名稱源源不斷地打印出來。
在我們找到引數之後,下面,我們就要把tensorflow框架格式規定的引數轉化為caffe框架規定格式規範的引數,具體該怎麼做,且聽下回分解。
歡迎閱讀筆者後續部落格,各位讀者朋友的支援與鼓勵是我最大的動力!
written by jiong
豈能盡如人意,但求無愧於心