1. 程式人生 > 程式設計 >基於Keras的格式化輸出Loss實現方式

基於Keras的格式化輸出Loss實現方式

在win7 64位,Anaconda安裝的Python3.6.1下安裝的TensorFlow與Keras,Keras的backend為TensorFlow。在執行Mask R-CNN時,在進行除錯時想知道PyCharm (Python IDE)底部視窗輸出的Loss格式是在哪裡定義的,如下圖紅框中所示:

基於Keras的格式化輸出Loss實現方式

圖1 訓練過程的Loss格式化輸出

在上圖紅框中,Loss的輸出格式是在哪裡定義的呢?有一點是明確的,即上圖紅框中的內容是在訓練的時候輸出的。那麼先來看一下Mask R-CNN的訓練過程。Keras以Numpy陣列作為輸入資料和標籤的資料型別。訓練模型一般使用 fit 函式。然而由於Mask R-CNN訓練資料巨大,不能一次性全部載入,否則太消耗記憶體。於是採用生成器的方式一次載入一個batch的資料,而且是在用到這個batch的資料才開始載入的,那麼它的訓練函式如下:

 self.keras_model.fit_generator(
   train_generator,initial_epoch=self.epoch,epochs=epochs,steps_per_epoch=self.config.STEPS_PER_EPOCH,callbacks=callbacks,validation_data=val_generator,validation_steps=self.config.VALIDATION_STEPS,max_queue_size=100,workers=workers,use_multiprocessing=False,)

這裡訓練模型的函式相應的為 fit_generator 函式。注意其中的引數callbacks=callbacks,這個引數在輸出紅框中的內容起到了關鍵性的作用。下面看一下callbacks的值:

# Callbacks
  callbacks = [
   keras.callbacks.TensorBoard(log_dir=self.log_dir,histogram_freq=0,write_graph=True,write_images=False),keras.callbacks.ModelCheckpoint(self.checkpoint_path,verbose=0,save_weights_only=True),]

在輸出紅框中的內容所需的資料均儲存在self.log_dir下。然後除錯進入self.keras_model.fit_generator函式,進入keras,legacy.interfaces的legacy_support(func)函式,如下所示:

 def legacy_support(func):
  @six.wraps(func)
  def wrapper(*args,**kwargs):
   if object_type == 'class':
    object_name = args[0].__class__.__name__
   else:
    object_name = func.__name__
   if preprocessor:
    args,kwargs,converted = preprocessor(args,kwargs)
   else:
    converted = []
   if check_positional_args:
    if len(args) > len(allowed_positional_args) + 1:
     raise TypeError('`' + object_name +
         '` can accept only ' +
         str(len(allowed_positional_args)) +
         ' positional arguments ' +
         str(tuple(allowed_positional_args)) +
         ',but you passed the following '
         'positional arguments: ' +
         str(list(args[1:])))
   for key in value_conversions:
    if key in kwargs:
     old_value = kwargs[key]
     if old_value in value_conversions[key]:
      kwargs[key] = value_conversions[key][old_value]
   for old_name,new_name in conversions:
    if old_name in kwargs:
     value = kwargs.pop(old_name)
     if new_name in kwargs:
      raise_duplicate_arg_error(old_name,new_name)
     kwargs[new_name] = value
     converted.append((new_name,old_name))
   if converted:
    signature = '`' + object_name + '('
    for i,value in enumerate(args[1:]):
     if isinstance(value,six.string_types):
      signature += '"' + value + '"'
     else:
      if isinstance(value,np.ndarray):
       str_val = 'array'
      else:
       str_val = str(value)
      if len(str_val) > 10:
       str_val = str_val[:10] + '...'
      signature += str_val
     if i < len(args[1:]) - 1 or kwargs:
      signature += ','
    for i,(name,value) in enumerate(kwargs.items()):
     signature += name + '='
     if isinstance(value,np.ndarray):
       str_val = 'array'
      else:
       str_val = str(value)
      if len(str_val) > 10:
       str_val = str_val[:10] + '...'
      signature += str_val
     if i < len(kwargs) - 1:
      signature += ','
    signature += ')`'
    warnings.warn('Update your `' + object_name +
        '` call to the Keras 2 API: ' + signature,stacklevel=2)
   return func(*args,**kwargs)
  wrapper._original_function = func
  return wrapper
 return legacy_support

在上述程式碼的倒數第4行的return func(*args,**kwargs)處返回func,func為fit_generator函式,現除錯進入fit_generator函式,該函式定義在keras.engine.training模組內的fit_generator函式,除錯進入函式callbacks.on_epoch_begin(epoch),如下所示:

# Construct epoch logs.
   epoch_logs = {}
   while epoch < epochs:
    for m in self.stateful_metric_functions:
     m.reset_states()
    callbacks.on_epoch_begin(epoch)

除錯進入到callbacks.on_epoch_begin(epoch)函式,進入on_epoch_begin函式,如下所示:

def on_epoch_begin(self,epoch,logs=None):
  """Called at the start of an epoch.
  # Arguments
   epoch: integer,index of epoch.
   logs: dictionary of logs.
  """
  logs = logs or {}
  for callback in self.callbacks:
   callback.on_epoch_begin(epoch,logs)
  self._delta_t_batch = 0.
  self._delta_ts_batch_begin = deque([],maxlen=self.queue_length)
  self._delta_ts_batch_end = deque([],maxlen=self.queue_length)

在上述函式on_epoch_begin中除錯進入callback.on_epoch_begin(epoch,logs)函式,轉到類ProgbarLogger(Callback)中定義的on_epoch_begin函式,如下所示:

class ProgbarLogger(Callback):
 """Callback that prints metrics to stdout.
 # Arguments
  count_mode: One of "steps" or "samples".
   Whether the progress bar should
   count samples seen or steps (batches) seen.
  stateful_metrics: Iterable of string names of metrics that
   should *not* be averaged over an epoch.
   Metrics in this list will be logged as-is.
   All others will be averaged over time (e.g. loss,etc).
 # Raises
  ValueError: In case of invalid `count_mode`.
 """
 
 def __init__(self,count_mode='samples',stateful_metrics=None):
  super(ProgbarLogger,self).__init__()
  if count_mode == 'samples':
   self.use_steps = False
  elif count_mode == 'steps':
   self.use_steps = True
  else:
   raise ValueError('Unknown `count_mode`: ' + str(count_mode))
  if stateful_metrics:
   self.stateful_metrics = set(stateful_metrics)
  else:
   self.stateful_metrics = set()
 
 def on_train_begin(self,logs=None):
  self.verbose = self.params['verbose']
  self.epochs = self.params['epochs']
 
 def on_epoch_begin(self,logs=None):
  if self.verbose:
   print('Epoch %d/%d' % (epoch + 1,self.epochs))
   if self.use_steps:
    target = self.params['steps']
   else:
    target = self.params['samples']
   self.target = target
   self.progbar = Progbar(target=self.target,verbose=self.verbose,stateful_metrics=self.stateful_metrics)
  self.seen = 0

在上述程式碼的

print('Epoch %d/%d' % (epoch + 1,self.epochs))

輸出

Epoch 1/40(如紅框中所示內容的第一行)。

然後返回到keras.engine.training模組內的fit_generator函式,執行到self.train_on_batch函式,如下所示:

outs = self.train_on_batch(x,y,sample_weight=sample_weight,class_weight=class_weight)
 
     if not isinstance(outs,list):
      outs = [outs]
     for l,o in zip(out_labels,outs):
      batch_logs[l] = o
 
     callbacks.on_batch_end(batch_index,batch_logs)
 
     batch_index += 1
     steps_done += 1

除錯進入上述程式碼中的callbacks.on_batch_end(batch_index,batch_logs)函式,進入到on_batch_end函式後,該函式的定義如下所示:

 def on_batch_end(self,batch,logs=None):
  """Called at the end of a batch.
  # Arguments
   batch: integer,index of batch within the current epoch.
   logs: dictionary of logs.
  """
  logs = logs or {}
  if not hasattr(self,'_t_enter_batch'):
   self._t_enter_batch = time.time()
  self._delta_t_batch = time.time() - self._t_enter_batch
  t_before_callbacks = time.time()
  for callback in self.callbacks:
   callback.on_batch_end(batch,logs)
  self._delta_ts_batch_end.append(time.time() - t_before_callbacks)
  delta_t_median = np.median(self._delta_ts_batch_end)
  if (self._delta_t_batch > 0. and
   (delta_t_median > 0.95 * self._delta_t_batch and delta_t_median > 0.1)):
   warnings.warn('Method on_batch_end() is slow compared '
       'to the batch update (%f). Check your callbacks.'
       % delta_t_median)

接著繼續除錯進入上述程式碼中的callback.on_batch_end(batch,logs)函式,進入到在類中ProgbarLogger(Callback)定義的on_batch_end函式,如下所示:

def on_batch_end(self,logs=None):
  logs = logs or {}
  batch_size = logs.get('size',0)
  if self.use_steps:
   self.seen += 1
  else:
   self.seen += batch_size
 
  for k in self.params['metrics']:
   if k in logs:
    self.log_values.append((k,logs[k]))
 
  # Skip progbar update for the last batch;
  # will be handled by on_epoch_end.
  if self.verbose and self.seen < self.target:
   self.progbar.update(self.seen,self.log_values)

然後執行到上述程式碼的最後一行self.progbar.update(self.seen,self.log_values),除錯進入update函式,該函式定義在模組keras.utils.generic_utils中的類Progbar(object)定義的函式。類的定義及方法如下所示:

class Progbar(object):
 """Displays a progress bar.
 # Arguments
  target: Total number of steps expected,None if unknown.
  width: Progress bar width on screen.
  verbose: Verbosity mode,0 (silent),1 (verbose),2 (semi-verbose)
  stateful_metrics: Iterable of string names of metrics that
   should *not* be averaged over time. Metrics in this list
   will be displayed as-is. All others will be averaged
   by the progbar before display.
  interval: Minimum visual progress update interval (in seconds).
 """
 
 def __init__(self,target,width=30,verbose=1,interval=0.05,stateful_metrics=None):
  self.target = target
  self.width = width
  self.verbose = verbose
  self.interval = interval
  if stateful_metrics:
   self.stateful_metrics = set(stateful_metrics)
  else:
   self.stateful_metrics = set()
 
  self._dynamic_display = ((hasattr(sys.stdout,'isatty') and
         sys.stdout.isatty()) or
         'ipykernel' in sys.modules)
  self._total_width = 0
  self._seen_so_far = 0
  self._values = collections.OrderedDict()
  self._start = time.time()
  self._last_update = 0
 
 def update(self,current,values=None):
  """Updates the progress bar.
  # Arguments
   current: Index of current step.
   values: List of tuples:
    `(name,value_for_last_step)`.
    If `name` is in `stateful_metrics`,`value_for_last_step` will be displayed as-is.
    Else,an average of the metric over time will be displayed.
  """
  values = values or []
  for k,v in values:
   if k not in self.stateful_metrics:
    if k not in self._values:
     self._values[k] = [v * (current - self._seen_so_far),current - self._seen_so_far]
    else:
     self._values[k][0] += v * (current - self._seen_so_far)
     self._values[k][1] += (current - self._seen_so_far)
   else:
    # Stateful metrics output a numeric value. This representation
    # means "take an average from a single value" but keeps the
    # numeric formatting.
    self._values[k] = [v,1]
  self._seen_so_far = current
 
  now = time.time()
  info = ' - %.0fs' % (now - self._start)
  if self.verbose == 1:
   if (now - self._last_update < self.interval and
     self.target is not None and current < self.target):
    return
 
   prev_total_width = self._total_width
   if self._dynamic_display:
    sys.stdout.write('\b' * prev_total_width)
    sys.stdout.write('\r')
   else:
    sys.stdout.write('\n')
 
   if self.target is not None:
    numdigits = int(np.floor(np.log10(self.target))) + 1
    barstr = '%%%dd/%d [' % (numdigits,self.target)
    bar = barstr % current
    prog = float(current) / self.target
    prog_width = int(self.width * prog)
    if prog_width > 0:
     bar += ('=' * (prog_width - 1))
     if current < self.target:
      bar += '>'
     else:
      bar += '='
    bar += ('.' * (self.width - prog_width))
    bar += ']'
   else:
    bar = '%7d/Unknown' % current
 
   self._total_width = len(bar)
   sys.stdout.write(bar)
 
   if current:
    time_per_unit = (now - self._start) / current
   else:
    time_per_unit = 0
   if self.target is not None and current < self.target:
    eta = time_per_unit * (self.target - current)
    if eta > 3600:
     eta_format = '%d:%02d:%02d' % (eta // 3600,(eta % 3600) // 60,eta % 60)
    elif eta > 60:
     eta_format = '%d:%02d' % (eta // 60,eta % 60)
    else:
     eta_format = '%ds' % eta
 
    info = ' - ETA: %s' % eta_format
   else:
    if time_per_unit >= 1:
     info += ' %.0fs/step' % time_per_unit
    elif time_per_unit >= 1e-3:
     info += ' %.0fms/step' % (time_per_unit * 1e3)
    else:
     info += ' %.0fus/step' % (time_per_unit * 1e6)
 
   for k in self._values:
    info += ' - %s:' % k
    if isinstance(self._values[k],list):
     avg = np.mean(
      self._values[k][0] / max(1,self._values[k][1]))
     if abs(avg) > 1e-3:
      info += ' %.4f' % avg
     else:
      info += ' %.4e' % avg
    else:
     info += ' %s' % self._values[k]
 
   self._total_width += len(info)
   if prev_total_width > self._total_width:
    info += (' ' * (prev_total_width - self._total_width))
 
   if self.target is not None and current >= self.target:
    info += '\n'
 
   sys.stdout.write(info)
   sys.stdout.flush()
 
  elif self.verbose == 2:
   if self.target is None or current >= self.target:
    for k in self._values:
     info += ' - %s:' % k
     avg = np.mean(
      self._values[k][0] / max(1,self._values[k][1]))
     if avg > 1e-3:
      info += ' %.4f' % avg
     else:
      info += ' %.4e' % avg
    info += '\n'
 
    sys.stdout.write(info)
    sys.stdout.flush()
 
  self._last_update = now
 
 def add(self,n,values=None):
  self.update(self._seen_so_far + n,values)

重點是上述程式碼中的update(self,values=None)函式,在該函式內設定斷點,即可調入該函式。下面重點分析上述程式碼中的幾個輸出條目:

1. sys.stdout.write('\n') #換行

2. sys.stdout.write('bar') #輸出 [..................],其中bar= [..................];

3. sys.stdout.write(info) #輸出loss格式,其中info='- ETA:...';

4. sys.stdout.flush() #重新整理快取,立即得到輸出。

通過對Mask R-CNN程式碼的除錯分析可知,圖1中的紅框中的訓練過程中的Loss格式化輸出是由built-in模組實現的。若想得到類似的格式化輸出,關鍵在self.keras_model.fit_generator函式中傳入callbacks引數和callbacks中內容的定義。

以上這篇基於Keras的格式化輸出Loss實現方式就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。