1. 程式人生 > 其它 >TensorFlow模型儲存和載入方法彙總

TensorFlow模型儲存和載入方法彙總

技術標籤:tensorflow

目錄


『TensorFlow』第七彈_儲存&載入會話_霸王回馬

回到頂部

一、TensorFlow常規模型載入方法

儲存模型

tf.train.Saver()類,.save(sess, ckpt檔案目錄)方法

引數名稱功能說明預設值
var_listSaver中儲存變數集合全域性變數集合
reshape載入時是否恢復變數形狀True
sharded是否將變數輪循放在所有裝置上True
max_to_keep保留最近檢查點個數5
restore_sequentially是否按順序恢復變數,模型較大時順序恢復記憶體消耗小True

var_list是字典形式{變數名字串: 變數符號},相對應的restore也根據同樣形式的字典將ckpt中的字串對應的變數載入給程式中的符號。

如果Saver給定了字典作為載入方式,則按照字典來,如:saver=tf.train.Saver({"v/ExponentialMovingAverage":v}),否則每個變數尋找自己的name屬性在ckpt中的對應值進行載入。

載入模型

當我們基於checkpoint檔案(ckpt)載入引數時,實際上我們使用Saver.restore取代了initializer的初始化

checkpoint檔案會記錄儲存資訊,通過它可以定位最新儲存的模型:

1

2

ckpt=tf.train.get_checkpoint_state('./model/')

print(ckpt.model_checkpoint_path)

.meta檔案儲存了當前圖結構

.data檔案儲存了當前引數名和值

.index檔案儲存了輔助索引資訊

.data檔案可以查詢到引數名和引數值,使用下面的命令可以查詢儲存在檔案中的全部變數{名:值}對,

1

2

fromtensorflow.python.tools.inspect_checkpointimportprint_tensors_in_checkpoint_file

print_tensors_in_checkpoint_file(os.path.join(savedir,savefile),None,True)

tf.train.import_meta_graph函式給出model.ckpt-n.meta的路徑後會載入圖結構,並返回saver物件

1

ckpt=tf.train.get_checkpoint_state('./model/')

tf.train.Saver函式會返回載入預設圖的saver物件,saver物件初始化時可以指定變數對映方式,根據名字對映變數(『TensorFlow』滑動平均)

1

saver=tf.train.Saver({"v/ExponentialMovingAverage":v})

saver.restore函式給出model.ckpt-n的路徑後會自動尋找引數名-值檔案進行載入

1

2

saver.restore(sess,'./model/model.ckpt-0')

saver.restore(sess,ckpt.model_checkpoint_path)

1.不載入圖結構,只加載引數

由於實際上我們引數儲存的都是Variable變數的值,所以其他的引數值(例如batch_size)等,我們在restore時可能希望修改,但是圖結構在train時一般就已經確定了,所以我們可以使用tf.Graph().as_default()新建一個預設圖(建議使用上下文環境),利用這個新圖修改和變數無關的參值大小,從而達到目的。

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

'''

使用原網路儲存的模型載入到自己重新定義的圖上

可以使用python變數名載入模型,也可以使用節點名

'''

importAlexNet as Net

importAlexNet_train as train

importrandom

importtensorflow as tf

IMAGE_PATH='./flower_photos/daisy/5673728_71b8cb57eb.jpg'

with tf.Graph().as_default() as g:

x=tf.placeholder(tf.float32, [1, train.INPUT_SIZE[0], train.INPUT_SIZE[1],3])

y=Net.inference_1(x, N_CLASS=5, train=False)

with tf.Session() as sess:

# 程式前面得有 Variable 供 save or restore 才不報錯

# 否則會提示沒有可儲存的變數

saver=tf.train.Saver()

ckpt=tf.train.get_checkpoint_state('./model/')

img_raw=tf.gfile.FastGFile(IMAGE_PATH,'rb').read()

img=sess.run(tf.expand_dims(tf.image.resize_images(

tf.image.decode_jpeg(img_raw),[224,224],method=random.randint(0,3)),0))

ifckptandckpt.model_checkpoint_path:

print(ckpt.model_checkpoint_path)

saver.restore(sess,'./model/model.ckpt-0')

global_step=ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]

res=sess.run(y, feed_dict={x: img})

print(global_step,sess.run(tf.argmax(res,1)))

2.載入圖結構和引數

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

'''

直接使用使用儲存好的圖

無需載入python定義的結構,直接使用節點名稱載入模型

由於節點形狀已經定下來了,所以有不便之處,placeholder定義batch後單張傳會報錯

現階段不推薦使用,以後如果理解深入了可能會找到使用方法

'''

importAlexNet_train as train

importrandom

importtensorflow as tf

IMAGE_PATH='./flower_photos/daisy/5673728_71b8cb57eb.jpg'

ckpt=tf.train.get_checkpoint_state('./model/')# 通過檢查點檔案鎖定最新的模型

saver=tf.train.import_meta_graph(ckpt.model_checkpoint_path+'.meta')# 載入圖結構,儲存在.meta檔案中

with tf.Session() as sess:

saver.restore(sess,ckpt.model_checkpoint_path)# 載入引數,引數儲存在兩個檔案中,不過restore會自己尋找

img_raw=tf.gfile.FastGFile(IMAGE_PATH,'rb').read()

img=sess.run(tf.image.resize_images(

tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0,3)))

imgs=[]

foriinrange(128):

imgs.append(img)

print(sess.run(tf.get_default_graph().get_tensor_by_name('fc3:0'),feed_dict={'Placeholder:0': imgs}))

'''

img=sess.run(tf.expand_dims(tf.image.resize_images(

tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0,3)),0))

print(img)

imgs=[]

foriinrange(128):

imgs.append(img)

print(sess.run(tf.get_default_graph().get_tensor_by_name('conv1:0'),

feed_dict={'Placeholder:0':img}))

注意,在所有兩種方式中都可以通過呼叫節點名稱使用節點輸出張量,節點.name屬性返回節點名稱。

3.簡化版本

1

2

3

4

5

6

7

8

9

10

11

12

# 連同圖結構一同載入

ckpt=tf.train.get_checkpoint_state('./model/')

saver=tf.train.import_meta_graph(ckpt.model_checkpoint_path+'.meta')

with tf.Session() as sess:

saver.restore(sess,ckpt.model_checkpoint_path)

# 只加載資料,不載入圖結構,可以在新圖中改變batch_size等的值

# 不過需要注意,Saver物件例項化之前需要定義好新的圖結構,否則會報錯

saver=tf.train.Saver()

with tf.Session() as sess:

ckpt=tf.train.get_checkpoint_state('./model/')

saver.restore(sess,ckpt.model_checkpoint_path)

回到頂部

二、TensorFlow二進位制模型載入方法

這種載入方法一般是對應網上各大公司已經訓練好的網路模型進行修改的工作

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

# 新建空白圖

self.graph=tf.Graph()

# 空白圖列為預設圖

withself.graph.as_default():

# 二進位制讀取模型檔案

with tf.gfile.FastGFile(os.path.join(model_dir,model_name),'rb') as f:

# 新建GraphDef檔案,用於臨時載入模型中的圖

graph_def=tf.GraphDef()

# GraphDef載入模型中的圖

graph_def.ParseFromString(f.read())

# 在空白圖中載入GraphDef中的圖

tf.import_graph_def(graph_def,name='')

# 在圖中獲取張量需要使用graph.get_tensor_by_name加張量名

# 這裡的張量可以直接用於session的run方法求值了

# 補充一個基礎知識,形如'conv1'是節點名稱,而'conv1:0'是張量名稱,表示節點的第一個輸出張量

self.input_tensor=self.graph.get_tensor_by_name(self.input_tensor_name)

self.layer_tensors=[self.graph.get_tensor_by_name(name+':0')fornameinself.layer_operation_names]

『TensorFlow』遷移學習_他山之石,可以攻玉

『cs231n』通過程式碼理解風格遷移

上面兩篇都使用了二進位制載入模型的方式

回到頂部

三、二進位制模型製作

這節是關於tensorflow的Freezing,字面意思是冷凍,可理解為整合合併;整合什麼呢,就是將模型檔案和權重檔案整合合併為一個檔案,主要用途是便於釋出。

tensorflow在訓練過程中,通常不會將權重資料儲存的格式檔案裡(這裡我理解是模型檔案),反而是分開儲存在一個叫checkpoint的檢查點檔案裡,當初始化時,再通過模型檔案裡的變數Op節點來從checkoupoint檔案讀取資料並初始化變數。這種模型和權重資料分開儲存的情況,使得釋出產品時不是那麼方便,我們可以將tf的圖和引數檔案整合進一個字尾為pb的二進位制檔案中,由於整合過程回將變數轉化為常量,所以我們在日後讀取模型檔案時不能夠進行訓練,僅能向前傳播,而且我們在儲存時需要指定節點名稱。

將圖變數轉換為常量的API:tf.graph_util.convert_variables_to_constants

轉換後的graph_def物件轉換為二進位制資料(graph_def.SerializeToString())後,寫入pb即可。

1

2

3

4

5

6

7

8

9

10

11

12

13

importtensorflow as tf

v1=tf.Variable(tf.constant(1.0, shape=[1]), name='v1')

v2=tf.Variable(tf.constant(2.0, shape=[1]), name='v2')

result=v1+v2

saver=tf.train.Saver()

with tf.Session() as sess:

sess.run(tf.global_variables_initializer())

saver.save(sess,'./tmodel/test_model.ckpt')

gd=tf.graph_util.convert_variables_to_constants(sess, tf.get_default_graph().as_graph_def(), ['add'])

with tf.gfile.GFile('./tmodel/model.pb','wb') as f:

f.write(gd.SerializeToString())

我們可以直接檢視gd:

node {
name: "v1"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 1
}
}
float_val: 1.0
}
}
}
}
……
node {
name: "add"
op: "Add"
input: "v1/read"
input: "v2/read"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
library {
}

回到頂部

四、從圖上讀取張量

上面的程式碼實際上已經包含了本小節的內容,但是由於從圖上讀取特定的張量是如此的重要,所以我仍然單獨的補充上這部分的內容。

無論如何,想要獲取特定的張量我們必須要有張量的名稱圖的控制代碼,比如 'import/pool_3/_reshape:0' 這種,有了張量名和圖,索引就很簡單了。

從二進位制模型載入張量

第二小節的程式碼很好的展示了這種情況

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

BOTTLENECK_TENSOR_NAME='pool_3/_reshape:0'# 瓶頸層輸出張量名稱

JPEG_DATA_TENSOR_NAME='DecodeJpeg/contents:0'# 輸入層張量名稱

MODEL_DIR='./inception_dec_2015'# 模型存放資料夾

MODEL_FILE='tensorflow_inception_graph.pb'# 模型名

# 載入模型

# with gfile.FastGFile(os.path.join(MODEL_DIR,MODEL_FILE),'rb') as f: # 閱讀器上下文

withopen(os.path.join(MODEL_DIR, MODEL_FILE),'rb') as f:# 閱讀器上下文

graph_def=tf.GraphDef()# 生成圖

graph_def.ParseFromString(f.read())# 圖載入模型

# 載入圖上節點張量(按照控制代碼理解)

bottleneck_tensor, jpeg_data_tensor=tf.import_graph_def(# 從圖上讀取張量,同時匯入預設圖

graph_def,

return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME])

從當前圖中獲取對應張量

這個就是很普通的情況,從我們當前操作的圖中獲取某個張量,用於feed啦或者用於輸出等操作,API也很簡單,用法如下:

g.get_tensor_by_name('import/pool_3/_reshape:0')

g表示當前圖控制代碼,可以簡單的使用 g = tf.get_default_graph() 獲取。

從圖中獲取節點資訊

有的時候我們對於模型中的節點並不夠了解,此時我們可以通過圖控制代碼來查詢圖的構造:

1

2

g=tf.get_default_graph()

print(g.as_graph_def().node)

這個操作將返回圖的構造結構。從這裡,對比前面的程式碼,我們也可以瞭解到:graph_def 實際就是圖的結構資訊儲存形式,我們可以將之還原為圖(二進位制模型載入程式碼中展示了),也可以從圖中將之提取出來(本部分程式碼)。

轉自https://www.cnblogs.com/hellcat/p/6925757.html