From 4a4eca1c1a8fc4e618d189f66a6e2bea5467c6d4 Mon Sep 17 00:00:00 2001 From: Raphael Odini Date: Tue, 21 Nov 2023 01:13:58 +0100 Subject: [PATCH] feat: GET /locations/id endpoint to get location details (#37) * New endpoint to fetch location data per id * Add tests & typing --- app/api.py | 11 +++++++++++ app/crud.py | 4 ++++ tests/test_api.py | 21 +++++++++++++++++++-- 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/app/api.py b/app/api.py index 16d10a63..302eb976 100644 --- a/app/api.py +++ b/app/api.py @@ -207,6 +207,17 @@ def get_user_proofs( return crud.get_user_proofs(db, user=current_user) +@app.get("/locations/{location_id}", response_model=schemas.LocationBase) +async def get_location(location_id: int, db: Session = Depends(get_db)): + db_location = crud.get_location_by_id(db, id=location_id) + if not db_location: + raise HTTPException( + status_code=404, + detail=f"Location with id {location_id} not found", + ) + return db_location + + @app.get("/status") async def status_endpoint(): return {"status": "running"} diff --git a/app/crud.py b/app/crud.py index 7b4d57e6..f2b6baba 100644 --- a/app/crud.py +++ b/app/crud.py @@ -176,6 +176,10 @@ def get_location_by_osm_id_and_type( ) +def get_location_by_id(db: Session, id: int): + return db.query(Location).filter(Location.id == id).first() + + def create_location(db: Session, location: LocationCreate): db_location = Location(**location.model_dump()) db.add(db_location) diff --git a/tests/test_api.py b/tests/test_api.py index 182b39c6..57d2cd0c 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -8,7 +8,7 @@ from app import crud from app.api import app, get_db from app.db import Base -from app.schemas import PriceCreate, UserBase +from app.schemas import LocationCreate, PriceCreate, UserBase # database setup # ------------------------------------------------------------------------------ @@ -40,7 +40,7 @@ def override_get_db(): client = TestClient(app) USER = UserBase(user_id="user1", token="user1__Utoken") - +LOCATION = LocationCreate(osm_id=3344841823, osm_type="NODE") PRICE_1 = PriceCreate( product_code="1111111111111", price=3.5, @@ -57,6 +57,12 @@ def user(db=override_get_db()): return db_user +@pytest.fixture(scope="module") +def location(db=override_get_db()): + db_location = crud.create_location(next(db), LOCATION) + return db_location + + # Tests # ------------------------------------------------------------------------------ def test_hello(): @@ -107,6 +113,8 @@ def test_get_prices(): response = client.get("/prices") assert response.status_code == 200 assert len(response.json()["items"]) == 1 + for price_field in ["location_id", "proof_id"]: + assert price_field in response.json()["items"][0] def test_get_prices_pagination(): @@ -138,3 +146,12 @@ def test_get_proofs(user): headers={"Authorization": f"Bearer {user.token}"}, ) assert response.status_code == 200 + + +def test_get_location(location): + # location exists + response = client.get(f"/locations/{location.id}") + assert response.status_code == 200 + # location does not exist + response = client.get(f"/locations/{location.id+1}") + assert response.status_code == 404