From 47db08ee9787df2afa61b18f9099ca3636a76e7e Mon Sep 17 00:00:00 2001 From: Hanyuan Li <14177193+hanyuone@users.noreply.github.com> Date: Tue, 9 Aug 2022 13:56:22 +1000 Subject: [PATCH] Add lockout functionality (#62) * feat: moved advent folder -> puzzles, added some comments * feat(docker): start separation of dev and prod builds, add pytest functionality to backend * feat(docker): added dev/prod to frontend, transition frontend to yarn * fix: remove .vscode folder * fix(makefile): restructured makefile a bit * feat: removed .vscode folder from git * feat(auth): get rudimentary autotesting in place, created clear_database function * feat(test): added all tests for auth/register * fix(puzzle): changed blueprint in routes/puzzle.py * feat(auth): refactored registration system, database connections * fix(auth): minor changes to constructor * feat(auth): implement email verification endpoints * feat(test): using fixtures * feat(auth): finish autotests, still needs commenting * feat(auth): finished writing tests for the most part * feat(auth): complete tests for basic auth system * fix(auth): removed duplicate clear_database function * fix(auth): add basic lockout functionality * fix(auth): fix clear_database utility function * fix(auth): change requests to conform with DB * fix(auth): add basic lockout to /login route * feat(auth): add function to carry over CSRF in headers --- backend/common/database.py | 50 ++----- backend/common/redis.py | 56 ++++++- backend/database/database.py | 56 +++---- backend/database/user.py | 149 ++++++++++++------- backend/models/user.py | 248 +++++++++++++++---------------- backend/routes/auth.py | 185 +++++++++++++---------- backend/routes/user.py | 19 ++- backend/test/auth/login_test.py | 149 ++++++++++++------- backend/test/auth/logout_test.py | 47 +++--- backend/test/helpers.py | 51 +++++-- 10 files changed, 584 insertions(+), 426 deletions(-) diff --git a/backend/common/database.py b/backend/common/database.py index 6fc4bc0..a465abb 100644 --- a/backend/common/database.py +++ b/backend/common/database.py @@ -147,18 +147,6 @@ def checkInput(compName, dayNum, uid, solution): # note: for more advanced processing, we might consider having a timeout if a user tries too many things too quickly # but idk how to implement this too well -# Get all the information about a user given their uid -# Returns all information in the form of a dictionary -def getUserInfo(uid): - query = f""" - select * from Users where uid = {uid}; - """ - cur.execute(query) - - # only one entry should be returned since day number is unique - t = cur.fetchone() - return t - # Get all the information about a user's stats in a certain competition # Returns all information in the form of a list of 'solved objects' def getUserStatsPerComp(compName, uid): @@ -172,23 +160,7 @@ def getUserStatsPerComp(compName, uid): right outer join Parts p on s.pid = p.pid join Questions q on p.qid = q.qid join Competitions c on q.cid = c.cid - where s.uid = {uid} and c.name = '{compName}'; - """ - cur.execute(query) - - return cur.fetchall() - -# Get only the number of stars and points for a user. -# Returns extremely simple info -def getBasicUserStatsPerComp(compName, uid): - - # A right outer join returns all the results from the parts table, even if there is no solves - # Best to look up examples :D - # Use this information to deduce whether a user has solved a part or not - query = f""" - select u.username, u.github, s.numStars, s.score from Stats s - right outer join Users u - where s.uid = {uid} and c.name = '{compName}'; + where i.uid = {uid} and c.name = {compName}; """ cur.execute(query) @@ -217,12 +189,22 @@ def getAllCompetitions(): def updateUsername(username, uid): query = f""" update Users - set username = '{username}' + set username = {username} where uid = {uid}; """ cur.execute(query) conn.commit() - ''' - cursor.close() - conn.close() - ''' + +# DO NOT EVER EXECUTE THIS FUNCTION BRUH +def dropDatabase(): + query = f""" + SELECT 'DROP TABLE IF EXISTS "' || tablename || '" CASCADE;' + from + pg_tables WHERE schemaname = 'advent'; + """ + cur.execute(query) + conn.commit() + +def clear_database(): + conn = get_connection() + cursor = conn.cursor() diff --git a/backend/common/redis.py b/backend/common/redis.py index 30ddd3f..99e989a 100644 --- a/backend/common/redis.py +++ b/backend/common/redis.py @@ -1,6 +1,50 @@ -import redis - -# We're using Redis as a way to store codes with expiry dates - it might a bit -# overkill, but it works - -cache = redis.Redis(host="redis", port=6379, db=0) +from datetime import timedelta +import redis + +# We're using Redis as a way to store codes with expiry dates - it might a bit +# overkill, but it works + +MINUTES_IN_DAY = 1440 + +cache = redis.Redis(host="redis", port=6379, db=0) + +## EMAIL VERIFICATION + +## LOCKOUT + +def register_incorrect(id): + times = cache.get(f"attempts_{id}") + + if times is None: + times = 0 + + cache.set(f"attempts_{id}", int(times) + 1) + +def incorrect_attempts(id): + attempts = cache.get(f"attempts_{id}") + + if attempts is None: + return 0 + else: + return int(attempts) + +def calculate_time(attempts): + if attempts < 3: + return 0 + + minutes = 2 ** (attempts - 3) + + if minutes > MINUTES_IN_DAY: + return MINUTES_IN_DAY + else: + return minutes + +def block(id, time): + cache.set(f"block_{id}", "", ex=timedelta(minutes=time)) + +def is_blocked(id): + token = cache.get(f"block_{id}") + return token is not None + +def clear_redis(): + cache.flushdb() diff --git a/backend/database/database.py b/backend/database/database.py index d69c5d3..13e1c5d 100644 --- a/backend/database/database.py +++ b/backend/database/database.py @@ -1,28 +1,28 @@ -import os -from psycopg2.pool import ThreadedConnectionPool - -user = os.environ["POSTGRES_USER"] -password = os.environ["POSTGRES_PASSWORD"] -host = os.environ["POSTGRES_HOST"] -port = os.environ["POSTGRES_PORT"] -database = os.environ["POSTGRES_DB"] - -# TABLES = ["Users", "Questions", "Parts", "Competitions", "Inputs", "Solves"] - -db = ThreadedConnectionPool( - 1, 20, - user=user, - password=password, - host=host, - port=port, - database=database -) - -def clear_database(): - conn = db.getconn() - - with conn.cursor() as cursor: - cursor.execute(f"""SELECT truncate_tables();""") - conn.commit() - - db.putconn(conn) +import os +from psycopg2.pool import ThreadedConnectionPool + +user = os.environ["POSTGRES_USER"] +password = os.environ["POSTGRES_PASSWORD"] +host = os.environ["POSTGRES_HOST"] +port = os.environ["POSTGRES_PORT"] +database = os.environ["POSTGRES_DB"] + +# TABLES = ["Users", "Questions", "Parts", "Competitions", "Inputs", "Solves"] + +db = ThreadedConnectionPool( + 1, 20, + user=user, + password=password, + host=host, + port=port, + database=database +) + +def clear_database(): + conn = db.getconn() + + with conn.cursor() as cursor: + cursor.execute(f"""SELECT truncate_tables();""") + conn.commit() + + db.putconn(conn) diff --git a/backend/database/user.py b/backend/database/user.py index caf9238..88f6d10 100644 --- a/backend/database/user.py +++ b/backend/database/user.py @@ -1,56 +1,93 @@ -from database.database import db - - -def add_user(email, username, password) -> int: - """Adds a user to the database, returning their ID.""" - - conn = db.getconn() - - with conn.cursor() as cursor: - cursor.execute(f"INSERT INTO Users (email, username, password) VALUES ('{email}', '{username}', '{password}')") - conn.commit() - - cursor.execute(f"SELECT uid FROM Users WHERE email = '{email}'") - id = cursor.fetchone()[0] - - db.putconn(conn) - return id - - -def fetch_user(email: str): - """Given a user's email, fetches their content from the database.""" - - conn = db.getconn() - - with conn.cursor() as cursor: - cursor.execute(f"SELECT * FROM Users WHERE email = '{email}'") - result = cursor.fetchone() - - db.putconn(conn) - return result - - -def email_exists(email: str) -> bool: - """Checks if an email exists in the users table.""" - - conn = db.getconn() - - with conn.cursor() as cursor: - cursor.execute(f"SELECT * FROM Users WHERE email = '{email}'") - results = cursor.fetchall() - - db.putconn(conn) - return results != [] - - -def username_exists(username: str) -> bool: - """Checks if a username is already used.""" - - conn = db.getconn() - - with conn.cursor() as cursor: - cursor.execute(f"SELECT * FROM Users WHERE username = '{username}'") - results = cursor.fetchall() - - db.putconn(conn) - return results != [] +from database.database import db + +# Get all the information about a user given their uid +# Returns all information in the form of a dictionary +def get_user_info(uid): + conn = db.getconn() + + with conn.cursor() as cursor: + query = f""" + select * from Users where uid = {uid}; + """ + cursor.execute(query) + + # only one entry should be returned since day number is unique + t = cursor.fetchone() + + db.putconn(conn) + return t + + +def add_user(email, username, password) -> int: + """Adds a user to the database, returning their ID.""" + + conn = db.getconn() + + with conn.cursor() as cursor: + cursor.execute("INSERT INTO Users (email, username, password) VALUES (%s, %s, %s)", + (email, username, password)) + conn.commit() + + cursor.execute("SELECT uid FROM Users WHERE email = %s", (email,)) + id = cursor.fetchone()[0] + + db.putconn(conn) + return id + + +def fetch_id(email: str): + """Given a user's email, fetches their ID.""" + + conn = db.getconn() + + with conn.cursor() as cursor: + cursor.execute("SELECT uid FROM Users WHERE email = %s", (email,)) + result = cursor.fetchone() + + if result is None: + db.putconn(conn) + return None + + id = result[0] + + db.putconn(conn) + return id + + +def fetch_user(email: str): + """Given a user's email, fetches their content from the database.""" + + conn = db.getconn() + + with conn.cursor() as cursor: + cursor.execute("SELECT * FROM Users WHERE email = %s", (email,)) + result = cursor.fetchone() + + db.putconn(conn) + return result + + +def email_exists(email: str) -> bool: + """Checks if an email exists in the users table.""" + + conn = db.getconn() + + with conn.cursor() as cursor: + cursor.execute("SELECT * FROM Users WHERE email = %s", (email,)) + results = cursor.fetchall() + + db.putconn(conn) + return results != [] + + +def username_exists(username: str) -> bool: + """Checks if a username is already used.""" + + conn = db.getconn() + + with conn.cursor() as cursor: + cursor.execute("SELECT * FROM Users WHERE username = %s", (username,)) + results = cursor.fetchall() + + db.putconn(conn) + return results != [] diff --git a/backend/models/user.py b/backend/models/user.py index ec67448..97de59a 100644 --- a/backend/models/user.py +++ b/backend/models/user.py @@ -1,126 +1,122 @@ -from datetime import timedelta -import os - -from argon2 import PasswordHasher -from argon2.exceptions import VerificationError -from email_validator import validate_email, EmailNotValidError -from itsdangerous import URLSafeTimedSerializer - -from common.exceptions import AuthError, InvalidError, RequestError -from common.redis import cache -from database.database import db -from database.user import add_user, email_exists, fetch_user, username_exists - -hasher = PasswordHasher( - time_cost=2, - memory_cost=2**15, - parallelism=1 -) - -verify_serialiser = URLSafeTimedSerializer(os.environ["FLASK_SECRET"], salt="verify") - -class User: - def __init__(self, id, email, username, password): - self.id = id - self.email = email - self.username = username - self.password = password - - # Helper methods - - @staticmethod - def hash_password(password): - return hasher.hash(password) - - # API-facing methods - - @staticmethod - def register(email, username, password): - """Given an email, username and password, creates a verification code - for that user in Redis such that we can verify that user's email.""" - - # Check for malformed input - try: - normalised = validate_email(email).email - except EmailNotValidError as e: - raise RequestError(description="Invalid email") from e - - if email_exists(normalised): - raise RequestError(description="Email already registered") - - if username_exists(username): - raise RequestError(description="Username already used") - - # Our account is good, we hash the password - hashed = hasher.hash(password) - - # Add verification code to Redis cache, with expiry date of 1 hour - code = verify_serialiser.dumps(normalised) - - data = { - "email": normalised, - "username": username, - "password": hashed - } - - # We use a pipeline here to ensure these instructions are atomic - pipeline = cache.pipeline() - - pipeline.hset(f"register:{code}", mapping=data) - pipeline.expire(f"register:{code}", timedelta(hours=1)) - - pipeline.execute() - - return code - - @staticmethod - def register_verify(token): - cache_key = f"register:{token}" - - if not cache.exists(cache_key): - raise AuthError("Token expired or does not correspond to registering user") - - result = cache.hgetall(cache_key) - stringified = {} - - for key, value in result.items(): - stringified[key.decode()] = value.decode() - - id = add_user(stringified["email"], stringified["username"], stringified["password"]) - return User(id, stringified["email"], stringified["username"], stringified["password"]) - - @staticmethod - def login(email, password): - """Logs user in with their credentials (currently email and password).""" - try: - normalised = validate_email(email).email - except EmailNotValidError as e: - raise AuthError(description="Invalid email or password") from e - - result = fetch_user(normalised) - - try: - id, email, username, github, hashed = result - hasher.verify(hashed, password) - except (TypeError, VerificationError) as e: - raise AuthError(description="Invalid email or password") from e - - return User(id, email, username, hashed) - - @staticmethod - def get(id): - """Given a user's ID, fetches all of their information from the database.""" - conn = db.getconn() - - with conn.cursor() as cursor: - cursor.execute("SELECT * FROM Users WHERE uid = %s", (id,)) - fetched = cursor.fetchall() - - if fetched == []: - raise InvalidError(description=f"Requested user ID {id} doesn't exist") - - id, email, github, username, password = fetched[0] - - db.putconn(conn) - - return User(id, email, username, password) +from datetime import timedelta +import os + +from argon2 import PasswordHasher +from argon2.exceptions import VerificationError +from email_validator import validate_email, EmailNotValidError +from itsdangerous import URLSafeTimedSerializer + +from common.exceptions import AuthError, InvalidError, RequestError +from common.redis import cache +from database.user import add_user, email_exists, fetch_user, get_user_info, username_exists + +hasher = PasswordHasher( + time_cost=2, + memory_cost=2**15, + parallelism=1 +) + +verify_serialiser = URLSafeTimedSerializer(os.environ["FLASK_SECRET"], salt="verify") + +class User: + def __init__(self, id, email, username, password, github_username=None): + self.id = id + self.email = email + self.username = username + self.password = password + + self.github_username = github_username + + # Helper methods + + @staticmethod + def hash_password(password): + return hasher.hash(password) + + # API-facing methods + + @staticmethod + def register(email, username, password): + """Given an email, username and password, creates a verification code + for that user in Redis such that we can verify that user's email.""" + + # Check for malformed input + try: + normalised = validate_email(email).email + except EmailNotValidError as e: + raise RequestError(description="Invalid email") from e + + if email_exists(normalised): + raise RequestError(description="Email already registered") + + if username_exists(username): + raise RequestError(description="Username already used") + + # Our account is good, we hash the password + hashed = hasher.hash(password) + + # Add verification code to Redis cache, with expiry date of 1 hour + code = verify_serialiser.dumps(normalised) + + data = { + "email": normalised, + "username": username, + "password": hashed + } + + # We use a pipeline here to ensure these instructions are atomic + pipeline = cache.pipeline() + + pipeline.hset(f"register:{code}", mapping=data) + pipeline.expire(f"register:{code}", timedelta(hours=1)) + + pipeline.execute() + + return code + + @staticmethod + def register_verify(token): + cache_key = f"register:{token}" + + if not cache.exists(cache_key): + raise AuthError("Token expired or does not correspond to registering user") + + result = cache.hgetall(cache_key) + stringified = {} + + for key, value in result.items(): + stringified[key.decode()] = value.decode() + + id = add_user(stringified["email"], stringified["username"], stringified["password"]) + return User(id, stringified["email"], stringified["username"], stringified["password"]) + + @staticmethod + def login(email, password): + """Logs user in with their credentials (currently email and password).""" + try: + normalised = validate_email(email).email + except EmailNotValidError as e: + raise AuthError(description="Invalid email or password") from e + + result = fetch_user(normalised) + + try: + id, email, github_username, username, hashed = result + hasher.verify(hashed, password) + except (TypeError, VerificationError) as e: + raise AuthError(description="Invalid email or password") from e + + return User(id, email, username, hashed, github_username) + + @staticmethod + def get(id): + """Given a user's ID, fetches all of their information from the database.""" + + result = get_user_info(id) + + if result is None: + raise InvalidError(description=f"Requested user ID {id} doesn't exist") + + id, email, github_username, username, password = result + + return User(id, email, username, password, github_username) diff --git a/backend/routes/auth.py b/backend/routes/auth.py index b725f10..8907577 100644 --- a/backend/routes/auth.py +++ b/backend/routes/auth.py @@ -1,80 +1,105 @@ -import os -from flask import Blueprint, render_template, request, jsonify -from flask_mail import Message -from flask_jwt_extended import create_access_token, set_access_cookies, unset_jwt_cookies, verify_jwt_in_request -from common.exceptions import AuthError -from common.plugins import mail -from models.user import User - -# Constants - -auth = Blueprint("auth", __name__) - -# Routes (fairly temporary here) -# TODO: add invalid login attempt protection - -@auth.route("/login", methods=["POST"]) -def login(): - json = request.get_json() - - user = User.login(json["email"], json["password"]) - token = create_access_token(identity=user) - - response = jsonify({}) - set_access_cookies(response, token) - - return response, 200 - -@auth.route("/register", methods=["POST"]) -def register(): - # TODO: convert to email verification once we get email address - json = request.get_json() - - # Fetch verification code - code = User.register(json["email"], json["username"], json["password"]) - # TODO: convert to domain of verification page once we have its address - url = f"{os.environ['TESTING_ADDRESS']}/verify/{code}" - - html = render_template("activate.html", confirm_url=url) - - # Send it over to email - message = Message( - "Account registered for Week in Wonderland", - sender="weekinwonderland@csesoc.org.au", - recipients=[json["email"]], - html=html - ) - - mail.send(message) - - response = jsonify({}) - - return response, 200 - -@auth.route("/register/verify", methods=["POST"]) -def register_verify(): - json = request.get_json() - - user = User.register_verify(json["token"]) - cookie = create_access_token(identity=user) - - response = jsonify({}) - set_access_cookies(response, cookie) - - return response, 200 - -@auth.route("/verify_token", methods=["GET"]) -def verify_token(): - try: - verify_jwt_in_request() - except: - raise AuthError("Not logged in") - - return jsonify({}), 200 - -@auth.route("/logout", methods=["POST"]) -def logout(): - response = jsonify({}) - unset_jwt_cookies(response) - - return response, 200 +import os +from flask import Blueprint, render_template, request, jsonify +from flask_mail import Message +from flask_jwt_extended import create_access_token, jwt_required, set_access_cookies, unset_jwt_cookies, verify_jwt_in_request + +from common.exceptions import AuthError +from common.plugins import mail +from common.redis import block, calculate_time, incorrect_attempts, is_blocked, register_incorrect +from database.user import fetch_id +from models.user import User + +# Constants + +auth = Blueprint("auth", __name__) + +# Routes (fairly temporary here) + +@auth.route("/login", methods=["POST"]) +def login(): + json = request.get_json() + + id = fetch_id(json["email"]) + + if is_blocked(id): + raise AuthError() + + try: + user = User.login(json["email"], json["password"]) + except Exception as e: + register_incorrect(id) + + attempts = incorrect_attempts(id) + + if attempts >= 3: + block_time = calculate_time(attempts) + block(id, block_time) + + raise e + + token = create_access_token(identity=user) + + response = jsonify({}) + set_access_cookies(response, token) + + return response, 200 + +@auth.route("/register", methods=["POST"]) +def register(): + # TODO: convert to email verification once we get email address + json = request.get_json() + + # Fetch verification code + code = User.register(json["email"], json["username"], json["password"]) + # TODO: convert to domain of verification page once we have its address + url = f"{os.environ['TESTING_ADDRESS']}/verify/{code}" + + html = render_template("activate.html", confirm_url=url) + + # Send it over to email + message = Message( + "Account registered for Week in Wonderland", + sender="weekinwonderland@csesoc.org.au", + recipients=[json["email"]], + html=html + ) + + mail.send(message) + + response = jsonify({}) + + return response, 200 + +@auth.route("/register/verify", methods=["POST"]) +def register_verify(): + json = request.get_json() + + user = User.register_verify(json["token"]) + cookie = create_access_token(identity=user) + + response = jsonify({}) + set_access_cookies(response, cookie) + + return response, 200 + +@auth.route("/verify_token", methods=["GET"]) +def verify_token(): + try: + verify_jwt_in_request() + except: + raise AuthError("Not logged in") + + return jsonify({}), 200 + +@auth.route("/logout", methods=["POST"]) +def logout(): + response = jsonify({}) + unset_jwt_cookies(response) + + return response, 200 + +@jwt_required() +@auth.route("/protected", methods=["POST"]) +def protected(): + verify_jwt_in_request() + return jsonify({}), 200 diff --git a/backend/routes/user.py b/backend/routes/user.py index ebf3e62..784370b 100644 --- a/backend/routes/user.py +++ b/backend/routes/user.py @@ -83,6 +83,23 @@ def set_name(): except: raise AuthError("Invalid token") +""" +@user.route("/user/reset_email/request", methods=["POST"]) +def reset_email_request(): + data = request.get_json() + ''' + { + token: token (in cookies) + email: string + } + ''' + try: + verify_jwt_in_request() + except: + raise AuthError("Invalid token") +======= +>>>>>>> main + # @user.route("/reset_email/request", methods=["POST"]) @@ -135,4 +152,4 @@ def set_name(): # def reset_password_request(): # json = request.get_json() # return jsonify({}) - +""" \ No newline at end of file diff --git a/backend/test/auth/login_test.py b/backend/test/auth/login_test.py index d1e432e..104761c 100644 --- a/backend/test/auth/login_test.py +++ b/backend/test/auth/login_test.py @@ -1,56 +1,93 @@ -import pytest - -# Import for pytest -from test.helpers import clear_all, db_add_user -from test.fixtures import app, client - - -def test_no_users(client): - clear_all() - - response = client.post("/auth/login", json={ - "email": "asdfghjkl@gmail.com", - "password": "foobar" - }) - - assert response.status_code == 401 - - -def test_invalid_email(client): - clear_all() - - db_add_user("asdfghjkl@gmail.com", "asdf", "foobar") - - response = client.post("/auth/login", json={ - "email": "foobar@gmail.com", - "password": "foobaz" - }) - - assert response.status_code == 401 - - -def test_wrong_password(client): - clear_all() - - db_add_user("asdfghjkl@gmail.com", "asdf", "foobar") - - response = client.post("/auth/login", json={ - "email": "asdfghjkl@gmail.com", - "password": "foobaz" - }) - - assert response.status_code == 401 - -def test_success(client): - clear_all() - - db_add_user("asdfghjkl@gmail.com", "asdf", "foobar") - - response = client.post("/auth/login", json={ - "email": "asdfghjkl@gmail.com", - "password": "foobar" - }) - - assert response.status_code == 200 - - # TODO: once user profile is in, improve this test +import pytest + +# Import for pytest +from flask.testing import FlaskClient +from test.helpers import clear_all, db_add_user, generate_csrf_header +from test.fixtures import app, client + + +def test_no_users(client): + clear_all() + + response = client.post("/auth/login", json={ + "email": "asdfghjkl@gmail.com", + "password": "foobar" + }) + + assert response.status_code == 401 + + +def test_invalid_email(client): + clear_all() + + db_add_user("asdfghjkl@gmail.com", "asdf", "foobar") + + response = client.post("/auth/login", json={ + "email": "foobar@gmail.com", + "password": "foobaz" + }) + + assert response.status_code == 401 + + +def test_wrong_password(client): + clear_all() + + db_add_user("asdfghjkl@gmail.com", "asdf", "foobar") + + response = client.post("/auth/login", json={ + "email": "asdfghjkl@gmail.com", + "password": "foobaz" + }) + + assert response.status_code == 401 + +def test_success(client): + clear_all() + + db_add_user("asdfghjkl@gmail.com", "asdf", "foobar") + + response = client.post("/auth/login", json={ + "email": "asdfghjkl@gmail.com", + "password": "foobar" + }) + + assert response.status_code == 200 + +def test_lockout(client): + clear_all() + + db_add_user("asdfghjkl@gmail.com", "asdf", "foobar") + + # Incorrect login 3 times + for _ in range(3): + response = client.post("/auth/login", json={ + "email": "asdfghjkl@gmail.com", + "password": "foobaz" + }) + + assert response.status_code == 401 + + # Now when we login, it should lock user out + response = client.post("/auth/login", json={ + "email": "asdfghjkl@gmail.com", + "password": "foobar" + }) + + assert response.status_code == 401 + +def test_protected_route(client: FlaskClient): + clear_all() + + db_add_user("asdfghjkl@gmail.com", "asdf", "foobar") + + response = client.post("/auth/login", json={ + "email": "asdfghjkl@gmail.com", + "password": "foobar" + }) + + assert response.status_code == 200 + + response = client.post("/auth/protected", headers=generate_csrf_header(response)) + + assert response.status_code == 200 diff --git a/backend/test/auth/logout_test.py b/backend/test/auth/logout_test.py index b01497c..eff5798 100644 --- a/backend/test/auth/logout_test.py +++ b/backend/test/auth/logout_test.py @@ -1,24 +1,23 @@ -from test.helpers import clear_all, db_add_user -from test.fixtures import app, client - -def test_success(client): - clear_all() - - db_add_user("asdfghjkl@gmail.com", "asdf", "foobar") - - # Log user in - login_response = client.post("/auth/login", json={ - "email": "asdfghjkl@gmail.com", - "password": "foobar" - }) - - assert login_response.status_code == 200 - - # Log user out - response = client.post("/auth/logout") - - assert response.status_code == 200 - - # Check there's no more cookies - assert len(client.cookie_jar) == 0 - +from test.helpers import clear_all, db_add_user +from test.fixtures import app, client + +def test_success(client): + clear_all() + + db_add_user("asdfghjkl@gmail.com", "asdf", "foobar") + + # Log user in + login_response = client.post("/auth/login", json={ + "email": "asdfghjkl@gmail.com", + "password": "foobar" + }) + + assert login_response.status_code == 200 + + # Log user out + response = client.post("/auth/logout") + + assert response.status_code == 200 + + # Check there's no more cookies + assert len(client.cookie_jar) == 0 diff --git a/backend/test/helpers.py b/backend/test/helpers.py index c3902c1..c4323d2 100644 --- a/backend/test/helpers.py +++ b/backend/test/helpers.py @@ -1,15 +1,36 @@ -from common.redis import cache -from database.database import clear_database -from database.user import add_user -from models.user import User - - -def db_add_user(email, username, password): - add_user(email, username, User.hash_password(password)) - -def clear_all(): - # Clear Redis - cache.flushdb() - - # Clear database - clear_database() +from common.redis import clear_redis +from database.database import clear_database +from database.user import add_user +from models.user import User + + +def db_add_user(email, username, password): + add_user(email, username, User.hash_password(password)) + +def clear_all(): + # Clear Redis + clear_redis() + + # Clear database + clear_database() + +def get_cookie_from_header(response, cookie_name): + cookie_headers = response.headers.getlist("Set-Cookie") + + for header in cookie_headers: + attributes = header.split(";") + + if cookie_name in attributes[0]: + cookie = {} + + for attr in attributes: + split = attr.split("=") + cookie[split[0].strip().lower()] = split[1] if len(split) > 1 else True + + return cookie + + return None + +def generate_csrf_header(response): + csrf_token = get_cookie_from_header(response, "csrf_access_token")["csrf_access_token"] + return {"X-CSRF-TOKEN": csrf_token}