Skip to content

Commit

Permalink
feat: batch check updates (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
rhamzeh authored Nov 2, 2023
2 parents d27eb1b + 095c3dd commit d8f2d42
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
17 changes: 9 additions & 8 deletions openfga_sdk/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,17 +511,20 @@ async def check(self, body: ClientCheckRequest, options: dict[str, str] = None):
)
return api_response

async def _single_batch_check(self, body: ClientCheckRequest, options: dict[str, str] = None): # noqa: E501
async def _single_batch_check(self, body: ClientCheckRequest, semaphore: asyncio.Semaphore, options: dict[str, str] = None): # noqa: E501
"""
Run a single batch request and return body in a SingleBatchCheckResponse
:param body - ClientCheckRequest defining check request
:param authorization_model_id(options) - Overrides the authorization model id in the configuration
"""
await semaphore.acquire()
try:
api_response = await self.check(body, options)
return BatchCheckResponse(allowed=api_response.allowed, request=body, response=api_response, error=None)
except Exception as err:
return BatchCheckResponse(allowed=False, request=body, response=None, error=err)
finally:
semaphore.release()

async def batch_check(self, body: List[ClientCheckRequest], options: dict[str, str] = None): # noqa: E501
"""
Expand All @@ -543,13 +546,11 @@ async def batch_check(self, body: List[ClientCheckRequest], options: dict[str, s
max_parallel_requests = 10
if options is not None and "max_parallel_requests" in options:
max_parallel_requests = options["max_parallel_requests"]
# Break the batch into chunks
request_batches = _chuck_array(body, max_parallel_requests)
batch_check_response = []
for request_batch in request_batches:
request = [self._single_batch_check(i, options) for i in request_batch]
response = await asyncio.gather(*request)
batch_check_response.extend(response)

sem = asyncio.Semaphore(max_parallel_requests)
batch_check_coros = [self._single_batch_check(request, sem, options) for request in body]
batch_check_response = await asyncio.gather(*batch_check_coros)

return batch_check_response

async def expand(self, body: ClientExpandRequest, options: dict[str, str] = None): # noqa: E501
Expand Down
16 changes: 10 additions & 6 deletions openfga_sdk/sync/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@
from openfga_sdk.models.write_request import WriteRequest
from openfga_sdk.validation import is_well_formed_ulid_string

import time
import uuid
from typing import List
from concurrent.futures import ThreadPoolExecutor

CLIENT_METHOD_HEADER = "X-OpenFGA-Client-Method"
CLIENT_BULK_REQUEST_ID_HEADER = "X-OpenFGA-Client-Bulk-Request-Id"
Expand Down Expand Up @@ -543,12 +543,16 @@ def batch_check(self, body: List[ClientCheckRequest], options: dict[str, str] =
max_parallel_requests = 10
if options is not None and "max_parallel_requests" in options:
max_parallel_requests = options["max_parallel_requests"]
# Break the batch into chunks
request_batches = _chuck_array(body, max_parallel_requests)

batch_check_response = []
for request_batch in request_batches:
response = [self._single_batch_check(i, options) for i in request_batch]
batch_check_response.extend(response)

def single_batch_check(request):
return self._single_batch_check(request, options)

with ThreadPoolExecutor(max_workers=max_parallel_requests) as executor:
for response in executor.map(single_batch_check, body):
batch_check_response.append(response)

return batch_check_response

def expand(self, body: ClientExpandRequest, options: dict[str, str] = None): # noqa: E501
Expand Down

0 comments on commit d8f2d42

Please sign in to comment.