Skip to content

Commit

Permalink
adapt EnlyzeClient to use PlatformApiClient
Browse files Browse the repository at this point in the history
  • Loading branch information
denizs committed Aug 16, 2024
1 parent 34659d6 commit 35202b2
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 125 deletions.
55 changes: 24 additions & 31 deletions src/enlyze/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
from typing import Any, Iterator, Mapping, Optional, Sequence, Tuple, Union
from uuid import UUID

import enlyze.api_clients.timeseries.models as timeseries_api_models
import enlyze.api_client.models as platform_api_models
import enlyze.models as user_models
from enlyze.api_clients.production_runs.client import ProductionRunsApiClient
from enlyze.api_clients.production_runs.models import ProductionRun
from enlyze.api_clients.timeseries.client import TimeseriesApiClient
from enlyze.api_client.client import PlatformApiClient
from enlyze.constants import (
ENLYZE_BASE_URL,
MAXIMUM_NUMBER_OF_VARIABLES_PER_TIMESERIES_REQUEST,
Expand All @@ -28,8 +26,8 @@


def _get_timeseries_data_from_pages(
pages: Iterator[timeseries_api_models.TimeseriesData],
) -> Optional[timeseries_api_models.TimeseriesData]:
pages: Iterator[platform_api_models.TimeseriesData],
) -> Optional[platform_api_models.TimeseriesData]:
try:
timeseries_data = next(pages)
except StopIteration:
Expand Down Expand Up @@ -90,19 +88,14 @@ class EnlyzeClient:
"""

def __init__(self, token: str, *, _base_url: str | None = None) -> None:
self._timeseries_api_client = TimeseriesApiClient(
token=token,
base_url=_base_url or ENLYZE_BASE_URL,
)
self._production_runs_api_client = ProductionRunsApiClient(
token=token,
base_url=_base_url or ENLYZE_BASE_URL,
self._platform_api_client = PlatformApiClient(
token=token, base_url=_base_url or ENLYZE_BASE_URL
)

def _get_sites(self) -> Iterator[timeseries_api_models.Site]:
def _get_sites(self) -> Iterator[platform_api_models.Site]:
"""Get all sites from the API"""
return self._timeseries_api_client.get_paginated(
"sites", timeseries_api_models.Site
return self._platform_api_client.get_paginated(
"sites", platform_api_models.Site
)

@cache
Expand All @@ -119,10 +112,10 @@ def get_sites(self) -> list[user_models.Site]:
"""
return [site.to_user_model() for site in self._get_sites()]

def _get_machines(self) -> Iterator[timeseries_api_models.Machine]:
def _get_machines(self) -> Iterator[platform_api_models.Machine]:
"""Get all machines from the API"""
return self._timeseries_api_client.get_paginated(
"appliances", timeseries_api_models.Machine
return self._platform_api_client.get_paginated(
"machines", platform_api_models.Machine
)

@cache
Expand All @@ -144,13 +137,13 @@ def get_machines(
"""

if site:
sites_by_id = {site._id: site}
sites_by_uuid = {site.uuid: site}
else:
sites_by_id = {site._id: site for site in self.get_sites()}
sites_by_uuid = {site.uuid: site for site in self.get_sites()}

machines = []
for machine_api in self._get_machines():
site_ = sites_by_id.get(machine_api.site)
site_ = sites_by_uuid.get(machine_api.site)
if not site_:
continue

Expand All @@ -160,11 +153,11 @@ def get_machines(

def _get_variables(
self, machine_uuid: UUID
) -> Iterator[timeseries_api_models.Variable]:
) -> Iterator[platform_api_models.Variable]:
"""Get variables for a machine from the API."""
return self._timeseries_api_client.get_paginated(
return self._platform_api_client.get_paginated(
"variables",
timeseries_api_models.Variable,
platform_api_models.Variable,
params={"appliance": str(machine_uuid)},
)

Expand Down Expand Up @@ -195,7 +188,7 @@ def _get_paginated_timeseries(
end: datetime,
variables: Sequence[str],
resampling_interval: Optional[int],
) -> Iterator[timeseries_api_models.TimeseriesData]:
) -> Iterator[platform_api_models.TimeseriesData]:
params: dict[str, Any] = {
"appliance": machine_uuid,
"start_datetime": start.isoformat(),
Expand All @@ -206,8 +199,8 @@ def _get_paginated_timeseries(
if resampling_interval:
params["resampling_interval"] = resampling_interval

return self._timeseries_api_client.get_paginated(
"timeseries", timeseries_api_models.TimeseriesData, params=params
return self._platform_api_client.get_paginated(
"timeseries", platform_api_models.TimeseriesData, params=params
)

def _get_timeseries(
Expand Down Expand Up @@ -356,7 +349,7 @@ def _get_production_runs(
machine: Optional[UUID] = None,
start: Optional[datetime] = None,
end: Optional[datetime] = None,
) -> Iterator[ProductionRun]:
) -> Iterator[platform_api_models.ProductionRun]:
"""Get production runs from the API."""

filters = {
Expand All @@ -367,8 +360,8 @@ def _get_production_runs(
"end": end.isoformat() if end else None,
}
params = {k: v for k, v in filters.items() if v is not None}
return self._production_runs_api_client.get_paginated(
"production-runs", ProductionRun, params=params
return self._platform_api_client.get_paginated(
"production-runs", platform_api_models.ProductionRun, params=params
)

def get_production_runs(
Expand Down
Loading

0 comments on commit 35202b2

Please sign in to comment.