From d880baef3373bd08c02aa09999a23f16300d7226 Mon Sep 17 00:00:00 2001 From: Tony Meyer Date: Thu, 4 Apr 2024 22:21:12 +1300 Subject: [PATCH 01/12] Remove _DCBase. --- scenario/state.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scenario/state.py b/scenario/state.py index d89838dd..be171ebb 100644 --- a/scenario/state.py +++ b/scenario/state.py @@ -266,6 +266,7 @@ class Secret(_max_posargs(1)): # mapping from revision IDs to each revision's contents contents: Dict[int, "RawSecretRevisionContents"] +class Secret: id: str # CAUTION: ops-created Secrets (via .add_secret()) will have a canonicalized # secret id (`secret:` prefix) From 1f18d0d921935bc9baba622b695c4cd0ca63830a Mon Sep 17 00:00:00 2001 From: Tony Meyer Date: Tue, 28 May 2024 18:56:57 +1200 Subject: [PATCH 02/12] Adjust the consistency checker and expose the Resource class. --- README.md | 86 +++++++++++++++++---------------- scenario/__init__.py | 3 ++ scenario/consistency_checker.py | 14 +++--- scenario/state.py | 84 ++++++++++++++++++++++++++++---- 4 files changed, 129 insertions(+), 58 deletions(-) diff --git a/README.md b/README.md index 92a13f84..0bf9ab55 100644 --- a/README.md +++ b/README.md @@ -322,22 +322,21 @@ class MyCharm(ops.CharmBase): def test_relation_data(): - state_in = scenario.State(relations=[ - scenario.Relation( - endpoint="foo", - interface="bar", - remote_app_name="remote", - local_unit_data={"abc": "foo"}, - remote_app_data={"cde": "baz!"}, - ), - ]) + rel = scenario.Relation( + endpoint="foo", + interface="bar", + remote_app_name="remote", + local_unit_data={"abc": "foo"}, + remote_app_data={"cde": "baz!"}, + ) + state_in = scenario.State(relations={rel}) ctx = scenario.Context(MyCharm, meta={"name": "foo"}) state_out = ctx.run(ctx.on.start(), state_in) - assert state_out.relations[0].local_unit_data == {"abc": "baz!"} - # you can do this to check that there are no other differences: - assert state_out.relations == [ + assert state_out.get_relation(rel.id).local_unit_data == {"abc": "baz!"} + # You can do this to check that there are no other differences: + assert state_out.relations == { scenario.Relation( endpoint="foo", interface="bar", @@ -345,7 +344,7 @@ def test_relation_data(): local_unit_data={"abc": "baz!"}, remote_app_data={"cde": "baz!"}, ), - ] + } # which is very idiomatic and superbly explicit. Noice. ``` @@ -381,11 +380,11 @@ be mindful when using `PeerRelation` not to include **"this unit"**'s ID in `pee be flagged by the Consistency Checker: ```python -state_in = scenario.State(relations=[ +state_in = scenario.State(relations={ scenario.PeerRelation( endpoint="peers", peers_data={1: {}, 2: {}, 42: {'foo': 'bar'}}, - )]) + )}) meta = { "name": "invalid", @@ -508,15 +507,15 @@ When testing a Kubernetes charm, you can mock container interactions. When using be no containers. So if the charm were to `self.unit.containers`, it would get back an empty dict. To give the charm access to some containers, you need to pass them to the input state, like so: -`State(containers=[...])` +`State(containers={...})` An example of a state including some containers: ```python -state = scenario.State(containers=[ +state = scenario.State(containers={ scenario.Container(name="foo", can_connect=True), scenario.Container(name="bar", can_connect=False) -]) +}) ``` In this case, `self.unit.get_container('foo').can_connect()` would return `True`, while for 'bar' it would give `False`. @@ -535,7 +534,7 @@ container = scenario.Container( can_connect=True, mounts={'local': scenario.Mount(location='/local/share/config.yaml', source=local_file)} ) -state = scenario.State(containers=[container]) +state = scenario.State(containers={container}) ``` In this case, if the charm were to: @@ -567,12 +566,12 @@ class MyCharm(ops.CharmBase): def test_pebble_push(): with tempfile.NamedTemporaryFile() as local_file: - container = scenario,Container( + container = scenario.Container( name='foo', can_connect=True, mounts={'local': Mount(location='/local/share/config.yaml', source=local_file.name)} ) - state_in = State(containers=[container]) + state_in = State(containers={container}) ctx = Context( MyCharm, meta={"name": "foo", "containers": {"foo": {}}} @@ -606,7 +605,7 @@ class MyCharm(ops.CharmBase): def test_pebble_push(): container = scenario.Container(name='foo', can_connect=True) - state_in = scenario.State(containers=[container]) + state_in = scenario.State(containers={container}) ctx = scenario.Context( MyCharm, meta={"name": "foo", "containers": {"foo": {}}} @@ -652,7 +651,7 @@ def test_pebble_exec(): stdout=LS_LL) } ) - state_in = scenario.State(containers=[container]) + state_in = scenario.State(containers={container}) ctx = scenario.Context( MyCharm, meta={"name": "foo", "containers": {"foo": {}}}, @@ -708,7 +707,7 @@ storage = scenario.Storage("foo") # Setup storage with some content: (storage.get_filesystem(ctx) / "myfile.txt").write_text("helloworld") -with ctx.manager(ctx.on.update_status(), scenario.State(storage=[storage])) as mgr: +with ctx.manager(ctx.on.update_status(), scenario.State(storage={storage})) as mgr: foo = mgr.charm.model.storages["foo"][0] loc = foo.location path = loc / "myfile.txt" @@ -753,11 +752,11 @@ So a natural follow-up Scenario test suite for this case would be: ctx = scenario.Context(MyCharm, meta=MyCharm.META) foo_0 = scenario.Storage('foo') # The charm is notified that one of the storages it has requested is ready: -ctx.run(ctx.on.storage_attached(foo_0), scenario.State(storage=[foo_0])) +ctx.run(ctx.on.storage_attached(foo_0), scenario.State(storage={foo_0})) foo_1 = scenario.Storage('foo') # The charm is notified that the other storage is also ready: -ctx.run(ctx.on.storage_attached(foo_1), scenario.State(storage=[foo_0, foo_1])) +ctx.run(ctx.on.storage_attached(foo_1), scenario.State(storage={foo_0, foo_}])) ``` ## Ports @@ -766,7 +765,7 @@ Since `ops 2.6.0`, charms can invoke the `open-port`, `close-port`, and `opened- - simulate a charm run with a port opened by some previous execution ctx = scenario.Context(MyCharm, meta=MyCharm.META) -ctx.run(ctx.on.start(), scenario.State(opened_ports=[scenario.TCPPort(42)])) +ctx.run(ctx.on.start(), scenario.State(opened_ports={scenario.TCPPort(42)})) ``` - assert that a charm has called `open-port` or `close-port`: ```python @@ -775,7 +774,7 @@ state1 = ctx.run(ctx.on.start(), scenario.State()) assert state1.opened_ports == [scenario.TCPPort(42)] state2 = ctx.run(ctx.on.stop(), state1) -assert state2.opened_ports == [] +assert state2.opened_ports == {} ``` ## Secrets @@ -784,12 +783,13 @@ Scenario has secrets. Here's how you use them. ```python state = scenario.State( - secrets=[ + secrets={ scenario.Secret( {0: {'key': 'public'}}, id='foo', - ) - ] + contents={0: {'key': 'public'}} + ), + }, ) ``` @@ -813,15 +813,15 @@ To specify a secret owned by this unit (or app): ```python state = scenario.State( - secrets=[ + secrets={ scenario.Secret( {0: {'key': 'private'}}, id='foo', owner='unit', # or 'app' remote_grants={0: {"remote"}} # the secret owner has granted access to the "remote" app over some relation with ID 0 - ) - ] + ), + }, ) ``` @@ -829,14 +829,14 @@ To specify a secret owned by some other application and give this unit (or app) ```python state = scenario.State( - secrets=[ + secrets={ scenario.Secret( {0: {'key': 'public'}}, id='foo', # owner=None, which is the default revision=0, # the revision that this unit (or app) is currently tracking - ) - ] + ), + }, ) ``` @@ -853,15 +853,16 @@ class MyCharmType(ops.CharmBase): assert self.my_stored_state.foo == 'bar' # this will pass! -state = scenario.State(stored_state=[ +state = scenario.State(stored_states={ scenario.StoredState( owner_path="MyCharmType", name="my_stored_state", content={ 'foo': 'bar', 'baz': {42: 42}, - }) -]) + }), + }, +) ``` And the charm's runtime will see `self.my_stored_state.foo` and `.baz` as expected. Also, you can run assertions on it on @@ -879,7 +880,8 @@ So, the only consistency-level check we enforce in Scenario when it comes to res import pathlib ctx = scenario.Context(MyCharm, meta={'name': 'juliette', "resources": {"foo": {"type": "oci-image"}}}) -with ctx.manager(ctx.on.start(), scenario.State(resources={'foo': '/path/to/resource.tar'})) as mgr: +resource = scenario.Resource('foo', '/path/to/resource.tar') +with ctx.manager(ctx.on.start(), scenario.State(resources={resource})) as mgr: # If the charm, at runtime, were to call self.model.resources.fetch("foo"), it would get '/path/to/resource.tar' back. path = mgr.charm.model.resources.fetch('foo') assert path == pathlib.Path('/path/to/resource.tar') @@ -1060,7 +1062,7 @@ class MyCharm(ops.CharmBase): def test_start_on_deferred_update_status(MyCharm): foo_relation = scenario.Relation('foo') scenario.State( - relations=[foo_relation], + relations={foo_relation}, deferred=[ scenario.deferred('foo_relation_changed', handler=MyCharm._on_foo_relation_changed, diff --git a/scenario/__init__.py b/scenario/__init__.py index a73570a6..276f8682 100644 --- a/scenario/__init__.py +++ b/scenario/__init__.py @@ -18,6 +18,7 @@ Notice, PeerRelation, Relation, + Resource, Secret, State, StateValidationError, @@ -52,6 +53,8 @@ "ICMPPort", "TCPPort", "UDPPort", + "Port", + "Resource", "Storage", "StoredState", "State", diff --git a/scenario/consistency_checker.py b/scenario/consistency_checker.py index 87d9da80..695af322 100644 --- a/scenario/consistency_checker.py +++ b/scenario/consistency_checker.py @@ -108,7 +108,7 @@ def check_resource_consistency( warnings = [] resources_from_meta = set(charm_spec.meta.get("resources", {})) - resources_from_state = set(state.resources) + resources_from_state = {resource.name for resource in state.resources} if not resources_from_meta.issuperset(resources_from_state): errors.append( f"any and all resources passed to State.resources need to have been defined in " @@ -330,11 +330,11 @@ def check_storages_consistency( **_kwargs, # noqa: U101 ) -> Results: """Check the consistency of the state.storages with the charm_spec.metadata (metadata.yaml).""" - state_storage = state.storage + state_storage = state.storages meta_storage = (charm_spec.meta or {}).get("storage", {}) errors = [] - if missing := {s.name for s in state.storage}.difference( + if missing := {s.name for s in state_storage}.difference( set(meta_storage.keys()), ): errors.append( @@ -347,7 +347,7 @@ def check_storages_consistency( if tag in seen: errors.append( f"duplicate storage in State: storage {s.name} with index {s.index} " - f"occurs multiple times in State.storage.", + f"occurs multiple times in State.storages.", ) seen.append(tag) @@ -628,12 +628,12 @@ def check_storedstate_consistency( state: "State", **_kwargs, # noqa: U101 ) -> Results: - """Check the internal consistency of `state.storedstate`.""" + """Check the internal consistency of `state.stored_states`.""" errors = [] # Attribute names must be unique on each object. names = defaultdict(list) - for ss in state.stored_state: + for ss in state.stored_states: names[ss.owner_path].append(ss.name) for owner, owner_names in names.items(): if len(owner_names) != len(set(owner_names)): @@ -642,7 +642,7 @@ def check_storedstate_consistency( ) # The content must be marshallable. - for ss in state.stored_state: + for ss in state.stored_states: # We don't need the marshalled state, just to know that it can be. # This is the same "only simple types" check that ops does. try: diff --git a/scenario/state.py b/scenario/state.py index be171ebb..a8ba86c7 100644 --- a/scenario/state.py +++ b/scenario/state.py @@ -15,6 +15,7 @@ Callable, Dict, Final, + FrozenSet, Generic, List, Literal, @@ -1007,6 +1008,14 @@ def get_filesystem(self, ctx: "Context") -> Path: return ctx._get_storage_root(self.name, self.index) +@dataclasses.dataclass(frozen=True) +class Resource(_max_posargs(0)): + """Represents a resource made available to the charm.""" + + name: str + path: "PathLike" + + @dataclasses.dataclass(frozen=True) class State(_max_posargs(0)): """Represents the juju-owned portion of a unit's state. @@ -1020,7 +1029,7 @@ class State(_max_posargs(0)): default_factory=dict, ) """The present configuration of this charm.""" - relations: List["AnyRelation"] = dataclasses.field(default_factory=list) + relations: FrozenSet["AnyRelation"] = dataclasses.field(default_factory=frozenset) """All relations that currently exist for this charm.""" networks: Dict[str, Network] = dataclasses.field(default_factory=dict) """Manual overrides for any relation and extra bindings currently provisioned for this charm. @@ -1030,36 +1039,38 @@ class State(_max_posargs(0)): support it, but use at your own risk.] If a metadata-defined extra-binding is left empty, it will be defaulted. """ - containers: List[Container] = dataclasses.field(default_factory=list) + containers: FrozenSet[Container] = dataclasses.field(default_factory=frozenset) """All containers (whether they can connect or not) that this charm is aware of.""" - storage: List[Storage] = dataclasses.field(default_factory=list) + storages: FrozenSet[Storage] = dataclasses.field(default_factory=frozenset) """All ATTACHED storage instances for this charm. If a storage is not attached, omit it from this listing.""" # we don't use sets to make json serialization easier - opened_ports: List[_Port] = dataclasses.field(default_factory=list) + opened_ports: FrozenSet[_Port] = dataclasses.field(default_factory=frozenset) """Ports opened by juju on this charm.""" leader: bool = False """Whether this charm has leadership.""" model: Model = Model() """The model this charm lives in.""" - secrets: List[Secret] = dataclasses.field(default_factory=list) + secrets: FrozenSet[Secret] = dataclasses.field(default_factory=frozenset) """The secrets this charm has access to (as an owner, or as a grantee). The presence of a secret in this list entails that the charm can read it. Whether it can manage it or not depends on the individual secret's `owner` flag.""" - resources: Dict[str, "PathLike"] = dataclasses.field(default_factory=dict) - """Mapping from resource name to path at which the resource can be found.""" + resources: FrozenSet[Resource] = dataclasses.field(default_factory=frozenset) + """All resources that this charm can access.""" planned_units: int = 1 """Number of non-dying planned units that are expected to be running this application. Use with caution.""" - # represents the OF's event queue. These events will be emitted before the event being + # Represents the OF's event queue. These events will be emitted before the event being # dispatched, and represent the events that had been deferred during the previous run. # If the charm defers any events during "this execution", they will be appended # to this list. deferred: List["DeferredEvent"] = dataclasses.field(default_factory=list) """Events that have been deferred on this charm by some previous execution.""" - stored_state: List["StoredState"] = dataclasses.field(default_factory=list) + stored_states: FrozenSet["StoredState"] = dataclasses.field( + default_factory=frozenset, + ) """Contents of a charm's stored state.""" # the current statuses. Will be cast to _EntitiyStatus in __post_init__ @@ -1079,6 +1090,23 @@ def __post_init__(self): object.__setattr__(self, name, _status_to_entitystatus(val)) else: raise TypeError(f"Invalid status.{name}: {val!r}") + # It's convenient to pass a set, but we really want the attributes to be + # frozen sets to increase the immutability of State objects. + for name in [ + "relations", + "containers", + "storages", + "opened_ports", + "secrets", + "resources", + "stored_states", + ]: + val = getattr(self, name) + # We check for "not frozenset" rather than "is set" so that you can + # actually pass a tuple or list or really any iterable of hashable + # objects, and it will end up as a frozenset. + if not isinstance(val, frozenset): + object.__setattr__(self, name, frozenset(val)) def _update_workload_version(self, new_workload_version: str): """Update the current app version and record the previous one.""" @@ -1130,6 +1158,44 @@ def get_container(self, container: Union[str, Container]) -> Container: raise ValueError(f"container: {container_name} not found in the State") return containers[0] + def get_secret(self, *, id: str, label: str) -> Secret: + """Get secret from this State, based on the secret's id or label.""" + for secret in self.secrets: + if secret.id == id or secret.label == label: + return secret + if id: + message = f"secret: id={id} not found in the State" + elif label: + message = f"secret: label={label} not found in the State" + else: + message = "An id or label must be provided." + raise ValueError(message) + + def get_stored_state( + self, + name: str, + owner_path: Optional[str] = None, + ) -> StoredState: + """Get stored state from this State, based on the stored state's name and owner_path.""" + for stored_state in self.stored_states: + if stored_state.name == name and stored_state.owner_path == owner_path: + return stored_state + raise ValueError(f"stored state: {name} not found in the State") + + def get_storage(self, name: str, index: int = 0) -> Storage: + """Get storage from this State, based on the storage's name and index.""" + for storage in self.storages: + if storage.name == name and storage.index == index: + return storage + raise ValueError(f"storage: name={name}, index={index} not found in the State") + + def get_relation(self, id: int) -> "AnyRelation": + """Get relation from this State, based on the relation's id.""" + for relation in self.relations: + if relation.relation_id == id: + return relation + raise ValueError(f"relation: id={id} not found in the State") + def get_relations(self, endpoint: str) -> Tuple["AnyRelation", ...]: """Get all relations on this endpoint from the current state.""" From 7c665624f482ca15fe59a8dd1783b2d1075772b5 Mon Sep 17 00:00:00 2001 From: Tony Meyer Date: Thu, 30 May 2024 19:04:49 +1200 Subject: [PATCH 03/12] Finish the conversion (all tests pass). --- scenario/consistency_checker.py | 7 +- scenario/mocking.py | 45 ++++---- scenario/runtime.py | 16 +-- scenario/state.py | 142 ++++++++++++++++++++++++- tests/helpers.py | 20 +++- tests/test_charm_spec_autoload.py | 2 +- tests/test_consistency_checker.py | 77 +++++++------- tests/test_e2e/test_deferred.py | 8 +- tests/test_e2e/test_event_bind.py | 62 +++++++++++ tests/test_e2e/test_pebble.py | 32 +++--- tests/test_e2e/test_play_assertions.py | 6 +- tests/test_e2e/test_ports.py | 7 +- tests/test_e2e/test_relations.py | 28 ++--- tests/test_e2e/test_secrets.py | 58 +++++----- tests/test_e2e/test_state.py | 22 ++-- tests/test_e2e/test_storage.py | 6 +- tests/test_e2e/test_stored_state.py | 24 +++-- 17 files changed, 394 insertions(+), 168 deletions(-) create mode 100644 tests/test_e2e/test_event_bind.py diff --git a/scenario/consistency_checker.py b/scenario/consistency_checker.py index 695af322..92dc30d5 100644 --- a/scenario/consistency_checker.py +++ b/scenario/consistency_checker.py @@ -4,7 +4,7 @@ import marshal import os import re -from collections import Counter, defaultdict +from collections import defaultdict from collections.abc import Sequence from numbers import Number from typing import TYPE_CHECKING, Iterable, List, NamedTuple, Tuple, Union @@ -593,11 +593,6 @@ def check_containers_consistency( f"Missing from metadata: {diff}.", ) - # guard against duplicate container names - names = Counter(state_containers) - if dupes := [n for n in names if names[n] > 1]: - errors.append(f"Duplicate container name(s): {dupes}.") - return Results(errors, []) diff --git a/scenario/mocking.py b/scenario/mocking.py index b17627d3..2f1745c6 100644 --- a/scenario/mocking.py +++ b/scenario/mocking.py @@ -131,7 +131,9 @@ def open_port( port_ = _port_cls_by_protocol[protocol](port=port) ports = self._state.opened_ports if port_ not in ports: - ports.append(port_) + ports.add(port_) + if ports != self._state.opened_ports: + self._state._update_opened_ports(frozenset(ports)) def close_port( self, @@ -142,6 +144,8 @@ def close_port( ports = self._state.opened_ports if _port in ports: ports.remove(_port) + if ports != self._state.opened_ports: + self._state._update_opened_ports(frozenset(ports)) def get_pebble(self, socket_path: str) -> "Client": container_name = socket_path.split("/")[ @@ -168,11 +172,9 @@ def _get_relation_by_id( rel_id, ) -> Union["Relation", "SubordinateRelation", "PeerRelation"]: try: - return next( - filter(lambda r: r.id == rel_id, self._state.relations), - ) - except StopIteration: - raise RelationNotFoundError() + return self._state.get_relation(rel_id) + except ValueError: + raise RelationNotFoundError() from None def _get_secret(self, id=None, label=None): # FIXME: what error would a charm get IRL? @@ -371,7 +373,9 @@ def secret_add( rotate=rotate, owner=owner, ) - self._state.secrets.append(secret) + secrets = set(self._state.secrets) + secrets.add(secret) + self._state._update_secrets(frozenset(secrets)) return secret_id def _check_can_manage_secret( @@ -546,7 +550,7 @@ def action_get(self): def storage_add(self, name: str, count: int = 1): if not isinstance(count, int) or isinstance(count, bool): raise TypeError( - f"storage count must be integer, got: {count} ({type(count)})", + f"storage count must be integer, got: {count} ({type(count)}", ) if "/" in name: @@ -557,7 +561,7 @@ def storage_add(self, name: str, count: int = 1): def storage_list(self, name: str) -> List[int]: return [ - storage.index for storage in self._state.storage if storage.name == name + storage.index for storage in self._state.storages if storage.name == name ] def _storage_event_details(self) -> Tuple[int, str]: @@ -584,7 +588,7 @@ def storage_get(self, storage_name_id: str, attribute: str) -> str: name, index = storage_name_id.split("/") index = int(index) storages: List[Storage] = [ - s for s in self._state.storage if s.name == name and s.index == index + s for s in self._state.storages if s.name == name and s.index == index ] # should not really happen: sanity checks. In practice, ops will guard against these paths. @@ -624,16 +628,19 @@ def add_metrics( "it's deprecated API)", ) + # TODO: It seems like this method has no tests. def resource_get(self, resource_name: str) -> str: - try: - return str(self._state.resources[resource_name]) - except KeyError: - # ops will not let us get there if the resource name is unknown from metadata. - # but if the user forgot to add it in State, then we remind you of that. - raise RuntimeError( - f"Inconsistent state: " - f"resource {resource_name} not found in State. please pass it.", - ) + # We assume that there are few enough resources that a linear search + # will perform well enough. + for resource in self._state.resources: + if resource.name == resource_name: + return str(resource.path) + # ops will not let us get there if the resource name is unknown from metadata. + # but if the user forgot to add it in State, then we remind you of that. + raise RuntimeError( + f"Inconsistent state: " + f"resource {resource_name} not found in State. please pass it.", + ) def credential_get(self) -> CloudSpec_Ops: if not self._context.app_trusted: diff --git a/scenario/runtime.py b/scenario/runtime.py index 97a7c773..91e4705b 100644 --- a/scenario/runtime.py +++ b/scenario/runtime.py @@ -10,7 +10,7 @@ import typing from contextlib import contextmanager from pathlib import Path -from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union +from typing import TYPE_CHECKING, Dict, FrozenSet, List, Optional, Type, Union import yaml from ops import pebble @@ -62,12 +62,12 @@ def _open_db(self) -> SQLiteStorage: """Open the db.""" return SQLiteStorage(self._state_file) - def get_stored_state(self) -> List["StoredState"]: + def get_stored_states(self) -> FrozenSet["StoredState"]: """Load any StoredState data structures from the db.""" db = self._open_db() - stored_state = [] + stored_states = set() for handle_path in db.list_snapshots(): if not EVENT_REGEX.match(handle_path) and ( match := STORED_STATE_REGEX.match(handle_path) @@ -75,10 +75,10 @@ def get_stored_state(self) -> List["StoredState"]: stored_state_snapshot = db.load_snapshot(handle_path) kwargs = match.groupdict() sst = StoredState(content=stored_state_snapshot, **kwargs) - stored_state.append(sst) + stored_states.add(sst) db.close() - return stored_state + return frozenset(stored_states) def get_deferred_events(self) -> List["DeferredEvent"]: """Load any DeferredEvent data structures from the db.""" @@ -119,7 +119,7 @@ def apply_state(self, state: "State"): ) from e db.save_snapshot(event.handle_path, event.snapshot_data) - for stored_state in state.stored_state: + for stored_state in state.stored_states: db.save_snapshot(stored_state.handle_path, stored_state.content) db.close() @@ -388,8 +388,8 @@ def _close_storage(self, state: "State", temporary_charm_root: Path): """Now that we're done processing this event, read the charm state and expose it.""" store = self._get_state_db(temporary_charm_root) deferred = store.get_deferred_events() - stored_state = store.get_stored_state() - return dataclasses.replace(state, deferred=deferred, stored_state=stored_state) + stored_state = store.get_stored_states() + return dataclasses.replace(state, deferred=deferred, stored_states=stored_state) @contextmanager def _exec_ctx(self, ctx: "Context"): diff --git a/scenario/state.py b/scenario/state.py index a8ba86c7..dd7798b8 100644 --- a/scenario/state.py +++ b/scenario/state.py @@ -290,6 +290,29 @@ class Secret: expire: Optional[datetime.datetime] = None rotate: Optional[SecretRotate] = None + def __hash__(self) -> int: + return hash(self.id) + + def __eq__(self, value: object) -> bool: + if value is self: + return True + if not isinstance(value, Secret): + return False + for attr in ( + "id", + "owner", + "revision", + "label", + "description", + "expire", + "rotate", + "remote_grants", + "contents", + ): + if getattr(self, attr) != getattr(value, attr): + return False + return True + def _set_revision(self, revision: int): """Set a new tracked revision.""" # bypass frozen dataclass @@ -453,6 +476,23 @@ def __post_init__(self): for databag in self._databags: self._validate_databag(databag) + def __hash__(self) -> int: + return hash(self.id) + + def __eq__(self, value: object) -> bool: + if value is self: + return True + if not isinstance(value, _RelationBase): + return False + if self.endpoint != value.endpoint or self.interface != value.interface: + return False + if ( + self.local_app_data != value.local_app_data + or self.local_unit_data != value.local_unit_data + ): + return False + return True + def _validate_databag(self, databag: dict): if not isinstance(databag, dict): raise StateValidationError( @@ -486,6 +526,30 @@ class Relation(_RelationBase): default_factory=lambda: {0: DEFAULT_JUJU_DATABAG.copy()}, # dedup ) + def __hash__(self) -> int: + return hash(self.id) + + def __eq__(self, value: object) -> bool: + if value is self: + return True + if not isinstance(value, Relation): + return False + if ( + self.endpoint != value.endpoint + or self.interface != value.interface + or self.limit != value.limit + or self.remote_app_name != value.remote_app_name + ): + return False + if ( + self.local_app_data != value.local_app_data + or self.local_unit_data != value.local_unit_data + or self.remote_app_data != value.remote_app_data + or self.remote_units_data != value.remote_units_data + ): + return False + return True + @property def _remote_app_name(self) -> str: """Who is on the other end of this relation?""" @@ -520,6 +584,30 @@ class SubordinateRelation(_RelationBase): remote_app_name: str = "remote" remote_unit_id: int = 0 + def __hash__(self) -> int: + return hash(self.id) + + def __eq__(self, value: object) -> bool: + if value is self: + return True + if not isinstance(value, SubordinateRelation): + return False + if ( + self.endpoint != value.endpoint + or self.interface != value.interface + or self.remote_app_name != value.remote_app_name + or self.remote_unit_id != value.remote_unit_id + ): + return False + if ( + self.local_app_data != value.local_app_data + or self.local_unit_data != value.local_unit_data + or self.remote_app_data != value.remote_app_data + or self.remote_unit_data != value.remote_unit_data + ): + return False + return True + @property def _remote_unit_ids(self) -> Tuple[int]: """Ids of the units on the other end of this relation.""" @@ -555,6 +643,23 @@ class PeerRelation(_RelationBase): # mapping from peer unit IDs to their databag contents. # Consistency checks will validate that *this unit*'s ID is not in here. + def __hash__(self) -> int: + return hash(self.id) + + def __eq__(self, value: object) -> bool: + if value is self: + return True + if not isinstance(value, PeerRelation): + return False + if self.endpoint != value.endpoint or self.interface != value.interface: + return False + if ( + self.local_app_data != value.local_app_data + or self.local_unit_data != value.local_unit_data + ): + return False + return True + @property def _databags(self): """Yield all databags in this relation.""" @@ -763,6 +868,9 @@ class Container(_max_posargs(1)): notices: List[Notice] = dataclasses.field(default_factory=list) + def __hash__(self) -> int: + return hash(self.name) + def _render_services(self): # copied over from ops.testing._TestingPebbleClient._render_services() services = {} # type: Dict[str, pebble.Service] @@ -905,6 +1013,18 @@ class StoredState(_max_posargs(1)): def handle_path(self): return f"{self.owner_path or ''}/{self._data_type_name}[{self.name}]" + def __hash__(self) -> int: + return hash(self.handle_path) + + def __eq__(self, value: object) -> bool: + if value is self: + return True + if not isinstance(value, StoredState): + return False + if self.handle_path != value.handle_path: + return False + return self.content == value.content + _RawPortProtocolLiteral = Literal["tcp", "udp", "icmp"] @@ -1127,6 +1247,16 @@ def _update_status( # bypass frozen dataclass object.__setattr__(self, name, _EntityStatus(new_status, new_message)) + def _update_opened_ports(self, new_ports: FrozenSet[Port]): + """Update the current opened ports.""" + # bypass frozen dataclass + object.__setattr__(self, "opened_ports", new_ports) + + def _update_secrets(self, new_secrets: FrozenSet[Secret]): + """Update the current secrets.""" + # bypass frozen dataclass + object.__setattr__(self, "secrets", new_secrets) + def with_can_connect(self, container_name: str, can_connect: bool) -> "State": def replacer(container: Container): if container.name == container_name: @@ -1158,7 +1288,12 @@ def get_container(self, container: Union[str, Container]) -> Container: raise ValueError(f"container: {container_name} not found in the State") return containers[0] - def get_secret(self, *, id: str, label: str) -> Secret: + def get_secret( + self, + *, + id: Optional[str] = None, + label: Optional[str] = None, + ) -> Secret: """Get secret from this State, based on the secret's id or label.""" for secret in self.secrets: if secret.id == id or secret.label == label: @@ -1192,7 +1327,7 @@ def get_storage(self, name: str, index: int = 0) -> Storage: def get_relation(self, id: int) -> "AnyRelation": """Get relation from this State, based on the relation's id.""" for relation in self.relations: - if relation.relation_id == id: + if relation.id == id: return relation raise ValueError(f"relation: id={id} not found in the State") @@ -1211,9 +1346,10 @@ def get_relations(self, endpoint: str) -> Tuple["AnyRelation", ...]: if normalize_name(r.endpoint) == normalized_endpoint ) + # TODO: It seems like this method has no tests. def get_storages(self, name: str) -> Tuple["Storage", ...]: """Get all storages with this name.""" - return tuple(s for s in self.storage if s.name == name) + return tuple(s for s in self.storages if s.name == name) def _is_valid_charmcraft_25_metadata(meta: Dict[str, Any]): diff --git a/tests/helpers.py b/tests/helpers.py index 7dd1f835..47c3d339 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -67,11 +67,21 @@ def trigger( return state_out -def jsonpatch_delta(input: "State", output: "State"): - patch = jsonpatch.make_patch( - dataclasses.asdict(output), - dataclasses.asdict(input), - ).patch +def jsonpatch_delta(self, other: "State"): + dict_other = dataclasses.asdict(other) + dict_self = dataclasses.asdict(self) + for attr in ( + "relations", + "containers", + "storages", + "opened_ports", + "secrets", + "resources", + "stored_states", + ): + dict_other[attr] = [dataclasses.asdict(o) for o in dict_other[attr]] + dict_self[attr] = [dataclasses.asdict(o) for o in dict_self[attr]] + patch = jsonpatch.make_patch(dict_other, dict_self).patch return sort_patch(patch) diff --git a/tests/test_charm_spec_autoload.py b/tests/test_charm_spec_autoload.py index 51ba1391..fb738f87 100644 --- a/tests/test_charm_spec_autoload.py +++ b/tests/test_charm_spec_autoload.py @@ -144,7 +144,7 @@ def test_relations_ok(tmp_path, legacy): ) as charm: # this would fail if there were no 'cuddles' relation defined in meta ctx = Context(charm) - ctx.run(ctx.on.start(), State(relations=[Relation("cuddles")])) + ctx.run(ctx.on.start(), State(relations={Relation("cuddles")})) @pytest.mark.parametrize("legacy", (True, False)) diff --git a/tests/test_consistency_checker.py b/tests/test_consistency_checker.py index 82321558..41d39836 100644 --- a/tests/test_consistency_checker.py +++ b/tests/test_consistency_checker.py @@ -16,6 +16,7 @@ Notice, PeerRelation, Relation, + Resource, Secret, State, Storage, @@ -63,7 +64,7 @@ def test_workload_event_without_container(): _CharmSpec(MyCharm, {}), ) assert_consistent( - State(containers=[Container("foo")]), + State(containers={Container("foo")}), _Event("foo-pebble-ready", container=Container("foo")), _CharmSpec(MyCharm, {"containers": {"foo": {}}}), ) @@ -74,12 +75,12 @@ def test_workload_event_without_container(): ) notice = Notice("example.com/foo") assert_consistent( - State(containers=[Container("foo", notices=[notice])]), + State(containers={Container("foo", notices=[notice])}), _Event("foo-pebble-custom-notice", container=Container("foo"), notice=notice), _CharmSpec(MyCharm, {"containers": {"foo": {}}}), ) assert_inconsistent( - State(containers=[Container("foo")]), + State(containers={Container("foo")}), _Event("foo-pebble-custom-notice", container=Container("foo"), notice=notice), _CharmSpec(MyCharm, {"containers": {"foo": {}}}), ) @@ -87,12 +88,12 @@ def test_workload_event_without_container(): def test_container_meta_mismatch(): assert_inconsistent( - State(containers=[Container("bar")]), + State(containers={Container("bar")}), _Event("foo"), _CharmSpec(MyCharm, {"containers": {"baz": {}}}), ) assert_consistent( - State(containers=[Container("bar")]), + State(containers={Container("bar")}), _Event("foo"), _CharmSpec(MyCharm, {"containers": {"bar": {}}}), ) @@ -100,12 +101,12 @@ def test_container_meta_mismatch(): def test_container_in_state_but_no_container_in_meta(): assert_inconsistent( - State(containers=[Container("bar")]), + State(containers={Container("bar")}), _Event("foo"), _CharmSpec(MyCharm, {}), ) assert_consistent( - State(containers=[Container("bar")]), + State(containers={Container("bar")}), _Event("foo"), _CharmSpec(MyCharm, {"containers": {"bar": {}}}), ) @@ -119,7 +120,7 @@ def test_container_not_in_state(): _CharmSpec(MyCharm, {"containers": {"bar": {}}}), ) assert_consistent( - State(containers=[container]), + State(containers={container}), _Event("bar_pebble_ready", container=container), _CharmSpec(MyCharm, {"containers": {"bar": {}}}), ) @@ -132,7 +133,7 @@ def test_evt_bad_container_name(): _CharmSpec(MyCharm, {}), ) assert_consistent( - State(containers=[Container("bar")]), + State(containers={Container("bar")}), _Event("bar-pebble-ready", container=Container("bar")), _CharmSpec(MyCharm, {"containers": {"bar": {}}}), ) @@ -147,7 +148,7 @@ def test_evt_bad_relation_name(suffix): ) relation = Relation("bar") assert_consistent( - State(relations=[relation]), + State(relations={relation}), _Event(f"bar{suffix}", relation=relation), _CharmSpec(MyCharm, {"requires": {"bar": {"interface": "xxx"}}}), ) @@ -158,7 +159,7 @@ def test_evt_no_relation(suffix): assert_inconsistent(State(), _Event(f"foo{suffix}"), _CharmSpec(MyCharm, {})) relation = Relation("bar") assert_consistent( - State(relations=[relation]), + State(relations={relation}), _Event(f"bar{suffix}", relation=relation), _CharmSpec(MyCharm, {"requires": {"bar": {"interface": "xxx"}}}), ) @@ -262,13 +263,13 @@ def test_config_secret_old_juju(juju_version): def test_secrets_jujuv_bad(bad_v): secret = Secret("secret:foo", {0: {"a": "b"}}) assert_inconsistent( - State(secrets=[secret]), + State(secrets={secret}), _Event("bar"), _CharmSpec(MyCharm, {}), bad_v, ) assert_inconsistent( - State(secrets=[secret]), + State(secrets={secret}), secret.changed_event, _CharmSpec(MyCharm, {}), bad_v, @@ -285,7 +286,7 @@ def test_secrets_jujuv_bad(bad_v): @pytest.mark.parametrize("good_v", ("3.0", "3.1", "3", "3.33", "4", "100")) def test_secrets_jujuv_bad(good_v): assert_consistent( - State(secrets=[Secret(id="secret:foo", contents={0: {"a": "b"}})]), + State(secrets={Secret(id="secret:foo", contents={0: {"a": "b"}})}), _Event("bar"), _CharmSpec(MyCharm, {}), good_v, @@ -308,12 +309,12 @@ def test_secret_not_in_state(): def test_peer_relation_consistency(): assert_inconsistent( - State(relations=[Relation("foo")]), + State(relations={Relation("foo")}), _Event("bar"), _CharmSpec(MyCharm, {"peers": {"foo": {"interface": "bar"}}}), ) assert_consistent( - State(relations=[PeerRelation("foo")]), + State(relations={PeerRelation("foo")}), _Event("bar"), _CharmSpec(MyCharm, {"peers": {"foo": {"interface": "bar"}}}), ) @@ -335,7 +336,7 @@ def test_duplicate_endpoints_inconsistent(): def test_sub_relation_consistency(): assert_inconsistent( - State(relations=[Relation("foo")]), + State(relations={Relation("foo")}), _Event("bar"), _CharmSpec( MyCharm, @@ -344,7 +345,7 @@ def test_sub_relation_consistency(): ) assert_consistent( - State(relations=[SubordinateRelation("foo")]), + State(relations={SubordinateRelation("foo")}), _Event("bar"), _CharmSpec( MyCharm, @@ -355,7 +356,7 @@ def test_sub_relation_consistency(): def test_relation_sub_inconsistent(): assert_inconsistent( - State(relations=[SubordinateRelation("foo")]), + State(relations={SubordinateRelation("foo")}), _Event("bar"), _CharmSpec(MyCharm, {"requires": {"foo": {"interface": "bar"}}}), ) @@ -369,7 +370,7 @@ def test_relation_not_in_state(): _CharmSpec(MyCharm, {"requires": {"foo": {"interface": "bar"}}}), ) assert_consistent( - State(relations=[relation]), + State(relations={relation}), _Event("foo_relation_changed", relation=relation), _CharmSpec(MyCharm, {"requires": {"foo": {"interface": "bar"}}}), ) @@ -377,7 +378,7 @@ def test_relation_not_in_state(): def test_dupe_containers_inconsistent(): assert_inconsistent( - State(containers=[Container("foo"), Container("foo")]), + State(containers={Container("foo"), Container("foo")}), _Event("bar"), _CharmSpec(MyCharm, {"containers": {"foo": {}}}), ) @@ -459,7 +460,7 @@ def test_action_params_type(ptype, good, bad): def test_duplicate_relation_ids(): assert_inconsistent( - State(relations=[Relation("foo", id=1), Relation("bar", id=1)]), + State(relations={Relation("foo", id=1), Relation("bar", id=1)}), _Event("start"), _CharmSpec( MyCharm, @@ -472,13 +473,13 @@ def test_duplicate_relation_ids(): def test_relation_without_endpoint(): assert_inconsistent( - State(relations=[Relation("foo", id=1), Relation("bar", id=1)]), + State(relations={Relation("foo", id=1), Relation("bar", id=1)}), _Event("start"), _CharmSpec(MyCharm, meta={"name": "charlemagne"}), ) assert_consistent( - State(relations=[Relation("foo", id=1), Relation("bar", id=2)]), + State(relations={Relation("foo", id=1), Relation("bar", id=2)}), _Event("start"), _CharmSpec( MyCharm, @@ -492,12 +493,12 @@ def test_relation_without_endpoint(): def test_storage_event(): storage = Storage("foo") assert_inconsistent( - State(storage=[storage]), + State(storages={storage}), _Event("foo-storage-attached"), _CharmSpec(MyCharm, meta={"name": "rupert"}), ) assert_inconsistent( - State(storage=[storage]), + State(storages={storage}), _Event("foo-storage-attached"), _CharmSpec( MyCharm, meta={"name": "rupert", "storage": {"foo": {"type": "filesystem"}}} @@ -510,19 +511,19 @@ def test_storage_states(): storage2 = Storage("foo", index=1) assert_inconsistent( - State(storage=[storage1, storage2]), + State(storages={storage1, storage2}), _Event("start"), _CharmSpec(MyCharm, meta={"name": "everett"}), ) assert_consistent( - State(storage=[storage1, dataclasses.replace(storage2, index=2)]), + State(storages={storage1, dataclasses.replace(storage2, index=2)}), _Event("start"), _CharmSpec( MyCharm, meta={"name": "frank", "storage": {"foo": {"type": "filesystem"}}} ), ) assert_consistent( - State(storage=[storage1, dataclasses.replace(storage2, name="marx")]), + State(storages={storage1, dataclasses.replace(storage2, name="marx")}), _Event("start"), _CharmSpec( MyCharm, @@ -560,7 +561,7 @@ def test_storage_not_in_state(): def test_resource_states(): # happy path assert_consistent( - State(resources={"foo": "/foo/bar.yaml"}), + State(resources={Resource("foo", "/foo/bar.yaml")}), _Event("start"), _CharmSpec( MyCharm, @@ -580,7 +581,7 @@ def test_resource_states(): # resource not defined in meta assert_inconsistent( - State(resources={"bar": "/foo/bar.yaml"}), + State(resources={Resource("bar", "/foo/bar.yaml")}), _Event("start"), _CharmSpec( MyCharm, @@ -589,7 +590,7 @@ def test_resource_states(): ) assert_inconsistent( - State(resources={"bar": "/foo/bar.yaml"}), + State(resources={Resource("bar", "/foo/bar.yaml")}), _Event("start"), _CharmSpec( MyCharm, @@ -672,12 +673,12 @@ def test_cloudspec_consistency(): def test_storedstate_consistency(): assert_consistent( State( - stored_state=[ + stored_states={ StoredState(content={"foo": "bar"}), StoredState(name="my_stored_state", content={"foo": 1}), StoredState(owner_path="MyCharmLib", content={"foo": None}), StoredState(owner_path="OtherCharmLib", content={"foo": (1, 2, 3)}), - ] + } ), _Event("start"), _CharmSpec( @@ -689,10 +690,10 @@ def test_storedstate_consistency(): ) assert_inconsistent( State( - stored_state=[ + stored_states={ StoredState(owner_path=None, content={"foo": "bar"}), StoredState(owner_path=None, name="_stored", content={"foo": "bar"}), - ] + } ), _Event("start"), _CharmSpec( @@ -704,11 +705,11 @@ def test_storedstate_consistency(): ) assert_inconsistent( State( - stored_state=[ + stored_states={ StoredState( owner_path=None, content={"secret": Secret(id="foo", contents={})} ) - ] + } ), _Event("start"), _CharmSpec( diff --git a/tests/test_e2e/test_deferred.py b/tests/test_e2e/test_deferred.py index fccb326c..f988dcc5 100644 --- a/tests/test_e2e/test_deferred.py +++ b/tests/test_e2e/test_deferred.py @@ -120,7 +120,7 @@ def test_deferred_relation_event(mycharm): out = trigger( State( - relations=[rel], + relations={rel}, deferred=[ deferred( event="foo_relation_changed", @@ -152,7 +152,7 @@ def test_deferred_relation_event_from_relation(mycharm): rel = Relation(endpoint="foo", remote_app_name="remote") out = trigger( State( - relations=[rel], + relations={rel}, deferred=[ ctx.on.relation_changed(rel, remote_unit=1).deferred( handler=mycharm._on_event @@ -190,7 +190,7 @@ def test_deferred_workload_event(mycharm): out = trigger( State( - containers=[ctr], + containers={ctr}, deferred=[ _Event("foo_pebble_ready", container=ctr).deferred( handler=mycharm._on_event @@ -238,7 +238,7 @@ def test_defer_reemit_relation_event(mycharm): rel = Relation("foo") mycharm.defer_next = 1 - state_1 = ctx.run(ctx.on.relation_created(rel), State(relations=[rel])) + state_1 = ctx.run(ctx.on.relation_created(rel), State(relations={rel})) mycharm.defer_next = 0 state_2 = ctx.run(ctx.on.start(), state_1) diff --git a/tests/test_e2e/test_event_bind.py b/tests/test_e2e/test_event_bind.py new file mode 100644 index 00000000..141592fe --- /dev/null +++ b/tests/test_e2e/test_event_bind.py @@ -0,0 +1,62 @@ +import pytest + +from scenario import Container, Event, Relation, Secret, State +from scenario.state import BindFailedError + + +def test_bind_relation(): + event = Event("foo-relation-changed") + foo_relation = Relation("foo") + state = State(relations={foo_relation}) + assert event.bind(state).relation is foo_relation + + +def test_bind_relation_complex_name(): + event = Event("foo-bar-baz-relation-changed") + foo_relation = Relation("foo_bar_baz") + state = State(relations={foo_relation}) + assert event.bind(state).relation is foo_relation + + +def test_bind_relation_notfound(): + event = Event("foo-relation-changed") + state = State() + with pytest.raises(BindFailedError): + event.bind(state) + + +def test_bind_relation_toomany(caplog): + event = Event("foo-relation-changed") + foo_relation = Relation("foo") + foo_relation1 = Relation("foo") + state = State(relations={foo_relation, foo_relation1}) + event.bind(state) + assert "too many relations" in caplog.text + + +def test_bind_secret(): + event = Event("secret-changed") + secret = Secret("foo", {"a": "b"}) + state = State(secrets={secret}) + assert event.bind(state).secret is secret + + +def test_bind_secret_notfound(): + event = Event("secret-changed") + state = State() + with pytest.raises(BindFailedError): + event.bind(state) + + +def test_bind_container(): + event = Event("foo-pebble-ready") + container = Container("foo") + state = State(containers={container}) + assert event.bind(state).container is container + + +def test_bind_container_notfound(): + event = Event("foo-pebble-ready") + state = State() + with pytest.raises(BindFailedError): + event.bind(state) diff --git a/tests/test_e2e/test_pebble.py b/tests/test_e2e/test_pebble.py index 7dfbba67..08acebc3 100644 --- a/tests/test_e2e/test_pebble.py +++ b/tests/test_e2e/test_pebble.py @@ -61,7 +61,7 @@ def callback(self: CharmBase): assert can_connect == self.unit.get_container("foo").can_connect() trigger( - State(containers=[Container(name="foo", can_connect=can_connect)]), + State(containers={Container(name="foo", can_connect=can_connect)}), charm_type=charm_cls, meta={"name": "foo", "containers": {"foo": {}}}, event="start", @@ -82,13 +82,13 @@ def callback(self: CharmBase): trigger( State( - containers=[ + containers={ Container( name="foo", can_connect=True, mounts={"bar": Mount(location="/bar/baz.txt", source=pth)}, ) - ] + } ), charm_type=charm_cls, meta={"name": "foo", "containers": {"foo": {}}}, @@ -122,7 +122,7 @@ def callback(self: CharmBase): can_connect=True, mounts={"foo": Mount(location="/foo", source=td.name)}, ) - state = State(containers=[container]) + state = State(containers={container}) ctx = Context( charm_type=charm_cls, @@ -156,7 +156,7 @@ def callback(self: CharmBase): else: # nothing has changed - out_purged = dataclasses.replace(out, stored_state=state.stored_state) + out_purged = dataclasses.replace(out, stored_states=state.stored_states) assert not jsonpatch_delta(out_purged, state) @@ -197,13 +197,13 @@ def callback(self: CharmBase): trigger( State( - containers=[ + containers={ Container( name="foo", can_connect=True, exec_mock={(cmd,): ExecOutput(stdout="hello pebble")}, ) - ] + } ), charm_type=charm_cls, meta={"name": "foo", "containers": {"foo": {}}}, @@ -220,7 +220,7 @@ def callback(self: CharmBase): container = Container(name="foo", can_connect=True) trigger( - State(containers=[container]), + State(containers={container}), charm_type=charm_cls, meta={"name": "foo", "containers": {"foo": {}}}, event="pebble_ready", @@ -287,14 +287,14 @@ def _on_ready(self, event): ) out = trigger( - State(containers=[container]), + State(containers={container}), charm_type=PlanCharm, meta={"name": "foo", "containers": {"foo": {}}}, event="pebble_ready", ) serv = lambda name, obj: pebble.Service(name, raw=obj) - container = out.containers[0] + container = out.get_container(container.name) assert container.plan.services == { "barserv": serv("barserv", {"startup": "disabled"}), "fooserv": serv("fooserv", {"startup": "enabled"}), @@ -308,13 +308,13 @@ def _on_ready(self, event): def test_exec_wait_error(charm_cls): state = State( - containers=[ + containers={ Container( name="foo", can_connect=True, exec_mock={("foo",): ExecOutput(stdout="hello pebble", return_code=1)}, ) - ] + } ) ctx = Context(charm_cls, meta={"name": "foo", "containers": {"foo": {}}}) @@ -328,7 +328,7 @@ def test_exec_wait_error(charm_cls): def test_exec_wait_output(charm_cls): state = State( - containers=[ + containers={ Container( name="foo", can_connect=True, @@ -336,7 +336,7 @@ def test_exec_wait_output(charm_cls): ("foo",): ExecOutput(stdout="hello pebble", stderr="oepsie") }, ) - ] + } ) ctx = Context(charm_cls, meta={"name": "foo", "containers": {"foo": {}}}) @@ -350,13 +350,13 @@ def test_exec_wait_output(charm_cls): def test_exec_wait_output_error(charm_cls): state = State( - containers=[ + containers={ Container( name="foo", can_connect=True, exec_mock={("foo",): ExecOutput(stdout="hello pebble", return_code=1)}, ) - ] + } ) ctx = Context(charm_cls, meta={"name": "foo", "containers": {"foo": {}}}) diff --git a/tests/test_e2e/test_play_assertions.py b/tests/test_e2e/test_play_assertions.py index 7fe07899..103940af 100644 --- a/tests/test_e2e/test_play_assertions.py +++ b/tests/test_e2e/test_play_assertions.py @@ -62,7 +62,7 @@ def post_event(charm): assert out.unit_status == ActiveStatus("yabadoodle") - out_purged = dataclasses.replace(out, stored_state=initial_state.stored_state) + out_purged = dataclasses.replace(out, stored_states=initial_state.stored_states) assert jsonpatch_delta(out_purged, initial_state) == [ { "op": "replace", @@ -100,7 +100,7 @@ def check_relation_data(charm): assert remote_app_data == {"yaba": "doodle"} state_in = State( - relations=[ + relations={ Relation( endpoint="relation_test", interface="azdrubales", @@ -109,7 +109,7 @@ def check_relation_data(charm): remote_app_data={"yaba": "doodle"}, remote_units_data={0: {"foo": "bar"}, 1: {"baz": "qux"}}, ) - ] + } ) trigger( state_in, diff --git a/tests/test_e2e/test_ports.py b/tests/test_e2e/test_ports.py index 3a19148f..76a46878 100644 --- a/tests/test_e2e/test_ports.py +++ b/tests/test_e2e/test_ports.py @@ -27,15 +27,16 @@ def ctx(): def test_open_port(ctx): - out = ctx.run(ctx.on.start(), State()) - port = out.opened_ports.pop() + out = ctx.run(ctx.on.start()), State()) + assert len(out.opened_ports) == 1 + port = tuple(out.opened_ports)[0] assert port.protocol == "tcp" assert port.port == 12 def test_close_port(ctx): - out = ctx.run(ctx.on.stop(), State(opened_ports=[TCPPort(42)])) + out = ctx.run(ctx.on.stop(), State(opened_ports={TCPPort(42)})) assert not out.opened_ports diff --git a/tests/test_e2e/test_relations.py b/tests/test_e2e/test_relations.py index 853c7ba5..9ba0ed61 100644 --- a/tests/test_e2e/test_relations.py +++ b/tests/test_e2e/test_relations.py @@ -64,10 +64,10 @@ def pre_event(charm: CharmBase): State( config={"foo": "bar"}, leader=True, - relations=[ + relations={ Relation(endpoint="foo", interface="foo", remote_app_name="remote"), Relation(endpoint="qux", interface="qux", remote_app_name="remote"), - ], + }, ), "start", mycharm, @@ -97,9 +97,9 @@ def test_relation_events(mycharm, evt_name): trigger( State( - relations=[ + relations={ relation, - ], + }, ), f"relation_{evt_name}", mycharm, @@ -141,9 +141,9 @@ def callback(charm: CharmBase, e): trigger( State( - relations=[ + relations={ relation, - ], + }, ), f"relation_{evt_name}", mycharm, @@ -202,7 +202,7 @@ def callback(charm: CharmBase, event): }, }, ) - state = State(relations=[relation]) + state = State(relations={relation}) kwargs = {} if has_unit: kwargs["remote_unit"] = remote_unit_id @@ -242,9 +242,9 @@ def callback(charm: CharmBase, event): trigger( State( - relations=[ + relations={ relation, - ], + }, ), f"relation_{evt_name}", mycharm, @@ -302,9 +302,9 @@ def callback(charm: CharmBase, event): trigger( State( - relations=[ + relations={ relation, - ], + }, ), f"relation_{evt_name}", mycharm, @@ -356,7 +356,7 @@ def test_relation_event_trigger(relation, evt_name, mycharm): "peers": {"b": {"interface": "i2"}}, } state = trigger( - State(relations=[relation]), + State(relations={relation}), f"relation_{evt_name}", mycharm, meta=meta, @@ -389,7 +389,7 @@ def post_event(charm: CharmBase): assert len(relation.units) == 1 trigger( - State(relations=[sub1, sub2]), + State(relations={sub1, sub2}), "update_status", mycharm, meta=meta, @@ -417,7 +417,7 @@ def test_broken_relation_not_in_model_relations(mycharm): ctx = Context( mycharm, meta={"name": "local", "requires": {"foo": {"interface": "foo"}}} ) - with ctx.manager(ctx.on.relation_broken(rel), state=State(relations=[rel])) as mgr: + with ctx.manager(ctx.on.relation_broken(rel), state=State(relations={rel})) as mgr: charm = mgr.charm assert charm.model.get_relation("foo") is None diff --git a/tests/test_e2e/test_secrets.py b/tests/test_e2e/test_secrets.py index 7229bd9f..7c398725 100644 --- a/tests/test_e2e/test_secrets.py +++ b/tests/test_e2e/test_secrets.py @@ -39,7 +39,7 @@ def test_get_secret_no_secret(mycharm): def test_get_secret(mycharm): ctx = Context(mycharm, meta={"name": "local"}) with ctx.manager( - state=State(secrets=[Secret(id="foo", contents={0: {"a": "b"}})]), + state=State(secrets={Secret(id="foo", contents={0: {"a": "b"}})}), event=ctx.on.update_status(), ) as mgr: assert mgr.charm.model.get_secret(id="foo").get_content()["a"] == "b" @@ -51,7 +51,7 @@ def test_get_secret_get_refresh(mycharm, owner): with ctx.manager( ctx.on.update_status(), State( - secrets=[ + secrets={ Secret( id="foo", contents={ @@ -60,7 +60,7 @@ def test_get_secret_get_refresh(mycharm, owner): }, owner=owner, ) - ] + } ), ) as mgr: charm = mgr.charm @@ -74,7 +74,7 @@ def test_get_secret_nonowner_peek_update(mycharm, app): ctx.on.update_status(), State( leader=app, - secrets=[ + secrets={ Secret( id="foo", contents={ @@ -82,7 +82,7 @@ def test_get_secret_nonowner_peek_update(mycharm, app): 1: {"a": "c"}, }, ), - ], + }, ), ) as mgr: charm = mgr.charm @@ -100,7 +100,7 @@ def test_get_secret_owner_peek_update(mycharm, owner): with ctx.manager( ctx.on.update_status(), State( - secrets=[ + secrets={ Secret( id="foo", contents={ @@ -109,7 +109,7 @@ def test_get_secret_owner_peek_update(mycharm, owner): }, owner=owner, ) - ] + } ), ) as mgr: charm = mgr.charm @@ -171,7 +171,7 @@ def test_add(mycharm, app): charm.unit.add_secret({"foo": "bar"}, label="mylabel") assert mgr.output.secrets - secret = mgr.output.secrets[0] + secret = mgr.output.get_secret(label="mylabel") assert secret.contents[0] == {"foo": "bar"} assert secret.label == "mylabel" @@ -215,7 +215,7 @@ def test_set_legacy_behaviour(mycharm): == rev3 ) - assert state_out.secrets[0].contents == { + assert state_out.get_secret(label="mylabel").contents == { 0: rev1, 1: rev2, 2: rev3, @@ -247,7 +247,7 @@ def test_set(mycharm): assert secret.get_content() == rev2 assert secret.peek_content() == secret.get_content(refresh=True) == rev3 - assert state_out.secrets[0].contents == { + assert state_out.get_secret(label="mylabel").contents == { 0: rev1, 1: rev2, 2: rev3, @@ -276,7 +276,7 @@ def test_set_juju33(mycharm): assert secret.peek_content() == rev3 assert secret.get_content(refresh=True) == rev3 - assert state_out.secrets[0].contents == { + assert state_out.get_secret(label="mylabel").contents == { 0: rev1, 1: rev2, 2: rev3, @@ -290,7 +290,7 @@ def test_meta(mycharm, app): ctx.on.update_status(), State( leader=True, - secrets=[ + secrets={ Secret( owner="app" if app else "unit", id="foo", @@ -301,7 +301,7 @@ def test_meta(mycharm, app): 0: {"a": "b"}, }, ) - ], + }, ), ) as mgr: charm = mgr.charm @@ -330,7 +330,7 @@ def test_secret_permission_model(mycharm, leader, owner): ctx.on.update_status(), State( leader=leader, - secrets=[ + secrets={ Secret( id="foo", label="mylabel", @@ -341,7 +341,7 @@ def test_secret_permission_model(mycharm, leader, owner): 0: {"a": "b"}, }, ) - ], + }, ), ) as mgr: secret = mgr.charm.model.get_secret(id="foo") @@ -383,7 +383,7 @@ def test_grant(mycharm, app): ctx.on.update_status(), State( relations=[Relation("foo", "remote")], - secrets=[ + secrets={ Secret( owner="unit", id="foo", @@ -394,7 +394,7 @@ def test_grant(mycharm, app): 0: {"a": "b"}, }, ) - ], + }, ), ) as mgr: charm = mgr.charm @@ -404,7 +404,7 @@ def test_grant(mycharm, app): secret.grant(relation=foo) else: secret.grant(relation=foo, unit=foo.units.pop()) - vals = list(mgr.output.secrets[0].remote_grants.values()) + vals = list(mgr.output.get_secret(label="mylabel").remote_grants.values()) assert vals == [{"remote"}] if app else [{"remote/0"}] @@ -415,7 +415,7 @@ def test_update_metadata(mycharm): with ctx.manager( ctx.on.update_status(), State( - secrets=[ + secrets={ Secret( owner="unit", id="foo", @@ -424,7 +424,7 @@ def test_update_metadata(mycharm): 0: {"a": "b"}, }, ) - ], + }, ), ) as mgr: secret = mgr.charm.model.get_secret(label="mylabel") @@ -435,7 +435,7 @@ def test_update_metadata(mycharm): rotate=SecretRotate.DAILY, ) - secret_out = mgr.output.secrets[0] + secret_out = mgr.output.get_secret(label="babbuccia") assert secret_out.label == "babbuccia" assert secret_out.rotate == SecretRotate.DAILY assert secret_out.description == "blu" @@ -475,8 +475,8 @@ def post_event(charm: CharmBase): out = trigger( State( - relations=[Relation("foo", "remote")], - secrets=[ + relations={Relation("foo", "remote")}, + secrets={ Secret( id="foo", label="mylabel", @@ -486,7 +486,7 @@ def post_event(charm: CharmBase): 0: {"a": "b"}, }, ) - ], + }, ), "update_status", mycharm, @@ -508,9 +508,9 @@ def __init__(self, *args): state = State( leader=True, - relations=[ + relations={ Relation("bar", remote_app_name=relation_remote_app, id=relation_id) - ], + }, ) with ctx.manager(ctx.on.start(), state) as mgr: @@ -521,7 +521,7 @@ def __init__(self, *args): secret.grant(bar_relation) assert mgr.output.secrets - scenario_secret = mgr.output.secrets[0] + scenario_secret = mgr.output.get_secret(label="mylabel") assert relation_remote_app in scenario_secret.remote_grants[relation_id] with ctx.manager(ctx.on.start(), mgr.output) as mgr: @@ -529,7 +529,7 @@ def __init__(self, *args): secret = charm.model.get_secret(label="mylabel") secret.revoke(bar_relation) - scenario_secret = mgr.output.secrets[0] + scenario_secret = mgr.output.get_secret(label="mylabel") assert scenario_secret.remote_grants == {} with ctx.manager(ctx.on.start(), mgr.output) as mgr: @@ -537,7 +537,7 @@ def __init__(self, *args): secret = charm.model.get_secret(label="mylabel") secret.remove_all_revisions() - assert not mgr.output.secrets[0].contents # secret wiped + assert not mgr.output.get_secret(label="mylabel").contents # secret wiped def test_no_additional_positional_arguments(): diff --git a/tests/test_e2e/test_state.py b/tests/test_e2e/test_state.py index 3f119909..d2bd6a50 100644 --- a/tests/test_e2e/test_state.py +++ b/tests/test_e2e/test_state.py @@ -67,7 +67,7 @@ def state(): def test_bare_event(state, mycharm): out = trigger(state, "start", mycharm, meta={"name": "foo"}) - out_purged = replace(out, stored_state=state.stored_state) + out_purged = replace(out, stored_states=state.stored_states) assert jsonpatch_delta(state, out_purged) == [] @@ -106,7 +106,7 @@ def call(charm: CharmBase, e): assert out.workload_version == "" # ignore stored state in the delta - out_purged = replace(out, stored_state=state.stored_state) + out_purged = replace(out, stored_states=state.stored_states) assert jsonpatch_delta(out_purged, state) == sort_patch( [ {"op": "replace", "path": "/app_status/message", "value": "foo barz"}, @@ -126,7 +126,7 @@ def pre_event(charm: CharmBase): assert container.can_connect() is connect trigger( - State(containers=[Container(name="foo", can_connect=connect)]), + State(containers={Container(name="foo", can_connect=connect)}), "start", mycharm, meta={ @@ -155,7 +155,7 @@ def pre_event(charm: CharmBase): assert not rel.data[unit] state = State( - relations=[ + relations={ Relation( endpoint="foo", interface="bar", @@ -165,7 +165,7 @@ def pre_event(charm: CharmBase): local_unit_data={"c": "d"}, remote_units_data={0: {}, 1: {"e": "f"}, 2: {}}, ) - ] + } ) trigger( state, @@ -215,7 +215,7 @@ def pre_event(charm: CharmBase): state = State( leader=True, planned_units=4, - relations=[relation], + relations={relation}, ) assert not mycharm.called @@ -231,16 +231,18 @@ def pre_event(charm: CharmBase): ) assert mycharm.called - assert asdict(out.relations[0]) == asdict( + assert asdict(out.get_relation(relation.id)) == asdict( replace( relation, local_app_data={"a": "b"}, local_unit_data={"c": "d", **DEFAULT_JUJU_DATABAG}, ) ) - - assert out.relations[0].local_app_data == {"a": "b"} - assert out.relations[0].local_unit_data == {"c": "d", **DEFAULT_JUJU_DATABAG} + assert out.get_relation(relation.id).local_app_data == {"a": "b"} + assert out.get_relation(relation.id).local_unit_data == { + "c": "d", + **DEFAULT_JUJU_DATABAG, + } @pytest.mark.parametrize( diff --git a/tests/test_e2e/test_storage.py b/tests/test_e2e/test_storage.py index b62288bb..5885d0dd 100644 --- a/tests/test_e2e/test_storage.py +++ b/tests/test_e2e/test_storage.py @@ -66,7 +66,7 @@ def test_storage_usage(storage_ctx): (storage.get_filesystem(storage_ctx) / "myfile.txt").write_text("helloworld") with storage_ctx.manager( - storage_ctx.on.update_status(), State(storage=[storage]) + storage_ctx.on.update_status(), State(storage={storage}) ) as mgr: foo = mgr.charm.model.storages["foo"][0] loc = foo.location @@ -85,9 +85,9 @@ def test_storage_usage(storage_ctx): def test_storage_attached_event(storage_ctx): storage = Storage("foo") - storage_ctx.run(storage_ctx.on.storage_attached(storage), State(storage=[storage])) + storage_ctx.run(storage_ctx.on.storage_attached(storage), State(storage={storage})) def test_storage_detaching_event(storage_ctx): storage = Storage("foo") - storage_ctx.run(storage_ctx.on.storage_detaching(storage), State(storage=[storage])) + storage_ctx.run(storage_ctx.on.storage_detaching(storage), State(storage={storage})) diff --git a/tests/test_e2e/test_stored_state.py b/tests/test_e2e/test_stored_state.py index 38c38efd..7e97f212 100644 --- a/tests/test_e2e/test_stored_state.py +++ b/tests/test_e2e/test_stored_state.py @@ -32,25 +32,37 @@ def _on_event(self, event): def test_stored_state_default(mycharm): out = trigger(State(), "start", mycharm, meta=mycharm.META) - assert out.stored_state[0].content == {"foo": "bar", "baz": {12: 142}} + assert out.get_stored_state("_stored", "MyCharm").content == { + "foo": "bar", + "baz": {12: 142}, + } + assert out.get_stored_state("_stored2", "MyCharm").content == { + "foo": "bar", + "baz": {12: 142}, + } def test_stored_state_initialized(mycharm): out = trigger( State( - stored_state=[ + stored_state={ StoredState( owner_path="MyCharm", name="_stored", content={"foo": "FOOX"} ), - ] + } ), "start", mycharm, meta=mycharm.META, ) - # todo: ordering is messy? - assert out.stored_state[1].content == {"foo": "FOOX", "baz": {12: 142}} - assert out.stored_state[0].content == {"foo": "bar", "baz": {12: 142}} + assert out.get_stored_state("_stored", "MyCharm").content == { + "foo": "FOOX", + "baz": {12: 142}, + } + assert out.get_stored_state("_stored2", "MyCharm").content == { + "foo": "bar", + "baz": {12: 142}, + } def test_positional_arguments(): From 97422c6b19620e71ff415d44b8bcfe8a9c57123d Mon Sep 17 00:00:00 2001 From: Tony Meyer Date: Thu, 30 May 2024 19:06:02 +1200 Subject: [PATCH 04/12] Don't add __eq__ for now. --- scenario/state.py | 99 ----------------------------------------------- 1 file changed, 99 deletions(-) diff --git a/scenario/state.py b/scenario/state.py index dd7798b8..68adfd17 100644 --- a/scenario/state.py +++ b/scenario/state.py @@ -293,26 +293,6 @@ class Secret: def __hash__(self) -> int: return hash(self.id) - def __eq__(self, value: object) -> bool: - if value is self: - return True - if not isinstance(value, Secret): - return False - for attr in ( - "id", - "owner", - "revision", - "label", - "description", - "expire", - "rotate", - "remote_grants", - "contents", - ): - if getattr(self, attr) != getattr(value, attr): - return False - return True - def _set_revision(self, revision: int): """Set a new tracked revision.""" # bypass frozen dataclass @@ -479,20 +459,6 @@ def __post_init__(self): def __hash__(self) -> int: return hash(self.id) - def __eq__(self, value: object) -> bool: - if value is self: - return True - if not isinstance(value, _RelationBase): - return False - if self.endpoint != value.endpoint or self.interface != value.interface: - return False - if ( - self.local_app_data != value.local_app_data - or self.local_unit_data != value.local_unit_data - ): - return False - return True - def _validate_databag(self, databag: dict): if not isinstance(databag, dict): raise StateValidationError( @@ -529,27 +495,6 @@ class Relation(_RelationBase): def __hash__(self) -> int: return hash(self.id) - def __eq__(self, value: object) -> bool: - if value is self: - return True - if not isinstance(value, Relation): - return False - if ( - self.endpoint != value.endpoint - or self.interface != value.interface - or self.limit != value.limit - or self.remote_app_name != value.remote_app_name - ): - return False - if ( - self.local_app_data != value.local_app_data - or self.local_unit_data != value.local_unit_data - or self.remote_app_data != value.remote_app_data - or self.remote_units_data != value.remote_units_data - ): - return False - return True - @property def _remote_app_name(self) -> str: """Who is on the other end of this relation?""" @@ -587,27 +532,6 @@ class SubordinateRelation(_RelationBase): def __hash__(self) -> int: return hash(self.id) - def __eq__(self, value: object) -> bool: - if value is self: - return True - if not isinstance(value, SubordinateRelation): - return False - if ( - self.endpoint != value.endpoint - or self.interface != value.interface - or self.remote_app_name != value.remote_app_name - or self.remote_unit_id != value.remote_unit_id - ): - return False - if ( - self.local_app_data != value.local_app_data - or self.local_unit_data != value.local_unit_data - or self.remote_app_data != value.remote_app_data - or self.remote_unit_data != value.remote_unit_data - ): - return False - return True - @property def _remote_unit_ids(self) -> Tuple[int]: """Ids of the units on the other end of this relation.""" @@ -646,20 +570,6 @@ class PeerRelation(_RelationBase): def __hash__(self) -> int: return hash(self.id) - def __eq__(self, value: object) -> bool: - if value is self: - return True - if not isinstance(value, PeerRelation): - return False - if self.endpoint != value.endpoint or self.interface != value.interface: - return False - if ( - self.local_app_data != value.local_app_data - or self.local_unit_data != value.local_unit_data - ): - return False - return True - @property def _databags(self): """Yield all databags in this relation.""" @@ -1016,15 +926,6 @@ def handle_path(self): def __hash__(self) -> int: return hash(self.handle_path) - def __eq__(self, value: object) -> bool: - if value is self: - return True - if not isinstance(value, StoredState): - return False - if self.handle_path != value.handle_path: - return False - return self.content == value.content - _RawPortProtocolLiteral = Literal["tcp", "udp", "icmp"] From 7e1bbc26ae7ed0202da1673ad51b6d893d5acf34 Mon Sep 17 00:00:00 2001 From: Tony Meyer Date: Thu, 30 May 2024 19:43:14 +1200 Subject: [PATCH 05/12] Update scenario/mocking.py --- scenario/mocking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scenario/mocking.py b/scenario/mocking.py index 2f1745c6..de9f4414 100644 --- a/scenario/mocking.py +++ b/scenario/mocking.py @@ -550,7 +550,7 @@ def action_get(self): def storage_add(self, name: str, count: int = 1): if not isinstance(count, int) or isinstance(count, bool): raise TypeError( - f"storage count must be integer, got: {count} ({type(count)}", + f"storage count must be integer, got: {count} ({type(count)})", ) if "/" in name: From 3b34ee37c9a35a1a9d200bb0a89950846e867abc Mon Sep 17 00:00:00 2001 From: Tony Meyer Date: Fri, 31 May 2024 18:13:46 +1200 Subject: [PATCH 06/12] Allow getting components by passing in the old entity. --- scenario/mocking.py | 2 +- scenario/state.py | 92 ++++++++++++++++++++--------- tests/test_e2e/test_stored_state.py | 8 +-- 3 files changed, 70 insertions(+), 32 deletions(-) diff --git a/scenario/mocking.py b/scenario/mocking.py index de9f4414..c8b3ca2b 100644 --- a/scenario/mocking.py +++ b/scenario/mocking.py @@ -154,7 +154,7 @@ def get_pebble(self, socket_path: str) -> "Client": container_root = self._context._get_container_root(container_name) try: mounts = self._state.get_container(container_name).mounts - except ValueError: + except KeyError: # container not defined in state. mounts = {} diff --git a/scenario/state.py b/scenario/state.py index 68adfd17..ec822224 100644 --- a/scenario/state.py +++ b/scenario/state.py @@ -1179,58 +1179,96 @@ def with_unit_status(self, status: StatusBase) -> "State": ), ) - def get_container(self, container: Union[str, Container]) -> Container: + def get_container(self, container: Union[str, Container], /) -> Container: """Get container from this State, based on an input container or its name.""" container_name = ( container.name if isinstance(container, Container) else container ) - containers = [c for c in self.containers if c.name == container_name] - if not containers: - raise ValueError(f"container: {container_name} not found in the State") - return containers[0] + for state_container in self.containers: + if state_container.name == container_name: + return state_container + raise KeyError(f"container: {container_name} not found in the State") def get_secret( self, *, id: Optional[str] = None, label: Optional[str] = None, + secret: Optional[Secret] = None, ) -> Secret: - """Get secret from this State, based on the secret's id or label.""" + """Get secret from this State, based on an input secret or the secret's id or label.""" + if id is None and label is None and secret is None: + raise ValueError("An id or label or Secret must be provided.") + if id is not None and secret is not None and secret.id != id: + raise ValueError("id and secret.id must match.") + if label is not None and secret is not None and secret.label != label: + raise ValueError("label and secret.label must match.") + + if secret: + if secret.id is not None: + id = secret.id + if secret.label is not None: + label = secret.label for secret in self.secrets: - if secret.id == id or secret.label == label: + if ( + (id and label and secret.id == id and secret.label == label) + or (id and label is None and secret.id == id) + or (id is None and label and secret.label == label) + ): return secret - if id: - message = f"secret: id={id} not found in the State" - elif label: - message = f"secret: label={label} not found in the State" - else: - message = "An id or label must be provided." - raise ValueError(message) + raise KeyError("secret: not found in the State") def get_stored_state( self, - name: str, + stored_state: Union[str, StoredState], + /, + *, owner_path: Optional[str] = None, ) -> StoredState: - """Get stored state from this State, based on the stored state's name and owner_path.""" + """Get stored state from this State, based on an input StoredState or the stored state's name and owner_path.""" + if ( + isinstance(stored_state, StoredState) + and owner_path is not None + and stored_state.owner_path != owner_path + ): + raise ValueError("owner_path and stored_state.owner_path must match.") + + name = ( + stored_state.name if isinstance(stored_state, StoredState) else stored_state + ) for stored_state in self.stored_states: if stored_state.name == name and stored_state.owner_path == owner_path: return stored_state raise ValueError(f"stored state: {name} not found in the State") - def get_storage(self, name: str, index: int = 0) -> Storage: - """Get storage from this State, based on the storage's name and index.""" - for storage in self.storages: - if storage.name == name and storage.index == index: - return storage + def get_storage( + self, + storage: Union[str, Storage], + /, + *, + index: Optional[int] = 0, + ) -> Storage: + """Get storage from this State, based on an input storage or the storage's name and index.""" + if ( + isinstance(storage, Storage) + and index is not None + and storage.index != index + ): + raise ValueError("index and storage.index must match.") + + name = storage.name if isinstance(storage, Storage) else storage + for state_storage in self.storages: + if state_storage.name == name and storage.index == index: + return state_storage raise ValueError(f"storage: name={name}, index={index} not found in the State") - def get_relation(self, id: int) -> "AnyRelation": - """Get relation from this State, based on the relation's id.""" - for relation in self.relations: - if relation.id == id: - return relation - raise ValueError(f"relation: id={id} not found in the State") + def get_relation(self, relation: Union[int, Relation]) -> "AnyRelation": + """Get relation from this State, based on an input relation or the relation's id.""" + relation_id = relation.id if isinstance(relation, Relation) else relation + for state_relation in self.relations: + if state_relation.id == relation_id: + return state_relation + raise KeyError(f"relation: id={relation_id} not found in the State") def get_relations(self, endpoint: str) -> Tuple["AnyRelation", ...]: """Get all relations on this endpoint from the current state.""" diff --git a/tests/test_e2e/test_stored_state.py b/tests/test_e2e/test_stored_state.py index 7e97f212..863f3e8f 100644 --- a/tests/test_e2e/test_stored_state.py +++ b/tests/test_e2e/test_stored_state.py @@ -32,11 +32,11 @@ def _on_event(self, event): def test_stored_state_default(mycharm): out = trigger(State(), "start", mycharm, meta=mycharm.META) - assert out.get_stored_state("_stored", "MyCharm").content == { + assert out.get_stored_state("_stored", owner_path="MyCharm").content == { "foo": "bar", "baz": {12: 142}, } - assert out.get_stored_state("_stored2", "MyCharm").content == { + assert out.get_stored_state("_stored2", owner_path="MyCharm").content == { "foo": "bar", "baz": {12: 142}, } @@ -55,11 +55,11 @@ def test_stored_state_initialized(mycharm): mycharm, meta=mycharm.META, ) - assert out.get_stored_state("_stored", "MyCharm").content == { + assert out.get_stored_state("_stored", owner_path="MyCharm").content == { "foo": "FOOX", "baz": {12: 142}, } - assert out.get_stored_state("_stored2", "MyCharm").content == { + assert out.get_stored_state("_stored2", owner_path="MyCharm").content == { "foo": "bar", "baz": {12: 142}, } From 9a6a18a33273394beadd7db69cef5d10c5059408 Mon Sep 17 00:00:00 2001 From: Tony Meyer Date: Wed, 5 Jun 2024 20:09:09 +1200 Subject: [PATCH 07/12] Revert back to the simpler get_ methods. --- scenario/state.py | 78 +++++++++++++++-------------------------------- 1 file changed, 24 insertions(+), 54 deletions(-) diff --git a/scenario/state.py b/scenario/state.py index ec822224..b629ac68 100644 --- a/scenario/state.py +++ b/scenario/state.py @@ -1179,36 +1179,23 @@ def with_unit_status(self, status: StatusBase) -> "State": ), ) - def get_container(self, container: Union[str, Container], /) -> Container: - """Get container from this State, based on an input container or its name.""" - container_name = ( - container.name if isinstance(container, Container) else container - ) + def get_container(self, container: str, /) -> Container: + """Get container from this State, based on its name.""" for state_container in self.containers: - if state_container.name == container_name: + if state_container.name == container: return state_container - raise KeyError(f"container: {container_name} not found in the State") + raise KeyError(f"container: {container} not found in the State") def get_secret( self, *, id: Optional[str] = None, label: Optional[str] = None, - secret: Optional[Secret] = None, ) -> Secret: - """Get secret from this State, based on an input secret or the secret's id or label.""" - if id is None and label is None and secret is None: - raise ValueError("An id or label or Secret must be provided.") - if id is not None and secret is not None and secret.id != id: - raise ValueError("id and secret.id must match.") - if label is not None and secret is not None and secret.label != label: - raise ValueError("label and secret.label must match.") - - if secret: - if secret.id is not None: - id = secret.id - if secret.label is not None: - label = secret.label + """Get secret from this State, based on the secret's id or label.""" + if id is None and label is None: + raise ValueError("An id or label must be provided.") + for secret in self.secrets: if ( (id and label and secret.id == id and secret.label == label) @@ -1220,55 +1207,38 @@ def get_secret( def get_stored_state( self, - stored_state: Union[str, StoredState], + stored_state: str, /, *, owner_path: Optional[str] = None, ) -> StoredState: - """Get stored state from this State, based on an input StoredState or the stored state's name and owner_path.""" - if ( - isinstance(stored_state, StoredState) - and owner_path is not None - and stored_state.owner_path != owner_path - ): - raise ValueError("owner_path and stored_state.owner_path must match.") - - name = ( - stored_state.name if isinstance(stored_state, StoredState) else stored_state - ) - for stored_state in self.stored_states: - if stored_state.name == name and stored_state.owner_path == owner_path: - return stored_state - raise ValueError(f"stored state: {name} not found in the State") + """Get stored state from this State, based on the stored state's name and owner_path.""" + for ss in self.stored_states: + if ss.name == stored_state and ss.owner_path == owner_path: + return ss + raise ValueError(f"stored state: {stored_state} not found in the State") def get_storage( self, - storage: Union[str, Storage], + storage: str, /, *, index: Optional[int] = 0, ) -> Storage: - """Get storage from this State, based on an input storage or the storage's name and index.""" - if ( - isinstance(storage, Storage) - and index is not None - and storage.index != index - ): - raise ValueError("index and storage.index must match.") - - name = storage.name if isinstance(storage, Storage) else storage + """Get storage from this State, based on the storage's name and index.""" for state_storage in self.storages: - if state_storage.name == name and storage.index == index: + if state_storage.name == storage and storage.index == index: return state_storage - raise ValueError(f"storage: name={name}, index={index} not found in the State") + raise ValueError( + f"storage: name={storage}, index={index} not found in the State", + ) - def get_relation(self, relation: Union[int, Relation]) -> "AnyRelation": - """Get relation from this State, based on an input relation or the relation's id.""" - relation_id = relation.id if isinstance(relation, Relation) else relation + def get_relation(self, relation: int, /) -> "AnyRelation": + """Get relation from this State, based on the relation's id.""" for state_relation in self.relations: - if state_relation.id == relation_id: + if state_relation.id == relation: return state_relation - raise KeyError(f"relation: id={relation_id} not found in the State") + raise KeyError(f"relation: id={relation} not found in the State") def get_relations(self, endpoint: str) -> Tuple["AnyRelation", ...]: """Get all relations on this endpoint from the current state.""" From 29bf3e5e9793aa0ea4e74fe778db3b0ff883bf53 Mon Sep 17 00:00:00 2001 From: Tony Meyer Date: Mon, 8 Jul 2024 18:11:08 +1200 Subject: [PATCH 08/12] Fix merges. --- README.md | 9 ++--- scenario/__init__.py | 1 - scenario/consistency_checker.py | 2 +- scenario/mocking.py | 4 +- scenario/runtime.py | 2 +- scenario/state.py | 3 +- tests/helpers.py | 4 +- tests/test_consistency_checker.py | 31 ++------------- tests/test_context_on.py | 2 +- tests/test_e2e/test_event_bind.py | 62 ----------------------------- tests/test_e2e/test_ports.py | 2 +- tests/test_e2e/test_state.py | 12 +++--- tests/test_e2e/test_storage.py | 8 ++-- tests/test_e2e/test_stored_state.py | 2 +- 14 files changed, 29 insertions(+), 115 deletions(-) delete mode 100644 tests/test_e2e/test_event_bind.py diff --git a/README.md b/README.md index 0bf9ab55..8d127852 100644 --- a/README.md +++ b/README.md @@ -707,7 +707,7 @@ storage = scenario.Storage("foo") # Setup storage with some content: (storage.get_filesystem(ctx) / "myfile.txt").write_text("helloworld") -with ctx.manager(ctx.on.update_status(), scenario.State(storage={storage})) as mgr: +with ctx.manager(ctx.on.update_status(), scenario.State(storages={storage})) as mgr: foo = mgr.charm.model.storages["foo"][0] loc = foo.location path = loc / "myfile.txt" @@ -752,11 +752,11 @@ So a natural follow-up Scenario test suite for this case would be: ctx = scenario.Context(MyCharm, meta=MyCharm.META) foo_0 = scenario.Storage('foo') # The charm is notified that one of the storages it has requested is ready: -ctx.run(ctx.on.storage_attached(foo_0), scenario.State(storage={foo_0})) +ctx.run(ctx.on.storage_attached(foo_0), scenario.State(storages={foo_0})) foo_1 = scenario.Storage('foo') # The charm is notified that the other storage is also ready: -ctx.run(ctx.on.storage_attached(foo_1), scenario.State(storage={foo_0, foo_}])) +ctx.run(ctx.on.storage_attached(foo_1), scenario.State(storages={foo_0, foo_1})) ``` ## Ports @@ -787,7 +787,6 @@ state = scenario.State( scenario.Secret( {0: {'key': 'public'}}, id='foo', - contents={0: {'key': 'public'}} ), }, ) @@ -880,7 +879,7 @@ So, the only consistency-level check we enforce in Scenario when it comes to res import pathlib ctx = scenario.Context(MyCharm, meta={'name': 'juliette', "resources": {"foo": {"type": "oci-image"}}}) -resource = scenario.Resource('foo', '/path/to/resource.tar') +resource = scenario.Resource(name='foo', path='/path/to/resource.tar') with ctx.manager(ctx.on.start(), scenario.State(resources={resource})) as mgr: # If the charm, at runtime, were to call self.model.resources.fetch("foo"), it would get '/path/to/resource.tar' back. path = mgr.charm.model.resources.fetch('foo') diff --git a/scenario/__init__.py b/scenario/__init__.py index 276f8682..fafc3631 100644 --- a/scenario/__init__.py +++ b/scenario/__init__.py @@ -53,7 +53,6 @@ "ICMPPort", "TCPPort", "UDPPort", - "Port", "Resource", "Storage", "StoredState", diff --git a/scenario/consistency_checker.py b/scenario/consistency_checker.py index 92dc30d5..274adc5c 100644 --- a/scenario/consistency_checker.py +++ b/scenario/consistency_checker.py @@ -265,7 +265,7 @@ def _check_storage_event( f"storage event {event.name} refers to storage {storage.name} " f"which is not declared in the charm metadata (metadata.yaml) under 'storage'.", ) - elif storage not in state.storage: + elif storage not in state.storages: errors.append( f"cannot emit {event.name} because storage {storage.name} " f"is not in the state.", diff --git a/scenario/mocking.py b/scenario/mocking.py index c8b3ca2b..0e2da1b5 100644 --- a/scenario/mocking.py +++ b/scenario/mocking.py @@ -129,7 +129,7 @@ def open_port( # fixme: the charm will get hit with a StateValidationError # here, not the expected ModelError... port_ = _port_cls_by_protocol[protocol](port=port) - ports = self._state.opened_ports + ports = set(self._state.opened_ports) if port_ not in ports: ports.add(port_) if ports != self._state.opened_ports: @@ -141,7 +141,7 @@ def close_port( port: Optional[int] = None, ): _port = _port_cls_by_protocol[protocol](port=port) - ports = self._state.opened_ports + ports = set(self._state.opened_ports) if _port in ports: ports.remove(_port) if ports != self._state.opened_ports: diff --git a/scenario/runtime.py b/scenario/runtime.py index 91e4705b..97abe921 100644 --- a/scenario/runtime.py +++ b/scenario/runtime.py @@ -347,7 +347,7 @@ def _virtual_charm_root(self): elif ( not spec.is_autoloaded and any_metadata_files_present_in_charm_virtual_root ): - logger.warn( + logger.warning( f"Some metadata files found in custom user-provided charm_root " f"{charm_virtual_root} while you have passed meta, config or actions to " f"Context.run(). " diff --git a/scenario/state.py b/scenario/state.py index b629ac68..21b9e551 100644 --- a/scenario/state.py +++ b/scenario/state.py @@ -267,7 +267,6 @@ class Secret(_max_posargs(1)): # mapping from revision IDs to each revision's contents contents: Dict[int, "RawSecretRevisionContents"] -class Secret: id: str # CAUTION: ops-created Secrets (via .add_secret()) will have a canonicalized # secret id (`secret:` prefix) @@ -1148,7 +1147,7 @@ def _update_status( # bypass frozen dataclass object.__setattr__(self, name, _EntityStatus(new_status, new_message)) - def _update_opened_ports(self, new_ports: FrozenSet[Port]): + def _update_opened_ports(self, new_ports: FrozenSet[_Port]): """Update the current opened ports.""" # bypass frozen dataclass object.__setattr__(self, "opened_ports", new_ports) diff --git a/tests/helpers.py b/tests/helpers.py index 47c3d339..4602f082 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -52,10 +52,10 @@ def trigger( if isinstance(event, str): if event.startswith("relation_"): assert len(state.relations) == 1, "shortcut only works with one relation" - event = getattr(ctx.on, event)(state.relations[0]) + event = getattr(ctx.on, event)(tuple(state.relations)[0]) elif event.startswith("pebble_"): assert len(state.containers) == 1, "shortcut only works with one container" - event = getattr(ctx.on, event)(state.containers[0]) + event = getattr(ctx.on, event)(tuple(state.containers)[0]) else: event = getattr(ctx.on, event)() with ctx.manager(event, state=state) as mgr: diff --git a/tests/test_consistency_checker.py b/tests/test_consistency_checker.py index 41d39836..51e3913e 100644 --- a/tests/test_consistency_checker.py +++ b/tests/test_consistency_checker.py @@ -376,14 +376,6 @@ def test_relation_not_in_state(): ) -def test_dupe_containers_inconsistent(): - assert_inconsistent( - State(containers={Container("foo"), Container("foo")}), - _Event("bar"), - _CharmSpec(MyCharm, {"containers": {"foo": {}}}), - ) - - def test_action_not_in_meta_inconsistent(): action = Action("foo", params={"bar": "baz"}) assert_inconsistent( @@ -549,7 +541,7 @@ def test_storage_not_in_state(): ), ) assert_consistent( - State(storage=[storage]), + State(storages=[storage]), _Event("foo_storage_attached", storage=storage), _CharmSpec( MyCharm, @@ -561,7 +553,7 @@ def test_storage_not_in_state(): def test_resource_states(): # happy path assert_consistent( - State(resources={Resource("foo", "/foo/bar.yaml")}), + State(resources={Resource(name="foo", path="/foo/bar.yaml")}), _Event("start"), _CharmSpec( MyCharm, @@ -581,7 +573,7 @@ def test_resource_states(): # resource not defined in meta assert_inconsistent( - State(resources={Resource("bar", "/foo/bar.yaml")}), + State(resources={Resource(name="bar", path="/foo/bar.yaml")}), _Event("start"), _CharmSpec( MyCharm, @@ -590,7 +582,7 @@ def test_resource_states(): ) assert_inconsistent( - State(resources={Resource("bar", "/foo/bar.yaml")}), + State(resources={Resource(name="bar", path="/foo/bar.yaml")}), _Event("start"), _CharmSpec( MyCharm, @@ -688,21 +680,6 @@ def test_storedstate_consistency(): }, ), ) - assert_inconsistent( - State( - stored_states={ - StoredState(owner_path=None, content={"foo": "bar"}), - StoredState(owner_path=None, name="_stored", content={"foo": "bar"}), - } - ), - _Event("start"), - _CharmSpec( - MyCharm, - meta={ - "name": "foo", - }, - ), - ) assert_inconsistent( State( stored_states={ diff --git a/tests/test_context_on.py b/tests/test_context_on.py index d9609d2e..1c98b4ea 100644 --- a/tests/test_context_on.py +++ b/tests/test_context_on.py @@ -156,7 +156,7 @@ def test_revision_secret_events_as_positional_arg(event_name): def test_storage_events(event_name, event_kind): ctx = scenario.Context(ContextCharm, meta=META, actions=ACTIONS) storage = scenario.Storage("foo") - state_in = scenario.State(storage=[storage]) + state_in = scenario.State(storages=[storage]) # These look like: # ctx.run(ctx.on.storage_attached(storage), state) with ctx.manager(getattr(ctx.on, event_name)(storage), state_in) as mgr: diff --git a/tests/test_e2e/test_event_bind.py b/tests/test_e2e/test_event_bind.py deleted file mode 100644 index 141592fe..00000000 --- a/tests/test_e2e/test_event_bind.py +++ /dev/null @@ -1,62 +0,0 @@ -import pytest - -from scenario import Container, Event, Relation, Secret, State -from scenario.state import BindFailedError - - -def test_bind_relation(): - event = Event("foo-relation-changed") - foo_relation = Relation("foo") - state = State(relations={foo_relation}) - assert event.bind(state).relation is foo_relation - - -def test_bind_relation_complex_name(): - event = Event("foo-bar-baz-relation-changed") - foo_relation = Relation("foo_bar_baz") - state = State(relations={foo_relation}) - assert event.bind(state).relation is foo_relation - - -def test_bind_relation_notfound(): - event = Event("foo-relation-changed") - state = State() - with pytest.raises(BindFailedError): - event.bind(state) - - -def test_bind_relation_toomany(caplog): - event = Event("foo-relation-changed") - foo_relation = Relation("foo") - foo_relation1 = Relation("foo") - state = State(relations={foo_relation, foo_relation1}) - event.bind(state) - assert "too many relations" in caplog.text - - -def test_bind_secret(): - event = Event("secret-changed") - secret = Secret("foo", {"a": "b"}) - state = State(secrets={secret}) - assert event.bind(state).secret is secret - - -def test_bind_secret_notfound(): - event = Event("secret-changed") - state = State() - with pytest.raises(BindFailedError): - event.bind(state) - - -def test_bind_container(): - event = Event("foo-pebble-ready") - container = Container("foo") - state = State(containers={container}) - assert event.bind(state).container is container - - -def test_bind_container_notfound(): - event = Event("foo-pebble-ready") - state = State() - with pytest.raises(BindFailedError): - event.bind(state) diff --git a/tests/test_e2e/test_ports.py b/tests/test_e2e/test_ports.py index 76a46878..80365a01 100644 --- a/tests/test_e2e/test_ports.py +++ b/tests/test_e2e/test_ports.py @@ -27,7 +27,7 @@ def ctx(): def test_open_port(ctx): - out = ctx.run(ctx.on.start()), State()) + out = ctx.run(ctx.on.start(), State()) assert len(out.opened_ports) == 1 port = tuple(out.opened_ports)[0] diff --git a/tests/test_e2e/test_state.py b/tests/test_e2e/test_state.py index d2bd6a50..7ec5d14f 100644 --- a/tests/test_e2e/test_state.py +++ b/tests/test_e2e/test_state.py @@ -287,13 +287,13 @@ def test_container_default_values(): def test_state_default_values(): state = State() assert state.config == {} - assert state.relations == [] + assert state.relations == frozenset() assert state.networks == {} - assert state.containers == [] - assert state.storage == [] - assert state.opened_ports == [] - assert state.secrets == [] - assert state.resources == {} + assert state.containers == frozenset() + assert state.storages == frozenset() + assert state.opened_ports == frozenset() + assert state.secrets == frozenset() + assert state.resources == frozenset() assert state.deferred == [] assert isinstance(state.model, Model) assert state.leader is False diff --git a/tests/test_e2e/test_storage.py b/tests/test_e2e/test_storage.py index 5885d0dd..3e6912fb 100644 --- a/tests/test_e2e/test_storage.py +++ b/tests/test_e2e/test_storage.py @@ -66,7 +66,7 @@ def test_storage_usage(storage_ctx): (storage.get_filesystem(storage_ctx) / "myfile.txt").write_text("helloworld") with storage_ctx.manager( - storage_ctx.on.update_status(), State(storage={storage}) + storage_ctx.on.update_status(), State(storages={storage}) ) as mgr: foo = mgr.charm.model.storages["foo"][0] loc = foo.location @@ -85,9 +85,11 @@ def test_storage_usage(storage_ctx): def test_storage_attached_event(storage_ctx): storage = Storage("foo") - storage_ctx.run(storage_ctx.on.storage_attached(storage), State(storage={storage})) + storage_ctx.run(storage_ctx.on.storage_attached(storage), State(storages={storage})) def test_storage_detaching_event(storage_ctx): storage = Storage("foo") - storage_ctx.run(storage_ctx.on.storage_detaching(storage), State(storage={storage})) + storage_ctx.run( + storage_ctx.on.storage_detaching(storage), State(storages={storage}) + ) diff --git a/tests/test_e2e/test_stored_state.py b/tests/test_e2e/test_stored_state.py index 863f3e8f..94b9c301 100644 --- a/tests/test_e2e/test_stored_state.py +++ b/tests/test_e2e/test_stored_state.py @@ -45,7 +45,7 @@ def test_stored_state_default(mycharm): def test_stored_state_initialized(mycharm): out = trigger( State( - stored_state={ + stored_states={ StoredState( owner_path="MyCharm", name="_stored", content={"foo": "FOOX"} ), From c2ff88240f83aa7db8aa4661132f83c4ac52be09 Mon Sep 17 00:00:00 2001 From: Tony Meyer Date: Mon, 8 Jul 2024 18:39:02 +1200 Subject: [PATCH 09/12] Remove unused method (was used in the old binding, not generally useful). --- scenario/state.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/scenario/state.py b/scenario/state.py index 21b9e551..cb0fb53f 100644 --- a/scenario/state.py +++ b/scenario/state.py @@ -1254,11 +1254,6 @@ def get_relations(self, endpoint: str) -> Tuple["AnyRelation", ...]: if normalize_name(r.endpoint) == normalized_endpoint ) - # TODO: It seems like this method has no tests. - def get_storages(self, name: str) -> Tuple["Storage", ...]: - """Get all storages with this name.""" - return tuple(s for s in self.storages if s.name == name) - def _is_valid_charmcraft_25_metadata(meta: Dict[str, Any]): # Check whether this dict has the expected mandatory metadata fields according to the From 28b8765d5a800031f77ee479ed3651253b32cb2f Mon Sep 17 00:00:00 2001 From: Tony Meyer Date: Mon, 8 Jul 2024 18:39:16 +1200 Subject: [PATCH 10/12] Add a basic test for resources. --- tests/test_e2e/test_state.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_e2e/test_state.py b/tests/test_e2e/test_state.py index 7ec5d14f..95b433a0 100644 --- a/tests/test_e2e/test_state.py +++ b/tests/test_e2e/test_state.py @@ -15,6 +15,7 @@ Model, Network, Relation, + Resource, State, ) from tests.helpers import jsonpatch_delta, sort_patch, trigger @@ -249,6 +250,7 @@ def pre_event(charm: CharmBase): "klass,num_args", [ (State, (1,)), + (Resource, (1, )), (Address, (0, 2)), (BindAddress, (0, 2)), (Network, (0, 2)), From ffa3b7a88f132824ef7b2c263a502ab0b8a45ef8 Mon Sep 17 00:00:00 2001 From: Tony Meyer Date: Mon, 8 Jul 2024 18:39:37 +1200 Subject: [PATCH 11/12] Add a basic test for resources. --- tests/test_e2e/test_resource.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 tests/test_e2e/test_resource.py diff --git a/tests/test_e2e/test_resource.py b/tests/test_e2e/test_resource.py new file mode 100644 index 00000000..a0354923 --- /dev/null +++ b/tests/test_e2e/test_resource.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +# Copyright 2024 Canonical Ltd. +# See LICENSE file for licensing details. + +import pathlib + +import pytest +import ops + +from scenario import Context, Resource, State + + +class ResourceCharm(ops.CharmBase): + def __init__(self, framework): + super().__init__(framework) + + +def test_get_resource(): + ctx = Context(ResourceCharm, meta={"name": "resource-charm", "resources": {"foo": {"type": "file"}, "bar": {"type": "file"}}}) + resource1 = Resource(name="foo", path=pathlib.Path("/tmp/foo")) + resource2 = Resource(name="bar", path=pathlib.Path("~/bar")) + with ctx.manager(ctx.on.update_status(), state=State(resources={resource1, resource2})) as mgr: + assert mgr.charm.model.resources.fetch("foo") == resource1.path + assert mgr.charm.model.resources.fetch("bar") == resource2.path + with pytest.raises(NameError): + mgr.charm.model.resources.fetch("baz") From f02c9712043ba4aadddaffafacda4a5f24b1cbb7 Mon Sep 17 00:00:00 2001 From: Tony Meyer Date: Tue, 9 Jul 2024 09:52:27 +1200 Subject: [PATCH 12/12] Make networks a set as well. --- README.md | 2 +- scenario/consistency_checker.py | 2 +- scenario/mocking.py | 5 ++++- scenario/state.py | 18 ++++++++++++++++-- tests/helpers.py | 1 + tests/test_consistency_checker.py | 6 +++--- tests/test_e2e/test_network.py | 4 ++-- tests/test_e2e/test_resource.py | 14 +++++++++++--- tests/test_e2e/test_state.py | 6 +++--- 9 files changed, 42 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 8d127852..d730e788 100644 --- a/README.md +++ b/README.md @@ -495,7 +495,7 @@ If you want to, you can override any of these relation or extra-binding associat ```python state = scenario.State(networks={ - 'foo': scenario.Network.default(private_address='192.0.2.1') + scenario.Network.default("foo", private_address='192.0.2.1') }) ``` diff --git a/scenario/consistency_checker.py b/scenario/consistency_checker.py index 274adc5c..c4281ece 100644 --- a/scenario/consistency_checker.py +++ b/scenario/consistency_checker.py @@ -462,7 +462,7 @@ def check_network_consistency( if metadata.get("scope") != "container" # mark of a sub } - state_bindings = set(state.networks) + state_bindings = {network.binding_name for network in state.networks} if diff := state_bindings.difference(meta_bindings.union(non_sub_relations)): errors.append( f"Some network bindings defined in State are not in metadata.yaml: {diff}.", diff --git a/scenario/mocking.py b/scenario/mocking.py index 0e2da1b5..7cea2752 100644 --- a/scenario/mocking.py +++ b/scenario/mocking.py @@ -313,7 +313,10 @@ def network_get(self, binding_name: str, relation_id: Optional[int] = None): raise RelationNotFoundError() # We look in State.networks for an override. If not given, we return a default network. - network = self._state.networks.get(binding_name, Network.default()) + try: + network = self._state.get_network(binding_name) + except KeyError: + network = Network.default("default") # The name is not used in the output. return network.hook_tool_output_fmt() # setter methods: these can mutate the state. diff --git a/scenario/state.py b/scenario/state.py index cb0fb53f..fd691f85 100644 --- a/scenario/state.py +++ b/scenario/state.py @@ -355,11 +355,15 @@ def hook_tool_output_fmt(self): @dataclasses.dataclass(frozen=True) -class Network(_max_posargs(0)): +class Network(_max_posargs(1)): + binding_name: str bind_addresses: List[BindAddress] ingress_addresses: List[str] egress_subnets: List[str] + def __hash__(self) -> int: + return hash(self.binding_name) + def hook_tool_output_fmt(self): # dumps itself to dict in the same format the hook tool would return { @@ -371,6 +375,7 @@ def hook_tool_output_fmt(self): @classmethod def default( cls, + binding_name: str, private_address: str = "192.0.2.0", hostname: str = "", cidr: str = "", @@ -381,6 +386,7 @@ def default( ) -> "Network": """Helper to create a minimal, heavily defaulted Network.""" return cls( + binding_name=binding_name, bind_addresses=[ BindAddress( interface_name=interface_name, @@ -1051,7 +1057,7 @@ class State(_max_posargs(0)): """The present configuration of this charm.""" relations: FrozenSet["AnyRelation"] = dataclasses.field(default_factory=frozenset) """All relations that currently exist for this charm.""" - networks: Dict[str, Network] = dataclasses.field(default_factory=dict) + networks: FrozenSet[Network] = dataclasses.field(default_factory=frozenset) """Manual overrides for any relation and extra bindings currently provisioned for this charm. If a metadata-defined relation endpoint is not explicitly mapped to a Network in this field, it will be defaulted. @@ -1116,6 +1122,7 @@ def __post_init__(self): "relations", "containers", "storages", + "networks", "opened_ports", "secrets", "resources", @@ -1185,6 +1192,13 @@ def get_container(self, container: str, /) -> Container: return state_container raise KeyError(f"container: {container} not found in the State") + def get_network(self, binding_name: str, /) -> Network: + """Get network from this State, based on its binding name.""" + for network in self.networks: + if network.binding_name == binding_name: + return network + raise KeyError(f"network: {binding_name} not found in the State") + def get_secret( self, *, diff --git a/tests/helpers.py b/tests/helpers.py index 4602f082..c8060d1c 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -78,6 +78,7 @@ def jsonpatch_delta(self, other: "State"): "secrets", "resources", "stored_states", + "networks", ): dict_other[attr] = [dataclasses.asdict(o) for o in dict_other[attr]] dict_self[attr] = [dataclasses.asdict(o) for o in dict_self[attr]] diff --git a/tests/test_consistency_checker.py b/tests/test_consistency_checker.py index 51e3913e..82d9c76a 100644 --- a/tests/test_consistency_checker.py +++ b/tests/test_consistency_checker.py @@ -593,7 +593,7 @@ def test_resource_states(): def test_networks_consistency(): assert_inconsistent( - State(networks={"foo": Network.default()}), + State(networks={Network.default("foo")}), _Event("start"), _CharmSpec( MyCharm, @@ -602,7 +602,7 @@ def test_networks_consistency(): ) assert_inconsistent( - State(networks={"foo": Network.default()}), + State(networks={Network.default("foo")}), _Event("start"), _CharmSpec( MyCharm, @@ -615,7 +615,7 @@ def test_networks_consistency(): ) assert_consistent( - State(networks={"foo": Network.default()}), + State(networks={Network.default("foo")}), _Event("start"), _CharmSpec( MyCharm, diff --git a/tests/test_e2e/test_network.py b/tests/test_e2e/test_network.py index 440a01be..e2b70ea4 100644 --- a/tests/test_e2e/test_network.py +++ b/tests/test_e2e/test_network.py @@ -51,7 +51,7 @@ def test_ip_get(mycharm): id=1, ), ], - networks={"foo": Network.default(private_address="4.4.4.4")}, + networks={Network.default("foo", private_address="4.4.4.4")}, ), ) as mgr: # we have a network for the relation @@ -113,7 +113,7 @@ def test_no_relation_error(mycharm): id=1, ), ], - networks={"bar": Network.default()}, + networks={Network.default("bar")}, ), ) as mgr: with pytest.raises(RelationNotFoundError): diff --git a/tests/test_e2e/test_resource.py b/tests/test_e2e/test_resource.py index a0354923..c4237ea6 100644 --- a/tests/test_e2e/test_resource.py +++ b/tests/test_e2e/test_resource.py @@ -4,8 +4,8 @@ import pathlib -import pytest import ops +import pytest from scenario import Context, Resource, State @@ -16,10 +16,18 @@ def __init__(self, framework): def test_get_resource(): - ctx = Context(ResourceCharm, meta={"name": "resource-charm", "resources": {"foo": {"type": "file"}, "bar": {"type": "file"}}}) + ctx = Context( + ResourceCharm, + meta={ + "name": "resource-charm", + "resources": {"foo": {"type": "file"}, "bar": {"type": "file"}}, + }, + ) resource1 = Resource(name="foo", path=pathlib.Path("/tmp/foo")) resource2 = Resource(name="bar", path=pathlib.Path("~/bar")) - with ctx.manager(ctx.on.update_status(), state=State(resources={resource1, resource2})) as mgr: + with ctx.manager( + ctx.on.update_status(), state=State(resources={resource1, resource2}) + ) as mgr: assert mgr.charm.model.resources.fetch("foo") == resource1.path assert mgr.charm.model.resources.fetch("bar") == resource2.path with pytest.raises(NameError): diff --git a/tests/test_e2e/test_state.py b/tests/test_e2e/test_state.py index 95b433a0..aaa3246f 100644 --- a/tests/test_e2e/test_state.py +++ b/tests/test_e2e/test_state.py @@ -250,10 +250,10 @@ def pre_event(charm: CharmBase): "klass,num_args", [ (State, (1,)), - (Resource, (1, )), + (Resource, (1,)), (Address, (0, 2)), (BindAddress, (0, 2)), - (Network, (0, 2)), + (Network, (1, 2)), ], ) def test_positional_arguments(klass, num_args): @@ -290,7 +290,7 @@ def test_state_default_values(): state = State() assert state.config == {} assert state.relations == frozenset() - assert state.networks == {} + assert state.networks == frozenset() assert state.containers == frozenset() assert state.storages == frozenset() assert state.opened_ports == frozenset()