diff --git a/app/api.py b/app/api.py index a80c9309..5098e7c9 100644 --- a/app/api.py +++ b/app/api.py @@ -55,6 +55,18 @@ async def create_token(user_id: str): return f"{user_id}__U{str(uuid.uuid4())}" +async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]): + if token and '__U' in token: + current_user: schemas.UserBase = crud.update_user_by_token(db, token=token) # type: ignore + if current_user: + return current_user + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # App startup & shutdown # ------------------------------------------------------------------------------ @app.on_event("startup") @@ -113,8 +125,8 @@ async def authentication(form_data: Annotated[OAuth2PasswordRequestForm, Depends @app.post("/prices", response_model=schemas.PriceBase) -async def create_price(price: schemas.PriceCreate): - db_price = crud.create_price(db, price=price) # type: ignore +async def create_price(price: schemas.PriceCreate, current_user: schemas.UserBase = Depends(get_current_user)): + db_price = crud.create_price(db, price=price, user=current_user) # type: ignore return db_price diff --git a/app/crud.py b/app/crud.py index 23af2efe..e1738bfd 100644 --- a/app/crud.py +++ b/app/crud.py @@ -1,4 +1,5 @@ from sqlalchemy.orm import Session +from sqlalchemy.sql import func from app.models import Price from app.models import User @@ -29,6 +30,16 @@ def create_user(db: Session, user: UserBase): return db_user +def update_user_by_token(db: Session, token: str): + db_user = get_user_by_token(db, token=token) + if db_user: + db.query(User).filter(User.user_id == db_user.user_id).update({"last_used": func.now()}) + db.commit() + db.refresh(db_user) + return db_user + return False + + def delete_user(db: Session, user_id: UserBase): db_user = get_user_by_user_id(db, user_id=user_id) if db_user: @@ -38,8 +49,8 @@ def delete_user(db: Session, user_id: UserBase): return False -def create_price(db: Session, price: PriceCreate): - db_price = Price(**price.dict()) +def create_price(db: Session, price: PriceCreate, user: UserBase): + db_price = Price(**price.dict(), owner=user.user_id) db.add(db_price) db.commit() db.refresh(db_price)