keras K.function獲取某層的輸出操作
如下所示:
from keras import backend as K from keras.models import load_model models = load_model('models.hdf5') image=r'image.png' images=cv2.imread(r'image.png') image_arr = process_image(image,(224,224,3)) image_arr = np.expand_dims(image_arr,axis=0) layer_1 = K.function([base_model.get_input_at(0)],[base_model.get_layer('layer_name').output]) f1 = layer_1([image_arr])[0]
載入訓練好並儲存的網路模型
載入資料(影象),並將資料處理成array形式
指定輸出層
將處理後的資料輸入,然後獲取輸出
其中,K.function有兩種不同的寫法:
1. 獲取名為layer_name的層的輸出
layer_1 = K.function([base_model.get_input_at(0)],[base_model.get_layer('layer_name').output]) #指定輸出層的名稱
2. 獲取第n層的輸出
layer_1 = K.function([model.get_input_at(0)],[model.layers[5].output]) #指定輸出層的序號(層號從0開始)
另外,需要注意的是,書寫不規範會導致報錯:
報錯:
TypeError: inputs to a TensorFlow backend function should be a list or tuple
將該句:
f1 = layer_1(image_arr)[0]
修改為:
f1 = layer_1([image_arr])[0]
補充知識:keras.backend.function()
如下所示:
def function(inputs,outputs,updates=None,**kwargs): """Instantiates a Keras function. Arguments: inputs: List of placeholder tensors. outputs: List of output tensors. updates: List of update ops. **kwargs: Passed to `tf.Session.run`. Returns: Output values as Numpy arrays. Raises: ValueError: if invalid kwargs are passed in. """ if kwargs: for key in kwargs: if (key not in tf_inspect.getargspec(session_module.Session.run)[0] and key not in tf_inspect.getargspec(Function.__init__)[0]): msg = ('Invalid argument "%s" passed to K.function with Tensorflow ' 'backend') % key raise ValueError(msg) return Function(inputs,updates=updates,**kwargs)
這是keras.backend.function()的原始碼。其中函式定義開頭的註釋就是官方文件對該函式的解釋。
我們可以發現function()函式返回的是一個Function物件。下面是Function類的定義。
class Function(object): """Runs a computation graph. Arguments: inputs: Feed placeholders to the computation graph. outputs: Output tensors to fetch. updates: Additional update ops to be run at function call. name: a name to help users identify what this function does. """ def __init__(self,inputs,name=None,**session_kwargs): updates = updates or [] if not isinstance(inputs,(list,tuple)): raise TypeError('`inputs` to a TensorFlow backend function ' 'should be a list or tuple.') if not isinstance(outputs,tuple)): raise TypeError('`outputs` of a TensorFlow backend function ' 'should be a list or tuple.') if not isinstance(updates,tuple)): raise TypeError('`updates` in a TensorFlow backend function ' 'should be a list or tuple.') self.inputs = list(inputs) self.outputs = list(outputs) with ops.control_dependencies(self.outputs): updates_ops = [] for update in updates: if isinstance(update,tuple): p,new_p = update updates_ops.append(state_ops.assign(p,new_p)) else: # assumed already an op updates_ops.append(update) self.updates_op = control_flow_ops.group(*updates_ops) self.name = name self.session_kwargs = session_kwargs def __call__(self,inputs): if not isinstance(inputs,tuple)): raise TypeError('`inputs` should be a list or tuple.') feed_dict = {} for tensor,value in zip(self.inputs,inputs): if is_sparse(tensor): sparse_coo = value.tocoo() indices = np.concatenate((np.expand_dims(sparse_coo.row,1),np.expand_dims(sparse_coo.col,1)),1) value = (indices,sparse_coo.data,sparse_coo.shape) feed_dict[tensor] = value session = get_session() updated = session.run( self.outputs + [self.updates_op],feed_dict=feed_dict,**self.session_kwargs) return updated[:len(self.outputs)]
所以,function函式利用我們之前已經建立好的comuptation graph。遵循計算圖,從輸入到定義的輸出。這也是為什麼該函式經常用於提取中間層結果。
以上這篇keras K.function獲取某層的輸出操作就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。