1. 程式人生 > >Tensorflow學習筆記(三)--變數作用域

Tensorflow學習筆記(三)--變數作用域

變數作用域機制主要由兩個函式實現:

tf.get_variable(<name>, <shape>, <initializer>)
tf.variable_scope(<scope_name>)

常用的initializer有

tf.constant_initializer(value) # 初始化一個常量值,
tf.random_uniform_initializer(a, b) # 從a到b均勻分佈的初始化,
tf.random_normal_initializer(mean, stddev) # 用所給平均值和標準差初始化正態分佈.

對如下例項,

with tf.variable_scope(conv1):
    # Variables created here will be named "conv1/weights".
    weights = tf.get_variable('weights',kernel_shape,
              initializer=tf.random_normal_initializer())

    # Variables created here will be named "conv1/biases".
    biases = tf.get_variable('biases'
,biases_shape, initializer=tf.constant_intializer(0.0))

變數作用域的tf.variable_scope()帶有一個名稱,它將會作為字首用於變數名,並且帶有一個重用標籤(後面會說到)來區分以上的兩種情況。巢狀的作用域附加名字所用的規則和檔案目錄的規則很類似。

對於採用了變數作用域的網路結構,結構虛擬碼如下:

def conv_relu(input, kernel_shape, bias_shape):
    # Create variable named "weights".
    weights = tf.get_variable("weights"
, kernel_shape, initializer=tf.random_normal_initializer()) # Create variable named "biases". biases = tf.get_variable("biases", bias_shape, initializer=tf.constant_intializer(0.0)) conv = tf.nn.conv2d(input, weights, strides=[1, 1, 1, 1], padding='SAME') return tf.nn.relu(conv + biases) def my_image_filter(input_images): with tf.variable_scope("conv1"): # Variables created here will be named "conv1/weights", "conv1/biases". relu1 = conv_relu(input_images, [5, 5, 32, 32], [32]) with tf.variable_scope("conv2"): # Variables created here will be named "conv2/weights", "conv2/biases". return conv_relu(relu1, [5, 5, 32, 32], [32])

如果連續呼叫兩次my_image_filter()將會報出ValueError:

result1 = my_image_filter(image1)
result2 = my_image_filter(image2)
# Raises ValueError(... conv1/weights already exists ...)

若不在網路架構中採用變數作用域則不會報錯,但是會產生兩組變數,而不是共享變數。

變數作用域是怎麼工作的?

理解tf.get_variable()

情況1:當tf.get_variable_scope().reuse == False時,該方法用來建立新變數。

with tf.variable_scope("foo"):
    v = tf.get_variable("v", [1])
assert v.name == "foo/v:0"

該情況下方法會生成一個“foo/v”,並檢查確保沒有其他變數使用該全稱。如果該全程已經有其他的變數在使用了,則會丟擲ValueError。

情況2:當tf.get_variable_scope().reuse == True時,該方法是為重用變數所設定

with tf.variable_scope("foo"):
    v = tf.get_variable("v", [1])
with tf.variable_scope("foo", reuse=True):
    v1 = tf.get_variable("v", [1])
assert v1 == v

該情況下會搜尋一個已存在的“foo/v”並將該變數的值賦給v1,若找不到“foo/v”變數則會丟擲ValueError。

注意reuse標籤可以被手動設定為True,但不能手動設定為False。reuse 引數是不可繼承的,所以當你設定一個變數作用域為重用作用域時,那麼其所有的子作用域也將會被重用。