diff --git a/xpublish_edr/__init__.py b/xpublish_edr/__init__.py index 1711770..75668a9 100644 --- a/xpublish_edr/__init__.py +++ b/xpublish_edr/__init__.py @@ -1,5 +1,5 @@ """ -xpublish_edr is not a real package, just a set of best practices examples. +Xpublish routers for the OGC EDR API. """ from xpublish_edr.plugin import CfEdrPlugin diff --git a/xpublish_edr/formats/to_covjson.py b/xpublish_edr/formats/to_covjson.py index 3a5c4d2..e06a9ba 100644 --- a/xpublish_edr/formats/to_covjson.py +++ b/xpublish_edr/formats/to_covjson.py @@ -16,6 +16,14 @@ import numpy as np import xarray as xr +from fastapi.responses import JSONResponse + + +class CovJSONResponse(JSONResponse): + """CovJSON response type""" + + # https://docs.ogc.org/cs/21-069r2/21-069r2.html#_b8b17e78-0147-4b58-8ade-a19465b57abc + media_type = "application/vnd.cov+json" class Domain(TypedDict): @@ -74,7 +82,7 @@ def invert_cf_dims(ds): return inverted -def to_cf_covjson(ds: xr.Dataset) -> CovJSON: +def to_cf_covjson(ds: xr.Dataset) -> CovJSONResponse: """Transform an xarray dataset to CoverageJSON using CF conventions""" covjson: CovJSON = { @@ -164,4 +172,4 @@ def to_cf_covjson(ds: xr.Dataset) -> CovJSON: covjson["ranges"][var] = cov_range - return covjson + return CovJSONResponse(content=covjson) diff --git a/xpublish_edr/formats/to_csv.py b/xpublish_edr/formats/to_csv.py index 436f81c..99e3d06 100644 --- a/xpublish_edr/formats/to_csv.py +++ b/xpublish_edr/formats/to_csv.py @@ -5,7 +5,7 @@ from fastapi import Response -def to_csv(ds: xr.Dataset): +def to_csv(ds: xr.Dataset) -> Response: """Return a CSV response from an xarray dataset""" ds = ds.squeeze() df = ds.to_pandas() diff --git a/xpublish_edr/formats/to_netcdf.py b/xpublish_edr/formats/to_netcdf.py index 1580771..01039c1 100644 --- a/xpublish_edr/formats/to_netcdf.py +++ b/xpublish_edr/formats/to_netcdf.py @@ -8,7 +8,7 @@ from fastapi import Response -def to_netcdf(ds: xr.Dataset): +def to_netcdf(ds: xr.Dataset) -> Response: """Return a NetCDF response from a dataset""" with TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "position.nc" diff --git a/xpublish_edr/plugin.py b/xpublish_edr/plugin.py index fab7bb0..26978f6 100644 --- a/xpublish_edr/plugin.py +++ b/xpublish_edr/plugin.py @@ -3,11 +3,15 @@ """ import importlib import logging -from typing import List, Optional +from functools import cache +from typing import Hashable, List, Optional, Tuple +import cachey +import dask import xarray as xr -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, Request, Response from xpublish import Dependencies, Plugin, hookimpl +from xpublish.utils.cache import CostTimer from .formats.to_covjson import to_cf_covjson from .query import EDRQuery, edr_query, edr_query_params @@ -15,6 +19,19 @@ logger = logging.getLogger("cf_edr") +def cache_key_from_request( + route: str, + request: Request, + query: EDRQuery, + dataset: xr.Dataset, +) -> Tuple[Hashable, ...]: + """Generate a cache key from the request and query parameters""" + with dask.config.set({"tokenize.ensure-deterministic": True}): + ds_token = dask.base.tokenize(dataset) + return (route, request, query, ds_token) + + +@cache def position_formats(): """ Return response format functions from registered @@ -71,85 +88,108 @@ def get_position( request: Request, query: EDRQuery = Depends(edr_query), dataset: xr.Dataset = Depends(deps.dataset), + cache: cachey.Cache = Depends(deps.cache), ): """ Returns position data based on WKT `Point(lon lat)` coordinates Extra selecting/slicing parameters can be provided as extra query parameters """ - try: - ds = dataset.cf.sel(X=query.point.x, Y=query.point.y, method="nearest") - except KeyError: - raise HTTPException( - status_code=404, - detail="Dataset does not have CF Convention compliant metadata", - ) + cache_key = cache_key_from_request("position", request, query, dataset) + response: Optional[Response] = cache.get(cache_key) - if query.z: - ds = dataset.cf.sel(Z=query.z, method="nearest") - - if query.datetime: - datetimes = query.datetime.split("/") + if response is not None: + logger.debug(f"Cache hit for {cache_key}") + return response + with CostTimer() as ct: try: - if len(datetimes) == 1: - ds = ds.cf.sel(T=datetimes[0], method="nearest") - elif len(datetimes) == 2: - ds = ds.cf.sel(T=slice(datetimes[0], datetimes[1])) - else: - raise HTTPException( - status_code=404, - detail="Invalid datetimes submitted", - ) - except ValueError as e: - logger.error("Error with datetime", exc_info=True) - raise HTTPException( - status_code=404, - detail=f"Invalid datetime ({e})", - ) from e - - if query.parameters: - try: - ds = ds.cf[query.parameters.split(",")] - except KeyError as e: + ds = dataset.cf.sel( + X=query.point.x, + Y=query.point.y, + method="nearest", + ) + except KeyError: raise HTTPException( status_code=404, - detail=f"Invalid variable: {e}", + detail="Dataset does not have CF Convention compliant metadata", ) - logger.debug(f"Dataset filtered by query params {ds}") + if query.z: + ds = dataset.cf.sel(Z=query.z, method="nearest") + + if query.datetime: + datetimes = query.datetime.split("/") + + try: + if len(datetimes) == 1: + ds = ds.cf.sel(T=datetimes[0], method="nearest") + elif len(datetimes) == 2: + ds = ds.cf.sel(T=slice(datetimes[0], datetimes[1])) + else: + raise HTTPException( + status_code=404, + detail="Invalid datetimes submitted", + ) + except ValueError as e: + logger.error("Error with datetime", exc_info=True) + raise HTTPException( + status_code=404, + detail=f"Invalid datetime ({e})", + ) from e - query_params = dict(request.query_params) - for query_param in request.query_params: - if query_param in edr_query_params: - del query_params[query_param] + if query.parameters: + try: + ds = ds.cf[query.parameters.split(",")] + except KeyError as e: + raise HTTPException( + status_code=404, + detail=f"Invalid variable: {e}", + ) - method: Optional[str] = "nearest" + logger.debug(f"Dataset filtered by query params {ds}") - for key, value in query_params.items(): - split_value = value.split("/") - if len(split_value) == 1: - continue - elif len(split_value) == 2: - query_params[key] = slice(split_value[0], split_value[1]) - method = None - else: - raise HTTPException(404, f"Too many values for selecting {key}") + query_params = dict(request.query_params) + for query_param in request.query_params: + if query_param in edr_query_params: + del query_params[query_param] - ds = ds.sel(query_params, method=method) + method: Optional[str] = "nearest" - if query.format: - try: - format_fn = position_formats()[query.format] - except KeyError: - raise HTTPException( - 404, - f"{query.format} is not a valid format for EDR position queries. " - "Get `./formats` for valid formats", - ) + for key, value in query_params.items(): + split_value = value.split("/") + if len(split_value) == 1: + continue + elif len(split_value) == 2: + query_params[key] = slice(split_value[0], split_value[1]) + method = None + else: + raise HTTPException( + 404, + f"Too many values for selecting {key}", + ) - return format_fn(ds) + ds = ds.sel(query_params, method=method) - return to_cf_covjson(ds) + if query.format: + try: + format_fn = position_formats()[query.format] + except KeyError: + raise HTTPException( + 404, + f"{query.format} is not a valid format for EDR position queries. " + "Get `./formats` for valid formats", + ) + else: + format_fn = to_cf_covjson + + response = format_fn(ds) + cache.put( + cache_key, + response, + ct.time, + int(response.headers["content-length"]), + ) + return response return router diff --git a/xpublish_edr/query.py b/xpublish_edr/query.py index 22b1d1f..ed776b7 100644 --- a/xpublish_edr/query.py +++ b/xpublish_edr/query.py @@ -29,6 +29,19 @@ def point(self): """Shapely point from WKT query params""" return wkt.loads(self.coords) + def __hash__(self): + """Hash based on query parameters""" + return hash( + ( + self.coords, + self.z, + self.datetime, + self.parameters, + self.crs, + self.format, + ), + ) + def edr_query( coords: str = Query(