From 6989656396c0e352efb9af6e8af8c6602157b8da Mon Sep 17 00:00:00 2001 From: Peter Law Date: Sat, 24 Sep 2022 23:04:11 +0100 Subject: [PATCH] Annotate almost everything The majority of these were added via MonkeyType from running the tests, with some manual adjustment afterwards. This also includes some minor refactors where doing so made the types clearer. This does leave one typing issue unresolved as it is highlighting a potential bug. --- routemaster/app.py | 8 +- routemaster/cli.py | 2 +- routemaster/conftest.py | 63 ++++++----- routemaster/context.py | 30 +++++- routemaster/cron.py | 2 +- routemaster/exit_conditions/analysis.py | 4 +- routemaster/exit_conditions/error_display.py | 5 +- routemaster/exit_conditions/evaluator.py | 102 ++++++++++++++---- routemaster/exit_conditions/exceptions.py | 3 +- routemaster/exit_conditions/peephole.py | 4 +- routemaster/feeds.py | 10 +- routemaster/logging/base.py | 64 +++++++---- routemaster/logging/plugins.py | 4 +- routemaster/logging/python_logger.py | 16 +-- routemaster/logging/split_logger.py | 6 +- routemaster/middleware.py | 2 +- .../091a6e84d9ac_initial_state_machine.py | 2 +- routemaster/server/endpoints.py | 60 +++++++---- routemaster/state_machine/actions.py | 3 +- routemaster/state_machine/api.py | 5 +- routemaster/state_machine/utils.py | 17 ++- routemaster/utils.py | 8 +- routemaster/validation.py | 25 +++-- 23 files changed, 307 insertions(+), 138 deletions(-) diff --git a/routemaster/app.py b/routemaster/app.py index 7ad715cd..52c27c4b 100644 --- a/routemaster/app.py +++ b/routemaster/app.py @@ -1,7 +1,7 @@ """Core App singleton that holds state for the application.""" import threading import contextlib -from typing import Dict, Optional +from typing import Dict, Iterator, Optional from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.engine import Engine @@ -32,7 +32,7 @@ def __init__( self.config = config self.initialise() - def initialise(self): + def initialise(self) -> None: """ Initialise this instance of the app. @@ -66,7 +66,7 @@ def session(self) -> Session: return self._current_session - def set_rollback(self): + def set_rollback(self) -> None: """Mark the current session as needing rollback.""" if self._current_session is None: raise RuntimeError( @@ -77,7 +77,7 @@ def set_rollback(self): self._needs_rollback = True @contextlib.contextmanager - def new_session(self): + def new_session(self) -> Iterator[None]: """Run a single session in this scope.""" if self._current_session is not None: raise RuntimeError("There is already a session running.") diff --git a/routemaster/cli.py b/routemaster/cli.py index f1e705b2..0be72689 100644 --- a/routemaster/cli.py +++ b/routemaster/cli.py @@ -107,7 +107,7 @@ def post_fork(): cron_thread.stop() -def _validate_config(app: App): +def _validate_config(app: App) -> None: try: validate_config(app, app.config) except ValidationError as e: diff --git a/routemaster/conftest.py b/routemaster/conftest.py index d18b9b73..b815b881 100644 --- a/routemaster/conftest.py +++ b/routemaster/conftest.py @@ -8,7 +8,7 @@ import functools import contextlib import subprocess -from typing import Any, Dict +from typing import Any, Dict, Callable, Iterator, ContextManager from unittest import mock import pytest @@ -19,7 +19,10 @@ from sqlalchemy import create_engine from werkzeug.test import Client from sqlalchemy.orm import sessionmaker +from _pytest.fixtures import SubRequest +from sqlalchemy.orm.session import Session +import routemaster.config.model from routemaster import state_machine from routemaster.db import Label, History, metadata from routemaster.app import App @@ -45,6 +48,7 @@ from routemaster.logging import BaseLogger, SplitLogger, register_loggers from routemaster.webhooks import ( WebhookResult, + WebhookRunner, webhook_runner_for_state_machine, ) from routemaster.middleware import wrap_application @@ -235,7 +239,8 @@ class TestApp(App): 2. We can set a flag on access to `.db` so that we needn't bother with resetting the database if nothing has actually been changed. """ - def __init__(self, config): + + def __init__(self, config: routemaster.config.model.Config) -> None: self.config = config self.session_used = False self.logger = SplitLogger(config, loggers=register_loggers(config)) @@ -249,13 +254,13 @@ def __init__(self, config): } @property - def session(self): + def session(self) -> Session: """Start if necessary and return a shared session.""" self.session_used = True return super().session -def get_test_app(**kwargs): +def get_test_app(**kwargs) -> TestApp: """Instantiate an app with testing parameters.""" return TestApp(Config( state_machines=kwargs.get('state_machines', TEST_STATE_MACHINES), @@ -270,28 +275,28 @@ def get_test_app(**kwargs): @pytest.fixture() -def client(custom_app=None): +def client(custom_app: None = None) -> Client: """Create a werkzeug test client.""" _app = get_test_app() if custom_app is None else custom_app - server.config.app = _app + server.config.app = _app # type: ignore[attr-defined] _app.logger.init_flask(server) return Client(wrap_application(_app, server), werkzeug.Response) @pytest.fixture() -def app(**kwargs): +def app(**kwargs: Any) -> TestApp: """Create an `App` config object for testing.""" return get_test_app(**kwargs) @pytest.fixture() -def custom_app(): +def custom_app() -> Callable[..., TestApp]: """Return the test app generator so that we can pass in custom config.""" return get_test_app @pytest.fixture() -def app_env(): +def app_env() -> Dict[str, str]: """ Create a dict of environment variables. @@ -307,7 +312,7 @@ def app_env(): @pytest.fixture(autouse=True, scope='session') -def database_creation(request): +def database_creation(request: SubRequest) -> Iterator[None]: """Wrap test session in creating and destroying all required tables.""" metadata.drop_all(bind=TEST_ENGINE) metadata.create_all(bind=TEST_ENGINE) @@ -315,7 +320,7 @@ def database_creation(request): @pytest.fixture(autouse=True) -def database_clear(app): +def database_clear(app: TestApp) -> Iterator[None]: """Truncate all tables after each test.""" yield if app.session_used: @@ -328,7 +333,10 @@ def database_clear(app): @pytest.fixture() -def create_label(app, mock_test_feed): +def create_label( + app: TestApp, + mock_test_feed: Callable[[], ContextManager[None]], +) -> Callable[[str, str, Dict[str, Any]], LabelRef]: """Create a label in the database.""" def _create( @@ -348,7 +356,7 @@ def _create( @pytest.fixture() -def delete_label(app): +def delete_label(app: TestApp) -> Callable[[str, str], None]: """ Mark a label in the database as deleted. """ @@ -364,7 +372,10 @@ def _delete(name: str, state_machine_name: str) -> None: @pytest.fixture() -def create_deleted_label(create_label, delete_label): +def create_deleted_label( + create_label: Callable[[str, str, Dict[str, Any]], LabelRef], + delete_label: Callable[[str, str], None], +) -> Callable[[str, str], LabelRef]: """ Create a label in the database and then delete it. """ @@ -378,7 +389,10 @@ def _create_and_delete(name: str, state_machine_name: str) -> LabelRef: @pytest.fixture() -def mock_webhook(): +def mock_webhook() -> Callable[ + [WebhookResult], + ContextManager[Callable[[StateMachine], WebhookRunner]], +]: """Mock the test config's webhook call.""" @contextlib.contextmanager def _mock(result=WebhookResult.SUCCESS): @@ -392,7 +406,7 @@ def _mock(result=WebhookResult.SUCCESS): @pytest.fixture() -def mock_test_feed(): +def mock_test_feed() -> Callable[[Dict[str, Any]], ContextManager[None]]: """Mock out the test feed.""" @contextlib.contextmanager def _mock(data={'should_do_alternate_action': False}): @@ -414,7 +428,7 @@ def _mock(data={'should_do_alternate_action': False}): @pytest.fixture() -def assert_history(app): +def assert_history(app: TestApp) -> Callable: """Assert that the database history matches what is expected.""" def _assert(entries): with app.new_session(): @@ -432,7 +446,7 @@ def _assert(entries): @pytest.fixture() -def set_metadata(app): +def set_metadata(app: TestApp) -> Callable: """Directly set the metadata for a label in the database.""" def _inner(label, update): with app.new_session(): @@ -449,7 +463,7 @@ def _inner(label, update): @pytest.fixture() -def make_context(app): +def make_context(app: TestApp) -> Callable: """Factory for Contexts that provides sane defaults for testing.""" def _inner(**kwargs): logger = BaseLogger(app.config) @@ -491,9 +505,9 @@ def version(): @pytest.fixture() -def current_state(app): +def current_state(app: TestApp) -> Callable[[LabelRef], str]: """Get the current state of a label.""" - def _inner(label): + def _inner(label: LabelRef) -> str: with app.new_session(): return app.session.query( History.new_state, @@ -501,13 +515,14 @@ def _inner(label): label_name=label.name, label_state_machine=label.state_machine, ).order_by( - History.id.desc(), + # TODO: use the sqlalchemy mypy plugin rather than our stubs + History.id.desc(), # type: ignore[attr-defined] ).limit(1).scalar() return _inner @pytest.fixture() -def unused_tcp_port(): +def unused_tcp_port() -> int: """Returns an unused TCP port, inspired by pytest-asyncio.""" with contextlib.closing(socket.socket()) as sock: sock.bind(('127.0.0.1', 0)) @@ -515,7 +530,7 @@ def unused_tcp_port(): @pytest.fixture() -def routemaster_serve_subprocess(unused_tcp_port): +def routemaster_serve_subprocess(unused_tcp_port: int) -> Callable: """ Fixture to spawn a routemaster server as a subprocess. diff --git a/routemaster/context.py b/routemaster/context.py index 010a1f9c..0e0e1d9b 100644 --- a/routemaster/context.py +++ b/routemaster/context.py @@ -1,10 +1,25 @@ """Context definition for exit condition programs.""" import datetime -from typing import Any, Dict, Iterable, Optional, Sequence +from typing import ( + Any, + Dict, + Tuple, + Union, + Callable, + Iterable, + Optional, + Sequence, + ContextManager, +) + +import requests from routemaster.feeds import Feed from routemaster.utils import get_path +ResponseLogger = Callable[[requests.Response], None] +FeedLoggingContext = Callable[[str], ContextManager[ResponseLogger]] + class Context(object): """Execution context for exit condition programs.""" @@ -18,7 +33,7 @@ def __init__( feeds: Dict[str, Feed], accessed_variables: Iterable[str], current_history_entry: Optional[Any], - feed_logging_context, + feed_logging_context: FeedLoggingContext, ) -> None: """Create an execution context.""" if now.tzinfo is None: @@ -65,7 +80,12 @@ def _lookup_history(self, path: Sequence[str]) -> Any: 'previous_state': self.current_history_entry.old_state, }[variable_name] - def property_handler(self, property_name, value, **kwargs): + def property_handler( + self, + property_name: Union[Tuple[str, ...]], + value: Any, + **kwargs: Any, + ) -> bool: """Handle a property in execution.""" if property_name == ('passed',): epoch = kwargs['since'] @@ -82,8 +102,8 @@ def _pre_warm_feeds( self, label: str, accessed_variables: Iterable[str], - logging_context, - ): + logging_context: FeedLoggingContext, + ) -> None: for accessed_variable in accessed_variables: parts = accessed_variable.split('.') diff --git a/routemaster/cron.py b/routemaster/cron.py index ba490917..f2dae883 100644 --- a/routemaster/cron.py +++ b/routemaster/cron.py @@ -82,7 +82,7 @@ def process_job( # Bound when scheduling a specific job for a state fn: LabelStateProcessor, label_provider: LabelProvider, -): +) -> None: """Process a single instance of a single cron job.""" def _iter_labels_until_terminating( diff --git a/routemaster/exit_conditions/analysis.py b/routemaster/exit_conditions/analysis.py index c1a840ee..2b4158ea 100644 --- a/routemaster/exit_conditions/analysis.py +++ b/routemaster/exit_conditions/analysis.py @@ -1,10 +1,12 @@ """Analysis of compiled programs.""" +from typing import Any, Tuple, Union, Iterator + from routemaster.exit_conditions.operations import Operation -def find_accessed_keys(instructions): +def find_accessed_keys(instructions: Any) -> Iterator[Union[Tuple[str, str], Tuple[str], Tuple[str, str, str]]]: """Yield each key accessed under the program.""" for instruction, *args in instructions: if instruction == Operation.LOOKUP: diff --git a/routemaster/exit_conditions/error_display.py b/routemaster/exit_conditions/error_display.py index b9dde3fe..eaa49032 100644 --- a/routemaster/exit_conditions/error_display.py +++ b/routemaster/exit_conditions/error_display.py @@ -1,7 +1,8 @@ """Human-readable ParseError handling.""" +from typing import Tuple -def _find_line_containing(source, index): +def _find_line_containing(source: str, index: int) -> Tuple[int, str, int]: """Find (line number, line, offset) triple for an index into a string.""" lines = source.splitlines() @@ -20,7 +21,7 @@ def _find_line_containing(source, index): raise AssertionError("index >> len(source)") -def format_parse_error_message(*, source, error): +def format_parse_error_message(*, source, error) -> str: """Format a parse error on some source for nicer display.""" error_line_number, error_line, error_offset = _find_line_containing( source, diff --git a/routemaster/exit_conditions/evaluator.py b/routemaster/exit_conditions/evaluator.py index 30a1f61d..2650a270 100644 --- a/routemaster/exit_conditions/evaluator.py +++ b/routemaster/exit_conditions/evaluator.py @@ -1,63 +1,115 @@ """Exit condition program evaluator.""" -from routemaster.exit_conditions.operations import Operation - +import datetime +from typing import Any, Dict, List, Tuple, Union, Callable, Iterable -def _evaluate_to_bool(stack, lookup, property_handler): +from routemaster.exit_conditions.operations import Operation +from routemaster.exit_conditions.prepositions import Preposition + +Stack = List[Union[ + Any, + bool, + int, + str, + None, + datetime.datetime, + Tuple[int, int], +]] + + +def _evaluate_to_bool( + stack: Stack, + lookup: Callable, + property_handler: Callable, +) -> None: top_of_stack = stack.pop() stack.append(bool(top_of_stack)) -def _evaluate_not(stack, lookup, property_handler): +def _evaluate_not( + stack: Stack, + lookup: Callable, + property_handler: Callable, +) -> None: top_of_stack = stack.pop() stack.append(not top_of_stack) -def _evaluate_and(stack, lookup, property_handler): +def _evaluate_and( + stack: Stack, + lookup: Callable, + property_handler: Callable, +) -> None: rhs = stack.pop() lhs = stack.pop() stack.append(lhs and rhs) -def _evaluate_or(stack, lookup, property_handler): +def _evaluate_or( + stack: Stack, + lookup: Callable, + property_handler: Callable, +) -> None: rhs = stack.pop() lhs = stack.pop() stack.append(lhs or rhs) -def _evaluate_literal(stack, lookup, property_handler, value): +def _evaluate_literal( + stack: Stack, + lookup: Callable, + property_handler: Callable, + value: Any, +) -> None: stack.append(value) -def _evaluate_lookup(stack, lookup, property_handler, key): +def _evaluate_lookup( + stack: Stack, + lookup: Callable, + property_handler: Callable, + key: Tuple[str, ...], +) -> None: stack.append(lookup(key)) -def _evaluate_eq(stack, lookup, property_handler): +def _evaluate_eq( + stack: Stack, + lookup: Callable, + property_handler: Callable, +) -> None: rhs = stack.pop() lhs = stack.pop() stack.append(lhs == rhs) -def _evaluate_lt(stack, lookup, property_handler): +def _evaluate_lt( + stack: Stack, + lookup: Callable, + property_handler: Callable, +) -> None: rhs = stack.pop() lhs = stack.pop() - stack.append(lhs < rhs) + stack.append(lhs < rhs) # type: ignore[operator] -def _evaluate_gt(stack, lookup, property_handler): +def _evaluate_gt( + stack: Stack, + lookup: Callable, + property_handler: Callable, +) -> None: rhs = stack.pop() lhs = stack.pop() - stack.append(lhs > rhs) + stack.append(lhs > rhs) # type: ignore[operator] def _evaluate_property( - stack, - lookup, - property_handler, - property_name, - prepositions, -): + stack: Stack, + lookup: Callable, + property_handler: Callable, + property_name: Tuple[str, ...], + prepositions: Tuple[Preposition, ...], +) -> None: prepositional_arguments = {} for preposition in reversed(prepositions): prepositional_arguments[preposition.value] = stack.pop() @@ -67,7 +119,7 @@ def _evaluate_property( ) -EVALUATORS = { +EVALUATORS: Dict[Operation, Callable] = { Operation.TO_BOOL: _evaluate_to_bool, Operation.AND: _evaluate_and, Operation.OR: _evaluate_or, @@ -81,13 +133,17 @@ def _evaluate_property( } -def evaluate(instructions, lookup, property_handler): +def evaluate( + instructions: Iterable[Any], + lookup: Callable, + property_handler: Callable, +) -> bool: """ Run the instructions given in `instructions`. Returns the single result. """ - stack = [] + stack: Stack = [] for instruction, *args in instructions: EVALUATORS[instruction](stack, lookup, property_handler, *args) - return stack.pop() + return stack.pop() # type: ignore[return-value] diff --git a/routemaster/exit_conditions/exceptions.py b/routemaster/exit_conditions/exceptions.py index 9871078f..8dd537df 100644 --- a/routemaster/exit_conditions/exceptions.py +++ b/routemaster/exit_conditions/exceptions.py @@ -1,10 +1,11 @@ """Exceptions for use in exit condition handling.""" +from typing import Tuple class ParseError(Exception): """Errors that occur when tokenizing or parsing.""" - def __init__(self, message, location): + def __init__(self, message: str, location: Tuple[int, int]) -> None: """ Construct by message and location. diff --git a/routemaster/exit_conditions/peephole.py b/routemaster/exit_conditions/peephole.py index e77e03f0..f33f539e 100644 --- a/routemaster/exit_conditions/peephole.py +++ b/routemaster/exit_conditions/peephole.py @@ -1,5 +1,7 @@ """Peephole evaluator optimiser.""" +from typing import Any, List + from routemaster.exit_conditions.operations import Operation MATCHERS = [ @@ -77,7 +79,7 @@ ] -def peephole_optimise(instructions): +def peephole_optimise(instructions: Any) -> List[Any]: """Run peephole optimisations over a given instruction sequence.""" instructions = list(instructions) diff --git a/routemaster/feeds.py b/routemaster/feeds.py index 11dec059..9c337c7c 100644 --- a/routemaster/feeds.py +++ b/routemaster/feeds.py @@ -1,14 +1,16 @@ """Creation and fetching of feed data.""" import threading -from typing import Any, Dict, Callable, Optional +from typing import Any, Dict, Union, Callable, Optional, Sequence from dataclasses import dataclass import requests +from requests.sessions import Session from routemaster.utils import get_path, template_url +from routemaster.config.model import StateMachine -def feeds_for_state_machine(state_machine) -> Dict[str, 'Feed']: +def feeds_for_state_machine(state_machine: StateMachine) -> Dict[str, 'Feed']: """Get a mapping of feed prefixes to unfetched feeds.""" return { x.name: Feed(x.url, state_machine.name) # type: ignore @@ -24,7 +26,7 @@ class FeedNotFetched(Exception): _feed_sessions = threading.local() -def _get_feed_session(): +def _get_feed_session() -> Session: # We cache sessions per thread so that we can use `requests.Session`'s # underlying `urllib3` connection pooling. if not hasattr(_feed_sessions, 'session'): @@ -59,7 +61,7 @@ def prefetch( response.raise_for_status() self.data = response.json() - def lookup(self, path): + def lookup(self, path: Union[Sequence[str]]) -> Optional[Union[bool, str]]: """Lookup data from a feed's contents.""" if self.data is None: raise FeedNotFetched(self.url) diff --git a/routemaster/logging/base.py b/routemaster/logging/base.py index 3c04359f..0615b8d2 100644 --- a/routemaster/logging/base.py +++ b/routemaster/logging/base.py @@ -1,15 +1,21 @@ """Base class for logging plugins.""" import contextlib +from typing import Any, Dict, Type, Tuple, Callable, Iterator, Optional + +import requests +from flask.app import Flask + +from routemaster.config.model import State, Config, StateMachine class BaseLogger: """Base class for logging plugins.""" - def __init__(self, config, *args, **kwargs) -> None: + def __init__(self, config: Optional[Config], *args, **kwargs) -> None: self.config = config - def init_flask(self, flask_app): + def init_flask(self, flask_app: Flask) -> None: """ Entrypoint for configuring logging on the flask server. @@ -19,55 +25,69 @@ def init_flask(self, flask_app): pass @contextlib.contextmanager - def process_cron(self, state_machine, state, fn_name): + def process_cron( + self, + state_machine: StateMachine, + state: State, + fn_name: str, + ) -> Iterator[None]: """Wraps the processing of a cron job for logging purposes.""" yield @contextlib.contextmanager - def process_webhook(self, state_machine, state): + def process_webhook( + self, + state_machine: StateMachine, + state: State, + ) -> Iterator[None]: """Wraps the processing of a webhook for logging purposes.""" yield - def process_request_started(self, environ): + def process_request_started(self, environ: Dict[str, Any]) -> None: """Request started.""" pass def process_request_finished( self, - environ, + environ: Dict[str, Any], *, - status, - headers, - exc_info, - ): + status: int, + headers: Dict[str, Any], + exc_info: Optional[Tuple[Type[RuntimeError], RuntimeError, None]], + ) -> None: """Completes the processing of a request.""" pass def webhook_response( self, - state_machine, - state, - response, - ): + state_machine: StateMachine, + state: State, + response: requests.Response, + ) -> None: """Logs the receipt of a response from a webhook.""" pass @contextlib.contextmanager - def process_feed(self, state_machine, state, feed_url): + def process_feed( + self, + state_machine: StateMachine, + state: State, + feed_url: str, + ) -> Iterator[None]: """Wraps the processing of a feed for logging purposes.""" yield def feed_response( self, - state_machine, - state, - feed_url, - response, - ): + state_machine: StateMachine, + state: State, + feed_url: str, + response: requests.Response, + ) -> None: """Logs the receipt of a response from a feed.""" pass - def __getattr__(self, name): + def __getattr__(self, name: str) -> Callable[[str], None]: """Implement the Python logger API.""" if name in ( 'debug', @@ -81,5 +101,5 @@ def __getattr__(self, name): raise AttributeError(name) - def _log_handler(self, *args, **kwargs): + def _log_handler(self, *args, **kwargs) -> None: pass diff --git a/routemaster/logging/plugins.py b/routemaster/logging/plugins.py index 76afc62e..9398073a 100644 --- a/routemaster/logging/plugins.py +++ b/routemaster/logging/plugins.py @@ -1,15 +1,17 @@ """Plugin loading and configuration.""" import importlib +from typing import Any, List, Union from routemaster.config import Config, LoggingPluginConfig from routemaster.logging.base import BaseLogger +from routemaster.logging.python_logger import PythonLogger class PluginConfigurationException(Exception): """Raised to signal an invalid plugin that was loaded.""" -def register_loggers(config: Config): +def register_loggers(config: Config) -> List[Union[Any, PythonLogger, BaseLogger]]: """ Iterate through all plugins in the config file and instatiate them. """ diff --git a/routemaster/logging/python_logger.py b/routemaster/logging/python_logger.py index 2d0b8ef7..85585282 100644 --- a/routemaster/logging/python_logger.py +++ b/routemaster/logging/python_logger.py @@ -3,7 +3,9 @@ import time import logging import contextlib +from typing import Any, Dict, Type, Tuple, Callable, Iterator, Optional +from routemaster.config.model import State, StateMachine from routemaster.logging.base import BaseLogger @@ -25,7 +27,7 @@ def __init__(self, *args, log_level: str) -> None: self.logger.info(f"Started logger with level {log_level}") @contextlib.contextmanager - def process_cron(self, state_machine, state, fn_name): + def process_cron(self, state_machine: StateMachine, state: State, fn_name: str) -> Iterator[None]: """Process a cron job, logging information to the Python logger.""" self.logger.info( f"Started cron {fn_name} for state {state.name} in " @@ -46,12 +48,12 @@ def process_cron(self, state_machine, state, fn_name): def process_request_finished( self, - environ, + environ: Dict[str, Any], *, - status, - headers, - exc_info, - ): + status: int, + headers: Dict[str, Any], + exc_info: Optional[Tuple[Type[RuntimeError], RuntimeError, None]], + ) -> None: """Process a web request and log some basic info about it.""" self.info("{method} {path} {status}".format( method=environ.get('REQUEST_METHOD'), @@ -59,6 +61,6 @@ def process_request_finished( status=status, )) - def __getattr__(self, name): + def __getattr__(self, name: str) -> Callable: """Fall back to the logger API.""" return getattr(self.logger, name) diff --git a/routemaster/logging/split_logger.py b/routemaster/logging/split_logger.py index 8021619c..f961268f 100644 --- a/routemaster/logging/split_logger.py +++ b/routemaster/logging/split_logger.py @@ -2,7 +2,7 @@ import functools import contextlib -from typing import List +from typing import List, Iterator from routemaster.logging.base import BaseLogger @@ -39,12 +39,12 @@ def __init__(self, *args, loggers: List[BaseLogger]) -> None: ): setattr(self, fn, functools.partial(self._log_all_ctx, fn)) - def _log_all(self, name, *args, **kwargs): + def _log_all(self, name: str, *args, **kwargs) -> None: for logger in self.loggers: getattr(logger, name)(*args, **kwargs) @contextlib.contextmanager - def _log_all_ctx(self, name, *args, **kwargs): + def _log_all_ctx(self, name: str, *args, **kwargs) -> Iterator[None]: with contextlib.ExitStack() as stack: for logger in self.loggers: logger_ctx = getattr(logger, name) diff --git a/routemaster/middleware.py b/routemaster/middleware.py index 0cf8ef58..aa90fee5 100644 --- a/routemaster/middleware.py +++ b/routemaster/middleware.py @@ -23,7 +23,7 @@ ACTIVE_MIDDLEWARES = [] -def middleware(fn: WSGIMiddleware): +def middleware(fn: WSGIMiddleware) -> Callable: """Decorator: add `fn` to ACTIVE_MIDDLEWARES.""" ACTIVE_MIDDLEWARES.append(fn) return fn diff --git a/routemaster/migrations/versions/091a6e84d9ac_initial_state_machine.py b/routemaster/migrations/versions/091a6e84d9ac_initial_state_machine.py index 6c7f881f..26e11fac 100644 --- a/routemaster/migrations/versions/091a6e84d9ac_initial_state_machine.py +++ b/routemaster/migrations/versions/091a6e84d9ac_initial_state_machine.py @@ -2,7 +2,7 @@ Initial state machine Revision ID: 091a6e84d9ac -Revises: +Revises: """ import sqlalchemy as sa diff --git a/routemaster/server/endpoints.py b/routemaster/server/endpoints.py index 9d3038ba..00d9a4c4 100644 --- a/routemaster/server/endpoints.py +++ b/routemaster/server/endpoints.py @@ -1,10 +1,14 @@ """Core API endpoints for routemaster service.""" +from typing import Tuple, Union + import sqlalchemy import pkg_resources from flask import Flask, abort, jsonify, request +from flask.wrappers import Response from routemaster import state_machine +from routemaster.app import App from routemaster.state_machine import ( LabelRef, UnknownLabel, @@ -17,7 +21,7 @@ @server.route('/', methods=['GET']) -def status(): +def status() -> Union[Tuple[Response, int], Response]: """ Status check endpoint. @@ -27,13 +31,16 @@ def status(): - 503 Service Unavailable: if there is any detected reason why the service might not be able to serve requests. """ + by_key = pkg_resources.working_set.by_key # type: ignore[attr-defined] try: - version = pkg_resources.working_set.by_key['routemaster'].version + version = by_key['routemaster'].version except KeyError: # pragma: no cover version = 'development' + app: App = server.config.app # type: ignore[attr-defined] # see cli.py + try: - server.config.app.session.query(sqlalchemy.literal(1)).one() + app.session.query(sqlalchemy.literal(1)).one() return jsonify({ 'status': 'ok', 'state-machines': '/state-machines', @@ -48,20 +55,22 @@ def status(): @server.route('/state-machines', methods=['GET']) -def get_state_machines(): +def get_state_machines() -> Response: """ List the state machines known to this server. Successful return codes return a list of dictionaries containing at least the name of each state machine. """ + app: App = server.config.app # type: ignore[attr-defined] # see cli.py + return jsonify({ 'state-machines': [ { 'name': x.name, 'labels': f'/state-machines/{x.name}/labels', } - for x in server.config.app.config.state_machines.values() + for x in app.config.state_machines.values() ], }) @@ -70,7 +79,7 @@ def get_state_machines(): '/state-machines//labels', methods=['GET'], ) -def get_labels(state_machine_name): +def get_labels(state_machine_name: str) -> Response: """ List the labels in a state machine. @@ -81,7 +90,7 @@ def get_labels(state_machine_name): Successful return codes return a list of dictionaries containing at least the name of each label. """ - app = server.config.app + app: App = server.config.app # type: ignore[attr-defined] # see cli.py try: state_machine_instance = app.config.state_machines[state_machine_name] @@ -100,7 +109,7 @@ def get_labels(state_machine_name): '/state-machines//labels/', methods=['GET'], ) -def get_label(state_machine_name, label_name): +def get_label(state_machine_name: str, label_name: str) -> Response: """ Get a label within a given state machine. @@ -111,7 +120,7 @@ def get_label(state_machine_name, label_name): Successful return codes return the full metadata for the label. """ - app = server.config.app + app: App = server.config.app # type: ignore[attr-defined] # see cli.py label = LabelRef(label_name, state_machine_name) try: @@ -126,6 +135,7 @@ def get_label(state_machine_name, label_name): abort(404, f"State machine '{label.state_machine}' does not exist.") state = state_machine.get_label_state(app, label) + assert state is not None # label deletion handled above return jsonify(metadata=metadata, state=state.name) @@ -133,7 +143,10 @@ def get_label(state_machine_name, label_name): '/state-machines//labels/', methods=['POST'], ) -def create_label(state_machine_name, label_name): +def create_label( + state_machine_name: str, + label_name: str, +) -> Tuple[Response, int]: """ Create a label with a given metadata, and start it in the state machine. @@ -145,14 +158,16 @@ def create_label(state_machine_name, label_name): Successful return codes return the full created metadata for the label. """ - app = server.config.app + app: App = server.config.app # type: ignore[attr-defined] # see cli.py label = LabelRef(label_name, state_machine_name) data = request.get_json() try: - initial_metadata = data['metadata'] + initial_metadata = data['metadata'] # type: ignore[index] except KeyError: abort(400, "No metadata given") + except TypeError: + abort(400, "Metadata must be a mapping") try: initial_state_name = \ @@ -168,10 +183,10 @@ def create_label(state_machine_name, label_name): @server.route( - '/state-machines//labels/', # noqa + '/state-machines//labels/', # noqa methods=['PATCH'], ) -def update_label(state_machine_name, label_name): +def update_label(state_machine_name: str, label_name: str) -> Response: """ Update a label in a state machine. @@ -186,12 +201,13 @@ def update_label(state_machine_name, label_name): Successful return codes return the full new metadata for a label. """ - app = server.config.app + app: App = server.config.app # type: ignore[attr-defined] # see cli.py label = LabelRef(label_name, state_machine_name) + data = request.get_json() try: - patch_metadata = request.get_json()['metadata'] - except KeyError: + patch_metadata = data['metadata'] # type: ignore[index] + except (TypeError, KeyError): abort(400, "No new metadata") try: @@ -200,8 +216,6 @@ def update_label(state_machine_name, label_name): label, patch_metadata, ) - state = state_machine.get_label_state(app, label) - return jsonify(metadata=new_metadata, state=state.name) except UnknownStateMachine: msg = f"State machine '{state_machine_name}' does not exist" abort(404, msg) @@ -212,12 +226,16 @@ def update_label(state_machine_name, label_name): f"'{state_machine_name}'.", ) + state = state_machine.get_label_state(app, label) + assert state is not None # label deletion handled above + return jsonify(metadata=new_metadata, state=state.name) + @server.route( '/state-machines//labels/', methods=['DELETE'], ) -def delete_label(state_machine_name, label_name): +def delete_label(state_machine_name: str, label_name: str) -> Tuple[str, int]: """ Delete a label in a state machine. @@ -228,7 +246,7 @@ def delete_label(state_machine_name, label_name): - 204 No content: if the label is successfully deleted (or did not exist). - 404 Not Found: if the state machine does not exist. """ - app = server.config.app + app: App = server.config.app # type: ignore[attr-defined] # see cli.py label = LabelRef(label_name, state_machine_name) try: diff --git a/routemaster/state_machine/actions.py b/routemaster/state_machine/actions.py index 715b01a8..21a07687 100644 --- a/routemaster/state_machine/actions.py +++ b/routemaster/state_machine/actions.py @@ -4,6 +4,7 @@ import hashlib import functools +import routemaster.db.model from routemaster.db import History from routemaster.app import App from routemaster.utils import template_url @@ -95,7 +96,7 @@ def process_action( return True -def _calculate_idempotency_token(label: LabelRef, latest_history) -> str: +def _calculate_idempotency_token(label: LabelRef, latest_history: routemaster.db.model.History) -> str: """ We want to make sure that an action is only performed once. diff --git a/routemaster/state_machine/api.py b/routemaster/state_machine/api.py index 8bf40ee0..e7b0b4c1 100644 --- a/routemaster/state_machine/api.py +++ b/routemaster/state_machine/api.py @@ -159,7 +159,7 @@ def _process_transitions_for_metadata_update( label: LabelRef, state_machine: StateMachine, state_pending_update: State, -): +) -> None: with app.session.begin_nested(): lock_label(app, label) current_state = get_current_state(app, label, state_machine) @@ -224,6 +224,7 @@ def delete_label(app: App, label: LabelRef) -> None: class LabelStateProcessor(Protocol): """Type signature for the label state processor callable.""" + def __call__( self, *, @@ -247,7 +248,7 @@ def process_cron( app: App, state_machine: StateMachine, state: State, -): +) -> None: """ Cron event entrypoint. """ diff --git a/routemaster/state_machine/utils.py b/routemaster/state_machine/utils.py index 2560805b..ea76da0f 100644 --- a/routemaster/state_machine/utils.py +++ b/routemaster/state_machine/utils.py @@ -3,8 +3,19 @@ import datetime import functools import contextlib -from typing import Any, Dict, List, Tuple, Optional, Sequence, Collection +from typing import ( + Any, + Dict, + List, + Tuple, + Callable, + Iterator, + Optional, + Sequence, + Collection, +) +import requests import dateutil.tz from sqlalchemy import func @@ -20,6 +31,8 @@ UnknownStateMachine, ) +ResponseLogger = Callable[[requests.Response], None] + def get_state_machine(app: App, label: LabelRef) -> StateMachine: """Finds the state machine instance by name in the app config.""" @@ -237,7 +250,7 @@ def context_for_label( accessed_variables.append(state.next_states.path) @contextlib.contextmanager - def feed_logging_context(feed_url): + def feed_logging_context(feed_url: str) -> Iterator[ResponseLogger]: with logger.process_feed(state_machine, state, feed_url): yield functools.partial( logger.feed_response, diff --git a/routemaster/utils.py b/routemaster/utils.py index c8a777e5..b5309488 100644 --- a/routemaster/utils.py +++ b/routemaster/utils.py @@ -1,9 +1,11 @@ """Shared utilities.""" import contextlib -from typing import Any, Dict, Sequence +from typing import Any, Dict, Iterator, Sequence +from routemaster.logging.base import BaseLogger -def dict_merge(d1, d2): + +def dict_merge(d1: Dict[str, Any], d2: Dict[str, Any]) -> Dict[str, Any]: """ Recursively merge two dicts to create a new dict. @@ -36,7 +38,7 @@ def get_path(path: Sequence[str], d: Dict[str, Any]) -> Any: @contextlib.contextmanager -def suppress_exceptions(logger): +def suppress_exceptions(logger: BaseLogger) -> Iterator[None]: """Catch all exceptions and log to a provided logger.""" try: yield diff --git a/routemaster/validation.py b/routemaster/validation.py index db75bc05..fcd4ddd2 100644 --- a/routemaster/validation.py +++ b/routemaster/validation.py @@ -1,12 +1,15 @@ """Validation of state machines.""" import collections +from typing import Union import networkx from sqlalchemy import func +import routemaster.config.model from routemaster.db import History from routemaster.app import App from routemaster.config import Config, StateMachine +from routemaster.conftest import TestApp class ValidationError(Exception): @@ -14,13 +17,13 @@ class ValidationError(Exception): pass -def validate_config(app: App, config: Config): +def validate_config(app: App, config: Config) -> None: """Validate that a given config satisfies invariants.""" for state_machine in config.state_machines.values(): _validate_state_machine(app, state_machine) -def _validate_state_machine(app: App, state_machine: StateMachine): +def _validate_state_machine(app: App, state_machine: StateMachine) -> None: """Validate that a given state machine is internally consistent.""" with app.new_session(): _validate_route_start_to_end(state_machine) @@ -38,13 +41,15 @@ def _build_graph(state_machine: StateMachine) -> networkx.Graph: return graph -def _validate_route_start_to_end(state_machine): +def _validate_route_start_to_end( + state_machine: routemaster.config.model.StateMachine, +) -> None: graph = _build_graph(state_machine) if not networkx.is_connected(graph): raise ValidationError("Graph is not fully connected") -def _validate_unique_state_names(state_machine): +def _validate_unique_state_names(state_machine: routemaster.config.model.StateMachine) -> None: state_name_counts = collections.Counter([ x.name for x in state_machine.states ]) @@ -58,7 +63,9 @@ def _validate_unique_state_names(state_machine): ) -def _validate_all_states_exist(state_machine): +def _validate_all_states_exist( + state_machine: routemaster.config.model.StateMachine, +) -> None: state_names = set(x.name for x in state_machine.states) for state in state_machine.states: for destination_name in state.next_states.all_destinations(): @@ -66,14 +73,18 @@ def _validate_all_states_exist(state_machine): raise ValidationError(f"{destination_name} does not exist") -def _validate_no_labels_in_nonexistent_states(state_machine, app): +def _validate_no_labels_in_nonexistent_states( + state_machine: routemaster.config.model.StateMachine, + app: Union[App, TestApp], +) -> None: states = [x.name for x in state_machine.states] states_by_rank = app.session.query( History.label_name, History.new_state, func.row_number().over( - order_by=History.id.desc(), + # TODO: use the sqlalchemy mypy plugin rather than our stubs file + order_by=History.id.desc(), # type: ignore[attr-defined] partition_by=History.label_name, ).label('rank'), ).filter_by(