Skip to content

Commit

Permalink
Fixing some lint issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Sebastian Gomez-Gonzalez committed Jul 13, 2024
1 parent 3d21c80 commit f0c5bce
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 52 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install pylint
pip install -r requirements.txt
- name: Analysing the code with pylint
run: |
pylint $(git ls-files '*.py')
pylint --disable=missing-docstring $(git ls-files '*.py')
153 changes: 135 additions & 18 deletions app/api.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,58 @@
from fastapi import FastAPI, Body, Depends, HTTPException, status
from typing import Annotated
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
"""API methods."""

from sqlmodel import Session, SQLModel, select
from datetime import timedelta
from typing import Annotated

import app.model as model
import app.auth.auth_handler as auth_handler
import app.auth.crypto as crypto
from fastapi import Body, Depends, FastAPI, HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from sqlmodel import Session, select

from app.auth import auth_handler
from app.auth import crypto
from app import model

app_obj = FastAPI()


@app_obj.on_event("startup")
def on_startup():
model.create_db_and_tables()

# route handlers


@app_obj.post("/token")
async def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], session: Session = Depends(model.get_db_session)):
user = model.UserLogin(email=form_data.username, password=form_data.password)
async def login(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
session: Session = Depends(model.get_db_session)):
"""
Login API to authenticate a user and generate an access token.
This function takes user credentials from the request body
and validates them against the database.
If the credentials are valid, it generates an access token
with a specific expiration time
and returns it along with the token type.
Args:
form_data: An instance of `OAuth2PasswordRequestForm` containing
user credentials.
Retrieved from the request body using Depends.
session: A SQLAlchemy database session object. Obtained using
Depends from `model.get_db_session`.
Raises:
HTTPException: If the username or password is incorrect (400 Bad Request).
Returns:
A `model.Token` object containing the access token and token type.
"""
user = model.UserLogin(email=form_data.username,
password=form_data.password)
db_user = auth_handler.check_and_get_user(user, session)
if not db_user:
raise HTTPException(status_code=400, detail="Incorrect username or password")
raise HTTPException(
status_code=400, detail="Incorrect username or password")
token = auth_handler.create_access_token(db_user, timedelta(minutes=30))
return model.Token(access_token=token, token_type="bearer")

Expand All @@ -31,14 +61,57 @@ async def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], sess
async def get_screenshots(
*,
session: Session = Depends(model.get_db_session),
current_user_id: model.UserId = Depends(auth_handler.get_current_user_id)):
statement = select(model.Screenshot).where(model.Screenshot.owner_id == current_user_id.id)
current_user_id: model.UserId = Depends(auth_handler.get_current_user_id)):
"""
Retrieves a list of screenshots owned by the currently authenticated user.
This API endpoint fetches all screenshots from the database that
belong to the user identified by the access token provided in the
request. It requires user authentication
through the `Depends` dependency on `auth_handler.get_current_user_id`.
Args:
session: A SQLAlchemy database session object. Obtained using
Depends from `model.get_db_session`.
current_user_id: An instance of `model.UserId` containing the
authenticated user's ID.
Retrieved using Depends from `auth_handler.get_current_user_id`.
Returns:
A list of `model.Screenshot` objects representing the user's screenshots.
"""
statement = select(model.Screenshot).where(
model.Screenshot.owner_id == current_user_id.id)
return session.exec(statement)


@app_obj.get("/screenshots/{id}", tags=["screenshots"], response_model=model.Screenshot)
async def get_single_screenshot(*, session: Session = Depends(model.get_db_session), id: str):
q = select(model.Screenshot).where(model.Screenshot.external_id == id)
async def get_single_screenshot(
*,
session: Session = Depends(model.get_db_session),
screenshot_id: str):
"""
Retrieves a specific screenshot by its external ID.
This API endpoint fetches a single screenshot identified by
the provided `screenshot_id`
from the database. It checks if the screenshot exists and raises
a 404 Not Found exception if it's not found.
Args:
session: A SQLAlchemy database session object. Obtained using
Depends from `model.get_db_session`.
screenshot_id: The external ID of the screenshot to retrieve.
Returns:
A single `model.Screenshot` object if the screenshot is found,
otherwise returns None.
Raises:
HTTPException: If the screenshot with the provided ID is not found (404 Not Found).
"""
q = select(model.Screenshot).where(
model.Screenshot.external_id == screenshot_id)
screenshot = session.exec(q).first()
if not screenshot:
raise HTTPException(status_code=404, detail="Screenshot not found")
Expand All @@ -50,20 +123,64 @@ async def add_screenshots(
*,
session: Session = Depends(model.get_db_session),
current_user_id: model.UserId = Depends(auth_handler.get_current_user_id),
screenshot: model.ScreenshotCreate):
screenshot_db = model.Screenshot.model_validate(screenshot, update={"owner_id": current_user_id.id})
screenshot: model.ScreenshotCreate):
"""
Creates a new screenshot for the currently authenticated user.
This API endpoint allows users to add new screenshots. It requires user
authentication through the `Depends` dependency on
`auth_handler.get_current_user_id`. The provided
`screenshot` data is validated against the `model.ScreenshotCreate` schema.
Args:
session: A SQLAlchemy database session object. Obtained using
Depends from `model.get_db_session`.
current_user_id: An instance of `model.UserId` containing
the authenticated user's ID.
Retrieved using Depends from `auth_handler.get_current_user_id`.
screenshot: An instance of `model.ScreenshotCreate` containing the
data for the new screenshot.
Returns:
A `model.Screenshot` object representing the newly created screenshot.
"""
screenshot_db = model.Screenshot.model_validate(
screenshot, update={"owner_id": current_user_id.id})
screenshot_db.external_id = crypto.generate_random_base64_string(32)
session.add(screenshot_db)
session.commit()
session.refresh(screenshot_db)
return screenshot_db


@app_obj.post("/user/signup", tags=["user"])
async def create_user(*, session: Session = Depends(model.get_db_session), user: model.UserCreate = Body(...)):
async def create_user(
*,
session: Session = Depends(model.get_db_session),
user: model.UserCreate = Body(...)):
"""
Creates a new user account.
This API endpoint allows users to register and create new accounts. The
provided `user` data is validated against the `model.UserCreate` schema.
The password is hashed before saving it to the database for security
reasons.
Args:
session: A SQLAlchemy database session object (Obtained using Depends).
user: An instance of `model.UserCreate` containing the new
user's information.
Returns:
A `model.Token` object containing the access token and token type
upon successful registration.
"""
db_user = model.User.model_validate(user)
# Hash password before saving it
db_user.password = crypto.get_password_hash(db_user.password)
session.add(db_user)
session.commit()
session.refresh(db_user)
return model.Token(access_token=auth_handler.create_access_token(db_user), token_type="bearer")
return model.Token(
access_token=auth_handler.create_access_token(db_user),
token_type="bearer")
46 changes: 29 additions & 17 deletions app/auth/auth_handler.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,38 @@
""" Handles authentication.
"""

import os
from typing import Dict, Annotated
from datetime import datetime, timedelta, timezone
from typing import Union
from typing import Annotated, Union

import jwt
from decouple import config
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from sqlmodel import select

import app.model as model
import app.auth.crypto as crypto

from app import model
from app.auth import crypto

JWT_SECRET = os.getenv("JWT_SECRET")
JWT_ALGORITHM = os.getenv("JWT_ALGORITHM")

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")


def check_and_get_user(data: model.UserLogin, session: model.Session):
user = session.query(model.User).filter(model.User.email == data.email).first()
""" Returns the User object if the credentials are correct or None.
"""
q = select(model.User).where(model.User.email == data.email)
user = session.exec(q).first()
if user:
if crypto.verify_password(data.password, user.password):
return user
return None


def create_access_token(user: model.User, expires_delta: Union[timedelta, None] = None):
""" Create JWT access token.
"""
to_encode = {
"sub": {'id': user.id, 'email': user.email}
}
Expand All @@ -36,24 +44,28 @@ def create_access_token(user: model.User, expires_delta: Union[timedelta, None]
encoded_jwt = jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM)
return encoded_jwt


def decode_jwt(token: str) -> dict:
""" Decodes JWT token.
"""
try:
decoded_token = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
decoded_token = jwt.decode(
token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
return decoded_token
except:
except jwt.InvalidTokenError:
return None


async def get_current_user_id(token: Annotated[str, Depends(oauth2_scheme)]):
""" Dependency for authentication that returns the user id.
"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
decoded_token = decode_jwt(token)
if not decoded_token:
raise credentials_exception
user_id = model.UserId(**decoded_token['sub'])
return user_id
except:
decoded_token = decode_jwt(token)
if not decoded_token:
raise credentials_exception
user_id = model.UserId(**decoded_token['sub'])
return user_id
49 changes: 36 additions & 13 deletions app/auth/crypto.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,55 @@
from passlib.context import CryptContext
""" Basic crypto and hash utilities.
"""

import base64
import os

from passlib.context import CryptContext

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")


def verify_password(plain_password: str, hashed_password: str):
"""
Verifies a plain text password against a hashed password.
Args:
plain_password: The password to be verified in plain text.
hashed_password: The hashed password to compare against.
Returns:
True if the passwords match, False otherwise.
"""
return pwd_context.verify(plain_password, hashed_password)


def get_password_hash(password: str):
"""
Hashes a plain text password.
Args:
password: The password to be hashed.
Returns:
The hashed password.
"""
return pwd_context.hash(password)


def generate_random_base64_string(length: int):
"""Generates a random string of the specified length and encodes it in base64 (URL-safe).
"""Generates a random string of the specified length and encodes it in base64 (URL-safe).
Args:
length: The desired length of the random string (in bytes).
Args:
length: The desired length of the random string (in bytes).
Returns:
A string containing the random data encoded in base64 (URL-safe).
"""
# Generate random bytes using os.urandom()
random_bytes = os.urandom(length)
Returns:
A string containing the random data encoded in base64 (URL-safe).
"""
# Generate random bytes using os.urandom()
random_bytes = os.urandom(length)

# Encode the random bytes in base64 (URL-safe) format
encoded_string = base64.urlsafe_b64encode(random_bytes).decode()
# Encode the random bytes in base64 (URL-safe) format
encoded_string = base64.urlsafe_b64encode(random_bytes).decode()

# Remove trailing newline character (optional)
return encoded_string.rstrip("\n")
# Remove trailing newline character (optional)
return encoded_string.rstrip("\n")
11 changes: 8 additions & 3 deletions app/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
""" Model classes and data persistence.
"""

import os
from pydantic import BaseModel, Field, EmailStr
from typing import Optional, Union

from sqlmodel import Field, Session, SQLModel, Relationship, create_engine
from pydantic import BaseModel, EmailStr
from sqlmodel import Field, Session, SQLModel, create_engine

connect_args = {"check_same_thread": False}
engine = create_engine(
Expand All @@ -11,10 +14,12 @@
connect_args=connect_args)

def get_db_session():
"""Returns DB session."""
with Session(engine) as session:
yield session

def create_db_and_tables():
"""Creates database with schema."""
SQLModel.metadata.create_all(engine)

class Token(BaseModel):
Expand Down Expand Up @@ -49,4 +54,4 @@ class ScreenshotBase(ScreenshotCreate):
#owner: User | None = Relationship(back_populates="screenshots")

class Screenshot(ScreenshotBase, table=True):
id: int = Field(default=None, primary_key=True)
id: int = Field(default=None, primary_key=True)

0 comments on commit f0c5bce

Please sign in to comment.