diff --git a/pyproject.toml b/pyproject.toml index f8c89a48..7151bf22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta" [project] name = "ops-scenario" -version = "7.0.1" +version = "7.0.2" authors = [ { name = "Pietro Pasotti", email = "pietro.pasotti@canonical.com" } diff --git a/scenario/__init__.py b/scenario/__init__.py index 2ba5a24c..82c48485 100644 --- a/scenario/__init__.py +++ b/scenario/__init__.py @@ -60,7 +60,7 @@ def test_base(): assert out.unit_status == UnknownStatus() """ -from scenario.context import Context, Manager +from scenario.context import Context, ExecArgs, Manager from scenario.state import ( ActionFailed, ActiveStatus, @@ -119,6 +119,7 @@ def test_base(): "DeferredEvent", "ErrorStatus", "Exec", + "ExecArgs", "ICMPPort", "JujuLogLine", "MaintenanceStatus", diff --git a/scenario/context.py b/scenario/context.py index cb5331d5..9075170e 100644 --- a/scenario/context.py +++ b/scenario/context.py @@ -2,6 +2,7 @@ # Copyright 2023 Canonical Ltd. # See LICENSE file for licensing details. +import dataclasses import functools import tempfile from contextlib import contextmanager @@ -9,11 +10,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Type, Union, cast import ops -import ops.testing from scenario.errors import AlreadyEmittedError, ContextSetupError from scenario.logger import logger as scenario_logger -from scenario.runtime import Runtime +from scenario.runtime import Runtime, _CharmType from scenario.state import ( ActionFailed, CheckInfo, @@ -36,6 +36,31 @@ _DEFAULT_JUJU_VERSION = "3.5" +# This needs to exactly match the class in ops/testing.py, because it will +# replace that one when Scenario is available in ops.testing. We cannot import +# it here because that would create a circular import. +@dataclasses.dataclass(frozen=True) +class ExecArgs: + """Represent arguments captured from the :meth:`ops.Container.exec` method call. + + These arguments will be available in the :attr:`Context.exec_history` dictionary. + + See :meth:`ops.pebble.Client.exec` for documentation of properties. + """ + + command: List[str] + environment: Dict[str, str] + working_dir: Optional[str] + timeout: Optional[float] + user_id: Optional[int] + user: Optional[str] + group_id: Optional[int] + group: Optional[str] + stdin: Optional[Union[str, bytes]] + encoding: Optional[str] + combine_stderr: bool + + class Manager: """Context manager to offer test code some runtime charm object introspection. @@ -426,7 +451,7 @@ def test_foo(): def __init__( self, - charm_type: Type[ops.testing.CharmType], + charm_type: Type[_CharmType], meta: Optional[Dict[str, Any]] = None, *, actions: Optional[Dict[str, Any]] = None, @@ -508,7 +533,7 @@ def __init__( self.juju_log: List["JujuLogLine"] = [] self.app_status_history: List["_EntityStatus"] = [] self.unit_status_history: List["_EntityStatus"] = [] - self.exec_history: Dict[str, List[ops.testing.ExecArgs]] = {} + self.exec_history: Dict[str, List[ExecArgs]] = {} self.workload_version_history: List[str] = [] self.removed_secret_revisions: List[int] = [] self.emitted_events: List[ops.EventBase] = [] @@ -644,7 +669,10 @@ def run(self, event: "_Event", state: "State") -> "State": assert self._output_state is not None if event.action: if self._action_failure_message is not None: - raise ActionFailed(self._action_failure_message, self._output_state) + raise ActionFailed( + self._action_failure_message, + state=self._output_state, + ) return self._output_state @contextmanager diff --git a/scenario/mocking.py b/scenario/mocking.py index f5207a37..2187175f 100644 --- a/scenario/mocking.py +++ b/scenario/mocking.py @@ -1,13 +1,22 @@ #!/usr/bin/env python3 # Copyright 2023 Canonical Ltd. # See LICENSE file for licensing details. + +import dataclasses import datetime +import fnmatch +import http +import io +import os import shutil +import signal from pathlib import Path from typing import ( TYPE_CHECKING, Any, + BinaryIO, Dict, + Iterable, List, Literal, Mapping, @@ -19,7 +28,7 @@ cast, ) -from ops import JujuVersion, pebble +from ops import Container, JujuVersion, pebble from ops.model import CloudSpec as CloudSpec_Ops from ops.model import ModelError from ops.model import Port as Port_Ops @@ -33,8 +42,8 @@ _ModelBackend, ) from ops.pebble import Client, ExecError -from ops.testing import ExecArgs, _TestingPebbleClient +from scenario.context import ExecArgs from scenario.errors import ActionMissingFromContextError from scenario.logger import logger as scenario_logger from scenario.state import ( @@ -66,9 +75,9 @@ def __init__( change_id: int, args: ExecArgs, return_code: int, - stdin: Optional[TextIO], - stdout: Optional[TextIO], - stderr: Optional[TextIO], + stdin: Optional[Union[TextIO, io.BytesIO]], + stdout: Optional[Union[TextIO, io.BytesIO]], + stderr: Optional[Union[TextIO, io.BytesIO]], ): self._change_id = change_id self._args = args @@ -85,7 +94,7 @@ def __del__(self): def _close_stdin(self): if self._args.stdin is None and self.stdin is not None: self.stdin.seek(0) - self._args.stdin = self.stdin.read() + object.__setattr__(self._args, "stdin", self.stdin.read()) def wait(self): self._close_stdin() @@ -99,7 +108,12 @@ def wait_output(self): stdout = self.stdout.read() if self.stdout is not None else None stderr = self.stderr.read() if self.stderr is not None else None if self._return_code != 0: - raise ExecError(list(self._args.command), self._return_code, stdout, stderr) + raise ExecError( + list(self._args.command), + self._return_code, + stdout, # type: ignore + stderr, # type: ignore + ) return stdout, stderr def send_signal(self, sig: Union[int, str]): # noqa: U100 @@ -167,15 +181,18 @@ def get_pebble(self, socket_path: str) -> "Client": # container not defined in state. mounts = {} - return _MockPebbleClient( - socket_path=socket_path, - container_root=container_root, - mounts=mounts, - state=self._state, - event=self._event, - charm_spec=self._charm_spec, - context=self._context, - container_name=container_name, + return cast( + Client, + _MockPebbleClient( + socket_path=socket_path, + container_root=container_root, + mounts=mounts, + state=self._state, + event=self._event, + charm_spec=self._charm_spec, + context=self._context, + container_name=container_name, + ), ) def _get_relation_by_id(self, rel_id) -> "RelationBase": @@ -607,7 +624,7 @@ def storage_add(self, name: str, count: int = 1): ) if "/" in name: - # this error is raised by ops.testing but not by ops at runtime + # this error is raised by Harness but not by ops at runtime raise ModelError('storage name cannot contain "/"') self._context.requested_storages[name] = count @@ -708,7 +725,7 @@ def credential_get(self) -> CloudSpec_Ops: return self._state.model.cloud_spec._to_ops() -class _MockPebbleClient(_TestingPebbleClient): +class _MockPebbleClient: def __init__( self, socket_path: str, @@ -743,6 +760,10 @@ def __init__( self._root = container_root + self._notices: Dict[Tuple[str, str], pebble.Notice] = {} + self._last_notice_id = 0 + self._changes: Dict[str, pebble.Change] = {} + # load any existing notices and check information from the state self._notices: Dict[Tuple[str, str], pebble.Notice] = {} self._check_infos: Dict[str, pebble.CheckInfo] = {} @@ -781,7 +802,7 @@ def _layers(self) -> Dict[str, pebble.Layer]: def _service_status(self) -> Dict[str, pebble.ServiceStatus]: return self._container.service_statuses - # Based on a method of the same name from ops.testing. + # Based on a method of the same name from Harness. def _find_exec_handler(self, command) -> Optional["Exec"]: handlers = {exec.command_prefix: exec for exec in self._container.execs} # Start with the full command and, each loop iteration, drop the last @@ -796,6 +817,25 @@ def _find_exec_handler(self, command) -> Optional["Exec"]: # matter how much of it was used, so we have failed to find a handler. return None + def _transform_exec_handler_output( + self, + data: Union[str, bytes], + encoding: Optional[str], + ) -> Union[io.BytesIO, io.StringIO]: + if isinstance(data, bytes): + if encoding is None: + return io.BytesIO(data) + else: + return io.StringIO(data.decode(encoding=encoding)) + else: + if encoding is None: + raise ValueError( + f"exec handler must return bytes if encoding is None," + f"not {data.__class__.__name__}", + ) + else: + return io.StringIO(cast(str, data)) + def exec( self, command: List[str], @@ -881,3 +921,633 @@ def _check_connection(self): f"can_connect=True for container {self._container.name}?" ) raise pebble.ConnectionError(msg) + + def get_system_info(self) -> pebble.SystemInfo: + self._check_connection() + return pebble.SystemInfo(version="1.0.0") + + def get_warnings( + self, + select: pebble.WarningState = pebble.WarningState.PENDING, + ) -> List["pebble.Warning"]: + self._check_connection() + raise NotImplementedError(self.get_warnings) + + def ack_warnings(self, timestamp: datetime.datetime) -> int: + self._check_connection() + raise NotImplementedError(self.ack_warnings) + + def get_changes( + self, + select: pebble.ChangeState = pebble.ChangeState.IN_PROGRESS, + service: Optional[str] = None, + ) -> List[pebble.Change]: + raise NotImplementedError(self.get_changes) + + def get_change(self, change_id: str) -> pebble.Change: + self._check_connection() + try: + return self._changes[change_id] + except KeyError: + message = f'cannot find change with id "{change_id}"' + raise self._api_error(404, message) from None + + def abort_change(self, change_id: pebble.ChangeID) -> pebble.Change: + raise NotImplementedError(self.abort_change) + + def autostart_services(self, timeout: float = 30.0, delay: float = 0.1): + self._check_connection() + for name, service in self._render_services().items(): + # TODO: jam 2021-04-20 This feels awkward that Service.startup might be a string or + # might be an enum. Probably should make Service.startup a property rather than an + # attribute. + if service.startup == "": + startup = pebble.ServiceStartup.DISABLED + else: + startup = pebble.ServiceStartup(service.startup) + if startup == pebble.ServiceStartup.ENABLED: + self._service_status[name] = pebble.ServiceStatus.ACTIVE + + def replan_services(self, timeout: float = 30.0, delay: float = 0.1): + return self.autostart_services(timeout, delay) + + def start_services( + self, + services: List[str], + timeout: float = 30.0, + delay: float = 0.1, + ): + self._check_connection() + + # Note: jam 2021-04-20 We don't implement ChangeID, but the default caller of this is + # Container.start() which currently ignores the return value + known_services = self._render_services() + # Names appear to be validated before any are activated, so do two passes + for name in services: + if name not in known_services: + # TODO: jam 2021-04-20 This needs a better error type + raise RuntimeError(f'400 Bad Request: service "{name}" does not exist') + for name in services: + self._service_status[name] = pebble.ServiceStatus.ACTIVE + + def stop_services( + self, + services: List[str], + timeout: float = 30.0, + delay: float = 0.1, + ): + self._check_connection() + + # Note: jam 2021-04-20 We don't implement ChangeID, but the default caller of this is + # Container.stop() which currently ignores the return value + known_services = self._render_services() + for name in services: + if name not in known_services: + # TODO: jam 2021-04-20 This needs a better error type + # 400 Bad Request: service "bal" does not exist + raise RuntimeError(f'400 Bad Request: service "{name}" does not exist') + for name in services: + self._service_status[name] = pebble.ServiceStatus.INACTIVE + + def restart_services( + self, + services: List[str], + timeout: float = 30.0, + delay: float = 0.1, + ): + self._check_connection() + + # Note: jam 2021-04-20 We don't implement ChangeID, but the default caller of this is + # Container.restart() which currently ignores the return value + known_services = self._render_services() + for name in services: + if name not in known_services: + # TODO: jam 2021-04-20 This needs a better error type + # 400 Bad Request: service "bal" does not exist + raise RuntimeError(f'400 Bad Request: service "{name}" does not exist') + for name in services: + self._service_status[name] = pebble.ServiceStatus.ACTIVE + + def wait_change( + self, + change_id: pebble.ChangeID, + timeout: float = 30.0, + delay: float = 0.1, + ) -> pebble.Change: + raise NotImplementedError(self.wait_change) + + def add_layer( + self, + label: str, + layer: Union[str, "pebble.LayerDict", pebble.Layer], + *, + combine: bool = False, + ): + if isinstance(layer, (str, dict)): + layer_obj = pebble.Layer(layer) + elif isinstance(layer, pebble.Layer): + layer_obj = layer + else: + raise TypeError( + f"layer must be str, dict, or pebble.Layer, not {type(layer).__name__}", + ) + + self._check_connection() + + if label in self._layers: + if not combine: + raise RuntimeError(f'400 Bad Request: layer "{label}" already exists') + layer = self._layers[label] + + for name, service in layer_obj.services.items(): + # 'override' is actually single quoted in the real error, but + # it shouldn't be, hopefully that gets cleaned up. + if not service.override: + raise RuntimeError( + f'500 Internal Server Error: layer "{label}" must define' + f'"override" for service "{name}"', + ) + if service.override not in ("merge", "replace"): + raise RuntimeError( + f'500 Internal Server Error: layer "{label}" has invalid ' + f'"override" value on service "{name}"', + ) + elif service.override == "replace": + layer.services[name] = service + elif service.override == "merge": + if combine and name in layer.services: + layer.services[name]._merge(service) + else: + layer.services[name] = service + + for name, check in layer_obj.checks.items(): + if not check.override: + raise RuntimeError( + f'500 Internal Server Error: layer "{label}" must define' + f'"override" for check "{name}"', + ) + if check.override not in ("merge", "replace"): + raise RuntimeError( + f'500 Internal Server Error: layer "{label}" has invalid ' + f'"override" value for check "{name}"', + ) + elif check.override == "replace": + layer.checks[name] = check + elif check.override == "merge": + if combine and name in layer.checks: + layer.checks[name]._merge(check) + else: + layer.checks[name] = check + + for name, log_target in layer_obj.log_targets.items(): + if not log_target.override: + raise RuntimeError( + f'500 Internal Server Error: layer "{label}" must define' + f'"override" for log target "{name}"', + ) + if log_target.override not in ("merge", "replace"): + raise RuntimeError( + f'500 Internal Server Error: layer "{label}" has invalid ' + f'"override" value for log target "{name}"', + ) + elif log_target.override == "replace": + layer.log_targets[name] = log_target + elif log_target.override == "merge": + if combine and name in layer.log_targets: + layer.log_targets[name]._merge(log_target) + else: + layer.log_targets[name] = log_target + + else: + self._layers[label] = layer_obj + + def _render_services(self) -> Dict[str, pebble.Service]: + services: Dict[str, pebble.Service] = {} + for key in sorted(self._layers.keys()): + layer = self._layers[key] + for name, service in layer.services.items(): + services[name] = service + return services + + def _render_checks(self) -> Dict[str, pebble.Check]: + checks: Dict[str, pebble.Check] = {} + for key in sorted(self._layers.keys()): + layer = self._layers[key] + for name, check in layer.checks.items(): + checks[name] = check + return checks + + def _render_log_targets(self) -> Dict[str, pebble.LogTarget]: + log_targets: Dict[str, pebble.LogTarget] = {} + for key in sorted(self._layers.keys()): + layer = self._layers[key] + for name, log_target in layer.log_targets.items(): + log_targets[name] = log_target + return log_targets + + def get_services( + self, + names: Optional[List[str]] = None, + ) -> List[pebble.ServiceInfo]: + if isinstance(names, str): + raise TypeError( + f'start_services should take a list of names, not just "{names}"', + ) + + self._check_connection() + services = self._render_services() + infos: List[pebble.ServiceInfo] = [] + if names is None: + names = sorted(services.keys()) + for name in sorted(names): + try: + service = services[name] + except KeyError: + # in pebble, it just returns "nothing matched" if there are 0 matches, + # but it ignores services it doesn't recognize + continue + status = self._service_status.get(name, pebble.ServiceStatus.INACTIVE) + if service.startup == "": + startup = pebble.ServiceStartup.DISABLED + else: + startup = pebble.ServiceStartup(service.startup) + info = pebble.ServiceInfo( + name, + startup=startup, + current=pebble.ServiceStatus(status), + ) + infos.append(info) + return infos + + @staticmethod + def _check_absolute_path(path: str): + if not path.startswith("/"): + raise pebble.PathError( + "generic-file-error", + f"paths must be absolute, got {path!r}", + ) + + def pull( + self, + path: str, + *, + encoding: Optional[str] = "utf-8", + ) -> Union[BinaryIO, TextIO]: + self._check_connection() + self._check_absolute_path(path) + file_path = self._root / path[1:] + try: + return cast( + Union[BinaryIO, TextIO], + file_path.open("rb" if encoding is None else "r", encoding=encoding), + ) + except FileNotFoundError: + raise pebble.PathError( + "not-found", + f"stat {path}: no such file or directory", + ) from None + except IsADirectoryError: + raise pebble.PathError( + "generic-file-error", + f'can only read a regular file: "{path}"', + ) from None + + def push( + self, + path: str, + source: Union[bytes, str, io.StringIO, io.BytesIO, BinaryIO], + *, + encoding: str = "utf-8", + make_dirs: bool = False, + permissions: Optional[int] = None, + user_id: Optional[int] = None, + user: Optional[str] = None, + group_id: Optional[int] = None, + group: Optional[str] = None, + ) -> None: + self._check_connection() + if permissions is not None and not (0 <= permissions <= 0o777): + raise pebble.PathError( + "generic-file-error", + f"permissions not within 0o000 to 0o777: {permissions:#o}", + ) + self._check_absolute_path(path) + file_path = self._root / path[1:] + if make_dirs and not file_path.parent.exists(): + self.make_dir( + os.path.dirname(path), + make_parents=True, + permissions=None, + user_id=user_id, + user=user, + group_id=group_id, + group=group, + ) + permissions = permissions if permissions is not None else 0o644 + try: + if isinstance(source, str): + file_path.write_text(source, encoding=encoding) + elif isinstance(source, bytes): + file_path.write_bytes(source) + else: + # If source is binary, open file in binary mode and ignore encoding param + is_binary = isinstance(source.read(0), bytes) # type: ignore + open_mode = "wb" if is_binary else "w" + open_encoding = None if is_binary else encoding + with file_path.open(open_mode, encoding=open_encoding) as f: + shutil.copyfileobj(cast(io.IOBase, source), cast(io.IOBase, f)) + os.chmod(file_path, permissions) + except FileNotFoundError as e: + raise pebble.PathError( + "not-found", + f"parent directory not found: {e.args[0]}", + ) from None + except NotADirectoryError: + raise pebble.PathError( + "generic-file-error", + f"open {path}.~: not a directory", + ) from None + + def list_files( + self, + path: str, + *, + pattern: Optional[str] = None, + itself: bool = False, + ) -> List[pebble.FileInfo]: + self._check_connection() + self._check_absolute_path(path) + file_path = self._root / path[1:] + if not file_path.exists(): + raise self._api_error(404, f"stat {path}: no such file or directory") + files = [file_path] + if not itself: + try: + files = [file_path / file for file in os.listdir(file_path)] + except NotADirectoryError: + pass + + if pattern is not None: + files = [file for file in files if fnmatch.fnmatch(file.name, pattern)] + + file_infos = [Container._build_fileinfo(file) for file in files] + for file_info in file_infos: + rel_path = os.path.relpath(file_info.path, start=self._root) + rel_path = "/" if rel_path == "." else "/" + rel_path + file_info.path = rel_path + if rel_path == "/": + file_info.name = "/" + return file_infos + + def make_dir( + self, + path: str, + *, + make_parents: bool = False, + permissions: Optional[int] = None, + user_id: Optional[int] = None, + user: Optional[str] = None, + group_id: Optional[int] = None, + group: Optional[str] = None, + ) -> None: + self._check_connection() + if permissions is not None and not (0 <= permissions <= 0o777): + raise pebble.PathError( + "generic-file-error", + f"permissions not within 0o000 to 0o777: {permissions:#o}", + ) + self._check_absolute_path(path) + dir_path = self._root / path[1:] + if not dir_path.parent.exists() and not make_parents: + raise pebble.PathError("not-found", f"parent directory not found: {path}") + if not dir_path.parent.exists() and make_parents: + self.make_dir( + os.path.dirname(path), + make_parents=True, + permissions=permissions, + user_id=user_id, + user=user, + group_id=group_id, + group=group, + ) + try: + permissions = permissions if permissions else 0o755 + dir_path.mkdir() + os.chmod(dir_path, permissions) + except FileExistsError: + if not make_parents: + raise pebble.PathError( + "generic-file-error", + f"mkdir {path}: file exists", + ) from None + except NotADirectoryError as e: + # Attempted to create a subdirectory of a file + raise pebble.PathError( + "generic-file-error", + f"not a directory: {e.args[0]}", + ) from None + + def remove_path(self, path: str, *, recursive: bool = False): + self._check_connection() + self._check_absolute_path(path) + file_path = self._root / path[1:] + if not file_path.exists(): + if recursive: + return + raise pebble.PathError( + "not-found", + f"remove {path}: no such file or directory", + ) + if file_path.is_dir(): + if recursive: + shutil.rmtree(file_path) + else: + try: + file_path.rmdir() + except OSError as e: + raise pebble.PathError( + "generic-file-error", + "cannot remove non-empty directory without recursive=True", + ) from e + else: + file_path.unlink() + + def send_signal(self, sig: Union[int, str], service_names: Iterable[str]): + if not service_names: + raise TypeError("send_signal expected at least 1 service name, got 0") + self._check_connection() + + # Convert signal to str + if isinstance(sig, int): + sig = signal.Signals(sig).name + + # Pebble first validates the service name, and then the signal name. + + plan = self.get_plan() + for service in service_names: + if ( + service not in plan.services + or not self.get_services([service])[0].is_running() + ): + # conform with the real pebble api + message = f'cannot send signal to "{service}": service is not running' + raise self._api_error(500, message) + + # Check if signal name is valid. + try: + signal.Signals[sig] + except KeyError: + # Conform with the real Pebble API. + first_service = next(iter(service_names)) + message = ( + f'cannot send signal to "{first_service}": invalid signal name "{sig}"' + ) + raise self._api_error(500, message) from None + + def get_checks( + self, + level: Optional[pebble.CheckLevel] = None, + names: Optional[Iterable[str]] = None, + ) -> List[pebble.CheckInfo]: + if names is not None: + names = frozenset(names) + return [ + info + for info in self._check_infos.values() + if (level is None or level == info.level) + and (names is None or info.name in names) + ] + + def notify( + self, + type: pebble.NoticeType, + key: str, + *, + data: Optional[Dict[str, str]] = None, + repeat_after: Optional[datetime.timedelta] = None, + ) -> str: + notice_id, _ = self._notify(type, key, data=data, repeat_after=repeat_after) + return notice_id + + def _notify( + self, + type: pebble.NoticeType, + key: str, + *, + data: Optional[Dict[str, str]] = None, + repeat_after: Optional[datetime.timedelta] = None, + ) -> Tuple[str, bool]: + """Record an occurrence of a notice with the specified details. + + Return a tuple of (notice_id, new_or_repeated). + """ + # The shape of the code below is taken from State.AddNotice in Pebble. + now = datetime.datetime.now(tz=datetime.timezone.utc) + + new_or_repeated = False + unique_key = (type.value, key) + notice = self._notices.get(unique_key) + if notice is None: + # First occurrence of this notice uid+type+key + self._last_notice_id += 1 + notice = pebble.Notice( + id=str(self._last_notice_id), + user_id=0, # Charm should always be able to read pebble_notify notices. + type=type, + key=key, + first_occurred=now, + last_occurred=now, + last_repeated=now, + expire_after=datetime.timedelta(days=7), + occurrences=1, + last_data=data or {}, + repeat_after=repeat_after, + ) + self._notices[unique_key] = notice + new_or_repeated = True + else: + # Additional occurrence, update existing notice + last_repeated = notice.last_repeated + if repeat_after is None or now > notice.last_repeated + repeat_after: + # Update last repeated time if repeat-after time has elapsed (or is None) + last_repeated = now + new_or_repeated = True + notice = dataclasses.replace( + notice, + last_occurred=now, + last_repeated=last_repeated, + occurrences=notice.occurrences + 1, + last_data=data or {}, + repeat_after=repeat_after, + ) + self._notices[unique_key] = notice + + return notice.id, new_or_repeated + + def _api_error(self, code: int, message: str) -> pebble.APIError: + status = http.HTTPStatus(code).phrase + body = { + "type": "error", + "status-code": code, + "status": status, + "result": {"message": message}, + } + return pebble.APIError(body, code, status, message) + + def get_notice(self, id: str) -> pebble.Notice: + for notice in self._notices.values(): + if notice.id == id: + return notice + raise self._api_error(404, f'cannot find notice with ID "{id}"') + + def get_notices( + self, + *, + users: Optional[pebble.NoticesUsers] = None, + user_id: Optional[int] = None, + types: Optional[Iterable[Union[pebble.NoticeType, str]]] = None, + keys: Optional[Iterable[str]] = None, + ) -> List[pebble.Notice]: + # Similar logic as api_notices.go:v1GetNotices in Pebble. + + filter_user_id = 0 # default is to filter by request UID (root) + if user_id is not None: + filter_user_id = user_id + if users is not None: + if user_id is not None: + raise self._api_error(400, 'cannot use both "users" and "user_id"') + filter_user_id = None + + if types is not None: + types = { + (t.value if isinstance(t, pebble.NoticeType) else t) for t in types + } + if keys is not None: + keys = set(keys) + + notices = [ + notice + for notice in self._notices.values() + if self._notice_matches(notice, filter_user_id, types, keys) + ] + notices.sort(key=lambda notice: notice.last_repeated) + return notices + + @staticmethod + def _notice_matches( + notice: pebble.Notice, + user_id: Optional[int] = None, + types: Optional[Set[str]] = None, + keys: Optional[Set[str]] = None, + ) -> bool: + # Same logic as NoticeFilter.matches in Pebble. + # For example: if user_id filter is set and it doesn't match, return False. + if user_id is not None and not ( + notice.user_id is None or user_id == notice.user_id + ): + return False + if types is not None and notice.type not in types: + return False + if keys is not None and notice.key not in keys: + return False + return True diff --git a/scenario/runtime.py b/scenario/runtime.py index f4df73db..1792c2a0 100644 --- a/scenario/runtime.py +++ b/scenario/runtime.py @@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Dict, FrozenSet, List, Optional, Type, TypeVar, Union import yaml -from ops import CollectStatusEvent, pebble +from ops import CollectStatusEvent, charm, pebble from ops.framework import ( CommitEvent, EventBase, @@ -37,9 +37,10 @@ SubordinateRelation, ) -if TYPE_CHECKING: # pragma: no cover - from ops.testing import CharmType +# CharmType represents user charms that are derived from CharmBase. +_CharmType = TypeVar("_CharmType", bound=charm.CharmBase) +if TYPE_CHECKING: # pragma: no cover from scenario.context import Context from scenario.state import State, _CharmSpec, _Event @@ -294,7 +295,7 @@ def _get_event_env(self, state: "State", event: "_Event", charm_root: Path): return env @staticmethod - def _wrap(charm_type: Type["CharmType"]) -> Type["CharmType"]: + def _wrap(charm_type: Type["_CharmType"]) -> Type["_CharmType"]: # dark sorcery to work around framework using class attrs to hold on to event sources # todo this should only be needed if we call play multiple times on the same runtime. # can we avoid it? @@ -307,7 +308,7 @@ class WrappedCharm(charm_type): # type: ignore on = WrappedEvents() WrappedCharm.__name__ = charm_type.__name__ - return typing.cast(Type["CharmType"], WrappedCharm) + return typing.cast(Type["_CharmType"], WrappedCharm) @contextmanager def _virtual_charm_root(self): diff --git a/scenario/state.py b/scenario/state.py index 9179735d..b089c582 100644 --- a/scenario/state.py +++ b/scenario/state.py @@ -117,7 +117,7 @@ class ActionFailed(Exception): """Raised at the end of the hook if the charm has called ``event.fail()``.""" - def __init__(self, message: str, state: "State"): + def __init__(self, message: str, *, state: "State"): self.message = message self.state = state diff --git a/tests/helpers.py b/tests/helpers.py index 5ceffa9d..35655340 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -18,11 +18,10 @@ from scenario.context import _DEFAULT_JUJU_VERSION, Context if TYPE_CHECKING: # pragma: no cover - from ops.testing import CharmType - + from scenario.runtime import _CharmType from scenario.state import State, _Event - _CT = TypeVar("_CT", bound=Type[CharmType]) + _CT = TypeVar("_CT", bound=Type[_CharmType]) logger = logging.getLogger() @@ -30,9 +29,9 @@ def trigger( state: "State", event: Union[str, "_Event"], - charm_type: Type["CharmType"], - pre_event: Optional[Callable[["CharmType"], None]] = None, - post_event: Optional[Callable[["CharmType"], None]] = None, + charm_type: Type["_CharmType"], + pre_event: Optional[Callable[["_CharmType"], None]] = None, + post_event: Optional[Callable[["_CharmType"], None]] = None, meta: Optional[Dict[str, Any]] = None, actions: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None, diff --git a/tests/test_charm_spec_autoload.py b/tests/test_charm_spec_autoload.py index 57b93a31..f90e5add 100644 --- a/tests/test_charm_spec_autoload.py +++ b/tests/test_charm_spec_autoload.py @@ -1,17 +1,15 @@ import importlib import sys -import tempfile from contextlib import contextmanager from pathlib import Path from typing import Type import pytest import yaml -from ops import CharmBase -from ops.testing import CharmType from scenario import Context, Relation, State from scenario.context import ContextSetupError +from scenario.runtime import _CharmType from scenario.state import MetadataNotFoundError, _CharmSpec CHARM = """ @@ -22,7 +20,7 @@ class MyCharm(CharmBase): pass @contextmanager -def import_name(name: str, source: Path) -> Type[CharmType]: +def import_name(name: str, source: Path) -> Type[_CharmType]: pkg_path = str(source.parent) sys.path.append(pkg_path) charm = importlib.import_module("charm")