-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enhance iRail class documentation and refactor code for clarity (#45)
- Loading branch information
Showing
3 changed files
with
270 additions
and
310 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
import logging | ||
import time | ||
from types import TracebackType | ||
from typing import Any, Dict, Type | ||
from typing import Any, Type | ||
|
||
from aiohttp import ClientError, ClientResponse, ClientSession | ||
|
||
|
@@ -45,7 +45,7 @@ class iRail: | |
# Available iRail API endpoints and their parameter requirements. | ||
# Each endpoint is configured with required parameters, optional parameters, and XOR | ||
# parameter groups (where exactly one parameter from the group must be provided). | ||
endpoints: Dict[str, Dict[str, Any]] = { | ||
endpoints: dict[str, dict[str, Any]] = { | ||
"stations": {}, | ||
"liveboard": {"xor": ["station", "id"], "optional": ["date", "time", "arrdep", "alerts"]}, | ||
"connections": { | ||
|
@@ -71,8 +71,8 @@ def __init__(self, lang: str = "en", session: ClientSession | None = None) -> No | |
self.last_request_time: float = time.time() | ||
self.lock: Lock = Lock() | ||
self.session: ClientSession | None = session | ||
self._owns_session = session is None # Track ownership | ||
self.etag_cache: Dict[str, str] = {} | ||
self._owns_session = session is None # Track ownership | ||
self.etag_cache: dict[str, str] = {} | ||
logger.info("iRail instance created") | ||
|
||
async def __aenter__(self) -> "iRail": | ||
|
@@ -130,11 +130,10 @@ def clear_etag_cache(self) -> None: | |
logger.info("ETag cache cleared") | ||
|
||
def _refill_tokens(self) -> None: | ||
"""Refill tokens for rate limiting based on elapsed time. | ||
This method refills both standard tokens (max 3) and burst tokens (max 5) | ||
using a token bucket algorithm. The refill rate is 3 tokens per second. | ||
"""Refill tokens for rate limiting using a token bucket algorithm. | ||
- Standard tokens: Refill rate of 3 tokens/second, max 3 tokens. | ||
- Burst tokens: Refilled only when standard tokens are full, max 5 tokens. | ||
""" | ||
logger.debug("Refilling tokens") | ||
current_time: float = time.time() | ||
|
@@ -151,9 +150,9 @@ def _refill_tokens(self) -> None: | |
async def _handle_rate_limit(self) -> None: | ||
"""Handle rate limiting using a token bucket algorithm. | ||
The implementation uses two buckets: | ||
- Normal bucket: 3 tokens/second | ||
- Burst bucket: 5 tokens/second | ||
- Standard tokens: 3 requests/second. | ||
- Burst tokens: Additional 5 requests/second for spikes. | ||
- Waits and refills tokens if both are exhausted. | ||
""" | ||
logger.debug("Handling rate limit") | ||
self._refill_tokens() | ||
|
@@ -169,18 +168,18 @@ async def _handle_rate_limit(self) -> None: | |
else: | ||
self.tokens -= 1 | ||
|
||
def _add_etag_header(self, method: str) -> Dict[str, str]: | ||
def _add_etag_header(self, method: str) -> dict[str, str]: | ||
"""Add ETag header for the given method if a cached ETag is available. | ||
Args: | ||
method (str): The API endpoint for which the header is being generated. | ||
Returns: | ||
Dict[str, str]: A dictionary containing HTTP headers, including the ETag header | ||
dict[str, str]: A dictionary containing HTTP headers, including the ETag header | ||
if a cached value exists. | ||
""" | ||
headers: Dict[str, str] = {"User-Agent": "pyRail (https://github.com/tjorim/pyrail; [email protected])"} | ||
headers: dict[str, str] = {"User-Agent": "pyRail (https://github.com/tjorim/pyrail; [email protected])"} | ||
if method in self.etag_cache: | ||
logger.debug("Adding If-None-Match header with value: %s", self.etag_cache[method]) | ||
headers["If-None-Match"] = self.etag_cache[method] | ||
|
@@ -226,12 +225,12 @@ def _validate_time(self, time: str | None) -> bool: | |
logger.error("Invalid time format. Expected HHMM (e.g., 1430 for 2:30 PM), got: %s", time) | ||
return False | ||
|
||
def _validate_params(self, method: str, params: Dict[str, Any] | None = None) -> bool: | ||
def _validate_params(self, method: str, params: dict[str, Any] | None = None) -> bool: | ||
"""Validate parameters for a specific iRail API endpoint based on predefined requirements. | ||
Args: | ||
method (str): The API endpoint method to validate parameters for. | ||
params (Dict[str, Any], optional): Dictionary of parameters to validate. Defaults to None. | ||
params (dict[str, Any], optional): Dictionary of parameters to validate. Defaults to None. | ||
Returns: | ||
bool: True if parameters are valid, False otherwise. | ||
|
@@ -289,12 +288,12 @@ def _validate_params(self, method: str, params: Dict[str, Any] | None = None) -> | |
|
||
return True | ||
|
||
async def _handle_success_response(self, response: ClientResponse, method: str) -> Dict[str, Any] | None: | ||
async def _handle_success_response(self, response: ClientResponse, method: str) -> dict[str, Any] | None: | ||
"""Handle a successful API response.""" | ||
if "Etag" in response.headers: | ||
self.etag_cache[method] = response.headers["Etag"] | ||
try: | ||
json_data: Dict[str, Any] | None = await response.json() | ||
json_data: dict[str, Any] | None = await response.json() | ||
if not json_data: | ||
logger.warning("Empty response received") | ||
return json_data | ||
|
@@ -303,8 +302,8 @@ async def _handle_success_response(self, response: ClientResponse, method: str) | |
return None | ||
|
||
async def _handle_response( | ||
self, response: ClientResponse, method: str, args: Dict[str, Any] | None = None | ||
) -> Dict[str, Any] | None: | ||
self, response: ClientResponse, method: str, args: dict[str, Any] | None = None | ||
) -> dict[str, Any] | None: | ||
"""Handle the API response based on status code.""" | ||
if response.status == 429: | ||
retry_after: int = int(response.headers.get("Retry-After", 1)) | ||
|
@@ -326,7 +325,7 @@ async def _handle_response( | |
logger.error("Request failed with status code: %s, response: %s", response.status, await response.text()) | ||
return None | ||
|
||
async def _do_request(self, method: str, args: Dict[str, Any] | None = None) -> Dict[str, Any] | None: | ||
async def _do_request(self, method: str, args: dict[str, Any] | None = None) -> dict[str, Any] | None: | ||
"""Send an asynchronous request to the specified iRail API endpoint. | ||
This method handles API requests with rate limiting, parameter validation, | ||
|
@@ -368,7 +367,7 @@ async def _do_request(self, method: str, args: Dict[str, Any] | None = None) -> | |
if args: | ||
params.update(args) | ||
|
||
request_headers: Dict[str, str] = self._add_etag_header(method) | ||
request_headers: dict[str, str] = self._add_etag_header(method) | ||
|
||
try: | ||
async with self.session.get(url, params=params, headers=request_headers) as response: | ||
|
@@ -429,7 +428,7 @@ async def get_liveboard( | |
print(f"Liveboard for Brussels-South: {liveboard}") | ||
""" | ||
extra_params: Dict[str, Any] = { | ||
extra_params: dict[str, Any] = { | ||
"station": station, | ||
"id": id, | ||
"date": date, | ||
|
@@ -473,7 +472,7 @@ async def get_connections( | |
print(f"Connections from Antwerpen-Centraal to Brussel-Centraal: {connections}") | ||
""" | ||
extra_params: Dict[str, Any] = { | ||
extra_params: dict[str, Any] = { | ||
"from": from_station, | ||
"to": to_station, | ||
"date": date, | ||
|
@@ -504,7 +503,7 @@ async def get_vehicle(self, id: str, date: str | None = None, alerts: bool = Fal | |
vehicle_info = await client.get_vehicle("BE.NMBS.IC1832") | ||
""" | ||
extra_params: Dict[str, Any] = {"id": id, "date": date, "alerts": "true" if alerts else "false"} | ||
extra_params: dict[str, Any] = {"id": id, "date": date, "alerts": "true" if alerts else "false"} | ||
vehicle_response_dict = await self._do_request( | ||
"vehicle", {k: v for k, v in extra_params.items() if v is not None} | ||
) | ||
|
@@ -527,7 +526,7 @@ async def get_composition(self, id: str, data: str | None = None) -> Composition | |
composition = await client.get_composition('S51507') | ||
""" | ||
extra_params: Dict[str, str | None] = {"id": id, "data": data} | ||
extra_params: dict[str, str | None] = {"id": id, "data": data} | ||
composition_response_dict = await self._do_request( | ||
"composition", {k: v for k, v in extra_params.items() if v is not None} | ||
) | ||
|
Oops, something went wrong.