Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: add Scenario classes that match the ops status classes #142

Merged
merged 29 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c2d3a1d
Remove _DCBase.
tonyandrewmeyer Apr 4, 2024
240fe6e
Remove deprecated functionality.
tonyandrewmeyer Apr 18, 2024
aa4088c
Add basic StoredState consistency checks.
tonyandrewmeyer Apr 1, 2024
4f9c191
Fix broken files.
tonyandrewmeyer May 30, 2024
63af2a8
Update scenario/mocking.py
tonyandrewmeyer May 30, 2024
90cdfe6
Test the code in the README.
tonyandrewmeyer Jun 6, 2024
64815be
Update tests and docs to match final (hopefully\!) API decision.
tonyandrewmeyer Jun 4, 2024
9f2e0e1
Style fixes.
tonyandrewmeyer Jun 4, 2024
a3fe6fe
Fix tests.
tonyandrewmeyer Jun 4, 2024
fc0dcca
Remove the old shortcuts on state components.
tonyandrewmeyer Jun 4, 2024
0c05854
Move the checks that were on binding to the consistency checker.
tonyandrewmeyer Jun 5, 2024
00b46e6
Update tests now that emitting custom events is not possible.
tonyandrewmeyer Jun 5, 2024
f3c1880
Support 'ctx.on.event_name' for specifying events.
tonyandrewmeyer Apr 24, 2024
885c34a
Update tests and docs to match final (hopefully\!) API decision.
tonyandrewmeyer Jun 4, 2024
4ed80fa
Fix tests.
tonyandrewmeyer Jun 4, 2024
cc62c8d
Style fixes.
tonyandrewmeyer Jun 6, 2024
990d5d9
Add Scenario versions of the status classes.
tonyandrewmeyer Jun 7, 2024
5230696
Add __eq__ to Port and Storage comparing to ops versions.
tonyandrewmeyer Jun 7, 2024
b9110c5
Explicitly use the Scenario classes, per review.
tonyandrewmeyer Jun 7, 2024
68f8309
Fix merge.
tonyandrewmeyer Jul 9, 2024
5c54588
Move the conversion from string/ops into _EntityStatus, per review.
tonyandrewmeyer Jul 9, 2024
075bcaa
Clarify == vs is for statuses, per review.
tonyandrewmeyer Jul 9, 2024
4b79027
Add clarifying comments, per review.
tonyandrewmeyer Jul 9, 2024
3ee146e
Copy the status docstrings from ops.
tonyandrewmeyer Jul 9, 2024
2977f59
Fix merging.
tonyandrewmeyer Jul 9, 2024
4789350
Merge branch '7.0' into add-status-classes
tonyandrewmeyer Jul 9, 2024
8ca4048
Fix merge.
tonyandrewmeyer Jul 9, 2024
7bc77c3
Use the deferred() method not _Event, per review.
tonyandrewmeyer Jul 9, 2024
2ed3e37
Be less clever and more efficient populating the str:class mapping, p…
tonyandrewmeyer Jul 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,12 @@ With that, we can write the simplest possible scenario test:
def test_scenario_base():
ctx = scenario.Context(MyCharm, meta={"name": "foo"})
out = ctx.run(ctx.on.start(), scenario.State())
assert out.unit_status == ops.UnknownStatus()
assert out.unit_status == scenario.UnknownStatus()
```

Note that you should always compare the app and unit status using `==`, not `is`. You can compare
them to either the `scenario` objects, or the `ops` ones.

Now let's start making it more complicated. Our charm sets a special state if it has leadership on 'start':

```python
Expand All @@ -110,7 +113,7 @@ class MyCharm(ops.CharmBase):
def test_status_leader(leader):
ctx = scenario.Context(MyCharm, meta={"name": "foo"})
out = ctx.run(ctx.on.start(), scenario.State(leader=leader))
assert out.unit_status == ops.ActiveStatus('I rule' if leader else 'I am ruled')
assert out.unit_status == scenario.ActiveStatus('I rule' if leader else 'I am ruled')
```

By defining the right state we can programmatically define what answers will the charm get to all the questions it can
Expand Down Expand Up @@ -165,15 +168,15 @@ def test_statuses():
ctx = scenario.Context(MyCharm, meta={"name": "foo"})
out = ctx.run(ctx.on.start(), scenario.State(leader=False))
assert ctx.unit_status_history == [
ops.UnknownStatus(),
ops.MaintenanceStatus('determining who the ruler is...'),
ops.WaitingStatus('checking this is right...'),
scenario.UnknownStatus(),
scenario.MaintenanceStatus('determining who the ruler is...'),
scenario.WaitingStatus('checking this is right...'),
]
assert out.unit_status == ops.ActiveStatus("I am ruled")
assert out.unit_status == scenario.ActiveStatus("I am ruled")

# similarly you can check the app status history:
assert ctx.app_status_history == [
ops.UnknownStatus(),
scenario.UnknownStatus(),
...
]
```
Expand All @@ -198,9 +201,9 @@ class MyCharm(ops.CharmBase):

# ...
ctx = scenario.Context(MyCharm, meta={"name": "foo"})
ctx.run(ctx.on.start(), scenario.State(unit_status=ops.ActiveStatus('foo')))
ctx.run(ctx.on.start(), scenario.State(unit_status=scenario.ActiveStatus('foo')))
assert ctx.unit_status_history == [
ops.ActiveStatus('foo'), # now the first status is active: 'foo'!
scenario.ActiveStatus('foo'), # now the first status is active: 'foo'!
# ...
]
```
Expand Down Expand Up @@ -248,7 +251,7 @@ def test_emitted_full():
capture_deferred_events=True,
capture_framework_events=True,
)
ctx.run(ctx.on.start(), scenario.State(deferred=[scenario.Event("update-status").deferred(MyCharm._foo)]))
ctx.run(ctx.on.start(), scenario.State(deferred=[ctx.on.update_status().deferred(MyCharm._foo)]))

assert len(ctx.emitted_events) == 5
assert [e.handle.kind for e in ctx.emitted_events] == [
Expand Down Expand Up @@ -396,8 +399,6 @@ meta = {
}
ctx = scenario.Context(ops.CharmBase, meta=meta, unit_id=1)
ctx.run(ctx.on.start(), state_in) # invalid: this unit's id cannot be the ID of a peer.


```

### SubordinateRelation
Expand Down
12 changes: 12 additions & 0 deletions scenario/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
from scenario.context import ActionOutput, Context
from scenario.state import (
Action,
ActiveStatus,
Address,
BindAddress,
BlockedStatus,
CloudCredential,
CloudSpec,
Container,
DeferredEvent,
ErrorStatus,
ExecOutput,
ICMPPort,
MaintenanceStatus,
Model,
Mount,
Network,
Expand All @@ -27,6 +31,8 @@
SubordinateRelation,
TCPPort,
UDPPort,
UnknownStatus,
WaitingStatus,
deferred,
)

Expand Down Expand Up @@ -58,4 +64,10 @@
"StoredState",
"State",
"DeferredEvent",
"ErrorStatus",
"BlockedStatus",
"WaitingStatus",
"MaintenanceStatus",
"ActiveStatus",
"UnknownStatus",
]
4 changes: 2 additions & 2 deletions scenario/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,9 +484,9 @@ def _get_storage_root(self, name: str, index: int) -> Path:
def _record_status(self, state: "State", is_app: bool):
"""Record the previous status before a status change."""
if is_app:
self.app_status_history.append(cast("_EntityStatus", state.app_status))
self.app_status_history.append(state.app_status)
else:
self.unit_status_history.append(cast("_EntityStatus", state.unit_status))
self.unit_status_history.append(state.unit_status)

def manager(self, event: "_Event", state: "State"):
"""Context manager to introspect live charm object before and after the event is emitted.
Expand Down
4 changes: 3 additions & 1 deletion scenario/mocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
Network,
PeerRelation,
Storage,
_EntityStatus,
_port_cls_by_protocol,
_RawPortProtocolLiteral,
_RawStatusLiteral,
Expand Down Expand Up @@ -335,7 +336,8 @@ def status_set(
is_app: bool = False,
):
self._context._record_status(self._state, is_app)
self._state._update_status(status, message, is_app)
status_obj = _EntityStatus.from_status_name(status, message)
self._state._update_status(status_obj, is_app)

def juju_log(self, level: str, message: str):
self._context.juju_log.append(JujuLogLine(level, message))
Expand Down
151 changes: 126 additions & 25 deletions scenario/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
Final,
FrozenSet,
Expand All @@ -29,6 +30,7 @@
)
from uuid import uuid4

import ops
import yaml
from ops import pebble
from ops.charm import CharmBase, CharmEvents
Expand Down Expand Up @@ -875,18 +877,17 @@ def get_notice(
class _EntityStatus:
"""This class represents StatusBase and should not be interacted with directly."""

# Why not use StatusBase directly? Because that's not json-serializable.
# Why not use StatusBase directly? Because that can't be used with
# dataclasses.asdict to then be JSON-serializable.

name: _RawStatusLiteral
message: str = ""

_entity_statuses: ClassVar[Dict[str, Type["_EntityStatus"]]] = {}

def __eq__(self, other):
if isinstance(other, (StatusBase, _EntityStatus)):
return (self.name, self.message) == (other.name, other.message)
logger.warning(
f"Comparing Status with {other} is not stable and will be forbidden soon."
f"Please compare with StatusBase directly.",
)
return super().__eq__(other)

def __repr__(self):
Expand All @@ -895,17 +896,89 @@ def __repr__(self):
return f"{status_type_name}()"
return f"{status_type_name}('{self.message}')"

@classmethod
def from_status_name(
cls,
name: _RawStatusLiteral,
message: str = "",
) -> "_EntityStatus":
# Note that this won't work for UnknownStatus.
# All subclasses have a default 'name' attribute, but the type checker can't tell that.
return cls._entity_statuses[name](message=message) # type:ignore

@classmethod
def from_ops(cls, obj: StatusBase) -> "_EntityStatus":
return cls.from_status_name(cast(_RawStatusLiteral, obj.name), obj.message)


@dataclasses.dataclass(frozen=True, eq=False, repr=False)
class UnknownStatus(_EntityStatus, ops.UnknownStatus):
__doc__ = ops.UnknownStatus.__doc__

name: Literal["unknown"] = "unknown"

def __init__(self):
super().__init__(name=self.name)


@dataclasses.dataclass(frozen=True, eq=False, repr=False)
class ErrorStatus(_EntityStatus, ops.ErrorStatus):
__doc__ = ops.ErrorStatus.__doc__

name: Literal["error"] = "error"

def __init__(self, message: str = ""):
super().__init__(name="error", message=message)


@dataclasses.dataclass(frozen=True, eq=False, repr=False)
class ActiveStatus(_EntityStatus, ops.ActiveStatus):
__doc__ = ops.ActiveStatus.__doc__

name: Literal["active"] = "active"

def __init__(self, message: str = ""):
super().__init__(name="active", message=message)

def _status_to_entitystatus(obj: StatusBase) -> _EntityStatus:
"""Convert StatusBase to _EntityStatus."""
statusbase_subclass = type(StatusBase.from_name(obj.name, obj.message))

class _MyClass(_EntityStatus, statusbase_subclass):
# Custom type inheriting from a specific StatusBase subclass to support instance checks:
# isinstance(state.unit_status, ops.ActiveStatus)
pass
@dataclasses.dataclass(frozen=True, eq=False, repr=False)
class BlockedStatus(_EntityStatus, ops.BlockedStatus):
__doc__ = ops.BlockedStatus.__doc__

return _MyClass(cast(_RawStatusLiteral, obj.name), obj.message)
name: Literal["blocked"] = "blocked"

def __init__(self, message: str = ""):
super().__init__(name="blocked", message=message)


@dataclasses.dataclass(frozen=True, eq=False, repr=False)
class MaintenanceStatus(_EntityStatus, ops.MaintenanceStatus):
__doc__ = ops.MaintenanceStatus.__doc__

name: Literal["maintenance"] = "maintenance"

def __init__(self, message: str = ""):
super().__init__(name="maintenance", message=message)


@dataclasses.dataclass(frozen=True, eq=False, repr=False)
class WaitingStatus(_EntityStatus, ops.WaitingStatus):
__doc__ = ops.WaitingStatus.__doc__

name: Literal["waiting"] = "waiting"

def __init__(self, message: str = ""):
super().__init__(name="waiting", message=message)


_EntityStatus._entity_statuses.update(
unknown=UnknownStatus,
error=ErrorStatus,
active=ActiveStatus,
blocked=BlockedStatus,
maintenance=MaintenanceStatus,
waiting=WaitingStatus,
)


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -950,6 +1023,11 @@ def __post_init__(self):
"please use TCPPort, UDPPort, or ICMPPort",
)

def __eq__(self, other: object) -> bool:
if isinstance(other, (_Port, ops.Port)):
return (self.protocol, self.port) == (other.protocol, other.port)
return False


@dataclasses.dataclass(frozen=True)
class TCPPort(_Port):
Expand Down Expand Up @@ -1029,6 +1107,11 @@ class Storage(_max_posargs(1)):
index: int = dataclasses.field(default_factory=next_storage_index)
# Every new Storage instance gets a new one, if there's trouble, override.

def __eq__(self, other: object) -> bool:
if isinstance(other, (Storage, ops.Storage)):
return (self.name, self.index) == (other.name, other.index)
return False

def get_filesystem(self, ctx: "Context") -> Path:
"""Simulated filesystem root in this context."""
return ctx._get_storage_root(self.name, self.index)
Expand Down Expand Up @@ -1099,23 +1182,45 @@ class State(_max_posargs(0)):
)
"""Contents of a charm's stored state."""

# the current statuses. Will be cast to _EntitiyStatus in __post_init__
app_status: Union[StatusBase, _EntityStatus] = _EntityStatus("unknown")
# the current statuses.
app_status: _EntityStatus = UnknownStatus()
"""Status of the application."""
unit_status: Union[StatusBase, _EntityStatus] = _EntityStatus("unknown")
unit_status: _EntityStatus = UnknownStatus()
"""Status of the unit."""
workload_version: str = ""
"""Workload version."""

def __post_init__(self):
# Let people pass in the ops classes, and convert them to the appropriate Scenario classes.
tonyandrewmeyer marked this conversation as resolved.
Show resolved Hide resolved
for name in ["app_status", "unit_status"]:
val = getattr(self, name)
if isinstance(val, _EntityStatus):
pass
elif isinstance(val, StatusBase):
object.__setattr__(self, name, _status_to_entitystatus(val))
object.__setattr__(self, name, _EntityStatus.from_ops(val))
else:
raise TypeError(f"Invalid status.{name}: {val!r}")
normalised_ports = [
_Port(protocol=port.protocol, port=port.port)
if isinstance(port, ops.Port)
else port
for port in self.opened_ports
]
if self.opened_ports != normalised_ports:
object.__setattr__(self, "opened_ports", normalised_ports)
normalised_storage = [
Storage(name=storage.name, index=storage.index)
if isinstance(storage, ops.Storage)
else storage
for storage in self.storages
]
if self.storages != normalised_storage:
object.__setattr__(self, "storages", normalised_storage)
# ops.Container, ops.Model, ops.Relation, ops.Secret should not be instantiated by charmers.
# ops.Network does not have the relation name, so cannot be converted.
# ops.Resources does not contain the source of the resource, so cannot be converted.
# ops.StoredState is not convenient to initialise with data, so not useful here.

# It's convenient to pass a set, but we really want the attributes to be
# frozen sets to increase the immutability of State objects.
for name in [
Expand Down Expand Up @@ -1145,14 +1250,13 @@ def _update_workload_version(self, new_workload_version: str):

def _update_status(
self,
new_status: _RawStatusLiteral,
new_message: str = "",
new_status: _EntityStatus,
is_app: bool = False,
):
"""Update the current app/unit status and add the previous one to the history."""
"""Update the current app/unit status."""
name = "app_status" if is_app else "unit_status"
# bypass frozen dataclass
object.__setattr__(self, name, _EntityStatus(new_status, new_message))
object.__setattr__(self, name, new_status)

def _update_opened_ports(self, new_ports: FrozenSet[_Port]):
"""Update the current opened ports."""
Expand All @@ -1179,10 +1283,7 @@ def with_leadership(self, leader: bool) -> "State":
def with_unit_status(self, status: StatusBase) -> "State":
return dataclasses.replace(
self,
status=dataclasses.replace(
cast(_EntityStatus, self.unit_status),
unit=_status_to_entitystatus(status),
),
unit_status=_EntityStatus.from_ops(status),
)

def get_container(self, container: str, /) -> Container:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_e2e/test_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from scenario import Context
from scenario.context import InvalidEventError
from scenario.state import Action, State, _Event, next_action_id
from scenario.state import Action, State, next_action_id


@pytest.fixture(scope="function")
Expand Down
Loading
Loading