diff --git a/offchain/metadata/fetchers/metadata_fetcher.py b/offchain/metadata/fetchers/metadata_fetcher.py index 07129b4..96ed404 100644 --- a/offchain/metadata/fetchers/metadata_fetcher.py +++ b/offchain/metadata/fetchers/metadata_fetcher.py @@ -1,4 +1,3 @@ -import cgi from typing import Optional, Tuple, Union import httpx @@ -8,6 +7,7 @@ from offchain.metadata.adapters import Adapter, AdapterConfig, DEFAULT_ADAPTER_CONFIGS from offchain.metadata.fetchers.base_fetcher import BaseFetcher from offchain.metadata.registries.fetcher_registry import FetcherRegistry +from offchain.utils.utils import parse_content_type @FetcherRegistry.register @@ -48,7 +48,7 @@ def set_max_retries(self, max_retries: int): # type: ignore[no-untyped-def] Args: max_retries (int): new maximum number of request retries. """ - self.max_retries = max_retries + self.max_retries = max_retries # pragma: no cover def set_timeout(self, timeout: int): # type: ignore[no-untyped-def] """Setter function for timeout @@ -56,7 +56,7 @@ def set_timeout(self, timeout: int): # type: ignore[no-untyped-def] Args: timeout (int): new request timeout in seconds. """ - self.timeout = timeout + self.timeout = timeout # pragma: no cover def _get_async_adapter_for_uri(self, uri: str) -> Optional[Adapter]: if self.async_adapter_configs is None: @@ -64,19 +64,16 @@ def _get_async_adapter_for_uri(self, uri: str) -> Optional[Adapter]: return None for async_adapter_config in self.async_adapter_configs: - if any( - uri.startswith(prefix) for prefix in async_adapter_config.mount_prefixes - ): + if any(uri.startswith(prefix) for prefix in async_adapter_config.mount_prefixes): logger.debug( - f"Selected {async_adapter_config.adapter_cls.__name__} for making async http requests for uri={uri}" # noqa: E501 + f"Selected {async_adapter_config.adapter_cls.__name__} for making async http requests for uri={uri}" + # noqa: E501 ) return async_adapter_config.adapter_cls( host_prefixes=async_adapter_config.host_prefixes, **async_adapter_config.kwargs, ) - logger.warning( - f"Unable to selected an adapter for async http requests for uri={uri}" - ) + logger.warning(f"Unable to selected an adapter for async http requests for uri={uri}") return None def _head(self, uri: str): # type: ignore[no-untyped-def] @@ -89,16 +86,10 @@ async def _gen(self, uri: str, method: Optional[str] = "GET") -> httpx.Response: async_adapter = self._get_async_adapter_for_uri(uri) if async_adapter is not None: if method == "HEAD": - return await async_adapter.gen_head( - url=uri, timeout=self.timeout, sess=self.async_sess - ) + return await async_adapter.gen_head(url=uri, timeout=self.timeout, sess=self.async_sess) else: - return await async_adapter.gen_send( - url=uri, timeout=self.timeout, sess=self.async_sess - ) - return await self.async_sess.get( - uri, timeout=self.timeout, follow_redirects=True - ) + return await async_adapter.gen_send(url=uri, timeout=self.timeout, sess=self.async_sess) + return await self.async_sess.get(uri, timeout=self.timeout, follow_redirects=True) async def _gen_head(self, uri: str) -> httpx.Response: return await self._gen(uri=uri, method="HEAD") @@ -122,13 +113,11 @@ def fetch_mime_type_and_size(self, uri: str) -> Tuple[str, int]: size = headers.get("content-length", 0) content_type = headers.get("content-type") or headers.get("Content-Type") if content_type is not None: - content_type, _ = cgi.parse_header(content_type) + content_type = parse_content_type(content_type) return content_type, size except Exception as e: - logger.error( - f"Failed to fetch content-type and size from uri {uri}. Error: {e}" - ) + logger.error(f"Failed to fetch content-type and size from uri {uri}. Error: {e}") raise async def gen_fetch_mime_type_and_size(self, uri: str) -> Tuple[str, int]: @@ -150,13 +139,11 @@ async def gen_fetch_mime_type_and_size(self, uri: str) -> Tuple[str, int]: size = headers.get("content-length", 0) content_type = headers.get("content-type") or headers.get("Content-Type") if content_type is not None: - content_type, _ = cgi.parse_header(content_type) + content_type = parse_content_type(content_type) return content_type, size except Exception as e: - logger.error( - f"Failed to fetch content-type and size from uri {uri}. Error: {e}" - ) + logger.error(f"Failed to fetch content-type and size from uri {uri}. Error: {e}") raise def fetch_content(self, uri: str) -> Union[dict, str]: # type: ignore[type-arg] diff --git a/offchain/utils/utils.py b/offchain/utils/utils.py index 7c9cf91..87bb874 100644 --- a/offchain/utils/utils.py +++ b/offchain/utils/utils.py @@ -28,7 +28,9 @@ async def wrapped(*args, **kwargs): # type: ignore[no-untyped-def] try: return await asyncio.wait_for(fn(*args, **kwargs), timeout=timeout) except Exception: - msg = f"Caught exception while executing async function {fn}: {traceback.format_exc()}" # noqa: E501 + msg = ( + f"Caught exception while executing async function {fn}: {traceback.format_exc()}" # noqa: E501 + ) if i + 1 == attempt: logger.error(msg) if not silent: @@ -40,3 +42,11 @@ async def wrapped(*args, **kwargs): # type: ignore[no-untyped-def] return wrapped return wrapper + + +def parse_content_type(header_string: str) -> str: + from email.message import EmailMessage + + msg = EmailMessage() + msg["content-type"] = header_string + return msg.get_content_type() diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index e6b66c1..a3762f7 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -3,7 +3,7 @@ import pytest -from offchain.utils.utils import safe_async_runner +from offchain.utils.utils import safe_async_runner, parse_content_type def build_coro(ret_val: int, delay: float): # type: ignore[no-untyped-def] @@ -16,31 +16,22 @@ async def coro(): # type: ignore[no-untyped-def] @pytest.mark.asyncio async def test_secure_gather_runner_happy_path(): # type: ignore[no-untyped-def] - results = await asyncio.gather( - *[ - safe_async_runner()(build_coro(ret_val, delay=0.1))() - for ret_val in range(10) - ] - ) + results = await asyncio.gather(*[safe_async_runner()(build_coro(ret_val, delay=0.1))() for ret_val in range(10)]) assert results == list(range(10)) @pytest.mark.asyncio async def test_secure_gather_runner_timeout(): # type: ignore[no-untyped-def] coros = [safe_async_runner(timeout=0.2)(build_coro(ret_val=-1, delay=0.3))()] + [ - safe_async_runner(timeout=0.2)(build_coro(ret_val, delay=0.1))() - for ret_val in range(10) + safe_async_runner(timeout=0.2)(build_coro(ret_val, delay=0.1))() for ret_val in range(10) ] # expect raise timeout error, because the first coro is erroring out with pytest.raises(asyncio.TimeoutError): await asyncio.gather(*coros) # when silent the runs, we should get results for all other coroutines - coros = [ - safe_async_runner(timeout=0.2, silent=True)(build_coro(ret_val=-1, delay=0.3))() - ] + [ - safe_async_runner(timeout=0.2, silent=True)(build_coro(ret_val, delay=0.1))() - for ret_val in range(10) + coros = [safe_async_runner(timeout=0.2, silent=True)(build_coro(ret_val=-1, delay=0.3))()] + [ + safe_async_runner(timeout=0.2, silent=True)(build_coro(ret_val, delay=0.1))() for ret_val in range(10) ] results = await asyncio.gather(*coros) assert results == [None] + list(range(10)) @@ -60,3 +51,15 @@ async def test_secure_gather_runner_retry(): # type: ignore[no-untyped-def] ) duration = time.time() - start assert duration == pytest.approx(0.1 * 5, rel=0.05) + + +@pytest.mark.parametrize( + "header_string, expected", + [ + ('application/json; charset="utf8"', "application/json"), + ("application/ld+json", "application/ld+json"), + ("application/x-www-form-urlencoded; boundary=something", "application/x-www-form-urlencoded"), + ], +) +def test_parse_content_type(header_string: str, expected: str): + assert parse_content_type(header_string) == expected