Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 7458af5

Browse files
committed
simplify implementation of second order gradient for FC
1 parent 745ea66 commit 7458af5

File tree

1 file changed

+1
-17
lines changed

1 file changed

+1
-17
lines changed

src/operator/nn/fully_connected.cc

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -176,22 +176,6 @@ struct FullyConnectedGrad {
176176
}
177177
};
178178

179-
std::vector<nnvm::NodeEntry> FullyConnectedBackwardGrad(
180-
const nnvm::NodePtr& n,
181-
const std::vector<nnvm::NodeEntry>& ograds) {
182-
// Note this is not strictly correct but we don't expect inputs to depend on weights at the
183-
// moment. If you find such a case, please contribute a more elaborate implementation.
184-
std::vector<nnvm::NodeEntry> ret;
185-
size_t i = 0;
186-
for (const auto& x : n->inputs) {
187-
std::ostringstream os;
188-
os << n->attrs.name << "_backward_" << i;
189-
ret.emplace_back(MakeNode("zeros_like", os.str(), {x}, nullptr, &n));
190-
++i;
191-
}
192-
return ret;
193-
}
194-
195179
inline static bool FCStorageType(const nnvm::NodeAttrs& attrs,
196180
const int dev_mask,
197181
DispatchMode* dispatch_mode,
@@ -341,7 +325,7 @@ NNVM_REGISTER_OP(_backward_FullyConnected)
341325
.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs){
342326
return std::vector<std::pair<int, int> >{{1, 0}};
343327
})
344-
.set_attr<nnvm::FGradient>("FGradient", FullyConnectedBackwardGrad)
328+
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
345329
.set_attr<FInferStorageType>("FInferStorageType", BackwardFCStorageType)
346330
.set_attr_parser(ParamParser<FullyConnectedParam>)
347331
#if MXNET_USE_MKLDNN == 1

0 commit comments

Comments
 (0)