基於Keras的格式化輸出Loss實現方式
在win7 64位,Anaconda安裝的Python3.6.1下安裝的TensorFlow與Keras,Keras的backend為TensorFlow。在執行Mask R-CNN時,在進行除錯時想知道PyCharm (Python IDE)底部視窗輸出的Loss格式是在哪裡定義的,如下圖紅框中所示:
圖1 訓練過程的Loss格式化輸出
在上圖紅框中,Loss的輸出格式是在哪裡定義的呢?有一點是明確的,即上圖紅框中的內容是在訓練的時候輸出的。那麼先來看一下Mask R-CNN的訓練過程。Keras以Numpy陣列作為輸入資料和標籤的資料型別。訓練模型一般使用 fit 函式。然而由於Mask R-CNN訓練資料巨大,不能一次性全部載入,否則太消耗記憶體。於是採用生成器的方式一次載入一個batch的資料,而且是在用到這個batch的資料才開始載入的,那麼它的訓練函式如下:
self.keras_model.fit_generator( train_generator,initial_epoch=self.epoch,epochs=epochs,steps_per_epoch=self.config.STEPS_PER_EPOCH,callbacks=callbacks,validation_data=val_generator,validation_steps=self.config.VALIDATION_STEPS,max_queue_size=100,workers=workers,use_multiprocessing=False,)
這裡訓練模型的函式相應的為 fit_generator 函式。注意其中的引數callbacks=callbacks,這個引數在輸出紅框中的內容起到了關鍵性的作用。下面看一下callbacks的值:
# Callbacks callbacks = [ keras.callbacks.TensorBoard(log_dir=self.log_dir,histogram_freq=0,write_graph=True,write_images=False),keras.callbacks.ModelCheckpoint(self.checkpoint_path,verbose=0,save_weights_only=True),]
在輸出紅框中的內容所需的資料均儲存在self.log_dir下。然後除錯進入self.keras_model.fit_generator函式,進入keras,legacy.interfaces的legacy_support(func)函式,如下所示:
def legacy_support(func): @six.wraps(func) def wrapper(*args,**kwargs): if object_type == 'class': object_name = args[0].__class__.__name__ else: object_name = func.__name__ if preprocessor: args,kwargs,converted = preprocessor(args,kwargs) else: converted = [] if check_positional_args: if len(args) > len(allowed_positional_args) + 1: raise TypeError('`' + object_name + '` can accept only ' + str(len(allowed_positional_args)) + ' positional arguments ' + str(tuple(allowed_positional_args)) + ',but you passed the following ' 'positional arguments: ' + str(list(args[1:]))) for key in value_conversions: if key in kwargs: old_value = kwargs[key] if old_value in value_conversions[key]: kwargs[key] = value_conversions[key][old_value] for old_name,new_name in conversions: if old_name in kwargs: value = kwargs.pop(old_name) if new_name in kwargs: raise_duplicate_arg_error(old_name,new_name) kwargs[new_name] = value converted.append((new_name,old_name)) if converted: signature = '`' + object_name + '(' for i,value in enumerate(args[1:]): if isinstance(value,six.string_types): signature += '"' + value + '"' else: if isinstance(value,np.ndarray): str_val = 'array' else: str_val = str(value) if len(str_val) > 10: str_val = str_val[:10] + '...' signature += str_val if i < len(args[1:]) - 1 or kwargs: signature += ',' for i,(name,value) in enumerate(kwargs.items()): signature += name + '=' if isinstance(value,np.ndarray): str_val = 'array' else: str_val = str(value) if len(str_val) > 10: str_val = str_val[:10] + '...' signature += str_val if i < len(kwargs) - 1: signature += ',' signature += ')`' warnings.warn('Update your `' + object_name + '` call to the Keras 2 API: ' + signature,stacklevel=2) return func(*args,**kwargs) wrapper._original_function = func return wrapper return legacy_support
在上述程式碼的倒數第4行的return func(*args,**kwargs)處返回func,func為fit_generator函式,現除錯進入fit_generator函式,該函式定義在keras.engine.training模組內的fit_generator函式,除錯進入函式callbacks.on_epoch_begin(epoch),如下所示:
# Construct epoch logs. epoch_logs = {} while epoch < epochs: for m in self.stateful_metric_functions: m.reset_states() callbacks.on_epoch_begin(epoch)
除錯進入到callbacks.on_epoch_begin(epoch)函式,進入on_epoch_begin函式,如下所示:
def on_epoch_begin(self,epoch,logs=None): """Called at the start of an epoch. # Arguments epoch: integer,index of epoch. logs: dictionary of logs. """ logs = logs or {} for callback in self.callbacks: callback.on_epoch_begin(epoch,logs) self._delta_t_batch = 0. self._delta_ts_batch_begin = deque([],maxlen=self.queue_length) self._delta_ts_batch_end = deque([],maxlen=self.queue_length)
在上述函式on_epoch_begin中除錯進入callback.on_epoch_begin(epoch,logs)函式,轉到類ProgbarLogger(Callback)中定義的on_epoch_begin函式,如下所示:
class ProgbarLogger(Callback): """Callback that prints metrics to stdout. # Arguments count_mode: One of "steps" or "samples". Whether the progress bar should count samples seen or steps (batches) seen. stateful_metrics: Iterable of string names of metrics that should *not* be averaged over an epoch. Metrics in this list will be logged as-is. All others will be averaged over time (e.g. loss,etc). # Raises ValueError: In case of invalid `count_mode`. """ def __init__(self,count_mode='samples',stateful_metrics=None): super(ProgbarLogger,self).__init__() if count_mode == 'samples': self.use_steps = False elif count_mode == 'steps': self.use_steps = True else: raise ValueError('Unknown `count_mode`: ' + str(count_mode)) if stateful_metrics: self.stateful_metrics = set(stateful_metrics) else: self.stateful_metrics = set() def on_train_begin(self,logs=None): self.verbose = self.params['verbose'] self.epochs = self.params['epochs'] def on_epoch_begin(self,logs=None): if self.verbose: print('Epoch %d/%d' % (epoch + 1,self.epochs)) if self.use_steps: target = self.params['steps'] else: target = self.params['samples'] self.target = target self.progbar = Progbar(target=self.target,verbose=self.verbose,stateful_metrics=self.stateful_metrics) self.seen = 0
在上述程式碼的
print('Epoch %d/%d' % (epoch + 1,self.epochs))
輸出
Epoch 1/40(如紅框中所示內容的第一行)。
然後返回到keras.engine.training模組內的fit_generator函式,執行到self.train_on_batch函式,如下所示:
outs = self.train_on_batch(x,y,sample_weight=sample_weight,class_weight=class_weight) if not isinstance(outs,list): outs = [outs] for l,o in zip(out_labels,outs): batch_logs[l] = o callbacks.on_batch_end(batch_index,batch_logs) batch_index += 1 steps_done += 1
除錯進入上述程式碼中的callbacks.on_batch_end(batch_index,batch_logs)函式,進入到on_batch_end函式後,該函式的定義如下所示:
def on_batch_end(self,batch,logs=None): """Called at the end of a batch. # Arguments batch: integer,index of batch within the current epoch. logs: dictionary of logs. """ logs = logs or {} if not hasattr(self,'_t_enter_batch'): self._t_enter_batch = time.time() self._delta_t_batch = time.time() - self._t_enter_batch t_before_callbacks = time.time() for callback in self.callbacks: callback.on_batch_end(batch,logs) self._delta_ts_batch_end.append(time.time() - t_before_callbacks) delta_t_median = np.median(self._delta_ts_batch_end) if (self._delta_t_batch > 0. and (delta_t_median > 0.95 * self._delta_t_batch and delta_t_median > 0.1)): warnings.warn('Method on_batch_end() is slow compared ' 'to the batch update (%f). Check your callbacks.' % delta_t_median)
接著繼續除錯進入上述程式碼中的callback.on_batch_end(batch,logs)函式,進入到在類中ProgbarLogger(Callback)定義的on_batch_end函式,如下所示:
def on_batch_end(self,logs=None): logs = logs or {} batch_size = logs.get('size',0) if self.use_steps: self.seen += 1 else: self.seen += batch_size for k in self.params['metrics']: if k in logs: self.log_values.append((k,logs[k])) # Skip progbar update for the last batch; # will be handled by on_epoch_end. if self.verbose and self.seen < self.target: self.progbar.update(self.seen,self.log_values)
然後執行到上述程式碼的最後一行self.progbar.update(self.seen,self.log_values),除錯進入update函式,該函式定義在模組keras.utils.generic_utils中的類Progbar(object)定義的函式。類的定義及方法如下所示:
class Progbar(object): """Displays a progress bar. # Arguments target: Total number of steps expected,None if unknown. width: Progress bar width on screen. verbose: Verbosity mode,0 (silent),1 (verbose),2 (semi-verbose) stateful_metrics: Iterable of string names of metrics that should *not* be averaged over time. Metrics in this list will be displayed as-is. All others will be averaged by the progbar before display. interval: Minimum visual progress update interval (in seconds). """ def __init__(self,target,width=30,verbose=1,interval=0.05,stateful_metrics=None): self.target = target self.width = width self.verbose = verbose self.interval = interval if stateful_metrics: self.stateful_metrics = set(stateful_metrics) else: self.stateful_metrics = set() self._dynamic_display = ((hasattr(sys.stdout,'isatty') and sys.stdout.isatty()) or 'ipykernel' in sys.modules) self._total_width = 0 self._seen_so_far = 0 self._values = collections.OrderedDict() self._start = time.time() self._last_update = 0 def update(self,current,values=None): """Updates the progress bar. # Arguments current: Index of current step. values: List of tuples: `(name,value_for_last_step)`. If `name` is in `stateful_metrics`,`value_for_last_step` will be displayed as-is. Else,an average of the metric over time will be displayed. """ values = values or [] for k,v in values: if k not in self.stateful_metrics: if k not in self._values: self._values[k] = [v * (current - self._seen_so_far),current - self._seen_so_far] else: self._values[k][0] += v * (current - self._seen_so_far) self._values[k][1] += (current - self._seen_so_far) else: # Stateful metrics output a numeric value. This representation # means "take an average from a single value" but keeps the # numeric formatting. self._values[k] = [v,1] self._seen_so_far = current now = time.time() info = ' - %.0fs' % (now - self._start) if self.verbose == 1: if (now - self._last_update < self.interval and self.target is not None and current < self.target): return prev_total_width = self._total_width if self._dynamic_display: sys.stdout.write('\b' * prev_total_width) sys.stdout.write('\r') else: sys.stdout.write('\n') if self.target is not None: numdigits = int(np.floor(np.log10(self.target))) + 1 barstr = '%%%dd/%d [' % (numdigits,self.target) bar = barstr % current prog = float(current) / self.target prog_width = int(self.width * prog) if prog_width > 0: bar += ('=' * (prog_width - 1)) if current < self.target: bar += '>' else: bar += '=' bar += ('.' * (self.width - prog_width)) bar += ']' else: bar = '%7d/Unknown' % current self._total_width = len(bar) sys.stdout.write(bar) if current: time_per_unit = (now - self._start) / current else: time_per_unit = 0 if self.target is not None and current < self.target: eta = time_per_unit * (self.target - current) if eta > 3600: eta_format = '%d:%02d:%02d' % (eta // 3600,(eta % 3600) // 60,eta % 60) elif eta > 60: eta_format = '%d:%02d' % (eta // 60,eta % 60) else: eta_format = '%ds' % eta info = ' - ETA: %s' % eta_format else: if time_per_unit >= 1: info += ' %.0fs/step' % time_per_unit elif time_per_unit >= 1e-3: info += ' %.0fms/step' % (time_per_unit * 1e3) else: info += ' %.0fus/step' % (time_per_unit * 1e6) for k in self._values: info += ' - %s:' % k if isinstance(self._values[k],list): avg = np.mean( self._values[k][0] / max(1,self._values[k][1])) if abs(avg) > 1e-3: info += ' %.4f' % avg else: info += ' %.4e' % avg else: info += ' %s' % self._values[k] self._total_width += len(info) if prev_total_width > self._total_width: info += (' ' * (prev_total_width - self._total_width)) if self.target is not None and current >= self.target: info += '\n' sys.stdout.write(info) sys.stdout.flush() elif self.verbose == 2: if self.target is None or current >= self.target: for k in self._values: info += ' - %s:' % k avg = np.mean( self._values[k][0] / max(1,self._values[k][1])) if avg > 1e-3: info += ' %.4f' % avg else: info += ' %.4e' % avg info += '\n' sys.stdout.write(info) sys.stdout.flush() self._last_update = now def add(self,n,values=None): self.update(self._seen_so_far + n,values)
重點是上述程式碼中的update(self,values=None)函式,在該函式內設定斷點,即可調入該函式。下面重點分析上述程式碼中的幾個輸出條目:
1. sys.stdout.write('\n') #換行
2. sys.stdout.write('bar') #輸出 [..................],其中bar= [..................];
3. sys.stdout.write(info) #輸出loss格式,其中info='- ETA:...';
4. sys.stdout.flush() #重新整理快取,立即得到輸出。
通過對Mask R-CNN程式碼的除錯分析可知,圖1中的紅框中的訓練過程中的Loss格式化輸出是由built-in模組實現的。若想得到類似的格式化輸出,關鍵在self.keras_model.fit_generator函式中傳入callbacks引數和callbacks中內容的定義。
以上這篇基於Keras的格式化輸出Loss實現方式就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。