diff --git a/routemaster/app.py b/routemaster/app.py index 7ad715cd..52c27c4b 100644 --- a/routemaster/app.py +++ b/routemaster/app.py @@ -1,7 +1,7 @@ """Core App singleton that holds state for the application.""" import threading import contextlib -from typing import Dict, Optional +from typing import Dict, Iterator, Optional from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.engine import Engine @@ -32,7 +32,7 @@ def __init__( self.config = config self.initialise() - def initialise(self): + def initialise(self) -> None: """ Initialise this instance of the app. @@ -66,7 +66,7 @@ def session(self) -> Session: return self._current_session - def set_rollback(self): + def set_rollback(self) -> None: """Mark the current session as needing rollback.""" if self._current_session is None: raise RuntimeError( @@ -77,7 +77,7 @@ def set_rollback(self): self._needs_rollback = True @contextlib.contextmanager - def new_session(self): + def new_session(self) -> Iterator[None]: """Run a single session in this scope.""" if self._current_session is not None: raise RuntimeError("There is already a session running.") diff --git a/routemaster/cli.py b/routemaster/cli.py index f1e705b2..0be72689 100644 --- a/routemaster/cli.py +++ b/routemaster/cli.py @@ -107,7 +107,7 @@ def post_fork(): cron_thread.stop() -def _validate_config(app: App): +def _validate_config(app: App) -> None: try: validate_config(app, app.config) except ValidationError as e: diff --git a/routemaster/conftest.py b/routemaster/conftest.py index d18b9b73..b815b881 100644 --- a/routemaster/conftest.py +++ b/routemaster/conftest.py @@ -8,7 +8,7 @@ import functools import contextlib import subprocess -from typing import Any, Dict +from typing import Any, Dict, Callable, Iterator, ContextManager from unittest import mock import pytest @@ -19,7 +19,10 @@ from sqlalchemy import create_engine from werkzeug.test import Client from sqlalchemy.orm import sessionmaker +from _pytest.fixtures import SubRequest +from sqlalchemy.orm.session import Session +import routemaster.config.model from routemaster import state_machine from routemaster.db import Label, History, metadata from routemaster.app import App @@ -45,6 +48,7 @@ from routemaster.logging import BaseLogger, SplitLogger, register_loggers from routemaster.webhooks import ( WebhookResult, + WebhookRunner, webhook_runner_for_state_machine, ) from routemaster.middleware import wrap_application @@ -235,7 +239,8 @@ class TestApp(App): 2. We can set a flag on access to `.db` so that we needn't bother with resetting the database if nothing has actually been changed. """ - def __init__(self, config): + + def __init__(self, config: routemaster.config.model.Config) -> None: self.config = config self.session_used = False self.logger = SplitLogger(config, loggers=register_loggers(config)) @@ -249,13 +254,13 @@ def __init__(self, config): } @property - def session(self): + def session(self) -> Session: """Start if necessary and return a shared session.""" self.session_used = True return super().session -def get_test_app(**kwargs): +def get_test_app(**kwargs) -> TestApp: """Instantiate an app with testing parameters.""" return TestApp(Config( state_machines=kwargs.get('state_machines', TEST_STATE_MACHINES), @@ -270,28 +275,28 @@ def get_test_app(**kwargs): @pytest.fixture() -def client(custom_app=None): +def client(custom_app: None = None) -> Client: """Create a werkzeug test client.""" _app = get_test_app() if custom_app is None else custom_app - server.config.app = _app + server.config.app = _app # type: ignore[attr-defined] _app.logger.init_flask(server) return Client(wrap_application(_app, server), werkzeug.Response) @pytest.fixture() -def app(**kwargs): +def app(**kwargs: Any) -> TestApp: """Create an `App` config object for testing.""" return get_test_app(**kwargs) @pytest.fixture() -def custom_app(): +def custom_app() -> Callable[..., TestApp]: """Return the test app generator so that we can pass in custom config.""" return get_test_app @pytest.fixture() -def app_env(): +def app_env() -> Dict[str, str]: """ Create a dict of environment variables. @@ -307,7 +312,7 @@ def app_env(): @pytest.fixture(autouse=True, scope='session') -def database_creation(request): +def database_creation(request: SubRequest) -> Iterator[None]: """Wrap test session in creating and destroying all required tables.""" metadata.drop_all(bind=TEST_ENGINE) metadata.create_all(bind=TEST_ENGINE) @@ -315,7 +320,7 @@ def database_creation(request): @pytest.fixture(autouse=True) -def database_clear(app): +def database_clear(app: TestApp) -> Iterator[None]: """Truncate all tables after each test.""" yield if app.session_used: @@ -328,7 +333,10 @@ def database_clear(app): @pytest.fixture() -def create_label(app, mock_test_feed): +def create_label( + app: TestApp, + mock_test_feed: Callable[[], ContextManager[None]], +) -> Callable[[str, str, Dict[str, Any]], LabelRef]: """Create a label in the database.""" def _create( @@ -348,7 +356,7 @@ def _create( @pytest.fixture() -def delete_label(app): +def delete_label(app: TestApp) -> Callable[[str, str], None]: """ Mark a label in the database as deleted. """ @@ -364,7 +372,10 @@ def _delete(name: str, state_machine_name: str) -> None: @pytest.fixture() -def create_deleted_label(create_label, delete_label): +def create_deleted_label( + create_label: Callable[[str, str, Dict[str, Any]], LabelRef], + delete_label: Callable[[str, str], None], +) -> Callable[[str, str], LabelRef]: """ Create a label in the database and then delete it. """ @@ -378,7 +389,10 @@ def _create_and_delete(name: str, state_machine_name: str) -> LabelRef: @pytest.fixture() -def mock_webhook(): +def mock_webhook() -> Callable[ + [WebhookResult], + ContextManager[Callable[[StateMachine], WebhookRunner]], +]: """Mock the test config's webhook call.""" @contextlib.contextmanager def _mock(result=WebhookResult.SUCCESS): @@ -392,7 +406,7 @@ def _mock(result=WebhookResult.SUCCESS): @pytest.fixture() -def mock_test_feed(): +def mock_test_feed() -> Callable[[Dict[str, Any]], ContextManager[None]]: """Mock out the test feed.""" @contextlib.contextmanager def _mock(data={'should_do_alternate_action': False}): @@ -414,7 +428,7 @@ def _mock(data={'should_do_alternate_action': False}): @pytest.fixture() -def assert_history(app): +def assert_history(app: TestApp) -> Callable: """Assert that the database history matches what is expected.""" def _assert(entries): with app.new_session(): @@ -432,7 +446,7 @@ def _assert(entries): @pytest.fixture() -def set_metadata(app): +def set_metadata(app: TestApp) -> Callable: """Directly set the metadata for a label in the database.""" def _inner(label, update): with app.new_session(): @@ -449,7 +463,7 @@ def _inner(label, update): @pytest.fixture() -def make_context(app): +def make_context(app: TestApp) -> Callable: """Factory for Contexts that provides sane defaults for testing.""" def _inner(**kwargs): logger = BaseLogger(app.config) @@ -491,9 +505,9 @@ def version(): @pytest.fixture() -def current_state(app): +def current_state(app: TestApp) -> Callable[[LabelRef], str]: """Get the current state of a label.""" - def _inner(label): + def _inner(label: LabelRef) -> str: with app.new_session(): return app.session.query( History.new_state, @@ -501,13 +515,14 @@ def _inner(label): label_name=label.name, label_state_machine=label.state_machine, ).order_by( - History.id.desc(), + # TODO: use the sqlalchemy mypy plugin rather than our stubs + History.id.desc(), # type: ignore[attr-defined] ).limit(1).scalar() return _inner @pytest.fixture() -def unused_tcp_port(): +def unused_tcp_port() -> int: """Returns an unused TCP port, inspired by pytest-asyncio.""" with contextlib.closing(socket.socket()) as sock: sock.bind(('127.0.0.1', 0)) @@ -515,7 +530,7 @@ def unused_tcp_port(): @pytest.fixture() -def routemaster_serve_subprocess(unused_tcp_port): +def routemaster_serve_subprocess(unused_tcp_port: int) -> Callable: """ Fixture to spawn a routemaster server as a subprocess. diff --git a/routemaster/context.py b/routemaster/context.py index 010a1f9c..0e0e1d9b 100644 --- a/routemaster/context.py +++ b/routemaster/context.py @@ -1,10 +1,25 @@ """Context definition for exit condition programs.""" import datetime -from typing import Any, Dict, Iterable, Optional, Sequence +from typing import ( + Any, + Dict, + Tuple, + Union, + Callable, + Iterable, + Optional, + Sequence, + ContextManager, +) + +import requests from routemaster.feeds import Feed from routemaster.utils import get_path +ResponseLogger = Callable[[requests.Response], None] +FeedLoggingContext = Callable[[str], ContextManager[ResponseLogger]] + class Context(object): """Execution context for exit condition programs.""" @@ -18,7 +33,7 @@ def __init__( feeds: Dict[str, Feed], accessed_variables: Iterable[str], current_history_entry: Optional[Any], - feed_logging_context, + feed_logging_context: FeedLoggingContext, ) -> None: """Create an execution context.""" if now.tzinfo is None: @@ -65,7 +80,12 @@ def _lookup_history(self, path: Sequence[str]) -> Any: 'previous_state': self.current_history_entry.old_state, }[variable_name] - def property_handler(self, property_name, value, **kwargs): + def property_handler( + self, + property_name: Union[Tuple[str, ...]], + value: Any, + **kwargs: Any, + ) -> bool: """Handle a property in execution.""" if property_name == ('passed',): epoch = kwargs['since'] @@ -82,8 +102,8 @@ def _pre_warm_feeds( self, label: str, accessed_variables: Iterable[str], - logging_context, - ): + logging_context: FeedLoggingContext, + ) -> None: for accessed_variable in accessed_variables: parts = accessed_variable.split('.') diff --git a/routemaster/cron.py b/routemaster/cron.py index ba490917..f2dae883 100644 --- a/routemaster/cron.py +++ b/routemaster/cron.py @@ -82,7 +82,7 @@ def process_job( # Bound when scheduling a specific job for a state fn: LabelStateProcessor, label_provider: LabelProvider, -): +) -> None: """Process a single instance of a single cron job.""" def _iter_labels_until_terminating( diff --git a/routemaster/exit_conditions/analysis.py b/routemaster/exit_conditions/analysis.py index c1a840ee..2b4158ea 100644 --- a/routemaster/exit_conditions/analysis.py +++ b/routemaster/exit_conditions/analysis.py @@ -1,10 +1,12 @@ """Analysis of compiled programs.""" +from typing import Any, Tuple, Union, Iterator + from routemaster.exit_conditions.operations import Operation -def find_accessed_keys(instructions): +def find_accessed_keys(instructions: Any) -> Iterator[Union[Tuple[str, str], Tuple[str], Tuple[str, str, str]]]: """Yield each key accessed under the program.""" for instruction, *args in instructions: if instruction == Operation.LOOKUP: diff --git a/routemaster/exit_conditions/error_display.py b/routemaster/exit_conditions/error_display.py index b9dde3fe..eaa49032 100644 --- a/routemaster/exit_conditions/error_display.py +++ b/routemaster/exit_conditions/error_display.py @@ -1,7 +1,8 @@ """Human-readable ParseError handling.""" +from typing import Tuple -def _find_line_containing(source, index): +def _find_line_containing(source: str, index: int) -> Tuple[int, str, int]: """Find (line number, line, offset) triple for an index into a string.""" lines = source.splitlines() @@ -20,7 +21,7 @@ def _find_line_containing(source, index): raise AssertionError("index >> len(source)") -def format_parse_error_message(*, source, error): +def format_parse_error_message(*, source, error) -> str: """Format a parse error on some source for nicer display.""" error_line_number, error_line, error_offset = _find_line_containing( source, diff --git a/routemaster/exit_conditions/evaluator.py b/routemaster/exit_conditions/evaluator.py index 30a1f61d..2650a270 100644 --- a/routemaster/exit_conditions/evaluator.py +++ b/routemaster/exit_conditions/evaluator.py @@ -1,63 +1,115 @@ """Exit condition program evaluator.""" -from routemaster.exit_conditions.operations import Operation - +import datetime +from typing import Any, Dict, List, Tuple, Union, Callable, Iterable -def _evaluate_to_bool(stack, lookup, property_handler): +from routemaster.exit_conditions.operations import Operation +from routemaster.exit_conditions.prepositions import Preposition + +Stack = List[Union[ + Any, + bool, + int, + str, + None, + datetime.datetime, + Tuple[int, int], +]] + + +def _evaluate_to_bool( + stack: Stack, + lookup: Callable, + property_handler: Callable, +) -> None: top_of_stack = stack.pop() stack.append(bool(top_of_stack)) -def _evaluate_not(stack, lookup, property_handler): +def _evaluate_not( + stack: Stack, + lookup: Callable, + property_handler: Callable, +) -> None: top_of_stack = stack.pop() stack.append(not top_of_stack) -def _evaluate_and(stack, lookup, property_handler): +def _evaluate_and( + stack: Stack, + lookup: Callable, + property_handler: Callable, +) -> None: rhs = stack.pop() lhs = stack.pop() stack.append(lhs and rhs) -def _evaluate_or(stack, lookup, property_handler): +def _evaluate_or( + stack: Stack, + lookup: Callable, + property_handler: Callable, +) -> None: rhs = stack.pop() lhs = stack.pop() stack.append(lhs or rhs) -def _evaluate_literal(stack, lookup, property_handler, value): +def _evaluate_literal( + stack: Stack, + lookup: Callable, + property_handler: Callable, + value: Any, +) -> None: stack.append(value) -def _evaluate_lookup(stack, lookup, property_handler, key): +def _evaluate_lookup( + stack: Stack, + lookup: Callable, + property_handler: Callable, + key: Tuple[str, ...], +) -> None: stack.append(lookup(key)) -def _evaluate_eq(stack, lookup, property_handler): +def _evaluate_eq( + stack: Stack, + lookup: Callable, + property_handler: Callable, +) -> None: rhs = stack.pop() lhs = stack.pop() stack.append(lhs == rhs) -def _evaluate_lt(stack, lookup, property_handler): +def _evaluate_lt( + stack: Stack, + lookup: Callable, + property_handler: Callable, +) -> None: rhs = stack.pop() lhs = stack.pop() - stack.append(lhs < rhs) + stack.append(lhs < rhs) # type: ignore[operator] -def _evaluate_gt(stack, lookup, property_handler): +def _evaluate_gt( + stack: Stack, + lookup: Callable, + property_handler: Callable, +) -> None: rhs = stack.pop() lhs = stack.pop() - stack.append(lhs > rhs) + stack.append(lhs > rhs) # type: ignore[operator] def _evaluate_property( - stack, - lookup, - property_handler, - property_name, - prepositions, -): + stack: Stack, + lookup: Callable, + property_handler: Callable, + property_name: Tuple[str, ...], + prepositions: Tuple[Preposition, ...], +) -> None: prepositional_arguments = {} for preposition in reversed(prepositions): prepositional_arguments[preposition.value] = stack.pop() @@ -67,7 +119,7 @@ def _evaluate_property( ) -EVALUATORS = { +EVALUATORS: Dict[Operation, Callable] = { Operation.TO_BOOL: _evaluate_to_bool, Operation.AND: _evaluate_and, Operation.OR: _evaluate_or, @@ -81,13 +133,17 @@ def _evaluate_property( } -def evaluate(instructions, lookup, property_handler): +def evaluate( + instructions: Iterable[Any], + lookup: Callable, + property_handler: Callable, +) -> bool: """ Run the instructions given in `instructions`. Returns the single result. """ - stack = [] + stack: Stack = [] for instruction, *args in instructions: EVALUATORS[instruction](stack, lookup, property_handler, *args) - return stack.pop() + return stack.pop() # type: ignore[return-value] diff --git a/routemaster/exit_conditions/exceptions.py b/routemaster/exit_conditions/exceptions.py index 9871078f..8dd537df 100644 --- a/routemaster/exit_conditions/exceptions.py +++ b/routemaster/exit_conditions/exceptions.py @@ -1,10 +1,11 @@ """Exceptions for use in exit condition handling.""" +from typing import Tuple class ParseError(Exception): """Errors that occur when tokenizing or parsing.""" - def __init__(self, message, location): + def __init__(self, message: str, location: Tuple[int, int]) -> None: """ Construct by message and location. diff --git a/routemaster/exit_conditions/peephole.py b/routemaster/exit_conditions/peephole.py index e77e03f0..f33f539e 100644 --- a/routemaster/exit_conditions/peephole.py +++ b/routemaster/exit_conditions/peephole.py @@ -1,5 +1,7 @@ """Peephole evaluator optimiser.""" +from typing import Any, List + from routemaster.exit_conditions.operations import Operation MATCHERS = [ @@ -77,7 +79,7 @@ ] -def peephole_optimise(instructions): +def peephole_optimise(instructions: Any) -> List[Any]: """Run peephole optimisations over a given instruction sequence.""" instructions = list(instructions) diff --git a/routemaster/feeds.py b/routemaster/feeds.py index 11dec059..9c337c7c 100644 --- a/routemaster/feeds.py +++ b/routemaster/feeds.py @@ -1,14 +1,16 @@ """Creation and fetching of feed data.""" import threading -from typing import Any, Dict, Callable, Optional +from typing import Any, Dict, Union, Callable, Optional, Sequence from dataclasses import dataclass import requests +from requests.sessions import Session from routemaster.utils import get_path, template_url +from routemaster.config.model import StateMachine -def feeds_for_state_machine(state_machine) -> Dict[str, 'Feed']: +def feeds_for_state_machine(state_machine: StateMachine) -> Dict[str, 'Feed']: """Get a mapping of feed prefixes to unfetched feeds.""" return { x.name: Feed(x.url, state_machine.name) # type: ignore @@ -24,7 +26,7 @@ class FeedNotFetched(Exception): _feed_sessions = threading.local() -def _get_feed_session(): +def _get_feed_session() -> Session: # We cache sessions per thread so that we can use `requests.Session`'s # underlying `urllib3` connection pooling. if not hasattr(_feed_sessions, 'session'): @@ -59,7 +61,7 @@ def prefetch( response.raise_for_status() self.data = response.json() - def lookup(self, path): + def lookup(self, path: Union[Sequence[str]]) -> Optional[Union[bool, str]]: """Lookup data from a feed's contents.""" if self.data is None: raise FeedNotFetched(self.url) diff --git a/routemaster/logging/base.py b/routemaster/logging/base.py index 3c04359f..0615b8d2 100644 --- a/routemaster/logging/base.py +++ b/routemaster/logging/base.py @@ -1,15 +1,21 @@ """Base class for logging plugins.""" import contextlib +from typing import Any, Dict, Type, Tuple, Callable, Iterator, Optional + +import requests +from flask.app import Flask + +from routemaster.config.model import State, Config, StateMachine class BaseLogger: """Base class for logging plugins.""" - def __init__(self, config, *args, **kwargs) -> None: + def __init__(self, config: Optional[Config], *args, **kwargs) -> None: self.config = config - def init_flask(self, flask_app): + def init_flask(self, flask_app: Flask) -> None: """ Entrypoint for configuring logging on the flask server. @@ -19,55 +25,69 @@ def init_flask(self, flask_app): pass @contextlib.contextmanager - def process_cron(self, state_machine, state, fn_name): + def process_cron( + self, + state_machine: StateMachine, + state: State, + fn_name: str, + ) -> Iterator[None]: """Wraps the processing of a cron job for logging purposes.""" yield @contextlib.contextmanager - def process_webhook(self, state_machine, state): + def process_webhook( + self, + state_machine: StateMachine, + state: State, + ) -> Iterator[None]: """Wraps the processing of a webhook for logging purposes.""" yield - def process_request_started(self, environ): + def process_request_started(self, environ: Dict[str, Any]) -> None: """Request started.""" pass def process_request_finished( self, - environ, + environ: Dict[str, Any], *, - status, - headers, - exc_info, - ): + status: int, + headers: Dict[str, Any], + exc_info: Optional[Tuple[Type[RuntimeError], RuntimeError, None]], + ) -> None: """Completes the processing of a request.""" pass def webhook_response( self, - state_machine, - state, - response, - ): + state_machine: StateMachine, + state: State, + response: requests.Response, + ) -> None: """Logs the receipt of a response from a webhook.""" pass @contextlib.contextmanager - def process_feed(self, state_machine, state, feed_url): + def process_feed( + self, + state_machine: StateMachine, + state: State, + feed_url: str, + ) -> Iterator[None]: """Wraps the processing of a feed for logging purposes.""" yield def feed_response( self, - state_machine, - state, - feed_url, - response, - ): + state_machine: StateMachine, + state: State, + feed_url: str, + response: requests.Response, + ) -> None: """Logs the receipt of a response from a feed.""" pass - def __getattr__(self, name): + def __getattr__(self, name: str) -> Callable[[str], None]: """Implement the Python logger API.""" if name in ( 'debug', @@ -81,5 +101,5 @@ def __getattr__(self, name): raise AttributeError(name) - def _log_handler(self, *args, **kwargs): + def _log_handler(self, *args, **kwargs) -> None: pass diff --git a/routemaster/logging/plugins.py b/routemaster/logging/plugins.py index 76afc62e..9398073a 100644 --- a/routemaster/logging/plugins.py +++ b/routemaster/logging/plugins.py @@ -1,15 +1,17 @@ """Plugin loading and configuration.""" import importlib +from typing import Any, List, Union from routemaster.config import Config, LoggingPluginConfig from routemaster.logging.base import BaseLogger +from routemaster.logging.python_logger import PythonLogger class PluginConfigurationException(Exception): """Raised to signal an invalid plugin that was loaded.""" -def register_loggers(config: Config): +def register_loggers(config: Config) -> List[Union[Any, PythonLogger, BaseLogger]]: """ Iterate through all plugins in the config file and instatiate them. """ diff --git a/routemaster/logging/python_logger.py b/routemaster/logging/python_logger.py index 2d0b8ef7..85585282 100644 --- a/routemaster/logging/python_logger.py +++ b/routemaster/logging/python_logger.py @@ -3,7 +3,9 @@ import time import logging import contextlib +from typing import Any, Dict, Type, Tuple, Callable, Iterator, Optional +from routemaster.config.model import State, StateMachine from routemaster.logging.base import BaseLogger @@ -25,7 +27,7 @@ def __init__(self, *args, log_level: str) -> None: self.logger.info(f"Started logger with level {log_level}") @contextlib.contextmanager - def process_cron(self, state_machine, state, fn_name): + def process_cron(self, state_machine: StateMachine, state: State, fn_name: str) -> Iterator[None]: """Process a cron job, logging information to the Python logger.""" self.logger.info( f"Started cron {fn_name} for state {state.name} in " @@ -46,12 +48,12 @@ def process_cron(self, state_machine, state, fn_name): def process_request_finished( self, - environ, + environ: Dict[str, Any], *, - status, - headers, - exc_info, - ): + status: int, + headers: Dict[str, Any], + exc_info: Optional[Tuple[Type[RuntimeError], RuntimeError, None]], + ) -> None: """Process a web request and log some basic info about it.""" self.info("{method} {path} {status}".format( method=environ.get('REQUEST_METHOD'), @@ -59,6 +61,6 @@ def process_request_finished( status=status, )) - def __getattr__(self, name): + def __getattr__(self, name: str) -> Callable: """Fall back to the logger API.""" return getattr(self.logger, name) diff --git a/routemaster/logging/split_logger.py b/routemaster/logging/split_logger.py index 8021619c..f961268f 100644 --- a/routemaster/logging/split_logger.py +++ b/routemaster/logging/split_logger.py @@ -2,7 +2,7 @@ import functools import contextlib -from typing import List +from typing import List, Iterator from routemaster.logging.base import BaseLogger @@ -39,12 +39,12 @@ def __init__(self, *args, loggers: List[BaseLogger]) -> None: ): setattr(self, fn, functools.partial(self._log_all_ctx, fn)) - def _log_all(self, name, *args, **kwargs): + def _log_all(self, name: str, *args, **kwargs) -> None: for logger in self.loggers: getattr(logger, name)(*args, **kwargs) @contextlib.contextmanager - def _log_all_ctx(self, name, *args, **kwargs): + def _log_all_ctx(self, name: str, *args, **kwargs) -> Iterator[None]: with contextlib.ExitStack() as stack: for logger in self.loggers: logger_ctx = getattr(logger, name) diff --git a/routemaster/middleware.py b/routemaster/middleware.py index 0cf8ef58..aa90fee5 100644 --- a/routemaster/middleware.py +++ b/routemaster/middleware.py @@ -23,7 +23,7 @@ ACTIVE_MIDDLEWARES = [] -def middleware(fn: WSGIMiddleware): +def middleware(fn: WSGIMiddleware) -> Callable: """Decorator: add `fn` to ACTIVE_MIDDLEWARES.""" ACTIVE_MIDDLEWARES.append(fn) return fn diff --git a/routemaster/migrations/versions/091a6e84d9ac_initial_state_machine.py b/routemaster/migrations/versions/091a6e84d9ac_initial_state_machine.py index 6c7f881f..26e11fac 100644 --- a/routemaster/migrations/versions/091a6e84d9ac_initial_state_machine.py +++ b/routemaster/migrations/versions/091a6e84d9ac_initial_state_machine.py @@ -2,7 +2,7 @@ Initial state machine Revision ID: 091a6e84d9ac -Revises: +Revises: """ import sqlalchemy as sa diff --git a/routemaster/server/endpoints.py b/routemaster/server/endpoints.py index 9d3038ba..00d9a4c4 100644 --- a/routemaster/server/endpoints.py +++ b/routemaster/server/endpoints.py @@ -1,10 +1,14 @@ """Core API endpoints for routemaster service.""" +from typing import Tuple, Union + import sqlalchemy import pkg_resources from flask import Flask, abort, jsonify, request +from flask.wrappers import Response from routemaster import state_machine +from routemaster.app import App from routemaster.state_machine import ( LabelRef, UnknownLabel, @@ -17,7 +21,7 @@ @server.route('/', methods=['GET']) -def status(): +def status() -> Union[Tuple[Response, int], Response]: """ Status check endpoint. @@ -27,13 +31,16 @@ def status(): - 503 Service Unavailable: if there is any detected reason why the service might not be able to serve requests. """ + by_key = pkg_resources.working_set.by_key # type: ignore[attr-defined] try: - version = pkg_resources.working_set.by_key['routemaster'].version + version = by_key['routemaster'].version except KeyError: # pragma: no cover version = 'development' + app: App = server.config.app # type: ignore[attr-defined] # see cli.py + try: - server.config.app.session.query(sqlalchemy.literal(1)).one() + app.session.query(sqlalchemy.literal(1)).one() return jsonify({ 'status': 'ok', 'state-machines': '/state-machines', @@ -48,20 +55,22 @@ def status(): @server.route('/state-machines', methods=['GET']) -def get_state_machines(): +def get_state_machines() -> Response: """ List the state machines known to this server. Successful return codes return a list of dictionaries containing at least the name of each state machine. """ + app: App = server.config.app # type: ignore[attr-defined] # see cli.py + return jsonify({ 'state-machines': [ { 'name': x.name, 'labels': f'/state-machines/{x.name}/labels', } - for x in server.config.app.config.state_machines.values() + for x in app.config.state_machines.values() ], }) @@ -70,7 +79,7 @@ def get_state_machines(): '/state-machines//labels', methods=['GET'], ) -def get_labels(state_machine_name): +def get_labels(state_machine_name: str) -> Response: """ List the labels in a state machine. @@ -81,7 +90,7 @@ def get_labels(state_machine_name): Successful return codes return a list of dictionaries containing at least the name of each label. """ - app = server.config.app + app: App = server.config.app # type: ignore[attr-defined] # see cli.py try: state_machine_instance = app.config.state_machines[state_machine_name] @@ -100,7 +109,7 @@ def get_labels(state_machine_name): '/state-machines//labels/', methods=['GET'], ) -def get_label(state_machine_name, label_name): +def get_label(state_machine_name: str, label_name: str) -> Response: """ Get a label within a given state machine. @@ -111,7 +120,7 @@ def get_label(state_machine_name, label_name): Successful return codes return the full metadata for the label. """ - app = server.config.app + app: App = server.config.app # type: ignore[attr-defined] # see cli.py label = LabelRef(label_name, state_machine_name) try: @@ -126,6 +135,7 @@ def get_label(state_machine_name, label_name): abort(404, f"State machine '{label.state_machine}' does not exist.") state = state_machine.get_label_state(app, label) + assert state is not None # label deletion handled above return jsonify(metadata=metadata, state=state.name) @@ -133,7 +143,10 @@ def get_label(state_machine_name, label_name): '/state-machines//labels/', methods=['POST'], ) -def create_label(state_machine_name, label_name): +def create_label( + state_machine_name: str, + label_name: str, +) -> Tuple[Response, int]: """ Create a label with a given metadata, and start it in the state machine. @@ -145,14 +158,16 @@ def create_label(state_machine_name, label_name): Successful return codes return the full created metadata for the label. """ - app = server.config.app + app: App = server.config.app # type: ignore[attr-defined] # see cli.py label = LabelRef(label_name, state_machine_name) data = request.get_json() try: - initial_metadata = data['metadata'] + initial_metadata = data['metadata'] # type: ignore[index] except KeyError: abort(400, "No metadata given") + except TypeError: + abort(400, "Metadata must be a mapping") try: initial_state_name = \ @@ -168,10 +183,10 @@ def create_label(state_machine_name, label_name): @server.route( - '/state-machines//labels/', # noqa + '/state-machines//labels/', # noqa methods=['PATCH'], ) -def update_label(state_machine_name, label_name): +def update_label(state_machine_name: str, label_name: str) -> Response: """ Update a label in a state machine. @@ -186,12 +201,13 @@ def update_label(state_machine_name, label_name): Successful return codes return the full new metadata for a label. """ - app = server.config.app + app: App = server.config.app # type: ignore[attr-defined] # see cli.py label = LabelRef(label_name, state_machine_name) + data = request.get_json() try: - patch_metadata = request.get_json()['metadata'] - except KeyError: + patch_metadata = data['metadata'] # type: ignore[index] + except (TypeError, KeyError): abort(400, "No new metadata") try: @@ -200,8 +216,6 @@ def update_label(state_machine_name, label_name): label, patch_metadata, ) - state = state_machine.get_label_state(app, label) - return jsonify(metadata=new_metadata, state=state.name) except UnknownStateMachine: msg = f"State machine '{state_machine_name}' does not exist" abort(404, msg) @@ -212,12 +226,16 @@ def update_label(state_machine_name, label_name): f"'{state_machine_name}'.", ) + state = state_machine.get_label_state(app, label) + assert state is not None # label deletion handled above + return jsonify(metadata=new_metadata, state=state.name) + @server.route( '/state-machines//labels/', methods=['DELETE'], ) -def delete_label(state_machine_name, label_name): +def delete_label(state_machine_name: str, label_name: str) -> Tuple[str, int]: """ Delete a label in a state machine. @@ -228,7 +246,7 @@ def delete_label(state_machine_name, label_name): - 204 No content: if the label is successfully deleted (or did not exist). - 404 Not Found: if the state machine does not exist. """ - app = server.config.app + app: App = server.config.app # type: ignore[attr-defined] # see cli.py label = LabelRef(label_name, state_machine_name) try: diff --git a/routemaster/state_machine/actions.py b/routemaster/state_machine/actions.py index 715b01a8..21a07687 100644 --- a/routemaster/state_machine/actions.py +++ b/routemaster/state_machine/actions.py @@ -4,6 +4,7 @@ import hashlib import functools +import routemaster.db.model from routemaster.db import History from routemaster.app import App from routemaster.utils import template_url @@ -95,7 +96,7 @@ def process_action( return True -def _calculate_idempotency_token(label: LabelRef, latest_history) -> str: +def _calculate_idempotency_token(label: LabelRef, latest_history: routemaster.db.model.History) -> str: """ We want to make sure that an action is only performed once. diff --git a/routemaster/state_machine/api.py b/routemaster/state_machine/api.py index 8bf40ee0..e7b0b4c1 100644 --- a/routemaster/state_machine/api.py +++ b/routemaster/state_machine/api.py @@ -159,7 +159,7 @@ def _process_transitions_for_metadata_update( label: LabelRef, state_machine: StateMachine, state_pending_update: State, -): +) -> None: with app.session.begin_nested(): lock_label(app, label) current_state = get_current_state(app, label, state_machine) @@ -224,6 +224,7 @@ def delete_label(app: App, label: LabelRef) -> None: class LabelStateProcessor(Protocol): """Type signature for the label state processor callable.""" + def __call__( self, *, @@ -247,7 +248,7 @@ def process_cron( app: App, state_machine: StateMachine, state: State, -): +) -> None: """ Cron event entrypoint. """ diff --git a/routemaster/state_machine/utils.py b/routemaster/state_machine/utils.py index 2560805b..ea76da0f 100644 --- a/routemaster/state_machine/utils.py +++ b/routemaster/state_machine/utils.py @@ -3,8 +3,19 @@ import datetime import functools import contextlib -from typing import Any, Dict, List, Tuple, Optional, Sequence, Collection +from typing import ( + Any, + Dict, + List, + Tuple, + Callable, + Iterator, + Optional, + Sequence, + Collection, +) +import requests import dateutil.tz from sqlalchemy import func @@ -20,6 +31,8 @@ UnknownStateMachine, ) +ResponseLogger = Callable[[requests.Response], None] + def get_state_machine(app: App, label: LabelRef) -> StateMachine: """Finds the state machine instance by name in the app config.""" @@ -237,7 +250,7 @@ def context_for_label( accessed_variables.append(state.next_states.path) @contextlib.contextmanager - def feed_logging_context(feed_url): + def feed_logging_context(feed_url: str) -> Iterator[ResponseLogger]: with logger.process_feed(state_machine, state, feed_url): yield functools.partial( logger.feed_response, diff --git a/routemaster/utils.py b/routemaster/utils.py index c8a777e5..b5309488 100644 --- a/routemaster/utils.py +++ b/routemaster/utils.py @@ -1,9 +1,11 @@ """Shared utilities.""" import contextlib -from typing import Any, Dict, Sequence +from typing import Any, Dict, Iterator, Sequence +from routemaster.logging.base import BaseLogger -def dict_merge(d1, d2): + +def dict_merge(d1: Dict[str, Any], d2: Dict[str, Any]) -> Dict[str, Any]: """ Recursively merge two dicts to create a new dict. @@ -36,7 +38,7 @@ def get_path(path: Sequence[str], d: Dict[str, Any]) -> Any: @contextlib.contextmanager -def suppress_exceptions(logger): +def suppress_exceptions(logger: BaseLogger) -> Iterator[None]: """Catch all exceptions and log to a provided logger.""" try: yield diff --git a/routemaster/validation.py b/routemaster/validation.py index db75bc05..fcd4ddd2 100644 --- a/routemaster/validation.py +++ b/routemaster/validation.py @@ -1,12 +1,15 @@ """Validation of state machines.""" import collections +from typing import Union import networkx from sqlalchemy import func +import routemaster.config.model from routemaster.db import History from routemaster.app import App from routemaster.config import Config, StateMachine +from routemaster.conftest import TestApp class ValidationError(Exception): @@ -14,13 +17,13 @@ class ValidationError(Exception): pass -def validate_config(app: App, config: Config): +def validate_config(app: App, config: Config) -> None: """Validate that a given config satisfies invariants.""" for state_machine in config.state_machines.values(): _validate_state_machine(app, state_machine) -def _validate_state_machine(app: App, state_machine: StateMachine): +def _validate_state_machine(app: App, state_machine: StateMachine) -> None: """Validate that a given state machine is internally consistent.""" with app.new_session(): _validate_route_start_to_end(state_machine) @@ -38,13 +41,15 @@ def _build_graph(state_machine: StateMachine) -> networkx.Graph: return graph -def _validate_route_start_to_end(state_machine): +def _validate_route_start_to_end( + state_machine: routemaster.config.model.StateMachine, +) -> None: graph = _build_graph(state_machine) if not networkx.is_connected(graph): raise ValidationError("Graph is not fully connected") -def _validate_unique_state_names(state_machine): +def _validate_unique_state_names(state_machine: routemaster.config.model.StateMachine) -> None: state_name_counts = collections.Counter([ x.name for x in state_machine.states ]) @@ -58,7 +63,9 @@ def _validate_unique_state_names(state_machine): ) -def _validate_all_states_exist(state_machine): +def _validate_all_states_exist( + state_machine: routemaster.config.model.StateMachine, +) -> None: state_names = set(x.name for x in state_machine.states) for state in state_machine.states: for destination_name in state.next_states.all_destinations(): @@ -66,14 +73,18 @@ def _validate_all_states_exist(state_machine): raise ValidationError(f"{destination_name} does not exist") -def _validate_no_labels_in_nonexistent_states(state_machine, app): +def _validate_no_labels_in_nonexistent_states( + state_machine: routemaster.config.model.StateMachine, + app: Union[App, TestApp], +) -> None: states = [x.name for x in state_machine.states] states_by_rank = app.session.query( History.label_name, History.new_state, func.row_number().over( - order_by=History.id.desc(), + # TODO: use the sqlalchemy mypy plugin rather than our stubs file + order_by=History.id.desc(), # type: ignore[attr-defined] partition_by=History.label_name, ).label('rank'), ).filter_by(