Skip to content

Commit

Permalink
tokens table and tokens functionality for oauth2 stuff (revoke/provid…
Browse files Browse the repository at this point in the history
…e/..)
  • Loading branch information
HardMax71 committed Aug 29, 2024
1 parent c2f0815 commit f4fcaf5
Show file tree
Hide file tree
Showing 9 changed files with 222 additions and 81 deletions.
50 changes: 34 additions & 16 deletions server/app/api/deps.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,72 @@
# /server/app/api/deps.py
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import jwt
from pydantic import ValidationError
from jose import jwt, JWTError
from sqlalchemy.orm import Session

from public_api.permissions import PermissionName, PermissionType
from public_api.permissions.permission_manager import PermissionManager
from public_api.shared_schemas import user as user_schemas
from public_api.permissions import PermissionName, PermissionType, PermissionManager
from server.app import crud, models
from server.app.core.config import settings
from server.app.db.database import get_db

oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/login/access-token")


def get_current_user(
db: Session = Depends(get_db),
token: str = Depends(oauth2_scheme)
) -> models.User:
def decode_token(token: str) -> dict:
try:
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
token_data = user_schemas.TokenData(**payload)
except (jwt.JWTError, ValidationError):
if payload.get("type") != "access_token":
raise HTTPException(status_code=401, detail="Invalid token type")
return payload
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
)
user = crud.user.get_by_username(db, username=token_data.username)
if not user:


def get_current_user(
db: Session = Depends(get_db),
token: str = Depends(oauth2_scheme)
) -> models.User:
try:
payload = decode_token(token)
user_id: str = payload.get("sub")
if user_id is None:
raise HTTPException(status_code=401, detail="Invalid token")
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
)

user = crud.user.get(db, id=int(user_id))
if user is None:
raise HTTPException(status_code=404, detail="User not found")

token_obj = crud.token.get_by_access_token(db, token)
if token_obj is None or not token_obj.is_active:
raise HTTPException(status_code=401, detail="Token is invalid or expired")

return user


def get_current_active_user(
current_user: models.User = Depends(get_current_user),
) -> models.User:
if not crud.user.is_active(current_user):
if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user


def get_current_admin(
current_user: models.User = Depends(get_current_user),
) -> models.User:
if not crud.user.is_admin(current_user):
if not current_user.role.name.lower() == "admin":
raise HTTPException(
status_code=400, detail="The user doesn't have enough privileges"
status_code=403, detail="The user doesn't have enough privileges"
)
return current_user

Expand Down
62 changes: 34 additions & 28 deletions server/app/api/v1/endpoints/users.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# /server/app/api/v1/endpoints/users.py
from datetime import timedelta
from datetime import datetime

import pyotp
from fastapi import APIRouter, Depends, HTTPException, Body, Query, status
from fastapi.security import OAuth2PasswordRequestForm
from jose import JWTError, jwt
from sqlalchemy.orm import Session

from public_api.shared_schemas import user as user_schemas, RefreshTokenRequest
from public_api.shared_schemas import user as user_schemas
from server.app import crud, models
from server.app.api import deps
from server.app.core import security
Expand All @@ -24,14 +23,17 @@ def login(
):
user = crud.user.authenticate(db, email=form_data.username, password=form_data.password)
if not user:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Incorrect email or password")
raise HTTPException(status_code=400, detail="Incorrect email or password")
if not crud.user.is_active(user):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Inactive user")
raise HTTPException(status_code=400, detail="Inactive user")

if user.two_factor_auth_enabled:
return user_schemas.Token(access_token="2FA_REQUIRED", token_type="bearer", refresh_token="", expires_in=0)

return create_token_for_user(user)
token = crud.token.create_user_tokens(db, user.id)
return user_schemas.Token(
access_token=token.access_token,
refresh_token=token.refresh_token,
token_type="bearer",
expires_in=(token.access_token_expires_at - int(datetime.utcnow().timestamp()))
)


@router.post("/login/2fa", response_model=user_schemas.Token)
Expand All @@ -51,10 +53,7 @@ def login_2fa(


def create_token_for_user(user: models.User) -> user_schemas.Token:
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = security.create_access_token(
user.username, expires_delta=access_token_expires
)
access_token = security.create_access_token(user.username)
refresh_token = security.create_refresh_token(user.username)
return user_schemas.Token(
access_token=access_token,
Expand All @@ -66,24 +65,31 @@ def create_token_for_user(user: models.User) -> user_schemas.Token:

@router.post("/refresh-token", response_model=user_schemas.Token)
def refresh_token(
refresh_data: RefreshTokenRequest,
db: Session = Depends(deps.get_db)
db: Session = Depends(deps.get_db),
refresh_token: str = Body(..., embed=True)
):
try:
payload = jwt.decode(
refresh_data.refresh_token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=400, detail="Invalid refresh token")
except JWTError:
raise HTTPException(status_code=400, detail="Invalid refresh token")
token = crud.token.get_by_refresh_token(db, refresh_token)
if not token or not token.is_active or token.refresh_token_expires_at < int(datetime.utcnow().timestamp()):
raise HTTPException(status_code=400, detail="Invalid or expired refresh token")

user = crud.user.get_by_username(db, username=username)
if not user:
raise HTTPException(status_code=404, detail="User not found")
crud.token.revoke_token(db, token)
new_token = crud.token.create_user_tokens(db, token.user_id)

return create_token_for_user(user)
return user_schemas.Token(
access_token=new_token.access_token,
refresh_token=new_token.refresh_token,
token_type="bearer",
expires_in=(new_token.access_token_expires_at - int(datetime.utcnow().timestamp()))
)


@router.post("/logout")
def logout(
current_user: models.User = Depends(deps.get_current_user),
db: Session = Depends(deps.get_db)
):
crud.token.revoke_all_user_tokens(db, current_user.id)
return {"detail": "Successfully logged out"}


@router.post("/register", response_model=user_schemas.UserSanitized)
Expand Down
62 changes: 33 additions & 29 deletions server/app/core/security.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# /server/app/core/security.py
from datetime import datetime, timedelta
from typing import Any, Union

from typing import Any, Dict
import uuid
from jose import jwt
from passlib.context import CryptContext

Expand All @@ -10,42 +10,46 @@
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")


def create_access_token(
subject: Union[str, Any], expires_delta: timedelta = None
) -> str:
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
to_encode = {"exp": expire, "username": str(subject)}
def create_token(subject: str, token_type: str, expires_delta: timedelta) -> str:
expire = datetime.utcnow() + expires_delta
to_encode: Dict[str, Any] = {
"exp": expire,
"iat": datetime.utcnow(),
"sub": subject,
"type": token_type,
"jti": str(uuid.uuid4())
}
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
return encoded_jwt


def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)
def create_access_token(subject: str) -> str:
return create_token(
subject,
"access_token",
timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
)


def generate_password_reset_token(email: str) -> str:
expire = datetime.utcnow() + timedelta(hours=1)
to_encode = {"exp": expire, "email": email}
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
return encoded_jwt
def create_refresh_token(subject: str) -> str:
return create_token(
subject,
"refresh_token",
timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
)


def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)


def get_password_hash(password: str) -> str:
return pwd_context.hash(password)


def create_refresh_token(
subject: Union[str, Any], expires_delta: timedelta = None
) -> str:
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
to_encode = {"exp": expire, "sub": str(subject)}
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
return encoded_jwt
def generate_password_reset_token(email: str) -> str:
return create_token(
email,
"password_reset",
timedelta(hours=1)
)
1 change: 1 addition & 0 deletions server/app/crud/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .shipment import shipment, carrier
from .supplier import supplier
from .task import task
from .token import token
from .user import user
from .warehouse import whole_warehouse
from .yard import yard
Expand Down
45 changes: 45 additions & 0 deletions server/app/crud/token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# /server/app/crud/token.py
from datetime import datetime, timedelta

from sqlalchemy.orm import Session

from server.app.core.config import settings
from server.app.core.security import create_access_token, create_refresh_token
from server.app.models import Token


class CRUDToken:
def create_user_tokens(self, db: Session, user_id: int) -> Token:
access_token = create_access_token(str(user_id))
refresh_token = create_refresh_token(str(user_id))

now = datetime.utcnow()
token = Token(
user_id=user_id,
access_token=access_token,
refresh_token=refresh_token,
access_token_expires_at=int((now + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)).timestamp()),
refresh_token_expires_at=int((now + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)).timestamp()),
)
db.add(token)
db.commit()
db.refresh(token)
return token

def get_by_access_token(self, db: Session, access_token: str) -> Token | None:
return db.query(Token).filter(Token.access_token == access_token, Token.is_active == True).first()

def get_by_refresh_token(self, db: Session, refresh_token: str) -> Token | None:
return db.query(Token).filter(Token.refresh_token == refresh_token, Token.is_active == True).first()

def revoke_all_user_tokens(self, db: Session, user_id: int) -> None:
db.query(Token).filter(Token.user_id == user_id).update({"is_active": False})
db.commit()

def revoke_token(self, db: Session, token: Token) -> None:
token.is_active = False
db.add(token)
db.commit()


token = CRUDToken()
Loading

0 comments on commit f4fcaf5

Please sign in to comment.