Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Numpy compatible max min #16046

Merged
merged 2 commits into from
Aug 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions src/operator/numpy/np_broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,24 @@ struct NumpyReduceAxesParam : public dmlc::Parameter<NumpyReduceAxesParam> {
}
};

struct NumpyReduceAxesNoDTypeParam : public dmlc::Parameter<NumpyReduceAxesNoDTypeParam> {
dmlc::optional<mxnet::Tuple<int>> axis;
bool keepdims;
dmlc::optional<double> initial;
DMLC_DECLARE_PARAMETER(NumpyReduceAxesNoDTypeParam) {
DMLC_DECLARE_FIELD(axis)
.set_default(dmlc::optional<mxnet::Tuple<int>>())
.describe("Axis or axes along which a sum is performed. The default, axis=None, will sum "
"all of the elements of the input array. If axis is negative it counts from the "
"last to the first axis.");
DMLC_DECLARE_FIELD(keepdims).set_default(false)
.describe("If this is set to `True`, the reduced axes are left "
"in the result as dimension with size one.");
DMLC_DECLARE_FIELD(initial).set_default(dmlc::optional<double>())
.describe("Starting value for the sum.");
}
};

inline TShape NumpyReduceAxesShapeImpl(const TShape& ishape,
const dmlc::optional<mxnet::Tuple<int>>& axis,
bool keepdims) {
Expand Down Expand Up @@ -152,6 +170,39 @@ inline bool NumpyReduceAxesShape(const nnvm::NodeAttrs& attrs,
return shape_is_known(out_attrs->at(0));
}

inline bool NumpyReduceAxesNoDTypeShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
if (!shape_is_known(in_attrs->at(0))) {
return false;
}
const NumpyReduceAxesNoDTypeParam& param = nnvm::get<NumpyReduceAxesNoDTypeParam>(attrs.parsed);
// check the case where the reduction axis should not be zero
bool is_all_reducded_axes_not_zero = true;
const TShape& ishape = (*in_attrs)[0];
if (param.axis.has_value()) {
const mxnet::Tuple<int>& axes = param.axis.value();
for (int i = 0; i < axes.ndim(); ++i) {
if (ishape[axes[i]] == 0) {
is_all_reducded_axes_not_zero = false;
break;
}
}
} else {
if (ishape.Size() == 0) {
// global reduction should excuted only when input have size more than 0
is_all_reducded_axes_not_zero = false;
}
}
CHECK(is_all_reducded_axes_not_zero)
<< "zero-size array to reduction operation maximum which has no identity";
SHAPE_ASSIGN_CHECK(*out_attrs, 0,
NumpyReduceAxesShapeImpl((*in_attrs)[0], param.axis, param.keepdims));
return shape_is_known(out_attrs->at(0));
}

template<bool safe_acc_hint = false>
inline bool NeedSafeAcc(int itype, int otype) {
bool rule = (itype != otype) || (itype != mshadow::kFloat32 && itype != mshadow::kFloat64);
Expand Down Expand Up @@ -186,6 +237,30 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs,
}
}

template<typename xpu, typename reducer, typename OP = op::mshadow_op::identity>
void NumpyReduceAxesNoDTypeCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const NumpyReduceAxesNoDTypeParam& param = nnvm::get<NumpyReduceAxesNoDTypeParam>(attrs.parsed);
if (param.initial.has_value()) {
LOG(FATAL) << "initial is not supported yet";
}
if (inputs[0].shape_.Size() == 0U || outputs[0].shape_.Size() == 0U) return; // zero-size tensor
if (param.axis.has_value() && param.axis.value().ndim() == 0) {
UnaryOp::IdentityCompute<xpu>(attrs, ctx, inputs, req, outputs);
}
TShape small;
if (param.keepdims) {
small = outputs[0].shape_;
} else {
small = NumpyReduceAxesShapeImpl(inputs[0].shape_, param.axis, true);
}
ReduceAxesComputeImpl<xpu, reducer, false, false, OP>(ctx, inputs, req, outputs, small);
}


template<typename xpu, bool normalize = false>
inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand Down Expand Up @@ -273,6 +348,24 @@ void NumpyBroadcastToBackward(const nnvm::NodeAttrs& attrs,
}
}

template<typename xpu, typename OP>
void NumpyReduceAxesNoDTypeBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
const NumpyReduceAxesNoDTypeParam& param = nnvm::get<NumpyReduceAxesNoDTypeParam>(attrs.parsed);
TShape small;
if (param.keepdims) {
small = inputs[0].shape_;
} else {
small = NumpyReduceAxesShapeImpl(outputs[0].shape_, param.axis, true);
}
ReduceAxesBackwardUseInOutImpl<xpu, OP, false>(ctx, small, inputs, req, outputs);
}

} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_
66 changes: 66 additions & 0 deletions src/operator/numpy/np_broadcast_reduce_op_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace mxnet {
namespace op {

DMLC_REGISTER_PARAMETER(NumpyReduceAxesParam);
DMLC_REGISTER_PARAMETER(NumpyReduceAxesNoDTypeParam);

inline bool NumpySumType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
Expand Down Expand Up @@ -74,6 +75,71 @@ NNVM_REGISTER_OP(_backward_np_sum)
.set_num_inputs(1)
.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesBackwardUseNone<cpu>);

inline bool NumpyReduceAxesNoDTypeType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));

return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;
}

NNVM_REGISTER_OP(_np_max)
.describe(R"code()code" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyReduceAxesNoDTypeParam>)
.set_attr<mxnet::FInferShape>("FInferShape", NumpyReduceAxesNoDTypeShape)
.set_attr<nnvm::FInferType>("FInferType", NumpyReduceAxesNoDTypeType)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"a"};
})
.add_argument("a", "NDArray-or-Symbol", "The input")
.add_arguments(NumpyReduceAxesNoDTypeParam::__FIELDS__())
.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesNoDTypeCompute<cpu, mshadow::red::maximum>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<nnvm::FGradient>("FGradient", ReduceGrad{"_backward_np_max"});

NNVM_REGISTER_OP(_backward_np_max)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyReduceAxesNoDTypeParam>)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_num_inputs(3)
.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesNoDTypeBackward<cpu, mshadow_op::eq>);

NNVM_REGISTER_OP(_np_min)
.describe(R"code()code" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyReduceAxesNoDTypeParam>)
.set_attr<mxnet::FInferShape>("FInferShape", NumpyReduceAxesNoDTypeShape)
.set_attr<nnvm::FInferType>("FInferType", NumpyReduceAxesNoDTypeType)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"a"};
})
.add_argument("a", "NDArray-or-Symbol", "The input")
.add_arguments(NumpyReduceAxesNoDTypeParam::__FIELDS__())
.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesNoDTypeCompute<cpu, mshadow::red::minimum>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<nnvm::FGradient>("FGradient", ReduceGrad{"_backward_np_min"});

NNVM_REGISTER_OP(_backward_np_min)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyReduceAxesNoDTypeParam>)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_num_inputs(3)
.set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesNoDTypeBackward<cpu, mshadow_op::eq>);

NNVM_REGISTER_OP(_np_prod)
.set_num_inputs(1)
.set_num_outputs(1)
Expand Down
12 changes: 12 additions & 0 deletions src/operator/numpy/np_broadcast_reduce_op_value.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ NNVM_REGISTER_OP(_np_sum)
NNVM_REGISTER_OP(_backward_np_sum)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesBackwardUseNone<gpu>);

NNVM_REGISTER_OP(_np_max)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesNoDTypeCompute<gpu, mshadow::red::maximum>);

NNVM_REGISTER_OP(_backward_np_max)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesNoDTypeBackward<gpu, mshadow_op::eq>);

NNVM_REGISTER_OP(_np_min)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesNoDTypeCompute<gpu, mshadow::red::minimum>);

NNVM_REGISTER_OP(_backward_np_min)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesNoDTypeBackward<gpu, mshadow_op::eq>);

NNVM_REGISTER_OP(_np_prod)
.set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesCompute<gpu, mshadow_op::product, true>);

Expand Down
113 changes: 113 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,119 @@ def is_int(dtype):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)


@with_seed()
@use_np
def test_np_max_min():
class TestMax(HybridBlock):
def __init__(self, axis=None, keepdims=False):
super(TestMax, self).__init__()
self._axis = axis
self._keepdims = keepdims

def hybrid_forward(self, F, a, *args, **kwargs):
return F.np.max(a, axis=self._axis, keepdims=self._keepdims)

class TestMin(HybridBlock):
def __init__(self, axis=None, keepdims=False):
super(TestMin, self).__init__()
self._axis = axis
self._keepdims = keepdims

def hybrid_forward(self, F, a, *args, **kwargs):
return F.np.min(a, axis=self._axis, keepdims=self._keepdims)

def is_int(dtype):
return 'int' == dtype

def get_grad(axis, func_name):
index = -1 if func_name == 'max' else 0
if axis == ():
return _np.ones((2,3,4,5))
else:
temp = _np.zeros((2,3,4,5))
if axis == 0:
temp[index,:,:,:] = 1
return temp
elif axis == 1:
temp[:,index,:,:] = 1
return temp
elif axis == 2:
temp[:,:,index,:] = 1
return temp
elif axis == 3:
temp[:,:,:,index] = 1
return temp
elif not axis:
temp[index,index,index,index] = 1
return temp
raise ValueError('axis should be int or None or ()')

def _test_np_exception(func, shape, dim):
x = _np.random.uniform(-1.0, 1.0, shape)
x = mx.nd.array(x).as_np_ndarray()
if func == 'max':
out = mx.np.max(x)
else:
out = mx.np.min(x)
assert out.ndim == dim, 'dimension mismatch, output.ndim={}, dim={}'.format(output.ndim, dim)

in_data_dim = random.choice([2, 3, 4])
shape = rand_shape_nd(in_data_dim, dim=3)
for func in ['max', 'min']:
for hybridize in [False, True]:
for keepdims in [True, False]:
for axis in ([i for i in range(in_data_dim)] + [(), None]):
for itype in ['float16', 'float32', 'float64', 'int']:
# test gluon
if func == 'max':
test_gluon = TestMax(axis=axis, keepdims=keepdims)
else:
test_gluon = TestMin(axis=axis, keepdims=keepdims)
if hybridize:
test_gluon.hybridize()
if is_int(itype):
x = mx.nd.arange(120).reshape((2, 3, 4, 5))
x = mx.nd.array(x)
else:
x = mx.nd.random.uniform(-1.0, 1.0, shape=shape, dtype=itype)
x = x.as_np_ndarray()
x.attach_grad()
if func == 'max':
expected_ret = _np.amax(x.asnumpy(), axis=axis, keepdims=keepdims)
else:
expected_ret = _np.amin(x.asnumpy(), axis=axis, keepdims=keepdims)
with mx.autograd.record():
y = test_gluon(x)
assert y.shape == expected_ret.shape
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if itype == 'float16' else 1e-3,
atol=1e-5 if itype == 'float16' else 1e-5)
y.backward()
# only check the gradient with hardcoded input
if is_int(itype):
assert same(x.grad.asnumpy(), get_grad(axis, func)), \
'x={}\ny={}\nx.grad={}\nnumpy={}'.format(x.asnumpy(), y.asnumpy(), x.grad.asnumpy(), get_grad(axis))

# test imperative
if func == 'max':
mx_out = np.max(x, axis=axis, keepdims=keepdims)
np_out = _np.amax(x.asnumpy(), axis=axis, keepdims=keepdims)
else:
mx_out = np.min(x, axis=axis, keepdims=keepdims)
np_out = _np.amin(x.asnumpy(), axis=axis, keepdims=keepdims)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)

# test zero and zero dim
shapes = [(), (0), (2, 0), (0, 2, 1)]
exceptions = [False, True, True, True]
dims = [0] * len(shapes)
for func in ['max', 'min']:
for shape, exception, dim in zip(shapes, exceptions, dims):
if exception:
assertRaises(MXNetError, _test_np_exception, func, shape, dim)
else:
_test_np_exception(func, shape, dim)


@with_seed()
@use_np
def test_np_linspace():
Expand Down