1. 程式人生 > >theano學習之Shared變數

theano學習之Shared變數

Shared 變數,意思是這些變數可以在運算過程中,不停地進行交換和更新值。 在定義 weights 和 bias 的情況下,會需要用到這樣的變數。

import numpy as np
import theano
import theano.tensor as T

#---------------------------shared是用來存放變數的,會不斷更新數值---------------------#

state = theano.shared(np.array(0,dtype=np.float64),'state')#用np.array給state賦初值,名字為state
inc = T.scalar('inc',dtype=state.dtype)#定義一個容器,名字為inc,值的型別為state型別,不弄用np.float64
accumulator = theano.function([inc],state,updates=[(state,state+inc)])#定義一個函式,傳過來的數為inc的值,結果為state,更新方法為state加inc

#--------輸出不能使用print(state),而要用state.get_value()來獲取state中個值---------#

print(state.get_value())#不傳值的話state裡面的值為賦初值裡的值
accumulator(1)#把1傳到累加器裡面去
print(state.get_value())#state的值更新了,變為1
accumulator(10)#把10傳過去
print(state.get_value())#state的值更新了,變為11

#---------------可以用set_value來改變state裡的值----------------------------------------#
state.set_value(-1)
accumulator(3)
print(state.get_value())#輸出值為2

#-------------------------------------臨時使用-------------------------------------------#
#有時只是想暫時使用 Shared 變數,並不需要把它更新: 這時我們可以定義一個 a 來臨時代替 state,注意定義 a 的時候也要統一dtype
tmp_func = state*2 + inc
a = T.scalar(dtype=state.dtype)
skip_shared = theano.function([inc,a],tmp_func,givens=[(state,a)])#忽略掉 Shared 變數自己的運算,輸入值是 [inc,a],相當於把 a 代入 state,輸出是 tmp_func,givens 就是想把什麼替換成什麼。 這樣的話,在呼叫 skip_shared 函式後,state 並沒有被改變。
print(skip_shared(2,3))#借用了一下share變數state
print(state.get_value())#原始值還是2



結果:

來源