Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 776d9c4

Browse files
haohuanwKellenSunderland
authored andcommitted
add deconv in TRT subgraph (#15666)
1 parent f0c69f5 commit 776d9c4

File tree

4 files changed

+116
-14
lines changed

4 files changed

+116
-14
lines changed

src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,16 @@ namespace mxnet {
4141
namespace op {
4242
namespace nnvm_to_onnx {
4343

44+
enum ConvDeconvType {Convolution, Deconvolution};
45+
4446
using namespace nnvm;
4547
using namespace ::onnx;
4648
using int64 = ::google::protobuf::int64;
4749

4850
std::unordered_map<std::string, mxnet::TShape> GetPlaceholderShapes(const ShapeVector& shape_inputs,
4951
const nnvm::IndexedGraph& ig);
5052

51-
std::unordered_map<std::string, int> GetPlaceholderDTypes(const DTypeVector&
52-
dtype_inputs,
53+
std::unordered_map<std::string, int> GetPlaceholderDTypes(const DTypeVector& dtype_inputs,
5354
const nnvm::IndexedGraph& ig);
5455

5556
std::unordered_map<std::string, uint32_t> GetOutputLookup(const nnvm::IndexedGraph& ig);
@@ -74,12 +75,25 @@ typedef void (*ConverterFunction)(NodeProto *node_proto,
7475
const nnvm::IndexedGraph &ig,
7576
const array_view<IndexedGraph::NodeEntry> &inputs);
7677

78+
template <class ConvDeconvParam>
79+
void ConvDeconvConvertHelper(NodeProto *node_proto,
80+
const NodeAttrs &attrs,
81+
const nnvm::IndexedGraph &ig,
82+
const array_view<IndexedGraph::NodeEntry> &inputs,
83+
const ConvDeconvParam& param,
84+
ConvDeconvType type);
85+
7786
// Forward declarations
7887
void ConvertConvolution(NodeProto *node_proto,
7988
const NodeAttrs &attrs,
8089
const nnvm::IndexedGraph &ig,
8190
const array_view<IndexedGraph::NodeEntry> &inputs);
8291

92+
void ConvertDeconvolution(NodeProto *node_proto,
93+
const NodeAttrs &attrs,
94+
const nnvm::IndexedGraph &ig,
95+
const array_view<IndexedGraph::NodeEntry> &inputs);
96+
8397
void ConvertPooling(NodeProto *node_proto,
8498
const NodeAttrs &attrs,
8599
const nnvm::IndexedGraph &ig,
@@ -158,6 +172,7 @@ static const std::unordered_map<std::string, ConverterFunction> converter_map =
158172
{"BatchNorm", ConvertBatchNorm},
159173
{"clip", ConvertClip},
160174
{"Convolution", ConvertConvolution},
175+
{"Deconvolution", ConvertDeconvolution},
161176
{"Concat", ConvertConcatenate},
162177
{"Dropout", ConvertDropout},
163178
{"elemwise_add", ConvertElementwiseAdd},

src/operator/subgraph/tensorrt/nnvm_to_onnx.cc

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <mxnet/base.h>
3232
#include <nnvm/graph.h>
3333
#include <nnvm/pass_functions.h>
34+
#include <operator/nn/deconvolution-inl.h>
3435

3536
#include "../../../common/utils.h"
3637
#include "../../../ndarray/ndarray_function.h"
@@ -170,20 +171,25 @@ std::string ConvertNnvmGraphToOnnx(
170171
return serialized_onnx_graph;
171172
}
172173

173-
void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs,
174-
const nnvm::IndexedGraph& /*ig*/,
175-
const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
176-
const auto& conv_param = nnvm::get<op::ConvolutionParam>(attrs.parsed);
177-
178-
node_proto->set_op_type("Conv");
174+
template <class ConvDeconvParam>
175+
void ConvDeconvConvertHelper(NodeProto* node_proto, const NodeAttrs& attrs,
176+
const nnvm::IndexedGraph& /*ig*/,
177+
const array_view<IndexedGraph::NodeEntry>& /*input*/,
178+
const ConvDeconvParam& param,
179+
ConvDeconvType type) {
180+
if (type == ConvDeconvType::Convolution) {
181+
node_proto->set_op_type("Conv");
182+
} else {
183+
node_proto->set_op_type("ConvTranspose");
184+
}
179185

180-
const mxnet::TShape kernel = conv_param.kernel;
181-
const mxnet::TShape stride = conv_param.stride;
182-
const mxnet::TShape dilate = conv_param.dilate;
183-
const mxnet::TShape pad = conv_param.pad;
184-
const uint32_t num_group = conv_param.num_group;
186+
const mxnet::TShape kernel = param.kernel;
187+
const mxnet::TShape stride = param.stride;
188+
const mxnet::TShape dilate = param.dilate;
189+
const mxnet::TShape pad = param.pad;
190+
const uint32_t num_group = param.num_group;
185191
// const bool no_bias = conv_param.no_bias;
186-
const dmlc::optional<int> layout = conv_param.layout;
192+
const dmlc::optional<int> layout = param.layout;
187193

188194
// dilations
189195
AttributeProto* const dilations = node_proto->add_attribute();
@@ -226,8 +232,24 @@ void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs,
226232
for (const dim_t kval : stride) {
227233
strides->add_ints(static_cast<int64>(kval));
228234
}
235+
}
236+
237+
void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs,
238+
const nnvm::IndexedGraph& ig,
239+
const array_view<IndexedGraph::NodeEntry>& inputs) {
240+
const auto& conv_param = nnvm::get<op::ConvolutionParam>(attrs.parsed);
241+
ConvDeconvConvertHelper(node_proto, attrs, ig, inputs, conv_param,
242+
ConvDeconvType::Convolution);
229243
} // end ConvertConvolution
230244

245+
void ConvertDeconvolution(NodeProto* node_proto, const NodeAttrs& attrs,
246+
const nnvm::IndexedGraph& ig,
247+
const array_view<IndexedGraph::NodeEntry>& inputs) {
248+
const auto& deconv_param = nnvm::get<op::DeconvolutionParam>(attrs.parsed);
249+
ConvDeconvConvertHelper(node_proto, attrs, ig, inputs, deconv_param,
250+
ConvDeconvType::Deconvolution);
251+
} // end ConvertDeconvolution
252+
231253
void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs,
232254
const nnvm::IndexedGraph& /*ig*/,
233255
const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {

src/operator/subgraph/tensorrt/tensorrt-inl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ class TensorrtSelector : public SubgraphSelector {
8888
"clip",
8989
"Concat",
9090
"Convolution",
91+
"Deconvolution",
9192
"Dropout",
9293
"elemwise_add",
9394
"elemwise_sub",
@@ -104,6 +105,7 @@ class TensorrtSelector : public SubgraphSelector {
104105
const std::unordered_set<std::string> withWeightsOps = {
105106
"BatchNorm",
106107
"Convolution",
108+
"Deconvolution",
107109
"FullyConnected"
108110
};
109111

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import mxnet as mx
19+
from mxnet.test_utils import assert_almost_equal
20+
21+
def get_params():
22+
arg_params = {}
23+
aux_params = {}
24+
arg_params["trt_bn_test_conv_weight"] = mx.nd.ones((1, 1, 3, 3))
25+
arg_params["trt_bn_test_deconv_weight"] = mx.nd.ones((1, 1, 3, 3))
26+
return arg_params, aux_params
27+
28+
def get_symbol():
29+
data = mx.sym.Variable("data")
30+
conv = mx.sym.Convolution(data=data, kernel=(3,3), no_bias=True, num_filter=1, num_group=1,
31+
name="trt_bn_test_conv")
32+
deconv = mx.sym.Deconvolution(data=conv, kernel=(3, 3), no_bias=True, num_filter=1,
33+
num_group=1, name="trt_bn_test_deconv")
34+
return deconv
35+
36+
def test_deconvolution_produce_same_output_as_tensorrt():
37+
arg_params, aux_params = get_params()
38+
arg_params_trt, aux_params_trt = get_params()
39+
40+
sym = get_symbol()
41+
sym_trt = get_symbol().get_backend_symbol("TensorRT")
42+
43+
mx.contrib.tensorrt.init_tensorrt_params(sym_trt, arg_params_trt, aux_params_trt)
44+
45+
executor = sym.simple_bind(ctx=mx.gpu(), data=(1, 1, 3, 3), grad_req='null', force_rebind=True)
46+
executor.copy_params_from(arg_params, aux_params)
47+
48+
executor_trt = sym_trt.simple_bind(ctx=mx.gpu(), data=(1, 1, 3, 3), grad_req='null',
49+
force_rebind=True)
50+
executor_trt.copy_params_from(arg_params_trt, aux_params_trt)
51+
52+
input_data = mx.nd.random.uniform(low=0, high=1, shape=(1, 1, 3, 3))
53+
54+
y = executor.forward(is_train=False, data=input_data)
55+
y_trt = executor_trt.forward(is_train=False, data=input_data)
56+
57+
print(y[0].asnumpy())
58+
print(y_trt[0].asnumpy())
59+
assert_almost_equal(y[0].asnumpy(), y_trt[0].asnumpy(), 1e-4, 1e-4)
60+
61+
if __name__ == '__main__':
62+
import nose
63+
nose.runmodule()

0 commit comments

Comments
 (0)