Skip to content

Commit

Permalink
Enhance iRail class documentation and refactor code for clarity (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
tjorim authored Jan 28, 2025
1 parent 1c87930 commit 161e295
Show file tree
Hide file tree
Showing 3 changed files with 270 additions and 310 deletions.
51 changes: 25 additions & 26 deletions pyrail/irail.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
import time
from types import TracebackType
from typing import Any, Dict, Type
from typing import Any, Type

from aiohttp import ClientError, ClientResponse, ClientSession

Expand Down Expand Up @@ -45,7 +45,7 @@ class iRail:
# Available iRail API endpoints and their parameter requirements.
# Each endpoint is configured with required parameters, optional parameters, and XOR
# parameter groups (where exactly one parameter from the group must be provided).
endpoints: Dict[str, Dict[str, Any]] = {
endpoints: dict[str, dict[str, Any]] = {
"stations": {},
"liveboard": {"xor": ["station", "id"], "optional": ["date", "time", "arrdep", "alerts"]},
"connections": {
Expand All @@ -71,8 +71,8 @@ def __init__(self, lang: str = "en", session: ClientSession | None = None) -> No
self.last_request_time: float = time.time()
self.lock: Lock = Lock()
self.session: ClientSession | None = session
self._owns_session = session is None # Track ownership
self.etag_cache: Dict[str, str] = {}
self._owns_session = session is None # Track ownership
self.etag_cache: dict[str, str] = {}
logger.info("iRail instance created")

async def __aenter__(self) -> "iRail":
Expand Down Expand Up @@ -130,11 +130,10 @@ def clear_etag_cache(self) -> None:
logger.info("ETag cache cleared")

def _refill_tokens(self) -> None:
"""Refill tokens for rate limiting based on elapsed time.
This method refills both standard tokens (max 3) and burst tokens (max 5)
using a token bucket algorithm. The refill rate is 3 tokens per second.
"""Refill tokens for rate limiting using a token bucket algorithm.
- Standard tokens: Refill rate of 3 tokens/second, max 3 tokens.
- Burst tokens: Refilled only when standard tokens are full, max 5 tokens.
"""
logger.debug("Refilling tokens")
current_time: float = time.time()
Expand All @@ -151,9 +150,9 @@ def _refill_tokens(self) -> None:
async def _handle_rate_limit(self) -> None:
"""Handle rate limiting using a token bucket algorithm.
The implementation uses two buckets:
- Normal bucket: 3 tokens/second
- Burst bucket: 5 tokens/second
- Standard tokens: 3 requests/second.
- Burst tokens: Additional 5 requests/second for spikes.
- Waits and refills tokens if both are exhausted.
"""
logger.debug("Handling rate limit")
self._refill_tokens()
Expand All @@ -169,18 +168,18 @@ async def _handle_rate_limit(self) -> None:
else:
self.tokens -= 1

def _add_etag_header(self, method: str) -> Dict[str, str]:
def _add_etag_header(self, method: str) -> dict[str, str]:
"""Add ETag header for the given method if a cached ETag is available.
Args:
method (str): The API endpoint for which the header is being generated.
Returns:
Dict[str, str]: A dictionary containing HTTP headers, including the ETag header
dict[str, str]: A dictionary containing HTTP headers, including the ETag header
if a cached value exists.
"""
headers: Dict[str, str] = {"User-Agent": "pyRail (https://github.com/tjorim/pyrail; [email protected])"}
headers: dict[str, str] = {"User-Agent": "pyRail (https://github.com/tjorim/pyrail; [email protected])"}
if method in self.etag_cache:
logger.debug("Adding If-None-Match header with value: %s", self.etag_cache[method])
headers["If-None-Match"] = self.etag_cache[method]
Expand Down Expand Up @@ -226,12 +225,12 @@ def _validate_time(self, time: str | None) -> bool:
logger.error("Invalid time format. Expected HHMM (e.g., 1430 for 2:30 PM), got: %s", time)
return False

def _validate_params(self, method: str, params: Dict[str, Any] | None = None) -> bool:
def _validate_params(self, method: str, params: dict[str, Any] | None = None) -> bool:
"""Validate parameters for a specific iRail API endpoint based on predefined requirements.
Args:
method (str): The API endpoint method to validate parameters for.
params (Dict[str, Any], optional): Dictionary of parameters to validate. Defaults to None.
params (dict[str, Any], optional): Dictionary of parameters to validate. Defaults to None.
Returns:
bool: True if parameters are valid, False otherwise.
Expand Down Expand Up @@ -289,12 +288,12 @@ def _validate_params(self, method: str, params: Dict[str, Any] | None = None) ->

return True

async def _handle_success_response(self, response: ClientResponse, method: str) -> Dict[str, Any] | None:
async def _handle_success_response(self, response: ClientResponse, method: str) -> dict[str, Any] | None:
"""Handle a successful API response."""
if "Etag" in response.headers:
self.etag_cache[method] = response.headers["Etag"]
try:
json_data: Dict[str, Any] | None = await response.json()
json_data: dict[str, Any] | None = await response.json()
if not json_data:
logger.warning("Empty response received")
return json_data
Expand All @@ -303,8 +302,8 @@ async def _handle_success_response(self, response: ClientResponse, method: str)
return None

async def _handle_response(
self, response: ClientResponse, method: str, args: Dict[str, Any] | None = None
) -> Dict[str, Any] | None:
self, response: ClientResponse, method: str, args: dict[str, Any] | None = None
) -> dict[str, Any] | None:
"""Handle the API response based on status code."""
if response.status == 429:
retry_after: int = int(response.headers.get("Retry-After", 1))
Expand All @@ -326,7 +325,7 @@ async def _handle_response(
logger.error("Request failed with status code: %s, response: %s", response.status, await response.text())
return None

async def _do_request(self, method: str, args: Dict[str, Any] | None = None) -> Dict[str, Any] | None:
async def _do_request(self, method: str, args: dict[str, Any] | None = None) -> dict[str, Any] | None:
"""Send an asynchronous request to the specified iRail API endpoint.
This method handles API requests with rate limiting, parameter validation,
Expand Down Expand Up @@ -368,7 +367,7 @@ async def _do_request(self, method: str, args: Dict[str, Any] | None = None) ->
if args:
params.update(args)

request_headers: Dict[str, str] = self._add_etag_header(method)
request_headers: dict[str, str] = self._add_etag_header(method)

try:
async with self.session.get(url, params=params, headers=request_headers) as response:
Expand Down Expand Up @@ -429,7 +428,7 @@ async def get_liveboard(
print(f"Liveboard for Brussels-South: {liveboard}")
"""
extra_params: Dict[str, Any] = {
extra_params: dict[str, Any] = {
"station": station,
"id": id,
"date": date,
Expand Down Expand Up @@ -473,7 +472,7 @@ async def get_connections(
print(f"Connections from Antwerpen-Centraal to Brussel-Centraal: {connections}")
"""
extra_params: Dict[str, Any] = {
extra_params: dict[str, Any] = {
"from": from_station,
"to": to_station,
"date": date,
Expand Down Expand Up @@ -504,7 +503,7 @@ async def get_vehicle(self, id: str, date: str | None = None, alerts: bool = Fal
vehicle_info = await client.get_vehicle("BE.NMBS.IC1832")
"""
extra_params: Dict[str, Any] = {"id": id, "date": date, "alerts": "true" if alerts else "false"}
extra_params: dict[str, Any] = {"id": id, "date": date, "alerts": "true" if alerts else "false"}
vehicle_response_dict = await self._do_request(
"vehicle", {k: v for k, v in extra_params.items() if v is not None}
)
Expand All @@ -527,7 +526,7 @@ async def get_composition(self, id: str, data: str | None = None) -> Composition
composition = await client.get_composition('S51507')
"""
extra_params: Dict[str, str | None] = {"id": id, "data": data}
extra_params: dict[str, str | None] = {"id": id, "data": data}
composition_response_dict = await self._do_request(
"composition", {k: v for k, v in extra_params.items() if v is not None}
)
Expand Down
Loading

0 comments on commit 161e295

Please sign in to comment.