Skip to content

Commit

Permalink
feat: propagate Autorization header along with Api-Key header (epam#135)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Jul 16, 2024
1 parent 2aca1ac commit 2e0e4dd
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 43 deletions.
38 changes: 22 additions & 16 deletions aidial_sdk/header_propagator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import types
from contextvars import ContextVar
from typing import Optional
from typing import MutableMapping, Optional

import aiohttp
import httpx
Expand Down Expand Up @@ -79,21 +79,12 @@ async def _on_aiohttp_request_start(
trace_config_ctx: types.SimpleNamespace,
params: aiohttp.TraceRequestStartParams,
):
if not str(params.url).startswith(self._dial_url):
return

api_key_val = self._api_key.get()
if api_key_val:
params.headers["api-key"] = api_key_val
self._modify_headers(str(params.url), params.headers)

def _instrument_requests(self):
def instrumented_send(wrapped, instance, args, kwargs):
request: requests.PreparedRequest = args[0]
if request.url and request.url.startswith(self._dial_url):
api_key_val = self._api_key.get()
if api_key_val:
request.headers["api-key"] = api_key_val

self._modify_headers(request.url or "", request.headers)
return wrapped(*args, **kwargs)

wrapt.wrap_function_wrapper(requests.Session, "send", instrumented_send)
Expand All @@ -102,10 +93,7 @@ def _instrument_httpx(self):

def instrumented_build_request(wrapped, instance, args, kwargs):
request: httpx.Request = wrapped(*args, **kwargs)
if request.url and str(request.url).startswith(self._dial_url):
api_key_val = self._api_key.get()
if api_key_val:
request.headers["api-key"] = api_key_val
self._modify_headers(str(request.url), request.headers)
return request

wrapt.wrap_function_wrapper(
Expand All @@ -115,3 +103,21 @@ def instrumented_build_request(wrapped, instance, args, kwargs):
wrapt.wrap_function_wrapper(
httpx.AsyncClient, "build_request", instrumented_build_request
)

def _modify_headers(
self, url: str, headers: MutableMapping[str, str]
) -> None:
if url.startswith(self._dial_url):
api_key = self._api_key.get()
if api_key:
old_api_key = headers.get("api-key")
old_authz = headers.get("Authorization")

if (
old_api_key
and old_authz
and old_authz == f"Bearer {old_api_key}"
):
headers["Authorization"] = f"Bearer {api_key}"

headers["api-key"] = api_key
10 changes: 6 additions & 4 deletions tests/header_propagation/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,33 +21,35 @@ class Library(str, Enum):
class Request(BaseModel):
url: str
lib: Library
headers: dict


@app.post("/")
async def handle(request: Request):
url = request.url
lib = request.lib
headers = request.headers

if lib == Library.requests:
response = requests.get(url)
response = requests.get(url, headers=headers)
status_code = response.status_code
content = response.json()

elif lib == Library.httpx_async:
async with httpx.AsyncClient() as client:
response = await client.get(url)
response = await client.get(url, headers=headers)
status_code = response.status_code
content = response.json()

elif lib == Library.httpx_sync:
with httpx.Client() as client:
response = client.get(url)
response = client.get(url, headers=headers)
status_code = response.status_code
content = response.json()

elif lib == Library.aiohttp:
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
async with session.get(url, headers=headers) as response:
status_code = response.status
content = await response.json()

Expand Down
74 changes: 51 additions & 23 deletions tests/test_header_propagation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import re
from typing import Optional
from itertools import product
from typing import Mapping, Optional

import aioresponses
import httpx
Expand All @@ -10,8 +11,10 @@
import respx
from fastapi import FastAPI
from fastapi.testclient import TestClient
from requests.structures import CaseInsensitiveDict

from aidial_sdk.header_propagator import HeaderPropagator
from aidial_sdk.utils.json import remove_nones
from tests.header_propagation.client import app as sender
from tests.utils.text import removeprefix

Expand All @@ -31,16 +34,21 @@ def client():
return TestClient(app)


def _get_headers(headers: Mapping[str, str]) -> dict:
api_key = headers.get("Api-Key")
authz = headers.get("Authorization")
return remove_nones({"api-key": api_key, "authorization": authz})


@pytest.fixture
def mock_requests():
with responses.mock as mock:

def callback(request: requests.PreparedRequest):
api_key = request.headers.get("api-key")
return (
200,
{"content-type": "application/json"},
json.dumps({"api_key": api_key}),
json.dumps(_get_headers(request.headers)),
)

mock.add_callback(
Expand All @@ -59,8 +67,7 @@ def mock_httpx():

@respx.route(method="GET", host__in=HOSTS, path="/")
def handler(request: httpx.Request):
api_key = request.headers.get("api-key")
return httpx.Response(200, json={"api_key": api_key})
return httpx.Response(200, json=_get_headers(request.headers))

yield mock

Expand All @@ -70,24 +77,22 @@ def mock_aiohttp():
with aioresponses.aioresponses() as mock:

def callback(url, **kwargs) -> aioresponses.CallbackResult:
api_key = kwargs.get("headers", {}).get("api-key")
return aioresponses.CallbackResult(payload={"api_key": api_key})
headers = CaseInsensitiveDict(kwargs.get("headers", {}))
return aioresponses.CallbackResult(payload=_get_headers(headers))

mock.get(URL_PATTERN, callback=callback)
yield mock


@pytest.mark.parametrize(
"lib", ["aiohttp", "requests", "httpx_sync", "httpx_async"]
)
@pytest.mark.parametrize(
"url,key_to_send,key_to_receive",
[
(DIAL_URL, API_KEY, API_KEY),
(NON_DIAL_URL, API_KEY, None),
(DIAL_URL, None, None),
(NON_DIAL_URL, None, None),
],
"lib, url, key_to_propagate, key_for_upstream, add_authz",
product(
["aiohttp", "requests", "httpx_sync", "httpx_async"],
[DIAL_URL, NON_DIAL_URL],
["test-api-key", None],
["dummy-api-key", None],
[True, False],
),
)
def test_send_request(
client: TestClient,
Expand All @@ -96,21 +101,44 @@ def test_send_request(
mock_aiohttp,
lib: str,
url: str,
key_to_send: Optional[str],
key_to_receive: Optional[str],
key_to_propagate: Optional[str],
key_for_upstream: Optional[str],
add_authz: bool,
):
headers_to_propagate = {}
if key_to_propagate:
headers_to_propagate["api-key"] = key_to_propagate
if add_authz:
headers_to_propagate["authorization"] = f"Bearer {key_to_propagate}"

headers_for_upstream = {}
if key_for_upstream:
headers_for_upstream["api-key"] = key_for_upstream
if add_authz:
headers_for_upstream["authorization"] = f"Bearer {key_for_upstream}"

response = client.post(
"/",
json={"url": url, "lib": lib},
headers={} if key_to_send is None else {"api-key": key_to_send},
json={"url": url, "lib": lib, "headers": headers_for_upstream},
headers=headers_to_propagate,
)
assert response.status_code == 200, response.json()

expected_key = (
key_to_propagate if url == DIAL_URL else None
) or key_for_upstream

expected_headers = {}
if expected_key:
expected_headers["api-key"] = expected_key
if add_authz and key_for_upstream:
expected_headers["authorization"] = f"Bearer {expected_key}"

# NOTE: aioresponses doesn't call trace_configs in the mocked version,
# and since we are patching the request via a dedicated trace config,
# we can't test the header propagation for aiohttp.
# https://github.com/pnuckowski/aioresponses/issues/246
if lib == "aiohttp":
key_to_receive = None
expected_headers = headers_for_upstream

assert response.json() == {"api_key": key_to_receive}
assert response.json() == expected_headers

0 comments on commit 2e0e4dd

Please sign in to comment.