Skip to content

Commit

Permalink
Add authentication on price endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
raphodn committed Nov 13, 2023
1 parent c90c3f6 commit fba1110
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
16 changes: 14 additions & 2 deletions app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,18 @@ async def create_token(user_id: str):
return f"{user_id}__U{str(uuid.uuid4())}"


async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
if token and '__U' in token:
current_user: schemas.UserBase = crud.update_user_by_token(db, token=token) # type: ignore
if current_user:
return current_user
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)


# App startup & shutdown
# ------------------------------------------------------------------------------
@app.on_event("startup")
Expand Down Expand Up @@ -113,8 +125,8 @@ async def authentication(form_data: Annotated[OAuth2PasswordRequestForm, Depends


@app.post("/prices", response_model=schemas.PriceBase)
async def create_price(price: schemas.PriceCreate):
db_price = crud.create_price(db, price=price) # type: ignore
async def create_price(price: schemas.PriceCreate, current_user: schemas.UserBase = Depends(get_current_user)):
db_price = crud.create_price(db, price=price, user=current_user) # type: ignore
return db_price


Expand Down
15 changes: 13 additions & 2 deletions app/crud.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from sqlalchemy.orm import Session
from sqlalchemy.sql import func

from app.models import Price
from app.models import User
Expand Down Expand Up @@ -29,6 +30,16 @@ def create_user(db: Session, user: UserBase):
return db_user


def update_user_by_token(db: Session, token: str):
db_user = get_user_by_token(db, token=token)
if db_user:
db.query(User).filter(User.user_id == db_user.user_id).update({"last_used": func.now()})
db.commit()
db.refresh(db_user)
return db_user
return False


def delete_user(db: Session, user_id: UserBase):
db_user = get_user_by_user_id(db, user_id=user_id)
if db_user:
Expand All @@ -38,8 +49,8 @@ def delete_user(db: Session, user_id: UserBase):
return False


def create_price(db: Session, price: PriceCreate):
db_price = Price(**price.dict())
def create_price(db: Session, price: PriceCreate, user: UserBase):
db_price = Price(**price.dict(), owner=user.user_id)
db.add(db_price)
db.commit()
db.refresh(db_price)
Expand Down

0 comments on commit fba1110

Please sign in to comment.