Skip to content

Commit 4bc265c

Browse files
kshitij12345Ubuntu
authored and
Ubuntu
committed
[MXNET-978] Add higher order gradient support tan, tanh (apache#15253)
* init to reset * issue: higher order backward sigmoid * update gradient code. update code as per apache#15288. * undo changes * relax tolerance of gradient mismatch for tanh * update comments * update comments
1 parent 0f35cdf commit 4bc265c

File tree

2 files changed

+93
-4
lines changed

2 files changed

+93
-4
lines changed

src/operator/tensor/elemwise_unary_op_trig.cc

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,33 @@ The storage type of ``tan`` output depends upon the input storage type:
139139
)code" ADD_FILELINE)
140140
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{ "_backward_tan" });
141141

142-
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_tan, unary_bwd<mshadow_op::tan_grad>);
142+
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_tan, unary_bwd<mshadow_op::tan_grad>)
143+
.set_attr<nnvm::FGradient>("FGradient",
144+
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
145+
// NodeEntry{n} : y_grad * f'(x)
146+
// n->inputs[0] : y_grad (dL/dy)
147+
// n->inputs[1] : y = f(x) = tan(x) (ElemwiseGradUseOut)
148+
// ograds[0] : head_grads (dL/dxgrad)
149+
// f'(x) = sec^2(x)
150+
// f''(x) = 2 * f'(x) * f(x)
151+
//
152+
// Note: When building gradient graph, the backward node of n->inputs[1] will be
153+
// added to the graph again, therefore f`(x) will be multiplied
154+
// So we need to compute only -> 2 * f(x) * dL/dy_grad * y_grad
155+
const std::unordered_map<std::string, std::string> args = {{"scalar", "2.0"}};
156+
auto two_y = MakeNode("_mul_scalar", n->attrs.name + "_mul_two", {n->inputs[1]}, &args, &n);
157+
auto grad_grad_mid = MakeNode("elemwise_mul", n->attrs.name + "_grad_mul",
158+
{n->inputs[0], nnvm::NodeEntry{two_y}}, nullptr, &n);
159+
auto dydx = MakeNode("elemwise_div", n->attrs.name + "_grad_div",
160+
{nnvm::NodeEntry{n}, n->inputs[0]}, nullptr, &n);
161+
162+
std::vector<nnvm::NodeEntry> ret;
163+
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad",
164+
{ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n));
165+
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad_in",
166+
{ograds[0], nnvm::NodeEntry{grad_grad_mid}}, nullptr, &n));
167+
return ret;
168+
});
143169

144170
// arcsin
145171
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(arcsin, cpu, mshadow_op::arcsin)
@@ -290,7 +316,34 @@ The storage type of ``tanh`` output depends upon the input storage type:
290316
)code" ADD_FILELINE)
291317
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{ "_backward_tanh" });
292318

293-
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_tanh, unary_bwd<mshadow_op::tanh_grad>);
319+
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_tanh, unary_bwd<mshadow_op::tanh_grad>)
320+
.set_attr<nnvm::FGradient>("FGradient",
321+
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
322+
// NodeEntry{n} : y_grad * f'(x)
323+
// n->inputs[0] : y_grad (dL/dy)
324+
// n->inputs[1] : y = f(x) = tanh(x) (ElemwiseGradUseOut)
325+
// ograds[0] : head_grads dL/dxgrad
326+
// f'(x) = sech^2(x)
327+
// f''(x) = -2 * f'(x) * f(x)
328+
//
329+
// Note: when building gradient graph, the backward node of n->inputs[1] will be
330+
// added to the graph again, therefore f`(x) will be multiplied
331+
// So we need to compute only -> -2 * f(x) * dL/dy_grad * y_grad
332+
const std::unordered_map<std::string, std::string> args = {{"scalar", "-2.0"}};
333+
auto neg_two_y = MakeNode("_mul_scalar", n->attrs.name + "_mul_neg_two",
334+
{n->inputs[1]}, &args, &n);
335+
auto grad_grad_mid = MakeNode("elemwise_mul", n->attrs.name + "_grad_mul",
336+
{n->inputs[0], nnvm::NodeEntry{neg_two_y}}, nullptr, &n);
337+
auto dydx = MakeNode("elemwise_div", n->attrs.name + "_grad_div",
338+
{nnvm::NodeEntry{n}, n->inputs[0]}, nullptr, &n);
339+
340+
std::vector<nnvm::NodeEntry> ret;
341+
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad",
342+
{ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n));
343+
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad_in",
344+
{ograds[0], nnvm::NodeEntry{grad_grad_mid}}, nullptr, &n));
345+
return ret;
346+
});
294347

295348
// arcsinh
296349
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(arcsinh, cpu, mshadow_op::arcsinh)

tests/python/unittest/test_higher_order_grad.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,41 @@ def grad_grad_op(x):
5050
check_second_order_unary(array, cos, grad_grad_op)
5151

5252

53+
@with_seed()
54+
def test_tan():
55+
def tan(x):
56+
return nd.tan(x)
57+
58+
def grad_op(x):
59+
return 1 / nd.cos(x)**2
60+
61+
def grad_grad_op(x):
62+
return 2 * tan(x) * grad_op(x)
63+
64+
for dim in range(1, 5):
65+
shape = rand_shape_nd(dim)
66+
array = random_arrays(shape)
67+
check_second_order_unary(array, tan, grad_grad_op)
68+
69+
70+
@with_seed()
71+
def test_tanh():
72+
def tanh(x):
73+
return nd.tanh(x)
74+
75+
def grad_op(x):
76+
return 1 / nd.cosh(x)**2
77+
78+
def grad_grad_op(x):
79+
return -2 * tanh(x) * grad_op(x)
80+
81+
for dim in range(1, 5):
82+
shape = rand_shape_nd(dim)
83+
array = random_arrays(shape)
84+
check_second_order_unary(
85+
array, tanh, grad_grad_op, rtol=1e-6, atol=1e-6)
86+
87+
5388
@with_seed()
5489
def test_relu():
5590
def relu(x):
@@ -150,7 +185,7 @@ def grad_grad_op(x):
150185
check_second_order_unary(array, sigmoid, grad_grad_op)
151186

152187

153-
def check_second_order_unary(x, op, grad_grad_op):
188+
def check_second_order_unary(x, op, grad_grad_op, rtol=None, atol=None):
154189
x = nd.array(x)
155190
grad_grad_x = grad_grad_op(x)
156191
x.attach_grad()
@@ -171,7 +206,8 @@ def check_second_order_unary(x, op, grad_grad_op):
171206
y_grad.asnumpy()
172207

173208
# Validate the gradients.
174-
assert_almost_equal(expected_grad_grad, x.grad.asnumpy())
209+
assert_almost_equal(expected_grad_grad,
210+
x.grad.asnumpy(), rtol=rtol, atol=atol)
175211

176212

177213
if __name__ == '__main__':

0 commit comments

Comments
 (0)