tf.contrib.slim.arg_scope 完整
緣由
最近一直在看深度學習的代碼,又一次看到了slim.arg_scope()的嵌套使用,具體代碼如下:
with slim.arg_scope( [slim.conv2d, slim.separable_conv2d], weights_initializer=tf.truncated_normal_initializer( stddev=weights_initializer_stddev), activation_fn=activation_fn, normalizer_fn=slim.batch_norm if use_batch_norm elseNone): with slim.arg_scope([slim.batch_norm], **batch_norm_params): with slim.arg_scope( [slim.conv2d], weights_regularizer=slim.l2_regularizer(weight_decay)): with slim.arg_scope( [slim.separable_conv2d], weights_regularizer=depthwise_regularizer) as arg_sc:return arg_sc
由上述代碼可以看到,第一層argscope有slim.conv2d參數,第三層也有這個參數,那麽不同層的參數是如何相互補充,作用到之後的代碼塊中,就是這篇博文的出發點。
準備工作
我們先看一下arg_scope的函數聲明:
@tf_contextlib.contextmanager def arg_scope(list_ops_or_scope, **kwargs):
有函數修飾符@tf_contextlib.contextmanager修飾arg_scope函數,我們先研究下這個函數修飾符。
@的作用
@之後一般接一個可調用對象(tf_contextlib.contextmanager),一起構成函數修飾符(裝飾器),這個可調用對象將被修飾函數(arg_scope)作為參數,執行一系列輔助操作,我們來看一個demo:
import time def my_time(func): print(time.ctime()) return func() @my_time # 從這裏可以看出@time 等價於 time(xxx()),但是這種寫法你得考慮python代碼的執行順序 def xxx(): print(‘Hello world!‘) 運行結果: Wed Jul 26 23:01:21 2017 Hello world!
在這個例子中,xxx函數實現我們的主要功能,打印Hello world!,但我們想給xxx函數添加一些輔助操作,於是我們用函數修飾符@my_time,使xxx函數先打印時間。整個例子的執行流程為調用my_time可調用對象,它接受xxx函數作為參數,先打印時間,再執行xxx函數。
上下文管理器
既然arg_scope函數存在裝飾器,那麽我們應該了解一下裝飾器提供了什麽輔助功能,代碼為:
import contextlib as _contextlib from tensorflow.python.util import tf_decorator def contextmanager(target): """A tf_decorator-aware wrapper for `contextlib.contextmanager`. Usage is identical to `contextlib.contextmanager`. Args: target: A callable to be wrapped in a contextmanager. Returns: A callable that can be used inside of a `with` statement. """ context_manager = _contextlib.contextmanager(target) return tf_decorator.make_decorator(target, context_manager, ‘contextmanager‘)
可以看到導入了contextlib庫,這個庫提供了contextmanager函數,這也是一個裝飾器,它使被修飾的函數具有上下文管理器的功能。上下文管理器的功能是在我們執行一段代碼塊之前做一些準備工作,執行完代碼塊之後做一些收尾工作,同樣先來看一個上下文管理器的例子:
import time class MyTimer(object): def __init__(self, verbose = False): self.verbose = verbose def __enter__(self): self.start = time.time() return self def __exit__(self, *unused): self.end = time.time() self.secs = self.end - self.start self.msecs = self.secs * 1000 if self.verbose: print "elapsed time: %f ms" %self.msecs
with MyTimer(True):
print(‘Hello world!‘)
類MyTimer中的__enter__和__exit__方法分別是準備工作和收尾工作。整個代碼的執行過程為:先執行__enter__方法,__enter__方法中的返回值(這個例子中是self)可以用到代碼塊中,再執行語句塊,這個例子中是print函數,最後執行__exit__方法,更多關於上下文管理器的內容可以看這,我的例子也是從那copy的。contextlib中實現上下文管理器稍有不同,一樣來看個例子:
from contextlib import contextmanager @contextmanager def tag(name): print "<%s>" % name yield print "</%s>" % name >>> with tag("h1"): ... print "foo"
運行結果: <h1> foo </h1>
tag函數中yield之前的代碼相當於__enter__方法,yield產生的生成器相當於__enter__方法的返回值,yield之後的代碼相當於__exit__方法。
arg_scope方法
這裏我把arg_scope方法中代碼稍微做了一些精簡,代碼如下:
arg_scope = [{}]
@tf_contextlib.contextmanager def arg_scope(list_ops_or_scope, **kwargs):try: current_scope = current_arg_scope().copy() for op in list_ops_or_scope: key = arg_scope_func_key(op) # op的代號 if not has_arg_scope(op): # op是否用@slim.add_arg_scope修飾,這會在下一篇中介紹 raise ValueError(‘%s is not decorated with @add_arg_scope‘, _name_op(op)) if key in current_scope: current_kwargs = current_scope[key].copy() current_kwargs.update(kwargs) current_scope[key] = current_kwargs else: current_scope[key] = kwargs.copy() _get_arg_stack().append(current_scope) yield current_scope finally: _get_arg_stack().pop()
# demo
with slim.arg_scope(
[slim.conv2d, slim.separable_conv2d],
weights_initializer=tf.truncated_normal_initializer(
stddev=weights_initializer_stddev),
activation_fn=activation_fn,
normalizer_fn=slim.batch_norm if use_batch_norm else None):
with slim.arg_scope([slim.batch_norm], **batch_norm_params):
with slim.arg_scope(
[slim.conv2d],
weights_regularizer=slim.l2_regularizer(weight_decay)):
with slim.arg_scope(
[slim.separable_conv2d],
weights_regularizer=depthwise_regularizer) as arg_sc:
return arg_sc
我們沿著demo一步步看,其中arg_scope是一個棧。先看第一層,current_arg_scope()函數返回棧中最後一個元素,此時是空字典{},由於字典為空,所以會把conv2d和separable_conv2d加入字典,此時棧為[{‘conv2d‘: kargs, ‘separable_conv2d‘: kargs}],然後執行接下來的代碼塊,即第二層with,finally中函數要在代碼塊執行完後再執行;第二層執行完後棧為[{‘conv2d‘: kargs, ‘separable_conv2d‘: kargs},{‘conv2d‘: kargs, ‘separable_conv2d‘: kargs, ‘batch_norm‘: batch_norm_params}],可以看到是將第一層的字典復制之後檢查其中是否有與第二層相同的op,相同的op就把參數更新,不同的op就增加鍵值對,如這裏的batch_norm。
回到我們開頭提到的問題,不同層的參數是如何互相補充的?現在我們可以看到,參數存儲在棧中,每疊加一層,就在原有參數基礎上把新參數添加上去。
最後編輯於20:54:35 2018-07-23
tf.contrib.slim.arg_scope 完整