diff --git a/tests/data/models.py b/tests/data/models.py index 7b8f90017..82c00e3e4 100644 --- a/tests/data/models.py +++ b/tests/data/models.py @@ -9,7 +9,9 @@ from __future__ import annotations +from collections.abc import Callable from functools import lru_cache +from typing import Final from rflx.error import Location from rflx.expression import ( @@ -621,10 +623,13 @@ def universal_options() -> Sequence: return Sequence("Universal::Options", universal_option()) +UNIVERSAL_MESSAGE_ID: Final = ID("Universal::Message") + + @lru_cache def universal_message() -> Message: return Message( - "Universal::Message", + UNIVERSAL_MESSAGE_ID, [ Link(INITIAL, Field("Message_Type")), Link( @@ -826,16 +831,22 @@ def session() -> Session: ) -def spark_test_models() -> list[Model]: - """Return models corresponding to generated code in tests/spark/generated.""" +def spark_test_models() -> list[Callable[[], Model]]: + """ + Return callables that create models corresponding to generated code in tests/spark/generated. + + Using callable functions instead of the models directly enables the caller to postpone the + time-consuming creation of the models to a later time. For instance, when using this function to + parameterize a test function, no model creation is necessary during collection time. + """ return [ - derivation_model(), - enumeration_model(), - ethernet_model(), - expression_model(), - null_message_in_tlv_message_model(), - null_model(), - sequence_model(), - tlv_model(), - Model(fixed_size_simple_message().dependencies), + derivation_model, + enumeration_model, + ethernet_model, + expression_model, + null_message_in_tlv_message_model, + null_model, + sequence_model, + tlv_model, + lambda: Model(fixed_size_simple_message().dependencies), ] diff --git a/tests/unit/generator_test.py b/tests/unit/generator_test.py index 280d69f32..5557e8f15 100644 --- a/tests/unit/generator_test.py +++ b/tests/unit/generator_test.py @@ -4,6 +4,7 @@ import typing as ty from collections.abc import Callable, Sequence from dataclasses import dataclass +from functools import lru_cache from pathlib import Path from typing import Optional @@ -154,8 +155,8 @@ def test_generate_partial_update(tmp_path: Path) -> None: @pytest.mark.parametrize("model", models.spark_test_models()) -def test_equality(model: Model, tmp_path: Path) -> None: - assert_equal_code(model, Integration(), GENERATED_DIR, tmp_path, accept_extra_files=True) +def test_equality(model: Callable[[], Model], tmp_path: Path) -> None: + assert_equal_code(model(), Integration(), GENERATED_DIR, tmp_path, accept_extra_files=True) @pytest.mark.parametrize("embedded", [True, False]) @@ -295,24 +296,26 @@ def test_prefixed_type_identifier() -> None: assert common.prefixed_type_identifier(ID(t), "P") == t.name -DUMMY_SESSION = ir.Session( - identifier=ID("P::S"), - states=[ - ir.State( - "State", - [ir.Transition("Final", ir.ComplexExpr([], ir.BoolVal(value=True)), None, None)], - None, - [], - None, - None, - ), - ], - declarations=[], - parameters=[], - types={t.identifier: t for t in models.universal_model().types}, - location=None, - variable_id=id_generator(), -) +@lru_cache +def dummy_session() -> ir.Session: + return ir.Session( + identifier=ID("P::S"), + states=[ + ir.State( + "State", + [ir.Transition("Final", ir.ComplexExpr([], ir.BoolVal(value=True)), None, None)], + None, + [], + None, + None, + ), + ], + declarations=[], + parameters=[], + types={t.identifier: t for t in models.universal_model().types}, + location=None, + variable_id=id_generator(), + ) @pytest.mark.parametrize( @@ -376,8 +379,8 @@ def test_session_create_abstract_function( expected: Sequence[ada.SubprogramDeclaration], ) -> None: session_generator = SessionGenerator( - DUMMY_SESSION, - AllocatorGenerator(DUMMY_SESSION, Integration()), + dummy_session(), + AllocatorGenerator(dummy_session(), Integration()), debug=Debug.BUILTIN, ) @@ -480,8 +483,8 @@ def test_session_create_abstract_functions_error( error_msg: str, ) -> None: session_generator = SessionGenerator( - DUMMY_SESSION, - AllocatorGenerator(DUMMY_SESSION, Integration()), + dummy_session(), + AllocatorGenerator(dummy_session(), Integration()), debug=Debug.BUILTIN, ) @@ -711,10 +714,10 @@ def test_session_evaluate_declarations( session_global: bool, expected: EvaluatedDeclaration, ) -> None: - allocator = AllocatorGenerator(DUMMY_SESSION, Integration()) + allocator = AllocatorGenerator(dummy_session(), Integration()) allocator._allocation_slots[Location((1, 1))] = 1 # noqa: SLF001 - session_generator = SessionGenerator(DUMMY_SESSION, allocator, debug=Debug.BUILTIN) + session_generator = SessionGenerator(dummy_session(), allocator, debug=Debug.BUILTIN) assert ( session_generator._evaluate_declarations( # noqa: SLF001 [declaration], @@ -978,10 +981,10 @@ def test_session_declare( expected: EvaluatedDeclarationStr, ) -> None: loc: Location = Location((1, 1)) - allocator = AllocatorGenerator(DUMMY_SESSION, Integration()) + allocator = AllocatorGenerator(dummy_session(), Integration()) allocator._allocation_slots[loc] = 1 # noqa: SLF001 - session_generator = SessionGenerator(DUMMY_SESSION, allocator, debug=Debug.BUILTIN) + session_generator = SessionGenerator(dummy_session(), allocator, debug=Debug.BUILTIN) result = session_generator._declare( # noqa: SLF001 ID("X"), @@ -1073,8 +1076,8 @@ def test_session_declare_error( error_msg: str, ) -> None: session_generator = SessionGenerator( - DUMMY_SESSION, - AllocatorGenerator(DUMMY_SESSION, Integration()), + dummy_session(), + AllocatorGenerator(dummy_session(), Integration()), debug=Debug.BUILTIN, ) @@ -1268,8 +1271,8 @@ def _update_str(self) -> None: ], ) def test_session_state_action(action: ir.Stmt, expected: str) -> None: - allocator = AllocatorGenerator(DUMMY_SESSION, Integration()) - session_generator = SessionGenerator(DUMMY_SESSION, allocator, debug=Debug.BUILTIN) + allocator = AllocatorGenerator(dummy_session(), Integration()) + session_generator = SessionGenerator(dummy_session(), allocator, debug=Debug.BUILTIN) allocator._allocation_slots[Location((1, 1))] = 1 # noqa: SLF001 assert ( @@ -1323,8 +1326,8 @@ def test_session_state_action_error( error_msg: str, ) -> None: session_generator = SessionGenerator( - DUMMY_SESSION, - AllocatorGenerator(DUMMY_SESSION, Integration()), + dummy_session(), + AllocatorGenerator(dummy_session(), Integration()), debug=Debug.BUILTIN, ) @@ -1696,8 +1699,8 @@ def test_session_assign_error( error_type: type[BaseError], error_msg: str, ) -> None: - allocator = AllocatorGenerator(DUMMY_SESSION, Integration()) - session_generator = SessionGenerator(DUMMY_SESSION, allocator, debug=Debug.BUILTIN) + allocator = AllocatorGenerator(dummy_session(), Integration()) + session_generator = SessionGenerator(dummy_session(), allocator, debug=Debug.BUILTIN) alloc_id = Location((1, 1)) allocator._allocation_slots[alloc_id] = 1 # noqa: SLF001 @@ -1773,8 +1776,8 @@ def test_session_append_error( error_msg: str, ) -> None: session_generator = SessionGenerator( - DUMMY_SESSION, - AllocatorGenerator(DUMMY_SESSION, Integration()), + dummy_session(), + AllocatorGenerator(dummy_session(), Integration()), debug=Debug.BUILTIN, ) @@ -1816,8 +1819,8 @@ def test_session_append_error( ) def test_session_read_error(read: ir.Read, error_type: type[BaseError], error_msg: str) -> None: session_generator = SessionGenerator( - DUMMY_SESSION, - AllocatorGenerator(DUMMY_SESSION, Integration()), + dummy_session(), + AllocatorGenerator(dummy_session(), Integration()), debug=Debug.BUILTIN, ) @@ -1847,8 +1850,8 @@ def test_session_read_error(read: ir.Read, error_type: type[BaseError], error_ms ) def test_session_write_error(write: ir.Write, error_type: type[BaseError], error_msg: str) -> None: session_generator = SessionGenerator( - DUMMY_SESSION, - AllocatorGenerator(DUMMY_SESSION, Integration()), + dummy_session(), + AllocatorGenerator(dummy_session(), Integration()), debug=Debug.BUILTIN, ) @@ -1908,8 +1911,8 @@ def test_session_write_error(write: ir.Write, error_type: type[BaseError], error ) def test_session_to_ada_expr(expression: ir.Expr, expected: ada.Expr) -> None: session_generator = SessionGenerator( - DUMMY_SESSION, - AllocatorGenerator(DUMMY_SESSION, Integration()), + dummy_session(), + AllocatorGenerator(dummy_session(), Integration()), debug=Debug.BUILTIN, ) @@ -1932,8 +1935,8 @@ def test_session_to_ada_expr_equality( expected: ada.Expr, ) -> None: session_generator = SessionGenerator( - DUMMY_SESSION, - AllocatorGenerator(DUMMY_SESSION, Integration()), + dummy_session(), + AllocatorGenerator(dummy_session(), Integration()), debug=Debug.BUILTIN, ) diff --git a/tests/unit/model/message_test.py b/tests/unit/model/message_test.py index 577d4306c..8b487a6ae 100644 --- a/tests/unit/model/message_test.py +++ b/tests/unit/model/message_test.py @@ -120,36 +120,39 @@ M_SMPL_REF.identifier, ) -PARAMETERIZED_MESSAGE = Message( - "P::M", - [ - Link( - INITIAL, - Field("F1"), - size=Mul(Variable("P1"), Number(8)), - ), - Link( - Field("F1"), - FINAL, - condition=Equal(Variable("P2"), Variable("One")), - ), - Link( - Field("F1"), - Field("F2"), - condition=Equal(Variable("P2"), Variable("Two")), - ), - Link( - Field("F2"), - FINAL, - ), - ], - { - Field("P1"): models.integer(), - Field("P2"): models.enumeration(), - Field("F1"): OPAQUE, - Field("F2"): models.integer(), - }, -) + +@lru_cache +def parameterized_message() -> Message: + return Message( + "P::M", + [ + Link( + INITIAL, + Field("F1"), + size=Mul(Variable("P1"), Number(8)), + ), + Link( + Field("F1"), + FINAL, + condition=Equal(Variable("P2"), Variable("One")), + ), + Link( + Field("F1"), + Field("F2"), + condition=Equal(Variable("P2"), Variable("Two")), + ), + Link( + Field("F2"), + FINAL, + ), + ], + { + Field("P1"): models.integer(), + Field("P2"): models.enumeration(), + Field("F1"): OPAQUE, + Field("F2"): models.integer(), + }, + ) @lru_cache @@ -182,16 +185,16 @@ def test_invalid_identifier() -> None: @pytest.mark.parametrize( "parameter_type", [ - models.null_message(), - models.tlv_message(), - models.sequence_integer_vector(), - models.sequence_inner_messages(), - OPAQUE, + models.null_message, + models.tlv_message, + models.sequence_integer_vector, + models.sequence_inner_messages, + lambda: OPAQUE, ], ) -def test_invalid_parameter_type_composite(parameter_type: Type) -> None: +def test_invalid_parameter_type_composite(parameter_type: abc.Callable[[], Type]) -> None: structure = [Link(INITIAL, Field("X")), Link(Field("X"), FINAL)] - types = {Field(ID("P", Location((1, 2)))): parameter_type, Field("X"): models.integer()} + types = {Field(ID("P", Location((1, 2)))): parameter_type(), Field("X"): models.integer()} assert_message_model_error( structure, @@ -417,7 +420,7 @@ def test_cycle() -> None: def test_parameters() -> None: assert not models.ethernet_frame().parameters - assert PARAMETERIZED_MESSAGE.parameters == ( + assert parameterized_message().parameters == ( Field("P1"), Field("P2"), ) @@ -433,21 +436,21 @@ def test_fields() -> None: Field("Type_Length"), Field("Payload"), ) - assert PARAMETERIZED_MESSAGE.fields == ( + assert parameterized_message().fields == ( Field("F1"), Field("F2"), ) def test_parameter_types() -> None: - assert PARAMETERIZED_MESSAGE.parameter_types == { + assert parameterized_message().parameter_types == { Field("P1"): models.integer(), Field("P2"): models.enumeration(), } def test_field_types() -> None: - assert PARAMETERIZED_MESSAGE.field_types == { + assert parameterized_message().field_types == { Field("F1"): OPAQUE, Field("F2"): models.integer(), } @@ -4838,7 +4841,7 @@ def test_boolean_variable_as_condition() -> None: ("message", "condition"), [ ( - Message( + lambda: Message( "P::M", [ Link(INITIAL, Field("Tag")), @@ -4856,7 +4859,7 @@ def test_boolean_variable_as_condition() -> None: ), ), ( - Message( + lambda: Message( "P::M", [ Link(INITIAL, Field("Tag")), @@ -4876,7 +4879,7 @@ def test_boolean_variable_as_condition() -> None: ), ], ) -def test_always_true_refinement(message: Message, condition: Expr) -> None: +def test_always_true_refinement(message: abc.Callable[[], Message], condition: Expr) -> None: with pytest.raises( RecordFluxError, match=( @@ -4886,7 +4889,7 @@ def test_always_true_refinement(message: Message, condition: Expr) -> None: ): Refinement( "In_Message", - message, + message(), Field(ID("Value", location=Location((10, 20)))), models.message(), condition, @@ -4897,7 +4900,7 @@ def test_always_true_refinement(message: Message, condition: Expr) -> None: ("message", "condition"), [ ( - Message( + lambda: Message( "P::M", [ Link(INITIAL, Field("Tag")), @@ -4915,7 +4918,7 @@ def test_always_true_refinement(message: Message, condition: Expr) -> None: ), ), ( - Message( + lambda: Message( "P::M", [ Link(INITIAL, Field("Tag")), @@ -4935,7 +4938,7 @@ def test_always_true_refinement(message: Message, condition: Expr) -> None: ), ], ) -def test_always_false_refinement(message: Message, condition: Expr) -> None: +def test_always_false_refinement(message: abc.Callable[[], Message], condition: Expr) -> None: with pytest.raises( RecordFluxError, match=( @@ -4945,7 +4948,7 @@ def test_always_false_refinement(message: Message, condition: Expr) -> None: ): Refinement( "In_Message", - message, + message(), Field(ID("Value", location=Location((10, 20)))), models.message(), condition, @@ -5057,69 +5060,62 @@ def test_possibly_always_true_refinement( ) in captured.out -@pytest.mark.parametrize( - ("unchecked", "expected"), - [ - ( - UncheckedMessage( - ID("P::No_Ref"), - [ - Link(INITIAL, Field("F1"), size=Number(16)), - Link(Field("F1"), Field("F2")), - Link( - Field("F2"), - Field("F3"), - LessEqual(Variable("F2"), Number(100)), - first=First("F2"), - ), - Link( - Field("F2"), - Field("F4"), - GreaterEqual(Variable("F2"), Number(200)), - first=First("F2"), - ), - Link(Field("F3"), FINAL, Equal(Variable("F3"), Variable("One"))), - Link(Field("F4"), FINAL), - ], - [], - [ - (Field("F1"), OPAQUE.identifier, []), - (Field("F2"), models.integer().identifier, []), - (Field("F3"), models.enumeration().identifier, []), - (Field("F4"), models.integer().identifier, []), - ], +def test_unchecked_message_checked() -> None: + unchecked = UncheckedMessage( + ID("P::No_Ref"), + [ + Link(INITIAL, Field("F1"), size=Number(16)), + Link(Field("F1"), Field("F2")), + Link( + Field("F2"), + Field("F3"), + LessEqual(Variable("F2"), Number(100)), + first=First("F2"), ), - Message( - ID("P::No_Ref"), - [ - Link(INITIAL, Field("F1"), size=Number(16)), - Link(Field("F1"), Field("F2")), - Link( - Field("F2"), - Field("F3"), - LessEqual(Variable("F2"), Number(100)), - first=First("F2"), - ), - Link( - Field("F2"), - Field("F4"), - GreaterEqual(Variable("F2"), Number(200)), - first=First("F2"), - ), - Link(Field("F3"), FINAL, Equal(Variable("F3"), Variable("One"))), - Link(Field("F4"), FINAL), - ], - { - Field("F1"): OPAQUE, - Field("F2"): models.integer(), - Field("F3"): models.enumeration(), - Field("F4"): models.integer(), - }, + Link( + Field("F2"), + Field("F4"), + GreaterEqual(Variable("F2"), Number(200)), + first=First("F2"), ), - ), - ], -) -def test_unchecked_message_checked(unchecked: UncheckedMessage, expected: Message) -> None: + Link(Field("F3"), FINAL, Equal(Variable("F3"), Variable("One"))), + Link(Field("F4"), FINAL), + ], + [], + [ + (Field("F1"), OPAQUE.identifier, []), + (Field("F2"), models.integer().identifier, []), + (Field("F3"), models.enumeration().identifier, []), + (Field("F4"), models.integer().identifier, []), + ], + ) + expected = Message( + ID("P::No_Ref"), + [ + Link(INITIAL, Field("F1"), size=Number(16)), + Link(Field("F1"), Field("F2")), + Link( + Field("F2"), + Field("F3"), + LessEqual(Variable("F2"), Number(100)), + first=First("F2"), + ), + Link( + Field("F2"), + Field("F4"), + GreaterEqual(Variable("F2"), Number(200)), + first=First("F2"), + ), + Link(Field("F3"), FINAL, Equal(Variable("F3"), Variable("One"))), + Link(Field("F4"), FINAL), + ], + { + Field("F1"): OPAQUE, + Field("F2"): models.integer(), + Field("F3"): models.enumeration(), + Field("F4"): models.integer(), + }, + ) assert unchecked.checked([OPAQUE, models.enumeration(), models.integer()]) == expected diff --git a/tests/unit/model/model_test.py b/tests/unit/model/model_test.py index 158bf184d..426ff0730 100644 --- a/tests/unit/model/model_test.py +++ b/tests/unit/model/model_test.py @@ -1,7 +1,7 @@ from __future__ import annotations import textwrap -from collections.abc import Sequence +from collections.abc import Callable, Sequence from copy import copy from pathlib import Path @@ -173,23 +173,26 @@ def test_invalid_enumeration_type_builtin_literals() -> None: @pytest.mark.parametrize( ("types", "model"), [ - ([models.tlv_message()], models.tlv_model()), - ([models.tlv_with_checksum_message()], models.tlv_with_checksum_model()), - ([models.ethernet_frame()], models.ethernet_model()), - ([models.enumeration_message()], models.enumeration_model()), - ([models.universal_refinement()], models.universal_model()), + ([models.tlv_message], models.tlv_model), + ([models.tlv_with_checksum_message], models.tlv_with_checksum_model), + ([models.ethernet_frame], models.ethernet_model), + ([models.enumeration_message], models.enumeration_model), + ([models.universal_refinement], models.universal_model), ( [ - models.sequence_message(), - models.sequence_messages_message(), - models.sequence_sequence_size_defined_by_message_size(), + models.sequence_message, + models.sequence_messages_message, + models.sequence_sequence_size_defined_by_message_size, ], - models.sequence_model(), + models.sequence_model, ), ], ) -def test_init_introduce_type_dependencies(types: Sequence[Type], model: Model) -> None: - assert Model(types).types == model.types +def test_init_introduce_type_dependencies( + types: Sequence[Callable[[], Type]], + model: Callable[[], Model], +) -> None: + assert Model([t() for t in types]).types == model().types def test_invalid_enumeration_type_identical_literals() -> None: @@ -407,7 +410,7 @@ def test_write_specification_files_line_too_long(tmp_path: Path) -> None: ([], []), ( [mty.UncheckedInteger(ID("P::T"), Number(0), Number(128), Number(8), Location((1, 2)))], - [mty.Integer(ID("P::T"), Number(0), Number(128), Number(8), Location((1, 2)))], + [lambda: mty.Integer(ID("P::T"), Number(0), Number(128), Number(8), Location((1, 2)))], ), ( [ @@ -427,14 +430,14 @@ def test_write_specification_files_line_too_long(tmp_path: Path) -> None: ), ], [ - mty.Integer( + lambda: mty.Integer( ID("P::I"), Number(0), Number(128), Number(8), Location((1, 2)), ), - mty.Enumeration( + lambda: mty.Enumeration( ID("P::E"), [(ID("A"), Number(0)), (ID("B"), Number(1))], Number(8), @@ -462,8 +465,8 @@ def test_write_specification_files_line_too_long(tmp_path: Path) -> None: ), ], [ - OPAQUE, - Message( + lambda: OPAQUE, + lambda: Message( "P::M", [ Link(INITIAL, Field("F"), size=Number(16)), @@ -479,14 +482,16 @@ def test_write_specification_files_line_too_long(tmp_path: Path) -> None: ) def test_unchecked_model_checked( unchecked: list[UncheckedTopLevelDeclaration], - expected: list[TopLevelDeclaration], + expected: list[Callable[[], TopLevelDeclaration]], tmp_path: Path, ) -> None: cache = Cache(tmp_path / "test.json") - assert UncheckedModel(unchecked, RecordFluxError()).checked(cache=cache) == Model(expected) + declarations = [d() for d in expected] + + assert UncheckedModel(unchecked, RecordFluxError()).checked(cache=cache) == Model(declarations) - messages = [d for d in expected if isinstance(d, Message)] + messages = [d for d in declarations if isinstance(d, Message)] if messages: for d in messages: cache.is_verified(Digest(d)) diff --git a/tests/unit/model/session_test.py b/tests/unit/model/session_test.py index a169e8dce..45a9b86ad 100644 --- a/tests/unit/model/session_test.py +++ b/tests/unit/model/session_test.py @@ -1927,7 +1927,7 @@ def test_type_error_in_renaming_declaration() -> None: [ decl.VariableDeclaration( "M", - models.universal_message().identifier, + models.UNIVERSAL_MESSAGE_ID, location=Location((1, 2)), ), ], @@ -1943,7 +1943,7 @@ def test_type_error_in_renaming_declaration() -> None: stmt.Read( "C1", expr.MessageAggregate( - models.universal_message().identifier, + models.UNIVERSAL_MESSAGE_ID, {"Message_Type": expr.Variable("Universal::MT_Null")}, ), location=Location((1, 2)), @@ -2835,107 +2835,100 @@ def test_message_assignment_from_function() -> None: ) -@pytest.mark.parametrize( - ("unchecked", "expected"), - [ - ( - UncheckedSession( - ID("P::S"), - [ - State( - "A", - declarations=[], - actions=[stmt.Read("X", expr.Variable("M"))], - transitions=[ - Transition("B"), - ], - ), - State( - "B", - declarations=[ - decl.VariableDeclaration("Z", BOOLEAN.identifier, expr.Variable("Y")), - ], - actions=[], - transitions=[ - Transition( - "null", - condition=expr.And( - expr.Equal(expr.Variable("Z"), expr.TRUE), - expr.Equal(expr.Call("G", [expr.Variable("F")]), expr.TRUE), - ), - description="rfc1149.txt+45:4-47:8", - ), - Transition("A"), - ], - description="rfc1149.txt+51:4-52:9", - ), +def test_unchecked_session_checked() -> None: + unchecked = UncheckedSession( + ID("P::S"), + [ + State( + "A", + declarations=[], + actions=[stmt.Read("X", expr.Variable("M"))], + transitions=[ + Transition("B"), ], - [ - decl.VariableDeclaration("M", "TLV::Message"), - decl.VariableDeclaration("Y", BOOLEAN.identifier, expr.FALSE), + ), + State( + "B", + declarations=[ + decl.VariableDeclaration("Z", BOOLEAN.identifier, expr.Variable("Y")), ], - [ - decl.ChannelDeclaration("X", readable=True, writable=True), - decl.FunctionDeclaration("F", [], BOOLEAN.identifier), - decl.FunctionDeclaration( - "G", - [decl.Argument("P", BOOLEAN.identifier)], - BOOLEAN.identifier, + actions=[], + transitions=[ + Transition( + "null", + condition=expr.And( + expr.Equal(expr.Variable("Z"), expr.TRUE), + expr.Equal(expr.Call("G", [expr.Variable("F")]), expr.TRUE), + ), + description="rfc1149.txt+45:4-47:8", ), + Transition("A"), ], - Location((1, 2)), + description="rfc1149.txt+51:4-52:9", ), - Session( - "P::S", - [ - State( - "A", - declarations=[], - actions=[stmt.Read("X", expr.Variable("M"))], - transitions=[ - Transition("B"), - ], - ), - State( - "B", - declarations=[ - decl.VariableDeclaration("Z", BOOLEAN.identifier, expr.Variable("Y")), - ], - actions=[], - transitions=[ - Transition( - "null", - condition=expr.And( - expr.Equal(expr.Variable("Z"), expr.TRUE), - expr.Equal(expr.Call("G", [expr.Variable("F")]), expr.TRUE), - ), - description="rfc1149.txt+45:4-47:8", - ), - Transition("A"), - ], - description="rfc1149.txt+51:4-52:9", - ), + ], + [ + decl.VariableDeclaration("M", "TLV::Message"), + decl.VariableDeclaration("Y", BOOLEAN.identifier, expr.FALSE), + ], + [ + decl.ChannelDeclaration("X", readable=True, writable=True), + decl.FunctionDeclaration("F", [], BOOLEAN.identifier), + decl.FunctionDeclaration( + "G", + [decl.Argument("P", BOOLEAN.identifier)], + BOOLEAN.identifier, + ), + ], + Location((1, 2)), + ) + expected = Session( + "P::S", + [ + State( + "A", + declarations=[], + actions=[stmt.Read("X", expr.Variable("M"))], + transitions=[ + Transition("B"), ], - [ - decl.VariableDeclaration("M", "TLV::Message"), - decl.VariableDeclaration("Y", BOOLEAN.identifier, expr.FALSE), + ), + State( + "B", + declarations=[ + decl.VariableDeclaration("Z", BOOLEAN.identifier, expr.Variable("Y")), ], - [ - decl.ChannelDeclaration("X", readable=True, writable=True), - decl.FunctionDeclaration("F", [], BOOLEAN.identifier), - decl.FunctionDeclaration( - "G", - [decl.Argument("P", BOOLEAN.identifier)], - BOOLEAN.identifier, + actions=[], + transitions=[ + Transition( + "null", + condition=expr.And( + expr.Equal(expr.Variable("Z"), expr.TRUE), + expr.Equal(expr.Call("G", [expr.Variable("F")]), expr.TRUE), + ), + description="rfc1149.txt+45:4-47:8", ), + Transition("A"), ], - [BOOLEAN, models.tlv_message()], - Location((1, 2)), + description="rfc1149.txt+51:4-52:9", ), - ), - ], -) -def test_unchecked_session_checked(unchecked: UncheckedSession, expected: Session) -> None: + ], + [ + decl.VariableDeclaration("M", "TLV::Message"), + decl.VariableDeclaration("Y", BOOLEAN.identifier, expr.FALSE), + ], + [ + decl.ChannelDeclaration("X", readable=True, writable=True), + decl.FunctionDeclaration("F", [], BOOLEAN.identifier), + decl.FunctionDeclaration( + "G", + [decl.Argument("P", BOOLEAN.identifier)], + BOOLEAN.identifier, + ), + ], + [BOOLEAN, models.tlv_message()], + Location((1, 2)), + ) assert ( unchecked.checked( [BOOLEAN, models.tlv_message()], diff --git a/tests/unit/model/type_test.py b/tests/unit/model/type_test.py index 425931d62..d1170a6c6 100644 --- a/tests/unit/model/type_test.py +++ b/tests/unit/model/type_test.py @@ -1,3 +1,5 @@ +from typing import Callable + import pytest import rflx.typing_ as rty @@ -403,22 +405,22 @@ def test_sequence_dependencies() -> None: ("element_type", "error"), [ ( - Sequence("P::B", models.integer(), Location((3, 4))), + lambda: Sequence("P::B", models.integer(), Location((3, 4))), r':1:2: model: error: invalid element type of sequence "A"\n' r':3:4: model: info: type "B" must be scalar or message', ), ( - OPAQUE, + lambda: OPAQUE, r':1:2: model: error: invalid element type of sequence "A"\n' r'__BUILTINS__:0:0: model: info: type "Opaque" must be scalar or message', ), ( - Message("P::B", [], {}, location=Location((3, 4))), + lambda: Message("P::B", [], {}, location=Location((3, 4))), r':1:2: model: error: invalid element type of sequence "A"\n' r":3:4: model: info: null messages must not be used as sequence element", ), ( - Message( + lambda: Message( "P::B", [Link(INITIAL, Field("A"), size=Size("Message")), Link(Field("A"), FINAL)], {Field("A"): OPAQUE}, @@ -429,7 +431,7 @@ def test_sequence_dependencies() -> None: ' on "Message\'Size" or "Message\'Last"', ), ( - Message( + lambda: Message( "P::B", [ Link(INITIAL, Field("A"), condition=Equal(Size("Message"), Number(8))), @@ -444,9 +446,9 @@ def test_sequence_dependencies() -> None: ), ], ) -def test_sequence_invalid_element_type(element_type: Type, error: str) -> None: +def test_sequence_invalid_element_type(element_type: Callable[[], Type], error: str) -> None: with pytest.raises(RecordFluxError, match=f"^{error}$"): - Sequence("P::A", element_type, Location((1, 2))) + Sequence("P::A", element_type(), Location((1, 2))) def test_sequence_unsupported_element_type() -> None: diff --git a/tools/generate_spark_test_code.py b/tools/generate_spark_test_code.py index de8d80ed8..1bcc02168 100755 --- a/tools/generate_spark_test_code.py +++ b/tools/generate_spark_test_code.py @@ -49,7 +49,7 @@ def generate_spark_tests() -> None: parser = Parser(cached=True) parser.parse(*SPECIFICATION_FILES) - model = merge_models([parser.create_model(), *models.spark_test_models()]) + model = merge_models([parser.create_model(), *[m() for m in models.spark_test_models()]]) Generator( "RFLX", reproducible=True,