Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatible with TensorRT 8.2.1 #596

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 39 additions & 26 deletions jetbot/ssd_tensorrt/FlattenConcat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <cublas_v2.h>

#include "NvInferPlugin.h"
#include "NvUffParser.h"

// Macro for calling GPU functions
#define CHECK(status) \
Expand Down Expand Up @@ -60,13 +61,13 @@ class FlattenConcat : public IPluginV2
assert(d == a + length);
}

int getNbOutputs() const override
int32_t getNbOutputs() const noexcept
{
// We always return one output
return 1;
}

Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) noexcept
{
// At least one input
assert(nbInputDims >= 1);
Expand All @@ -86,31 +87,32 @@ class FlattenConcat : public IPluginV2
flattenedOutputSize += inputVolume;
}

return DimsCHW(flattenedOutputSize, 1, 1);
// return DimsCHW(flattenedOutputSize, 1, 1);
return Dims3(flattenedOutputSize, 1, 1);
}

int initialize() override
int initialize() noexcept
{
// Called on engine initialization, we initialize cuBLAS library here,
// since we'll be using it for inference
CHECK(cublasCreate(&mCublas));
return 0;
}

void terminate() override
void terminate() noexcept
{
// Called on engine destruction, we destroy cuBLAS data structures,
// which were created in initialize()
CHECK(cublasDestroy(mCublas));
}

size_t getWorkspaceSize(int maxBatchSize) const override
size_t getWorkspaceSize(int maxBatchSize) const noexcept
{
// The operation is done in place, it doesn't use GPU memory
return 0;
}

int enqueue(int batchSize, const void* const* inputs, void** outputs, void*, cudaStream_t stream) override
int32_t enqueue(int32_t batchSize, void const* const* inputs, void* const* outputs, void*, cudaStream_t stream) noexcept
{
// Does the actual concat of inputs, which is just
// copying all inputs bytes to output byte array
Expand All @@ -132,7 +134,7 @@ class FlattenConcat : public IPluginV2
return 0;
}

size_t getSerializationSize() const override
size_t getSerializationSize() const noexcept
{
// Returns FlattenConcat plugin serialization size
size_t size = sizeof(mFlattenedInputSize[0]) * mFlattenedInputSize.size()
Expand All @@ -141,7 +143,7 @@ class FlattenConcat : public IPluginV2
return size;
}

void serialize(void* buffer) const override
void serialize(void* buffer) const noexcept
{
// Serializes FlattenConcat plugin into byte array

Expand All @@ -165,7 +167,7 @@ class FlattenConcat : public IPluginV2
assert(d == a + getSerializationSize());
}

void configureWithFormat(const Dims* inputs, int nbInputs, const Dims* outputDims, int nbOutputs, nvinfer1::DataType type, nvinfer1::PluginFormat format, int maxBatchSize) override
void configureWithFormat(const Dims* inputs, int nbInputs, const Dims* outputDims, int nbOutputs, nvinfer1::DataType type, nvinfer1::PluginFormat format, int maxBatchSize) noexcept
{
// We only support one output
assert(nbOutputs == 1);
Expand Down Expand Up @@ -195,28 +197,39 @@ class FlattenConcat : public IPluginV2
}
}

bool supportsFormat(DataType type, PluginFormat format) const override
bool supportsFormat(DataType type, PluginFormat format) const noexcept
{
return (type == DataType::kFLOAT && format == PluginFormat::kNCHW);
return (type == DataType::kFLOAT && (format == TensorFormat::kLINEAR ||
format == TensorFormat::kCHW2 ||
format == TensorFormat::kHWC8 ||
format == TensorFormat::kCHW4 ||
format == TensorFormat::kCHW16 ||
format == TensorFormat::kCHW32 ||
format == TensorFormat::kDHWC8 ||
format == TensorFormat::kCDHW32 ||
format == TensorFormat::kHWC ||
format == TensorFormat::kDLA_LINEAR ||
format == TensorFormat::kDLA_HWC4 ||
format == TensorFormat::kHWC16));
}

const char* getPluginType() const override { return FLATTENCONCAT_PLUGIN_NAME; }
const char* getPluginType() const noexcept { return FLATTENCONCAT_PLUGIN_NAME; }

const char* getPluginVersion() const override { return FLATTENCONCAT_PLUGIN_VERSION; }
const char* getPluginVersion() const noexcept { return FLATTENCONCAT_PLUGIN_VERSION; }

void destroy() override {}
void destroy() noexcept {}

IPluginV2* clone() const override
IPluginV2* clone() const noexcept
{
return new FlattenConcat(mFlattenedInputSize.data(), mFlattenedInputSize.size(), mFlattenedOutputSize);
}

void setPluginNamespace(const char* pluginNamespace) override
void setPluginNamespace(const char* pluginNamespace) noexcept
{
mPluginNamespace = pluginNamespace;
}

const char* getPluginNamespace() const override
const char* getPluginNamespace() const noexcept
{
return mPluginNamespace.c_str();
}
Expand Down Expand Up @@ -260,29 +273,29 @@ class FlattenConcatPluginCreator : public IPluginCreator

~FlattenConcatPluginCreator() {}

const char* getPluginName() const override { return FLATTENCONCAT_PLUGIN_NAME; }
const char* getPluginName() const noexcept { return FLATTENCONCAT_PLUGIN_NAME; }

const char* getPluginVersion() const override { return FLATTENCONCAT_PLUGIN_VERSION; }
const char* getPluginVersion() const noexcept { return FLATTENCONCAT_PLUGIN_VERSION; }

const PluginFieldCollection* getFieldNames() override { return &mFC; }
const PluginFieldCollection* getFieldNames() noexcept { return &mFC; }

IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) override
IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) noexcept
{
return new FlattenConcat();
}

IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override
IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept
{

return new FlattenConcat(serialData, serialLength);
}

void setPluginNamespace(const char* pluginNamespace) override
void setPluginNamespace(const char* pluginNamespace) noexcept
{
mPluginNamespace = pluginNamespace;
}

const char* getPluginNamespace() const override
const char* getPluginNamespace() const noexcept
{
return mPluginNamespace.c_str();
}
Expand Down