-
Notifications
You must be signed in to change notification settings - Fork 689
Sandeepkumar skb groupnorm plugin #437
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
jaybdub
merged 34 commits into
NVIDIA-AI-IOT:master
from
jaybdub:sandeepkumar-skb-groupnorm_plugin
Nov 4, 2020
Merged
Changes from all commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
0fc7192
group_norm - plugin and converter; Changes to build the plugin is not…
b3dd0bc
adding build support for group_norm
a9b464f
adding test for group_norm and importing correct plugin
871f4d7
Adding a common file to register extensions
9497b97
Removing individual extension registrations from individual plugins
1c3daea
Building plugin - building directly from plugins.cpp
9abb4b0
correcting plugin import in the converters
c7d4277
1. Adding support for in the plugin. 2. Minor clean-ups
4df27e5
Stale code and comment clean up
8328e6d
Adding test with eps
adfd9d8
trying at::native_group_norm
044c0cc
Adding support for weight and bias in groupnorm plugin
1aa5ee0
reverting interpolate converter back to orig state
94969ed
groupnorm fixes
jaybdub 6b9b30e
fixed group norm comment
jaybdub 926051b
added group norm to changelog
jaybdub 02a1d4b
group_norm - plugin and converter; Changes to build the plugin is not…
1df95eb
adding build support for group_norm
b168e43
adding test for group_norm and importing correct plugin
8430b49
Adding a common file to register extensions
5839d68
Removing individual extension registrations from individual plugins
e73d047
Building plugin - building directly from plugins.cpp
5adfd96
correcting plugin import in the converters
2da64e4
1. Adding support for in the plugin. 2. Minor clean-ups
d3e87de
Stale code and comment clean up
1b365e0
Adding test with eps
c6fffc6
trying at::native_group_norm
3ec1f1e
Adding support for weight and bias in groupnorm plugin
9363f45
reverting interpolate converter back to orig state
926e4cc
groupnorm fixes
jaybdub 9612ba4
fixed group norm comment
jaybdub 50d6c94
added group norm to changelog
jaybdub 58bfc1e
fix changelog
jaybdub d30a871
merge remote
jaybdub File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
|
||
PLUGINS = [ | ||
'interpolate', | ||
'group_norm', | ||
] | ||
|
||
BASE_FOLDER = 'torch2trt/converters' | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import torch.nn as nn | ||
from torch2trt.torch2trt import * | ||
from torch2trt.module_test import add_module_test | ||
|
||
def has_group_norm_plugin(): | ||
try: | ||
from torch2trt.plugins import GroupNormPlugin | ||
return True | ||
except: | ||
return False | ||
|
||
|
||
def get_group_norm_plugin(num_groups, weight, bias, eps): | ||
from torch2trt.plugins import GroupNormPlugin | ||
PLUGIN_NAME = 'group_norm' | ||
registry = trt.get_plugin_registry() | ||
creator = [c for c in registry.plugin_creator_list if c.name == PLUGIN_NAME and c.plugin_namespace == 'torch2trt'][0] | ||
torch2trt_plugin = GroupNormPlugin(num_groups=num_groups, weight=weight, bias=bias, eps=eps) | ||
return creator.deserialize_plugin(PLUGIN_NAME, torch2trt_plugin.serializeToString()) | ||
|
||
@tensorrt_converter('torch.nn.GroupNorm.forward', has_group_norm_plugin()) | ||
def convert_group_norm_trt(ctx): | ||
module = ctx.method_args[0] | ||
input = ctx.method_args[1] | ||
num_groups = module.num_groups | ||
weight = module.weight | ||
bias = module.bias | ||
eps = module.eps | ||
input_trt = add_missing_trt_tensors(ctx.network, [input]) | ||
output = ctx.method_return | ||
plugin = get_group_norm_plugin(num_groups, weight, bias, eps) | ||
|
||
layer = ctx.network.add_plugin_v2(input_trt, plugin) | ||
|
||
output._trt = layer.get_output(0) | ||
|
||
|
||
|
||
@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 112, 112)], has_group_norm_plugin()) | ||
def test_group_norm_trt_g2_fp32(): | ||
return torch.nn.GroupNorm(2, 10) | ||
|
||
@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 112, 112)], has_group_norm_plugin()) | ||
def test_group_norm_trt_g2_eps_fp32(): | ||
return torch.nn.GroupNorm(2, 10, eps=1e-4) | ||
|
||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,296 @@ | ||
#include <torch/extension.h> | ||
#include <torch/script.h> | ||
#include <iostream> | ||
#include <string> | ||
#include <sstream> | ||
#include <NvInfer.h> | ||
#include <ATen/ATen.h> | ||
#include <ATen/cuda/CUDAEvent.h> | ||
#include <torch/torch.h> | ||
#include <cuda_runtime_api.h> | ||
|
||
using namespace nvinfer1; | ||
|
||
namespace torch2trt { | ||
|
||
class GroupNormPlugin : public IPluginV2 { | ||
private: | ||
// configured by class | ||
at::TensorOptions tensor_options; | ||
std::vector<int64_t> input_sizes; | ||
std::vector<int64_t> output_sizes; | ||
DataType dtype; | ||
|
||
// group norm parameters, configured by user | ||
int64_t num_groups; | ||
at::Tensor weight; | ||
at::Tensor bias; | ||
double eps; | ||
|
||
|
||
public: | ||
|
||
// create from arguments | ||
GroupNormPlugin(int64_t num_groups, at::Tensor weight, at::Tensor bias, double eps) : | ||
num_groups{num_groups}, weight{weight}, bias{bias}, eps{eps} | ||
{} | ||
|
||
GroupNormPlugin(const char *data, size_t length) : GroupNormPlugin(std::string(data, length)) {} | ||
|
||
GroupNormPlugin(const std::string &data){ | ||
deserializeFromString(data); | ||
} | ||
|
||
void deserializeFromString(const std::string &data) { | ||
std::istringstream data_stream(data); | ||
torch::serialize::InputArchive input_archive; | ||
input_archive.load_from(data_stream); | ||
{ | ||
torch::IValue value; | ||
input_archive.read("num_groups", value); | ||
#ifdef USE_DEPRECATED_INTLIST | ||
num_groups = value.toIntListRef().vec(); | ||
#else | ||
num_groups = value.toInt(); | ||
#endif | ||
} | ||
{ | ||
torch::IValue value; | ||
input_archive.read("weight", value); | ||
weight = value.toTensor(); | ||
} | ||
{ | ||
torch::IValue value; | ||
input_archive.read("bias", value); | ||
bias = value.toTensor(); | ||
} | ||
|
||
{ | ||
torch::IValue value; | ||
input_archive.read("eps", value); | ||
#ifdef USE_DEPRECATED_INTLIST | ||
eps = value.toDoubleListRef().vec(); | ||
#else | ||
eps = value.toDouble(); | ||
#endif | ||
} | ||
{ | ||
torch::IValue value; | ||
input_archive.read("dtype", value); | ||
dtype = (DataType) value.toInt(); | ||
} | ||
{ | ||
torch::IValue value; | ||
input_archive.read("input_sizes", value); | ||
#ifdef USE_DEPRECATED_INTLIST | ||
input_sizes = value.toIntListRef().vec(); | ||
#else | ||
input_sizes = value.toIntVector(); | ||
#endif | ||
} | ||
{ | ||
torch::IValue value; | ||
input_archive.read("output_sizes", value); | ||
#ifdef USE_DEPRECATED_INTLIST | ||
output_sizes = value.toIntListRef().vec(); | ||
#else | ||
output_sizes = value.toIntVector(); | ||
#endif | ||
} | ||
} | ||
std::string serializeToString() const { | ||
torch::serialize::OutputArchive output_archive; | ||
output_archive.write("num_groups", torch::IValue(num_groups)); | ||
output_archive.write("weight", torch::IValue(weight)); | ||
output_archive.write("bias", torch::IValue(bias)); | ||
output_archive.write("eps", torch::IValue(eps)); | ||
output_archive.write("dtype", torch::IValue((int) dtype)); | ||
output_archive.write("input_sizes", torch::IValue(input_sizes)); | ||
output_archive.write("output_sizes", torch::IValue(output_sizes)); | ||
std::ostringstream data_str; | ||
output_archive.save_to(data_str); | ||
return data_str.str(); | ||
} | ||
|
||
const char* getPluginType() const override { | ||
return "group_norm"; | ||
}; | ||
|
||
const char* getPluginVersion() const override { | ||
return "1"; | ||
} | ||
|
||
int getNbOutputs() const override { | ||
return 1; | ||
} | ||
|
||
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override { | ||
Dims dims; | ||
dims.nbDims = inputs->nbDims; | ||
|
||
for (int i = 0; i < inputs->nbDims; i++) { | ||
dims.d[i] = inputs->d[i]; | ||
} | ||
|
||
return dims; | ||
} | ||
|
||
bool supportsFormat(DataType type, PluginFormat format) const override { | ||
if (format != PluginFormat::kNCHW) { | ||
return false; | ||
} | ||
if (type == DataType::kINT32 || type == DataType::kINT8) { | ||
return false; | ||
} | ||
return true; | ||
} | ||
|
||
void configureWithFormat(const Dims* inputDims, int nbInputs, const Dims* outputDims, | ||
int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) override { | ||
|
||
// set data type | ||
if (type == DataType::kFLOAT) { | ||
tensor_options = tensor_options.dtype(c10::kFloat); | ||
dtype = type; | ||
} else if (type == DataType::kHALF) { | ||
tensor_options = tensor_options.dtype(c10::kHalf); | ||
dtype = type; | ||
} | ||
|
||
// set input sizes | ||
input_sizes.resize(inputDims[0].nbDims); | ||
for (int i = 0; i < inputDims[0].nbDims; i++) { | ||
input_sizes[i] = inputDims[0].d[i]; | ||
} | ||
|
||
// set output sizes | ||
output_sizes.resize(outputDims[0].nbDims); | ||
for (int i = 0; i < outputDims[0].nbDims; i++) { | ||
output_sizes[i] = outputDims[0].d[i]; | ||
} | ||
} | ||
|
||
int initialize() override { | ||
// set device | ||
tensor_options = tensor_options.device(c10::kCUDA); | ||
|
||
// set data type | ||
if (dtype == DataType::kFLOAT) { | ||
tensor_options = tensor_options.dtype(c10::kFloat); | ||
} else if (dtype == DataType::kHALF) { | ||
tensor_options = tensor_options.dtype(c10::kHalf); | ||
} | ||
|
||
|
||
weight = weight.to(tensor_options); | ||
bias = bias.to(tensor_options); | ||
|
||
return 0; | ||
} | ||
|
||
void terminate() override {} | ||
|
||
size_t getWorkspaceSize(int maxBatchSize) const override { return 0; } | ||
|
||
int enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override { | ||
// get input / output dimensions | ||
std::vector<long> batch_input_sizes = input_sizes; | ||
std::vector<long> batch_output_sizes = output_sizes; | ||
batch_input_sizes.insert(batch_input_sizes.begin(), batchSize); | ||
batch_output_sizes.insert(batch_output_sizes.begin(), batchSize); | ||
|
||
// create tensor wrappers | ||
at::Tensor input = at::from_blob((void*) inputs[0], batch_input_sizes, [](void*){}, tensor_options); | ||
at::Tensor output = at::from_blob(outputs[0], batch_output_sizes, [](void*){}, tensor_options); | ||
|
||
// create new torch cuda stream | ||
at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool(); | ||
at::cuda::CUDAStreamGuard torch_guard(torch_stream); | ||
|
||
// capture current work on tensorrt cuda stream | ||
cudaEvent_t event; | ||
cudaEventCreate(&event); | ||
cudaEventRecord(event, stream); | ||
|
||
// make torch cuda stream wait on tensorrt work | ||
cudaStreamWaitEvent(torch_stream.stream(), event, 0); | ||
|
||
|
||
|
||
// enqueue work | ||
// Group_norm function from PyTorch: https://pytorch.org/cppdocs/api/function_namespaceat_1a6bc1e9504ea440c6c96ff8a8b94333f2.html#exhale-function-namespaceat-1a6bc1e9504ea440c6c96ff8a8b94333f2 | ||
at::Tensor output_tmp = at::group_norm(input, num_groups, weight, bias, eps=eps); | ||
output.copy_(output_tmp); | ||
|
||
// capture event on enqueued stream | ||
cudaEvent_t torch_event; | ||
cudaEventCreate(&torch_event); | ||
cudaEventRecord(torch_event, torch_stream.stream()); | ||
cudaStreamWaitEvent(stream, torch_event, 0); | ||
|
||
cudaEventDestroy(event); | ||
cudaEventDestroy(torch_event); | ||
|
||
return 0; | ||
} | ||
|
||
|
||
size_t getSerializationSize() const override { | ||
return serializeToString().size(); | ||
} | ||
|
||
void serialize(void* buffer) const override { | ||
std::string data = serializeToString(); | ||
size_t size = getSerializationSize(); | ||
data.copy((char *) buffer, size); | ||
} | ||
|
||
void destroy() override {} | ||
|
||
IPluginV2* clone() const override { | ||
return new GroupNormPlugin(num_groups, weight, bias, eps); | ||
} | ||
|
||
void setPluginNamespace(const char* pluginNamespace) override {} | ||
|
||
const char *getPluginNamespace() const override { | ||
return "torch2trt"; | ||
} | ||
|
||
}; | ||
|
||
class GroupNormPluginCreator : public IPluginCreator { | ||
public: | ||
GroupNormPluginCreator() {} | ||
|
||
const char *getPluginNamespace() const override { | ||
return "torch2trt"; | ||
} | ||
|
||
const char *getPluginName() const override { | ||
return "group_norm"; | ||
} | ||
|
||
const char *getPluginVersion() const override { | ||
return "1"; | ||
} | ||
|
||
IPluginV2 *deserializePlugin(const char *name, const void *data, size_t length) override { | ||
return new GroupNormPlugin((const char*) data, length); | ||
} | ||
|
||
void setPluginNamespace(const char *N) override {} | ||
const PluginFieldCollection *getFieldNames() override { return nullptr; } | ||
|
||
IPluginV2 *createPlugin(const char *name, const PluginFieldCollection *fc) override { return nullptr; } | ||
|
||
}; | ||
|
||
|
||
REGISTER_TENSORRT_PLUGIN(GroupNormPluginCreator); | ||
|
||
} // namespace torch2trt | ||
|
||
|
||
|
||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.