1. 程式人生 > 程式設計 >keras中的loss、optimizer、metrics用法

keras中的loss、optimizer、metrics用法

用keras搭好模型架構之後的下一步,就是執行編譯操作。在編譯時,經常需要指定三個引數

loss

optimizer

metrics

這三個引數有兩類選擇:

使用字串

使用識別符號,如keras.losses,keras.optimizers,metrics包下面的函式

例如:

sgd = SGD(lr=0.01,decay=1e-6,momentum=0.9,nesterov=True)
model.compile(loss='categorical_crossentropy',optimizer=sgd,metrics=['accuracy'])

因為有時可以使用字串,有時可以使用識別符號,令人很想知道背後是如何操作的。下面分別針對optimizer,loss,metrics三種物件的獲取進行研究。

optimizer

一個模型只能有一個optimizer,在執行編譯的時候只能指定一個optimizer。

在keras.optimizers.py中,有一個get函式,用於根據使用者傳進來的optimizer引數獲取優化器的例項:

def get(identifier):
 # 如果後端是tensorflow並且使用的是tensorflow自帶的優化器例項,可以直接使用tensorflow原生的優化器 
 if K.backend() == 'tensorflow':
 # Wrap TF optimizer instances
 if isinstance(identifier,tf.train.Optimizer):
  return TFOptimizer(identifier)
 # 如果以json串的形式定義optimizer並進行引數配置
 if isinstance(identifier,dict):
 return deserialize(identifier)
 elif isinstance(identifier,six.string_types):
 # 如果以字串形式指定optimizer,那麼使用優化器的預設配置引數
 config = {'class_name': str(identifier),'config': {}}
 return deserialize(config)
 if isinstance(identifier,Optimizer):
 # 如果使用keras封裝的Optimizer的例項
 return identifier
 else:
 raise ValueError('Could not interpret optimizer identifier: ' +
    str(identifier))

其中,deserilize(config)函式的作用就是把optimizer反序列化製造一個例項。

loss

keras.losses函式也有一個get(identifier)方法。其中需要注意以下一點:

如果identifier是可呼叫的一個函式名,也就是一個自定義的損失函式,這個損失函式返回值是一個張量。這樣就輕而易舉的實現了自定義損失函式。除了使用str和dict型別的identifier,我們也可以直接使用keras.losses包下面的損失函式。

def get(identifier):
 if identifier is None:
 return None
 if isinstance(identifier,six.string_types):
 identifier = str(identifier)
 return deserialize(identifier)
 if isinstance(identifier,dict):
 return deserialize(identifier)
 elif callable(identifier):
 return identifier
 else:
 raise ValueError('Could not interpret '
    'loss function identifier:',identifier)

metrics

在model.compile()函式中,optimizer和loss都是單數形式,只有metrics是複數形式。因為一個模型只能指明一個optimizer和loss,卻可以指明多個metrics。metrics也是三者中處理邏輯最為複雜的一個。

在keras最核心的地方keras.engine.train.py中有如下處理metrics的函式。這個函式其實就做了兩件事:

根據輸入的metric找到具體的metric對應的函式

計算metric張量

在尋找metric對應函式時,有兩種步驟:

使用字串形式指明準確率和交叉熵

使用keras.metrics.py中的函式

def handle_metrics(metrics,weights=None):
 metric_name_prefix = 'weighted_' if weights is not None else ''

 for metric in metrics:
 # 如果metrics是最常見的那種:accuracy,交叉熵
 if metric in ('accuracy','acc','crossentropy','ce'):
  # custom handling of accuracy/crossentropy
  # (because of class mode duality)
  output_shape = K.int_shape(self.outputs[i])
  # 如果輸出維度是1或者損失函式是二分類損失函式,那麼說明是個二分類問題,應該使用二分類的accuracy和二分類的的交叉熵
  if (output_shape[-1] == 1 or
  self.loss_functions[i] == losses.binary_crossentropy):
  # case: binary accuracy/crossentropy
  if metric in ('accuracy','acc'):
   metric_fn = metrics_module.binary_accuracy
  elif metric in ('crossentropy','ce'):
   metric_fn = metrics_module.binary_crossentropy
  # 如果損失函式是sparse_categorical_crossentropy,那麼目標y_input就不是one-hot的,所以就需要使用sparse的多類準去率和sparse的多類交叉熵
  elif self.loss_functions[i] == losses.sparse_categorical_crossentropy:
  # case: categorical accuracy/crossentropy
  # with sparse targets
  if metric in ('accuracy','acc'):
   metric_fn = metrics_module.sparse_categorical_accuracy
  elif metric in ('crossentropy','ce'):
   metric_fn = metrics_module.sparse_categorical_crossentropy
  else:
  # case: categorical accuracy/crossentropy
  if metric in ('accuracy','acc'):
   metric_fn = metrics_module.categorical_accuracy
  elif metric in ('crossentropy','ce'):
   metric_fn = metrics_module.categorical_crossentropy
  if metric in ('accuracy','acc'):
   suffix = 'acc'
  elif metric in ('crossentropy','ce'):
   suffix = 'ce'
  weighted_metric_fn = weighted_masked_objective(metric_fn)
  metric_name = metric_name_prefix + suffix
 else:
  # 如果輸入的metric不是字串,那麼就呼叫metrics模組獲取
  metric_fn = metrics_module.get(metric)
  weighted_metric_fn = weighted_masked_objective(metric_fn)
  # Get metric name as string
  if hasattr(metric_fn,'name'):
  metric_name = metric_fn.name
  else:
  metric_name = metric_fn.__name__
  metric_name = metric_name_prefix + metric_name

 with K.name_scope(metric_name):
  metric_result = weighted_metric_fn(y_true,y_pred,weights=weights,mask=masks[i])

 # Append to self.metrics_names,self.metric_tensors,# self.stateful_metric_names
 if len(self.output_names) > 1:
  metric_name = self.output_names[i] + '_' + metric_name
 # Dedupe name
 j = 1
 base_metric_name = metric_name
 while metric_name in self.metrics_names:
  metric_name = base_metric_name + '_' + str(j)
  j += 1
 self.metrics_names.append(metric_name)
 self.metrics_tensors.append(metric_result)

 # Keep track of state updates created by
 # stateful metrics (i.e. metrics layers).
 if isinstance(metric_fn,Layer) and metric_fn.stateful:
  self.stateful_metric_names.append(metric_name)
  self.stateful_metric_functions.append(metric_fn)
  self.metrics_updates += metric_fn.updates

無論怎麼使用metric,最終都會變成metrics包下面的函式。當使用字串形式指明accuracy和crossentropy時,keras會非常智慧地確定應該使用metrics包下面的哪個函式。因為metrics包下的那些metric函式有不同的使用場景,例如:

有的處理的是one-hot形式的y_input(資料的類別),有的處理的是非one-hot形式的y_input

有的處理的是二分類問題的metric,有的處理的是多分類問題的metric

當使用字串“accuracy”和“crossentropy”指明metric時,keras會根據損失函式、輸出層的shape來確定具體應該使用哪個metric函式。在任何情況下,直接使用metrics下面的函式名是總不會出錯的。

keras.metrics.py檔案中也有一個get(identifier)函式用於獲取metric函式。

def get(identifier):
 if isinstance(identifier,dict):
 config = {'class_name': str(identifier),'config': {}}
 return deserialize(config)
 elif isinstance(identifier,six.string_types):
 return deserialize(str(identifier))
 elif callable(identifier):
 return identifier
 else:
 raise ValueError('Could not interpret '
    'metric function identifier:',identifier)

如果identifier是字串或者字典,那麼會根據identifier反序列化出一個metric函式。

如果identifier本身就是一個函式名,那麼就直接返回這個函式名。這種方式就為自定義metric提供了巨大便利。

keras中的設計哲學堪稱完美。

以上這篇keras中的loss、optimizer、metrics用法就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。