1. 程式人生 > >tensorflow統計graph中的trainable_variables

tensorflow統計graph中的trainable_variables

原文地址https://blog.csdn.net/shwan_ma/article/details/78879620,版權歸原作者所有。

原博主寫的很好,將常用的方法記載下來供以後學習參考。

sess.run(tf.global_varibales_initializer())

variable_name = [v.name for v in tf.trainable_variables()]

print(variable_names)

 

variable_names = [v.name for v in tf.trainable_variables()]

values = sess.run(variable_names)

for k,v in zip(variable_names, values):

    print("Variable: ", k)

    print("Shape: ", v.shape)

    print(v)

 

for variable in tf.trainable_variables():

    shape = variable.get_shape()

    variable_parameters = 1

    for dim in shape:

        variable_parameters *= dim.value

    total_parameters += variable_parameters