diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index 79db36ab..7c61f06c 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -60,7 +60,7 @@ jobs: # - name: 'Tar install directory and metal lib directory' # shell: bash # working-directory: ${{ steps.strings.outputs.install-output-dir }} - # run: | + # run: | # tar cvf artifact.tar . # - name: Upload install folder to archive @@ -150,4 +150,3 @@ jobs: export LD_LIBRARY_PATH="/opt/ttmlir-toolchain/lib/:${{ steps.strings.outputs.install-output-dir }}/lib:${LD_LIBRARY_PATH}" source env/activate pytest -v test - diff --git a/.gitignore b/.gitignore index 2f78cf5b..0d20b648 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1 @@ *.pyc - diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..c337c650 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,23 @@ +repos: + - repo: https://github.com/psf/black + rev: 22.3.0 + hooks: + - id: black + language_version: python3 + - repo: https://github.com/pre-commit/mirrors-clang-format + rev: v18.1.7 + hooks: + - id: clang-format + types_or: [c++, c] + args: [-style=file, -i] + - repo: https://github.com/espressif/check-copyright/ + rev: v1.0.3 + hooks: + - id: check-copyright + args: ['--config', '.github/check-spdx.yaml'] + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.4.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-added-large-files diff --git a/env/activate b/env/activate index ebca09fc..a486539a 100644 --- a/env/activate +++ b/env/activate @@ -35,4 +35,3 @@ else export TT_METAL_LOGGER_LEVEL="ERROR" fi - diff --git a/requirements.txt b/requirements.txt index 4ee83327..e6f692c1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -torch-mlir +torch-mlir torch@https://download.pytorch.org/whl/cpu-cxx11-abi/torch-2.5.0%2Bcpu.cxx11.abi-cp311-cp311-linux_x86_64.whl torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cpu diff --git a/setup.py b/setup.py index 9998ce0c..3b8e2da2 100644 --- a/setup.py +++ b/setup.py @@ -1,32 +1,41 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 import os import sys import subprocess from setuptools import setup, Extension from setuptools.command.build_ext import build_ext + class CMakeBuild(build_ext): def build_extension(self, ext): extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) - cfg = 'Debug' if self.debug else 'Release' + cfg = "Debug" if self.debug else "Release" cmake_args = [ - f'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}', - f'-DPYTHON_EXECUTABLE={sys.executable}', - f'-DCMAKE_BUILD_TYPE={cfg}', + f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}", + f"-DPYTHON_EXECUTABLE={sys.executable}", + f"-DCMAKE_BUILD_TYPE={cfg}", ] - build_args = ['--config', cfg] + build_args = ["--config", cfg] if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) - subprocess.check_call(['cmake', ext.sourcedir] + cmake_args, cwd=self.build_temp) - subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp) + subprocess.check_call( + ["cmake", ext.sourcedir] + cmake_args, cwd=self.build_temp + ) + subprocess.check_call( + ["cmake", "--build", "."] + build_args, cwd=self.build_temp + ) + setup( - name='tt_torch', - version='0.1', - author='Aleks Knezevic', - description='TT PyTorch FrontEnd', - long_description='', - ext_modules=[Extension('c_bindings', sources=[])], - cmdclass={'build_ext': CMakeBuild}, + name="tt_torch", + version="0.1", + author="Aleks Knezevic", + description="TT PyTorch FrontEnd", + long_description="", + ext_modules=[Extension("c_bindings", sources=[])], + cmdclass={"build_ext": CMakeBuild}, zip_safe=False, -) \ No newline at end of file +) diff --git a/test/conftest.py b/test/conftest.py index e1a5c28d..ea4e5099 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -5,6 +5,7 @@ import pytest import torch + @pytest.fixture(autouse=True) def run_around_tests(): yield diff --git a/test/test_basic.py b/test/test_basic.py index f0a40dd8..f18ca14b 100644 --- a/test/test_basic.py +++ b/test/test_basic.py @@ -1,3 +1,6 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 import torch from torch import nn import pytest @@ -5,172 +8,180 @@ import tt_torch from tt_torch.tools.verify import verify_module + def test_add(): - class Basic(nn.Module): - def __init__(self): - super().__init__() + class Basic(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x + x - def forward(self, x): - return x + x - + verify_module(Basic(), [(256, 256)]) - verify_module(Basic(), [(256, 256)]) def test_concat_dim0(): - class Basic(nn.Module): - def __init__(self): - super().__init__() + class Basic(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.cat((x, y), dim=0) + + verify_module(Basic(), [(32, 32), (64, 32)]) - def forward(self, x, y): - return torch.cat((x, y), dim = 0) - - verify_module(Basic(), [(32, 32), (64, 32)]) def test_concat_dim1(): - class Basic(nn.Module): - def __init__(self): - super().__init__() + class Basic(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.cat((x, y), dim=1) + + verify_module(Basic(), [(32, 32), (32, 64)]) - def forward(self, x, y): - return torch.cat((x, y), dim = 1) - - verify_module(Basic(), [(32, 32), (32, 64)]) def test_concat_dim2(): - class Basic(nn.Module): - def __init__(self): - super().__init__() + class Basic(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.cat((x, y), dim=2) + + verify_module(Basic(), [(32, 32, 32), (32, 32, 64)]) - def forward(self, x, y): - return torch.cat((x, y), dim = 2) - - verify_module(Basic(), [(32, 32, 32), (32, 32, 64)]) def test_concat_dim3(): - class Basic(nn.Module): - def __init__(self): - super().__init__() + class Basic(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.cat((x, y), dim=3) + + verify_module(Basic(), [(32, 32, 32, 32), (32, 32, 32, 64)]) - def forward(self, x, y): - return torch.cat((x, y), dim = 3) - - verify_module(Basic(), [(32, 32, 32, 32),(32, 32, 32, 64)]) def test_linear(): - class Basic(nn.Module): - def __init__(self): - super().__init__() - self.linear_a = nn.Linear(32, 64, bias=False) - self.linear_b = nn.Linear(64, 64, bias=False) + class Basic(nn.Module): + def __init__(self): + super().__init__() + self.linear_a = nn.Linear(32, 64, bias=False) + self.linear_b = nn.Linear(64, 64, bias=False) + + def forward(self, x): + x = self.linear_a(x) + x = self.linear_b(x) + return x - def forward(self, x): - x = self.linear_a(x) - x = self.linear_b(x) - return x + verify_module(Basic(), [(32, 32)]) - verify_module(Basic(), [(32, 32)]) from torch_mlir import fx from torch_mlir.compiler_utils import OutputType + def test_linear_with_bias(): - pytest.xfail() - class Basic(nn.Module): - def __init__(self): - super().__init__() - self.linear_a = nn.Linear(32, 32) + pytest.xfail() - def forward(self, x): - x = self.linear_a(x) - return x + class Basic(nn.Module): + def __init__(self): + super().__init__() + self.linear_a = nn.Linear(32, 32) - verify_module(Basic(), [(32, 32)]) + def forward(self, x): + x = self.linear_a(x) + return x + + verify_module(Basic(), [(32, 32)]) def test_relu(): - pytest.xfail() - class Basic(nn.Module): - def __init__(self): - super().__init__() + pytest.xfail() + + class Basic(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.relu(x) - def forward(self, x): - return torch.relu(x) + verify_module(Basic(), [(32, 32)]) - verify_module(Basic(), [(32, 32)]) def test_rsqrt(): - class Basic(nn.Module): - def __init__(self): - super().__init__() + class Basic(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.rsqrt(x) + + verify_module(Basic(), [(32, 32)], required_atol=3e-2, input_range=(0.1, 1)) - def forward(self, x): - return torch.rsqrt(x) - - verify_module(Basic(), [(32, 32)], required_atol=3e-2, input_range=(0.1, 1)) def test_sqrt(): - class Basic(nn.Module): - def __init__(self): - super().__init__() + class Basic(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sqrt(x) + + verify_module(Basic(), [(32, 32)], required_atol=3e-2, input_range=(0.1, 1)) - def forward(self, x): - return torch.sqrt(x) - - verify_module(Basic(), [(32, 32)], required_atol=3e-2, input_range=(0.1, 1)) dim0_cases = [] for begin in torch.arange(10).tolist(): - for end in torch.arange(90, 100).tolist(): - dim0_cases.append((begin, end, 0)) + for end in torch.arange(90, 100).tolist(): + dim0_cases.append((begin, end, 0)) dim1_cases = [] for begin in torch.arange(10).tolist(): - for end in torch.arange(90, 100).tolist(): - dim1_cases.append((begin, end, 1)) + for end in torch.arange(90, 100).tolist(): + dim1_cases.append((begin, end, 1)) -dim2_cases = [] +dim2_cases = [] for begin in torch.arange(0, 64, 32).tolist(): - for end in torch.arange(64, 128, 32).tolist(): - dim2_cases.append((begin, end, 2)) + for end in torch.arange(64, 128, 32).tolist(): + dim2_cases.append((begin, end, 2)) dim3_cases = [] for begin in torch.arange(0, 64, 32).tolist(): - for end in torch.arange(64, 128, 32).tolist(): - dim3_cases.append((begin, end, 3)) + for end in torch.arange(64, 128, 32).tolist(): + dim3_cases.append((begin, end, 3)) + @pytest.mark.parametrize( - "begin, end, dim", - [ - *dim2_cases, - *dim3_cases, - *dim0_cases, - *dim1_cases - ] + "begin, end, dim", [*dim2_cases, *dim3_cases, *dim0_cases, *dim1_cases] ) def test_slice(begin, end, dim): - class Basic(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, a): - if dim == 0: - return a[begin:end, :, :, :] - elif dim == 1: - return a[:, begin:end, :, :] - elif dim == 2: - return a[:, :, begin:end, :] - else: - return a[:, :, :, begin:end] - - shape = [10, 10, 10, 10] - shape[dim] = 128 - verify_module(Basic(), [shape]) + class Basic(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a): + if dim == 0: + return a[begin:end, :, :, :] + elif dim == 1: + return a[:, begin:end, :, :] + elif dim == 2: + return a[:, :, begin:end, :] + else: + return a[:, :, :, begin:end] + + shape = [10, 10, 10, 10] + shape[dim] = 128 + verify_module(Basic(), [shape]) + def test_bert(): - pytest.xfail() - from torch_mlir import fx - from torch_mlir.compiler_utils import OutputType - from transformers import BertModel - bert = BertModel.from_pretrained("prajjwal1/bert-tiny") - verify_module(bert, [(1, 32)], input_data_types=[torch.int32]) - + pytest.xfail() + from torch_mlir import fx + from torch_mlir.compiler_utils import OutputType + from transformers import BertModel + + bert = BertModel.from_pretrained("prajjwal1/bert-tiny") + verify_module(bert, [(1, 32)], input_data_types=[torch.int32]) diff --git a/tt_torch/CMakeLists.txt b/tt_torch/CMakeLists.txt index fbc67602..86735ca2 100644 --- a/tt_torch/CMakeLists.txt +++ b/tt_torch/CMakeLists.txt @@ -1 +1 @@ -add_subdirectory(csrc) \ No newline at end of file +add_subdirectory(csrc) diff --git a/tt_torch/csrc/CMakeLists.txt b/tt_torch/csrc/CMakeLists.txt index ed74214a..c783f259 100644 --- a/tt_torch/csrc/CMakeLists.txt +++ b/tt_torch/csrc/CMakeLists.txt @@ -13,7 +13,7 @@ add_dependencies(TT_TORCH_MLIR set_target_properties(TT_TORCH_MLIR PROPERTIES COMPILE_FLAGS "-fno-rtti") install (TARGETS TT_TORCH_MLIR LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/lib) -target_include_directories(TT_TORCH_MLIR PUBLIC +target_include_directories(TT_TORCH_MLIR PUBLIC ${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir-build/include ${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir-build/include/ttmlir/Target/Common ${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir/include @@ -62,12 +62,12 @@ set(STABLEHLO_LIBS StablehloRegister StablehloBase StablehloPasses - StablehloReferenceErrors + StablehloReferenceErrors StablehloReferenceProcess - StablehloReferenceToken + StablehloReferenceToken StablehloSerialization StablehloBroadcastUtils - StablehloPortableApi + StablehloPortableApi StablehloReferenceIndex StablehloReferenceProcessGrid StablehloReferenceTypes @@ -103,7 +103,7 @@ find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib set(TARGET_NAME tt_mlir) pybind11_add_module(${TARGET_NAME} bindings.cpp) -add_dependencies(${TARGET_NAME} +add_dependencies(${TARGET_NAME} TT_TORCH_MLIR tt-mlir ) @@ -116,7 +116,7 @@ target_include_directories(${TARGET_NAME} PUBLIC ${TTMLIR_TOOLCHAIN_DIR}/include ${TORCH_INCLUDE_DIRS} ) -target_link_libraries(${TARGET_NAME} PUBLIC +target_link_libraries(${TARGET_NAME} PUBLIC TT_TORCH_MLIR ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY} @@ -131,4 +131,4 @@ set_target_properties(${TARGET_NAME} LIBRARY_OUTPUT_NAME ${TARGET_NAME} INSTALL_RPATH "$ORIGIN" ) -install (TARGETS ${TARGET_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/lib) \ No newline at end of file +install (TARGETS ${TARGET_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/lib) diff --git a/tt_torch/csrc/bindings.cpp b/tt_torch/csrc/bindings.cpp index 886d1638..31d865d8 100644 --- a/tt_torch/csrc/bindings.cpp +++ b/tt_torch/csrc/bindings.cpp @@ -1,124 +1,139 @@ -#include -#include -#include "tt-mlir-interface.hpp" +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 +#include "tt-mlir-interface.hpp" +#include +#include namespace py = pybind11; - tt::runtime::Binary compile(std::string_view code) { - return tt::torch::Compile(code); + return tt::torch::Compile(code); } - -static tt::target::DataType torch_scalar_type_to_dt(torch::ScalarType st) -{ - switch (st) - { - case torch::ScalarType::Byte: return tt::target::DataType::UInt8; - case torch::ScalarType::Char: return tt::target::DataType::UInt8; - case torch::ScalarType::Short: return tt::target::DataType::UInt16; - case torch::ScalarType::Int: return tt::target::DataType::UInt32; - case torch::ScalarType::Long: return tt::target::DataType::UInt32; - case torch::ScalarType::Half: return tt::target::DataType::Float16; - case torch::ScalarType::Float: return tt::target::DataType::Float32; - // case torch::ScalarType::Double: - // case torch::ScalarType::ComplexHalf: - // case torch::ScalarType::ComplexFloat: - // case torch::ScalarType::ComplexDouble: - // case torch::ScalarType::Bool: - case torch::ScalarType::BFloat16: return tt::target::DataType::BFloat16; - default: break; - } - assert(false && "Unsupported scalar type"); - +static tt::target::DataType torch_scalar_type_to_dt(torch::ScalarType st) { + switch (st) { + case torch::ScalarType::Byte: + return tt::target::DataType::UInt8; + case torch::ScalarType::Char: + return tt::target::DataType::UInt8; + case torch::ScalarType::Short: + return tt::target::DataType::UInt16; + case torch::ScalarType::Int: + return tt::target::DataType::UInt32; + case torch::ScalarType::Long: + return tt::target::DataType::UInt32; + case torch::ScalarType::Half: + return tt::target::DataType::Float16; + case torch::ScalarType::Float: + return tt::target::DataType::Float32; + // case torch::ScalarType::Double: + // case torch::ScalarType::ComplexHalf: + // case torch::ScalarType::ComplexFloat: + // case torch::ScalarType::ComplexDouble: + // case torch::ScalarType::Bool: + case torch::ScalarType::BFloat16: + return tt::target::DataType::BFloat16; + default: + break; + } + assert(false && "Unsupported scalar type"); } -static torch::ScalarType dt_to_torch_scalar_type(tt::target::DataType df) -{ - switch (df) - { - case tt::target::DataType::UInt8: return torch::ScalarType::Byte; - case tt::target::DataType::UInt16: return torch::ScalarType::Short; - case tt::target::DataType::UInt32: return torch::ScalarType::Int; - case tt::target::DataType::Float16: return torch::ScalarType::Half; - case tt::target::DataType::Float32: return torch::ScalarType::Float; - case tt::target::DataType::BFloat16: return torch::ScalarType::BFloat16; - default: break; - } - assert(false && "Unsupported scalar type"); - +static torch::ScalarType dt_to_torch_scalar_type(tt::target::DataType df) { + switch (df) { + case tt::target::DataType::UInt8: + return torch::ScalarType::Byte; + case tt::target::DataType::UInt16: + return torch::ScalarType::Short; + case tt::target::DataType::UInt32: + return torch::ScalarType::Int; + case tt::target::DataType::Float16: + return torch::ScalarType::Half; + case tt::target::DataType::Float32: + return torch::ScalarType::Float; + case tt::target::DataType::BFloat16: + return torch::ScalarType::BFloat16; + default: + break; + } + assert(false && "Unsupported scalar type"); } - -static tt::runtime::Tensor create_tensor(const torch::Tensor& tensor) -{ - auto data = std::shared_ptr( - tensor.data_ptr(), - [tensor](void*) { (void)tensor; } // Capture tensor by value to increase ref count and keep it alive - ); - - auto shape = std::vector(tensor.sizes().begin(), tensor.sizes().end()); - auto stride = std::vector(tensor.strides().begin(), tensor.strides().end()); - - return tt::runtime::createTensor( - data, shape, stride, tensor.element_size(), torch_scalar_type_to_dt(tensor.scalar_type())); +static tt::runtime::Tensor create_tensor(const torch::Tensor &tensor) { + auto data = std::shared_ptr( + tensor.data_ptr(), + [tensor](void *) { + (void)tensor; + } // Capture tensor by value to increase ref count and keep it alive + ); + + auto shape = + std::vector(tensor.sizes().begin(), tensor.sizes().end()); + auto stride = + std::vector(tensor.strides().begin(), tensor.strides().end()); + + return tt::runtime::createTensor( + data, shape, stride, tensor.element_size(), + torch_scalar_type_to_dt(tensor.scalar_type())); } template -std::vector as_vec_int64(std::vector const& vec) -{ - std::vector result; - result.reserve(vec.size()); - for (auto const& v : vec) - { - result.push_back(v); - } - return result; +std::vector as_vec_int64(std::vector const &vec) { + std::vector result; + result.reserve(vec.size()); + for (auto const &v : vec) { + result.push_back(v); + } + return result; } -std::vector run(const std::vector& inputs, tt::runtime::Binary binary) { - auto [system_desc, chip_ids] = tt::runtime::getCurrentSystemDesc(); - int dev_0 = chip_ids[0]; - auto device = tt::runtime::openDevice({dev_0}); - - int program_idx = 0; - auto input_descs = binary.getProgramInputs(program_idx); - - std::vector rt_inputs; - for (auto const& input : inputs) - { - rt_inputs.emplace_back(create_tensor(input)); - } - - std::vector outputs; - std::vector rt_outputs; - std::vector output_descs = binary.getProgramOutputs(program_idx); - outputs.reserve(output_descs.size()); - for (auto const& desc : output_descs) - { - std::vector shape = as_vec_int64(desc.shape); - std::vector stride = as_vec_int64(desc.stride); - - at::Tensor output = at::empty_strided(shape, stride, dt_to_torch_scalar_type(desc.dataType)); - outputs.emplace_back(std::move(output)); - rt_outputs.emplace_back(create_tensor(outputs.back())); - } - - tt::runtime::Event event = tt::runtime::submit(device, binary, program_idx, rt_inputs, rt_outputs); - (void)event; - tt::runtime::closeDevice(device); - return outputs; +std::vector run(const std::vector &inputs, + tt::runtime::Binary binary) { + auto [system_desc, chip_ids] = tt::runtime::getCurrentSystemDesc(); + int dev_0 = chip_ids[0]; + auto device = tt::runtime::openDevice({dev_0}); + + int program_idx = 0; + auto input_descs = binary.getProgramInputs(program_idx); + + std::vector rt_inputs; + for (auto const &input : inputs) { + rt_inputs.emplace_back(create_tensor(input)); + } + + std::vector outputs; + std::vector rt_outputs; + std::vector output_descs = + binary.getProgramOutputs(program_idx); + outputs.reserve(output_descs.size()); + for (auto const &desc : output_descs) { + std::vector shape = as_vec_int64(desc.shape); + std::vector stride = as_vec_int64(desc.stride); + + at::Tensor output = at::empty_strided( + shape, stride, dt_to_torch_scalar_type(desc.dataType)); + outputs.emplace_back(std::move(output)); + rt_outputs.emplace_back(create_tensor(outputs.back())); + } + + tt::runtime::Event event = + tt::runtime::submit(device, binary, program_idx, rt_inputs, rt_outputs); + (void)event; + tt::runtime::closeDevice(device); + return outputs; } PYBIND11_MODULE(tt_mlir, m) { - m.doc() = "tt_mlir"; - py::class_(m, "Binary") - .def("getProgramInputs", &tt::runtime::Binary::getProgramInputs) - .def("getProgramOutputs", &tt::runtime::Binary::getProgramOutputs); - m.def("compile", &compile, "A function that compiles a stableHLO model to a flatbuffer"); - m.def("run", &run, "Push inputs and run binary"); - m.def("get_current_system_desc", &tt::runtime::getCurrentSystemDesc, - "Get the current system descriptor"); - m.def("get_num_available_devices", &tt::runtime::getNumAvailableDevices, - "Get the number of available devices"); + m.doc() = "tt_mlir"; + py::class_(m, "Binary") + .def("getProgramInputs", &tt::runtime::Binary::getProgramInputs) + .def("getProgramOutputs", &tt::runtime::Binary::getProgramOutputs); + m.def("compile", &compile, + "A function that compiles a stableHLO model to a flatbuffer"); + m.def("run", &run, "Push inputs and run binary"); + m.def("get_current_system_desc", &tt::runtime::getCurrentSystemDesc, + "Get the current system descriptor"); + m.def("get_num_available_devices", &tt::runtime::getNumAvailableDevices, + "Get the number of available devices"); } diff --git a/tt_torch/csrc/tt-mlir-interface.cpp b/tt_torch/csrc/tt-mlir-interface.cpp index 2492159f..0f0b0187 100644 --- a/tt_torch/csrc/tt-mlir-interface.cpp +++ b/tt_torch/csrc/tt-mlir-interface.cpp @@ -14,45 +14,43 @@ #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/IR/MLIRContext.h" - -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/OwningOpRef.h" // from @llvm-project -#include "mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/Parser/Parser.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project -#include "stablehlo/dialect/ChloOps.h" // from @stablehlo -#include "stablehlo/dialect/Register.h" // from @stablehlo -#include "stablehlo/dialect/Serialization.h" // from @stablehlo -#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo -#include "stablehlo/transforms/Passes.h" // from @stablehlo +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/InitAllPasses.h" +#include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "stablehlo/dialect/ChloOps.h" // from @stablehlo +#include "stablehlo/dialect/Register.h" // from @stablehlo +#include "stablehlo/dialect/Serialization.h" // from @stablehlo +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "stablehlo/transforms/Passes.h" // from @stablehlo #define TTMLIR_ENABLE_STABLEHLO +#include "ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h" #include "ttmlir/Dialect/TTIR/Transforms/Passes.h" -#include "ttmlir/Dialect/TTNN/Transforms/Passes.h" #include "ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h" -#include "ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h" -#include "ttmlir/Target/TTNN/TTNNToFlatbuffer.h" +#include "ttmlir/Dialect/TTNN/Transforms/Passes.h" #include "ttmlir/RegisterAll.h" +#include "ttmlir/Target/TTNN/TTNNToFlatbuffer.h" #include "tt/runtime/runtime.h" namespace tt::torch { - tt::runtime::Binary Compile(std::string_view code) { mlir::MLIRContext context; mlir::DialectRegistry registry; - + registry.insert(); registry.insert(); registry.insert(); @@ -77,43 +75,40 @@ tt::runtime::Binary Compile(std::string_view code) { mlir::tt::ttir::registerPasses(); mlir::tt::ttnn::registerPasses(); - // Implicit nesting required to call the stablehlo.composite --> func.call conversion. - mlir::PassManager shlo_pm(mlir_module.get()->getName(), mlir::PassManager::Nesting::Implicit); + // Implicit nesting required to call the stablehlo.composite --> func.call + // conversion. + mlir::PassManager shlo_pm(mlir_module.get()->getName(), + mlir::PassManager::Nesting::Implicit); mlir::tt::ttir::StableHLOToTTIRPipelineOptions shlo_options; shlo_options.arithDialectConversionsEnabled = true; shlo_options.removeDeadValuesEnabled = true; shlo_options.legalizeCompositeToCallEnabled = true; mlir::tt::ttir::createStableHLOToTTIRPipeline(shlo_pm, shlo_options); // Run the pass manager. - if (mlir::failed(shlo_pm.run(mlir_module.get()))) - { - throw std::runtime_error("Failed to run MLIR compiler pass pipeline."); + if (mlir::failed(shlo_pm.run(mlir_module.get()))) { + throw std::runtime_error("Failed to run MLIR compiler pass pipeline."); } mlir_module->dump(); - mlir::PassManager pm(mlir_module.get()->getName()); mlir::tt::ttnn::TTIRToTTNNBackendPipelineOptions options; mlir::tt::ttnn::createTTIRToTTNNBackendPipeline(pm, options); - + // Run the pass manager. - if (mlir::failed(pm.run(mlir_module.get()))) - { - throw std::runtime_error("Failed to run MLIR compiler pass pipeline."); + if (mlir::failed(pm.run(mlir_module.get()))) { + throw std::runtime_error("Failed to run MLIR compiler pass pipeline."); } mlir_module->dump(); auto binary_ptr = mlir::tt::ttnn::ttnnToFlatbuffer(mlir_module.get()); - - if (binary_ptr == nullptr) - { - throw std::runtime_error("Failed to generate flatbuffer binary."); + + if (binary_ptr == nullptr) { + throw std::runtime_error("Failed to generate flatbuffer binary."); } tt::runtime::Binary binary(binary_ptr); return binary; - } -} // namespace tt::torch +} // namespace tt::torch diff --git a/tt_torch/csrc/tt-mlir-interface.hpp b/tt_torch/csrc/tt-mlir-interface.hpp index c8b2f84c..5614681b 100644 --- a/tt_torch/csrc/tt-mlir-interface.hpp +++ b/tt_torch/csrc/tt-mlir-interface.hpp @@ -1,3 +1,7 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + #pragma once #include #include @@ -5,5 +9,5 @@ #include "tt/runtime/runtime.h" namespace tt::torch { - tt::runtime::Binary Compile(std::string_view code); -} // namespace tt::torch \ No newline at end of file +tt::runtime::Binary Compile(std::string_view code); +} // namespace tt::torch diff --git a/tt_torch/dynamo/backend.py b/tt_torch/dynamo/backend.py index 3511af61..3f9b3df8 100644 --- a/tt_torch/dynamo/backend.py +++ b/tt_torch/dynamo/backend.py @@ -1,3 +1,6 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 import torch from torch._dynamo.backends.common import aot_autograd from torch.fx.experimental.proxy_tensor import make_fx @@ -16,21 +19,28 @@ ) from typing import List, Tuple, Union + def execute(gm, inputs): return gm(*inputs) -class Executor(): + +class Executor: def __init__(self, binary): self.binary = binary - + def __call__(self, *inputs): return tt_mlir.run(inputs, self.binary) + def reduce_graph(module_or_graph: Union[torch.fx.Graph, torch.fx.GraphModule]): # Reduce the graph to only the nodes that are used # Traverse up the graph from output nodes to populate consumed nodes set - graph = module_or_graph.graph if isinstance(module_or_graph, torch.fx.GraphModule) else module_or_graph + graph = ( + module_or_graph.graph + if isinstance(module_or_graph, torch.fx.GraphModule) + else module_or_graph + ) consumed = set() working_nodes = [] for node in graph.nodes: @@ -57,6 +67,7 @@ def reduce_graph(module_or_graph: Union[torch.fx.Graph, torch.fx.GraphModule]): # Remove the output node if it's the only one graph.erase_node(node) + def _base_backend(gm: torch.fx.GraphModule, example_inputs): gm.graph.print_tabular() gm = pass_pipeline(gm, example_inputs) @@ -86,4 +97,6 @@ def backend(gm, example_inputs): # aten = make_fx(gm, tracing_mode="symbolic", decomposition_table={}, _allow_non_fake_inputs=True)(*example_inputs) # return _base_backend(aten, example_inputs) return _base_backend(gm, example_inputs) + + # backend = aot_autograd(fw_compiler=_base_backend) diff --git a/tt_torch/dynamo/decompositions.py b/tt_torch/dynamo/decompositions.py index 189e18e1..a00bb1ef 100644 --- a/tt_torch/dynamo/decompositions.py +++ b/tt_torch/dynamo/decompositions.py @@ -1,3 +1,6 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 from typing import Callable, Dict, List, Optional, Sequence, Union import contextlib @@ -6,6 +9,7 @@ import torch from torch._decomp import get_decompositions, remove_decompositions from torch_mlir.extras.fx_decomp_util import get_decomposition_table + DecompositionTable = Dict[torch._ops.OperatorBase, Callable] DecompositionOpsList = Sequence[ Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket] @@ -62,7 +66,8 @@ def _extend_context_manager( popped is table ), "contextmanager unbalanced: popped different that pushed" -#TODO: DO we ever need this? + +# TODO: DO we ever need this? def _get_default_decomposition_ops() -> DecompositionOpsList: aten = torch.ops.aten # default decompositions pulled from SHARK / torch._decomp diff --git a/tt_torch/dynamo/passes.py b/tt_torch/dynamo/passes.py index 9ef438cc..adecd7ed 100644 --- a/tt_torch/dynamo/passes.py +++ b/tt_torch/dynamo/passes.py @@ -1,3 +1,6 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 import torch from torch.fx.experimental.proxy_tensor import make_fx from torch._decomp import get_decompositions diff --git a/tt_torch/tools/utils.py b/tt_torch/tools/utils.py index 396c31de..3f424f1e 100644 --- a/tt_torch/tools/utils.py +++ b/tt_torch/tools/utils.py @@ -1,124 +1,128 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 import re import json def extract_shape(shape_str): - if not shape_str.startswith("tensor<"): - breakpoint() - assert shape_str.startswith("tensor<") - assert shape_str.endswith(">") - shape_str = shape_str[len("tensor<") : -1] - dims = shape_str.split("x") - return [int(dim) for dim in dims[:-1]] + if not shape_str.startswith("tensor<"): + breakpoint() + assert shape_str.startswith("tensor<") + assert shape_str.endswith(">") + shape_str = shape_str[len("tensor<") : -1] + dims = shape_str.split("x") + return [int(dim) for dim in dims[:-1]] -def split_top(string, splitter=",", openers="([{<", closers = ")]}>", whitespace=" \n\t"): - outlist = [] - outstring = [] +def split_top(string, splitter=",", openers="([{<", closers=")]}>", whitespace=" \n\t"): + outlist = [] + outstring = [] - depth = 0 + depth = 0 - for c in string: - if c in openers: - depth += 1 - elif c in closers: - depth -= 1 + for c in string: + if c in openers: + depth += 1 + elif c in closers: + depth -= 1 - if depth < 0: - raise SyntaxError() + if depth < 0: + raise SyntaxError() - if not depth and c == splitter: - outlist.append("".join(outstring)) - outstring = [] - else: - if len(outstring): - outstring.append(c) - elif c not in whitespace: - outstring.append(c) + if not depth and c == splitter: + outlist.append("".join(outstring)) + outstring = [] + else: + if len(outstring): + outstring.append(c) + elif c not in whitespace: + outstring.append(c) + + outlist.append("".join(outstring)) - outlist.append("".join(outstring)) + return outlist - return outlist def print_shape(shape): - return "x".join([str(dim) for dim in shape]) + return "x".join([str(dim) for dim in shape]) + def parse_mlir(mlir_code, verbose=False): - ops = [] - unique_ops = {} - for line in mlir_code.splitlines(): - line = line.strip() - if not line.startswith('%'): - continue - if verbose: - print(line) - - output = line.split(" = ")[0].strip() - # if output == "%21": - # breakpoint() - op_name = line.split(' = ')[1].split(" ")[0] - if op_name.startswith("\""): - op_name = op_name.split("\"")[1] - elif "(" in op_name: - op_name = op_name.split("(")[0] - if verbose: - print(f" op_name: {op_name}") - #reduce is special cased - args_and_attr = line.split(op_name)[1] - if op_name == "stablehlo.reduce": - op_name += "_" + args_and_attr.split("applies")[1].strip().split(" ")[0] - dim = args_and_attr.split("dimensions = ")[1].split(" ")[0] - attr = {"dim": dim} - args = [args_and_attr.split(")")[0].strip("(")] - else: - args_and_attr = line.split(op_name)[1] - args_and_attr = args_and_attr[:args_and_attr.rfind(":")] - args_and_attr = split_top(args_and_attr) - args = [] - attr = {} - for arg in args_and_attr: - if "=" in arg: - key, value = arg.split("=") - attr[key.strip()] = value.strip() + ops = [] + unique_ops = {} + for line in mlir_code.splitlines(): + line = line.strip() + if not line.startswith("%"): + continue + if verbose: + print(line) + + output = line.split(" = ")[0].strip() + # if output == "%21": + # breakpoint() + op_name = line.split(" = ")[1].split(" ")[0] + if op_name.startswith('"'): + op_name = op_name.split('"')[1] + elif "(" in op_name: + op_name = op_name.split("(")[0] + if verbose: + print(f" op_name: {op_name}") + # reduce is special cased + args_and_attr = line.split(op_name)[1] + if op_name == "stablehlo.reduce": + op_name += "_" + args_and_attr.split("applies")[1].strip().split(" ")[0] + dim = args_and_attr.split("dimensions = ")[1].split(" ")[0] + attr = {"dim": dim} + args = [args_and_attr.split(")")[0].strip("(")] else: - args.append(arg.strip()) - if verbose: - print(f" args: {args}") - print(f" attr: {attr}") - io_shapes = line[line.rfind(":")+1:].strip() - io_shapes = io_shapes.split(" -> ") - if len(io_shapes) == 1: - # input and output shapes are the same - input_shapes = io_shapes[0].split(", ") - output_shapes = io_shapes[0].split(", ") - else: - input_shapes, output_shapes = io_shapes - output_shapes = output_shapes.split(", ") - input_shapes = input_shapes.strip("(").strip(")") - input_shapes = input_shapes.split(", ") - input_shapes = [extract_shape(shape) for shape in input_shapes] - output_shapes = [extract_shape(shape) for shape in output_shapes] - if verbose: - print(f" input_shapes: {input_shapes}") - print(f" output_shape: {output_shapes}") - op = (output, op_name, args, attr, input_shapes, output_shapes, line) - ops.append(op) - - if op_name not in unique_ops: - unique_ops[op_name] = {} - - if len(input_shapes) == 0: - key = "" - else: - key = print_shape(input_shapes[0]) - for shape in input_shapes[1:]: - key += f"_x_{print_shape(shape)}" - if key not in unique_ops[op_name]: - unique_ops[op_name][key] = {} - unique_ops[op_name][key]["ops"] = [] - unique_ops[op_name][key]["num_ops"] = 1 - else: - unique_ops[op_name][key]["num_ops"] += 1 - unique_ops[op_name][key]["ops"].append(op) - return ops, unique_ops + args_and_attr = line.split(op_name)[1] + args_and_attr = args_and_attr[: args_and_attr.rfind(":")] + args_and_attr = split_top(args_and_attr) + args = [] + attr = {} + for arg in args_and_attr: + if "=" in arg: + key, value = arg.split("=") + attr[key.strip()] = value.strip() + else: + args.append(arg.strip()) + if verbose: + print(f" args: {args}") + print(f" attr: {attr}") + io_shapes = line[line.rfind(":") + 1 :].strip() + io_shapes = io_shapes.split(" -> ") + if len(io_shapes) == 1: + # input and output shapes are the same + input_shapes = io_shapes[0].split(", ") + output_shapes = io_shapes[0].split(", ") + else: + input_shapes, output_shapes = io_shapes + output_shapes = output_shapes.split(", ") + input_shapes = input_shapes.strip("(").strip(")") + input_shapes = input_shapes.split(", ") + input_shapes = [extract_shape(shape) for shape in input_shapes] + output_shapes = [extract_shape(shape) for shape in output_shapes] + if verbose: + print(f" input_shapes: {input_shapes}") + print(f" output_shape: {output_shapes}") + op = (output, op_name, args, attr, input_shapes, output_shapes, line) + ops.append(op) + + if op_name not in unique_ops: + unique_ops[op_name] = {} + if len(input_shapes) == 0: + key = "" + else: + key = print_shape(input_shapes[0]) + for shape in input_shapes[1:]: + key += f"_x_{print_shape(shape)}" + if key not in unique_ops[op_name]: + unique_ops[op_name][key] = {} + unique_ops[op_name][key]["ops"] = [] + unique_ops[op_name][key]["num_ops"] = 1 + else: + unique_ops[op_name][key]["num_ops"] += 1 + unique_ops[op_name][key]["ops"].append(op) + return ops, unique_ops diff --git a/tt_torch/tools/verify.py b/tt_torch/tools/verify.py index 949c9a2a..de1ce7c9 100644 --- a/tt_torch/tools/verify.py +++ b/tt_torch/tools/verify.py @@ -1,15 +1,28 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 import torch import numpy as np from tt_torch.dynamo.backend import backend -def verify_module(mod, input_shapes, input_data_types=[torch.float32],required_pcc=0.99, required_atol=1e-2, input_range=(-0.5, 0.5)): + +def verify_module( + mod, + input_shapes, + input_data_types=[torch.float32], + required_pcc=0.99, + required_atol=1e-2, + input_range=(-0.5, 0.5), +): tt_mod = torch.compile(mod, backend=backend) if all([dtype.is_floating_point for dtype in input_data_types]): - low, high = input_range - inputs = [(low - high) * torch.rand(shape) + high for shape in input_shapes] # uniformly distribute random numbers within the input_range + low, high = input_range + inputs = [ + (low - high) * torch.rand(shape) + high for shape in input_shapes + ] # uniformly distribute random numbers within the input_range else: - inputs = [torch.randint(0, 1000, shape) for shape in input_shapes] + inputs = [torch.randint(0, 1000, shape) for shape in input_shapes] ret = tt_mod(*inputs) golden = mod(*inputs) @@ -21,4 +34,4 @@ def verify_module(mod, input_shapes, input_data_types=[torch.float32],required_p np.ma.masked_invalid(torch.squeeze(golden).detach().numpy()).flatten(), ) ) - assert pcc >= required_pcc, f"PCC too low: {pcc} vs {required_pcc}" \ No newline at end of file + assert pcc >= required_pcc, f"PCC too low: {pcc} vs {required_pcc}"