EMAHook¶
- class mmcls.engine.hooks.EMAHook(ema_type='ExponentialMovingAverage', strict_load=False, begin_iter=0, begin_epoch=0, evaluate_on_ema=True, evaluate_on_origin=False, **kwargs)[source]¶
A Hook to apply Exponential Moving Average (EMA) on the model during training.
Comparing with
mmengine.hooks.EMAHook
, this hook acceptsevaluate_on_ema
andevaluate_on_origin
arguments. By default, theevaluate_on_ema
is enabled, and if you want to do validation and testing on both original and EMA models, please set both argumentsTrue
.Note
EMAHook takes priority over CheckpointHook.
The original model parameters are actually saved in ema field after train.
begin_iter
andbegin_epoch
cannot be set at the same time.
- Parameters
ema_type (str) – The type of EMA strategy to use. You can find the supported strategies in
mmengine.model.averaged_model
. Defaults to ‘ExponentialMovingAverage’.strict_load (bool) – Whether to strictly enforce that the keys of
state_dict
in checkpoint match the keys returned byself.module.state_dict
. Defaults to False. Changed in v0.3.0.begin_iter (int) – The number of iteration to enable
EMAHook
. Defaults to 0.begin_epoch (int) – The number of epoch to enable
EMAHook
. Defaults to 0.evaluate_on_ema (bool) – Whether to evaluate (validate and test) on EMA model during val-loop and test-loop. Defaults to True.
evaluate_on_origin (bool) – Whether to evaluate (validate and test) on the original model during val-loop and test-loop. Defaults to False.
**kwargs – Keyword arguments passed to subclasses of
BaseAveragedModel
- after_load_checkpoint(runner, checkpoint)[source]¶
Resume ema parameters from checkpoint.
- Parameters
runner (Runner) – The runner of the testing process.
- after_test_epoch(runner, metrics=None)[source]¶
We recover source model’s parameter from ema model after test.
- after_val_epoch(runner, metrics=None)[source]¶
We recover source model’s parameter from ema model after validation.