diff --git a/aiocouch/exception.py b/aiocouch/exception.py index e6947bb..deb73bf 100644 --- a/aiocouch/exception.py +++ b/aiocouch/exception.py @@ -112,10 +112,16 @@ class ExpectationFailedError(ValueError): pass +class ClientResponseError(aiohttp.ClientResponseError): + def __init__(self, reason, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self.reason = reason + + def raise_for_endpoint( endpoint: Endpoint, message: str, - exception: aiohttp.ClientResponseError, + exception: ClientResponseError, exception_type: Optional[Type[Exception]] = None, ) -> NoReturn: if exception_type is None: @@ -143,6 +149,9 @@ def raise_for_endpoint( message_input = {} + with suppress(AttributeError): + message_input["reason"] = exception.reason + message_input["reason"] = message_input.get("reason", exception.message) with suppress(AttributeError): message_input["id"] = endpoint.id message_input["endpoint"] = endpoint.endpoint @@ -165,7 +174,7 @@ def decorator_raises(func: FuncT) -> FuncT: async def wrapper(endpoint: Endpoint, *args: Any, **kwargs: Any) -> Any: try: return await func(endpoint, *args, **kwargs) - except aiohttp.ClientResponseError as exception: + except ClientResponseError as exception: if status == exception.status: raise_for_endpoint(endpoint, message, exception, exception_type) raise exception @@ -186,7 +195,7 @@ async def wrapper( try: async for data in func(endpoint, *args, **kwargs): yield data - except aiohttp.ClientResponseError as exception: + except ClientResponseError as exception: if status == exception.status: raise_for_endpoint(endpoint, message, exception, exception_type) raise exception diff --git a/aiocouch/remote.py b/aiocouch/remote.py index c5fa91b..c54349a 100644 --- a/aiocouch/remote.py +++ b/aiocouch/remote.py @@ -37,7 +37,7 @@ import aiohttp from . import database, document -from .exception import NotFoundError, generator_raises, raises +from .exception import NotFoundError, generator_raises, raises, ClientResponseError from .typing import JsonDict @@ -160,7 +160,21 @@ async def _request( async with self._http_session.request( method, url=f"{self._server}{path}", **kwargs ) as resp: - resp.raise_for_status() + if not resp.ok: + reason = None + with suppress(Exception): + reason = (await resp.json())["reason"] + # Copied from aiohttp v3.9.5 raise_for_status(): + assert resp.reason is not None + resp.release() + raise ClientResponseError( + reason, + resp.request_info, + resp.history, + status=resp.status, + message=resp.reason, + headers=resp.headers, + ) return ( HTTPResponse(resp), await resp.json() if return_json else await resp.read(), @@ -179,7 +193,21 @@ async def _streamed_request( async with self._http_session.request( method, url=f"{self._server}{path}", **kwargs ) as resp: - resp.raise_for_status() + if not resp.ok: + reason = None + with suppress(Exception): + reason = (await resp.json())["reason"] + # Copied from aiohttp v3.9.5 raise_for_status(): + assert resp.reason is not None + resp.release() + raise ClientResponseError( + reason, + resp.request_info, + resp.history, + status=resp.status, + message=resp.reason, + headers=resp.headers, + ) async for line in resp.content: # this should only happen for empty lines @@ -187,6 +215,7 @@ async def _streamed_request( yield json.loads(line) @raises(401, "Invalid credentials") + @raises(403, "Access forbidden: {reason}") async def _all_dbs(self, **params: Any) -> List[str]: _, json = await self._get("/_all_dbs", params) assert not isinstance(json, bytes) @@ -203,12 +232,14 @@ async def close(self) -> None: await asyncio.sleep(0.250 if has_ssl_conn else 0) @raises(401, "Invalid credentials") + @raises(403, "Access forbidden: {reason}") async def _info(self) -> JsonDict: _, json = await self._get("/") assert not isinstance(json, bytes) return json @raises(401, "Authentication failed, check provided credentials.") + @raises(403, "Access forbidden: {reason}") async def _check_session(self) -> RequestResult: return await self._get("/_session") @@ -223,19 +254,19 @@ def endpoint(self) -> str: return f"/{_quote_id(self.id)}" @raises(401, "Invalid credentials") - @raises(403, "Read permission required") + @raises(403, "Access forbidden: {reason}") async def _exists(self) -> bool: try: await self._remote._head(self.endpoint) return True - except aiohttp.ClientResponseError as e: + except ClientResponseError as e: if e.status == 404: return False else: raise e @raises(401, "Invalid credentials") - @raises(403, "Read permission required") + @raises(403, "Access forbidden: {reason}") @raises(404, "Requested database not found ({id})") async def _get(self) -> JsonDict: _, json = await self._remote._get(self.endpoint) @@ -244,6 +275,7 @@ async def _get(self) -> JsonDict: @raises(400, "Invalid database name") @raises(401, "CouchDB Server Administrator privileges required") + @raises(403, "Access forbidden: {reason}") @raises(412, "Database already exists") async def _put(self, **params: Any) -> JsonDict: _, json = await self._remote._put(self.endpoint, params=params) @@ -252,13 +284,14 @@ async def _put(self, **params: Any) -> JsonDict: @raises(400, "Invalid database name or forgotten document id by accident") @raises(401, "CouchDB Server Administrator privileges required") + @raises(403, "Access forbidden: {reason}") @raises(404, "Database doesn't exist or invalid database name ({id})") async def _delete(self) -> None: await self._remote._delete(self.endpoint) @raises(400, "The request provided invalid JSON data or invalid query parameter") @raises(401, "Read permission required") - @raises(403, "Read permission required") + @raises(403, "Access forbidden: {reason}") @raises(404, "Invalid database name") @raises(415, "Bad Content-Type value") async def _bulk_get(self, docs: List[str], **params: Any) -> JsonDict: @@ -270,7 +303,7 @@ async def _bulk_get(self, docs: List[str], **params: Any) -> JsonDict: @raises(400, "The request provided invalid JSON data") @raises(401, "Invalid credentials") - @raises(403, "Write permission required") + @raises(403, "Access forbidden: {reason}") @raises(417, "At least one document was rejected by the validation function") async def _bulk_docs(self, docs: List[JsonDict], **data: Any) -> JsonDict: data["docs"] = docs @@ -280,7 +313,7 @@ async def _bulk_docs(self, docs: List[JsonDict], **data: Any) -> JsonDict: @raises(400, "Invalid request") @raises(401, "Read privilege required for document '{id}'") - @raises(403, "Read permission required") + @raises(403, "Access forbidden: {reason}") @raises(500, "Query execution failed", RuntimeError) async def _find(self, selector: Any, **data: Any) -> JsonDict: data["selector"] = selector @@ -290,6 +323,7 @@ async def _find(self, selector: Any, **data: Any) -> JsonDict: @raises(400, "Invalid request") @raises(401, "Admin permission required") + @raises(403, "Access forbidden: {reason}") @raises(404, "Database not found") @raises(500, "Execution error") async def _index(self, index: JsonDict, **data: Any) -> JsonDict: @@ -299,20 +333,21 @@ async def _index(self, index: JsonDict, **data: Any) -> JsonDict: return json @raises(401, "Invalid credentials") - @raises(403, "Permission required") + @raises(403, "Access forbidden: {reason}") async def _get_security(self) -> JsonDict: _, json = await self._remote._get(f"{self.endpoint}/_security") assert not isinstance(json, bytes) return json @raises(401, "Invalid credentials") - @raises(403, "Permission required") + @raises(403, "Access forbidden: {reason}") async def _put_security(self, doc: JsonDict) -> JsonDict: _, json = await self._remote._put(f"{self.endpoint}/_security", doc) assert not isinstance(json, bytes) return json @generator_raises(400, "Invalid request") + @generator_raises(403, "Access forbidden: {reason}") async def _changes(self, **params: Any) -> AsyncGenerator[JsonDict, None]: if "feed" in params and params["feed"] == "continuous": params.setdefault("heartbeat", True) @@ -329,6 +364,7 @@ async def _changes(self, **params: Any) -> AsyncGenerator[JsonDict, None]: yield result @raises(400, "Invalid database or JSON payload") + @raises(403, "Access forbidden: {reason}") @raises(415, "Bad Content-Type header value") @raises(500, "Internal server error or timeout") async def _purge(self, docs: JsonDict, **params: Any) -> JsonDict: @@ -350,13 +386,13 @@ def endpoint(self) -> str: return f"{self._database.endpoint}/{_quote_id(self.id)}" @raises(401, "Read privilege required for document '{id}'") - @raises(403, "Read privilege required for document '{id}'") + @raises(403, "Access forbidden: {reason}") @raises(404, "Document {id} was not found") async def _head(self) -> None: await self._database._remote._head(self.endpoint) @raises(401, "Read privilege required for document '{id}'") - @raises(403, "Read privilege required for document '{id}'") + @raises(403, "Access forbidden: {reason}") @raises(404, "Document {id} was not found") async def _info(self) -> JsonDict: response, _ = await self._database._remote._head(self.endpoint) @@ -376,7 +412,7 @@ async def _exists(self) -> bool: @raises(400, "The format of the request or revision was invalid") @raises(401, "Read privilege required for document '{id}'") - @raises(403, "Read privilege required for document '{id}'") + @raises(403, "Access forbidden: {reason}") @raises(404, "Document {id} was not found") async def _get(self, **params: Any) -> JsonDict: _, json = await self._database._remote._get(self.endpoint, params) @@ -385,7 +421,7 @@ async def _get(self, **params: Any) -> JsonDict: @raises(400, "The format of the request or revision was invalid") @raises(401, "Write privilege required for document '{id}'") - @raises(403, "Write privilege required for document '{id}'") + @raises(403, "Access forbidden: {reason}") @raises(404, "Specified database or document ID doesn't exists ({endpoint})") @raises( 409, @@ -401,7 +437,7 @@ async def _put( @raises(400, "Invalid request body or parameters") @raises(401, "Write privilege required for document '{id}'") - @raises(403, "Write privilege required for document '{id}'") + @raises(403, "Access forbidden: {reason}") @raises(404, "Specified database or document ID doesn't exists ({endpoint})") @raises( 409, "Specified revision ({rev}) is not the latest for target document '{id}'" @@ -414,7 +450,7 @@ async def _delete(self, rev: str, **params: Any) -> Tuple[HTTPResponse, JsonDict @raises(400, "Invalid request body or parameters") @raises(401, "Read or write privileges required") - @raises(403, "Read or write privileges required") + @raises(403, "Access forbidden: {reason}") @raises( 404, "Specified database, document ID or revision doesn't exists ({endpoint})" ) @@ -444,7 +480,7 @@ def endpoint(self) -> str: return f"{self._document.endpoint}/{_quote_id(self.id)}" @raises(401, "Read privilege required for document '{document_id}'") - @raises(403, "Read privilege required for document '{document_id}'") + @raises(403, "Access forbidden: {reason}") async def _exists(self) -> bool: try: response, _ = await self._document._database._remote._head( @@ -452,7 +488,7 @@ async def _exists(self) -> bool: ) self.content_type = response.headers["Content-Type"] return True - except aiohttp.ClientResponseError as e: + except ClientResponseError as e: if e.status == 404: return False else: @@ -460,7 +496,7 @@ async def _exists(self) -> bool: @raises(400, "Invalid request parameters") @raises(401, "Read privilege required for document '{document_id}'") - @raises(403, "Read privilege required for document '{document_id}'") + @raises(403, "Access forbidden: {reason}") @raises(404, "Document '{document_id}' or attachment '{id}' doesn't exists") async def _get(self, **params: Any) -> bytes: response, data = await self._document._database._remote._get_bytes( @@ -472,7 +508,7 @@ async def _get(self, **params: Any) -> bytes: @raises(400, "Invalid request body or parameters") @raises(401, "Write privilege required for document '{document_id}'") - @raises(403, "Write privilege required for document '{document_id}'") + @raises(403, "Access forbidden: {reason}") @raises(404, "Document '{document_id}' doesn't exists") @raises( 409, "Specified revision {document_rev} is not the latest for target document" @@ -490,7 +526,7 @@ async def _put( @raises(400, "Invalid request body or parameters") @raises(401, "Write privilege required for document '{document_id}'") - @raises(403, "Write privilege required for document '{document_id}'") + @raises(403, "Access forbidden: {reason}") @raises(404, "Specified database or document ID doesn't exists ({endpoint})") @raises( 409, "Specified revision {document_rev} is not the latest for target document" @@ -519,7 +555,7 @@ def endpoint(self) -> str: @raises(400, "Invalid request") @raises(401, "Read privileges required") - @raises(403, "Read privileges required") + @raises(403, "Access forbidden: {reason}") @raises(404, "Specified database, design document or view is missing") async def _get(self, **params: Any) -> JsonDict: _, json = await self._database._remote._get(self.endpoint, params) @@ -528,7 +564,7 @@ async def _get(self, **params: Any) -> JsonDict: @raises(400, "Invalid request") @raises(401, "Write privileges required") - @raises(403, "Write privileges required") + @raises(403, "Access forbidden: {reason}") @raises(404, "Specified database, design document or view is missing") async def _post(self, keys: List[str], **params: Any) -> JsonDict: _, json = await self._database._remote._post( diff --git a/tests/test_exception.py b/tests/test_exception.py index ee21d60..a3c7d17 100644 --- a/tests/test_exception.py +++ b/tests/test_exception.py @@ -2,7 +2,6 @@ import pytest from aiohttp.client import RequestInfo -from aiohttp.client_exceptions import ClientResponseError from aiocouch.couchdb import JsonDict from aiocouch.exception import ( @@ -14,6 +13,7 @@ PreconditionFailedError, UnauthorizedError, UnsupportedMediaTypeError, + ClientResponseError, generator_raises, raises, ) @@ -27,49 +27,45 @@ class CustomError(Exception): class DummyEndpoint: - @property - def endpoint(self) -> str: - return "endpoint" - @raises(400, "bad thing") async def raise_bad_request(self) -> NoReturn: - raise ClientResponseError(cast(RequestInfo, None), (), status=400) + raise ClientResponseError(None, cast(RequestInfo, None), (), status=400) @raises(401, "bad thing") async def raise_unauthorized(self) -> NoReturn: - raise ClientResponseError(cast(RequestInfo, None), (), status=401) + raise ClientResponseError(None, cast(RequestInfo, None), (), status=401) - @raises(403, "bad thing") + @raises(403, "Access forbidden: {reason}") async def raise_forbidden(self) -> NoReturn: - raise ClientResponseError(cast(RequestInfo, None), (), status=403) + raise ClientResponseError("a reason", cast(RequestInfo, None), (), status=403) @raises(404, "bad thing") async def raise_not_found(self) -> NoReturn: - raise ClientResponseError(cast(RequestInfo, None), (), status=404) + raise ClientResponseError(None, cast(RequestInfo, None), (), status=404) @raises(409, "bad thing") async def raise_conflict(self) -> NoReturn: - raise ClientResponseError(cast(RequestInfo, None), (), status=409) + raise ClientResponseError(None, cast(RequestInfo, None), (), status=409) @raises(412, "bad thing") async def raise_precondition_failed(self) -> NoReturn: - raise ClientResponseError(cast(RequestInfo, None), (), status=412) + raise ClientResponseError(None, cast(RequestInfo, None), (), status=412) @raises(415, "bad thing") async def raise_unsupported_media(self) -> NoReturn: - raise ClientResponseError(cast(RequestInfo, None), (), status=415) + raise ClientResponseError(None, cast(RequestInfo, None), (), status=415) @raises(417, "bad thing") async def raise_expectation_failed(self) -> NoReturn: - raise ClientResponseError(cast(RequestInfo, None), (), status=417) + raise ClientResponseError(None, cast(RequestInfo, None), (), status=417) @raises(500, "bad thing", CustomError) async def raise_custom(self) -> NoReturn: - raise ClientResponseError(cast(RequestInfo, None), (), status=500) + raise ClientResponseError(None, cast(RequestInfo, None), (), status=500) @generator_raises(400, "bad thing") async def raise_in_generator(self) -> AsyncGenerator[JsonDict, None]: - raise ClientResponseError(cast(RequestInfo, None), (), status=400) + raise ClientResponseError(None, cast(RequestInfo, None), (), status=400) yield None @@ -81,7 +77,7 @@ async def test_raises() -> None: with pytest.raises(UnauthorizedError): await dummy.raise_unauthorized() - with pytest.raises(ForbiddenError): + with pytest.raises(ForbiddenError, match="Access forbidden: a reason"): await dummy.raise_forbidden() with pytest.raises(NotFoundError):