Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 90954ec

Browse files
dtraczptrendx
authored andcommitted
make TransposeShape infer shape form both sides (#15713)
* make TransposeShape infer shape form both sides * small fixes * remove redundant lines * unit tests
1 parent b3064c5 commit 90954ec

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

src/operator/tensor/matrix_op-inl.h

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,19 +344,34 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs,
344344
CHECK_EQ(in_attrs->size(), 1U);
345345
CHECK_EQ(out_attrs->size(), 1U);
346346
mxnet::TShape& shp = (*in_attrs)[0];
347+
mxnet::TShape& out_shp = (*out_attrs)[0];
347348
CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";
348-
mxnet::TShape ret(shp.ndim(), -1);
349+
CHECK_NE(shp.ndim(), 0) << "Number of dimensions cannot be 0";
350+
CHECK_NE(out_shp.ndim(), 0) << "Number of dimensions cannot be 0";
351+
if (shp.ndim() == -1 && out_shp.ndim() == -1)
352+
return false; // none of the shapes is known
353+
if (out_shp.ndim() > 0 && shp.ndim() > 0)
354+
CHECK_EQ(out_shp.ndim(), shp.ndim());
355+
mxnet::TShape get(std::max(shp.ndim(), out_shp.ndim()), -1);
356+
mxnet::TShape ret(std::max(shp.ndim(), out_shp.ndim()), -1);
349357
if (param.axes.ndim() == 0) {
350358
for (int i = 0; i < shp.ndim(); ++i) {
351359
ret[i] = shp[shp.ndim()-1-i];
352360
}
361+
for (int i = 0; i < out_shp.ndim(); ++i) {
362+
get[shp.ndim()-1-i] = out_shp[i];
363+
}
353364
} else {
354-
CHECK_EQ(shp.ndim(), param.axes.ndim());
365+
CHECK_EQ(std::max(shp.ndim(), out_shp.ndim()), param.axes.ndim());
355366
for (int i = 0; i < shp.ndim(); ++i) {
356367
CHECK(param.axes[i] < static_cast<int64_t>(shp.ndim()));
357368
ret[i] = shp[param.axes[i]];
358369
}
370+
for (int i = 0; i < out_shp.ndim(); ++i) {
371+
get[param.axes[i]] = out_shp[i];
372+
}
359373
}
374+
SHAPE_ASSIGN_CHECK(*in_attrs, 0, get);
360375
SHAPE_ASSIGN_CHECK(*out_attrs, 0, ret);
361376
return shape_is_known(ret);
362377
}

tests/python/unittest/test_operator.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8970,6 +8970,26 @@ def test_get_operator_arguments():
89708970
ok_(operator_arguments.narg == 2)
89718971

89728972

8973+
def test_transpose_infer_shape_back():
8974+
o1 = mx.sym.ones(shape=[2,3])
8975+
o2 = mx.sym.ones(shape=[-1,-1])
8976+
t = mx.sym.transpose(o2)
8977+
b = o1 + t
8978+
x = b.bind(mx.cpu(), args={})
8979+
y = x.forward()
8980+
assert(y[0].shape == (2,3))
8981+
8982+
8983+
def test_transpose_infer_shape_mixed():
8984+
o1 = mx.sym.ones(shape=[2,-1])
8985+
o2 = mx.sym.ones(shape=[3,-1])
8986+
t = mx.sym.transpose(o2)
8987+
b = o1 + t
8988+
x = b.bind(mx.cpu(), args={})
8989+
y = x.forward()
8990+
assert(y[0].shape == (2,3))
8991+
8992+
89738993
if __name__ == '__main__':
89748994
import nose
89758995
nose.runmodule()

0 commit comments

Comments
 (0)