Skip to content

Sandeepkumar skb groupnorm plugin #437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
0fc7192
group_norm - plugin and converter; Changes to build the plugin is not…
Oct 19, 2020
b3dd0bc
adding build support for group_norm
Oct 20, 2020
a9b464f
adding test for group_norm and importing correct plugin
Oct 20, 2020
871f4d7
Adding a common file to register extensions
Oct 20, 2020
9497b97
Removing individual extension registrations from individual plugins
Oct 20, 2020
1c3daea
Building plugin - building directly from plugins.cpp
Oct 20, 2020
9abb4b0
correcting plugin import in the converters
Oct 20, 2020
c7d4277
1. Adding support for in the plugin. 2. Minor clean-ups
Oct 22, 2020
4df27e5
Stale code and comment clean up
Oct 22, 2020
8328e6d
Adding test with eps
Oct 22, 2020
adfd9d8
trying at::native_group_norm
Oct 26, 2020
044c0cc
Adding support for weight and bias in groupnorm plugin
Oct 26, 2020
1aa5ee0
reverting interpolate converter back to orig state
Oct 27, 2020
94969ed
groupnorm fixes
jaybdub Nov 2, 2020
6b9b30e
fixed group norm comment
jaybdub Nov 4, 2020
926051b
added group norm to changelog
jaybdub Nov 4, 2020
02a1d4b
group_norm - plugin and converter; Changes to build the plugin is not…
Oct 19, 2020
1df95eb
adding build support for group_norm
Oct 20, 2020
b168e43
adding test for group_norm and importing correct plugin
Oct 20, 2020
8430b49
Adding a common file to register extensions
Oct 20, 2020
5839d68
Removing individual extension registrations from individual plugins
Oct 20, 2020
e73d047
Building plugin - building directly from plugins.cpp
Oct 20, 2020
5adfd96
correcting plugin import in the converters
Oct 20, 2020
2da64e4
1. Adding support for in the plugin. 2. Minor clean-ups
Oct 22, 2020
d3e87de
Stale code and comment clean up
Oct 22, 2020
1b365e0
Adding test with eps
Oct 22, 2020
c6fffc6
trying at::native_group_norm
Oct 26, 2020
3ec1f1e
Adding support for weight and bias in groupnorm plugin
Oct 26, 2020
9363f45
reverting interpolate converter back to orig state
Oct 27, 2020
926e4cc
groupnorm fixes
jaybdub Nov 2, 2020
9612ba4
fixed group norm comment
jaybdub Nov 4, 2020
50d6c94
added group norm to changelog
jaybdub Nov 4, 2020
58bfc1e
fix changelog
jaybdub Nov 4, 2020
d30a871
merge remote
jaybdub Nov 4, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@
### Added

- Added names for TensorRT layers
- Added GroupNorm plugin which internally uses PyTorch aten::group_norm
- Replaced Tensor.ndim references with len(tensor.shape) to support older pytorch versions
- Added reduced precision documentation page
1 change: 1 addition & 0 deletions build.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

PLUGINS = [
'interpolate',
'group_norm',
]

BASE_FOLDER = 'torch2trt/converters'
Expand Down
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def trt_lib_dir():
plugins_ext_module = CUDAExtension(
name='plugins',
sources=[
'torch2trt/plugins/interpolate.cpp'
'torch2trt/plugins/plugins.cpp'
],
include_dirs=[
trt_inc_dir()
Expand All @@ -29,8 +29,7 @@ def trt_lib_dir():
'cxx': ['-DUSE_DEPRECATED_INTLIST'] if torch.__version__ < "1.5" else [],
'nvcc': []
}
)

)
if '--plugins' in sys.argv:
ext_modules.append(plugins_ext_module)
sys.argv.remove('--plugins')
Expand Down
1 change: 1 addition & 0 deletions torch2trt/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .identity import *
from .instance_norm import *
from .interpolate import *
from .group_norm import *
from .max import *
from .max_pool2d import *
from .mean import *
Expand Down
48 changes: 48 additions & 0 deletions torch2trt/converters/group_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch.nn as nn
from torch2trt.torch2trt import *
from torch2trt.module_test import add_module_test

def has_group_norm_plugin():
try:
from torch2trt.plugins import GroupNormPlugin
return True
except:
return False


def get_group_norm_plugin(num_groups, weight, bias, eps):
from torch2trt.plugins import GroupNormPlugin
PLUGIN_NAME = 'group_norm'
registry = trt.get_plugin_registry()
creator = [c for c in registry.plugin_creator_list if c.name == PLUGIN_NAME and c.plugin_namespace == 'torch2trt'][0]
torch2trt_plugin = GroupNormPlugin(num_groups=num_groups, weight=weight, bias=bias, eps=eps)
return creator.deserialize_plugin(PLUGIN_NAME, torch2trt_plugin.serializeToString())

@tensorrt_converter('torch.nn.GroupNorm.forward', has_group_norm_plugin())
def convert_group_norm_trt(ctx):
module = ctx.method_args[0]
input = ctx.method_args[1]
num_groups = module.num_groups
weight = module.weight
bias = module.bias
eps = module.eps
input_trt = add_missing_trt_tensors(ctx.network, [input])
output = ctx.method_return
plugin = get_group_norm_plugin(num_groups, weight, bias, eps)

layer = ctx.network.add_plugin_v2(input_trt, plugin)

output._trt = layer.get_output(0)



@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 112, 112)], has_group_norm_plugin())
def test_group_norm_trt_g2_fp32():
return torch.nn.GroupNorm(2, 10)

@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 112, 112)], has_group_norm_plugin())
def test_group_norm_trt_g2_eps_fp32():
return torch.nn.GroupNorm(2, 10, eps=1e-4)



296 changes: 296 additions & 0 deletions torch2trt/plugins/group_norm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
#include <torch/extension.h>
#include <torch/script.h>
#include <iostream>
#include <string>
#include <sstream>
#include <NvInfer.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAEvent.h>
#include <torch/torch.h>
#include <cuda_runtime_api.h>

using namespace nvinfer1;

namespace torch2trt {

class GroupNormPlugin : public IPluginV2 {
private:
// configured by class
at::TensorOptions tensor_options;
std::vector<int64_t> input_sizes;
std::vector<int64_t> output_sizes;
DataType dtype;

// group norm parameters, configured by user
int64_t num_groups;
at::Tensor weight;
at::Tensor bias;
double eps;


public:

// create from arguments
GroupNormPlugin(int64_t num_groups, at::Tensor weight, at::Tensor bias, double eps) :
num_groups{num_groups}, weight{weight}, bias{bias}, eps{eps}
{}

GroupNormPlugin(const char *data, size_t length) : GroupNormPlugin(std::string(data, length)) {}

GroupNormPlugin(const std::string &data){
deserializeFromString(data);
}

void deserializeFromString(const std::string &data) {
std::istringstream data_stream(data);
torch::serialize::InputArchive input_archive;
input_archive.load_from(data_stream);
{
torch::IValue value;
input_archive.read("num_groups", value);
#ifdef USE_DEPRECATED_INTLIST
num_groups = value.toIntListRef().vec();
#else
num_groups = value.toInt();
#endif
}
{
torch::IValue value;
input_archive.read("weight", value);
weight = value.toTensor();
}
{
torch::IValue value;
input_archive.read("bias", value);
bias = value.toTensor();
}

{
torch::IValue value;
input_archive.read("eps", value);
#ifdef USE_DEPRECATED_INTLIST
eps = value.toDoubleListRef().vec();
#else
eps = value.toDouble();
#endif
}
{
torch::IValue value;
input_archive.read("dtype", value);
dtype = (DataType) value.toInt();
}
{
torch::IValue value;
input_archive.read("input_sizes", value);
#ifdef USE_DEPRECATED_INTLIST
input_sizes = value.toIntListRef().vec();
#else
input_sizes = value.toIntVector();
#endif
}
{
torch::IValue value;
input_archive.read("output_sizes", value);
#ifdef USE_DEPRECATED_INTLIST
output_sizes = value.toIntListRef().vec();
#else
output_sizes = value.toIntVector();
#endif
}
}
std::string serializeToString() const {
torch::serialize::OutputArchive output_archive;
output_archive.write("num_groups", torch::IValue(num_groups));
output_archive.write("weight", torch::IValue(weight));
output_archive.write("bias", torch::IValue(bias));
output_archive.write("eps", torch::IValue(eps));
output_archive.write("dtype", torch::IValue((int) dtype));
output_archive.write("input_sizes", torch::IValue(input_sizes));
output_archive.write("output_sizes", torch::IValue(output_sizes));
std::ostringstream data_str;
output_archive.save_to(data_str);
return data_str.str();
}

const char* getPluginType() const override {
return "group_norm";
};

const char* getPluginVersion() const override {
return "1";
}

int getNbOutputs() const override {
return 1;
}

Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override {
Dims dims;
dims.nbDims = inputs->nbDims;

for (int i = 0; i < inputs->nbDims; i++) {
dims.d[i] = inputs->d[i];
}

return dims;
}

bool supportsFormat(DataType type, PluginFormat format) const override {
if (format != PluginFormat::kNCHW) {
return false;
}
if (type == DataType::kINT32 || type == DataType::kINT8) {
return false;
}
return true;
}

void configureWithFormat(const Dims* inputDims, int nbInputs, const Dims* outputDims,
int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) override {

// set data type
if (type == DataType::kFLOAT) {
tensor_options = tensor_options.dtype(c10::kFloat);
dtype = type;
} else if (type == DataType::kHALF) {
tensor_options = tensor_options.dtype(c10::kHalf);
dtype = type;
}

// set input sizes
input_sizes.resize(inputDims[0].nbDims);
for (int i = 0; i < inputDims[0].nbDims; i++) {
input_sizes[i] = inputDims[0].d[i];
}

// set output sizes
output_sizes.resize(outputDims[0].nbDims);
for (int i = 0; i < outputDims[0].nbDims; i++) {
output_sizes[i] = outputDims[0].d[i];
}
}

int initialize() override {
// set device
tensor_options = tensor_options.device(c10::kCUDA);

// set data type
if (dtype == DataType::kFLOAT) {
tensor_options = tensor_options.dtype(c10::kFloat);
} else if (dtype == DataType::kHALF) {
tensor_options = tensor_options.dtype(c10::kHalf);
}


weight = weight.to(tensor_options);
bias = bias.to(tensor_options);

return 0;
}

void terminate() override {}

size_t getWorkspaceSize(int maxBatchSize) const override { return 0; }

int enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override {
// get input / output dimensions
std::vector<long> batch_input_sizes = input_sizes;
std::vector<long> batch_output_sizes = output_sizes;
batch_input_sizes.insert(batch_input_sizes.begin(), batchSize);
batch_output_sizes.insert(batch_output_sizes.begin(), batchSize);

// create tensor wrappers
at::Tensor input = at::from_blob((void*) inputs[0], batch_input_sizes, [](void*){}, tensor_options);
at::Tensor output = at::from_blob(outputs[0], batch_output_sizes, [](void*){}, tensor_options);

// create new torch cuda stream
at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool();
at::cuda::CUDAStreamGuard torch_guard(torch_stream);

// capture current work on tensorrt cuda stream
cudaEvent_t event;
cudaEventCreate(&event);
cudaEventRecord(event, stream);

// make torch cuda stream wait on tensorrt work
cudaStreamWaitEvent(torch_stream.stream(), event, 0);



// enqueue work
// Group_norm function from PyTorch: https://pytorch.org/cppdocs/api/function_namespaceat_1a6bc1e9504ea440c6c96ff8a8b94333f2.html#exhale-function-namespaceat-1a6bc1e9504ea440c6c96ff8a8b94333f2
at::Tensor output_tmp = at::group_norm(input, num_groups, weight, bias, eps=eps);
output.copy_(output_tmp);

// capture event on enqueued stream
cudaEvent_t torch_event;
cudaEventCreate(&torch_event);
cudaEventRecord(torch_event, torch_stream.stream());
cudaStreamWaitEvent(stream, torch_event, 0);

cudaEventDestroy(event);
cudaEventDestroy(torch_event);

return 0;
}


size_t getSerializationSize() const override {
return serializeToString().size();
}

void serialize(void* buffer) const override {
std::string data = serializeToString();
size_t size = getSerializationSize();
data.copy((char *) buffer, size);
}

void destroy() override {}

IPluginV2* clone() const override {
return new GroupNormPlugin(num_groups, weight, bias, eps);
}

void setPluginNamespace(const char* pluginNamespace) override {}

const char *getPluginNamespace() const override {
return "torch2trt";
}

};

class GroupNormPluginCreator : public IPluginCreator {
public:
GroupNormPluginCreator() {}

const char *getPluginNamespace() const override {
return "torch2trt";
}

const char *getPluginName() const override {
return "group_norm";
}

const char *getPluginVersion() const override {
return "1";
}

IPluginV2 *deserializePlugin(const char *name, const void *data, size_t length) override {
return new GroupNormPlugin((const char*) data, length);
}

void setPluginNamespace(const char *N) override {}
const PluginFieldCollection *getFieldNames() override { return nullptr; }

IPluginV2 *createPlugin(const char *name, const PluginFieldCollection *fc) override { return nullptr; }

};


REGISTER_TENSORRT_PLUGIN(GroupNormPluginCreator);

} // namespace torch2trt




Loading