Train Model

The are three major parts that you want to look at for training a model on this repo.

  • configs/masktrack_rcnn_r50_fpn_1x_youtubevos.py
  • mmdet/datasets/ytvos.py
  • tools/train.py

config file

This file is where you define the model, set dataset, and control other hyperparameters.

The dataset_type is the classname for the target dataset, which is defined in mmdet/datasets. It will be explained in the next section. The data_root, ann_file, and img_prefix need to be updated according to your dataset. The code snippet below is adopted to the kitti mots dataset.

dataset_type = 'KittiDataset'
data_root = 'data/MOTS/'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
data = dict(
    imgs_per_gpu=4,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_train_sub_ped+car.json',
        img_prefix=data_root + 'images/image_combine',
        img_scale=(640, 360),
        img_norm_cfg=img_norm_cfg,
        size_divisor=32,
        flip_ratio=0.5,
        with_mask=True,
        with_crowd=True,
        with_label=True,
        with_track=True),
    val=dict(
        # hide for readability, 
    ),
    test=dict(
        # hide for readability
        with_mask=False,
        with_label=False,
        test_mode=True,
        with_track=True
    ))

Here you can control the optimizer, pre-trained model, and where to save the checkpoint (work_dir)

optimizer = dict(type='SGD', lr=0.002, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
    policy='step',
    warmup='linear',
    warmup_iters=500,
    warmup_ratio=1.0 / 3,
    step=[16, 48])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
    interval=50,
    hooks=[
        dict(type='TextLoggerHook'),
        # dict(type='TensorboardLoggerHook')
    ])
# yapf:enable
# runtime settings
total_epochs = 100
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/masktrack_rcnn_r50_fpn_1x_kitti_Aug_24_ped+car_resize'
load_from = 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/mask_rcnn_r50_fpn_1x_20181010-069fa190.pth'
resume_from = None

You also need to update the num_classes in bbox_head and mask_head to number of classes + 1. Here we are working on two classes, pedestrain and vehicle, so the num_classes is set to 3.

  bbox_head=dict(
        type='SharedFCBBoxHead',
        num_fcs=2,
        in_channels=256,
        fc_out_channels=1024,
        roi_feat_size=7,
        num_classes=3,
        target_means=[0., 0., 0., 0.],
        target_stds=[0.1, 0.1, 0.2, 0.2],
        reg_class_agnostic=False),
    ...
mask_head=dict(
        type='FCNMaskHead',
        num_convs=4,
        in_channels=256,
        conv_out_channels=256,
        num_classes=3))

dataset class

This files defines how the dataset loader. The way to reproduce it for kitti is to create a copy for ytvos.py and renamed as kitti.py. Then do the following modification

In mmdet/datasets/kitti.py

  • change the class name to “KittiDataset”
  • change CLASSES to equal to (‘person’,’car’)
  • in line 372, cast bbox element to float. x1, y1, w, h = list(map(float, bbox))

In mmdet/datasets/__init__.py

  • import kitti.py in ./mmdet/datasets/__init__.py, and add it to __all__.

Then run pip install -v -e . to apply the changes

train.py

This python file accept one argument to be the path to config file. Make sure everything in the config file is correct and you are in the MaskTrackRCNN conda environment before running train.py. The command is

python tools/train.py configs/masktrack_rcnn_r50_fpn_1x_kitti.py

The checkpoints will be saved at the work_dir you specified in the config file. Note that the default setting save the model every epoch, update it to a larger number if you don't have enough space to hold many checkpoints.