Skip to content

Commit

Permalink
WIP defer support
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick91 committed Jan 14, 2025
1 parent fa5c2d0 commit 9e7d000
Show file tree
Hide file tree
Showing 12 changed files with 291 additions and 157 deletions.
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: minor

@defer 👀
16 changes: 8 additions & 8 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry.dependencies]
python = "^3.9"
graphql-core = ">=3.2.0,<3.4.0"
graphql-core = ">=3.2.0"
typing-extensions = ">=4.5.0"
python-dateutil = "^2.7.0"
starlette = {version = ">=0.18.0", optional = true}
Expand Down Expand Up @@ -102,6 +102,7 @@ types-deprecated = "^1.2.15.20241117"
types-six = "^1.17.0.20241205"
types-pyyaml = "^6.0.12.20240917"
mypy = "^1.13.0"
graphql-core = "3.3.0a6"

[tool.poetry.group.integrations]
optional = true
Expand Down
101 changes: 96 additions & 5 deletions strawberry/http/async_base_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,18 @@

from graphql import GraphQLError

# TODO: only import this if exists
from graphql.execution.execute import (
ExperimentalIncrementalExecutionResults,
InitialIncrementalExecutionResult,
)
from graphql.execution.incremental_publisher import (
IncrementalDeferResult,
IncrementalResult,
IncrementalStreamResult,
SubsequentIncrementalExecutionResult,
)

from strawberry.exceptions import MissingQueryError
from strawberry.file_uploads.utils import replace_placeholders_with_files
from strawberry.http import (
Expand Down Expand Up @@ -337,6 +349,29 @@ async def run(
except MissingQueryError as e:
raise HTTPException(400, "No GraphQL query found in the request") from e

if isinstance(result, ExperimentalIncrementalExecutionResults):

async def stream():
yield "---"
response = await self.process_result(request, result.initial_result)
yield self.encode_multipart_data(response, "-")

async for value in result.subsequent_results:
response = await self.process_subsequent_result(request, value)
yield self.encode_multipart_data(response, "-")

yield "--\r\n"

return await self.create_streaming_response(
request,
stream,
sub_response,
headers={
"Transfer-Encoding": "chunked",
"Content-Type": 'multipart/mixed; boundary="-"',
},
)

if isinstance(result, SubscriptionExecutionResult):
stream = self._get_stream(request, result)

Expand All @@ -360,12 +395,15 @@ async def run(
)

def encode_multipart_data(self, data: Any, separator: str) -> str:
encoded_data = self.encode_json(data)

return "".join(
[
f"\r\n--{separator}\r\n",
"Content-Type: application/json\r\n\r\n",
self.encode_json(data),
"\n",
"\r\n",
"Content-Type: application/json; charset=utf-8\r\n",
"\r\n",
encoded_data,
f"\r\n--{separator}",
]
)

Expand Down Expand Up @@ -475,9 +513,62 @@ async def parse_http_body(
protocol=protocol,
)

def process_incremental_result(
self, request: Request, result: IncrementalResult
) -> GraphQLHTTPResponse:
if isinstance(result, IncrementalDeferResult):
return {
"data": result.data,
"errors": result.errors,
"path": result.path,
"label": result.label,
"extensions": result.extensions,
}
if isinstance(result, IncrementalStreamResult):
return {
"items": result.items,
"errors": result.errors,
"path": result.path,
"label": result.label,
"extensions": result.extensions,
}

raise ValueError(f"Unsupported incremental result type: {type(result)}")

async def process_subsequent_result(
self,
request: Request,
result: SubsequentIncrementalExecutionResult,
# TODO: use proper return type
) -> GraphQLHTTPResponse:
data = {
"incremental": [
await self.process_result(request, value)
for value in result.incremental
],
"hasNext": result.has_next,
"extensions": result.extensions,
}

return data

async def process_result(
self, request: Request, result: ExecutionResult
self,
request: Request,
result: ExecutionResult | InitialIncrementalExecutionResult,
) -> GraphQLHTTPResponse:
if isinstance(result, InitialIncrementalExecutionResult):
return {
"data": result.data,
"incremental": [
self.process_incremental_result(request, value)
for value in result.incremental
]
if result.incremental
else [],
"hasNext": result.has_next,
"extensions": result.extensions,
}
return process_result(result)

async def on_ws_connect(
Expand Down
49 changes: 26 additions & 23 deletions strawberry/schema/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from graphql import ExecutionResult as GraphQLExecutionResult
from graphql import GraphQLError, parse
from graphql import execute as original_execute
from graphql.execution import experimental_execute_incrementally
from graphql.validation import validate

from strawberry.exceptions import MissingQueryError
Expand Down Expand Up @@ -121,16 +121,17 @@ async def _handle_execution_result(
extensions_runner: SchemaExtensionsRunner,
process_errors: ProcessErrors | None,
) -> ExecutionResult:
# Set errors on the context so that it's easier
# to access in extensions
if result.errors:
context.errors = result.errors
if process_errors:
process_errors(result.errors, context)
if isinstance(result, GraphQLExecutionResult):
result = ExecutionResult(data=result.data, errors=result.errors)
result.extensions = await extensions_runner.get_extensions_results(context)
context.result = result # type: ignore # mypy failed to deduce correct type.
# TODO: deal with this later
# # Set errors on the context so that it's easier
# # to access in extensions
# if result.errors:
# context.errors = result.errors
# if process_errors:
# process_errors(result.errors, context)
# if isinstance(result, GraphQLExecutionResult):
# result = ExecutionResult(data=result.data, errors=result.errors)
# result.extensions = await extensions_runner.get_extensions_results(context)
# context.result = result # type: ignore # mypy failed to deduce correct type.
return result


Expand Down Expand Up @@ -164,7 +165,7 @@ async def execute(
async with extensions_runner.executing():
if not execution_context.result:
result = await await_maybe(
original_execute(
experimental_execute_incrementally(
schema,
execution_context.graphql_document,
root_value=execution_context.root_value,
Expand All @@ -178,16 +179,18 @@ async def execute(
execution_context.result = result
else:
result = execution_context.result
# Also set errors on the execution_context so that it's easier
# to access in extensions
if result.errors:
execution_context.errors = result.errors

# Run the `Schema.process_errors` function here before
# extensions have a chance to modify them (see the MaskErrors
# extension). That way we can log the original errors but
# only return a sanitised version to the client.
process_errors(result.errors, execution_context)
# TODO: deal with this later
# # Also set errors on the execution_context so that it's easier
# # to access in extensions
# breakpoint()
# if result.errors:
# execution_context.errors = result.errors

# # Run the `Schema.process_errors` function here before
# # extensions have a chance to modify them (see the MaskErrors
# # extension). That way we can log the original errors but
# # only return a sanitised version to the client.
# process_errors(result.errors, execution_context)

except (MissingQueryError, InvalidOperationTypeError):
raise
Expand Down Expand Up @@ -252,7 +255,7 @@ def execute_sync(

with extensions_runner.executing():
if not execution_context.result:
result = original_execute(
result = experimental_execute_incrementally(
schema,
execution_context.graphql_document,
root_value=execution_context.root_value,
Expand Down
12 changes: 10 additions & 2 deletions strawberry/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
validate_schema,
)
from graphql.execution.middleware import MiddlewareManager
from graphql.type.directives import specified_directives
from graphql.type.directives import (
GraphQLDeferDirective,
GraphQLStreamDirective,
specified_directives,
)

from strawberry import relay
from strawberry.annotation import StrawberryAnnotation
Expand Down Expand Up @@ -194,7 +198,11 @@ class Query:
query=query_type,
mutation=mutation_type,
subscription=subscription_type if subscription else None,
directives=specified_directives + tuple(graphql_directives),
directives=(
specified_directives
+ tuple(graphql_directives)
+ (GraphQLDeferDirective, GraphQLStreamDirective)
),
types=graphql_types,
extensions={
GraphQLCoreConverter.DEFINITION_BACKREF: self,
Expand Down
7 changes: 2 additions & 5 deletions strawberry/static/graphiql.html
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@
<link
crossorigin
rel="stylesheet"
href="https://unpkg.com/[email protected]/graphiql.min.css"
integrity="sha384-yz3/sqpuplkA7msMo0FE4ekg0xdwdvZ8JX9MVZREsxipqjU4h8IRfmAMRcb1QpUy"
href="https://unpkg.com/[email protected]/graphiql.min.css"
/>

<link
Expand All @@ -77,13 +76,11 @@
<div id="graphiql" class="graphiql-container">Loading...</div>
<script
crossorigin
src="https://unpkg.com/[email protected]/graphiql.min.js"
integrity="sha384-Mjte+vxCWz1ZYCzszGHiJqJa5eAxiqI4mc3BErq7eDXnt+UGLXSEW7+i0wmfPiji"
src="https://unpkg.com/[email protected]/graphiql.min.js"
></script>
<script
crossorigin
src="https://unpkg.com/@graphiql/[email protected]/dist/index.umd.js"
integrity="sha384-2oonKe47vfHIZnmB6ZZ10vl7T0Y+qrHQF2cmNTaFDuPshpKqpUMGMc9jgj9MLDZ9"
></script>
<script>
const EXAMPLE_QUERY = `# Welcome to GraphiQL 🍓
Expand Down
Empty file.
44 changes: 44 additions & 0 deletions tests/http/incremental/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import contextlib

import pytest

from tests.http.clients.base import HttpClient


@pytest.fixture
def http_client(http_client_class: type[HttpClient]) -> HttpClient:
with contextlib.suppress(ImportError):
import django

if django.VERSION < (4, 2):
pytest.skip(reason="Django < 4.2 doesn't async streaming responses")

from tests.http.clients.django import DjangoHttpClient

if http_client_class is DjangoHttpClient:
pytest.skip(reason="(sync) DjangoHttpClient doesn't support streaming")

with contextlib.suppress(ImportError):
from tests.http.clients.channels import SyncChannelsHttpClient

# TODO: why do we have a sync channels client?
if http_client_class is SyncChannelsHttpClient:
pytest.skip(reason="SyncChannelsHttpClient doesn't support streaming")

with contextlib.suppress(ImportError):
from tests.http.clients.async_flask import AsyncFlaskHttpClient
from tests.http.clients.flask import FlaskHttpClient

if http_client_class is FlaskHttpClient:
pytest.skip(reason="FlaskHttpClient doesn't support streaming")

if http_client_class is AsyncFlaskHttpClient:
pytest.xfail(reason="AsyncFlaskHttpClient doesn't support streaming")

with contextlib.suppress(ImportError):
from tests.http.clients.chalice import ChaliceHttpClient

if http_client_class is ChaliceHttpClient:
pytest.skip(reason="ChaliceHttpClient doesn't support streaming")

return http_client_class()
Loading

0 comments on commit 9e7d000

Please sign in to comment.