From 6c1a746f22f473b44fe524aaf18e9279a3b01d2e Mon Sep 17 00:00:00 2001 From: Shen Li Date: Fri, 30 Sep 2022 15:55:04 +0800 Subject: [PATCH 1/2] [WIP]refactor: use static init instead of call_once. --- .../mlir/converters/torch_mlir_op_filter.cpp | 34 ++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/pytorch_blade/src/compiler/mlir/converters/torch_mlir_op_filter.cpp b/pytorch_blade/src/compiler/mlir/converters/torch_mlir_op_filter.cpp index ed610eac86a..89cc96c98cc 100644 --- a/pytorch_blade/src/compiler/mlir/converters/torch_mlir_op_filter.cpp +++ b/pytorch_blade/src/compiler/mlir/converters/torch_mlir_op_filter.cpp @@ -37,8 +37,8 @@ bool IsTorchMlirSupported(const torch::jit::Node& node) { } // clang-format off -const std::unordered_set &GetTorchMlirWhiteList() { - static std::unordered_set white_list{ +std::unordered_set CreateTorchMlirWhiteList() { + std::unordered_set white_list{ "aten::_autocast_to_reduced_precision", "aten::__and__", "aten::add", @@ -113,24 +113,26 @@ const std::unordered_set &GetTorchMlirWhiteList() { "torch_blade::fake_quant" }; + auto custom_ops = env::ReadStringFromEnvVar("TORCH_MHLO_OP_WHITE_LIST", ""); + std::ostringstream ostr; + ostr << "User defined white list: ["; + std::istringstream f(custom_ops); + std::string s; + for (auto s : StrSplit(custom_ops, ';')) { + white_list.insert(std::string(s)); + ostr << s << ", "; + } + ostr << "]"; + LOG(INFO) << ostr.str(); - static std::once_flag white; - std::call_once(white, []() { - auto custom_ops = env::ReadStringFromEnvVar("TORCH_MHLO_OP_WHITE_LIST", ""); - std::ostringstream ostr; - ostr << "User defined white list: ["; - std::istringstream f(custom_ops); - std::string s; - for (auto s : StrSplit(custom_ops, ';')) { - white_list.insert(std::string(s)); - ostr << s << ", "; - } - ostr << "]"; - LOG(INFO) << ostr.str(); - }); return white_list; } // clang-format off +const std::unordered_set &GetTorchMlirWhiteList() { + static std::unordered_set white_list = CreateTorchMlirWhiteList(); + return white_list; +} + } // namespace blade } // namespace torch From 43cb14e9d4346acb6b2f083f644f6309e24c5203 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Fri, 30 Sep 2022 19:03:06 +0800 Subject: [PATCH 2/2] remove call_once in InitBladeDiscEngine --- .../src/compiler/mlir/runtime/disc_engine.cpp | 11 +++++------ pytorch_blade/src/compiler/mlir/runtime/disc_engine.h | 1 - .../ltc/disc_compiler/passes/register_disc_class.cpp | 1 - 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/pytorch_blade/src/compiler/mlir/runtime/disc_engine.cpp b/pytorch_blade/src/compiler/mlir/runtime/disc_engine.cpp index 4b6b6c480be..cb21c4e44ec 100644 --- a/pytorch_blade/src/compiler/mlir/runtime/disc_engine.cpp +++ b/pytorch_blade/src/compiler/mlir/runtime/disc_engine.cpp @@ -99,16 +99,15 @@ const char* GetBackendName() { return DiscEngine::GetBackendName(); } +namespace { bool InitBladeDiscEngine() { - static std::once_flag flag; - std::call_once(flag, [&]() { - auto torch_blade_engine_creator = - torch::blade::backends::EngineCreatorRegister().RegisterBackend( - DiscEngine::GetBackendName(), &DiscEngine::Create); - }); + torch::blade::backends::EngineCreatorRegister().RegisterBackend( + DiscEngine::GetBackendName(), &DiscEngine::Create); return true; } static bool init_dummy = InitBladeDiscEngine(); +} // namespace + } // namespace disc } // namespace blade } // namespace torch diff --git a/pytorch_blade/src/compiler/mlir/runtime/disc_engine.h b/pytorch_blade/src/compiler/mlir/runtime/disc_engine.h index 588dda6f54f..1c723b0e3bc 100644 --- a/pytorch_blade/src/compiler/mlir/runtime/disc_engine.h +++ b/pytorch_blade/src/compiler/mlir/runtime/disc_engine.h @@ -15,7 +15,6 @@ namespace torch { namespace blade { namespace disc { const char* GetBackendName(); -bool InitBladeDiscEngine(); } // namespace disc } // namespace blade } // namespace torch diff --git a/pytorch_blade/src/ltc/disc_compiler/passes/register_disc_class.cpp b/pytorch_blade/src/ltc/disc_compiler/passes/register_disc_class.cpp index 9d3f237381e..b3bf8a32687 100644 --- a/pytorch_blade/src/ltc/disc_compiler/passes/register_disc_class.cpp +++ b/pytorch_blade/src/ltc/disc_compiler/passes/register_disc_class.cpp @@ -125,7 +125,6 @@ void ReplaceDiscClass( std::vector RegisterDiscClass( const std::shared_ptr& graph) { torch::blade::backends::InitTorchBladeEngine(); - torch::blade::disc::InitBladeDiscEngine(); std::vector disc_inputs; std::vector disc_nodes; std::copy_if(