Skip to content

Commit

Permalink
Added option to remove constant scalars from FX graph and convert the…
Browse files Browse the repository at this point in the history
…m to inputs
  • Loading branch information
AleksKnezevic committed Jan 5, 2025
1 parent 092a5eb commit a7cc192
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 10 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@ You can use the following environment variables to override default behaviour:
| TT_TORCH_VERIFY_INTERMEDIATES | Sets whether to verify intermediate tensors against pytorch when running with compile depth `EXECUTE_OP_BY_OP`. | False |
| TT_TORCH_CONSTEVAL | Enables evaluation of constant expressions (consteval) in the Torch FX graph prior to compilation. | False |
| TT_TORCH_CONSTEVAL_PARAMETERS | Extends consteval to include parameters (e.g., model weights) as well as embedded constants. | False |
| TT_TORCH_EMBEDDEDD_CONSTANTS | Remove embedded constants from the Torch FX graph and convert them to constant inputs | False |
| TT_TORCH_ENABLE_IR_PRINTING | Enables printing MLIR for all conversion steps from StableHLO to TTNN. Be warned, this forces single core compile, so is much slower. | False |
3 changes: 1 addition & 2 deletions tests/models/autoencoder_linear/test_autoencoder_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,9 @@ def test_autoencoder_linear(record_property, mode, nightly):
cc = CompilerConfig()
cc.enable_consteval = True
cc.consteval_parameters = True
cc.remove_embedded_constants = True
if nightly:
cc.compile_depth = CompileDepth.EXECUTE_OP_BY_OP
else:
cc.compile_depth = CompileDepth.TTNN_IR

tester = ThisTester(model_name, mode, compiler_config=cc)
results = tester.test_model()
Expand Down
36 changes: 31 additions & 5 deletions tests/torch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,24 @@ def forward(self, x):
verify_module(Basic(), input_shapes=[(32, 32)])


from torch_mlir import fx
from torch_mlir.compiler_utils import OutputType
@pytest.mark.xfail(
strict=True,
reason="Embedded constants currently broken, see https://github.com/tenstorrent/tt-torch/issues/152",
)
def test_linear_with_bias():
class Basic(nn.Module):
def __init__(self):
super().__init__()
self.linear_a = nn.Linear(32, 32)

def forward(self, x):
x = self.linear_a(x)
return x

def test_linear_with_bias():
pytest.xfail()
verify_module(Basic(), input_shapes=[(32, 32)])


def test_linear_with_bias_no_embedded_constants():
class Basic(nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -194,7 +205,22 @@ def forward(self, x):
x = self.linear_a(x)
return x

verify_module(Basic(), input_shapes=[(32, 32)])
cc = CompilerConfig()
cc.remove_embedded_constants = True
verify_module(Basic(), input_shapes=[(32, 32)], compiler_config=cc)


def test_constant():
class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x + 1.0

cc = CompilerConfig()
cc.remove_embedded_constants = True
verify_module(Basic(), input_shapes=[(1, 768)], compiler_config=cc)


def test_maximum():
Expand Down
7 changes: 7 additions & 0 deletions tt_torch/csrc/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,15 @@ static tt::runtime::Tensor create_tensor(const torch::Tensor &tensor) {

auto shape =
std::vector<uint32_t>(tensor.sizes().begin(), tensor.sizes().end());
if (shape.empty()) {
shape.push_back(1);
}

auto stride =
std::vector<uint32_t>(tensor.strides().begin(), tensor.strides().end());
if (stride.empty()) {
stride.push_back(1);
}

return tt::runtime::createTensor(
data, shape, stride, tensor.element_size(),
Expand Down
45 changes: 42 additions & 3 deletions tt_torch/dynamo/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,17 +181,49 @@ def inline_parameters(gm):
return gm, parameters


def order_constant_inputs(gm, parameters, constants):
def order_constant_inputs(gm, parameters, constants, embedded_constants):
constant_inputs = []
for node in gm.graph.nodes:
if node.op == "placeholder":
if node.target in parameters:
constant_inputs.append(parameters[node.target])
elif node.target in constants:
constant_inputs.append(constants[node.target])
elif node.target in embedded_constants:
constant_inputs.append(embedded_constants[node.target])
return constant_inputs


def inline_constants(gm, example_inputs):
inlied_constats = {}
placeholders = {}

for node in gm.graph.nodes:
if node.op == "placeholder":
# start appending after last placeholder
gm.graph.inserting_after(node)

if node.op != "call_function" or node.target._overloadname != "Tensor":
continue

for idx, arg in enumerate(node.args):
if isinstance(arg, (int, float)):
if arg not in placeholders:
name = f"const_{arg}"
placeholder = gm.graph.placeholder(name)
placeholders[arg] = placeholder
new_arg = torch.tensor(arg)
inlied_constats[name] = new_arg
else:
placeholder = placeholders[arg]

args = list(node.args)
args[idx] = placeholder
node.args = tuple(args)

return gm, inlied_constats


def pass_pipeline(gm: torch.fx.GraphModule, example_inputs, compiler_config):
decompositions = DEFAULT_DECOMPOSITION_TABLE
decompositions.update(CUSTOM_DECOMPOSITION_TABLE)
Expand All @@ -201,10 +233,17 @@ def pass_pipeline(gm: torch.fx.GraphModule, example_inputs, compiler_config):
elif compiler_config.consteval_parameters:
raise Exception("consteval_parameters is enabled but enable_consteval is not")
else:
constants = []
constants = {}
gm = bypass_redundant_getitem(gm)
gm, parameters = inline_parameters(gm)
constant_inputs = order_constant_inputs(gm, parameters, constants)
if compiler_config.remove_embedded_constants:
gm, embedded_constants = inline_constants(gm, example_inputs)
else:
embedded_constants = {}

constant_inputs = order_constant_inputs(
gm, parameters, constants, embedded_constants
)

# some constant folding operations are preformed by changing tensor strides, we
# want all the strides to be 1, so make them contiguous
Expand Down
4 changes: 4 additions & 0 deletions tt_torch/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __init__(self):
self.single_op_timeout = 5
self.enable_intermediate_verification = False
self.enable_consteval = False
self.remove_embedded_constants = False
self._consteval_parameters = False

self.apply_environment_overrides()
Expand Down Expand Up @@ -156,6 +157,9 @@ def apply_environment_overrides(self):
consteval_parameters = os.environ.get("TT_TORCH_CONSTEVAL_PARAMETERS")
if consteval_parameters and int(consteval_parameters):
self.consteval_parameters = True
remove_embedded_constants = os.environ.get("TT_TORCH_EMBEDDEDD_CONSTANTS")
if remove_embedded_constants and int(remove_embedded_constants):
self.remove_embedded_constants = True

def post_init(self):
if self.consteval_parameters:
Expand Down

0 comments on commit a7cc192

Please sign in to comment.