diff --git a/mlir/dialects/scf.py b/mlir/dialects/scf.py index a3b974f..1c54cc7 100644 --- a/mlir/dialects/scf.py +++ b/mlir/dialects/scf.py @@ -5,7 +5,7 @@ from mlir.dialect import Dialect, DialectOp, is_op, UnaryOperation import mlir.astnodes as mast from dataclasses import dataclass -from typing import Optional +from typing import Optional, List, Tuple @dataclass @@ -13,10 +13,15 @@ class SCFForOp(DialectOp): index: mast.SsaId begin: mast.SsaId end: mast.SsaId + step: mast.SsaId body: mast.Region - step: Optional[mast.SsaId] = None - _syntax_ = ['scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} {body.region}', - 'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} {body.region}'] + iter_args: Optional[List[Tuple[mast.SsaId, mast.SsaId]]] = None + iter_args_types: Optional[List[mast.Type]] = None + out_type: Optional[mast.Type] = None + _syntax_ = ['scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} {body.region}', + 'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} : {out_type.type} {body.region}', + 'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args {iter_args.argument_assignment_list_parens} -> {iter_args_types.type_list_parens} {body.region}', + 'scf.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} iter_args {iter_args.argument_assignment_list_parens} -> {iter_args_types.type_list_parens} : {out_type.type} {body.region}'] @dataclass diff --git a/mlir/lark/mlir.lark b/mlir/lark/mlir.lark index dbd7fe5..39de4a5 100644 --- a/mlir/lark/mlir.lark +++ b/mlir/lark/mlir.lark @@ -221,6 +221,9 @@ region_list : "(" region? ("," region)* ")" // Arguments named_argument : ssa_id ":" type optional_attr_dict argument_list : (named_argument ("," named_argument)*) | (type optional_attr_dict ("," type optional_attr_dict)*) +argument_assignment : ssa_id "=" ssa_id +argument_assignment_list_no_parens : argument_assignment ("," argument_assignment)* +argument_assignment_list_parens : ("(" ")") | ("(" argument_assignment_list_no_parens ")") // Return values function_result : type optional_attr_dict diff --git a/mlir/parser_transformer.py b/mlir/parser_transformer.py index 63f4d04..d552c1e 100644 --- a/mlir/parser_transformer.py +++ b/mlir/parser_transformer.py @@ -200,6 +200,7 @@ def block_label(self, value): symbol_use_list = list operation_list = list argument_list = list + argument_assignment_list_no_parens = list definition_list = list function_list = list module_list = list diff --git a/tests/test_syntax.py b/tests/test_syntax.py index eecfc35..325aaba 100644 --- a/tests/test_syntax.py +++ b/tests/test_syntax.py @@ -158,6 +158,25 @@ def test_affine(parser: Optional[Parser] = None): module = parser.parse(code) print(module.pretty()) +def test_scf_for(parser: Optional[Parser] = None): + code = """ +module { + func.func @reduce(%buffer: memref<1024xf32>, %lb: index, + %ub: index, %step: index) -> (f32) { + %sum_0 = arith.constant 0.0 : f32 + %sum = scf.for %iv = %lb to %ub step %step + iter_args(%sum_iter = %sum_0) -> (f32) { + %t = load %buffer[%iv] : memref<1024xf32> + %sum_next = arith.addf %sum_iter, %t : f32 + scf.yield %sum_next : f32 + } + return %sum : f32 + } +} + """ + parser = parser or Parser() + module = parser.parse(code) + print(module.pretty()) def test_definitions(parser: Optional[Parser] = None): code = '''