@@ -139,7 +139,33 @@ The storage type of ``tan`` output depends upon the input storage type:
139
139
)code" ADD_FILELINE)
140
140
.set_attr<nnvm::FGradient>(" FGradient" , ElemwiseGradUseOut{ " _backward_tan" });
141
141
142
- MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR (_backward_tan, unary_bwd<mshadow_op::tan_grad>);
142
+ MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR (_backward_tan, unary_bwd<mshadow_op::tan_grad>)
143
+ .set_attr<nnvm::FGradient>(" FGradient" ,
144
+ [](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
145
+ // NodeEntry{n} : y_grad * f'(x)
146
+ // n->inputs[0] : y_grad (dL/dy)
147
+ // n->inputs[1] : y = f(x) = tan(x) (ElemwiseGradUseOut)
148
+ // ograds[0] : head_grads (dL/dxgrad)
149
+ // f'(x) = sec^2(x)
150
+ // f''(x) = 2 * f'(x) * f(x)
151
+ //
152
+ // Note: When building gradient graph, the backward node of n->inputs[1] will be
153
+ // added to the graph again, therefore f`(x) will be multiplied
154
+ // So we need to compute only -> 2 * f(x) * dL/dy_grad * y_grad
155
+ const std::unordered_map<std::string, std::string> args = {{" scalar" , " 2.0" }};
156
+ auto two_y = MakeNode (" _mul_scalar" , n->attrs .name + " _mul_two" , {n->inputs [1 ]}, &args, &n);
157
+ auto grad_grad_mid = MakeNode (" elemwise_mul" , n->attrs .name + " _grad_mul" ,
158
+ {n->inputs [0 ], nnvm::NodeEntry{two_y}}, nullptr , &n);
159
+ auto dydx = MakeNode (" elemwise_div" , n->attrs .name + " _grad_div" ,
160
+ {nnvm::NodeEntry{n}, n->inputs [0 ]}, nullptr , &n);
161
+
162
+ std::vector<nnvm::NodeEntry> ret;
163
+ ret.emplace_back (MakeNode (" elemwise_mul" , n->attrs .name + " backward_grad_grad" ,
164
+ {ograds[0 ], nnvm::NodeEntry{dydx}}, nullptr , &n));
165
+ ret.emplace_back (MakeNode (" elemwise_mul" , n->attrs .name + " backward_grad_grad_in" ,
166
+ {ograds[0 ], nnvm::NodeEntry{grad_grad_mid}}, nullptr , &n));
167
+ return ret;
168
+ });
143
169
144
170
// arcsin
145
171
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR (arcsin, cpu, mshadow_op::arcsin)
@@ -290,7 +316,34 @@ The storage type of ``tanh`` output depends upon the input storage type:
290
316
)code" ADD_FILELINE)
291
317
.set_attr<nnvm::FGradient>(" FGradient" , ElemwiseGradUseOut{ " _backward_tanh" });
292
318
293
- MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR (_backward_tanh, unary_bwd<mshadow_op::tanh_grad>);
319
+ MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR (_backward_tanh, unary_bwd<mshadow_op::tanh_grad>)
320
+ .set_attr<nnvm::FGradient>(" FGradient" ,
321
+ [](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
322
+ // NodeEntry{n} : y_grad * f'(x)
323
+ // n->inputs[0] : y_grad (dL/dy)
324
+ // n->inputs[1] : y = f(x) = tanh(x) (ElemwiseGradUseOut)
325
+ // ograds[0] : head_grads dL/dxgrad
326
+ // f'(x) = sech^2(x)
327
+ // f''(x) = -2 * f'(x) * f(x)
328
+ //
329
+ // Note: when building gradient graph, the backward node of n->inputs[1] will be
330
+ // added to the graph again, therefore f`(x) will be multiplied
331
+ // So we need to compute only -> -2 * f(x) * dL/dy_grad * y_grad
332
+ const std::unordered_map<std::string, std::string> args = {{" scalar" , " -2.0" }};
333
+ auto neg_two_y = MakeNode (" _mul_scalar" , n->attrs .name + " _mul_neg_two" ,
334
+ {n->inputs [1 ]}, &args, &n);
335
+ auto grad_grad_mid = MakeNode (" elemwise_mul" , n->attrs .name + " _grad_mul" ,
336
+ {n->inputs [0 ], nnvm::NodeEntry{neg_two_y}}, nullptr , &n);
337
+ auto dydx = MakeNode (" elemwise_div" , n->attrs .name + " _grad_div" ,
338
+ {nnvm::NodeEntry{n}, n->inputs [0 ]}, nullptr , &n);
339
+
340
+ std::vector<nnvm::NodeEntry> ret;
341
+ ret.emplace_back (MakeNode (" elemwise_mul" , n->attrs .name + " backward_grad_grad" ,
342
+ {ograds[0 ], nnvm::NodeEntry{dydx}}, nullptr , &n));
343
+ ret.emplace_back (MakeNode (" elemwise_mul" , n->attrs .name + " backward_grad_grad_in" ,
344
+ {ograds[0 ], nnvm::NodeEntry{grad_grad_mid}}, nullptr , &n));
345
+ return ret;
346
+ });
294
347
295
348
// arcsinh
296
349
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR (arcsinh, cpu, mshadow_op::arcsinh)
0 commit comments