1. 程式人生 > 程式設計 >keras K.function獲取某層的輸出操作

keras K.function獲取某層的輸出操作

如下所示:

from keras import backend as K
from keras.models import load_model

models = load_model('models.hdf5')
image=r'image.png'
images=cv2.imread(r'image.png')
image_arr = process_image(image,(224,224,3))
image_arr = np.expand_dims(image_arr,axis=0)
layer_1 = K.function([base_model.get_input_at(0)],[base_model.get_layer('layer_name').output])
f1 = layer_1([image_arr])[0]

載入訓練好並儲存的網路模型

載入資料(影象),並將資料處理成array形式

指定輸出層

將處理後的資料輸入,然後獲取輸出

其中,K.function有兩種不同的寫法:

1. 獲取名為layer_name的層的輸出

layer_1 = K.function([base_model.get_input_at(0)],[base_model.get_layer('layer_name').output]) #指定輸出層的名稱

2. 獲取第n層的輸出

layer_1 = K.function([model.get_input_at(0)],[model.layers[5].output]) #指定輸出層的序號(層號從0開始)

另外,需要注意的是,書寫不規範會導致報錯:

報錯:

TypeError: inputs to a TensorFlow backend function should be a list or tuple

將該句:

f1 = layer_1(image_arr)[0]

修改為:

f1 = layer_1([image_arr])[0]

補充知識:keras.backend.function()

如下所示:

def function(inputs,outputs,updates=None,**kwargs):
 """Instantiates a Keras function.
 Arguments:
   inputs: List of placeholder tensors.
   outputs: List of output tensors.
   updates: List of update ops.
   **kwargs: Passed to `tf.Session.run`.
 Returns:
   Output values as Numpy arrays.
 Raises:
   ValueError: if invalid kwargs are passed in.
 """
 if kwargs:
  for key in kwargs:
   if (key not in tf_inspect.getargspec(session_module.Session.run)[0] and
     key not in tf_inspect.getargspec(Function.__init__)[0]):
    msg = ('Invalid argument "%s" passed to K.function with Tensorflow '
        'backend') % key
    raise ValueError(msg)
 return Function(inputs,updates=updates,**kwargs)

這是keras.backend.function()的原始碼。其中函式定義開頭的註釋就是官方文件對該函式的解釋。

我們可以發現function()函式返回的是一個Function物件。下面是Function類的定義。

class Function(object):
 """Runs a computation graph.
 Arguments:
   inputs: Feed placeholders to the computation graph.
   outputs: Output tensors to fetch.
   updates: Additional update ops to be run at function call.
   name: a name to help users identify what this function does.
 """

 def __init__(self,inputs,name=None,**session_kwargs):
  updates = updates or []
  if not isinstance(inputs,(list,tuple)):
   raise TypeError('`inputs` to a TensorFlow backend function '
           'should be a list or tuple.')
  if not isinstance(outputs,tuple)):
   raise TypeError('`outputs` of a TensorFlow backend function '
           'should be a list or tuple.')
  if not isinstance(updates,tuple)):
   raise TypeError('`updates` in a TensorFlow backend function '
           'should be a list or tuple.')
  self.inputs = list(inputs)
  self.outputs = list(outputs)
  with ops.control_dependencies(self.outputs):
   updates_ops = []
   for update in updates:
    if isinstance(update,tuple):
     p,new_p = update
     updates_ops.append(state_ops.assign(p,new_p))
    else:
     # assumed already an op
     updates_ops.append(update)
   self.updates_op = control_flow_ops.group(*updates_ops)
  self.name = name
  self.session_kwargs = session_kwargs

 def __call__(self,inputs):
  if not isinstance(inputs,tuple)):
   raise TypeError('`inputs` should be a list or tuple.')
  feed_dict = {}
  for tensor,value in zip(self.inputs,inputs):
   if is_sparse(tensor):
    sparse_coo = value.tocoo()
    indices = np.concatenate((np.expand_dims(sparse_coo.row,1),np.expand_dims(sparse_coo.col,1)),1)
    value = (indices,sparse_coo.data,sparse_coo.shape)
   feed_dict[tensor] = value
  session = get_session()
  updated = session.run(
    self.outputs + [self.updates_op],feed_dict=feed_dict,**self.session_kwargs)
  return updated[:len(self.outputs)]

所以,function函式利用我們之前已經建立好的comuptation graph。遵循計算圖,從輸入到定義的輸出。這也是為什麼該函式經常用於提取中間層結果。

以上這篇keras K.function獲取某層的輸出操作就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。