From 4995982e733042567905131215778cdeb579dbfd Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sat, 9 Sep 2023 17:35:07 +0200 Subject: [PATCH 01/10] Implement sync auto batching requests --- gql/client.py | 130 +++++++++++++++++++++++++++++++++-- tests/test_requests_batch.py | 49 +++++++++++++ 2 files changed, 172 insertions(+), 7 deletions(-) diff --git a/gql/client.py b/gql/client.py index 326442e0..18bb6682 100644 --- a/gql/client.py +++ b/gql/client.py @@ -1,7 +1,11 @@ import asyncio import logging import sys +import time import warnings +from concurrent.futures import Future +from queue import Queue +from threading import Event, Thread from typing import ( Any, AsyncGenerator, @@ -10,6 +14,7 @@ Generator, List, Optional, + Tuple, TypeVar, Union, cast, @@ -82,6 +87,8 @@ def __init__( execute_timeout: Optional[Union[int, float]] = 10, serialize_variables: bool = False, parse_results: bool = False, + batch_interval: float = 0, + batch_max: int = 10, ): """Initialize the client with the given parameters. @@ -99,6 +106,9 @@ def __init__( serialized. Used for custom scalars and/or enums. Default: False. :param parse_results: Whether gql will try to parse the serialized output sent by the backend. Can be used to unserialize custom scalars or enums. + :param batch_interval: Time to wait in seconds for batching requests together. + Batching is disabled (by default) if 0. + :param batch_max: Maximum number of requests in a single batch. """ if introspection: @@ -146,6 +156,12 @@ def __init__( self.serialize_variables = serialize_variables self.parse_results = parse_results + self.batch_interval = batch_interval + self.batch_max = batch_max + + @property + def batching_enabled(self): + return self.batch_interval != 0 def validate(self, document: DocumentNode): """:meta private:""" @@ -758,7 +774,12 @@ def connect_sync(self): return self.session def close_sync(self): - """Close the sync transport.""" + """Close the sync session and the sync transport. + + If batching is enabled, this will block until the remaining queries in the + batching queue have been processed. + """ + self.session.wait_stop() self.transport.close() def __enter__(self): @@ -779,6 +800,13 @@ def __init__(self, client: Client): """:param client: the :class:`client ` used""" self.client = client + if self.client.batching_enabled: + self.batch_queue: Queue = Queue() + self._batch_thread_stop_requested = False + self._batch_thread_stopped_event = Event() + self._batch_thread = Thread(target=self._batch_loop, daemon=True) + self._batch_thread.start() + def _execute( self, document: DocumentNode, @@ -818,12 +846,22 @@ def _execute( operation_name=operation_name, ) - result = self.transport.execute( - document, - variable_values=variable_values, - operation_name=operation_name, - **kwargs, - ) + if self.client.batching_enabled: + request = GraphQLRequest( + document, + variable_values=variable_values, + operation_name=operation_name, + ) + future_result = self._execute_future(request) + result = future_result.result() + + else: + result = self.transport.execute( + document, + variable_values=variable_values, + operation_name=operation_name, + **kwargs, + ) # Unserialize the result if requested if self.client.schema: @@ -1038,6 +1076,84 @@ def execute_batch( return cast(List[Dict[str, Any]], [result.data for result in results]) + def _batch_loop(self) -> None: + """main loop of the thread used to wait for requests + to execute them in a batch""" + + stop_loop = False + + while not stop_loop: + + # First wait for a first request in from the batch queue + requests_and_futures: List[Tuple[GraphQLRequest, Future]] = [] + request_and_future: Tuple[GraphQLRequest, Future] = self.batch_queue.get() + if request_and_future is None: + break + requests_and_futures.append(request_and_future) + + # Then wait the requested batch interval except if we already + # have the maximum number of requests in the queue + if self.batch_queue.qsize() < self.client.batch_max - 1: + time.sleep(self.client.batch_interval) + + # Then get the requests which had been made during that wait interval + for _ in range(self.client.batch_max - 1): + if self.batch_queue.empty(): + break + request_and_future = self.batch_queue.get() + if request_and_future is None: + stop_loop = True + break + requests_and_futures.append(request_and_future) + + requests = [request for request, _ in requests_and_futures] + futures = [future for _, future in requests_and_futures] + + # Manually execute the requests in a batch + try: + results: List[ExecutionResult] = self._execute_batch(requests) + except Exception as exc: + for future in futures: + future.set_exception(exc) + + # Fill in the future results + for result, future in zip(results, futures): + future.set_result(result) + + # Indicate that the Thread has stopped + self._batch_thread_stopped_event.set() + + def _execute_future( + self, + request: GraphQLRequest, + ) -> Future: + """If batching is enabled, this method will put a request in the batching queue + instead of executing it directly so that the requests could be put in a batch. + """ + + assert hasattr(self, "batch_queue"), "Batching is not enabled" + assert not self._batch_thread_stop_requested, "Batching thread has been stopped" + + future: Future = Future() + self.batch_queue.put((request, future)) + + return future + + def wait_stop(self): + """Cleanup the batching thread if batching is enabled. + + Will wait until all the remaining requests in the batch processing queue + have been executed. + """ + if hasattr(self, "_batch_thread_stopped_event"): + # Send a None in the queue to indicate that the batching Thread must stop + # after having processed the remaining requests in the queue + self._batch_thread_stop_requested = True + self.batch_queue.put(None) + + # Wait for the Thread to stop + self._batch_thread_stopped_event.wait() + def fetch_schema(self) -> None: """Fetch the GraphQL schema explicitly using introspection. diff --git a/tests/test_requests_batch.py b/tests/test_requests_batch.py index 23ab1254..58fe01fc 100644 --- a/tests/test_requests_batch.py +++ b/tests/test_requests_batch.py @@ -375,3 +375,52 @@ def test_code(): assert execution_results[0].extensions["key1"] == "val1" await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.online +@pytest.mark.requests +def test_requests_sync_batch_auto(): + + from threading import Thread + from gql.transport.requests import RequestsHTTPTransport + + client = Client( + transport=RequestsHTTPTransport(url="https://countries.trevorblades.com/"), + batch_interval=0.01, + batch_max=3, + ) + + query = gql( + """ + query getContinentName($continent_code: ID!) { + continent(code: $continent_code) { + name + } + } + """ + ) + + def get_continent_name(session, continent_code): + variables = { + "continent_code": continent_code, + } + + result = session.execute(query, variable_values=variables) + + name = result["continent"]["name"] + print(f"The continent with the code {continent_code} has the name: '{name}'") + + continent_codes = ["EU", "AF", "NA", "OC", "SA", "AS", "AN"] + + with client as session: + + for continent_code in continent_codes: + + thread = Thread( + target=get_continent_name, + args=( + session, + continent_code, + ), + ) + thread.start() From 7fbee893588c5a3063d286af705485752f348c0b Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sat, 9 Sep 2023 19:02:15 +0200 Subject: [PATCH 02/10] Fix using two sync sessions with batch enabled --- gql/client.py | 35 +++++++++++++++++++++-------------- tests/test_requests_batch.py | 14 ++++++++++++++ 2 files changed, 35 insertions(+), 14 deletions(-) diff --git a/gql/client.py b/gql/client.py index 18bb6682..2ff0a421 100644 --- a/gql/client.py +++ b/gql/client.py @@ -755,11 +755,11 @@ def connect_sync(self): " Use 'async with Client(...) as session:' instead" ) - self.transport.connect() - if not hasattr(self, "session"): self.session = SyncClientSession(client=self) + self.session.connect() + # Get schema from transport if needed try: if self.fetch_schema_from_transport and not self.schema: @@ -768,7 +768,7 @@ def connect_sync(self): # we don't know what type of exception is thrown here because it # depends on the underlying transport; we just make sure that the # transport is closed and re-raise the exception - self.transport.close() + self.session.close() raise return self.session @@ -779,8 +779,7 @@ def close_sync(self): If batching is enabled, this will block until the remaining queries in the batching queue have been processed. """ - self.session.wait_stop() - self.transport.close() + self.session.close() def __enter__(self): return self.connect_sync() @@ -800,13 +799,6 @@ def __init__(self, client: Client): """:param client: the :class:`client ` used""" self.client = client - if self.client.batching_enabled: - self.batch_queue: Queue = Queue() - self._batch_thread_stop_requested = False - self._batch_thread_stopped_event = Event() - self._batch_thread = Thread(target=self._batch_loop, daemon=True) - self._batch_thread.start() - def _execute( self, document: DocumentNode, @@ -1139,8 +1131,21 @@ def _execute_future( return future - def wait_stop(self): - """Cleanup the batching thread if batching is enabled. + def connect(self): + """Connect the transport and initialize the batch threading loop if batching + is enabled.""" + + if self.client.batching_enabled: + self.batch_queue: Queue = Queue() + self._batch_thread_stop_requested = False + self._batch_thread_stopped_event = Event() + self._batch_thread = Thread(target=self._batch_loop, daemon=True) + self._batch_thread.start() + + self.transport.connect() + + def close(self): + """Close the transport and cleanup the batching thread if batching is enabled. Will wait until all the remaining requests in the batch processing queue have been executed. @@ -1154,6 +1159,8 @@ def wait_stop(self): # Wait for the Thread to stop self._batch_thread_stopped_event.wait() + self.transport.close() + def fetch_schema(self) -> None: """Fetch the GraphQL schema explicitly using introspection. diff --git a/tests/test_requests_batch.py b/tests/test_requests_batch.py index 58fe01fc..53ae2b9a 100644 --- a/tests/test_requests_batch.py +++ b/tests/test_requests_batch.py @@ -424,3 +424,17 @@ def get_continent_name(session, continent_code): ), ) thread.start() + + # Doing it twice to check that everything is closing correctly + with client as session: + + for continent_code in continent_codes: + + thread = Thread( + target=get_continent_name, + args=( + session, + continent_code, + ), + ) + thread.start() From 9fd57a48a8827226f19ee1c7bccaa39aacbb438f Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sat, 9 Sep 2023 19:32:18 +0200 Subject: [PATCH 03/10] Fix exception handling --- gql/client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gql/client.py b/gql/client.py index 2ff0a421..48d739b2 100644 --- a/gql/client.py +++ b/gql/client.py @@ -1107,6 +1107,7 @@ def _batch_loop(self) -> None: except Exception as exc: for future in futures: future.set_exception(exc) + continue # Fill in the future results for result, future in zip(results, futures): From c310552e2c8f9292a59ed2e800ef04e084dba641 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sat, 9 Sep 2023 19:32:28 +0200 Subject: [PATCH 04/10] Add tests --- tests/test_requests_batch.py | 159 ++++++++++++++++++++++++++++++++++- 1 file changed, 158 insertions(+), 1 deletion(-) diff --git a/tests/test_requests_batch.py b/tests/test_requests_batch.py index 53ae2b9a..0f016a46 100644 --- a/tests/test_requests_batch.py +++ b/tests/test_requests_batch.py @@ -30,6 +30,21 @@ '{"code":"SA","name":"South America"}]}}]' ) +query1_server_answer_twice_list = ( + "[" + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}},' + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}}' + "]" +) + @pytest.mark.aiohttp @pytest.mark.asyncio @@ -74,6 +89,108 @@ def test_code(): await run_sync_test(event_loop, server, test_code) +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_query_auto_batch_enabled( + event_loop, aiohttp_server, run_sync_test +): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + transport = RequestsHTTPTransport(url=url) + + with Client( + transport=transport, + batch_interval=0.01, + ) as session: + + query = gql(query1_str) + + # Execute query synchronously + result = session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_query_auto_batch_enabled_two_requests( + event_loop, aiohttp_server, run_sync_test +): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + from threading import Thread + + async def handler(request): + return web.Response( + text=query1_server_answer_twice_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + transport = RequestsHTTPTransport(url=url) + + with Client( + transport=transport, + batch_interval=0.01, + ) as session: + + def test_thread(): + query = gql(query1_str) + + # Execute query synchronously + result = session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + for _ in range(2): + thread = Thread(target=test_thread) + thread.start() + + await run_sync_test(event_loop, server, test_code) + + @pytest.mark.aiohttp @pytest.mark.asyncio async def test_requests_cookies(event_loop, aiohttp_server, run_sync_test): @@ -148,6 +265,46 @@ def test_code(): await run_sync_test(event_loop, server, test_code) +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_error_code_401_auto_batch_enabled( + event_loop, aiohttp_server, run_sync_test +): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + + async def handler(request): + # Will generate http error code 401 + return web.Response( + text='{"error":"Unauthorized","message":"401 Client Error: Unauthorized"}', + content_type="application/json", + status=401, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + transport = RequestsHTTPTransport(url=url) + + with Client( + transport=transport, + batch_interval=0.01, + ) as session: + + query = gql(query1_str) + + with pytest.raises(TransportServerError) as exc_info: + session.execute(query) + + assert "401 Client Error: Unauthorized" in str(exc_info.value) + + await run_sync_test(event_loop, server, test_code) + + @pytest.mark.aiohttp @pytest.mark.asyncio async def test_requests_error_code_429(event_loop, aiohttp_server, run_sync_test): @@ -425,7 +582,7 @@ def get_continent_name(session, continent_code): ) thread.start() - # Doing it twice to check that everything is closing correctly + # Doing it twice to check that everything is closing and reconnecting correctly with client as session: for continent_code in continent_codes: From f75c1cd6812db827190452cb6183f7d10a250da9 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sun, 10 Sep 2023 09:12:13 +0200 Subject: [PATCH 05/10] Adding some online tests to demonstrate _execute_future --- tests/test_requests_batch.py | 68 ++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/tests/test_requests_batch.py b/tests/test_requests_batch.py index 0f016a46..cb9ae6aa 100644 --- a/tests/test_requests_batch.py +++ b/tests/test_requests_batch.py @@ -595,3 +595,71 @@ def get_continent_name(session, continent_code): ), ) thread.start() + + +@pytest.mark.online +@pytest.mark.requests +def test_requests_sync_batch_auto_execute_future(): + + from gql.transport.requests import RequestsHTTPTransport + + client = Client( + transport=RequestsHTTPTransport(url="https://countries.trevorblades.com/"), + batch_interval=0.01, + batch_max=3, + ) + + query = gql( + """ + query getContinentName($continent_code: ID!) { + continent(code: $continent_code) { + name + } + } + """ + ) + + with client as session: + + request_eu = GraphQLRequest(query, variable_values={"continent_code": "EU"}) + future_result_eu = session._execute_future(request_eu) + + request_af = GraphQLRequest(query, variable_values={"continent_code": "AF"}) + future_result_af = session._execute_future(request_af) + + result_eu = future_result_eu.result().data + result_af = future_result_af.result().data + + assert result_eu["continent"]["name"] == "Europe" + assert result_af["continent"]["name"] == "Africa" + + +@pytest.mark.online +@pytest.mark.requests +def test_requests_sync_batch_manual(): + + from gql.transport.requests import RequestsHTTPTransport + + client = Client( + transport=RequestsHTTPTransport(url="https://countries.trevorblades.com/"), + ) + + query = gql( + """ + query getContinentName($continent_code: ID!) { + continent(code: $continent_code) { + name + } + } + """ + ) + + with client as session: + + request_eu = GraphQLRequest(query, variable_values={"continent_code": "EU"}) + request_af = GraphQLRequest(query, variable_values={"continent_code": "AF"}) + + result_eu, result_af = session.execute_batch([request_eu, request_af]) + + assert result_eu["continent"]["name"] == "Europe" + assert result_af["continent"]["name"] == "Africa" From 69492e3d54a8a5a6590b679b3fdac48c6b38c95f Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sun, 10 Sep 2023 10:15:14 +0200 Subject: [PATCH 06/10] Add * to execute_batch methods --- gql/client.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/gql/client.py b/gql/client.py index 48d739b2..435d0496 100644 --- a/gql/client.py +++ b/gql/client.py @@ -257,6 +257,7 @@ def execute_sync( def execute_batch_sync( self, reqs: List[GraphQLRequest], + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, @@ -456,6 +457,7 @@ def execute( def execute_batch( self, reqs: List[GraphQLRequest], + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, @@ -969,6 +971,7 @@ def execute( def _execute_batch( self, reqs: List[GraphQLRequest], + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, **kwargs, @@ -1019,6 +1022,7 @@ def _execute_batch( def execute_batch( self, reqs: List[GraphQLRequest], + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, From a6416e7f91a00164d44ea3c53065c218866b376e Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sun, 10 Sep 2023 10:15:55 +0200 Subject: [PATCH 07/10] Rename reqs to requests --- gql/client.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/gql/client.py b/gql/client.py index 435d0496..31482c16 100644 --- a/gql/client.py +++ b/gql/client.py @@ -256,7 +256,7 @@ def execute_sync( def execute_batch_sync( self, - reqs: List[GraphQLRequest], + requests: List[GraphQLRequest], *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, @@ -266,7 +266,7 @@ def execute_batch_sync( """:meta private:""" with self as session: return session.execute_batch( - reqs, + requests, serialize_variables=serialize_variables, parse_result=parse_result, get_execution_result=get_execution_result, @@ -456,7 +456,7 @@ def execute( def execute_batch( self, - reqs: List[GraphQLRequest], + requests: List[GraphQLRequest], *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, @@ -490,7 +490,7 @@ def execute_batch( else: # Sync transports return self.execute_batch_sync( - reqs, + requests, serialize_variables=serialize_variables, parse_result=parse_result, get_execution_result=get_execution_result, @@ -970,7 +970,7 @@ def execute( def _execute_batch( self, - reqs: List[GraphQLRequest], + requests: List[GraphQLRequest], *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, @@ -979,7 +979,7 @@ def _execute_batch( """Execute multiple GraphQL requests in a batch, using the sync transport, returning a list of ExecutionResult objects. - :param reqs: List of requests that will be executed. + :param requests: List of requests that will be executed. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. @@ -990,21 +990,21 @@ def _execute_batch( # Validate document if self.client.schema: - for req in reqs: + for req in requests: self.client.validate(req.document) # Parse variable values for custom scalars if requested if serialize_variables or ( serialize_variables is None and self.client.serialize_variables ): - reqs = [ + requests = [ req.serialize_variable_values(self.client.schema) if req.variable_values is not None else req - for req in reqs + for req in requests ] - results = self.transport.execute_batch(reqs, **kwargs) + results = self.transport.execute_batch(requests, **kwargs) # Unserialize the result if requested if self.client.schema: @@ -1021,7 +1021,7 @@ def _execute_batch( def execute_batch( self, - reqs: List[GraphQLRequest], + requests: List[GraphQLRequest], *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, @@ -1034,7 +1034,7 @@ def execute_batch( Raises a TransportQueryError if an error has been returned in any ExecutionResult. - :param reqs: List of requests that will be executed. + :param requests: List of requests that will be executed. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. @@ -1047,7 +1047,7 @@ def execute_batch( # Validate and execute on the transport results = self._execute_batch( - reqs, + requests, serialize_variables=serialize_variables, parse_result=parse_result, **kwargs, From 91697b363684c0d5a47e5186c2f1008e74994d8e Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sun, 10 Sep 2023 10:30:19 +0200 Subject: [PATCH 08/10] Adding overloads to the execute_batch methods --- gql/client.py | 108 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) diff --git a/gql/client.py b/gql/client.py index 31482c16..1b0df4dc 100644 --- a/gql/client.py +++ b/gql/client.py @@ -254,6 +254,42 @@ def execute_sync( **kwargs, ) + @overload + def execute_batch_sync( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: Literal[False], + **kwargs, + ) -> List[Dict[str, Any]]: + ... # pragma: no cover + + @overload + def execute_batch_sync( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: Literal[True], + **kwargs, + ) -> List[ExecutionResult]: + ... # pragma: no cover + + @overload + def execute_batch_sync( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: bool, + **kwargs, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: + ... # pragma: no cover + def execute_batch_sync( self, requests: List[GraphQLRequest], @@ -454,6 +490,42 @@ def execute( **kwargs, ) + @overload + def execute_batch( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: Literal[False], + **kwargs, + ) -> List[Dict[str, Any]]: + ... # pragma: no cover + + @overload + def execute_batch( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: Literal[True], + **kwargs, + ) -> List[ExecutionResult]: + ... # pragma: no cover + + @overload + def execute_batch( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: bool, + **kwargs, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: + ... # pragma: no cover + def execute_batch( self, requests: List[GraphQLRequest], @@ -1019,6 +1091,42 @@ def _execute_batch( return results + @overload + def execute_batch( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: Literal[False], + **kwargs, + ) -> List[Dict[str, Any]]: + ... # pragma: no cover + + @overload + def execute_batch( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: Literal[True], + **kwargs, + ) -> List[ExecutionResult]: + ... # pragma: no cover + + @overload + def execute_batch( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: bool, + **kwargs, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: + ... # pragma: no cover + def execute_batch( self, requests: List[GraphQLRequest], From 726942462c401f29ff8c379d616ceed6bd09e28b Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sun, 10 Sep 2023 10:48:33 +0200 Subject: [PATCH 09/10] Don't validate the doc or parse the variables twice when using automatic batching --- gql/client.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/gql/client.py b/gql/client.py index 1b0df4dc..5c1edffa 100644 --- a/gql/client.py +++ b/gql/client.py @@ -1046,6 +1046,7 @@ def _execute_batch( *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, + validate_document: Optional[bool] = True, **kwargs, ) -> List[ExecutionResult]: """Execute multiple GraphQL requests in a batch, using @@ -1057,13 +1058,16 @@ def _execute_batch( By default use the serialize_variables argument of the client. :param parse_result: Whether gql will unserialize the result. By default use the parse_results argument of the client. + :param validate_document: Whether we still need to validate the document. The extra arguments are passed to the transport execute method.""" # Validate document if self.client.schema: - for req in requests: - self.client.validate(req.document) + + if validate_document: + for req in requests: + self.client.validate(req.document) # Parse variable values for custom scalars if requested if serialize_variables or ( @@ -1215,7 +1219,12 @@ def _batch_loop(self) -> None: # Manually execute the requests in a batch try: - results: List[ExecutionResult] = self._execute_batch(requests) + results: List[ExecutionResult] = self._execute_batch( + requests, + serialize_variables=False, # already done + parse_result=False, + validate_document=False, + ) except Exception as exc: for future in futures: future.set_exception(exc) From 02f249e23d4d12b1a080280d16d0b1e0165b93af Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Mon, 11 Sep 2023 21:19:36 +0200 Subject: [PATCH 10/10] Add thread.join to tests --- tests/test_requests_batch.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_requests_batch.py b/tests/test_requests_batch.py index cb9ae6aa..1f922db7 100644 --- a/tests/test_requests_batch.py +++ b/tests/test_requests_batch.py @@ -162,6 +162,8 @@ async def handler(request): def test_code(): transport = RequestsHTTPTransport(url=url) + threads = [] + with Client( transport=transport, batch_interval=0.01, @@ -187,6 +189,10 @@ def test_thread(): for _ in range(2): thread = Thread(target=test_thread) thread.start() + threads.append(thread) + + for thread in threads: + thread.join() await run_sync_test(event_loop, server, test_code) @@ -581,6 +587,7 @@ def get_continent_name(session, continent_code): ), ) thread.start() + thread.join() # Doing it twice to check that everything is closing and reconnecting correctly with client as session: @@ -595,6 +602,7 @@ def get_continent_name(session, continent_code): ), ) thread.start() + thread.join() @pytest.mark.online