Skip to content

Commit 9457b8c

Browse files
authored
Merge pull request #570 from jaybdub/linear_functional_converter
Linear functional converter
2 parents e01279c + ddb3558 commit 9457b8c

File tree

7 files changed

+219
-12
lines changed

7 files changed

+219
-12
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22

33
## [Master]
44

5+
- Added converter for ``torch.nn.functional.layer_norm``
6+
- Added converter for ``torch.nn.functional.gelu``
7+
- Added converter for ``torch.nn.functional.linear``
8+
- Added converter for ``torch.nn.functional.silu``
9+
510
## [0.2.0] - 03/02/2021
611

712
- Added converter for ``torch.Tensor.flatten``

torch2trt/converters/Linear.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,35 @@
22
from torch2trt.module_test import add_module_test
33

44

5-
@tensorrt_converter('torch.nn.Linear.forward')
5+
@tensorrt_converter('torch.nn.functional.linear')
66
def convert_Linear(ctx):
7-
module = ctx.method_args[0]
8-
input = ctx.method_args[1]
7+
input = ctx.method_args[0]
8+
weight = get_arg(ctx, 'weight', 1, None)
9+
bias = get_arg(ctx, 'bias', 2, None)
910
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
1011
output = ctx.method_return
1112

1213
# reshape to ...xNx1x1
1314
layer = ctx.network.add_shuffle(input_trt)
1415
layer.reshape_dims = tuple(input_trt.shape) + (1, 1)
1516

16-
bias = trt.Weights(torch_dtype_to_trt(module.weight.dtype))
17-
if module.bias is not None:
18-
bias = module.bias.detach().cpu().numpy()
17+
bias_trt = trt.Weights(torch_dtype_to_trt(weight.dtype))
18+
if bias is not None:
19+
bias_trt = bias.detach().cpu().numpy()
1920

2021
# add fully connected
2122
layer = ctx.network.add_fully_connected(
2223
input=layer.get_output(0),
23-
num_outputs=module.out_features,
24-
kernel=module.weight.detach().cpu().numpy(),
25-
bias=bias)
24+
num_outputs=int(weight.shape[0]),
25+
kernel=weight.detach().cpu().numpy(),
26+
bias=bias_trt)
2627

2728
# reshape back to N
2829
layer = ctx.network.add_shuffle(layer.get_output(0))
2930
layer.reshape_dims = tuple(output.shape[1:])
3031

3132
output._trt = layer.get_output(0)
32-
33+
3334

3435
@add_module_test(torch.float32, torch.device('cuda'), [(1, 10)])
3536
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 10)])
@@ -42,4 +43,4 @@ def test_Linear_basic():
4243
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 10)])
4344
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 10)])
4445
def test_Linear_no_bias():
45-
return torch.nn.Linear(10, 5, bias=False)
46+
return torch.nn.Linear(10, 5, bias=False)

torch2trt/converters/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,13 @@
2828
from .div import *
2929
from .expand import *
3030
from .floordiv import *
31+
from .gelu import *
3132
from .getitem import *
33+
from .group_norm import *
3234
from .identity import *
3335
from .instance_norm import *
3436
from .interpolate import *
35-
from .group_norm import *
37+
from .layer_norm import *
3638
from .max import *
3739
from .max_pool2d import *
3840
from .mean import *
@@ -50,6 +52,7 @@
5052
from .relu import *
5153
from .relu6 import *
5254
from .sigmoid import *
55+
from .silu import *
5356
from .softmax import *
5457
from .split import *
5558
from .stack import *

torch2trt/converters/gelu.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from torch2trt.torch2trt import *
2+
from torch2trt.module_test import add_module_test
3+
import math
4+
5+
6+
@tensorrt_converter('torch.nn.functional.gelu')
7+
def convert_gelu_v1(ctx):
8+
# approximate equation 1 from paper
9+
input = get_arg(ctx, 'input', 0, None)
10+
output = ctx.method_return
11+
12+
x, c05, c1, cs2pi, c044, c3 = add_missing_trt_tensors(
13+
ctx.network,
14+
[input, 0.5, 1.0, math.sqrt(2.0 / math.pi), 0.044715, 3.0]
15+
)
16+
17+
x, c05, c1, cs2pi, c044, c3 = broadcast_trt_tensors(
18+
ctx.network,
19+
[x, c05, c1, cs2pi, c044, c3],
20+
len(output.shape) - 1
21+
)
22+
23+
y = ctx.network.add_elementwise(x, c3, trt.ElementWiseOperation.POW).get_output(0)
24+
y = ctx.network.add_elementwise(y, c044, trt.ElementWiseOperation.PROD).get_output(0)
25+
y = ctx.network.add_elementwise(x, y, trt.ElementWiseOperation.SUM).get_output(0)
26+
y = ctx.network.add_elementwise(y, cs2pi, trt.ElementWiseOperation.PROD).get_output(0)
27+
y = ctx.network.add_activation(y, trt.ActivationType.TANH).get_output(0)
28+
y = ctx.network.add_elementwise(y, c1, trt.ElementWiseOperation.SUM).get_output(0)
29+
y = ctx.network.add_elementwise(x, y, trt.ElementWiseOperation.PROD).get_output(0)
30+
y = ctx.network.add_elementwise(y, c05, trt.ElementWiseOperation.PROD).get_output(0)
31+
32+
output._trt = y
33+
34+
35+
# @tensorrt_converter('torch.nn.functional.gelu')
36+
# def convert_gelu_v2(ctx):
37+
# # approximate equation 1 from paper
38+
# input = get_arg(ctx, 'input', 0, None)
39+
# output = ctx.method_return
40+
41+
# x, c1702 = add_missing_trt_tensors(
42+
# ctx.network,
43+
# [input, 1.702]
44+
# )
45+
46+
# x, c1702 = broadcast_trt_tensors(
47+
# ctx.network,
48+
# [x, c1702],
49+
# len(output.shape) - 1
50+
# )
51+
52+
# y = ctx.network.add_elementwise(x, c1702, trt.ElementWiseOperation.PROD).get_output(0)
53+
# y = ctx.network.add_activation(y, trt.ActivationType.SIGMOID).get_output(0)
54+
# y = ctx.network.add_elementwise(x, y, trt.ElementWiseOperation.PROD).get_output(0)
55+
56+
# output._trt = y
57+
58+
59+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5)])
60+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)])
61+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3, 3)])
62+
def test_silu():
63+
return torch.nn.GELU()

torch2trt/converters/layer_norm.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from torch2trt.torch2trt import *
2+
from torch2trt.module_test import add_module_test
3+
4+
5+
@tensorrt_converter('torch.nn.functional.layer_norm')
6+
def convert_layernorm(ctx):
7+
input = get_arg(ctx, 'input', 0, None)
8+
shape = get_arg(ctx, 'normalized_shape', 1, None)
9+
weight = get_arg(ctx, 'weight', 2, None)
10+
bias = get_arg(ctx, 'bias', 3, None)
11+
eps = get_arg(ctx, 'eps', 4, 1e-05)
12+
output = ctx.method_return
13+
14+
input_trt, eps_trt = add_missing_trt_tensors(
15+
ctx.network,
16+
[input, eps]
17+
)
18+
19+
input_trt, eps_trt = broadcast_trt_tensors(
20+
ctx.network,
21+
[input_trt, eps_trt],
22+
len(output.shape) - 1
23+
)
24+
25+
if weight is not None:
26+
_, weight_trt = add_missing_trt_tensors(
27+
ctx.network,
28+
[input, weight]
29+
)
30+
_, weight_trt = broadcast_trt_tensors(
31+
ctx.network,
32+
[input_trt, weight_trt],
33+
len(output.shape) - 1
34+
)
35+
36+
if bias is not None:
37+
_, bias_trt = add_missing_trt_tensors(
38+
ctx.network,
39+
[input, bias]
40+
)
41+
_, bias_trt = broadcast_trt_tensors(
42+
ctx.network,
43+
[input_trt, bias_trt],
44+
len(output.shape) - 1
45+
)
46+
47+
if isinstance(shape, int):
48+
shape = (shape,)
49+
dim = tuple([-i - 1 for i in range(len(shape))])
50+
dim = torch_dim_resolve_negative(dim, len(input.shape))
51+
axes = torch_dim_to_trt_axes(dim)
52+
53+
ux = ctx.network.add_reduce(input_trt, trt.ReduceOperation.AVG, axes, keep_dims=True).get_output(0)
54+
numerator = ctx.network.add_elementwise(input_trt, ux, trt.ElementWiseOperation.SUB).get_output(0)
55+
varx = ctx.network.add_elementwise(numerator, numerator, trt.ElementWiseOperation.PROD).get_output(0)
56+
varx = ctx.network.add_reduce(varx, trt.ReduceOperation.AVG, axes, keep_dims=True).get_output(0)
57+
denom = ctx.network.add_elementwise(varx, eps_trt, trt.ElementWiseOperation.SUM).get_output(0)
58+
denom = ctx.network.add_unary(denom, trt.UnaryOperation.SQRT).get_output(0)
59+
y = ctx.network.add_elementwise(numerator, denom, trt.ElementWiseOperation.DIV).get_output(0)
60+
61+
if weight is not None:
62+
y = ctx.network.add_elementwise(y, weight_trt, trt.ElementWiseOperation.PROD).get_output(0)
63+
64+
if bias is not None:
65+
y = ctx.network.add_elementwise(y, bias_trt, trt.ElementWiseOperation.SUM).get_output(0)
66+
67+
output._trt = y
68+
69+
70+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3)])
71+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)])
72+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 5, 3)])
73+
def test_layer_norm_1d():
74+
return torch.nn.LayerNorm(3)
75+
76+
77+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)])
78+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 5, 3)])
79+
def test_layer_norm_2d():
80+
return torch.nn.LayerNorm((5, 3))
81+
82+
83+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 5, 3)])
84+
def test_layer_norm_3d():
85+
return torch.nn.LayerNorm((5, 5, 3))
86+
87+
88+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3)])
89+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)])
90+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 5, 3)])
91+
def test_layer_norm_1d_nonaffine():
92+
return torch.nn.LayerNorm(3, elementwise_affine=False)
93+
94+
95+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)])
96+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 5, 3)])
97+
def test_layer_norm_2d_nonaffine():
98+
return torch.nn.LayerNorm((5, 3), elementwise_affine=False)
99+
100+
101+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 5, 3)])
102+
def test_layer_norm_3d_nonaffine():
103+
return torch.nn.LayerNorm((5, 5, 3), elementwise_affine=False)

torch2trt/converters/silu.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from torch2trt.torch2trt import *
2+
from torch2trt.module_test import add_module_test
3+
4+
5+
@tensorrt_converter('torch.nn.functional.silu')
6+
def convert_silu(ctx):
7+
input = get_arg(ctx, 'input', pos=0, default=None)
8+
output = ctx.method_return
9+
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
10+
11+
layer = ctx.network.add_activation(input_trt, trt.ActivationType.SIGMOID)
12+
layer = ctx.network.add_elementwise(input_trt, layer.get_output(0), trt.ElementWiseOperation.PROD)
13+
14+
output._trt = layer.get_output(0)
15+
16+
17+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5)])
18+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)])
19+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3, 3)])
20+
def test_silu():
21+
return torch.nn.SiLU()

torch2trt/torch2trt.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,17 @@ def trt_num_outputs(engine):
8787
return count
8888

8989

90+
def torch_dim_resolve_negative(dim, ndim):
91+
if not isinstance(dim, tuple):
92+
dim = (dim,)
93+
pos = []
94+
for d in dim:
95+
if d < 0:
96+
d = ndim + d
97+
pos.append(d)
98+
return tuple(pos)
99+
100+
90101
def torch_dim_to_trt_axes(dim):
91102
"""Converts torch dim, or tuple of dims to a tensorrt axes bitmask"""
92103
if not isinstance(dim, tuple):

0 commit comments

Comments
 (0)