From f02d22d50e7025978315c24d181a84e23f64b6b7 Mon Sep 17 00:00:00 2001 From: Louis Brunner Date: Tue, 21 Jan 2025 18:28:56 +0000 Subject: [PATCH] feat(python): add support for the upcoming auth update (#207) Co-authored-by: Dominic Reber <71256590+domire8@users.noreply.github.com> Co-authored-by: Enrico Eberhard <32450951+eeberhard@users.noreply.github.com> --- python/CHANGELOG.md | 6 +- python/examples/with_api_key.py | 12 + python/pyproject.toml | 30 +- python/src/aica_api/client.py | 456 +++++++++++++++++++++--------- python/src/aica_api/sio_client.py | 40 ++- 5 files changed, 395 insertions(+), 149 deletions(-) create mode 100644 python/examples/with_api_key.py diff --git a/python/CHANGELOG.md b/python/CHANGELOG.md index 12299c97..8f455113 100644 --- a/python/CHANGELOG.md +++ b/python/CHANGELOG.md @@ -12,6 +12,10 @@ Release Versions: - [1.0.1](#101) - [1.0.0](#100) +## Upcoming changes + +- feat(python): add support for the upcoming auth update (#207) + ## 3.0.0 Version 3.0.0 of the AICA API client is compatible with the new AICA Core version 4.0. It supports additional methods to @@ -65,4 +69,4 @@ Version 1.0.1 fixes a relative import issue. ## 1.0.0 Version 1.0.0 marks the version for the first software release. From now on, all changes must be well documented and -semantic versioning must be maintained to reflect patch, minor or major changes. \ No newline at end of file +semantic versioning must be maintained to reflect patch, minor or major changes. diff --git a/python/examples/with_api_key.py b/python/examples/with_api_key.py new file mode 100644 index 00000000..21e6202f --- /dev/null +++ b/python/examples/with_api_key.py @@ -0,0 +1,12 @@ +import os + +from aica_api.client import AICA + +client = AICA( + api_key=os.getenv('AICA_API_KEY'), +) + +assert client.check() +print(f'Application state: {client.get_application_state().text}') +print(f'Application state: {client.load_component("def").text}') +print(client.wait_for_component('abc', 'loaded')) diff --git a/python/pyproject.toml b/python/pyproject.toml index 0b4fdb67..809b374f 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -5,25 +5,33 @@ build-backend = "hatchling.build" [project] name = "aica_api" version = "3.0.0" -authors = [ - { name="Enrico Eberhard", email="enrico@aica.tech" }, -] +authors = [{ name = "Enrico Eberhard", email = "enrico@aica.tech" }] description = "A client utility for the AICA API" readme = "README.md" requires-python = ">=3.7" dependencies = [ - "deprecation ~= 2.1.0", - "python-socketio[client] ~= 5.11.0", - "pyyaml ~= 6.0.1", - "requests ~= 2.28.1", - "semver ~= 3.0.2" + "deprecation ~= 2.1.0", + "python-socketio[client] ~= 5.11.0", + "pyyaml ~= 6.0.1", + "requests ~= 2.28.1", + "semver ~= 3.0.2", ] classifiers = [ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", ] [project.urls] "Homepage" = "https://github.com/aica-technology/api" "Bug Tracker" = "https://github.com/aica-technology/api/issues" + + +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +ignore = ["D212", "D400", "D415", "FA100", "G004"] + +[tool.ruff.format] +quote-style = "single" diff --git a/python/src/aica_api/client.py b/python/src/aica_api/client.py index b24a8a03..23197e28 100644 --- a/python/src/aica_api/client.py +++ b/python/src/aica_api/client.py @@ -1,28 +1,31 @@ -from deprecation import deprecated +import importlib.metadata +import os +import urllib.parse from functools import wraps -from logging import getLogger, INFO -from typing import Union, List +from logging import INFO, getLogger +from typing import List, Optional, Union -import os import requests import semver import yaml - -import importlib.metadata +from deprecation import deprecated from aica_api.sio_client import read_until - CLIENT_VERSION = importlib.metadata.version('aica_api') class AICA: - """ - API client for AICA applications. - """ + """API client for AICA applications.""" # noinspection HttpUrlsUsage - def __init__(self, url: str = 'localhost', port: Union[str, int] = '8080', log_level = INFO): + def __init__( + self, + url: str = 'localhost', + port: Union[str, int] = '8080', + log_level=INFO, + api_key: Optional[str] = None, + ): """ Construct the API client with the address of the AICA application. @@ -44,8 +47,10 @@ def __init__(self, url: str = 'localhost', port: Union[str, int] = '8080', log_l self._logger.setLevel(log_level) self._protocol = None self._core_version = None + self.__api_key = api_key + self.__token = None - def _endpoint(self, endpoint=''): + def _endpoint(self, endpoint: str = '') -> str: """ Build the request address for a given endpoint. @@ -54,7 +59,93 @@ def _endpoint(self, endpoint=''): """ if self._protocol is None: self.protocol() - return f'{self._address}/{self._protocol}/{endpoint}' + return self.__raw_endpoint(f'{self._protocol}/{endpoint}') + + def __raw_endpoint(self, endpoint: str) -> str: + return f'{self._address}/{endpoint}' + + def __ensure_token(self) -> None: + """Authenticate with the API and store the result in self.__token.""" + err = ' The function call may fail due to lack of authentication.' + has_version, is_compatible = self._check_version( + None, + '>=4.3.0', + err_incompatible=err, + err_undefined=err, + ) + if not has_version or not is_compatible: + return + if self.__token is not None: + return + res = requests.post(self._endpoint('auth/login'), headers={'Authorization': f'Bearer {self.__api_key}'}) + res.raise_for_status() + self.__token = res.json()['token'] + + def _sio_auth(self) -> Optional[str]: + # FIXME: doesn't handle token expiration + if self.__api_key is not None: + self.__ensure_token() + return self.__token + + @staticmethod + def _safe_uri(uri: str) -> str: + """ + Make a string safe for use in a URI by encoding special characters. + + :param uri: The URI to sanitize + :return: The sanitized URI + """ + return urllib.parse.quote_plus(uri) + + def _request( + self, + method: str, + endpoint: str, + *, + params: Optional[dict] = None, + json: Optional[dict] = None, + ) -> requests.Response: + headers = None + retry = 2 + res = None + while retry > 0: + if self.__api_key is not None: + self.__ensure_token() + headers = {'Authorization': f'Bearer {self.__token}'} + res = requests.request( + method, self._endpoint(endpoint), params=params, json=json, headers=headers, timeout=5 + ) + retry -= 1 + if res.status_code == 401: + if self.__api_key is None: + break + self.__token = None + return res + + def _check_version( + self, + name: Optional[str], + requirement: str, + *, + err_undefined: str = '', + err_incompatible: str = '', + ) -> tuple[bool, bool]: + fname = f'The function {name}' if name is not None else 'This function' + if self._core_version is None and self.core_version() is None: + self._logger.warning( + f'{fname} requires AICA Core version {requirement}, ' + f'but the current Core version is unknown.{err_undefined}' + ) + return False, False + + if not semver.match(self._core_version, requirement): + self._logger.error( + f'{fname} requires AICA Core version {requirement}, ' + f'but the current AICA Core version is {self._core_version}.{err_incompatible}' + ) + return True, False + + return True, True @staticmethod def _requires_core_version(version): @@ -73,18 +164,14 @@ def my_new_endpoint() def decorator(func): @wraps(func) def wrapper(self, *args, **kwargs): - if self._core_version is None and self.core_version() is None: - self._logger.warning(f'The function {func.__name__} requires AICA Core version {version}, ' - f'but the current Core version is unknown. The function call behavior ' - f'may be undefined.') - return func(self, *args, **kwargs) - - if not semver.match(self._core_version, version): - self._logger.error(f'The function {func.__name__} requires AICA Core version {version}, ' - f'but the current AICA Core version is {self._core_version}. The function ' - f'will not be called') + has_version, is_compatible = self._check_version( + func.__name__, + version, + err_undefined=' The function call behavior may be undefined.', + err_incompatible=' The function will not be called.', + ) + if not is_compatible: return None - return func(self, *args, **kwargs) return wrapper @@ -102,13 +189,14 @@ def api_version(self) -> Union[str, None]: self._logger.debug(f'AICA API server version identified as {api_server_version}') return api_server_version except requests.exceptions.RequestException: - self._logger.error(f'Error connecting to the API server at {self._address}! ' - f'Check that AICA Core is running and configured with the right address.') - except requests.exceptions.JSONDecodeError: - self._logger.error(f'Error getting version details! Expected JSON data in the response body.') + self._logger.error( + f'Error connecting to the API server at {self._address}! ' + f'Check that AICA Core is running and configured with the right address.' + ) except KeyError as e: self._logger.error( - f'Error getting version details! Expected a map of `signed_packages` to include `aica_api_server`: {e}') + f'Error getting version details! Expected a map of `signed_packages` to include `aica_api_server`: {e}' + ) return None def core_version(self) -> Union[str, None]: @@ -119,14 +207,18 @@ def core_version(self) -> Union[str, None]: """ core_version = None try: - core_version = requests.get(f'{self._address}/version').json() + core_version = requests.get(self.__raw_endpoint('version')).json() except requests.exceptions.RequestException: - self._logger.error(f'Error connecting to the API server at {self._address}! ' - f'Check that AICA Core is running and configured with the right address.') + self._logger.error( + f'Error connecting to the API server at {self._address}! ' + f'Check that AICA Core is running and configured with the right address.' + ) if not semver.Version.is_valid(f'{core_version}'): - self._logger.warning(f'Invalid format for the AICA Core version {core_version}! This could be a result ' - f'of an internal or pre-release build of AICA Core.') + self._logger.warning( + f'Invalid format for the AICA Core version {core_version}! This could be a result ' + f'of an internal or pre-release build of AICA Core.' + ) core_version = None self._core_version = core_version @@ -148,12 +240,14 @@ def protocol(self) -> Union[str, None]: :return: The version of the API protocol or None in case of connection failure """ try: - self._protocol = requests.get(f'{self._address}/protocol').json() + self._protocol = requests.get(self.__raw_endpoint('protocol')).json() self._logger.debug(f'API protocol version identified as {self._protocol}') return self._protocol except requests.exceptions.RequestException: - self._logger.error(f'Error connecting to the API server at {self._address}! ' - f'Check that AICA Core is running and configured with the right address.') + self._logger.error( + f'Error connecting to the API server at {self._address}! ' + f'Check that AICA Core is running and configured with the right address.' + ) return None def check(self) -> bool: @@ -164,9 +258,11 @@ def check(self) -> bool: """ if self._protocol is None and self.protocol() is None: return False - elif self._protocol != "v2": - self._logger.error(f'The detected API protocol version {self._protocol} is not supported by this client' - f'(v{self.client_version()}). Please refer to the compatibility table.') + elif self._protocol != 'v2': + self._logger.error( + f'The detected API protocol version {self._protocol} is not supported by this client' + f'(v{self.client_version()}). Please refer to the compatibility table.' + ) return False if self._core_version is None and self.core_version() is None: @@ -175,25 +271,39 @@ def check(self) -> bool: version_info = semver.parse_version_info(self._core_version) if version_info.major == 4: + if version_info.minor > 2 and self.__api_key is None: + self._logger.warning( + f'The detected AICA Core version v{self._core_version} requires an API key for ' + f'authentication. Please provide an API key to the client for this version.' + ) + return False return True elif version_info.major > 4: - self._logger.error(f'The detected AICA Core version v{self._core_version} is newer than the maximum AICA ' - f'Core version supported by this client (v{self.client_version()}). Please upgrade the ' - f'Python API client version for newer versions of Core.') + self._logger.error( + f'The detected AICA Core version v{self._core_version} is newer than the maximum AICA ' + f'Core version supported by this client (v{self.client_version()}). Please upgrade the ' + f'Python API client version for newer versions of Core.' + ) return False elif version_info.major == 3: - self._logger.error(f'The detected AICA Core version v{self._core_version} is older than the minimum AICA ' - f'Core version supported by this client (v{self.client_version()}). Please downgrade ' - f'the Python API client to version v2.1.0 for API server versions v3.X.') + self._logger.error( + f'The detected AICA Core version v{self._core_version} is older than the minimum AICA ' + f'Core version supported by this client (v{self.client_version()}). Please downgrade ' + f'the Python API client to version v2.1.0 for API server versions v3.X.' + ) return False elif version_info.major == 2: - self._logger.error(f'The detected AICA Core version v{self._core_version} is older than the minimum AICA ' - f'Core version supported by this client (v{self.client_version()}). Please downgrade ' - f'the Python API client to version v1.2.0 for API server versions v2.X.') + self._logger.error( + f'The detected AICA Core version v{self._core_version} is older than the minimum AICA ' + f'Core version supported by this client (v{self.client_version()}). Please downgrade ' + f'the Python API client to version v1.2.0 for API server versions v2.X.' + ) return False else: - self._logger.error(f'The detected AICA Core version v{self._core_version} is deprecated and not supported ' - f'by this API client!') + self._logger.error( + f'The detected AICA Core version v{self._core_version} is deprecated and not supported ' + f'by this API client!' + ) return False def license(self) -> requests.Response: @@ -203,7 +313,7 @@ def license(self) -> requests.Response: Use `license().json()` to extract the map of license details from the response object. """ - return requests.get(self._endpoint('license')) + return self._request('GET', 'license') def component_descriptions(self) -> requests.Response: """ @@ -211,7 +321,7 @@ def component_descriptions(self) -> requests.Response: Use `component_descriptions().json()` to extract the map of descriptions from the response object. """ - return requests.get(self._endpoint('components')) + return self._request('GET', 'components') def controller_descriptions(self) -> requests.Response: """ @@ -219,10 +329,14 @@ def controller_descriptions(self) -> requests.Response: Use `controller_descriptions().json()` to extract the map of descriptions from the response object. """ - return requests.get(self._endpoint('controllers')) + return self._request('GET', 'controllers') - @deprecated(deprecated_in='3.0.0', removed_in='4.0.0', current_version=CLIENT_VERSION, - details='Use the call_component_service function instead') + @deprecated( + deprecated_in='3.0.0', + removed_in='4.0.0', + current_version=CLIENT_VERSION, + details='Use the call_component_service function instead', + ) def call_service(self, component: str, service: str, payload: str) -> requests.Response: """ Call a service on a component. @@ -241,9 +355,9 @@ def call_component_service(self, component: str, service: str, payload: str) -> :param service: The name of the service :param payload: The service payload, formatted according to the respective service description """ - endpoint = 'application/components/' + component + '/service/' + service - data = {"payload": payload} - return requests.put(self._endpoint(endpoint), json=data) + endpoint = f'application/components/{AICA._safe_uri(component)}/service/{AICA._safe_uri(service)}' + data = {'payload': payload} + return self._request('PUT', endpoint, json=data) def call_controller_service(self, hardware: str, controller: str, service: str, payload: str) -> requests.Response: """ @@ -254,15 +368,15 @@ def call_controller_service(self, hardware: str, controller: str, service: str, :param service: The name of the service :param payload: The service payload, formatted according to the respective service description """ - endpoint = 'application/hardware/' + hardware + '/controller/' + controller + '/service/' + service - data = {"payload": payload} - return requests.put(self._endpoint(endpoint), json=data) + endpoint = f'application/hardware/{AICA._safe_uri(hardware)}/controller/{AICA._safe_uri(controller)}/service/{AICA._safe_uri(service)}' + data = {'payload': payload} + return self._request('PUT', endpoint, json=data) def get_application_state(self) -> requests.Response: """ Get the application state """ - return requests.get(self._endpoint('application/state')) + return self._request('GET', 'application/state') def load_component(self, component: str) -> requests.Response: """ @@ -271,8 +385,8 @@ def load_component(self, component: str) -> requests.Response: :param component: The name of the component to load """ - endpoint = 'application/components/' + component - return requests.put(self._endpoint(endpoint)) + endpoint = f'application/components/{AICA._safe_uri(component)}' + return self._request('PUT', endpoint) def load_controller(self, hardware: str, controller: str) -> requests.Response: """ @@ -282,8 +396,8 @@ def load_controller(self, hardware: str, controller: str) -> requests.Response: :param hardware: The name of the hardware interface :param controller: The name of the controller to load """ - endpoint = 'application/hardware/' + hardware + '/controller/' + controller - return requests.put(self._endpoint(endpoint)) + endpoint = f'application/hardware/{AICA._safe_uri(hardware)}/controller/{AICA._safe_uri(controller)}' + return self._request('PUT', endpoint) def load_hardware(self, hardware: str) -> requests.Response: """ @@ -292,16 +406,16 @@ def load_hardware(self, hardware: str) -> requests.Response: :param hardware: The name of the hardware interface to load """ - endpoint = 'application/hardware/' + hardware - return requests.put(self._endpoint(endpoint)) + endpoint = f'application/hardware/{AICA._safe_uri(hardware)}' + return self._request('PUT', endpoint) def pause_application(self) -> requests.Response: """ Pause the current application. This prevents any events from being triggered or handled, but does not pause the periodic execution of active components. """ - endpoint = 'application/state/transition?action=pause' - return requests.put(self._endpoint(endpoint)) + endpoint = 'application/state/transition' + return self._request('PUT', endpoint, params={'action': 'pause'}) def set_application(self, payload: str) -> requests.Response: """ @@ -309,30 +423,32 @@ def set_application(self, payload: str) -> requests.Response: :param payload: The filepath of an application or the application content as a YAML-formatted string """ - if payload.endswith(".yaml") and os.path.isfile(payload): - with open(payload, "r") as file: + if payload.endswith('.yaml') and os.path.isfile(payload): + with open(payload, 'r') as file: payload = yaml.safe_load(file) - data = { - "payload": payload - } - return requests.put(self._endpoint('application'), json=data) + data = {'payload': payload} + return self._request('PUT', 'application', json=data) def start_application(self) -> requests.Response: """ Start the AICA application engine. """ endpoint = 'application/state/transition?action=start' - return requests.put(self._endpoint(endpoint)) + return self._request('PUT', endpoint) def stop_application(self) -> requests.Response: """ Stop and reset the AICA application engine, removing all components and hardware interfaces. """ - endpoint = 'application/state/transition?action=stop' - return requests.put(self._endpoint(endpoint)) + endpoint = 'application/state/transition' + return self._request('PUT', endpoint, params={'action': 'stop'}) - def set_component_parameter(self, component: str, parameter: str, value: Union[ - bool, int, float, bool, List[bool], List[int], List[float], List[str]]) -> requests.Response: + def set_component_parameter( + self, + component: str, + parameter: str, + value: Union[bool, int, float, bool, List[bool], List[int], List[float], List[str]], + ) -> requests.Response: """ Set a parameter on a component. @@ -340,12 +456,17 @@ def set_component_parameter(self, component: str, parameter: str, value: Union[ :param parameter: The name of the parameter :param value: The value of the parameter """ - endpoint = 'application/components/' + component + '/parameter/' + parameter - data = {"value": value} - return requests.put(self._endpoint(endpoint), json=data) + endpoint = f'application/components/{AICA._safe_uri(component)}/parameter/{AICA._safe_uri(parameter)}' + data = {'value': value} + return self._request('PUT', endpoint, json=data) - def set_controller_parameter(self, hardware: str, controller: str, parameter: str, value: Union[ - bool, int, float, bool, List[bool], List[int], List[float], List[str]]) -> requests.Response: + def set_controller_parameter( + self, + hardware: str, + controller: str, + parameter: str, + value: Union[bool, int, float, bool, List[bool], List[int], List[float], List[str]], + ) -> requests.Response: """ Set a parameter on a controller. @@ -354,9 +475,9 @@ def set_controller_parameter(self, hardware: str, controller: str, parameter: st :param parameter: The name of the parameter :param value: The value of the parameter """ - endpoint = 'application/hardware/' + hardware + '/controller/' + controller + '/parameter/' + parameter - data = {"value": value} - return requests.put(self._endpoint(endpoint), json=data) + endpoint = f'application/hardware/{AICA._safe_uri(hardware)}/controller/{AICA._safe_uri(controller)}/parameter/{AICA._safe_uri(parameter)}' + data = {'value': value} + return self._request('PUT', endpoint, json=data) def set_lifecycle_transition(self, component: str, transition: str) -> requests.Response: """ @@ -370,12 +491,16 @@ def set_lifecycle_transition(self, component: str, transition: str) -> requests. :param component: The name of the component :param transition: The lifecycle transition label """ - endpoint = 'application/components/' + component + '/lifecycle/transition' - data = {"transition": transition} - return requests.put(self._endpoint(endpoint), json=data) + endpoint = f'application/components/{AICA._safe_uri(component)}/lifecycle/transition' + data = {'transition': transition} + return self._request('PUT', endpoint, json=data) - def switch_controllers(self, hardware: str, activate: Union[None, List[str]] = None, - deactivate: Union[None, List[str]] = None) -> requests.Response: + def switch_controllers( + self, + hardware: str, + activate: Union[None, List[str]] = None, + deactivate: Union[None, List[str]] = None, + ) -> requests.Response: """ Activate and deactivate the controllers for a given hardware interface. @@ -383,12 +508,12 @@ def switch_controllers(self, hardware: str, activate: Union[None, List[str]] = N :param activate: A list of controllers to activate :param deactivate: A list of controllers to deactivate """ - endpoint = 'application/hardware/' + hardware + '/controllers' + endpoint = f'application/hardware/{AICA._safe_uri(hardware)}/controllers' params = { - "activate": [] if not activate else activate, - "deactivate": [] if not deactivate else deactivate + 'activate': [] if not activate else activate, + 'deactivate': [] if not deactivate else deactivate, } - return requests.put(self._endpoint(endpoint), params=params) + return self._request('PUT', endpoint, params=params) def unload_component(self, component: str) -> requests.Response: """ @@ -397,8 +522,8 @@ def unload_component(self, component: str) -> requests.Response: :param component: The name of the component to unload """ - endpoint = 'application/components/' + component - return requests.delete(self._endpoint(endpoint)) + endpoint = f'application/components/{AICA._safe_uri(component)}' + return self._request('DELETE', endpoint) def unload_controller(self, hardware: str, controller: str) -> requests.Response: """ @@ -408,8 +533,8 @@ def unload_controller(self, hardware: str, controller: str) -> requests.Response :param hardware: The name of the hardware interface :param controller: The name of the controller to unload """ - endpoint = 'application/hardware/' + hardware + '/controller/' + controller - return requests.delete(self._endpoint(endpoint)) + endpoint = f'application/hardware/{AICA._safe_uri(hardware)}/controller/{AICA._safe_uri(controller)}' + return self._request('DELETE', endpoint) def unload_hardware(self, hardware: str) -> requests.Response: """ @@ -418,14 +543,14 @@ def unload_hardware(self, hardware: str) -> requests.Response: :param hardware: The name of the hardware interface to unload """ - endpoint = 'application/hardware/' + hardware - return requests.delete(self._endpoint(endpoint)) + endpoint = f'application/hardware/{AICA._safe_uri(hardware)}' + return self._request('DELETE', endpoint) def get_application(self) -> requests.Response: """ Get the currently set application """ - return requests.get(self._endpoint("application")) + return self._request('GET', 'application') @_requires_core_version('>=4.0.0') def manage_sequence(self, sequence_name: str, action: str): @@ -437,8 +562,8 @@ def manage_sequence(self, sequence_name: str, action: str): :param sequence_name: The name of the sequence :param action: The sequence action label """ - endpoint = f"application/sequences/{sequence_name}?action={action}" - return requests.put(self._endpoint(endpoint)) + endpoint = f'application/sequences/{AICA._safe_uri(sequence_name)}' + return self._request('PUT', endpoint, params={'action': AICA._safe_uri(action)}) def wait_for_component(self, component: str, state: str, timeout: Union[None, int, float] = None) -> bool: """ @@ -450,8 +575,17 @@ def wait_for_component(self, component: str, state: str, timeout: Union[None, in :param timeout: Timeout duration in seconds. If set to None, block indefinitely :return: True if the component is in the intended state before the timeout duration, False otherwise """ - return read_until(lambda data: data[component]['state'] == state, url=self._address, namespace='/v2/components', - event='component_data', timeout=timeout) is not None + return ( + read_until( + lambda data: data[component]['state'] == state, + url=self._address, + namespace='/v2/components', + event='component_data', + timeout=timeout, + auth=self._sio_auth(), + ) + is not None + ) @_requires_core_version('>=3.1.0') def wait_for_hardware(self, hardware: str, state: str, timeout: Union[None, int, float] = None) -> bool: @@ -464,12 +598,26 @@ def wait_for_hardware(self, hardware: str, state: str, timeout: Union[None, int, :param timeout: Timeout duration in seconds. If set to None, block indefinitely :return: True if the hardware is in the intended state before the timeout duration, False otherwise """ - return read_until(lambda data: data[hardware]['state'] == state, url=self._address, namespace='/v2/hardware', - event='hardware_data', timeout=timeout) is not None + return ( + read_until( + lambda data: data[hardware]['state'] == state, + url=self._address, + namespace='/v2/hardware', + event='hardware_data', + timeout=timeout, + auth=self._sio_auth(), + ) + is not None + ) @_requires_core_version('>=3.1.0') - def wait_for_controller(self, hardware: str, controller: str, state: str, - timeout: Union[None, int, float] = None) -> bool: + def wait_for_controller( + self, + hardware: str, + controller: str, + state: str, + timeout: Union[None, int, float] = None, + ) -> bool: """ Wait for a controller to be in a particular state. Controllers can be in any of the following states: ['unloaded', 'loaded', 'active', 'finalized'] @@ -480,12 +628,22 @@ def wait_for_controller(self, hardware: str, controller: str, state: str, :param timeout: Timeout duration in seconds. If set to None, block indefinitely :return: True if the controller is in the intended state before the timeout duration, False otherwise """ - return read_until(lambda data: data[hardware]['controllers'][controller]['state'] == state, url=self._address, - namespace='/v2/hardware', event='hardware_data', timeout=timeout) is not None + return ( + read_until( + lambda data: data[hardware]['controllers'][controller]['state'] == state, + url=self._address, + namespace='/v2/hardware', + event='hardware_data', + timeout=timeout, + auth=self._sio_auth(), + ) + is not None + ) @_requires_core_version('>=3.1.0') - def wait_for_component_predicate(self, component: str, predicate: str, - timeout: Union[None, int, float] = None) -> bool: + def wait_for_component_predicate( + self, component: str, predicate: str, timeout: Union[None, int, float] = None + ) -> bool: """ Wait until a component predicate is true. @@ -494,12 +652,26 @@ def wait_for_component_predicate(self, component: str, predicate: str, :param timeout: Timeout duration in seconds. If set to None, block indefinitely :return: True if the predicate is true before the timeout duration, False otherwise """ - return read_until(lambda data: data[component]['predicates'][predicate], url=self._address, - namespace='/v2/components', event='component_data', timeout=timeout) is not None + return ( + read_until( + lambda data: data[component]['predicates'][predicate], + url=self._address, + namespace='/v2/components', + event='component_data', + timeout=timeout, + auth=self._sio_auth(), + ) + is not None + ) @_requires_core_version('>=3.1.0') - def wait_for_controller_predicate(self, hardware: str, controller: str, predicate: str, - timeout: Union[None, int, float] = None) -> bool: + def wait_for_controller_predicate( + self, + hardware: str, + controller: str, + predicate: str, + timeout: Union[None, int, float] = None, + ) -> bool: """ Wait until a controller predicate is true. @@ -509,9 +681,17 @@ def wait_for_controller_predicate(self, hardware: str, controller: str, predicat :param timeout: Timeout duration in seconds. If set to None, block indefinitely :return: True if the predicate is true before the timeout duration, False otherwise """ - return read_until(lambda data: data[hardware]['controllers'][controller]['predicates'][predicate], - url=self._address, namespace='/v2/hardware', event='hardware_data', - timeout=timeout) is not None + return ( + read_until( + lambda data: data[hardware]['controllers'][controller]['predicates'][predicate], + url=self._address, + namespace='/v2/hardware', + event='hardware_data', + timeout=timeout, + auth=self._sio_auth(), + ) + is not None + ) def wait_for_condition(self, condition: str, timeout=None) -> bool: """ @@ -521,8 +701,17 @@ def wait_for_condition(self, condition: str, timeout=None) -> bool: :param timeout: Timeout duration in seconds. If set to None, block indefinitely :return: True if the condition is true before the timeout duration, False otherwise """ - return read_until(lambda data: data[condition], url=self._address, namespace='/v2/conditions', - event='conditions', timeout=timeout) is not None + return ( + read_until( + lambda data: data[condition], + url=self._address, + namespace='/v2/conditions', + event='conditions', + timeout=timeout, + auth=self._sio_auth(), + ) + is not None + ) @_requires_core_version('>=4.0.0') def wait_for_sequence(self, sequence: str, state: str, timeout=None) -> bool: @@ -535,5 +724,14 @@ def wait_for_sequence(self, sequence: str, state: str, timeout=None) -> bool: :param timeout: Timeout duration in seconds. If set to None, block indefinitely :return: True if the condition is true before the timeout duration, False otherwise """ - return read_until(lambda data: data[sequence]['state'] == state, url=self._address, namespace='/v2/sequences', - event='sequences', timeout=timeout) is not None + return ( + read_until( + lambda data: data[sequence]['state'] == state, + url=self._address, + namespace='/v2/sequences', + event='sequences', + timeout=timeout, + auth=self._sio_auth(), + ) + is not None + ) diff --git a/python/src/aica_api/sio_client.py b/python/src/aica_api/sio_client.py index d8488669..609b5abf 100644 --- a/python/src/aica_api/sio_client.py +++ b/python/src/aica_api/sio_client.py @@ -1,12 +1,17 @@ import time -from typing import Callable, Union +from typing import Callable, Optional, Union import socketio from socketio.exceptions import ConnectionError, TimeoutError -def read_once(url: str = 'http://0.0.0.0:5000', namespace: str = '/', event: str = '*', - timeout: Union[None, int, float] = 5) -> Union[None, dict]: +def read_once( + url: str = 'http://0.0.0.0:5000', + namespace: str = '/', + event: str = '*', + timeout: Union[None, int, float] = 5, + auth: Optional[str] = None, +) -> Union[None, dict]: """ Listen for and return the first Socket.IO event on a specified URL and namespace within a time limited period @@ -16,11 +21,24 @@ def read_once(url: str = 'http://0.0.0.0:5000', namespace: str = '/', event: str :param timeout: The timeout in seconds to listen for an event. If set to None, block indefinitely :return: The received event data, or None if the connection or event listener timed out """ - return read_until(lambda data: True, url=url, namespace=namespace, event=event, timeout=timeout) + return read_until( + lambda data: True, + url=url, + namespace=namespace, + event=event, + timeout=timeout, + auth=auth, + ) -def read_until(callback: Callable[[dict], bool], url: str = 'http://0.0.0.0:5000', namespace: str = '/', - event: str = '*', timeout: Union[None, int, float] = 5) -> Union[None, dict]: +def read_until( + callback: Callable[[dict], bool], + url: str = 'http://0.0.0.0:5000', + namespace: str = '/', + event: str = '*', + timeout: Union[None, int, float] = 5, + auth: Optional[str] = None, +) -> Union[None, dict]: """ Listen for and return the first Socket.IO event that validates against a callback function on a specified URL and namespace within a time limited period @@ -38,9 +56,15 @@ def user_callback(data: dict) -> bool: with socketio.SimpleClient() as sio: try: - sio.connect(url, namespace=namespace, wait_timeout=timeout) + sio.connect( + url, + socketio_path='/ws/socket.io', + namespace=namespace, + wait_timeout=timeout, + auth={'auth': f'Bearer {auth}'}, + ) except ConnectionError: - print(f"Could not connect!") + print('Could not connect!') return None start_time = time.time()