Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change OAUTH fastAPI class to HTTPBearer #17

Merged
merged 5 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ NEUROAGENT_GENERATIVE__OPENAI__TOKEN=
# Important but not required
NEUROAGENT_AGENT__MODEL=

NEUROAGENT_KNOWLEDGE_GRAPH__USE_TOKEN=
NEUROAGENT_KNOWLEDGE_GRAPH__TOKEN=
NEUROAGENT_KNOWLEDGE_GRAPH__DOWNLOAD_HIERARCHY=

# Useful but not required.
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ repos:
args: [--branch, master, --branch, main]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.5
rev: v0.6.7
hooks:
- id: ruff
args: [ --fix ]
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Changed
- Switched from OAUTH2 security on FASTAPI to HTTPBearer.

### Added
- Add get morphoelectric (me) model tool

Expand Down
23 changes: 4 additions & 19 deletions src/neuroagent/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Literal, Optional

from dotenv import dotenv_values
from fastapi.openapi.models import OAuthFlowPassword, OAuthFlows
from pydantic import BaseModel, ConfigDict, SecretStr, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict

Expand Down Expand Up @@ -59,15 +58,6 @@ def user_info_endpoint(self) -> str | None:
else:
return None

@property
def flows(self) -> OAuthFlows:
"""Define the flow to override Fastapi's one."""
return OAuthFlows(
password=OAuthFlowPassword(
tokenUrl=self.token_endpoint,
),
)

@property
def server_url(self) -> str:
"""Server url."""
Expand Down Expand Up @@ -126,8 +116,6 @@ class SettingsKnowledgeGraph(BaseModel):
"""Knowledge graph API settings."""

base_url: str
token: SecretStr | None = None
use_token: bool = False
download_hierarchy: bool = False
br_saving_path: pathlib.Path | str = str(
pathlib.Path(__file__).parent / "data" / "brainregion_hierarchy.json"
Expand Down Expand Up @@ -230,14 +218,11 @@ def check_consistency(self) -> "Settings":
model validator is run during instantiation.

"""
# If you don't enforce keycloak auth, you need a way to communicate with the APIs the tools leverage
if not self.keycloak.password and not self.keycloak.validate_token:
if not self.knowledge_graph.use_token:
raise ValueError("if no password is provided, please use token auth.")
if not self.knowledge_graph.token:
raise ValueError(
"No auth method provided for knowledge graph related queries."
" Please set either a password or use a fixed token."
)
raise ValueError(
"Need an auth method for subsequent APIs called by the tools."
)

return self

Expand Down
38 changes: 21 additions & 17 deletions src/neuroagent/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Annotated, Any, AsyncIterator, Iterator

from fastapi import Depends, HTTPException, Request
from fastapi.security import OAuth2PasswordBearer
from fastapi.security import HTTPBearer
from httpx import AsyncClient, HTTPStatusError
from keycloak import KeycloakOpenID
from langchain_openai import ChatOpenAI
Expand Down Expand Up @@ -40,10 +40,17 @@

logger = logging.getLogger(__name__)

auth = OAuth2PasswordBearer(
tokenUrl="/token", # Will be overriden
auto_error=False,
)

class HTTPBearerDirect(HTTPBearer):
"""HTTPBearer class that returns directly the token in the call."""

async def __call__(self, request: Request) -> str | None: # type: ignore
"""Intercept the bearer token in the headers."""
auth_credentials = await super().__call__(request)
return auth_credentials.credentials if auth_credentials else None


auth = HTTPBearerDirect(auto_error=False)


@cache
Expand Down Expand Up @@ -168,18 +175,15 @@ def get_kg_token(
if token:
return token
else:
if not settings.knowledge_graph.use_token:
instance = KeycloakOpenID(
server_url=settings.keycloak.server_url,
realm_name=settings.keycloak.realm,
client_id=settings.keycloak.client_id,
)
return instance.token(
username=settings.keycloak.username,
password=settings.keycloak.password.get_secret_value(), # type: ignore
)["access_token"]
else:
return settings.knowledge_graph.token.get_secret_value() # type: ignore
instance = KeycloakOpenID(
server_url=settings.keycloak.server_url,
realm_name=settings.keycloak.realm,
client_id=settings.keycloak.client_id,
)
return instance.token(
username=settings.keycloak.username,
password=settings.keycloak.password.get_secret_value(), # type: ignore
)["access_token"]


def get_literature_tool(
Expand Down
4 changes: 0 additions & 4 deletions src/neuroagent/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from neuroagent import __version__
from neuroagent.app.config import Settings
from neuroagent.app.dependencies import (
auth,
get_agent_memory,
get_cell_types_kg_hierarchy,
get_connection_string,
Expand Down Expand Up @@ -72,9 +71,6 @@ async def lifespan(fastapi_app: FastAPI) -> AsyncContextManager[None]: # type:
"""Read environment (settings of the application)."""
# hacky but works: https://github.com/tiangolo/fastapi/issues/425
app_settings = fastapi_app.dependency_overrides.get(get_settings, get_settings)()
if app_settings.keycloak.validate_token:
auth.model.flows = app_settings.keycloak.flows # type: ignore

engine = fastapi_app.dependency_overrides.get(get_engine, get_engine)(
app_settings, get_connection_string(app_settings)
)
Expand Down
15 changes: 6 additions & 9 deletions tests/app/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ def test_required(monkeypatch, patch_required_env):
assert settings.tools.literature.url == "https://fake_url"
assert settings.knowledge_graph.base_url == "https://fake_url/api/nexus/v1"
assert settings.openai.token.get_secret_value() == "dummy"
assert settings.knowledge_graph.use_token
assert settings.knowledge_graph.token.get_secret_value() == "token"

# make sure not case sensitive
monkeypatch.delenv("NEUROAGENT_TOOLS__LITERATURE__URL")
Expand Down Expand Up @@ -44,8 +42,6 @@ def test_setup_tools(monkeypatch, patch_required_env):
assert settings.tools.kg_morpho_features.search_size == 20
assert settings.keycloak.username == "user"
assert settings.keycloak.password.get_secret_value() == "pass"
assert settings.knowledge_graph.use_token
assert settings.knowledge_graph.token.get_secret_value() == "token"


def test_check_consistency(monkeypatch):
Expand All @@ -58,17 +54,18 @@ def test_check_consistency(monkeypatch):
Settings()

monkeypatch.setenv("NEUROAGENT_GENERATIVE__OPENAI__TOKEN", "dummy")
monkeypatch.setenv("NEUROAGENT_KNOWLEDGE_GRAPH__USE_TOKEN", "true")
monkeypatch.setenv("NEUROAGENT_KEYCLOAK__VALIDATE_TOKEN", "true")

with pytest.raises(ValueError):
Settings()

monkeypatch.setenv("NEUROAGENT_KNOWLEDGE_GRAPH__USE_TOKEN", "false")
monkeypatch.setenv("NEUROAGENT_KEYCLOAK__VALIDATE_TOKEN", "false")

with pytest.raises(ValueError):
Settings()

monkeypatch.setenv("NEUROAGENT_KNOWLEDGE_GRAPH__TOKEN", "Hello")
monkeypatch.setenv("NEUROAGENT_KNOWLEDGE_GRAPH__BASE_URL", "http://fake_nexus.com")
monkeypatch.setenv("NEUROAGENT_KEYCLOAK__VALIDATE_TOKEN", "true")
monkeypatch.setenv("NEUROAGENT_KEYCLOAK__PASSWORD", "Hello")

with pytest.raises(ValueError):
Settings()
Settings()
1 change: 1 addition & 0 deletions tests/app/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ async def save_dummy(*args, **kwargs):
with (
patch("neuroagent.app.main.get_update_kg_hierarchy", new=save_dummy),
patch("neuroagent.app.main.get_cell_types_kg_hierarchy", new=save_dummy),
patch("neuroagent.app.main.get_kg_token", new=lambda *args, **kwargs: "dev"),
):
# The with statement triggers the startup.
with TestClient(app) as test_client:
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def patch_required_env(monkeypatch):
"NEUROAGENT_KNOWLEDGE_GRAPH__BASE_URL", "https://fake_url/api/nexus/v1"
)
monkeypatch.setenv("NEUROAGENT_OPENAI__TOKEN", "dummy")
monkeypatch.setenv("NEUROAGENT_KNOWLEDGE_GRAPH__TOKEN", "token")
monkeypatch.setenv("NEUROAGENT_KNOWLEDGE_GRAPH__USE_TOKEN", "true")
monkeypatch.setenv("NEUROAGENT_KEYCLOAK__VALIDATE_TOKEN", "False")
monkeypatch.setenv("NEUROAGENT_KEYCLOAK__PASSWORD", "password")


@pytest.fixture(params=["sqlite", "postgresql"], name="db_connection")
Expand Down
Loading