Migration from MMClassification 0.x¶
We introduce some modifications in MMClassification 1.x, and some of them are BC-breading. To migrate your projects from MMClassification 0.x smoothly, please read this tutorial.
New dependencies¶
MMClassification 1.x depends on some new packages, you can prepare a new clean environment and install again according to the install tutorial. Or install the below packages manually.
MMEngine: MMEngine is the core the OpenMMLab 2.0 architecture, and we splited many compentents unrelated to computer vision from MMCV to MMEngine.
MMCV: The computer vision package of OpenMMLab. This is not a new dependency, but you need to upgrade it to above
2.0.0rc1
version.rich: A terminal formatting package, and we use it to beautify some outputs in the terminal.
Configuration files¶
In MMClassification 1.x, we refactored the structure of configuration files, and the original files are not usable.
In this section, we will introduce all changes of the configuration files. And we assume you already have ideas of the config files.
Model settings¶
No changes in model.backbone
, model.neck
and model.head
fields.
Changes in model.train_cfg
:
BatchMixup
is renamed toMixup
.BatchCutMix
is renamed toCutMix
.BatchResizeMix
is renamed toResizeMix
.The
prob
argument is removed from all augments settings, and you can use theprobs
field intrain_cfg
to specify probabilities of every augemnts. If noprobs
field, randomly choose one by the same probability.
Original |
model = dict(
...
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]
)
|
New |
model = dict(
...
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8), dict(type='CutMix', alpha=1.0)]
)
|
Data settings¶
Changes in data
:
The original
data
field is splited totrain_dataloader
,val_dataloader
andtest_dataloader
. This allows us to configure them in fine-grained. For example, you can specify different sampler and batch size during training and test.The
samples_per_gpu
is renamed tobatch_size
.The
workers_per_gpu
is renamed tonum_workers
.
Original |
data = dict(
samples_per_gpu=32,
workers_per_gpu=2,
train=dict(...),
val=dict(...),
test=dict(...),
)
|
New |
train_dataloader = dict(
batch_size=32,
num_workers=2,
dataset=dict(...),
sampler=dict(type='DefaultSampler', shuffle=True) # necessary
)
val_dataloader = dict(
batch_size=32,
num_workers=2,
dataset=dict(...),
sampler=dict(type='DefaultSampler', shuffle=False) # necessary
)
test_dataloader = val_dataloader
|
Changes in pipeline
:
The original formatting transforms
ToTensor
、ImageToTensor
、Collect
are combined asPackClsInputs
.We don’t recommend to do
Normalize
in the dataset pipeline. Please remove it from pipelines and set it in thedata_preprocessor
field.The argument
flip_prob
inRandomFlip
is renamed toflip
.The argument
size
inRandomCrop
is renamed tocrop_size
.The argument
size
inRandomResizedCrop
is renamed toscale
.The argument
size
inResize
is renamed toscale
. AndResize
won’t support size like(256, -1)
, please useResizeEdge
to replace it.The argument
policies
inAutoAugment
andRandAugment
supports using string to specify preset policies.AutoAugment
supports “imagenet” andRandAugment
supports “timm_increasing”.RandomResizedCrop
andCenterCrop
won’t supportsefficientnet_style
, and please useEfficientNetRandomCrop
andEfficientNetCenterCrop
to replace them.
Note
We move some work of data transforms to the data preprocessor, like normalization, see the documentation for more details.
Original |
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', size=224),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
|
New |
data_preprocessor = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', scale=224),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackClsInputs'),
]
|
Changes in evaluation
:
The
evaluation
field is splited toval_evaluator
andtest_evaluator
. And it won’t supportsinterval
andsave_best
arguments. Theinterval
is moved totrain_cfg.val_interval
, see the schedule settings and thesave_best
is moved todefault_hooks.checkpoint.save_best
, see the runtime settings.The ‘accuracy’ metric is renamed to
Accuracy
.The ‘precision’,’recall’,’f1-score’ and ‘support’ are combined as
SingleLabelMetric
, and useitems
argument to specify to calculate which metric.The ‘mAP’ is renamed to
AveragePrecision
.The ‘CP’, ‘CR’, ‘CF1’, ‘OP’, ‘OR’, ‘OF1’ are combined as
MultiLabelMetric
, and useitems
andaverage
arguments to specify to calculate which metric.
Original |
evaluation = dict(
interval=1,
metric='accuracy',
metric_options=dict(topk=(1, 5))
)
|
New |
val_evaluator = dict(type='Accuracy', topk=(1, 5))
test_evaluator = val_evaluator
|
Original |
evaluation = dict(
interval=1,
metric=['mAP', 'CP', 'OP', 'CR', 'OR', 'CF1', 'OF1'],
metric_options=dict(thr=0.5),
)
|
New |
val_evaluator = [
dict(type='AveragePrecision'),
dict(type='MultiLabelMetric',
items=['precision', 'recall', 'f1-score'],
average='both',
thr=0.5),
]
test_evaluator = val_evaluator
|
Schedule settings¶
Changes in optimizer
and optimizer_config
:
Now we use
optim_wrapper
field to specify all configuration about the optimization process. And theoptimizer
is a sub field ofoptim_wrapper
now.paramwise_cfg
is also a sub field ofoptim_wrapper
, instead ofoptimizer
.optimizer_config
is removed now, and all configurations of it are moved tooptim_wrapper
.grad_clip
is renamed toclip_grad
.
Original |
optimizer = dict(
type='AdamW',
lr=0.0015,
weight_decay=0.3,
paramwise_cfg = dict(
norm_decay_mult=0.0,
bias_decay_mult=0.0,
))
optimizer_config = dict(grad_clip=dict(max_norm=1.0))
|
New |
optim_wrapper = dict(
optimizer=dict(type='AdamW', lr=0.0015, weight_decay=0.3),
paramwise_cfg = dict(
norm_decay_mult=0.0,
bias_decay_mult=0.0,
),
clip_grad=dict(max_norm=1.0),
)
|
Changes in lr_config
:
The
lr_config
field is removed and we use newparam_scheduler
to replace it.The
warmup
related arguments are removed, since we use schedulers combination to implement this functionality.
The new schedulers combination mechanism is very flexible, and you can use it to design many kinds of learning rate / momentum curves. See the tutorial for more details.
Original |
lr_config = dict(
policy='CosineAnnealing',
min_lr=0,
warmup='linear',
warmup_iters=5,
warmup_ratio=0.01,
warmup_by_epoch=True)
|
New |
param_scheduler = [
# warmup
dict(
type='LinearLR',
start_factor=0.01,
by_epoch=True,
end=5,
# Update the learning rate after every iters.
convert_to_iter_based=True),
# main learning rate scheduler
dict(type='CosineAnnealingLR', by_epoch=True, begin=5),
]
|
Changes in runner
:
Most configuration in the original runner
field is moved to train_cfg
, val_cfg
and test_cfg
, which
configure the loop in training, validation and test.
Original |
runner = dict(type='EpochBasedRunner', max_epochs=100)
|
New |
# The `val_interval` is the original `evaluation.interval`.
train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1)
val_cfg = dict() # Use the default validation loop.
test_cfg = dict() # Use the default test loop.
|
In fact, in OpenMMLab 2.0, we introduced Loop
to control the behaviors in training, validation and test. And
the functionalities of Runner
are also changed. You can find more details in the MMEngine tutorials.
Runtime settings¶
Changes in checkpoint_config
and log_config
:
The checkpoint_config
are moved to default_hooks.checkpoint
and the log_config
are moved to default_hooks.logger
.
And we move many hooks settings from the script code to the default_hooks
field in the runtime configuration.
default_hooks = dict(
# record the time of every iterations.
timer=dict(type='IterTimerHook'),
# print log every 100 iterations.
logger=dict(type='LoggerHook', interval=100),
# enable the parameter scheduler.
param_scheduler=dict(type='ParamSchedulerHook'),
# save checkpoint per epoch, and automatically save the best checkpoint.
checkpoint=dict(type='CheckpointHook', interval=1, save_best='auto'),
# set sampler seed in distributed evrionment.
sampler_seed=dict(type='DistSamplerSeedHook'),
# validation results visualization, set True to enable it.
visualization=dict(type='VisualizationHook', enable=False),
)
In addition, we splited the original logger to logger and visualizer. The logger is used to record information and the visualizer is used to show the logger in different backends, like terminal, TensorBoard and Wandb.
Original |
log_config = dict(
interval=100,
hooks=[
dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook'),
])
|
New |
default_hooks = dict(
...
logger=dict(type='LoggerHook', interval=100),
)
visualizer = dict(
type='ClsVisualizer',
vis_backends=[dict(type='LocalVisBackend'), dict(type='TensorboardVisBackend')],
)
|
Changes in load_from
and resume_from
:
The
resume_from
is removed. And we useresume
andload_from
to replace it.If
resume=True
andload_from
is not None, resume training from the checkpoint inload_from
.If
resume=True
andload_from
is None, try to resume from the latest checkpoint in the work directory.If
resume=False
andload_from
is not None, only load the checkpoint, not resume training.If
resume=False
andload_from
is None, do not load nor resume.
Changes in dist_params
: The dist_params
field is a sub field of env_cfg
now. And there are some new
configurations in the env_cfg
.
env_cfg = dict(
# whether to enable cudnn benchmark
cudnn_benchmark=False,
# set multi process parameters
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
# set distributed parameters
dist_cfg=dict(backend='nccl'),
)
Changes in workflow
: workflow
related functionalities are removed.
New field visualizer
: The visualizer is a new design in OpenMMLab 2.0 architecture. We use a
visualizer instance in the runner to handle results & log visualization and save to different backends.
See the MMEngine tutorial for more details.
visualizer = dict(
type='ClsVisualizer',
vis_backends=[
dict(type='LocalVisBackend'),
# Uncomment the below line to save the log and visualization results to TensorBoard.
# dict(type='TensorboardVisBackend')
]
)
New field default_scope
: The start point to search module for all registries. The default_scope
in MMClassification is mmcls
. See the registry tutorial for more details.
Packages¶
mmcls.apis
¶
The documentation can be found here.
Function |
Changes |
---|---|
|
No changes |
|
No changes |
|
Removed, use |
|
Removed, use |
|
Removed, use |
|
Waiting for support. |
|
Removed, use |
|
Removed, use |
mmcls.core
¶
The mmcls.core
package is renamed to mmcls.engine
.
Sub package |
Changes |
---|---|
|
Removed, use the metrics in |
|
Moved to |
|
Moved to |
|
Removed, the distributed environment related functions can be found in the |
|
Removed, the related functionalities are implemented in |
The MMClsWandbHook
in hooks
package is waiting for implementation.
The CosineAnnealingCooldownLrUpdaterHook
in hooks
package is removed, and we support this functionality by
the combination of parameter schedulers, see the tutorial.
mmcls.datasets
¶
The documentation can be found here.
Dataset class |
Changes |
---|---|
Add |
|
Same as |
|
Same as |
|
The |
|
The |
|
Requires |
|
Requires |
The mmcls.datasets.pipelines
is renamed to mmcls.datasets.transforms
.
Transform class |
Changes |
---|---|
|
Removed, use |
|
Removed, use |
|
The argument |
|
The argument |
|
Removed, use |
|
Removed, use |
|
The argument |
|
Removed, use |
mmcls.models
¶
The documentation can be found here. The interface of all backbones, necks and losses didn’t change.
Changes in ImageClassifier
:
Method of classifiers |
Changes |
---|---|
|
No changes |
|
Now only accepts three arguments: |
|
Replaced by |
|
Replaced by |
|
The |
|
The original |
|
New method, and it’s the same as |
Changes in heads:
Method of heads |
Changes |
---|---|
|
No changes |
|
Replaced by |
|
Replaced by |
|
It accepts |
|
New method, and it returns the output of the classification head without any post-processs like softmax or sigmoid. |
mmcls.utils
¶
Function |
Changes |
---|---|
|
No changes |
|
Removed, use |
|
Waiting for support |
|
Removed, use |
|
Removed, we auto wrap the model in the runner. |
|
Removed, we auto wrap the model in the runner. |
|
Removed, we auto select the device in the runner. |
Other changes¶
We moved the definition of all registries in different packages to the
mmcls.registry
package.