Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 69f1ad9

Browse files
committed
Make mrcnn_mask_target arg mask_size a 2d tuple
Signed-off-by: Serge Panev <[email protected]>
1 parent 746cbc5 commit 69f1ad9

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

src/operator/contrib/mrcnn_mask_target-inl.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,17 @@ namespace mrcnn_index {
4545
struct MRCNNMaskTargetParam : public dmlc::Parameter<MRCNNMaskTargetParam> {
4646
int num_rois;
4747
int num_classes;
48-
int mask_size;
4948
int sample_ratio;
49+
mxnet::TShape mask_size;
5050

5151
DMLC_DECLARE_PARAMETER(MRCNNMaskTargetParam) {
5252
DMLC_DECLARE_FIELD(num_rois)
5353
.describe("Number of sampled RoIs.");
5454
DMLC_DECLARE_FIELD(num_classes)
5555
.describe("Number of classes.");
5656
DMLC_DECLARE_FIELD(mask_size)
57-
.describe("Size of the pooled masks.");
57+
.set_expect_ndim(2).enforce_nonzero()
58+
.describe("Size of the pooled masks height and width: (h, w).");
5859
DMLC_DECLARE_FIELD(sample_ratio).set_default(2)
5960
.describe("Sampling ratio of ROI align. Set to -1 to use adaptative size.");
6061
}
@@ -91,7 +92,8 @@ inline bool MRCNNMaskTargetShape(const NodeAttrs& attrs,
9192
CHECK_EQ(tshape[0], batch_size) << " batch size should be the same for all the inputs.";
9293

9394
// out: 2 * (B, N, C, MS, MS)
94-
auto oshape = Shape5(batch_size, num_rois, param.num_classes, param.mask_size, param.mask_size);
95+
auto oshape = Shape5(batch_size, num_rois, param.num_classes,
96+
param.mask_size[0], param.mask_size[1]);
9597
out_shape->clear();
9698
out_shape->push_back(oshape);
9799
out_shape->push_back(oshape);

src/operator/contrib/mrcnn_mask_target.cu

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,15 +196,16 @@ __global__ void MRCNNMaskTargetKernel(const DType *rois,
196196
int num_gtmasks,
197197
int gt_height,
198198
int gt_width,
199-
int mask_size,
199+
int mask_size_h,
200+
int mask_size_w,
200201
int sample_ratio) {
201202
// computing sampled_masks
202203
RoIAlignForward(gt_masks, rois, matches, total_out_el,
203-
num_classes, gt_height, gt_width, mask_size, mask_size,
204+
num_classes, gt_height, gt_width, mask_size_h, mask_size_w,
204205
sample_ratio, num_rois, num_gtmasks, sampled_masks);
205206
// computing mask_cls
206207
int num_masks = batch_size * num_rois * num_classes;
207-
int mask_vol = mask_size * mask_size;
208+
int mask_vol = mask_size_h * mask_size_w;
208209
for (int mask_idx = blockIdx.x; mask_idx < num_masks; mask_idx += gridDim.x) {
209210
int cls_idx = mask_idx % num_classes;
210211
int roi_idx = (mask_idx / num_classes) % num_rois;
@@ -252,7 +253,8 @@ void MRCNNMaskTargetRun<gpu>(const MRCNNMaskTargetParam& param, const std::vecto
252253
(rois.dptr_, gt_masks.dptr_, matches.dptr_, cls_targets.dptr_,
253254
out_masks.dptr_, out_mask_cls.dptr_,
254255
num_el, batch_size, param.num_classes, param.num_rois,
255-
num_gtmasks, gt_height, gt_width, param.mask_size, param.sample_ratio);
256+
num_gtmasks, gt_height, gt_width,
257+
param.mask_size[0], param.mask_size[1], param.sample_ratio);
256258
MSHADOW_CUDA_POST_KERNEL_CHECK(MRCNNMaskTargetKernel);
257259
});
258260
}

tests/python/unittest/test_contrib_operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def test_op_mrcnn_mask_target():
359359

360360
num_rois = 2
361361
num_classes = 4
362-
mask_size = 3
362+
mask_size = (3, 3)
363363
ctx = mx.gpu(0)
364364
# (B, N, 4)
365365
rois = mx.nd.array([[[2.3, 4.3, 2.2, 3.3],

0 commit comments

Comments
 (0)