Skip to content

Commit

Permalink
Add slice tests, fix module compilation
Browse files Browse the repository at this point in the history
Fixed tests

.
  • Loading branch information
LPanosTT committed Oct 29, 2024
1 parent 89c288d commit e154344
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 4 deletions.
11 changes: 11 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

@pytest.fixture(autouse=True)
def run_around_tests():
yield
torch._dynamo.reset()
49 changes: 49 additions & 0 deletions test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torch import nn
import pytest

import tt_torch
from tt_torch.tools.verify import verify_module

def test_add():
Expand Down Expand Up @@ -116,6 +117,54 @@ def forward(self, x):

verify_module(Basic(), [(32, 32)], required_atol=3e-2, input_range=(0.1, 1))

dim0_cases = []
for begin in torch.arange(10).tolist():
for end in torch.arange(90, 100).tolist():
dim0_cases.append((begin, end, 0))

dim1_cases = []
for begin in torch.arange(10).tolist():
for end in torch.arange(90, 100).tolist():
dim1_cases.append((begin, end, 1))

dim2_cases = []
for begin in torch.arange(0, 64, 32).tolist():
for end in torch.arange(64, 128, 32).tolist():
dim2_cases.append((begin, end, 2))

dim3_cases = []
for begin in torch.arange(0, 64, 32).tolist():
for end in torch.arange(64, 128, 32).tolist():
dim3_cases.append((begin, end, 3))

@pytest.mark.parametrize(
"begin, end, dim",
[
*dim2_cases,
*dim3_cases,
*dim0_cases,
*dim1_cases
]
)
def test_slice(begin, end, dim):
class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, a):
if dim == 0:
return a[begin:end, :, :, :]
elif dim == 1:
return a[:, begin:end, :, :]
elif dim == 2:
return a[:, :, begin:end, :]
else:
return a[:, :, :, begin:end]

shape = [10, 10, 10, 10]
shape[dim] = 128
verify_module(Basic(), [shape])

def test_bert():
from torch_mlir import fx
from torch_mlir.compiler_utils import OutputType
Expand Down
2 changes: 1 addition & 1 deletion third_party/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
#

set(TT_MLIR_VERSION "5267b3964628f8d76f225b1fd0e4403fdcd29d08")
set(TT_MLIR_VERSION "d7c655cb57d1b1d2d611d861d9c5e8981894fc50")

if (TOOLCHAIN STREQUAL "ON")
cmake_minimum_required(VERSION 3.20)
Expand Down
5 changes: 4 additions & 1 deletion tt_torch/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ target_include_directories(TT_TORCH_MLIR PUBLIC
${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir-build/stablehlo/
${TTMLIR_TOOLCHAIN_DIR}/include
${TTMLIR_TOOLCHAIN_DIR}/src/stablehlo
${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir-build/include
)

set(STABLEHLO_LIBS
Expand Down Expand Up @@ -82,8 +83,10 @@ target_link_libraries(TT_TORCH_MLIR PUBLIC
LLVM
MLIR
TTMLIR
TTMLIRStableHLOToTTIR
TTMLIRStatic
TTMLIRTosaToTTIR
MLIRTTIRPipelines
TTMLIRStableHLOToTTIR
${STABLEHLO_LIBS}
)
target_link_directories(TT_TORCH_MLIR PUBLIC
Expand Down
10 changes: 8 additions & 2 deletions tt_torch/csrc/tt-mlir-interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@
#include "stablehlo/dialect/Serialization.h" // from @stablehlo
#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo
#include "stablehlo/transforms/Passes.h" // from @stablehlo
#include "mlir/InitAllPasses.h"

#define TTMLIR_ENABLE_STABLEHLO
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h"
#include "ttmlir/Dialect/TTNN/Transforms/Passes.h"
#include "ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h"
#include "ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h"
#include "ttmlir/Target/TTNN/TTNNToFlatbuffer.h"
#include "ttmlir/RegisterAll.h"

#include "tt/runtime/runtime.h"

Expand All @@ -56,8 +58,10 @@ tt::runtime::Binary Compile(std::string_view code) {
registry.insert<mlir::ml_program::MLProgramDialect>();
registry.insert<mlir::shape::ShapeDialect>();

mlir::tt::registerAllDialects(registry);
mlir::stablehlo::registerAllDialects(registry);
mlir::func::registerAllExtensions(registry);
mlir::tt::registerAllExtensions(registry);

context.appendDialectRegistry(registry);

Expand All @@ -69,14 +73,16 @@ tt::runtime::Binary Compile(std::string_view code) {
mlir::ParserConfig{&context, /*verifyAfterParse=*/true});

mlir_module->dump();

mlir::tt::ttir::registerPasses();
mlir::tt::ttnn::registerPasses();

mlir::PassManager shlo_pm(mlir_module.get()->getName());
// Implicit nesting required to call the stablehlo.composite --> func.call conversion.
mlir::PassManager shlo_pm(mlir_module.get()->getName(), mlir::PassManager::Nesting::Implicit);
mlir::tt::ttir::StableHLOToTTIRPipelineOptions shlo_options;
shlo_options.arithDialectConversionsEnabled = true;
shlo_options.removeDeadValuesEnabled = true;
shlo_options.legalizeCompositeToCallEnabled = true;
mlir::tt::ttir::createStableHLOToTTIRPipeline(shlo_pm, shlo_options);
// Run the pass manager.
if (mlir::failed(shlo_pm.run(mlir_module.get())))
Expand Down

0 comments on commit e154344

Please sign in to comment.