Train Model
The are three major parts that you want to look at for training a model on this repo.
- configs/masktrack_rcnn_r50_fpn_1x_youtubevos.py
- mmdet/datasets/ytvos.py
- tools/train.py
config file
This file is where you define the model, set dataset, and control other hyperparameters.
The dataset_type
is the classname for the target dataset, which is defined in mmdet/datasets. It will be explained in the next section. The data_root
, ann_file
, and img_prefix
need to be updated according to your dataset. The code snippet below is adopted to the kitti mots dataset.
dataset_type = 'KittiDataset'
data_root = 'data/MOTS/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
data = dict(
imgs_per_gpu=4,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train_sub_ped+car.json',
img_prefix=data_root + 'images/image_combine',
img_scale=(640, 360),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0.5,
with_mask=True,
with_crowd=True,
with_label=True,
with_track=True),
val=dict(
# hide for readability,
),
test=dict(
# hide for readability
with_mask=False,
with_label=False,
test_mode=True,
with_track=True
))
Here you can control the optimizer, pre-trained model, and where to save the checkpoint (work_dir
)
optimizer = dict(type='SGD', lr=0.002, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[16, 48])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 100
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/masktrack_rcnn_r50_fpn_1x_kitti_Aug_24_ped+car_resize'
load_from = 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/mask_rcnn_r50_fpn_1x_20181010-069fa190.pth'
resume_from = None
You also need to update the num_classes
in bbox_head and mask_head to number of classes + 1
. Here we are working on two classes, pedestrain and vehicle, so the num_classes is set to 3.
bbox_head=dict(
type='SharedFCBBoxHead',
num_fcs=2,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=3,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False),
...
mask_head=dict(
type='FCNMaskHead',
num_convs=4,
in_channels=256,
conv_out_channels=256,
num_classes=3))
dataset class
This files defines how the dataset loader. The way to reproduce it for kitti is to create a copy for ytvos.py and renamed as kitti.py. Then do the following modification
In mmdet/datasets/kitti.py
- change the class name to “KittiDataset”
- change CLASSES to equal to (‘person’,’car’)
- in line 372, cast bbox element to float.
x1, y1, w, h = list(map(float, bbox))
In mmdet/datasets/__init__.py
- import kitti.py in
./mmdet/datasets/__init__.py
, and add it to__all__
.
Then run pip install -v -e .
to apply the changes
train.py
This python file accept one argument to be the path to config file. Make sure everything in the config file is correct and you are in the MaskTrackRCNN conda environment before running train.py. The command is
python tools/train.py configs/masktrack_rcnn_r50_fpn_1x_kitti.py
The checkpoints will be saved at the work_dir you specified in the config file. Note that the default setting save the model every epoch, update it to a larger number if you don't have enough space to hold many checkpoints.