diff --git a/src/rdflib_endpoint/__main__.py b/src/rdflib_endpoint/__main__.py index f86990f..4d38cc2 100644 --- a/src/rdflib_endpoint/__main__.py +++ b/src/rdflib_endpoint/__main__.py @@ -19,11 +19,12 @@ def cli() -> None: @click.option("--host", default="localhost", help="Host of the SPARQL endpoint") @click.option("--port", default=8000, help="Port of the SPARQL endpoint") @click.option("--store", default="default", help="Store used by RDFLib: default or Oxigraph") -def serve(files: List[str], host: str, port: int, store: str) -> None: - run_serve(files, host, port, store) +@click.option("--enable-update", is_flag=True, help="Enable SPARQL updates") +def serve(files: List[str], host: str, port: int, store: str, enable_update: bool) -> None: + run_serve(files, host, port, store, enable_update) -def run_serve(files: List[str], host: str, port: int, store: str = "default") -> None: +def run_serve(files: List[str], host: str, port: int, store: str = "default", enable_update: bool = False) -> None: if store == "oxigraph": store = store.capitalize() g = ConjunctiveGraph(store=store) @@ -41,6 +42,7 @@ def run_serve(files: List[str], host: str, port: int, store: str = "default") -> app = SparqlEndpoint( graph=g, + enable_update=enable_update, example_query="""PREFIX rdf: PREFIX rdfs: SELECT * WHERE { diff --git a/src/rdflib_endpoint/sparql_router.py b/src/rdflib_endpoint/sparql_router.py index 1f3f0d9..89c2ce1 100644 --- a/src/rdflib_endpoint/sparql_router.py +++ b/src/rdflib_endpoint/sparql_router.py @@ -1,4 +1,5 @@ import logging +import os import re from importlib import resources from typing import Any, Callable, Dict, List, Optional, Union @@ -8,7 +9,7 @@ from fastapi import APIRouter, Query, Request, Response from fastapi.responses import JSONResponse from rdflib import RDF, ConjunctiveGraph, Dataset, Graph, Literal, URIRef -from rdflib.plugins.sparql import prepareQuery +from rdflib.plugins.sparql import prepareQuery, prepareUpdate from rdflib.plugins.sparql.evaluate import evalPart from rdflib.plugins.sparql.evalutils import _eval from rdflib.plugins.sparql.parserutils import CompValue @@ -203,14 +204,22 @@ def __init__( description=self.example_markdown, responses=api_responses, ) - async def sparql_endpoint(request: Request, query: Optional[str] = Query(None)) -> Response: + async def sparql_endpoint( + request: Request, query: Optional[str] = Query(None), update: Optional[str] = None # Not supported for GET + ) -> Response: """ Send a SPARQL query to be executed through HTTP GET operation. :param request: The HTTP GET request :param query: SPARQL query input. """ - if not query: + if query and update: + return JSONResponse( + status_code=400, + content={"message": "Cannot do both query and update"}, + ) + + if not query and not update: if str(request.headers["accept"]).startswith("text/html"): return self.serve_yasgui() # If not asking HTML, return the SPARQL endpoint service description @@ -236,72 +245,79 @@ async def sparql_endpoint(request: Request, query: Optional[str] = Query(None)) graph_ns = dict(self.graph.namespaces()) - try: - # Query the graph with the custom functions loaded - parsed_query = prepareQuery(query, initNs=graph_ns) - query_operation = re.sub(r"(\w)([A-Z])", r"\1 \2", parsed_query.algebra.name) - except Exception as e: - logging.error("Error parsing the SPARQL query: " + str(e)) - return JSONResponse( - status_code=400, - content={"message": "Error parsing the SPARQL query"}, - ) - - # TODO: RDFLib doesn't support SPARQL insert (Expected {SelectQuery | ConstructQuery | DescribeQuery | AskQuery}, found 'INSERT') - # But we could implement it by doing a CONSTRUCT, and adding the resulting triples to the graph - # if not self.enable_update: - # if query_operation == "Insert Query" or query_operation == "Delete Query": - # return JSONResponse(status_code=403, content={"message": "INSERT and DELETE queries are not allowed."}) - # if os.getenv('RDFLIB_APIKEY') and (query_operation == "Insert Query" or query_operation == "Delete Query"): - # if apikey != os.getenv('RDFLIB_APIKEY'): - # return JSONResponse(status_code=403, content={"message": "Wrong API KEY."}) - - try: - query_results = self.graph.query(query, processor=self.processor) - except Exception as e: - logging.error("Error executing the SPARQL query on the RDFLib Graph: " + str(e)) - return JSONResponse( - status_code=400, - content={"message": "Error executing the SPARQL query on the RDFLib Graph"}, - ) - - # Format and return results depending on Accept mime type in request header - mime_types = parse_accept_header(request.headers.get("accept", DEFAULT_CONTENT_TYPE)) - - # Handle cases that are more complicated, like it includes multiple - # types, extra information, etc. - output_mime_type = DEFAULT_CONTENT_TYPE - for mime_type in mime_types: - if mime_type in CONTENT_TYPE_TO_RDFLIB_FORMAT: - output_mime_type = mime_type - # Use the first mime_type that matches - break - - # Handle mime type for construct queries - if query_operation == "Construct Query": - if output_mime_type in {"application/json", "text/csv"}: - output_mime_type = "text/turtle" - # TODO: support JSON-LD for construct query? - # g.serialize(format='json-ld', indent=4) - elif output_mime_type == "application/xml": - output_mime_type = "application/rdf+xml" - else: - pass # TODO what happens here? - - try: - rdflib_format = CONTENT_TYPE_TO_RDFLIB_FORMAT[output_mime_type] - response = Response( - query_results.serialize(format=rdflib_format), - media_type=output_mime_type, - ) - except Exception as e: - logging.error("Error serializing the SPARQL query results with RDFLib: %s", e) - return JSONResponse( - status_code=422, - content={"message": "Error serializing the SPARQL query results"}, - ) - else: - return response + if query: + try: + parsed_query = prepareQuery(query, initNs=graph_ns) + query_results = self.graph.query(parsed_query, processor=self.processor) + + # Format and return results depending on Accept mime type in request header + mime_types = parse_accept_header(request.headers.get("accept", DEFAULT_CONTENT_TYPE)) + + # Handle cases that are more complicated, like it includes multiple + # types, extra information, etc. + output_mime_type = DEFAULT_CONTENT_TYPE + for mime_type in mime_types: + if mime_type in CONTENT_TYPE_TO_RDFLIB_FORMAT: + output_mime_type = mime_type + # Use the first mime_type that matches + break + + query_operation = re.sub(r"(\w)([A-Z])", r"\1 \2", parsed_query.algebra.name) + + # Handle mime type for construct queries + if query_operation == "Construct Query": + if output_mime_type in {"application/json", "text/csv"}: + output_mime_type = "text/turtle" + # TODO: support JSON-LD for construct query? + # g.serialize(format='json-ld', indent=4) + elif output_mime_type == "application/xml": + output_mime_type = "application/rdf+xml" + else: + pass # TODO what happens here? + + try: + rdflib_format = CONTENT_TYPE_TO_RDFLIB_FORMAT[output_mime_type] + response = Response( + query_results.serialize(format=rdflib_format), + media_type=output_mime_type, + ) + except Exception as e: + logging.error("Error serializing the SPARQL query results with RDFLib: %s", e) + return JSONResponse( + status_code=422, + content={"message": "Error serializing the SPARQL query results"}, + ) + else: + return response + except Exception as e: + logging.error("Error executing the SPARQL query on the RDFLib Graph: " + str(e)) + return JSONResponse( + status_code=400, + content={"message": "Error executing the SPARQL query on the RDFLib Graph"}, + ) + else: # update + if not self.enable_update: + return JSONResponse( + status_code=403, content={"message": "INSERT and DELETE queries are not allowed."} + ) + if rdflib_apikey := os.environ.get("RDFLIB_APIKEY"): + authorized = False + if auth_header := request.headers.get("Authorization"): # noqa: SIM102 + if auth_header.startswith("Bearer ") and auth_header[7:] == rdflib_apikey: + authorized = True + if not authorized: + return JSONResponse(status_code=403, content={"message": "Invalid API KEY."}) + try: + prechecked_update: str = update # type: ignore + parsed_update = prepareUpdate(prechecked_update, initNs=graph_ns) + self.graph.update(parsed_update, "sparql") + return Response(status_code=204) + except Exception as e: + logging.error("Error executing the SPARQL update on the RDFLib Graph: " + str(e)) + return JSONResponse( + status_code=400, + content={"message": "Error executing the SPARQL update on the RDFLib Graph"}, + ) @self.post( path, @@ -309,21 +325,31 @@ async def sparql_endpoint(request: Request, query: Optional[str] = Query(None)) description=self.example_markdown, responses=api_responses, ) - async def post_sparql_endpoint(request: Request, query: Optional[str] = Query(None)) -> Response: + async def post_sparql_endpoint(request: Request) -> Response: """Send a SPARQL query to be executed through HTTP POST operation. :param request: The HTTP POST request with a .body() - :param query: SPARQL query input. """ - if not query: - # Handle federated query services which provide the query in the body - query_body = await request.body() - body = query_body.decode("utf-8") - parsed_query = parse.parse_qsl(body) - for params in parsed_query: - if params[0] == "query": - query = parse.unquote(params[1]) - return await sparql_endpoint(request, query) + request_body = await request.body() + body = request_body.decode("utf-8") + content_type = request.headers.get("content-type") + if content_type == "application/sparql-query": + query = body + update = None + elif content_type == "application/sparql-update": + query = None + update = body + elif content_type == "application/x-www-form-urlencoded": + request_params = parse.parse_qsl(body) + query_params = [kvp[1] for kvp in request_params if kvp[0] == "query"] + query = parse.unquote(query_params[0]) if query_params else None + update_params = [kvp[1] for kvp in request_params if kvp[0] == "update"] + update = parse.unquote(update_params[0]) if update_params else None + else: + # Response with the service description + query = None + update = None + return await sparql_endpoint(request, query, update) def eval_custom_functions(self, ctx: QueryContext, part: CompValue) -> List[Any]: """Retrieve variables from a SPARQL-query, then execute registered SPARQL functions diff --git a/tests/test_example_app.py b/tests/test_example_app.py index 4c397fc..b64da52 100644 --- a/tests/test_example_app.py +++ b/tests/test_example_app.py @@ -28,7 +28,7 @@ def test_custom_concat(): response = endpoint.post( "/", - data="query=" + custom_concat_query, + data={"query": custom_concat_query}, headers={"accept": "application/json"}, ) assert response.status_code == 200 diff --git a/tests/test_oxrdflib.py b/tests/test_oxrdflib.py index 79af6b5..9fef6a8 100644 --- a/tests/test_oxrdflib.py +++ b/tests/test_oxrdflib.py @@ -33,25 +33,25 @@ def test_custom_concat_json(): assert response.status_code == 200 assert response.json()["results"]["bindings"][0]["label"]["value"] == "test value" - response = endpoint.post("/", data="query=" + label_select, headers={"accept": "application/json"}) + response = endpoint.post("/", data={"query": label_select}, headers={"accept": "application/json"}) assert response.status_code == 200 assert response.json()["results"]["bindings"][0]["label"]["value"] == "test value" def test_select_noaccept_xml(): - response = endpoint.post("/", data="query=" + label_select) + response = endpoint.post("/", data={"query": label_select}) assert response.status_code == 200 # assert response.json()['results']['bindings'][0]['concat']['value'] == "Firstlast" def test_select_csv(): - response = endpoint.post("/", data="query=" + label_select, headers={"accept": "text/csv"}) + response = endpoint.post("/", data={"query": label_select}, headers={"accept": "text/csv"}) assert response.status_code == 200 # assert response.json()['results']['bindings'][0]['concat']['value'] == "Firstlast" def test_fail_select_turtle(): - response = endpoint.post("/", data="query=" + label_select, headers={"accept": "text/turtle"}) + response = endpoint.post("/", data={"query": label_select}, headers={"accept": "text/turtle"}) assert response.status_code == 422 # assert response.json()['results']['bindings'][0]['concat']['value'] == "Firstlast" diff --git a/tests/test_rdflib_endpoint.py b/tests/test_rdflib_endpoint.py index 91ab6d7..db5dafa 100644 --- a/tests/test_rdflib_endpoint.py +++ b/tests/test_rdflib_endpoint.py @@ -1,12 +1,28 @@ +import pytest from example.app.main import custom_concat from fastapi.testclient import TestClient +from rdflib import RDFS, Graph, Literal, URIRef from rdflib_endpoint import SparqlEndpoint +graph = Graph() + + +@pytest.fixture(autouse=True) +def clear_graph(): + # Workaround to clear graph without putting + # graph, app and endpoint into a fixture + # and modifying the test fixture usage. + for triple in graph: + graph.remove(triple) + + app = SparqlEndpoint( + graph=graph, functions={ "https://w3id.org/um/sparql-functions/custom_concat": custom_concat, - } + }, + enable_update=True, ) endpoint = TestClient(app) @@ -29,25 +45,73 @@ def test_service_description(): def test_custom_concat_json(): response = endpoint.get("/", params={"query": concat_select}, headers={"accept": "application/json"}) - print(response.json()) + # print(response.json()) + assert response.status_code == 200 + assert response.json()["results"]["bindings"][0]["concat"]["value"] == "Firstlast" + + response = endpoint.post("/", data={"query": concat_select}, headers={"accept": "application/json"}) assert response.status_code == 200 assert response.json()["results"]["bindings"][0]["concat"]["value"] == "Firstlast" - response = endpoint.post("/", data="query=" + concat_select, headers={"accept": "application/json"}) + response = endpoint.post( + "/", data=concat_select, headers={"accept": "application/json", "content-type": "application/sparql-query"} + ) assert response.status_code == 200 assert response.json()["results"]["bindings"][0]["concat"]["value"] == "Firstlast" def test_select_noaccept_xml(): - response = endpoint.post("/", data="query=" + concat_select) + response = endpoint.post("/", data={"query": concat_select}) assert response.status_code == 200 def test_select_csv(): - response = endpoint.post("/", data="query=" + concat_select, headers={"accept": "text/csv"}) + response = endpoint.post("/", data={"query": concat_select}, headers={"accept": "text/csv"}) assert response.status_code == 200 +label_patch = """ +PREFIX rdfs: +DELETE { ?subject rdfs:label "foo" } +INSERT { ?subject rdfs:label "bar" } +WHERE { ?subject rdfs:label "foo" } +""" + + +@pytest.mark.parametrize( + "api_key,key_provided,param_method", + [ + (api_key, key_provided, param_method) + for api_key in [None, "key"] + for key_provided in [True, False] + for param_method in ["body_form", "body_direct"] + ], +) +def test_sparql_update(api_key, key_provided, param_method, monkeypatch): + if api_key: + monkeypatch.setenv("RDFLIB_APIKEY", api_key) + subject = URIRef("http://server.test/subject") + headers = {} + if key_provided: + headers["Authorization"] = "Bearer key" + graph.add((subject, RDFS.label, Literal("foo"))) + if param_method == "body_form": + request_args = {"data": {"update": label_patch}} + else: + # direct + headers["content-type"] = "application/sparql-update" + request_args = {"data": label_patch} + response = endpoint.post("/", headers=headers, **request_args) + if api_key is None or key_provided: + assert response.status_code == 204 + assert (subject, RDFS.label, Literal("foo")) not in graph + assert (subject, RDFS.label, Literal("bar")) in graph + else: + assert response.status_code == 403 + assert (subject, RDFS.label, Literal("foo")) in graph + assert (subject, RDFS.label, Literal("bar")) not in graph + + def test_multiple_accept_return_json(): response = endpoint.get( "/", @@ -69,7 +133,7 @@ def test_multiple_accept_return_json2(): def test_fail_select_turtle(): - response = endpoint.post("/", data="query=" + concat_select, headers={"accept": "text/turtle"}) + response = endpoint.post("/", data={"query": concat_select}, headers={"accept": "text/turtle"}) assert response.status_code == 422 # assert response.json()['results']['bindings'][0]['concat']['value'] == "Firstlast" @@ -78,7 +142,7 @@ def test_concat_construct_turtle(): # expected to return turtle response = endpoint.post( "/", - data="query=" + custom_concat_construct, + data={"query": custom_concat_construct}, headers={"accept": "application/json"}, ) assert response.status_code == 200 @@ -89,7 +153,7 @@ def test_concat_construct_xml(): # expected to return turtle response = endpoint.post( "/", - data="query=" + custom_concat_construct, + data={"query": custom_concat_construct}, headers={"accept": "application/xml"}, ) assert response.status_code == 200