Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sync batching to requests sync transport #431

Merged
merged 15 commits into from
Sep 5, 2023
Merged
2 changes: 2 additions & 0 deletions gql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from .__version__ import __version__
from .client import Client
from .gql import gql
from .graphql_request import GraphQLRequest

__all__ = [
"__version__",
"gql",
"Client",
"GraphQLRequest",
]
174 changes: 165 additions & 9 deletions gql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Callable,
Dict,
Generator,
List,
Optional,
TypeVar,
Union,
Expand All @@ -27,6 +28,8 @@
validate,
)

from gql.graphql_request import GraphQLRequest
itolosa marked this conversation as resolved.
Show resolved Hide resolved

from .transport.async_transport import AsyncTransport
from .transport.exceptions import TransportClosed, TransportQueryError
from .transport.local_schema import LocalSchemaTransport
Expand Down Expand Up @@ -236,6 +239,24 @@ def execute_sync(
**kwargs,
)

def execute_batch_sync(
self,
reqs: List[GraphQLRequest],
serialize_variables: Optional[bool] = None,
parse_result: Optional[bool] = None,
get_execution_result: bool = False,
**kwargs,
) -> Union[List[Dict[str, Any]], List[ExecutionResult]]:
""":meta private:"""
with self as session:
return session.execute_batch(
reqs,
serialize_variables=serialize_variables,
parse_result=parse_result,
get_execution_result=get_execution_result,
**kwargs,
)

@overload
async def execute_async(
self,
Expand Down Expand Up @@ -375,7 +396,6 @@ def execute(
"""

if isinstance(self.transport, AsyncTransport):

# Get the current asyncio event loop
# Or create a new event loop if there isn't one (in a new Thread)
try:
Expand Down Expand Up @@ -418,6 +438,48 @@ def execute(
**kwargs,
)

def execute_batch(
self,
reqs: List[GraphQLRequest],
serialize_variables: Optional[bool] = None,
parse_result: Optional[bool] = None,
get_execution_result: bool = False,
**kwargs,
) -> Union[List[Dict[str, Any]], List[ExecutionResult]]:
"""Execute the provided requests against the remote server using
itolosa marked this conversation as resolved.
Show resolved Hide resolved
the transport provided during init.

This function **WILL BLOCK** until the result is received from the server.

Either the transport is sync and we execute the query synchronously directly
OR the transport is async and we execute the query in the asyncio loop
(blocking here until answer).

This method will:

- connect using the transport to get a session
- execute the GraphQL requests on the transport session
- close the session and close the connection to the server

If you want to perform multiple executions, it is better to use
the context manager to keep a session active.

The extra arguments passed in the method will be passed to the transport
execute method.
"""

if isinstance(self.transport, AsyncTransport):
raise NotImplementedError("Batching is not implemented for async yet.")

else: # Sync transports
return self.execute_batch_sync(
reqs,
serialize_variables=serialize_variables,
parse_result=parse_result,
get_execution_result=get_execution_result,
**kwargs,
)

@overload
def subscribe_async(
self,
Expand Down Expand Up @@ -476,7 +538,6 @@ async def subscribe_async(
]:
""":meta private:"""
async with self as session:

generator = session.subscribe(
document,
variable_values=variable_values,
Expand Down Expand Up @@ -600,7 +661,6 @@ def subscribe(
pass

except (KeyboardInterrupt, Exception, GeneratorExit):

# Graceful shutdown
asyncio.ensure_future(async_generator.aclose(), loop=loop)

Expand Down Expand Up @@ -661,11 +721,9 @@ async def close_async(self):
await self.transport.close()

async def __aenter__(self):

return await self.connect_async()

async def __aexit__(self, exc_type, exc, tb):

await self.close_async()

def connect_sync(self):
Expand Down Expand Up @@ -705,7 +763,6 @@ def close_sync(self):
self.transport.close()

def __enter__(self):

return self.connect_sync()

def __exit__(self, *args):
Expand Down Expand Up @@ -880,6 +937,108 @@ def execute(

return result.data

def _execute_batch(
self,
reqs: List[GraphQLRequest],
serialize_variables: Optional[bool] = None,
parse_result: Optional[bool] = None,
**kwargs,
) -> List[ExecutionResult]:
"""Execute the provided requests synchronously using
the sync transport, returning a list of ExecutionResult objects.

:param reqs: List of requests that will be executed.
:param serialize_variables: whether the variable values should be
serialized. Used for custom scalars and/or enums.
By default use the serialize_variables argument of the client.
:param parse_result: Whether gql will unserialize the result.
By default use the parse_results argument of the client.

The extra arguments are passed to the transport execute method."""

# Validate document
if self.client.schema:
for req in reqs:
self.client.validate(req.document)

# Parse variable values for custom scalars if requested
if serialize_variables or (
serialize_variables is None and self.client.serialize_variables
):
reqs = [
req.serialize_variable_values(self.client.schema)
if req.variable_values is not None
else req
for req in reqs
]

results = self.transport.execute_batch(reqs, **kwargs)

# Unserialize the result if requested
if self.client.schema:
if parse_result or (parse_result is None and self.client.parse_results):
for result in results:
result.data = parse_result_fn(
self.client.schema,
req.document,
result.data,
operation_name=req.operation_name,
)

return results

def execute_batch(
self,
reqs: List[GraphQLRequest],
serialize_variables: Optional[bool] = None,
parse_result: Optional[bool] = None,
get_execution_result: bool = False,
**kwargs,
) -> Union[List[Dict[str, Any]], List[ExecutionResult]]:
"""Execute the provided requests synchronously using
the sync transport. This method sends the requests to the server all at once.

Raises a TransportQueryError if an error has been returned in any
ExecutionResult.

:param reqs: List of requests that will be executed.
:param serialize_variables: whether the variable values should be
serialized. Used for custom scalars and/or enums.
By default use the serialize_variables argument of the client.
:param parse_result: Whether gql will unserialize the result.
By default use the parse_results argument of the client.
:param get_execution_result: return the full ExecutionResult instance instead of
only the "data" field. Necessary if you want to get the "extensions" field.

The extra arguments are passed to the transport execute method."""

# Validate and execute on the transport
results = self._execute_batch(
reqs,
serialize_variables=serialize_variables,
parse_result=parse_result,
**kwargs,
)

for result in results:
# Raise an error if an error is returned in the ExecutionResult object
if result.errors:
raise TransportQueryError(
str_first_element(result.errors),
errors=result.errors,
data=result.data,
extensions=result.extensions,
)

assert (
result.data is not None
), "Transport returned an ExecutionResult without data or errors"

if get_execution_result:
return results

return cast(List[Dict[str, Any]], [result.data for result in results])

def fetch_schema(self) -> None:
"""Fetch the GraphQL schema explicitly using introspection.

Expand Down Expand Up @@ -966,7 +1125,6 @@ async def _subscribe(

try:
async for result in inner_generator:

if self.client.schema:
if parse_result or (
parse_result is None and self.client.parse_results
Expand Down Expand Up @@ -1070,7 +1228,6 @@ async def subscribe(
try:
# Validate and subscribe on the transport
async for result in inner_generator:

# Raise an error if an error is returned in the ExecutionResult object
if result.errors:
raise TransportQueryError(
Expand Down Expand Up @@ -1343,7 +1500,6 @@ async def _connection_loop(self):
"""

while True:

# Connect to the transport with the retry decorator
# By default it should keep retrying until it connect
await self._connect_with_retries()
Expand Down
37 changes: 37 additions & 0 deletions gql/graphql_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional

from graphql import DocumentNode, GraphQLSchema

from .utilities import serialize_variable_values


@dataclass(frozen=True)
class GraphQLRequest:
"""GraphQL Request to be executed."""

document: DocumentNode
"""GraphQL query as AST Node object."""

variable_values: Optional[Dict[str, Any]] = None
"""Dictionary of input parameters (Default: None)."""

operation_name: Optional[str] = None
"""
Name of the operation that shall be executed.
Only required in multi-operation documents (Default: None).
"""

def serialize_variable_values(self, schema: GraphQLSchema) -> "GraphQLRequest":
assert self.variable_values

return GraphQLRequest(
document=self.document,
variable_values=serialize_variable_values(
schema=schema,
document=self.document,
variable_values=self.variable_values,
operation_name=self.operation_name,
),
operation_name=self.operation_name,
)
2 changes: 1 addition & 1 deletion gql/transport/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ async def execute(
document: DocumentNode,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
extra_args: Dict[str, Any] = None,
extra_args: Optional[Dict[str, Any]] = None,
upload_files: bool = False,
) -> ExecutionResult:
"""Execute the provided document AST against the configured remote server
Expand Down
Loading