-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsecurity.py
226 lines (198 loc) · 9.58 KB
/
security.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
"""
Most of this security has been well tested (and stolen) from https://github.com/Intility/fastapi-azure-auth,
which I'm the author of. However, this specific project has been written in a day or two for fun, not for enterprise
security. If you're using this library as inspiration for anything, please keep that in mind.
"""
import logging
from datetime import datetime, timedelta
from typing import Any
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2AuthorizationCodeBearer, SecurityScopes
from fastapi.security.base import SecurityBase
from httpx import AsyncClient
from jose import ExpiredSignatureError, jwk, jwt
from jose.backends.cryptography_backend import CryptographyRSAKey
from jose.exceptions import JWTClaimsError, JWTError
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from starlette.requests import Request
from app.api.dependencies import yield_db_session
from app.core.config import settings
from app.models.klepp import User
from app.schemas.schemas_v1.user import User as CognitoUser
class InvalidAuth(HTTPException):
"""
Exception raised when the user is not authorized
"""
def __init__(self, detail: str) -> None:
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED, detail=detail, headers={'WWW-Authenticate': 'Bearer'}
)
log = logging.getLogger(__name__)
class OpenIdConfig:
def __init__(self) -> None:
self._config_timestamp: datetime | None = None
self.openid_url = (
f'https://cognito-idp.{settings.AWS_REGION}.amazonaws.com/'
f'{settings.AWS_USER_POOL_ID}/.well-known/openid-configuration'
)
self.issuer: str
async def load_config(self) -> None:
"""
Loads config from the openid endpoint if it's over 24 hours old (or don't exist)
"""
refresh_time = datetime.now() - timedelta(hours=24)
if not self._config_timestamp or self._config_timestamp < refresh_time:
try:
log.debug('Loading Cognito OpenID configuration.')
await self._load_openid_config()
self._config_timestamp = datetime.now()
except Exception as error:
log.exception('Unable to fetch OpenID configuration from Cognito. Error: %s', error)
# We can't fetch an up to date openid-config, so authentication will not work.
if self._config_timestamp:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='Connection to Cognito is down. Unable to fetch provider configuration',
headers={'WWW-Authenticate': 'Bearer'},
)
else:
raise RuntimeError(f'Unable to fetch provider information. {error}')
log.info('Loaded settings from Cognito.')
log.info('Issuer: %s', self.issuer)
async def _load_openid_config(self) -> None:
"""
Load openid config, fetch signing keys
"""
async with AsyncClient(timeout=10) as client:
log.info('Fetching OpenID Connect config from %s', self.openid_url)
openid_response = await client.get(self.openid_url)
openid_response.raise_for_status()
openid_cfg = openid_response.json()
self.issuer = openid_cfg['issuer']
jwks_uri = openid_cfg['jwks_uri']
log.info('Fetching jwks from %s', jwks_uri)
jwks_response = await client.get(jwks_uri)
jwks_response.raise_for_status()
self._load_keys(jwks_response.json()['keys'])
def _load_keys(self, keys: list[dict[str, Any]]) -> None:
"""
Create certificates based on signing keys and store them
"""
self.signing_keys: dict[str, CryptographyRSAKey] = {}
for key in keys:
if key.get('use') == 'sig': # Only care about keys that are used for signatures, not encryption
log.debug('Loading public key from certificate: %s', key)
cert_obj = jwk.construct(key, 'RS256')
if kid := key.get('kid'):
self.signing_keys[kid] = cert_obj.public_key()
class CognitoAuthorizationCodeBearerBase(SecurityBase):
def __init__(self, auto_error: bool = True) -> None:
self.auto_error = auto_error
self.openid_config: OpenIdConfig = OpenIdConfig()
self.oauth = OAuth2AuthorizationCodeBearer(
authorizationUrl='https://auth.klepp.me/oauth2/authorize',
tokenUrl='https://auth.klepp.me/oauth2/token',
scopes={'openid': 'openid'},
scheme_name='CognitoAuth',
auto_error=True,
)
self.model = self.oauth.model
self.scheme_name: str = 'Cognito'
async def __call__(self, request: Request, security_scopes: SecurityScopes) -> CognitoUser | None:
"""
Extends call to also validate the token.
"""
try:
access_token = await self.oauth(request=request)
try:
# Extract header information of the token.
header: dict[str, str] = jwt.get_unverified_header(token=access_token) or {}
claims: dict[str, Any] = jwt.get_unverified_claims(token=access_token) or {}
except Exception as error:
log.warning('Malformed token received. %s. Error: %s', access_token, error, exc_info=True)
raise InvalidAuth(detail='Invalid token format')
for scope in security_scopes.scopes:
token_scope_string = claims.get('scp', '')
if isinstance(token_scope_string, str):
token_scopes = token_scope_string.split(' ')
if scope not in token_scopes:
raise InvalidAuth('Required scope missing')
else:
raise InvalidAuth('Token contains invalid formatted scopes')
# Load new config if old
await self.openid_config.load_config()
# Use the `kid` from the header to find a matching signing key to use
try:
if key := self.openid_config.signing_keys.get(header.get('kid', '')):
# We require and validate all fields in a Cognito token
options = {
'verify_signature': True,
'verify_aud': False,
'verify_iat': True,
'verify_exp': True,
'verify_nbf': False,
'verify_iss': True,
'verify_sub': True,
'verify_jti': True,
'verify_at_hash': True,
'require_aud': False,
'require_iat': True,
'require_exp': True,
'require_nbf': False,
'require_iss': True,
'require_sub': True,
'require_jti': False,
'require_at_hash': False,
'leeway': 0,
}
# Validate token
token = jwt.decode(
access_token,
key=key, # noqa
algorithms=['RS256'],
issuer=self.openid_config.issuer,
options=options,
)
# Attach the user to the request. Can be accessed through `request.state.user`
user: CognitoUser = CognitoUser(**token)
request.state.user = user
return user
except JWTClaimsError as error:
log.info('Token contains invalid claims. %s', error)
raise InvalidAuth(detail='Token contains invalid claims')
except ExpiredSignatureError as error:
log.info('Token signature has expired. %s', error)
raise InvalidAuth(detail='Token signature has expired')
except JWTError as error:
log.warning('Invalid token. Error: %s', error, exc_info=True)
raise InvalidAuth(detail='Unable to validate token')
except Exception as error:
# Extra failsafe in case of a bug in a future version of the jwt library
log.exception('Unable to process jwt token. Uncaught error: %s', error)
raise InvalidAuth(detail='Unable to process token')
log.warning('Unable to verify token. No signing keys found')
raise InvalidAuth(detail='Unable to verify token, no signing keys found')
except (HTTPException, InvalidAuth):
if not self.auto_error:
return None
raise
cognito_scheme = CognitoAuthorizationCodeBearerBase()
cognito_scheme_or_anonymous = CognitoAuthorizationCodeBearerBase(auto_error=False)
async def cognito_signed_in(
cognito_user: CognitoUser = Depends(cognito_scheme),
db_session: AsyncSession = Depends(yield_db_session),
) -> User:
"""
Creates a user in the DB for a signed in Cognito user if it don't exist
"""
select_user = select(User).where(User.name == cognito_user.username)
user_query = await db_session.exec(select_user) # type: ignore
user = user_query.one_or_none()
if not user:
new_user = User(name=cognito_user.username)
db_session.add(new_user)
await db_session.commit()
await db_session.refresh(new_user)
return new_user
return user # type: ignore