Skip to content

Commit

Permalink
feat!: add Scenario classes that match the ops status classes (#142)
Browse files Browse the repository at this point in the history
Adds classes that match the ops status classes:

* UnknownStatus
* ActiveStatus
* WaitingStatus
* MaintenanceStatus
* BlockedStatus
* ErrorStatus
  • Loading branch information
tonyandrewmeyer committed Sep 2, 2024
1 parent 5285060 commit f11809a
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 53 deletions.
25 changes: 13 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,12 @@ With that, we can write the simplest possible scenario test:
def test_scenario_base():
ctx = scenario.Context(MyCharm, meta={"name": "foo"})
out = ctx.run(ctx.on.start(), scenario.State())
assert out.unit_status == ops.UnknownStatus()
assert out.unit_status == scenario.UnknownStatus()
```

Note that you should always compare the app and unit status using `==`, not `is`. You can compare
them to either the `scenario` objects, or the `ops` ones.

Now let's start making it more complicated. Our charm sets a special state if it has leadership on 'start':

```python
Expand All @@ -110,7 +113,7 @@ class MyCharm(ops.CharmBase):
def test_status_leader(leader):
ctx = scenario.Context(MyCharm, meta={"name": "foo"})
out = ctx.run(ctx.on.start(), scenario.State(leader=leader))
assert out.unit_status == ops.ActiveStatus('I rule' if leader else 'I am ruled')
assert out.unit_status == scenario.ActiveStatus('I rule' if leader else 'I am ruled')
```

By defining the right state we can programmatically define what answers will the charm get to all the questions it can
Expand Down Expand Up @@ -165,15 +168,15 @@ def test_statuses():
ctx = scenario.Context(MyCharm, meta={"name": "foo"})
out = ctx.run(ctx.on.start(), scenario.State(leader=False))
assert ctx.unit_status_history == [
ops.UnknownStatus(),
ops.MaintenanceStatus('determining who the ruler is...'),
ops.WaitingStatus('checking this is right...'),
scenario.UnknownStatus(),
scenario.MaintenanceStatus('determining who the ruler is...'),
scenario.WaitingStatus('checking this is right...'),
]
assert out.unit_status == ops.ActiveStatus("I am ruled")
assert out.unit_status == scenario.ActiveStatus("I am ruled")

# similarly you can check the app status history:
assert ctx.app_status_history == [
ops.UnknownStatus(),
scenario.UnknownStatus(),
...
]
```
Expand All @@ -198,9 +201,9 @@ class MyCharm(ops.CharmBase):

# ...
ctx = scenario.Context(MyCharm, meta={"name": "foo"})
ctx.run(ctx.on.start(), scenario.State(unit_status=ops.ActiveStatus('foo')))
ctx.run(ctx.on.start(), scenario.State(unit_status=scenario.ActiveStatus('foo')))
assert ctx.unit_status_history == [
ops.ActiveStatus('foo'), # now the first status is active: 'foo'!
scenario.ActiveStatus('foo'), # now the first status is active: 'foo'!
# ...
]
```
Expand Down Expand Up @@ -248,7 +251,7 @@ def test_emitted_full():
capture_deferred_events=True,
capture_framework_events=True,
)
ctx.run(ctx.on.start(), scenario.State(deferred=[scenario.Event("update-status").deferred(MyCharm._foo)]))
ctx.run(ctx.on.start(), scenario.State(deferred=[ctx.on.update_status().deferred(MyCharm._foo)]))

assert len(ctx.emitted_events) == 5
assert [e.handle.kind for e in ctx.emitted_events] == [
Expand Down Expand Up @@ -396,8 +399,6 @@ meta = {
}
ctx = scenario.Context(ops.CharmBase, meta=meta, unit_id=1)
ctx.run(ctx.on.start(), state_in) # invalid: this unit's id cannot be the ID of a peer.


```

### SubordinateRelation
Expand Down
12 changes: 12 additions & 0 deletions scenario/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
from scenario.context import ActionOutput, Context
from scenario.state import (
Action,
ActiveStatus,
Address,
BindAddress,
BlockedStatus,
CloudCredential,
CloudSpec,
Container,
DeferredEvent,
ErrorStatus,
ExecOutput,
ICMPPort,
MaintenanceStatus,
Model,
Mount,
Network,
Expand All @@ -27,6 +31,8 @@
SubordinateRelation,
TCPPort,
UDPPort,
UnknownStatus,
WaitingStatus,
deferred,
)

Expand Down Expand Up @@ -58,4 +64,10 @@
"StoredState",
"State",
"DeferredEvent",
"ErrorStatus",
"BlockedStatus",
"WaitingStatus",
"MaintenanceStatus",
"ActiveStatus",
"UnknownStatus",
]
4 changes: 2 additions & 2 deletions scenario/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,9 +557,9 @@ def _get_storage_root(self, name: str, index: int) -> Path:
def _record_status(self, state: "State", is_app: bool):
"""Record the previous status before a status change."""
if is_app:
self.app_status_history.append(cast("_EntityStatus", state.app_status))
self.app_status_history.append(state.app_status)
else:
self.unit_status_history.append(cast("_EntityStatus", state.unit_status))
self.unit_status_history.append(state.unit_status)

def manager(self, event: "_Event", state: "State"):
"""Context manager to introspect live charm object before and after the event is emitted.
Expand Down
4 changes: 3 additions & 1 deletion scenario/mocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
Network,
PeerRelation,
Storage,
_EntityStatus,
_port_cls_by_protocol,
_RawPortProtocolLiteral,
_RawStatusLiteral,
Expand Down Expand Up @@ -338,7 +339,8 @@ def status_set(
is_app: bool = False,
):
self._context._record_status(self._state, is_app)
self._state._update_status(status, message, is_app)
status_obj = _EntityStatus.from_status_name(status, message)
self._state._update_status(status_obj, is_app)

def juju_log(self, level: str, message: str):
self._context.juju_log.append(JujuLogLine(level, message))
Expand Down
151 changes: 126 additions & 25 deletions scenario/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
Final,
FrozenSet,
Expand All @@ -32,6 +33,7 @@
)
from uuid import uuid4

import ops
import yaml
from ops import pebble
from ops.charm import CharmBase, CharmEvents
Expand Down Expand Up @@ -956,18 +958,17 @@ def get_notice(
class _EntityStatus:
"""This class represents StatusBase and should not be interacted with directly."""

# Why not use StatusBase directly? Because that's not json-serializable.
# Why not use StatusBase directly? Because that can't be used with
# dataclasses.asdict to then be JSON-serializable.

name: _RawStatusLiteral
message: str = ""

_entity_statuses: ClassVar[Dict[str, Type["_EntityStatus"]]] = {}

def __eq__(self, other):
if isinstance(other, (StatusBase, _EntityStatus)):
return (self.name, self.message) == (other.name, other.message)
logger.warning(
f"Comparing Status with {other} is not stable and will be forbidden soon."
f"Please compare with StatusBase directly.",
)
return super().__eq__(other)

def __repr__(self):
Expand All @@ -976,17 +977,89 @@ def __repr__(self):
return f"{status_type_name}()"
return f"{status_type_name}('{self.message}')"

@classmethod
def from_status_name(
cls,
name: _RawStatusLiteral,
message: str = "",
) -> "_EntityStatus":
# Note that this won't work for UnknownStatus.
# All subclasses have a default 'name' attribute, but the type checker can't tell that.
return cls._entity_statuses[name](message=message) # type:ignore

@classmethod
def from_ops(cls, obj: StatusBase) -> "_EntityStatus":
return cls.from_status_name(cast(_RawStatusLiteral, obj.name), obj.message)


@dataclasses.dataclass(frozen=True, eq=False, repr=False)
class UnknownStatus(_EntityStatus, ops.UnknownStatus):
__doc__ = ops.UnknownStatus.__doc__

name: Literal["unknown"] = "unknown"

def __init__(self):
super().__init__(name=self.name)


@dataclasses.dataclass(frozen=True, eq=False, repr=False)
class ErrorStatus(_EntityStatus, ops.ErrorStatus):
__doc__ = ops.ErrorStatus.__doc__

name: Literal["error"] = "error"

def __init__(self, message: str = ""):
super().__init__(name="error", message=message)


@dataclasses.dataclass(frozen=True, eq=False, repr=False)
class ActiveStatus(_EntityStatus, ops.ActiveStatus):
__doc__ = ops.ActiveStatus.__doc__

name: Literal["active"] = "active"

def __init__(self, message: str = ""):
super().__init__(name="active", message=message)

def _status_to_entitystatus(obj: StatusBase) -> _EntityStatus:
"""Convert StatusBase to _EntityStatus."""
statusbase_subclass = type(StatusBase.from_name(obj.name, obj.message))

class _MyClass(_EntityStatus, statusbase_subclass):
# Custom type inheriting from a specific StatusBase subclass to support instance checks:
# isinstance(state.unit_status, ops.ActiveStatus)
pass
@dataclasses.dataclass(frozen=True, eq=False, repr=False)
class BlockedStatus(_EntityStatus, ops.BlockedStatus):
__doc__ = ops.BlockedStatus.__doc__

return _MyClass(cast(_RawStatusLiteral, obj.name), obj.message)
name: Literal["blocked"] = "blocked"

def __init__(self, message: str = ""):
super().__init__(name="blocked", message=message)


@dataclasses.dataclass(frozen=True, eq=False, repr=False)
class MaintenanceStatus(_EntityStatus, ops.MaintenanceStatus):
__doc__ = ops.MaintenanceStatus.__doc__

name: Literal["maintenance"] = "maintenance"

def __init__(self, message: str = ""):
super().__init__(name="maintenance", message=message)


@dataclasses.dataclass(frozen=True, eq=False, repr=False)
class WaitingStatus(_EntityStatus, ops.WaitingStatus):
__doc__ = ops.WaitingStatus.__doc__

name: Literal["waiting"] = "waiting"

def __init__(self, message: str = ""):
super().__init__(name="waiting", message=message)


_EntityStatus._entity_statuses.update(
unknown=UnknownStatus,
error=ErrorStatus,
active=ActiveStatus,
blocked=BlockedStatus,
maintenance=MaintenanceStatus,
waiting=WaitingStatus,
)


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -1033,6 +1106,11 @@ def __post_init__(self):
"please use TCPPort, UDPPort, or ICMPPort",
)

def __eq__(self, other: object) -> bool:
if isinstance(other, (_Port, ops.Port)):
return (self.protocol, self.port) == (other.protocol, other.port)
return False


@dataclasses.dataclass(frozen=True)
class TCPPort(_Port):
Expand Down Expand Up @@ -1112,6 +1190,11 @@ class Storage(_max_posargs(1)):
index: int = dataclasses.field(default_factory=next_storage_index)
# Every new Storage instance gets a new one, if there's trouble, override.

def __eq__(self, other: object) -> bool:
if isinstance(other, (Storage, ops.Storage)):
return (self.name, self.index) == (other.name, other.index)
return False

def get_filesystem(self, ctx: "Context") -> Path:
"""Simulated filesystem root in this context."""
return ctx._get_storage_root(self.name, self.index)
Expand Down Expand Up @@ -1182,23 +1265,45 @@ class State(_max_posargs(0)):
)
"""Contents of a charm's stored state."""

# the current statuses. Will be cast to _EntitiyStatus in __post_init__
app_status: Union[StatusBase, _EntityStatus] = _EntityStatus("unknown")
# the current statuses.
app_status: _EntityStatus = UnknownStatus()
"""Status of the application."""
unit_status: Union[StatusBase, _EntityStatus] = _EntityStatus("unknown")
unit_status: _EntityStatus = UnknownStatus()
"""Status of the unit."""
workload_version: str = ""
"""Workload version."""

def __post_init__(self):
# Let people pass in the ops classes, and convert them to the appropriate Scenario classes.
for name in ["app_status", "unit_status"]:
val = getattr(self, name)
if isinstance(val, _EntityStatus):
pass
elif isinstance(val, StatusBase):
object.__setattr__(self, name, _status_to_entitystatus(val))
object.__setattr__(self, name, _EntityStatus.from_ops(val))
else:
raise TypeError(f"Invalid status.{name}: {val!r}")
normalised_ports = [
_Port(protocol=port.protocol, port=port.port)
if isinstance(port, ops.Port)
else port
for port in self.opened_ports
]
if self.opened_ports != normalised_ports:
object.__setattr__(self, "opened_ports", normalised_ports)
normalised_storage = [
Storage(name=storage.name, index=storage.index)
if isinstance(storage, ops.Storage)
else storage
for storage in self.storages
]
if self.storages != normalised_storage:
object.__setattr__(self, "storages", normalised_storage)
# ops.Container, ops.Model, ops.Relation, ops.Secret should not be instantiated by charmers.
# ops.Network does not have the relation name, so cannot be converted.
# ops.Resources does not contain the source of the resource, so cannot be converted.
# ops.StoredState is not convenient to initialise with data, so not useful here.

# It's convenient to pass a set, but we really want the attributes to be
# frozen sets to increase the immutability of State objects.
for name in [
Expand Down Expand Up @@ -1228,14 +1333,13 @@ def _update_workload_version(self, new_workload_version: str):

def _update_status(
self,
new_status: _RawStatusLiteral,
new_message: str = "",
new_status: _EntityStatus,
is_app: bool = False,
):
"""Update the current app/unit status and add the previous one to the history."""
"""Update the current app/unit status."""
name = "app_status" if is_app else "unit_status"
# bypass frozen dataclass
object.__setattr__(self, name, _EntityStatus(new_status, new_message))
object.__setattr__(self, name, new_status)

def _update_opened_ports(self, new_ports: FrozenSet[_Port]):
"""Update the current opened ports."""
Expand All @@ -1262,10 +1366,7 @@ def with_leadership(self, leader: bool) -> "State":
def with_unit_status(self, status: StatusBase) -> "State":
return dataclasses.replace(
self,
status=dataclasses.replace(
cast(_EntityStatus, self.unit_status),
unit=_status_to_entitystatus(status),
),
unit_status=_EntityStatus.from_ops(status),
)

def get_container(self, container: str, /) -> Container:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_e2e/test_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from scenario import Context
from scenario.context import InvalidEventError
from scenario.state import Action, State, _Event, next_action_id
from scenario.state import Action, State, next_action_id


@pytest.fixture(scope="function")
Expand Down
Loading

0 comments on commit f11809a

Please sign in to comment.