Skip to content

Commit

Permalink
[Emit] Support lowering of Int8 MLIR data types
Browse files Browse the repository at this point in the history
  • Loading branch information
nvukobratTT committed Aug 14, 2024
1 parent d9a1278 commit 8809689
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions pybuda/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ class MLIRGenerator
return builder_.create<mlir::tensor::EmptyOp>(
get_tt_forge_operation_location(graph, node),
shape_vec,
get_float_type(node));
get_data_type(node));
}

/// Emit the return operation for the function.
Expand All @@ -364,8 +364,8 @@ class MLIRGenerator
mlir::ValueRange(returnValues));
}

/// Get the MLIR float type type for a TTForge node.
mlir::FloatType get_float_type(graphlib::Node *node)
/// Get the MLIR data type for a TTForge node.
mlir::Type get_data_type(graphlib::Node *node)
{
switch (node->output_df())
{
Expand All @@ -375,7 +375,10 @@ class MLIRGenerator
return builder_.getBF16Type();
case tt::DataFormat::Float16:
return builder_.getF16Type();
case tt::DataFormat::Int8:
return builder_.getI8Type();
default:
log_error("Unsupported data format during lowering from TTForge to TTIR: {}", node->output_df());
TT_ASSERT(false);
}
// TODO add all supported types in switch
Expand All @@ -390,7 +393,7 @@ class MLIRGenerator
{
shape_vec.push_back((int64_t)dim);
}
return mlir::RankedTensorType::get(shape_vec, get_float_type(node));
return mlir::RankedTensorType::get(shape_vec, get_data_type(node));
}

/// Get the location for a module.
Expand Down

0 comments on commit 8809689

Please sign in to comment.