Skip to content

Commit

Permalink
feat(client): use asyncio semaphore for streaming batch check
Browse files Browse the repository at this point in the history
  • Loading branch information
booniepepper committed Nov 2, 2023
1 parent 5d67f5c commit fc3bd65
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions openfga_sdk/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,17 +511,20 @@ async def check(self, body: ClientCheckRequest, options: dict[str, str] = None):
)
return api_response

async def _single_batch_check(self, body: ClientCheckRequest, options: dict[str, str] = None): # noqa: E501
async def _single_batch_check(self, body: ClientCheckRequest, semaphore: asyncio.Semaphore, options: dict[str, str] = None): # noqa: E501
"""
Run a single batch request and return body in a SingleBatchCheckResponse
:param body - ClientCheckRequest defining check request
:param authorization_model_id(options) - Overrides the authorization model id in the configuration
"""
await semaphore.acquire()
try:
api_response = await self.check(body, options)
return BatchCheckResponse(allowed=api_response.allowed, request=body, response=api_response, error=None)
except Exception as err:
return BatchCheckResponse(allowed=False, request=body, response=None, error=err)
finally:
semaphore.release()

async def batch_check(self, body: List[ClientCheckRequest], options: dict[str, str] = None): # noqa: E501
"""
Expand All @@ -543,13 +546,11 @@ async def batch_check(self, body: List[ClientCheckRequest], options: dict[str, s
max_parallel_requests = 10
if options is not None and "max_parallel_requests" in options:
max_parallel_requests = options["max_parallel_requests"]
# Break the batch into chunks
request_batches = _chuck_array(body, max_parallel_requests)
batch_check_response = []
for request_batch in request_batches:
request = [self._single_batch_check(i, options) for i in request_batch]
response = await asyncio.gather(*request)
batch_check_response.extend(response)

sem = asyncio.Semaphore(max_parallel_requests)
batch_check_coros = [self._single_batch_check(request, sem, options) for request in body]
batch_check_response = [await coro for coro in batch_check_coros]

return batch_check_response

async def expand(self, body: ClientExpandRequest, options: dict[str, str] = None): # noqa: E501
Expand Down

0 comments on commit fc3bd65

Please sign in to comment.