1. 程式人生 > 其它 >mmdetection原始碼閱讀筆記

mmdetection原始碼閱讀筆記

目錄

MMDet

工作需要簡單看了一下原始碼,主要側重訓練和推理的部分,涉及到的是Registry、Runner和Hook部分。

核心庫

核心庫有MMDetection、MMSegmentation、MMDetection3d、MMCV。

MMDetection3d: 支援3d目標檢測的模型和資料集

MMDetection & MMSegmentation: 支援常規的目標檢測和分割的模型

MMCV:MM系列的基礎庫。支援了:

  • Universal IO APIs
  • Image/Video processing
  • Image and annotation visualization
  • Useful utilities (progress bar, timer, ...)
  • PyTorch runner with hooking mechanism
  • Various CNN architectures
  • High-quality implementation of common CUDA ops

核心元件

使用MMDet進行訓練和推理設計到的核心元件為MMCV/Registry和MMCV/Runner。還有一個非常重要的就是Hooks,可以在原始碼中看到開發人員大量的使用了hooks。

Registry

Registry是用來例項化MMDet中所有物件的工具,包括模型、資料集和Optimizer等等。

整體流程

Examples:

1. 註冊類:類名->類 的對映
VOXEL_ENCODERS = Registry('voxel_encoder')

@VOXEL_ENCODERS.register_module()
class HardSimpleVFE(nn.Module):
  ...
  
2. 例項化類物件:解析config,例項化類
def build_voxel_encoder(cfg):
    """Build voxel encoder."""
    return build(cfg, VOXEL_ENCODERS)

def build(cfg, registry, default_args=None):
  	...
    return build_from_cfg(cfg, registry, default_args)
  
def build_from_cfg(cfg, registry, default_args=None):
    ...
    args = cfg.copy()
    obj_cls = registry.get(obj_type)
    return obj_cls(**args)

核心元件

MMCV/utils/regitry.py 主要是用一個self._module_dict來儲存類名和類

class Registry:
    """A registry to map strings to classes.
    Registered object could be built from registry.
    Example:
        >>> MODELS = Registry('models')
        >>> @MODELS.register_module()
        >>> class ResNet:
        >>>     pass
        >>> resnet = MODELS.build(dict(type='ResNet'))
    Please refer to https://mmcv.readthedocs.io/en/latest/registry.html for
    advanced useage.
    Args:
        name (str): Registry name.
        build_func(func, optional): Build function to construct instance from
            Registry, func:`build_from_cfg` is used if neither ``parent`` or
            ``build_func`` is specified. If ``parent`` is specified and
            ``build_func`` is not given,  ``build_func`` will be inherited
            from ``parent``. Default: None.
        parent (Registry, optional): Parent registry. The class registered in
            children registry could be built from parent. Default: None.
        scope (str, optional): The scope of registry. It is the key to search
            for children registry. If not specified, scope will be the name of
            the package where class is defined, e.g. mmdet, mmcls, mmseg.
            Default: None.
    """
    def __init__(self, name, build_func=None, parent=None, scope=None):
        self._name = name
        self._module_dict = dict()
        self._children = dict()
        self._scope = self.infer_scope() if scope is None else scope

        # self.build_func will be set with the following priority:
        # 1. build_func
        # 2. parent.build_func
        # 3. build_from_cfg
        if build_func is None:
            if parent is not None:
                self.build_func = parent.build_func
            else:
                self.build_func = build_from_cfg
        else:
            self.build_func = build_func
        if parent is not None:
            assert isinstance(parent, Registry)
            parent._add_children(self)
            self.parent = parent
        else:
            self.parent = None
        
	def _register_module(self, module_class, module_name=None, force=False):
        if not inspect.isclass(module_class):
            raise TypeError('module must be a class, '
                            f'but got {type(module_class)}')

        if module_name is None:
            module_name = module_class.__name__
        if isinstance(module_name, str):
            module_name = [module_name]
        for name in module_name:
            if not force and name in self._module_dict:
                raise KeyError(f'{name} is already registered '
                               f'in {self.name}')
            self._module_dict[name] = module_class

    def register_module(self, name=None, force=False, module=None):
        """Register a module.
        A record will be added to `self._module_dict`, whose key is the class
        name or the specified name, and value is the class itself.
        It can be used as a decorator or a normal function.
        Example:
            >>> backbones = Registry('backbone')
            >>> @backbones.register_module()
            >>> class ResNet:
            >>>     pass
            >>> backbones = Registry('backbone')
            >>> @backbones.register_module(name='mnet')
            >>> class MobileNet:
            >>>     pass
            >>> backbones = Registry('backbone')
            >>> class ResNet:
            >>>     pass
            >>> backbones.register_module(ResNet)
        Args:
            name (str | None): The module name to be registered. If not
                specified, the class name will be used.
            force (bool, optional): Whether to override an existing class with
                the same name. Default: False.
            module (type): Module class to be registered.
        """
        if not isinstance(force, bool):
            raise TypeError(f'force must be a boolean, but got {type(force)}')
        # NOTE: This is a walkaround to be compatible with the old api,
        # while it may introduce unexpected bugs.
        if isinstance(name, type):
            return self.deprecated_register_module(name, force=force)

        # raise the error ahead of time
        if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
            raise TypeError(
                'name must be either of None, an instance of str or a sequence'
                f'  of str, but got {type(name)}')

        # use it as a normal method: x.register_module(module=SomeClass)
        if module is not None:
            self._register_module(
                module_class=module, module_name=name, force=force)
            return module

        # use it as a decorator: @x.register_module()
        def _register(cls):
            self._register_module(
                module_class=cls, module_name=name, force=force)
            return cls

        return _register

註冊類

VOXEL_ENCODERS = Registry('voxel_encoder')
MIDDLE_ENCODERS = Registry('middle_encoder')
FUSION_LAYERS = Registry('fusion_layer')
@VOXEL_ENCODERS.register_module()
class HardSimpleVFE(nn.Module):

build

def build(cfg, registry, default_args=None):
    """Build a module.

    Args:
        cfg (dict, list[dict]): The config of modules, is is either a dict
            or a list of configs.
        registry (:obj:`Registry`): A registry the module belongs to.
        default_args (dict, optional): Default arguments to build the module.
            Defaults to None.

    Returns:
        nn.Module: A built nn module.
    """
    if isinstance(cfg, list):
        modules = [
            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        ]
        # 注意這只是把一些細節的模組拼在一起
        return nn.Sequential(*modules)
    else:
        return build_from_cfg(cfg, registry, default_args)
def build_from_cfg(cfg, registry, default_args=None):
    """Build a module from config dict.
    Args:
        cfg (dict): Config dict. It should at least contain the key "type".
        registry (:obj:`Registry`): The registry to search the type from.
        default_args (dict, optional): Default initialization arguments.
    Returns:
        object: The constructed object.
    """
    if not isinstance(cfg, dict):
        raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
    if 'type' not in cfg:
        if default_args is None or 'type' not in default_args:
            raise KeyError(
                '`cfg` or `default_args` must contain the key "type", '
                f'but got {cfg}\n{default_args}')
    if not isinstance(registry, Registry):
        raise TypeError('registry must be an mmcv.Registry object, '
                        f'but got {type(registry)}')
    if not (isinstance(default_args, dict) or default_args is None):
        raise TypeError('default_args must be a dict or None, '
                        f'but got {type(default_args)}')

    args = cfg.copy()

    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)

    obj_type = args.pop('type')
    if isinstance(obj_type, str):
        obj_cls = registry.get(obj_type)
        if obj_cls is None:
            raise KeyError(
                f'{obj_type} is not in the {registry.name} registry')
    elif inspect.isclass(obj_type):
        obj_cls = obj_type
    else:
        raise TypeError(
            f'type must be a str or valid type, but got {type(obj_type)}')
    try:
        return obj_cls(**args)
    except Exception as e:
        # Normal TypeError does not print class name.
        raise type(e)(f'{obj_cls.__name__}: {e}')

Runner

MMCV/runner

從runner的目錄結構可以看出,runner主要負責的就是實現checkpoint、train、val、optimizer和hooks。

mmdet3d、mmdet、mmcv中train的呼叫關係可以總結為:Mmdet3d/train.py -> mmdet/train_detector() -> mmcv/runner.run() -> mmcv/epoch_base_runner.py/train()

程式碼實現

mmcv/epoch_base_runner.py/train()

可以看出train函式已經開始進行dataloader遍歷訓練的過程了。其中也添加了很多hooks,這些都是在runner例項化的時候就已經register進runner中的,在EpochBasedRunner類的父類BaseRunner中有register_hook方法負責這件事。

def train(self, data_loader, **kwargs):
        self.model.train()
        self.mode = 'train'
        self.data_loader = data_loader
        self._max_iters = self._max_epochs * len(self.data_loader)
        self.call_hook('before_train_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        for i, data_batch in enumerate(self.data_loader):
            self._inner_iter = i
            self.call_hook('before_train_iter')
            self.run_iter(data_batch, train_mode=True, **kwargs)
            self.call_hook('after_train_iter')
            self._iter += 1

        self.call_hook('after_train_epoch')
        self._epoch += 1

mmcv/epoch_base_runner.py/run_iter()

def run_iter(self, data_batch, train_mode, **kwargs):
        if self.batch_processor is not None:
            outputs = self.batch_processor(
                self.model, data_batch, train_mode=train_mode, **kwargs)
        elif train_mode:
            outputs = self.model.train_step(data_batch, self.optimizer,
                                            **kwargs)
        else:
            outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
        if not isinstance(outputs, dict):
            raise TypeError('"batch_processor()" or "model.train_step()"'
                            'and "model.val_step()" must return a dict')
        if 'log_vars' in outputs:
            self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
        self.outputs = outputs

Hook

  1. hook基類

mmcv/runner/hooks/hook.py

from mmcv.utils import Registry

HOOKS = Registry('hook')


class Hook:

    def before_run(self, runner):
        pass

    def after_run(self, runner):
        pass

    def before_epoch(self, runner):
        pass

    def after_epoch(self, runner):
        pass

    def before_iter(self, runner):
        pass

    def after_iter(self, runner):
        pass

    def before_train_epoch(self, runner):
        self.before_epoch(runner)

    def before_val_epoch(self, runner):
        self.before_epoch(runner)

    def after_train_epoch(self, runner):
        self.after_epoch(runner)

    def after_val_epoch(self, runner):
        self.after_epoch(runner)

    def before_train_iter(self, runner):
        self.before_iter(runner)

    def before_val_iter(self, runner):
        self.before_iter(runner)

    def after_train_iter(self, runner):
        self.after_iter(runner)

    def after_val_iter(self, runner):
        self.after_iter(runner)

    def every_n_epochs(self, runner, n):
        return (runner.epoch + 1) % n == 0 if n > 0 else False

    def every_n_inner_iters(self, runner, n):
        return (runner.inner_iter + 1) % n == 0 if n > 0 else False

    def every_n_iters(self, runner, n):
        return (runner.iter + 1) % n == 0 if n > 0 else False

    def end_of_epoch(self, runner):
        return runner.inner_iter + 1 == len(runner.data_loader)

    def is_last_epoch(self, runner):
        return runner.epoch + 1 == runner._max_epochs

    def is_last_iter(self, runner):
        return runner.iter + 1 == runner._max_iters

所有的hooks

__all__ = [
    'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
    'OptimizerHook', 'Fp16OptimizerHook', 'IterTimerHook',
    'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook',
    'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
    'NeptuneLoggerHook', 'WandbLoggerHook', 'DvcliveLoggerHook',
    'MomentumUpdaterHook', 'SyncBuffersHook', 'EMAHook', 'EvalHook',
    'DistEvalHook', 'ProfilerHook'
]
  1. hook嵌入入runner中

    用一個priority queue儲存例項化的hook物件,用來保證hook呼叫的優先順序。優先順序定義如下:

        """Hook priority levels.
        +------------+------------+
        | Level      | Value      |
        +============+============+
        | HIGHEST    | 0          |
        +------------+------------+
        | VERY_HIGH  | 10         |
        +------------+------------+
        | HIGH       | 30         |
        +------------+------------+
        | NORMAL     | 50         |
        +------------+------------+
        | LOW        | 70         |
        +------------+------------+
        | VERY_LOW   | 90         |
        +------------+------------+
        | LOWEST     | 100        |
        +------------+------------+
        """
    

    runner中有兩種hook註冊方式:

    1. register_hook
    2. register_hook_from_cfg

    這兩個方法是在hook的基類mmcv/runner/base_runner.py中實現的,可以看到在register_hook中,倒序遍歷佇列,當找到一個比當前hook優先順序高的hook時,就把當前的hook插入到這個hook的後面,如果找不到比它優先順序高的就直接放在第一位。

    def register_hook(self, hook, priority='NORMAL'):
            """Register a hook into the hook list.
            The hook will be inserted into a priority queue, with the specified
            priority (See :class:`Priority` for details of priorities).
            For hooks with the same priority, they will be triggered in the same
            order as they are registered.
            Args:
                hook (:obj:`Hook`): The hook to be registered.
                priority (int or str or :obj:`Priority`): Hook priority.
                    Lower value means higher priority.
            """
            assert isinstance(hook, Hook)
            if hasattr(hook, 'priority'):
                raise ValueError('"priority" is a reserved attribute for hooks')
            priority = get_priority(priority)
            hook.priority = priority
            # insert the hook to a sorted list
            inserted = False
            for i in range(len(self._hooks) - 1, -1, -1):
                if priority >= self._hooks[i].priority:
                    self._hooks.insert(i + 1, hook)
                    inserted = True
                    break
            if not inserted:
                self._hooks.insert(0, hook)
    
        def register_hook_from_cfg(self, hook_cfg):
            """Register a hook from its cfg.
            Args:
                hook_cfg (dict): Hook config. It should have at least keys 'type'
                  and 'priority' indicating its type and priority.
            Notes:
                The specific hook class to register should not use 'type' and
                'priority' arguments during initialization.
            """
            hook_cfg = hook_cfg.copy()
            priority = hook_cfg.pop('priority', 'NORMAL')
            hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
            self.register_hook(hook, priority=priority)
    
  2. runner中呼叫hook

    在priority queue中按順序遍歷hooks,確保優先順序。根據實現可以看出每次呼叫call_hook的時候整個佇列中的所有hook都會被呼叫到,並且執行自己實現的fn_name函式。

    def call_hook(self, fn_name):
            """Call all hooks.
            Args:
                fn_name (str): The function name in each hook to be called, such as
                    "before_train_epoch".
            """
            for hook in self._hooks:
                getattr(hook, fn_name)(self)
    
  3. Training前註冊hook

    例項化runner物件後,會去註冊runner中用到的hooks

        # register hooks
        runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                       cfg.checkpoint_config, cfg.log_config,
                                       cfg.get('momentum_config', None))
        if distributed:
            if isinstance(runner, EpochBasedRunner):
                runner.register_hook(DistSamplerSeedHook())
        # register eval hooks
        if validate:
            # Support batch_size > 1 in validation
            val_samples_per_gpu = cfg.data.val.pop('samples_per_gpu', 1)
            if val_samples_per_gpu > 1:
                # Replace 'ImageToTensor' to 'DefaultFormatBundle'
                cfg.data.val.pipeline = replace_ImageToTensor(
                    cfg.data.val.pipeline)
            val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
            val_dataloader = build_dataloader(
                val_dataset,
                samples_per_gpu=val_samples_per_gpu,
                workers_per_gpu=cfg.data.workers_per_gpu,
                dist=distributed,
                shuffle=False)
            eval_cfg = cfg.get('evaluation', {})
            eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
            eval_hook = DistEvalHook if distributed else EvalHook
            runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
    
        # user-defined hooks
        if cfg.get('custom_hooks', None):
            custom_hooks = cfg.custom_hooks
            assert isinstance(custom_hooks, list), \
                f'custom_hooks expect list type, but got {type(custom_hooks)}'
            for hook_cfg in cfg.custom_hooks:
                assert isinstance(hook_cfg, dict), \
                    'Each item in custom_hooks expects dict type, but got ' \
                    f'{type(hook_cfg)}'
                hook_cfg = hook_cfg.copy()
                priority = hook_cfg.pop('priority', 'NORMAL')
                hook = build_from_cfg(hook_cfg, HOOKS)
                runner.register_hook(hook, priority=priority)
    

    mmcv/base_runner.py

        def register_training_hooks(self,
                                    lr_config,
                                    optimizer_config=None,
                                    checkpoint_config=None,
                                    log_config=None,
                                    momentum_config=None,
                                    timer_config=dict(type='IterTimerHook'),
                                    custom_hooks_config=None):
            """Register default and custom hooks for training.
            Default and custom hooks include:
              Hooks                 Priority
            - LrUpdaterHook         10
            - MomentumUpdaterHook   30
            - OptimizerStepperHook  50
            - CheckpointSaverHook   70
            - IterTimerHook         80
            - LoggerHook(s)         90
            - CustomHook(s)         50 (default)
            """
            self.register_lr_hook(lr_config)
            self.register_momentum_hook(momentum_config)
            self.register_optimizer_hook(optimizer_config)
            self.register_checkpoint_hook(checkpoint_config)
            self.register_timer_hook(timer_config)
            self.register_logger_hooks(log_config)
            self.register_custom_hooks(custom_hooks_config)
    
  4. 訓練推理中呼叫hook

    def train(self, data_loader, **kwargs):
            self.model.train()
            self.mode = 'train'
            self.data_loader = data_loader
            self._max_iters = self._max_epochs * len(self.data_loader)
            self.call_hook('before_train_epoch')
            time.sleep(2)  # Prevent possible deadlock during epoch transition
            for i, data_batch in enumerate(self.data_loader):
                self._inner_iter = i
                self.call_hook('before_train_iter')
                self.run_iter(data_batch, train_mode=True, **kwargs)
                self.call_hook('after_train_iter')
                self._iter += 1
    
            self.call_hook('after_train_epoch')
            self._epoch += 1
            
     @torch.no_grad()
        def val(self, data_loader, **kwargs):
            self.model.eval()
            self.mode = 'val'
            self.data_loader = data_loader
            self.call_hook('before_val_epoch')
            time.sleep(2)  # Prevent possible deadlock during epoch transition
            for i, data_batch in enumerate(self.data_loader):
                self._inner_iter = i
                self.call_hook('before_val_iter')
                self.run_iter(data_batch, train_mode=False)
                self.call_hook('after_val_iter')
    
            self.call_hook('after_val_epoch')
    

    Examples:

    訓練的 self.call_hook('after_train_iter') 呼叫應該會發生backward更新梯度的操作,也就是說應該在optimizer_hook中有個實現的after_train_iter方法,實現了.backward()和optimizer.step()

    @HOOKS.register_module()
    class OptimizerHook(Hook):
    
        def __init__(self, grad_clip=None):
            self.grad_clip = grad_clip
    
        def clip_grads(self, params):
            params = list(
                filter(lambda p: p.requires_grad and p.grad is not None, params))
            if len(params) > 0:
                return clip_grad.clip_grad_norm_(params, **self.grad_clip)
    
        def after_train_iter(self, runner):
            runner.optimizer.zero_grad()
            runner.outputs['loss'].backward()
            if self.grad_clip is not None:
                grad_norm = self.clip_grads(runner.model.parameters())
                if grad_norm is not None:
                    # Add grad norm to the logger
                    runner.log_buffer.update({'grad_norm': float(grad_norm)},
                                             runner.outputs['num_samples'])
            runner.optimizer.step()