Skip to content

Commit

Permalink
Convert boolean inputs to bfloat16 to match TTNN types
Browse files Browse the repository at this point in the history
  • Loading branch information
mmanzoorTT committed Nov 21, 2024
1 parent 1f885ac commit af95f51
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions tt_torch/dynamo/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ def __init__(self, gm, compiler_config=None):
if compiler_config is None:
compiler_config = CompilerConfig()
self.compiler_config = compiler_config
# Dictionary to keep track of the type conversion for unsupported hardware
# types and use it to convert the input arguments to supported types.
self.type_conversion = {torch.bool: torch.bfloat16}

def set_binary(self, binary):
self.binary = binary
Expand Down Expand Up @@ -378,6 +381,19 @@ def run_gm_op_by_op(self, *inputs):
return outputs

def __call__(self, *inputs):
new_inputs = ()
for input in inputs:
input_type = input.dtype
if input_type in self.type_conversion.keys():
new_inputs = new_inputs + (
(input.to(dtype=self.type_conversion[input_type])),
)
continue

new_inputs = new_inputs + ((input),)

inputs = new_inputs

if self.compiler_config.compile_depth == CompileDepth.EXECUTE:
assert self.binary is not None, "Binary must be set for EXECUTE mode"
return tt_mlir.run(inputs, self.binary)
Expand Down

0 comments on commit af95f51

Please sign in to comment.