Skip to content

Commit

Permalink
Add slice tests, fix module compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
LPanosTT committed Oct 28, 2024
1 parent 46675ae commit 00fe305
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 3 deletions.
23 changes: 23 additions & 0 deletions test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torch import nn
import pytest

import tt_torch
from tt_torch.tools.verify import verify_module

def test_add():
Expand Down Expand Up @@ -116,6 +117,28 @@ def forward(self, x):

verify_module(Basic(), [(32, 32)], required_atol=3e-2, input_range=(0.1, 1))

@pytest.mark.parametrize("begin_W", torch.arange(64).tolist())
@pytest.mark.parametrize("end_W", torch.arange(64, 128).tolist())
@pytest.mark.parametrize("dim", [0, 1, 2, 3])
def test_slice(begin_W, end_W, dim):
class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, a):
if dim == 0:
return a[begin_W:end_W, :, :, :]
elif dim == 1:
return a[:, begin_W:end_W, :, :]
elif dim == 2:
return a[:, :, begin_W:end_W, :]
else:
return a[:, :, :, begin_W:end_W]

shape = [10, 10, 10, 10]
shape[dim] = 128
verify_module(Basic(), [shape])

def test_bert():
from torch_mlir import fx
from torch_mlir.compiler_utils import OutputType
Expand Down
2 changes: 1 addition & 1 deletion third_party/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
#

set(TT_MLIR_VERSION "5267b3964628f8d76f225b1fd0e4403fdcd29d08")
set(TT_MLIR_VERSION "1fae4f9430fac400d653b2c87f7f549bd34a29a8")

if (TOOLCHAIN STREQUAL "ON")
cmake_minimum_required(VERSION 3.20)
Expand Down
5 changes: 5 additions & 0 deletions tt_torch/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ target_include_directories(TT_TORCH_MLIR PUBLIC
${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir-build/stablehlo/
${TTMLIR_TOOLCHAIN_DIR}/include
${TTMLIR_TOOLCHAIN_DIR}/src/stablehlo
${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir-build/include
${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir-build/lib
${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir/lib
)

set(STABLEHLO_LIBS
Expand Down Expand Up @@ -78,6 +81,8 @@ set(STABLEHLO_LIBS
StablehloReferenceValue
)

target_link_libraries(TT_TORCH_MLIR PUBLIC ${PROJECT_SOURCE_DIR}/install/lib/libTTMLIRStatic.a)
target_link_libraries(TT_TORCH_MLIR PUBLIC ${PROJECT_SOURCE_DIR}/install/lib/libTTMLIRTosaToTTIR.a)
target_link_libraries(TT_TORCH_MLIR PUBLIC
LLVM
MLIR
Expand Down
10 changes: 8 additions & 2 deletions tt_torch/csrc/tt-mlir-interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@
#include "stablehlo/dialect/Serialization.h" // from @stablehlo
#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo
#include "stablehlo/transforms/Passes.h" // from @stablehlo
#include "mlir/InitAllPasses.h"

#define TTMLIR_ENABLE_STABLEHLO
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h"
#include "ttmlir/Dialect/TTNN/Transforms/Passes.h"
#include "ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h"
#include "ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h"
#include "ttmlir/Target/TTNN/TTNNToFlatbuffer.h"
#include "ttmlir/RegisterAll.h"

#include "tt/runtime/runtime.h"

Expand All @@ -56,8 +58,10 @@ tt::runtime::Binary Compile(std::string_view code) {
registry.insert<mlir::ml_program::MLProgramDialect>();
registry.insert<mlir::shape::ShapeDialect>();

mlir::tt::registerAllDialects(registry);
mlir::stablehlo::registerAllDialects(registry);
mlir::func::registerAllExtensions(registry);
mlir::tt::registerAllExtensions(registry);

context.appendDialectRegistry(registry);

Expand All @@ -69,14 +73,16 @@ tt::runtime::Binary Compile(std::string_view code) {
mlir::ParserConfig{&context, /*verifyAfterParse=*/true});

mlir_module->dump();

mlir::tt::ttir::registerPasses();
mlir::tt::ttnn::registerPasses();

mlir::PassManager shlo_pm(mlir_module.get()->getName());
// Implicit nesting required to call the stablehlo.composite --> func.call conversion.
mlir::PassManager shlo_pm(mlir_module.get()->getName(), mlir::PassManager::Nesting::Implicit);
mlir::tt::ttir::StableHLOToTTIRPipelineOptions shlo_options;
shlo_options.arithDialectConversionsEnabled = true;
shlo_options.removeDeadValuesEnabled = true;
shlo_options.legalizeCompositeToCallEnabled = true;
mlir::tt::ttir::createStableHLOToTTIRPipeline(shlo_pm, shlo_options);
// Run the pass manager.
if (mlir::failed(shlo_pm.run(mlir_module.get())))
Expand Down

0 comments on commit 00fe305

Please sign in to comment.