1. 程式人生 > >tensorflow 模型權重匯出至h5檔案

tensorflow 模型權重匯出至h5檔案

tensorflow在儲存權重模型時多使用tf.train.Saver().save 函式進行權重儲存,儲存的ckpt檔案無法直接開啟,不利於將模型權重匯入到其他框架使用(如Caffe、Keras等)。

好在tensorflow提供了相關函式 tf.train.NewCheckpointReader 可以對ckpt檔案進行權重檢視,因此可以通過該函式進行資料匯出。

 1 import tensorflow as tf
 2 import h5py
 3 
 4 cpktLogFileName = r'./checkpoint/checkpoint'  #cpkt 檔案路徑
 5 with open(cpktLogFileName, '
r') as f: 6 #權重節點往往會保留多個epoch的資料,此處獲取最後的權重資料 7 cpktFileName = f.readline().split('"')[1] 8 9 h5FileName = r'./model/net_classification.h5' 10 11 reader = tf.train.NewCheckpointReader(cpktFileName) 12 f = h5py.File(h5FileName, 'w') 13 t_g = None 14 for key in sorted(reader.get_variable_to_shape_map()):
15 # 權重名稱需根據自己網路名稱自行修改 16 if key.endswith('w') or key.endswith('biases'): 17 keySplits = key.split(r'/') 18 keyDict = keySplits[1] + '/' + keySplits[1] + '/' + keySplits[2] 19 f[keyDict] = reader.get_tensor(key)