Skip to content

Commit 4e63cb2

Browse files
authored
Merge 0f0f35a into 892f9e1
2 parents 892f9e1 + 0f0f35a commit 4e63cb2

File tree

8 files changed

+577
-2
lines changed

8 files changed

+577
-2
lines changed

configs/ddrnet/README.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# DDRNet
2+
3+
> [Deep Dual-resolution Networks for Real-time and Accurate Semantic Segmentation of Road Scenes](http://arxiv.org/abs/2101.06085)
4+
5+
## Introduction
6+
7+
<!-- [ALGORITHM] -->
8+
9+
<a href="https://github.com/ydhongHIT/DDRNet">Official Repo</a>
10+
11+
## Abstract
12+
13+
<!-- [ABSTRACT] -->
14+
15+
Semantic segmentation is a key technology for autonomous vehicles to understand the surrounding scenes. The appealing performances of contemporary models usually come at the expense of heavy computations and lengthy inference time, which is intolerable for self-driving. Using light-weight architectures (encoder-decoder or two-pathway) or reasoning on low-resolution images, recent methods realize very fast scene parsing, even running at more than 100 FPS on a single 1080Ti GPU. However, there is still a significant gap in performance between these real-time methods and the models based on dilation backbones. To tackle this problem, we proposed a family of efficient backbones specially designed for real-time semantic segmentation. The proposed deep dual-resolution networks (DDRNets) are composed of two deep branches between which multiple bilateral fusions are performed. Additionally, we design a new contextual information extractor named Deep Aggregation Pyramid Pooling Module (DAPPM) to enlarge effective receptive fields and fuse multi-scale context based on low-resolution feature maps. Our method achieves a new state-of-the-art trade-off between accuracy and speed on both Cityscapes and CamVid dataset. In particular, on a single 2080Ti GPU, DDRNet-23-slim yields 77.4% mIoU at 102 FPS on Cityscapes test set and 74.7% mIoU at 230 FPS on CamVid test set. With widely used test augmentation, our method is superior to most state-of-the-art models and requires much less computation. Codes and trained models are available online.
16+
17+
<!-- [IMAGE] -->
18+
19+
<!-- <div align=center>
20+
<img src="https://raw.githubusercontent.com/ydhongHIT/DDRNet/main/figs/DDRNet_seg.png" width="60%"/>
21+
</div> -->
22+
23+
## Results and models
24+
25+
### Cityscapes
26+
27+
| Method | Backbone | Crop Size | Lr schd | Mem(GB) | Inf time(fps) | Device | mIoU | mIoU(ms+flip) | config | download |
28+
| ------ | ------------- | --------- | ------- | ------- | ------------- | -------- | ----- | ------------- | ------------ | ------------ |
29+
| DDRNet | DDRNet23-slim | 1024x1024 | 120000 | | 85.85 | RTX 8000 | 77.85 | 79.80 | [config](<>) | model \| log |
30+
| DDRNet | DDRNet23 | 1024x1024 | 120000 | | 33.41 | RTX 8000 | 79.53 | 80.98 | [config](<>) | model \| log |
31+
| DDRNet | DDRNet39 | 1024x1024 | 120000 | | | RTX 8000 | | | [config](<>) | model \| log |
32+
33+
## Notes
34+
35+
The pretrained weights in config files are converted from [the official repo](https://github.com/ydhongHIT/DDRNet#pretrained-models).
36+
37+
## Citation
38+
39+
```bibtex
40+
@misc{hong2021ddrnet,
41+
title={Deep Dual-resolution Networks for Real-time and Accurate Semantic Segmentation of Road Scenes},
42+
author={Hong, Yuanduo and Pan, Huihui and Sun, Weichao and Jia, Yisong},
43+
year={2021},
44+
eprint={2101.06085},
45+
archivePrefix={arXiv},
46+
primaryClass={cs.CV},
47+
}
48+
```
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
_base_ = [
2+
'../_base_/datasets/cityscapes_1024x1024.py',
3+
'../_base_/default_runtime.py',
4+
]
5+
6+
# The class_weight is borrowed from https://github.com/openseg-group/OCNet.pytorch/issues/14 # noqa
7+
# Licensed under the MIT License
8+
class_weight = [
9+
0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 0.8786,
10+
1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 1.0865, 1.1529,
11+
1.0507
12+
]
13+
14+
crop_size = (1024, 1024)
15+
data_preprocessor = dict(
16+
type='SegDataPreProcessor',
17+
size=crop_size,
18+
mean=[123.675, 116.28, 103.53],
19+
std=[58.395, 57.12, 57.375],
20+
bgr_to_rgb=True,
21+
pad_val=0,
22+
seg_pad_val=255)
23+
norm_cfg = dict(type='SyncBN', requires_grad=True)
24+
model = dict(
25+
type='EncoderDecoder',
26+
data_preprocessor=data_preprocessor,
27+
backbone=dict(
28+
type='DDRNet',
29+
in_channels=3,
30+
channels=32,
31+
ppm_channels=128,
32+
norm_cfg=norm_cfg,
33+
align_corners=False,
34+
init_cfg=dict(
35+
type='Pretrained',
36+
checkpoint='pretrained/ddrnet23s_in1k_mmseg.pth')),
37+
decode_head=dict(
38+
type='DDRHead',
39+
in_channels=32 * 4,
40+
channels=64,
41+
dropout_ratio=0.,
42+
num_classes=19,
43+
align_corners=False,
44+
norm_cfg=norm_cfg,
45+
loss_decode=[
46+
dict(
47+
type='OhemCrossEntropy',
48+
thres=0.9,
49+
min_kept=131072,
50+
class_weight=class_weight,
51+
loss_weight=1.0),
52+
dict(
53+
type='OhemCrossEntropy',
54+
thres=0.9,
55+
min_kept=131072,
56+
class_weight=class_weight,
57+
loss_weight=0.4),
58+
]),
59+
60+
# model training and testing settings
61+
train_cfg=dict(),
62+
test_cfg=dict(mode='whole'))
63+
64+
train_dataloader = dict(batch_size=6, num_workers=4)
65+
66+
iters = 120000
67+
# optimizer
68+
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
69+
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
70+
# learning policy
71+
param_scheduler = [
72+
dict(
73+
type='PolyLR',
74+
eta_min=0,
75+
power=0.9,
76+
begin=0,
77+
end=iters,
78+
by_epoch=False)
79+
]
80+
81+
# training schedule for 120k
82+
train_cfg = dict(
83+
type='IterBasedTrainLoop', max_iters=iters, val_interval=iters // 10)
84+
val_cfg = dict(type='ValLoop')
85+
test_cfg = dict(type='TestLoop')
86+
default_hooks = dict(
87+
timer=dict(type='IterTimerHook'),
88+
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
89+
param_scheduler=dict(type='ParamSchedulerHook'),
90+
checkpoint=dict(
91+
type='CheckpointHook', by_epoch=False, interval=iters // 10),
92+
sampler_seed=dict(type='DistSamplerSeedHook'),
93+
visualization=dict(type='SegVisualizationHook'))
94+
95+
randomness = dict(seed=304)
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
_base_ = [
2+
'../_base_/datasets/cityscapes_1024x1024.py',
3+
'../_base_/default_runtime.py',
4+
]
5+
6+
# The class_weight is borrowed from https://github.com/openseg-group/OCNet.pytorch/issues/14 # noqa
7+
# Licensed under the MIT License
8+
class_weight = [
9+
0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 0.8786,
10+
1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 1.0865, 1.1529,
11+
1.0507
12+
]
13+
14+
crop_size = (1024, 1024)
15+
data_preprocessor = dict(
16+
type='SegDataPreProcessor',
17+
size=crop_size,
18+
mean=[123.675, 116.28, 103.53],
19+
std=[58.395, 57.12, 57.375],
20+
bgr_to_rgb=True,
21+
pad_val=0,
22+
seg_pad_val=255)
23+
norm_cfg = dict(type='SyncBN', requires_grad=True)
24+
model = dict(
25+
type='EncoderDecoder',
26+
data_preprocessor=data_preprocessor,
27+
backbone=dict(
28+
type='DDRNet',
29+
in_channels=3,
30+
channels=64,
31+
ppm_channels=128,
32+
norm_cfg=norm_cfg,
33+
align_corners=False,
34+
init_cfg=dict(
35+
type='Pretrained',
36+
checkpoint='pretrained/ddrnet23_in1k_mmseg.pth')),
37+
decode_head=dict(
38+
type='DDRHead',
39+
in_channels=64 * 4,
40+
channels=128,
41+
dropout_ratio=0.,
42+
num_classes=19,
43+
align_corners=False,
44+
norm_cfg=norm_cfg,
45+
loss_decode=[
46+
dict(
47+
type='OhemCrossEntropy',
48+
thres=0.9,
49+
min_kept=131072,
50+
class_weight=class_weight,
51+
loss_weight=1.0),
52+
dict(
53+
type='OhemCrossEntropy',
54+
thres=0.9,
55+
min_kept=131072,
56+
class_weight=class_weight,
57+
loss_weight=0.4),
58+
]),
59+
60+
# model training and testing settings
61+
train_cfg=dict(),
62+
test_cfg=dict(mode='whole'))
63+
64+
train_dataloader = dict(batch_size=6, num_workers=4)
65+
66+
iters = 120000
67+
# optimizer
68+
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
69+
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
70+
# learning policy
71+
param_scheduler = [
72+
dict(
73+
type='PolyLR',
74+
eta_min=0,
75+
power=0.9,
76+
begin=0,
77+
end=iters,
78+
by_epoch=False)
79+
]
80+
81+
# training schedule for 120k
82+
train_cfg = dict(
83+
type='IterBasedTrainLoop', max_iters=iters, val_interval=iters // 10)
84+
val_cfg = dict(type='ValLoop')
85+
test_cfg = dict(type='TestLoop')
86+
default_hooks = dict(
87+
timer=dict(type='IterTimerHook'),
88+
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
89+
param_scheduler=dict(type='ParamSchedulerHook'),
90+
checkpoint=dict(
91+
type='CheckpointHook', by_epoch=False, interval=iters // 10),
92+
sampler_seed=dict(type='DistSamplerSeedHook'),
93+
visualization=dict(type='SegVisualizationHook'))
94+
95+
randomness = dict(seed=304)

mmseg/models/backbones/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .bisenetv1 import BiSeNetV1
44
from .bisenetv2 import BiSeNetV2
55
from .cgnet import CGNet
6+
from .ddrnet import DDRNet
67
from .erfnet import ERFNet
78
from .fast_scnn import FastSCNN
89
from .hrnet import HRNet
@@ -28,5 +29,6 @@
2829
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
2930
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
3031
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT',
31-
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet', 'MSCAN'
32+
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet', 'MSCAN',
33+
'DDRNet'
3234
]

0 commit comments

Comments
 (0)