Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 1cac460

Browse files
Kh4Lapeforest
authored andcommitted
Make mrcnn_mask_target arg mask_size a 2d tuple (#16567)
Signed-off-by: Serge Panev <[email protected]>
1 parent c4580ae commit 1cac460

File tree

3 files changed

+83
-22
lines changed

3 files changed

+83
-22
lines changed

src/operator/contrib/mrcnn_target-inl.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,17 @@ namespace mrcnn_index {
4545
struct MRCNNTargetParam : public dmlc::Parameter<MRCNNTargetParam> {
4646
int num_rois;
4747
int num_classes;
48-
int mask_size;
4948
int sample_ratio;
49+
mxnet::TShape mask_size;
5050

5151
DMLC_DECLARE_PARAMETER(MRCNNTargetParam) {
5252
DMLC_DECLARE_FIELD(num_rois)
5353
.describe("Number of sampled RoIs.");
5454
DMLC_DECLARE_FIELD(num_classes)
5555
.describe("Number of classes.");
5656
DMLC_DECLARE_FIELD(mask_size)
57-
.describe("Size of the pooled masks.");
57+
.set_expect_ndim(2).enforce_nonzero()
58+
.describe("Size of the pooled masks height and width: (h, w).");
5859
DMLC_DECLARE_FIELD(sample_ratio).set_default(2)
5960
.describe("Sampling ratio of ROI align. Set to -1 to use adaptative size.");
6061
}
@@ -91,7 +92,8 @@ inline bool MRCNNTargetShape(const NodeAttrs& attrs,
9192
CHECK_EQ(tshape[0], batch_size) << " batch size should be the same for all the inputs.";
9293

9394
// out: 2 * (B, N, C, MS, MS)
94-
auto oshape = Shape5(batch_size, num_rois, param.num_classes, param.mask_size, param.mask_size);
95+
auto oshape = Shape5(batch_size, num_rois, param.num_classes,
96+
param.mask_size[0], param.mask_size[1]);
9597
out_shape->clear();
9698
out_shape->push_back(oshape);
9799
out_shape->push_back(oshape);

src/operator/contrib/mrcnn_target.cu

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -183,28 +183,29 @@ __device__ void RoIAlignForward(
183183

184184

185185
template<typename DType>
186-
__global__ void MRCNNTargetKernel(const DType *rois,
187-
const DType *gt_masks,
188-
const DType *matches,
189-
const DType *cls_targets,
190-
DType* sampled_masks,
191-
DType* mask_cls,
192-
const int total_out_el,
193-
int batch_size,
194-
int num_classes,
195-
int num_rois,
196-
int num_gtmasks,
197-
int gt_height,
198-
int gt_width,
199-
int mask_size,
200-
int sample_ratio) {
186+
__global__ void MRCNNMaskTargetKernel(const DType *rois,
187+
const DType *gt_masks,
188+
const DType *matches,
189+
const DType *cls_targets,
190+
DType* sampled_masks,
191+
DType* mask_cls,
192+
const int total_out_el,
193+
int batch_size,
194+
int num_classes,
195+
int num_rois,
196+
int num_gtmasks,
197+
int gt_height,
198+
int gt_width,
199+
int mask_size_h,
200+
int mask_size_w,
201+
int sample_ratio) {
201202
// computing sampled_masks
202203
RoIAlignForward(gt_masks, rois, matches, total_out_el,
203-
num_classes, gt_height, gt_width, mask_size, mask_size,
204+
num_classes, gt_height, gt_width, mask_size_h, mask_size_w,
204205
sample_ratio, num_rois, num_gtmasks, sampled_masks);
205206
// computing mask_cls
206207
int num_masks = batch_size * num_rois * num_classes;
207-
int mask_vol = mask_size * mask_size;
208+
int mask_vol = mask_size_h * mask_size_w;
208209
for (int mask_idx = blockIdx.x; mask_idx < num_masks; mask_idx += gridDim.x) {
209210
int cls_idx = mask_idx % num_classes;
210211
int roi_idx = (mask_idx / num_classes) % num_rois;
@@ -252,8 +253,9 @@ void MRCNNTargetRun<gpu>(const MRCNNTargetParam& param, const std::vector<TBlob>
252253
(rois.dptr_, gt_masks.dptr_, matches.dptr_, cls_targets.dptr_,
253254
out_masks.dptr_, out_mask_cls.dptr_,
254255
num_el, batch_size, param.num_classes, param.num_rois,
255-
num_gtmasks, gt_height, gt_width, param.mask_size, param.sample_ratio);
256-
MSHADOW_CUDA_POST_KERNEL_CHECK(MRCNNTargetKernel);
256+
num_gtmasks, gt_height, gt_width,
257+
param.mask_size[0], param.mask_size[1], param.sample_ratio);
258+
MSHADOW_CUDA_POST_KERNEL_CHECK(MRCNNMaskTargetKernel);
257259
});
258260
}
259261

tests/python/unittest/test_contrib_operator.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,63 @@ def test_box_decode_op():
351351
assert_allclose(Y.asnumpy(), np.array([[[-0.0562755, -0.00865743, 0.26227552, 0.42465743], \
352352
[0.13240421, 0.17859563, 0.93759584, 1.1174043 ]]]), atol=1e-5, rtol=1e-5)
353353

354+
@with_seed()
355+
def test_op_mrcnn_mask_target():
356+
if default_context().device_type != 'gpu':
357+
return
358+
359+
num_rois = 2
360+
num_classes = 4
361+
mask_size = (3, 3)
362+
ctx = mx.gpu(0)
363+
# (B, N, 4)
364+
rois = mx.nd.array([[[2.3, 4.3, 2.2, 3.3],
365+
[3.5, 5.5, 0.9, 2.4]]], ctx=ctx)
366+
gt_masks = mx.nd.arange(0, 4*32*32, ctx=ctx).reshape(1, 4, 32, 32)
367+
368+
# (B, N)
369+
matches = mx.nd.array([[2, 0]], ctx=ctx)
370+
# (B, N)
371+
cls_targets = mx.nd.array([[2, 1]], ctx=ctx)
372+
373+
mask_targets, mask_cls = mx.nd.contrib.mrcnn_mask_target(rois, gt_masks, matches, cls_targets,
374+
num_rois=num_rois,
375+
num_classes=num_classes,
376+
mask_size=mask_size)
377+
378+
# Ground truth outputs were generated with GluonCV's target generator
379+
# gluoncv.model_zoo.mask_rcnn.MaskTargetGenerator(1, num_rois, num_classes, mask_size)
380+
gt_mask_targets = mx.nd.array([[[[[2193.4 , 2193.7332 , 2194.0667 ],
381+
[2204.0667 , 2204.4 , 2204.7334 ],
382+
[2214.7334 , 2215.0667 , 2215.4 ]],
383+
[[2193.4 , 2193.7332 , 2194.0667 ],
384+
[2204.0667 , 2204.4 , 2204.7334 ],
385+
[2214.7334 , 2215.0667 , 2215.4 ]],
386+
[[2193.4 , 2193.7332 , 2194.0667 ],
387+
[2204.0667 , 2204.4 , 2204.7334 ],
388+
[2214.7334 , 2215.0667 , 2215.4 ]],
389+
[[2193.4 , 2193.7332 , 2194.0667 ],
390+
[2204.0667 , 2204.4 , 2204.7334 ],
391+
[2214.7334 , 2215.0667 , 2215.4 ]]],
392+
[[[ 185. , 185.33334, 185.66667],
393+
[ 195.66667, 196.00002, 196.33334],
394+
[ 206.33333, 206.66666, 207. ]],
395+
[[ 185. , 185.33334, 185.66667],
396+
[ 195.66667, 196.00002, 196.33334],
397+
[ 206.33333, 206.66666, 207. ]],
398+
[[ 185. , 185.33334, 185.66667],
399+
[ 195.66667, 196.00002, 196.33334],
400+
[ 206.33333, 206.66666, 207. ]],
401+
[[ 185. , 185.33334, 185.66667],
402+
[ 195.66667, 196.00002, 196.33334],
403+
[ 206.33333, 206.66666, 207. ]]]]])
404+
405+
gt_mask_cls = mx.nd.array([[0,0,1,0], [0,1,0,0]])
406+
gt_mask_cls = gt_mask_cls.reshape(1,2,4,1,1).broadcast_axes(axis=(3,4), size=(3,3))
407+
408+
assert_almost_equal(mask_targets.asnumpy(), gt_mask_targets.asnumpy())
409+
assert_almost_equal(mask_cls.asnumpy(), gt_mask_cls.asnumpy())
410+
354411
if __name__ == '__main__':
355412
import nose
356413
nose.runmodule()

0 commit comments

Comments
 (0)