From e91546e6c80f966014367c3f3ecc2ec246ca8051 Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Fri, 17 Jan 2025 15:31:10 -0600 Subject: [PATCH] Pyright Improvements (#932) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Pull Request ## Title Pyright improvements ______________________________________________________________________ ## Description Pyright is a type checker that ships with VSCode's Pylance by default. It is billed as a faster, though less complete, version of mypy. As such it gets a few things a little differently that mypy and alerts in VSCode. This PR fixes those ("standard") alerts and removes the mypy extension from VSCode's default extensions for MLOS in favor of just using pyright (there's no sense in running both interactively). We do not enable pyright's "strict" mode. Additionally, it enables pyright in pre-commit rules to ensure those fixes remain. We leave the rest of the mypy checks as well since they are still useful. A list of some of the types of fixes: - TypeDict initialization checks for Tunables - Check that json.loads() returns a dict and not a list (e.g.) - Replace ConcreteOptimizer TypeVar with a TypeAlias - Add BoundMethod protocol for checking __self__ attribute - Ensure correct type inference in a number of places - Add `...` to Protocol methods to make pyright aware of the lack of method body. - Fix a few type annotations ______________________________________________________________________ ## Type of Change - 🛠️ Bug fix - 🔄 Refactor ______________________________________________________________________ ## Testing - Additional CI checks as described above. ______________________________________________________________________ --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .devcontainer/devcontainer.json | 1 - .pre-commit-config.yaml | 17 ++- .vscode/extensions.json | 1 - .vscode/settings.json | 3 +- conda-envs/mlos.yml | 1 + doc/source/conf.py | 4 + .../mlos_bench/environments/mock_env.py | 6 +- .../mlos_bench/optimizers/mock_optimizer.py | 2 +- .../mlos_bench/services/base_service.py | 3 +- .../services/remote/ssh/ssh_fileshare.py | 2 + .../services/remote/ssh/ssh_host_service.py | 1 + .../services/types/authenticator_type.py | 6 + .../mlos_bench/services/types/bound_method.py | 24 ++++ .../services/types/config_loader_type.py | 10 ++ .../services/types/host_ops_type.py | 6 + .../services/types/host_provisioner_type.py | 6 + .../services/types/local_exec_type.py | 7 +- .../types/network_provisioner_type.py | 5 + .../mlos_bench/services/types/os_ops_type.py | 5 + .../services/types/remote_config_type.py | 4 + .../services/types/remote_exec_type.py | 4 + .../services/types/vm_provisioner_type.py | 9 ++ .../mlos_bench/storage/sql/experiment.py | 6 +- .../composite_env_service_test.py | 2 + .../tests/event_loop_context_test.py | 10 +- .../mlos_bench/tests/launcher_run_test.py | 28 +++-- .../mlos_bench/tests/services/mock_service.py | 4 + .../services/remote/mock/mock_vm_service.py | 2 + .../tests/tunables/tunable_definition_test.py | 21 ++++ .../tunables/tunable_distributions_test.py | 1 + .../tunables/tunable_slice_references_test.py | 3 + .../tests/tunables/tunables_assign_test.py | 1 + .../tests/tunables/tunables_copy_test.py | 2 + mlos_bench/mlos_bench/tunables/tunable.py | 116 ++++++++++++++---- mlos_bench/setup.py | 3 +- mlos_core/mlos_core/__init__.py | 2 +- mlos_core/mlos_core/optimizers/__init__.py | 17 +-- .../bayesian_optimizers/smac_optimizer.py | 8 +- .../mlos_core/spaces/adapters/__init__.py | 20 +-- .../mlos_core/spaces/adapters/llamatune.py | 5 +- .../mlos_core/spaces/converters/flaml.py | 5 +- mlos_core/mlos_core/spaces/converters/util.py | 11 +- mlos_core/mlos_core/tests/__init__.py | 5 +- .../tests/optimizers/optimizer_test.py | 3 +- .../adapters/space_adapter_factory_test.py | 9 +- .../mlos_core/tests/spaces/spaces_test.py | 4 +- mlos_core/setup.py | 3 +- mlos_viz/mlos_viz/base.py | 2 +- mlos_viz/setup.py | 3 +- pyproject.toml | 6 + 50 files changed, 323 insertions(+), 106 deletions(-) create mode 100644 mlos_bench/mlos_bench/services/types/bound_method.py diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 91373d000a5..f9cbbc4a384 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -68,7 +68,6 @@ "huntertran.auto-markdown-toc", "ibm.output-colorizer", "lextudio.restructuredtext", - "matangover.mypy", "ms-azuretools.vscode-docker", "ms-python.black-formatter", "ms-python.pylint", diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3bbb03cc335..8e947e111bd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ ci: # Let pre-commit.ci automatically update PRs with formatting fixes. autofix_prs: true # skip local hooks - they should be managed manually via conda-envs/*.yml - skip: [mypy, pylint, pycodestyle] + skip: [mypy, pylint, pycodestyle, pyright] autoupdate_schedule: monthly autoupdate_commit_msg: | [pre-commit.ci] pre-commit autoupdate @@ -15,6 +15,7 @@ ci: See Also: - https://github.com/microsoft/MLOS/blob/main/conda-envs/mlos.yml - https://pypi.org/project/mypy/ + - https://pypi.org/project/pyright/ - https://pypi.org/project/pylint/ - https://pypi.org/project/pycodestyle/ @@ -140,6 +141,20 @@ repos: (?x)^( doc/source/conf.py )$ + - id: pyright + name: pyright + entry: pyright + language: system + types: [python] + require_serial: true + exclude: | + (?x)^( + doc/source/conf.py| + mlos_core/setup.py| + mlos_bench/setup.py| + mlos_viz/setup.py| + conftest.py + )$ - id: mypy name: mypy entry: mypy diff --git a/.vscode/extensions.json b/.vscode/extensions.json index 92725b501c0..ed56e7c520d 100644 --- a/.vscode/extensions.json +++ b/.vscode/extensions.json @@ -12,7 +12,6 @@ "huntertran.auto-markdown-toc", "ibm.output-colorizer", "lextudio.restructuredtext", - "matangover.mypy", "ms-azuretools.vscode-docker", "ms-python.black-formatter", "ms-python.pylint", diff --git a/.vscode/settings.json b/.vscode/settings.json index 406600fa4aa..23e43fed683 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -170,6 +170,5 @@ "python.testing.unittestEnabled": false, "debugpy.debugJustMyCode": false, "python.analysis.autoImportCompletions": true, - "python.analysis.supportRestructuredText": true, - "python.analysis.typeCheckingMode": "standard" + "python.analysis.supportRestructuredText": true } diff --git a/conda-envs/mlos.yml b/conda-envs/mlos.yml index 4d8f82390c9..96834c879d0 100644 --- a/conda-envs/mlos.yml +++ b/conda-envs/mlos.yml @@ -28,6 +28,7 @@ dependencies: - pylint==3.3.3 - tomlkit - mypy==1.14.1 + - pyright==1.1.392.post0 - pandas-stubs - types-beautifulsoup4 - types-colorama diff --git a/doc/source/conf.py b/doc/source/conf.py index 5ed19fc45e6..5ab3ce8bd0d 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -29,6 +29,10 @@ from sphinx.application import Sphinx as SphinxApp from sphinx.environment import BuildEnvironment +# Note: doc requirements aren't installed by default. +# To install them, run `pip install -r doc/requirements.txt` + + sys.path.insert(0, os.path.abspath("../../mlos_core/mlos_core")) sys.path.insert(1, os.path.abspath("../../mlos_bench/mlos_bench")) sys.path.insert(1, os.path.abspath("../../mlos_viz/mlos_viz")) diff --git a/mlos_bench/mlos_bench/environments/mock_env.py b/mlos_bench/mlos_bench/environments/mock_env.py index 23ec37a6392..4ddd9dee1c0 100644 --- a/mlos_bench/mlos_bench/environments/mock_env.py +++ b/mlos_bench/mlos_bench/environments/mock_env.py @@ -64,8 +64,8 @@ def __init__( # pylint: disable=too-many-arguments seed = int(self.config.get("mock_env_seed", -1)) self._run_random = random.Random(seed or None) if seed >= 0 else None self._status_random = random.Random(seed or None) if seed >= 0 else None - self._range = self.config.get("mock_env_range") - self._metrics = self.config.get("mock_env_metrics", ["score"]) + self._range: tuple[int, int] | None = self.config.get("mock_env_range") + self._metrics: list[str] | None = self.config.get("mock_env_metrics", ["score"]) self._is_ready = True def _produce_metrics(self, rand: random.Random | None) -> dict[str, TunableValue]: @@ -80,7 +80,7 @@ def _produce_metrics(self, rand: random.Random | None) -> dict[str, TunableValue if self._range: score = self._range[0] + score * (self._range[1] - self._range[0]) - return {metric: score for metric in self._metrics} + return {metric: float(score) for metric in self._metrics or []} def run(self) -> tuple[Status, datetime, dict[str, TunableValue] | None]: """ diff --git a/mlos_bench/mlos_bench/optimizers/mock_optimizer.py b/mlos_bench/mlos_bench/optimizers/mock_optimizer.py index 69176ebc8d8..1e138284b1e 100644 --- a/mlos_bench/mlos_bench/optimizers/mock_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/mock_optimizer.py @@ -32,7 +32,7 @@ def __init__( self._random: dict[str, Callable[[Tunable], TunableValue]] = { "categorical": lambda tunable: rnd.choice(tunable.categories), "float": lambda tunable: rnd.uniform(*tunable.range), - "int": lambda tunable: rnd.randint(*tunable.range), + "int": lambda tunable: rnd.randint(*(int(x) for x in tunable.range)), } def bulk_register( diff --git a/mlos_bench/mlos_bench/services/base_service.py b/mlos_bench/mlos_bench/services/base_service.py index 24e9d493078..41eebfbb98b 100644 --- a/mlos_bench/mlos_bench/services/base_service.py +++ b/mlos_bench/mlos_bench/services/base_service.py @@ -14,6 +14,7 @@ from typing import Any, Literal from mlos_bench.config.schemas import ConfigSchema +from mlos_bench.services.types.bound_method import BoundMethod from mlos_bench.services.types.config_loader_type import SupportsConfigLoading from mlos_bench.util import instantiate_from_config @@ -278,7 +279,7 @@ def register(self, services: dict[str, Callable] | list[Callable]) -> None: for _, svc_method in self._service_methods.items() # Note: some methods are actually stand alone functions, so we need # to filter them out. - if hasattr(svc_method, "__self__") and isinstance(svc_method.__self__, Service) + if isinstance(svc_method, BoundMethod) and isinstance(svc_method.__self__, Service) } def export(self) -> dict[str, Callable]: diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py index b03006e431c..8d9d35883af 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py @@ -26,6 +26,8 @@ class CopyMode(Enum): class SshFileShareService(FileShareService, SshService): """A collection of functions for interacting with SSH servers as file shares.""" + # pylint: disable=too-many-ancestors + async def _start_file_copy( self, params: dict, diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py index 46cae33d2b9..1b0c7e38231 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py @@ -24,6 +24,7 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec): """Helper methods to manage machines via SSH.""" + # pylint: disable=too-many-ancestors # pylint: disable=too-many-instance-attributes def __init__( diff --git a/mlos_bench/mlos_bench/services/types/authenticator_type.py b/mlos_bench/mlos_bench/services/types/authenticator_type.py index 45b056119f0..e3f2135d146 100644 --- a/mlos_bench/mlos_bench/services/types/authenticator_type.py +++ b/mlos_bench/mlos_bench/services/types/authenticator_type.py @@ -14,6 +14,9 @@ class SupportsAuth(Protocol[T_co]): """Protocol interface for authentication for the cloud services.""" + # Needed by pyright + # pylint: disable=unnecessary-ellipsis,redundant-returns-doc + def get_access_token(self) -> str: """ Get the access token for cloud services. @@ -23,6 +26,7 @@ def get_access_token(self) -> str: access_token : str Access token. """ + ... def get_auth_headers(self) -> dict: """ @@ -33,6 +37,7 @@ def get_auth_headers(self) -> dict: access_header : dict HTTP header containing the access token. """ + ... def get_credential(self) -> T_co: """ @@ -43,3 +48,4 @@ def get_credential(self) -> T_co: credential : T_co Cloud-specific credential object. """ + ... diff --git a/mlos_bench/mlos_bench/services/types/bound_method.py b/mlos_bench/mlos_bench/services/types/bound_method.py new file mode 100644 index 00000000000..8e6179cffe3 --- /dev/null +++ b/mlos_bench/mlos_bench/services/types/bound_method.py @@ -0,0 +1,24 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +"""Protocol representing a bound method.""" + +from typing import Any, Protocol, runtime_checkable + + +@runtime_checkable +class BoundMethod(Protocol): + """A callable method bound to an object.""" + + # pylint: disable=too-few-public-methods + # pylint: disable=unnecessary-ellipsis + + @property + def __self__(self) -> Any: + """The self object of the bound method.""" + ... + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """Call the bound method.""" + ... diff --git a/mlos_bench/mlos_bench/services/types/config_loader_type.py b/mlos_bench/mlos_bench/services/types/config_loader_type.py index 71b68947f9e..97f7c98e216 100644 --- a/mlos_bench/mlos_bench/services/types/config_loader_type.py +++ b/mlos_bench/mlos_bench/services/types/config_loader_type.py @@ -23,6 +23,9 @@ class SupportsConfigLoading(Protocol): """Protocol interface for helper functions to lookup and load configs.""" + # Needed by pyright + # pylint: disable=unnecessary-ellipsis,redundant-returns-doc + def get_config_paths(self) -> list[str]: """ Gets the list of config paths this service will search for config files. @@ -31,6 +34,7 @@ def get_config_paths(self) -> list[str]: ------- list[str] """ + ... def resolve_path(self, file_path: str, extra_paths: Iterable[str] | None = None) -> str: """ @@ -49,6 +53,7 @@ def resolve_path(self, file_path: str, extra_paths: Iterable[str] | None = None) path : str An actual path to the config or script. """ + ... def load_config( self, @@ -71,6 +76,7 @@ def load_config( config : Union[dict, list[dict]] Free-format dictionary that contains the configuration. """ + ... def build_environment( # pylint: disable=too-many-arguments self, @@ -108,6 +114,7 @@ def build_environment( # pylint: disable=too-many-arguments env : Environment An instance of the `Environment` class initialized with `config`. """ + ... def load_environment( self, @@ -140,6 +147,7 @@ def load_environment( env : Environment A new benchmarking environment. """ + ... def load_environment_list( self, @@ -173,6 +181,7 @@ def load_environment_list( env : list[Environment] A list of new benchmarking environments. """ + ... def load_services( self, @@ -198,3 +207,4 @@ def load_services( service : Service A collection of service methods. """ + ... diff --git a/mlos_bench/mlos_bench/services/types/host_ops_type.py b/mlos_bench/mlos_bench/services/types/host_ops_type.py index 4f649166fc6..29e1cd0f941 100644 --- a/mlos_bench/mlos_bench/services/types/host_ops_type.py +++ b/mlos_bench/mlos_bench/services/types/host_ops_type.py @@ -14,6 +14,8 @@ class SupportsHostOps(Protocol): """Protocol interface for Host/VM boot operations.""" + # pylint: disable=unnecessary-ellipsis + def start_host(self, params: dict) -> tuple["Status", dict]: """ Start a Host/VM. @@ -29,6 +31,7 @@ def start_host(self, params: dict) -> tuple["Status", dict]: A pair of Status and result. The result is always {}. Status is one of {PENDING, SUCCEEDED, FAILED} """ + ... def stop_host(self, params: dict, force: bool = False) -> tuple["Status", dict]: """ @@ -47,6 +50,7 @@ def stop_host(self, params: dict, force: bool = False) -> tuple["Status", dict]: A pair of Status and result. The result is always {}. Status is one of {PENDING, SUCCEEDED, FAILED} """ + ... def restart_host(self, params: dict, force: bool = False) -> tuple["Status", dict]: """ @@ -65,6 +69,7 @@ def restart_host(self, params: dict, force: bool = False) -> tuple["Status", dic A pair of Status and result. The result is always {}. Status is one of {PENDING, SUCCEEDED, FAILED} """ + ... def wait_host_operation(self, params: dict) -> tuple["Status", dict]: """ @@ -85,3 +90,4 @@ def wait_host_operation(self, params: dict) -> tuple["Status", dict]: Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} Result is info on the operation runtime if SUCCEEDED, otherwise {}. """ + ... diff --git a/mlos_bench/mlos_bench/services/types/host_provisioner_type.py b/mlos_bench/mlos_bench/services/types/host_provisioner_type.py index 2a343877abc..d979029def6 100644 --- a/mlos_bench/mlos_bench/services/types/host_provisioner_type.py +++ b/mlos_bench/mlos_bench/services/types/host_provisioner_type.py @@ -14,6 +14,8 @@ class SupportsHostProvisioning(Protocol): """Protocol interface for Host/VM provisioning operations.""" + # pylint: disable=unnecessary-ellipsis + def provision_host(self, params: dict) -> tuple["Status", dict]: """ Check if Host/VM is ready. Deploy a new Host/VM, if necessary. @@ -31,6 +33,7 @@ def provision_host(self, params: dict) -> tuple["Status", dict]: A pair of Status and result. The result is always {}. Status is one of {PENDING, SUCCEEDED, FAILED} """ + ... def wait_host_deployment(self, params: dict, *, is_setup: bool) -> tuple["Status", dict]: """ @@ -52,6 +55,7 @@ def wait_host_deployment(self, params: dict, *, is_setup: bool) -> tuple["Status Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} Result is info on the operation runtime if SUCCEEDED, otherwise {}. """ + ... def deprovision_host(self, params: dict) -> tuple["Status", dict]: """ @@ -68,6 +72,7 @@ def deprovision_host(self, params: dict) -> tuple["Status", dict]: A pair of Status and result. The result is always {}. Status is one of {PENDING, SUCCEEDED, FAILED} """ + ... def deallocate_host(self, params: dict) -> tuple["Status", dict]: """ @@ -88,3 +93,4 @@ def deallocate_host(self, params: dict) -> tuple["Status", dict]: A pair of Status and result. The result is always {}. Status is one of {PENDING, SUCCEEDED, FAILED} """ + ... diff --git a/mlos_bench/mlos_bench/services/types/local_exec_type.py b/mlos_bench/mlos_bench/services/types/local_exec_type.py index 12b8cafdf61..831870c194d 100644 --- a/mlos_bench/mlos_bench/services/types/local_exec_type.py +++ b/mlos_bench/mlos_bench/services/types/local_exec_type.py @@ -24,6 +24,9 @@ class SupportsLocalExec(Protocol): vs the target environment. Used in LocalEnv and provided by LocalExecService. """ + # Needed by pyright + # pylint: disable=unnecessary-ellipsis,redundant-returns-doc + def local_exec( self, script_lines: Iterable[str], @@ -49,6 +52,7 @@ def local_exec( (return_code, stdout, stderr) : (int, str, str) A 3-tuple of return code, stdout, and stderr of the script process. """ + ... def temp_dir_context( self, @@ -59,7 +63,7 @@ def temp_dir_context( Parameters ---------- - path : str + path : str | None A path to the temporary directory. Create a new one if None. Returns @@ -67,3 +71,4 @@ def temp_dir_context( temp_dir_context : tempfile.TemporaryDirectory Temporary directory context to use in the `with` clause. """ + ... diff --git a/mlos_bench/mlos_bench/services/types/network_provisioner_type.py b/mlos_bench/mlos_bench/services/types/network_provisioner_type.py index 58248ae486a..00f63a51419 100644 --- a/mlos_bench/mlos_bench/services/types/network_provisioner_type.py +++ b/mlos_bench/mlos_bench/services/types/network_provisioner_type.py @@ -14,6 +14,8 @@ class SupportsNetworkProvisioning(Protocol): """Protocol interface for Network provisioning operations.""" + # pylint: disable=unnecessary-ellipsis + def provision_network(self, params: dict) -> tuple["Status", dict]: """ Check if Network is ready. Deploy a new Network, if necessary. @@ -31,6 +33,7 @@ def provision_network(self, params: dict) -> tuple["Status", dict]: A pair of Status and result. The result is always {}. Status is one of {PENDING, SUCCEEDED, FAILED} """ + ... def wait_network_deployment(self, params: dict, *, is_setup: bool) -> tuple["Status", dict]: """ @@ -52,6 +55,7 @@ def wait_network_deployment(self, params: dict, *, is_setup: bool) -> tuple["Sta Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} Result is info on the operation runtime if SUCCEEDED, otherwise {}. """ + ... def deprovision_network( self, @@ -75,3 +79,4 @@ def deprovision_network( A pair of Status and result. The result is always {}. Status is one of {PENDING, SUCCEEDED, FAILED} """ + ... diff --git a/mlos_bench/mlos_bench/services/types/os_ops_type.py b/mlos_bench/mlos_bench/services/types/os_ops_type.py index 43c00f7f84c..f3b3c127136 100644 --- a/mlos_bench/mlos_bench/services/types/os_ops_type.py +++ b/mlos_bench/mlos_bench/services/types/os_ops_type.py @@ -14,6 +14,8 @@ class SupportsOSOps(Protocol): """Protocol interface for Host/OS operations.""" + # pylint: disable=unnecessary-ellipsis + def shutdown(self, params: dict, force: bool = False) -> tuple["Status", dict]: """ Initiates a (graceful) shutdown of the Host/VM OS. @@ -31,6 +33,7 @@ def shutdown(self, params: dict, force: bool = False) -> tuple["Status", dict]: A pair of Status and result. The result is always {}. Status is one of {PENDING, SUCCEEDED, FAILED} """ + ... def reboot(self, params: dict, force: bool = False) -> tuple["Status", dict]: """ @@ -49,6 +52,7 @@ def reboot(self, params: dict, force: bool = False) -> tuple["Status", dict]: A pair of Status and result. The result is always {}. Status is one of {PENDING, SUCCEEDED, FAILED} """ + ... def wait_os_operation(self, params: dict) -> tuple["Status", dict]: """ @@ -69,3 +73,4 @@ def wait_os_operation(self, params: dict) -> tuple["Status", dict]: Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} Result is info on the operation runtime if SUCCEEDED, otherwise {}. """ + ... diff --git a/mlos_bench/mlos_bench/services/types/remote_config_type.py b/mlos_bench/mlos_bench/services/types/remote_config_type.py index 7a96e16fb40..09a1e8c20fe 100644 --- a/mlos_bench/mlos_bench/services/types/remote_config_type.py +++ b/mlos_bench/mlos_bench/services/types/remote_config_type.py @@ -14,6 +14,8 @@ class SupportsRemoteConfig(Protocol): """Protocol interface for configuring cloud services.""" + # pylint: disable=unnecessary-ellipsis + def configure(self, config: dict[str, Any], params: dict[str, Any]) -> tuple["Status", dict]: """ Update the parameters of a SaaS service in the cloud. @@ -31,6 +33,7 @@ def configure(self, config: dict[str, Any], params: dict[str, Any]) -> tuple["St A pair of Status and result. The result is always {}. Status is one of {PENDING, SUCCEEDED, FAILED} """ + ... def is_config_pending(self, config: dict[str, Any]) -> tuple["Status", dict]: """ @@ -49,3 +52,4 @@ def is_config_pending(self, config: dict[str, Any]) -> tuple["Status", dict]: If "isConfigPendingReboot" is set to True, rebooting a VM is necessary. Status is one of {PENDING, TIMED_OUT, SUCCEEDED, FAILED} """ + ... diff --git a/mlos_bench/mlos_bench/services/types/remote_exec_type.py b/mlos_bench/mlos_bench/services/types/remote_exec_type.py index d90134070b7..052ad9be7da 100644 --- a/mlos_bench/mlos_bench/services/types/remote_exec_type.py +++ b/mlos_bench/mlos_bench/services/types/remote_exec_type.py @@ -19,6 +19,8 @@ class SupportsRemoteExec(Protocol): on a remote host OS. """ + # pylint: disable=unnecessary-ellipsis + def remote_exec( self, script: Iterable[str], @@ -46,6 +48,7 @@ def remote_exec( A pair of Status and result. Status is one of {PENDING, SUCCEEDED, FAILED} """ + ... def get_remote_exec_results(self, config: dict) -> tuple["Status", dict]: """ @@ -64,3 +67,4 @@ def get_remote_exec_results(self, config: dict) -> tuple["Status", dict]: A pair of Status and result. Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} """ + ... diff --git a/mlos_bench/mlos_bench/services/types/vm_provisioner_type.py b/mlos_bench/mlos_bench/services/types/vm_provisioner_type.py index e56be3d6083..a67e9f0274a 100644 --- a/mlos_bench/mlos_bench/services/types/vm_provisioner_type.py +++ b/mlos_bench/mlos_bench/services/types/vm_provisioner_type.py @@ -14,6 +14,8 @@ class SupportsVMOps(Protocol): """Protocol interface for VM provisioning operations.""" + # pylint: disable=unnecessary-ellipsis + def vm_provision(self, params: dict) -> tuple["Status", dict]: """ Check if VM is ready. Deploy a new VM, if necessary. @@ -31,6 +33,7 @@ def vm_provision(self, params: dict) -> tuple["Status", dict]: A pair of Status and result. The result is always {}. Status is one of {PENDING, SUCCEEDED, FAILED} """ + ... def wait_vm_deployment(self, is_setup: bool, params: dict) -> tuple["Status", dict]: """ @@ -51,6 +54,7 @@ def wait_vm_deployment(self, is_setup: bool, params: dict) -> tuple["Status", di Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} Result is info on the operation runtime if SUCCEEDED, otherwise {}. """ + ... def vm_start(self, params: dict) -> tuple["Status", dict]: """ @@ -67,6 +71,7 @@ def vm_start(self, params: dict) -> tuple["Status", dict]: A pair of Status and result. The result is always {}. Status is one of {PENDING, SUCCEEDED, FAILED} """ + ... def vm_stop(self, params: dict) -> tuple["Status", dict]: """ @@ -83,6 +88,7 @@ def vm_stop(self, params: dict) -> tuple["Status", dict]: A pair of Status and result. The result is always {}. Status is one of {PENDING, SUCCEEDED, FAILED} """ + ... def vm_restart(self, params: dict) -> tuple["Status", dict]: """ @@ -99,6 +105,7 @@ def vm_restart(self, params: dict) -> tuple["Status", dict]: A pair of Status and result. The result is always {}. Status is one of {PENDING, SUCCEEDED, FAILED} """ + ... def vm_deprovision(self, params: dict) -> tuple["Status", dict]: """ @@ -115,6 +122,7 @@ def vm_deprovision(self, params: dict) -> tuple["Status", dict]: A pair of Status and result. The result is always {}. Status is one of {PENDING, SUCCEEDED, FAILED} """ + ... def wait_vm_operation(self, params: dict) -> tuple["Status", dict]: """ @@ -135,3 +143,4 @@ def wait_vm_operation(self, params: dict) -> tuple["Status", dict]: Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} Result is info on the operation runtime if SUCCEEDED, otherwise {}. """ + ... diff --git a/mlos_bench/mlos_bench/storage/sql/experiment.py b/mlos_bench/mlos_bench/storage/sql/experiment.py index f902571e9fa..62daa0232c2 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment.py @@ -293,9 +293,11 @@ def _get_config_id(self, conn: Connection, tunables: TunableGroups) -> int: if cur_config is not None: return int(cur_config.config_id) # mypy doesn't know it's always int # Config not found, create a new one: - config_id: int = conn.execute( + new_config_result = conn.execute( self._schema.config.insert().values(config_hash=config_hash) - ).inserted_primary_key[0] + ).inserted_primary_key + assert new_config_result + config_id: int = new_config_result[0] save_params( conn, self._schema.config_param, diff --git a/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py b/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py index 0d81ec78477..b8059a3f67e 100644 --- a/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py +++ b/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py @@ -10,6 +10,7 @@ from mlos_bench.environments.composite_env import CompositeEnv from mlos_bench.services.config_persistence import ConfigPersistenceService from mlos_bench.services.local.local_exec import LocalExecService +from mlos_bench.services.types.local_exec_type import SupportsLocalExec from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.util import path_join @@ -58,6 +59,7 @@ def test_composite_services(composite_env: CompositeEnv) -> None: for i, path in ((0, "_test_tmp_global"), (1, "_test_tmp_other_2"), (2, "_test_tmp_other_3")): service = composite_env.children[i]._service # pylint: disable=protected-access assert service is not None and hasattr(service, "temp_dir_context") + assert isinstance(service, SupportsLocalExec) with service.temp_dir_context() as temp_dir: assert os.path.samefile(temp_dir, path) os.rmdir(path) diff --git a/mlos_bench/mlos_bench/tests/event_loop_context_test.py b/mlos_bench/mlos_bench/tests/event_loop_context_test.py index cc6cd825b3d..7eff0fe7451 100644 --- a/mlos_bench/mlos_bench/tests/event_loop_context_test.py +++ b/mlos_bench/mlos_bench/tests/event_loop_context_test.py @@ -137,9 +137,15 @@ def test_event_loop_context() -> None: ): assert len(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop._ready) == 1 else: - assert len(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop._ready) == 0 + assert ( + len(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop._ready) # pyright: ignore + == 0 + ) assert hasattr(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, "_scheduled") - assert len(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop._scheduled) == 0 + assert ( + len(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop._scheduled) # pyright: ignore + == 0 + ) with pytest.raises( AssertionError diff --git a/mlos_bench/mlos_bench/tests/launcher_run_test.py b/mlos_bench/mlos_bench/tests/launcher_run_test.py index 61c4679f027..4974a75b80b 100644 --- a/mlos_bench/mlos_bench/tests/launcher_run_test.py +++ b/mlos_bench/mlos_bench/tests/launcher_run_test.py @@ -83,9 +83,11 @@ def test_launch_main_app_bench(root_path: str, local_exec_service: LocalExecServ _launch_main_app( root_path, local_exec_service, - " --config cli/mock-bench.jsonc" - " --trial_config_repeat_count 5" - " --mock_env_seed -1", # Deterministic Mock Environment. + ( + " --config cli/mock-bench.jsonc" + " --trial_config_repeat_count 5" + " --mock_env_seed -1" # Deterministic Mock Environment. + ), [ f"^{_RE_DATE} run\\.py:\\d+ " + r"_main INFO Final score: \{'score': 67\.40\d+\}\s*$", ], @@ -102,10 +104,12 @@ def test_launch_main_app_bench_values( _launch_main_app( root_path, local_exec_service, - " --config cli/mock-bench.jsonc" - " --tunable_values tunable-values/tunable-values-example.jsonc" - " --trial_config_repeat_count 5" - " --mock_env_seed -1", # Deterministic Mock Environment. + ( + " --config cli/mock-bench.jsonc" + " --tunable_values tunable-values/tunable-values-example.jsonc" + " --trial_config_repeat_count 5" + " --mock_env_seed -1" # Deterministic Mock Environment. + ), [ f"^{_RE_DATE} run\\.py:\\d+ " + r"_main INFO Final score: \{'score': 67\.11\d+\}\s*$", ], @@ -119,10 +123,12 @@ def test_launch_main_app_opt(root_path: str, local_exec_service: LocalExecServic _launch_main_app( root_path, local_exec_service, - "--config cli/mock-opt.jsonc" - " --trial_config_repeat_count 3" - " --max_suggestions 3" - " --mock_env_seed 42", # Noisy Mock Environment. + ( + "--config cli/mock-opt.jsonc" + " --trial_config_repeat_count 3" + " --max_suggestions 3" + " --mock_env_seed 42" # Noisy Mock Environment. + ), [ # Iteration 1: Expect first value to be the baseline f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " diff --git a/mlos_bench/mlos_bench/tests/services/mock_service.py b/mlos_bench/mlos_bench/tests/services/mock_service.py index 29cebc95387..868cd24712b 100644 --- a/mlos_bench/mlos_bench/tests/services/mock_service.py +++ b/mlos_bench/mlos_bench/tests/services/mock_service.py @@ -18,11 +18,15 @@ class SupportsSomeMethod(Protocol): """Protocol for some_method.""" + # pylint: disable=unnecessary-ellipsis + def some_method(self) -> str: """some_method.""" + ... def some_other_method(self) -> str: """some_other_method.""" + ... class MockServiceBase(Service, SupportsSomeMethod): diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py index 62e5a2804a9..86a1e27deb1 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py @@ -17,6 +17,8 @@ class MockVMService(Service, SupportsHostProvisioning, SupportsHostOps, SupportsOSOps): """Mock VM service for testing.""" + # pylint: disable=too-many-ancestors + def __init__( self, config: dict[str, Any] | None = None, diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py index a7b12c26921..3ce11ce82ba 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py @@ -27,6 +27,7 @@ def test_categorical_required_params() -> None: } """ config = json.loads(json_config) + assert isinstance(config, dict) with pytest.raises(ValueError): Tunable(name="test", config=config) @@ -42,6 +43,7 @@ def test_categorical_weights() -> None: } """ config = json.loads(json_config) + assert isinstance(config, dict) tunable = Tunable(name="test", config=config) assert tunable.weights == [25, 25, 50] @@ -57,6 +59,7 @@ def test_categorical_weights_wrong_count() -> None: } """ config = json.loads(json_config) + assert isinstance(config, dict) with pytest.raises(ValueError): Tunable(name="test", config=config) @@ -72,6 +75,7 @@ def test_categorical_weights_wrong_values() -> None: } """ config = json.loads(json_config) + assert isinstance(config, dict) with pytest.raises(ValueError): Tunable(name="test", config=config) @@ -87,6 +91,7 @@ def test_categorical_wrong_params() -> None: } """ config = json.loads(json_config) + assert isinstance(config, dict) with pytest.raises(ValueError): Tunable(name="test", config=config) @@ -102,6 +107,7 @@ def test_categorical_disallow_special_values() -> None: } """ config = json.loads(json_config) + assert isinstance(config, dict) with pytest.raises(ValueError): Tunable(name="test", config=config) @@ -173,6 +179,7 @@ def test_numerical_tunable_required_params(tunable_type: TunableValueTypeName) - }} """ config = json.loads(json_config) + assert isinstance(config, dict) with pytest.raises(ValueError): Tunable(name=f"test_{tunable_type}", config=config) @@ -188,6 +195,7 @@ def test_numerical_tunable_invalid_range(tunable_type: TunableValueTypeName) -> }} """ config = json.loads(json_config) + assert isinstance(config, dict) with pytest.raises(AssertionError): Tunable(name=f"test_{tunable_type}", config=config) @@ -203,6 +211,7 @@ def test_numerical_tunable_reversed_range(tunable_type: TunableValueTypeName) -> }} """ config = json.loads(json_config) + assert isinstance(config, dict) with pytest.raises(ValueError): Tunable(name=f"test_{tunable_type}", config=config) @@ -221,6 +230,7 @@ def test_numerical_weights(tunable_type: TunableValueTypeName) -> None: }} """ config = json.loads(json_config) + assert isinstance(config, dict) tunable = Tunable(name="test", config=config) assert tunable.special == [0] assert tunable.weights == [0.1] @@ -239,6 +249,7 @@ def test_numerical_quantization(tunable_type: TunableValueTypeName) -> None: }} """ config = json.loads(json_config) + assert isinstance(config, dict) tunable = Tunable(name="test", config=config) expected = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100] assert tunable.quantization_bins == len(expected) @@ -258,6 +269,7 @@ def test_numerical_log(tunable_type: TunableValueTypeName) -> None: }} """ config = json.loads(json_config) + assert isinstance(config, dict) tunable = Tunable(name="test", config=config) assert tunable.is_log @@ -274,6 +286,7 @@ def test_numerical_weights_no_specials(tunable_type: TunableValueTypeName) -> No }} """ config = json.loads(json_config) + assert isinstance(config, dict) with pytest.raises(ValueError): Tunable(name="test", config=config) @@ -294,6 +307,7 @@ def test_numerical_weights_non_normalized(tunable_type: TunableValueTypeName) -> }} """ config = json.loads(json_config) + assert isinstance(config, dict) tunable = Tunable(name="test", config=config) assert tunable.special == [-1, 0] assert tunable.weights == [0, 10] # Zero weights are ok @@ -314,6 +328,7 @@ def test_numerical_weights_wrong_count(tunable_type: TunableValueTypeName) -> No }} """ config = json.loads(json_config) + assert isinstance(config, dict) with pytest.raises(ValueError): Tunable(name="test", config=config) @@ -331,6 +346,7 @@ def test_numerical_weights_no_range_weight(tunable_type: TunableValueTypeName) - }} """ config = json.loads(json_config) + assert isinstance(config, dict) with pytest.raises(ValueError): Tunable(name="test", config=config) @@ -348,6 +364,7 @@ def test_numerical_range_weight_no_weights(tunable_type: TunableValueTypeName) - }} """ config = json.loads(json_config) + assert isinstance(config, dict) with pytest.raises(ValueError): Tunable(name="test", config=config) @@ -364,6 +381,7 @@ def test_numerical_range_weight_no_specials(tunable_type: TunableValueTypeName) }} """ config = json.loads(json_config) + assert isinstance(config, dict) with pytest.raises(ValueError): Tunable(name="test", config=config) @@ -382,6 +400,7 @@ def test_numerical_weights_wrong_values(tunable_type: TunableValueTypeName) -> N }} """ config = json.loads(json_config) + assert isinstance(config, dict) with pytest.raises(ValueError): Tunable(name="test", config=config) @@ -398,6 +417,7 @@ def test_numerical_quantization_wrong(tunable_type: TunableValueTypeName) -> Non }} """ config = json.loads(json_config) + assert isinstance(config, dict) with pytest.raises(ValueError): Tunable(name="test", config=config) @@ -412,5 +432,6 @@ def test_bad_type() -> None: } """ config = json.loads(json_config) + assert isinstance(config, dict) with pytest.raises(ValueError): Tunable(name="test_bad_type", config=config) diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py index 54f08e17092..2c8e0a6bfdc 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py @@ -90,5 +90,6 @@ def test_numerical_distribution_unsupported(tunable_type: str) -> None: }} """ config = json.loads(json_config) + assert isinstance(config, dict) with pytest.raises(ValueError): Tunable(name="test", config=config) diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_slice_references_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_slice_references_test.py index 9d267d4e16c..e06f81f9ea2 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_slice_references_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_slice_references_test.py @@ -81,6 +81,7 @@ def test_overlapping_group_merge_tunable_groups(tunable_groups_config: dict) -> """ other_tunables_config = json.loads(other_tunables_json) + assert isinstance(other_tunables_config, dict) other_tunables = TunableGroups(other_tunables_config) with pytest.raises(ValueError): @@ -110,6 +111,7 @@ def test_bad_extended_merge_tunable_group(tunable_groups_config: dict) -> None: """ other_tunables_config = json.loads(other_tunables_json) + assert isinstance(other_tunables_config, dict) other_tunables = TunableGroups(other_tunables_config) with pytest.raises(ValueError): @@ -138,6 +140,7 @@ def test_good_extended_merge_tunable_group(tunable_groups_config: dict) -> None: """ other_tunables_config = json.loads(other_tunables_json) + assert isinstance(other_tunables_config, dict) other_tunables = TunableGroups(other_tunables_config) assert "new-param" not in parent_tunables diff --git a/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py b/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py index e4ddbed28be..77f7710b6d4 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py @@ -148,6 +148,7 @@ def test_tunable_assign_null_to_categorical() -> None: } """ config = json.loads(json_config) + assert isinstance(config, dict) categorical_tunable = Tunable(name="categorical_test", config=config) assert categorical_tunable assert categorical_tunable.category == "foo" diff --git a/mlos_bench/mlos_bench/tests/tunables/tunables_copy_test.py b/mlos_bench/mlos_bench/tests/tunables/tunables_copy_test.py index c5395fcb16f..87e97929d7d 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunables_copy_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunables_copy_test.py @@ -37,6 +37,8 @@ def test_copy_covariant_group(covariant_group: CovariantTunableGroup) -> None: new_value = [x for x in tunable.categories if x != tunable.category][0] elif tunable.is_numerical: new_value = tunable.numerical_value + 1 + else: + raise ValueError(f"{tunable=} :: unsupported tunable type.") covariant_group_copy[tunable] = new_value assert covariant_group_copy.is_updated() assert not covariant_group.is_updated() diff --git a/mlos_bench/mlos_bench/tunables/tunable.py b/mlos_bench/mlos_bench/tunables/tunable.py index fa08faa898c..69267bddcaf 100644 --- a/mlos_bench/mlos_bench/tunables/tunable.py +++ b/mlos_bench/mlos_bench/tunables/tunable.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # """Tunable parameter definition.""" -import collections import copy import logging from collections.abc import Iterable, Sequence @@ -38,25 +37,40 @@ """Tunable value distribution type.""" -class DistributionDict(TypedDict, total=False): - """A typed dict for tunable parameters' distributions.""" +class DistributionDictOpt(TypedDict, total=False): + """ + A TypedDict for a :py:class:`.Tunable` parameter's optional ``distribution``'s + config parameters. + + Mostly used by type checking. These are the types expected to be received from the + json config. + """ - type: DistributionName params: dict[str, float] | None -class TunableDict(TypedDict, total=False): +class DistributionDict(DistributionDictOpt): """ - A typed dict for tunable parameters. + A TypedDict for a :py:class:`.Tunable` parameter's required ``distribution``'s + config parameters. - Mostly used for mypy type checking. + Mostly used by type checking. These are the types expected to be received from the + json config. + """ + + type: DistributionName - These are the types expected to be received from the json config. + +class TunableDictOpt(TypedDict, total=False): """ + A TypedDict for a :py:class:`.Tunable` parameter's optional config parameters. - type: TunableValueTypeName + Mostly used for mypy type checking. These are the types expected to be received from + the json config. + """ + + # Optional fields description: str | None - default: TunableValue values: list[str | None] | None range: Sequence[int] | Sequence[float] | None quantization_bins: int | None @@ -69,17 +83,64 @@ class TunableDict(TypedDict, total=False): meta: dict[str, Any] +class TunableDict(TunableDictOpt): + """ + A TypedDict for a :py:class:`.Tunable` parameter's required config parameters. + + Mostly used for mypy type checking. These are the types expected to be received from + the json config. + """ + + # Required fields + type: TunableValueTypeName + default: TunableValue + + +def tunable_dict_from_dict(config: dict[str, Any]) -> TunableDict: + """ + Creates a TunableDict from a regular dict. + + Parameters + ---------- + config : dict[str, Any] + A regular dict that represents a TunableDict. + + Returns + ------- + TunableDict + """ + _type = config.get("type") + if _type not in Tunable.DTYPE: + raise ValueError(f"Invalid parameter type: {_type}") + _meta = config.get("meta", {}) + return TunableDict( + type=_type, + description=config.get("description"), + default=config.get("default"), + values=config.get("values"), + range=config.get("range"), + quantization_bins=config.get("quantization_bins"), + log=config.get("log"), + distribution=config.get("distribution"), + special=config.get("special"), + values_weights=config.get("values_weights"), + special_weights=config.get("special_weights"), + range_weight=config.get("range_weight"), + meta=_meta, + ) + + class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-methods """A tunable parameter definition and its current value.""" - # Maps tunable types to their corresponding Python types by name. - _DTYPE: dict[TunableValueTypeName, TunableValueType] = { + DTYPE: dict[TunableValueTypeName, TunableValueType] = { "int": int, "float": float, "categorical": str, } + """Maps Tunable types to their corresponding Python types by name.""" - def __init__(self, name: str, config: TunableDict): + def __init__(self, name: str, config: dict): """ Create an instance of a new tunable parameter. @@ -95,25 +156,26 @@ def __init__(self, name: str, config: TunableDict): :py:mod:`mlos_bench.tunables` : for more information on tunable parameters and their configuration. """ + t_config = tunable_dict_from_dict(config) if not isinstance(name, str) or "!" in name: # TODO: Use a regex here and in JSON schema raise ValueError(f"Invalid name of the tunable: {name}") self._name = name - self._type: TunableValueTypeName = config["type"] # required - if self._type not in self._DTYPE: + self._type: TunableValueTypeName = t_config["type"] # required + if self._type not in self.DTYPE: raise ValueError(f"Invalid parameter type: {self._type}") - self._description = config.get("description") - self._default = config["default"] + self._description = t_config.get("description") + self._default = t_config["default"] self._default = self.dtype(self._default) if self._default is not None else self._default - self._values = config.get("values") + self._values = t_config.get("values") if self._values: self._values = [str(v) if v is not None else v for v in self._values] - self._meta: dict[str, Any] = config.get("meta", {}) + self._meta: dict[str, Any] = t_config.get("meta", {}) self._range: tuple[int, int] | tuple[float, float] | None = None - self._quantization_bins: int | None = config.get("quantization_bins") - self._log: bool | None = config.get("log") + self._quantization_bins: int | None = t_config.get("quantization_bins") + self._log: bool | None = t_config.get("log") self._distribution: DistributionName | None = None self._distribution_params: dict[str, float] = {} - distr = config.get("distribution") + distr = t_config.get("distribution") if distr: self._distribution = distr["type"] # required self._distribution_params = distr.get("params") or {} @@ -122,11 +184,11 @@ def __init__(self, name: str, config: TunableDict): assert len(config_range) == 2, f"Invalid range: {config_range}" config_range = (config_range[0], config_range[1]) self._range = config_range - self._special: list[int] | list[float] = config.get("special") or [] + self._special: list[int] | list[float] = t_config.get("special") or [] self._weights: list[float] = ( - config.get("values_weights") or config.get("special_weights") or [] + t_config.get("values_weights") or t_config.get("special_weights") or [] ) - self._range_weight: float | None = config.get("range_weight") + self._range_weight: float | None = t_config.get("range_weight") self._current_value = None self._sanity_check() self.value = self._default @@ -150,7 +212,7 @@ def _sanity_check_categorical(self) -> None: """ # pylint: disable=too-complex assert self.is_categorical - if not (self._values and isinstance(self._values, collections.abc.Iterable)): + if not (self._values and isinstance(self._values, Iterable)): raise ValueError(f"Must specify values for the categorical type tunable {self}") if self._range is not None: raise ValueError(f"Range must be None for the categorical type tunable {self}") @@ -523,7 +585,7 @@ def dtype(self) -> TunableValueType: dtype : type Data type of the tunable - one of {int, float, str}. """ - return self._DTYPE[self._type] + return self.DTYPE[self._type] @property def is_categorical(self) -> bool: diff --git a/mlos_bench/setup.py b/mlos_bench/setup.py index 1ff3dbc461e..7a3ccbc1021 100644 --- a/mlos_bench/setup.py +++ b/mlos_bench/setup.py @@ -25,7 +25,8 @@ warning(f"version.py not found, using dummy VERSION={VERSION}") try: - from setuptools_scm import get_version + # Note: setuptools_scm is typically only installed as a part of the build process. + from setuptools_scm import get_version # pyright: ignore[reportMissingImports] version = get_version(root="..", relative_to=__file__, fallback_version=VERSION) if version is not None: diff --git a/mlos_core/mlos_core/__init__.py b/mlos_core/mlos_core/__init__.py index cb19057a7d3..26d4adbf81c 100644 --- a/mlos_core/mlos_core/__init__.py +++ b/mlos_core/mlos_core/__init__.py @@ -111,7 +111,7 @@ franca for data science. - :py:meth:`mlos_core.optimizers.OptimizerFactory.create` is a factory function - that creates a new :py:type:`~mlos_core.optimizers.ConcreteOptimizer` instance + that creates a new :py:attr:`~mlos_core.optimizers.ConcreteOptimizer` instance To do this it uses the :py:class:`~mlos_core.optimizers.OptimizerType` enum to specify which underlying optimizer to use (e.g., diff --git a/mlos_core/mlos_core/optimizers/__init__.py b/mlos_core/mlos_core/optimizers/__init__.py index 9bec7a5ef8d..3b3786e1e30 100644 --- a/mlos_core/mlos_core/optimizers/__init__.py +++ b/mlos_core/mlos_core/optimizers/__init__.py @@ -29,7 +29,6 @@ """ from enum import Enum -from typing import TypeVar import ConfigSpace @@ -71,19 +70,9 @@ class will be used. """ -# To make mypy happy, we need to define a type variable for each optimizer type. -# https://github.com/python/mypy/issues/12952 -# ConcreteOptimizer = TypeVar('ConcreteOptimizer', *[member.value for member in OptimizerType]) -# To address this, we add a test for complete coverage of the enum. - -ConcreteOptimizer = TypeVar( - "ConcreteOptimizer", - RandomOptimizer, - FlamlOptimizer, - SmacOptimizer, -) +ConcreteOptimizer = RandomOptimizer | FlamlOptimizer | SmacOptimizer """ -Type variable for concrete optimizer classes. +Type alias for concrete optimizer classes. (e.g., :class:`~mlos_core.optimizers.bayesian_optimizers.smac_optimizer.SmacOptimizer`, etc.) """ @@ -108,7 +97,7 @@ def create( # pylint: disable=too-many-arguments optimizer_kwargs: dict | None = None, space_adapter_type: SpaceAdapterType = SpaceAdapterType.IDENTITY, space_adapter_kwargs: dict | None = None, - ) -> ConcreteOptimizer: # type: ignore[type-var] + ) -> ConcreteOptimizer: """ Create a new optimizer instance, given the parameter space, optimizer type, and potential optimizer options. diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py index 910b75d7175..e39216f33a2 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py @@ -192,9 +192,11 @@ def __init__( initial_design_args["n_configs"] = n_random_init if n_random_init > 0.25 * max_trials and max_ratio is None: warning( - "Number of random initial configs (%d) is " - "greater than 25%% of max_trials (%d). " - "Consider setting max_ratio to avoid SMAC overriding n_random_init.", + ( + "Number of random initial configs (%d) is " + "greater than 25%% of max_trials (%d). " + "Consider setting max_ratio to avoid SMAC overriding n_random_init." + ), n_random_init, max_trials, ) diff --git a/mlos_core/mlos_core/spaces/adapters/__init__.py b/mlos_core/mlos_core/spaces/adapters/__init__.py index 16a2c3065fd..2cd29526400 100644 --- a/mlos_core/mlos_core/spaces/adapters/__init__.py +++ b/mlos_core/mlos_core/spaces/adapters/__init__.py @@ -32,7 +32,6 @@ """ from enum import Enum -from typing import TypeVar import ConfigSpace @@ -58,19 +57,8 @@ class SpaceAdapterType(Enum): """An instance of :class:`.LlamaTuneAdapter` class will be used.""" -# To make mypy happy, we need to define a type variable for each optimizer type. -# https://github.com/python/mypy/issues/12952 -# ConcreteSpaceAdapter = TypeVar( -# "ConcreteSpaceAdapter", -# *[member.value for member in SpaceAdapterType], -# ) -# To address this, we add a test for complete coverage of the enum. -ConcreteSpaceAdapter = TypeVar( - "ConcreteSpaceAdapter", - IdentityAdapter, - LlamaTuneAdapter, -) -"""Type variable for concrete SpaceAdapter classes (e.g., +ConcreteSpaceAdapter = IdentityAdapter | LlamaTuneAdapter +"""Type alias for concrete SpaceAdapter classes (e.g., :class:`~mlos_core.spaces.adapters.identity_adapter.IdentityAdapter`, etc.) """ @@ -86,9 +74,9 @@ class SpaceAdapterFactory: def create( *, parameter_space: ConfigSpace.ConfigurationSpace, - space_adapter_type: SpaceAdapterType = SpaceAdapterType.IDENTITY, + space_adapter_type: SpaceAdapterType | None = SpaceAdapterType.IDENTITY, space_adapter_kwargs: dict | None = None, - ) -> ConcreteSpaceAdapter: # type: ignore[type-var] + ) -> ConcreteSpaceAdapter: """ Create a new space adapter instance, given the parameter space and potential space adapter options. diff --git a/mlos_core/mlos_core/spaces/adapters/llamatune.py b/mlos_core/mlos_core/spaces/adapters/llamatune.py index 5dea393d765..73478de9a24 100644 --- a/mlos_core/mlos_core/spaces/adapters/llamatune.py +++ b/mlos_core/mlos_core/spaces/adapters/llamatune.py @@ -14,6 +14,7 @@ `_. """ import os +from typing import Any from warnings import warn import ConfigSpace @@ -565,7 +566,9 @@ def _try_generate_approx_inverse_mapping(self) -> None: # Compute pseudo-inverse matrix try: - inv_matrix: npt.NDArray = pinv(proj_matrix) + _inv = pinv(proj_matrix) + assert _inv is not None and not isinstance(_inv, tuple) + inv_matrix: npt.NDArray[np.floating[Any]] = _inv self._pinv_matrix = inv_matrix except LinAlgError as err: raise RuntimeError( diff --git a/mlos_core/mlos_core/spaces/converters/flaml.py b/mlos_core/mlos_core/spaces/converters/flaml.py index 7b203b42ad6..f82334bab98 100644 --- a/mlos_core/mlos_core/spaces/converters/flaml.py +++ b/mlos_core/mlos_core/spaces/converters/flaml.py @@ -12,14 +12,15 @@ import flaml.tune import flaml.tune.sample import numpy as np +from flaml.tune.sample import Domain if TYPE_CHECKING: from ConfigSpace.hyperparameters import Hyperparameter -FlamlDomain: TypeAlias = flaml.tune.sample.Domain +FlamlDomain: TypeAlias = Domain """Flaml domain type alias.""" -FlamlSpace: TypeAlias = dict[str, flaml.tune.sample.Domain] +FlamlSpace: TypeAlias = dict[str, Domain] """Flaml space type alias - a `dict[str, FlamlDomain]`""" diff --git a/mlos_core/mlos_core/spaces/converters/util.py b/mlos_core/mlos_core/spaces/converters/util.py index de0edb7cd1b..9b97a11f329 100644 --- a/mlos_core/mlos_core/spaces/converters/util.py +++ b/mlos_core/mlos_core/spaces/converters/util.py @@ -41,7 +41,11 @@ def monkey_patch_hp_quantization(hp: Hyperparameter) -> Hyperparameter: # No quantization requested. # Remove any previously applied patches. if hasattr(dist, "sample_vector_mlos_orig"): - setattr(dist, "sample_vector", dist.sample_vector_mlos_orig) + setattr( + dist, + "sample_vector", + dist.sample_vector_mlos_orig, # pyright: ignore[reportAttributeAccessIssue] + ) delattr(dist, "sample_vector_mlos_orig") return hp @@ -61,7 +65,10 @@ def monkey_patch_hp_quantization(hp: Hyperparameter) -> Hyperparameter: dist, "sample_vector", lambda n, *, seed=None: quantize( - dist.sample_vector_mlos_orig(n, seed=seed), + dist.sample_vector_mlos_orig( # pyright: ignore[reportAttributeAccessIssue] + n, + seed=seed, + ), bounds=(dist.lower_vectorized, dist.upper_vectorized), bins=quantization_bins, ), diff --git a/mlos_core/mlos_core/tests/__init__.py b/mlos_core/mlos_core/tests/__init__.py index 65f5d696866..d5856689615 100644 --- a/mlos_core/mlos_core/tests/__init__.py +++ b/mlos_core/mlos_core/tests/__init__.py @@ -6,7 +6,8 @@ from importlib import import_module from pkgutil import walk_packages -from typing import TypeAlias, TypeVar +from types import ModuleType +from typing import TypeVar # A common seed to use to avoid tracking down race conditions and intermingling # issues of seeds across tests that run in non-deterministic parallel orders. @@ -15,7 +16,7 @@ T = TypeVar("T") -def get_all_submodules(pkg: TypeAlias) -> list[str]: +def get_all_submodules(pkg: ModuleType) -> list[str]: """ Imports all submodules for a package and returns their names. diff --git a/mlos_core/mlos_core/tests/optimizers/optimizer_test.py b/mlos_core/mlos_core/tests/optimizers/optimizer_test.py index 2106ba64192..8913e976450 100644 --- a/mlos_core/mlos_core/tests/optimizers/optimizer_test.py +++ b/mlos_core/mlos_core/tests/optimizers/optimizer_test.py @@ -172,8 +172,7 @@ def objective(inp: float) -> pd.Series: ) def test_concrete_optimizer_type(optimizer_type: OptimizerType) -> None: """Test that all optimizer types are listed in the ConcreteOptimizer constraints.""" - # pylint: disable=no-member - assert optimizer_type.value in ConcreteOptimizer.__constraints__ + assert optimizer_type.value in ConcreteOptimizer.__args__ @pytest.mark.parametrize( diff --git a/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py b/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py index 0e65f815901..94cb64fe964 100644 --- a/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py +++ b/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py @@ -28,10 +28,11 @@ *list(SpaceAdapterType), ], ) -def test_concrete_optimizer_type(space_adapter_type: SpaceAdapterType) -> None: - """Test that all optimizer types are listed in the ConcreteOptimizer constraints.""" - # pylint: disable=no-member - assert space_adapter_type.value in ConcreteSpaceAdapter.__constraints__ +def test_concrete_space_adapter_type(space_adapter_type: SpaceAdapterType) -> None: + """Test that all spaceadapter types are listed in the ConcreteSpaceAdapter + constraints. + """ + assert space_adapter_type.value in ConcreteSpaceAdapter.__args__ @pytest.mark.parametrize( diff --git a/mlos_core/mlos_core/tests/spaces/spaces_test.py b/mlos_core/mlos_core/tests/spaces/spaces_test.py index 7e1b8d0d8fe..46d89e2ee9a 100644 --- a/mlos_core/mlos_core/tests/spaces/spaces_test.py +++ b/mlos_core/mlos_core/tests/spaces/spaces_test.py @@ -179,10 +179,10 @@ class TestFlamlConversion(BaseConversion): def sample( self, - config_space: FlamlSpace, # type: ignore[override] + config_space: OptimizerSpace, n_samples: int = 1, ) -> npt.NDArray: - assert isinstance(config_space, dict) + assert isinstance(config_space, dict) # FlamlSpace assert isinstance(next(iter(config_space.values())), flaml.tune.sample.Domain) ret: npt.NDArray = np.array( [domain.sample(size=n_samples) for domain in config_space.values()] diff --git a/mlos_core/setup.py b/mlos_core/setup.py index 81fe47c292e..b9ba6ac3c40 100644 --- a/mlos_core/setup.py +++ b/mlos_core/setup.py @@ -25,7 +25,8 @@ warning(f"version.py not found, using dummy VERSION={VERSION}") try: - from setuptools_scm import get_version + # Note: setuptools_scm is typically only installed as a part of the build process. + from setuptools_scm import get_version # pyright: ignore[reportMissingImports] version = get_version(root="..", relative_to=__file__, fallback_version=VERSION) if version is not None: diff --git a/mlos_viz/mlos_viz/base.py b/mlos_viz/mlos_viz/base.py index ca990fe33db..8360baf320d 100644 --- a/mlos_viz/mlos_viz/base.py +++ b/mlos_viz/mlos_viz/base.py @@ -493,7 +493,7 @@ def plot_top_n_configs( data=top_n_config_results_df, x=groupby_column, y=orderby_col, - legend=None, + legend=False, ax=axis, ) plt.grid() diff --git a/mlos_viz/setup.py b/mlos_viz/setup.py index c4c56ca9365..e646d549e92 100644 --- a/mlos_viz/setup.py +++ b/mlos_viz/setup.py @@ -25,7 +25,8 @@ warning(f"version.py not found, using dummy VERSION={VERSION}") try: - from setuptools_scm import get_version + # Note: setuptools_scm is typically only installed as a part of the build process. + from setuptools_scm import get_version # pyright: ignore[reportMissingImports] version = get_version(root="..", relative_to=__file__, fallback_version=VERSION) if version is not None: diff --git a/pyproject.toml b/pyproject.toml index a0504b034e7..69861825ae9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,4 +81,10 @@ exclude = [ "doc/build/html", "doc/build/doctrees", "htmlcov", + "mlos_*/build/", + "mlos_*/dist/", + "mlos_*/setup.py" ] +typeCheckingMode = "standard" +reportMissingTypeStubs = false # not as granular as mypy to override on a per package basis +#reportMissingImports = false # somewhat expected for build files and docs