From 6d83807a13a02fd7ac2c3bb0433c7945c6976fdf Mon Sep 17 00:00:00 2001 From: jiapei100 Date: Wed, 2 Aug 2023 14:35:29 -0700 Subject: [PATCH] Compatible with TensorRT 8.2.1 --- jetbot/ssd_tensorrt/FlattenConcat.cpp | 65 ++++++++++++++++----------- 1 file changed, 39 insertions(+), 26 deletions(-) diff --git a/jetbot/ssd_tensorrt/FlattenConcat.cpp b/jetbot/ssd_tensorrt/FlattenConcat.cpp index 6fb4d148..ad1b6dc5 100644 --- a/jetbot/ssd_tensorrt/FlattenConcat.cpp +++ b/jetbot/ssd_tensorrt/FlattenConcat.cpp @@ -7,6 +7,7 @@ #include #include "NvInferPlugin.h" +#include "NvUffParser.h" // Macro for calling GPU functions #define CHECK(status) \ @@ -60,13 +61,13 @@ class FlattenConcat : public IPluginV2 assert(d == a + length); } - int getNbOutputs() const override + int32_t getNbOutputs() const noexcept { // We always return one output return 1; } - Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override + Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) noexcept { // At least one input assert(nbInputDims >= 1); @@ -86,10 +87,11 @@ class FlattenConcat : public IPluginV2 flattenedOutputSize += inputVolume; } - return DimsCHW(flattenedOutputSize, 1, 1); + // return DimsCHW(flattenedOutputSize, 1, 1); + return Dims3(flattenedOutputSize, 1, 1); } - int initialize() override + int initialize() noexcept { // Called on engine initialization, we initialize cuBLAS library here, // since we'll be using it for inference @@ -97,20 +99,20 @@ class FlattenConcat : public IPluginV2 return 0; } - void terminate() override + void terminate() noexcept { // Called on engine destruction, we destroy cuBLAS data structures, // which were created in initialize() CHECK(cublasDestroy(mCublas)); } - size_t getWorkspaceSize(int maxBatchSize) const override + size_t getWorkspaceSize(int maxBatchSize) const noexcept { // The operation is done in place, it doesn't use GPU memory return 0; } - - int enqueue(int batchSize, const void* const* inputs, void** outputs, void*, cudaStream_t stream) override + + int32_t enqueue(int32_t batchSize, void const* const* inputs, void* const* outputs, void*, cudaStream_t stream) noexcept { // Does the actual concat of inputs, which is just // copying all inputs bytes to output byte array @@ -132,7 +134,7 @@ class FlattenConcat : public IPluginV2 return 0; } - size_t getSerializationSize() const override + size_t getSerializationSize() const noexcept { // Returns FlattenConcat plugin serialization size size_t size = sizeof(mFlattenedInputSize[0]) * mFlattenedInputSize.size() @@ -141,7 +143,7 @@ class FlattenConcat : public IPluginV2 return size; } - void serialize(void* buffer) const override + void serialize(void* buffer) const noexcept { // Serializes FlattenConcat plugin into byte array @@ -165,7 +167,7 @@ class FlattenConcat : public IPluginV2 assert(d == a + getSerializationSize()); } - void configureWithFormat(const Dims* inputs, int nbInputs, const Dims* outputDims, int nbOutputs, nvinfer1::DataType type, nvinfer1::PluginFormat format, int maxBatchSize) override + void configureWithFormat(const Dims* inputs, int nbInputs, const Dims* outputDims, int nbOutputs, nvinfer1::DataType type, nvinfer1::PluginFormat format, int maxBatchSize) noexcept { // We only support one output assert(nbOutputs == 1); @@ -195,28 +197,39 @@ class FlattenConcat : public IPluginV2 } } - bool supportsFormat(DataType type, PluginFormat format) const override + bool supportsFormat(DataType type, PluginFormat format) const noexcept { - return (type == DataType::kFLOAT && format == PluginFormat::kNCHW); + return (type == DataType::kFLOAT && (format == TensorFormat::kLINEAR || + format == TensorFormat::kCHW2 || + format == TensorFormat::kHWC8 || + format == TensorFormat::kCHW4 || + format == TensorFormat::kCHW16 || + format == TensorFormat::kCHW32 || + format == TensorFormat::kDHWC8 || + format == TensorFormat::kCDHW32 || + format == TensorFormat::kHWC || + format == TensorFormat::kDLA_LINEAR || + format == TensorFormat::kDLA_HWC4 || + format == TensorFormat::kHWC16)); } - const char* getPluginType() const override { return FLATTENCONCAT_PLUGIN_NAME; } + const char* getPluginType() const noexcept { return FLATTENCONCAT_PLUGIN_NAME; } - const char* getPluginVersion() const override { return FLATTENCONCAT_PLUGIN_VERSION; } + const char* getPluginVersion() const noexcept { return FLATTENCONCAT_PLUGIN_VERSION; } - void destroy() override {} + void destroy() noexcept {} - IPluginV2* clone() const override + IPluginV2* clone() const noexcept { return new FlattenConcat(mFlattenedInputSize.data(), mFlattenedInputSize.size(), mFlattenedOutputSize); } - void setPluginNamespace(const char* pluginNamespace) override + void setPluginNamespace(const char* pluginNamespace) noexcept { mPluginNamespace = pluginNamespace; } - const char* getPluginNamespace() const override + const char* getPluginNamespace() const noexcept { return mPluginNamespace.c_str(); } @@ -260,29 +273,29 @@ class FlattenConcatPluginCreator : public IPluginCreator ~FlattenConcatPluginCreator() {} - const char* getPluginName() const override { return FLATTENCONCAT_PLUGIN_NAME; } + const char* getPluginName() const noexcept { return FLATTENCONCAT_PLUGIN_NAME; } - const char* getPluginVersion() const override { return FLATTENCONCAT_PLUGIN_VERSION; } + const char* getPluginVersion() const noexcept { return FLATTENCONCAT_PLUGIN_VERSION; } - const PluginFieldCollection* getFieldNames() override { return &mFC; } + const PluginFieldCollection* getFieldNames() noexcept { return &mFC; } - IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) override + IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) noexcept { return new FlattenConcat(); } - IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override + IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept { return new FlattenConcat(serialData, serialLength); } - void setPluginNamespace(const char* pluginNamespace) override + void setPluginNamespace(const char* pluginNamespace) noexcept { mPluginNamespace = pluginNamespace; } - const char* getPluginNamespace() const override + const char* getPluginNamespace() const noexcept { return mPluginNamespace.c_str(); }