Skip to content

Commit

Permalink
feat(python-sdk): batch_check updates (#228)
Browse files Browse the repository at this point in the history
  • Loading branch information
rhamzeh authored Nov 2, 2023
2 parents ac7d46b + 1eeb2b5 commit 660bb33
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
22 changes: 15 additions & 7 deletions config/clients/python/template/client/client.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -514,17 +514,22 @@ class OpenFgaClient():
)
return api_response

{{#asyncio}}async {{/asyncio}}def _single_batch_check(self, body: ClientCheckRequest, options: dict[str, str] = None): # noqa: E501
{{#asyncio}}async {{/asyncio}}def _single_batch_check(self, body: ClientCheckRequest, semaphore: asyncio.Semaphore, options: dict[str, str] = None): # noqa: E501
"""
Run a single batch request and return body in a SingleBatchCheckResponse
:param body - ClientCheckRequest defining check request
:param authorization_model_id(options) - Overrides the authorization model id in the configuration
"""
{{#asyncio}}await semaphore.acquire(){{/asyncio}}
try:
api_response = {{#asyncio}}await {{/asyncio}}self.check(body, options)
return BatchCheckResponse(allowed=api_response.allowed, request=body, response=api_response, error=None)
except Exception as err:
return BatchCheckResponse(allowed=False, request=body, response=None, error=err)
{{#asyncio}}
finally:
semaphore.release()
{{/asyncio}}

{{#asyncio}}async {{/asyncio}}def batch_check(self, body: List[ClientCheckRequest], options: dict[str, str] = None): # noqa: E501
"""
Expand All @@ -546,18 +551,21 @@ class OpenFgaClient():
max_parallel_requests = {{ clientMaxMethodParallelRequests }}
if options is not None and "max_parallel_requests" in options:
max_parallel_requests = options["max_parallel_requests"]

{{#asyncio}}
sem = asyncio.Semaphore(max_parallel_requests)
batch_check_coros = [self._single_batch_check(request, sem, options) for request in body]
batch_check_response = await asyncio.gather(*batch_check_coros)
{{/asyncio}}
{{^asyncio}}
# Break the batch into chunks
request_batches = _chuck_array(body, max_parallel_requests)
batch_check_response = []
for request_batch in request_batches:
{{#asyncio}}
request = [self._single_batch_check(i, options) for i in request_batch]
response = await asyncio.gather(*request)
{{/asyncio}}
{{^asyncio}}
response = [self._single_batch_check(i, options) for i in request_batch]
{{/asyncio}}
batch_check_response.extend(response)
{{/asyncio}}

return batch_check_response

{{#asyncio}}async {{/asyncio}}def expand(self, body: ClientExpandRequest, options: dict[str, str] = None): # noqa: E501
Expand Down
16 changes: 10 additions & 6 deletions config/clients/python/template/client/client_sync.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ from {{packageName}}.models.write_authorization_model_request import WriteAuthor
from {{packageName}}.models.write_request import WriteRequest
from {{packageName}}.validation import is_well_formed_ulid_string

import time
import uuid
from typing import List
from concurrent.futures import ThreadPoolExecutor

CLIENT_METHOD_HEADER = "X-OpenFGA-Client-Method"
CLIENT_BULK_REQUEST_ID_HEADER = "X-OpenFGA-Client-Bulk-Request-Id"
Expand Down Expand Up @@ -526,12 +526,16 @@ class OpenFgaClient():
max_parallel_requests = {{ clientMaxMethodParallelRequests }}
if options is not None and "max_parallel_requests" in options:
max_parallel_requests = options["max_parallel_requests"]
# Break the batch into chunks
request_batches = _chuck_array(body, max_parallel_requests)

batch_check_response = []
for request_batch in request_batches:
response = [self._single_batch_check(i, options) for i in request_batch]
batch_check_response.extend(response)

def single_batch_check(request):
return self._single_batch_check(request, options)

with ThreadPoolExecutor(max_workers=max_parallel_requests) as executor:
for response in executor.map(single_batch_check, body):
batch_check_response.append(response)

return batch_check_response

def expand(self, body: ClientExpandRequest, options: dict[str, str] = None): # noqa: E501
Expand Down

0 comments on commit 660bb33

Please sign in to comment.