Skip to content

Commit

Permalink
Allow searching by name
Browse files Browse the repository at this point in the history
  • Loading branch information
k-nut committed Jun 24, 2024
1 parent 866285e commit 1f50803
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 1 deletion.
15 changes: 14 additions & 1 deletion app/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ def apply(self, query):
column_to_filter = getattr(models.School, self.key)
return query.filter(column_to_filter.in_(self.values))

class TextMatchFilter(Filter):
supported_keys = ['name']

def apply(self, query):
column_to_filter = getattr(models.School, self.key)
return query.filter(column_to_filter.icontains(self.values))

class UpdateTimestampFilter(Filter):
supported_keys = ['update_timestamp']

Expand Down Expand Up @@ -67,7 +74,13 @@ def apply(self, query):


class SchoolFilter:
filter_classes = [StateFilter, BasicFilter, LatLonSorter, BoundingBoxFilter, UpdateTimestampFilter]
filter_classes = [StateFilter,
BasicFilter,
LatLonSorter,
BoundingBoxFilter,
UpdateTimestampFilter,
TextMatchFilter
]

def __init__(self, params):
self.used_filters = []
Expand Down
4 changes: 4 additions & 0 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def read_schools(skip: int = 0,
state: Optional[List[State]] = Query(None),
school_type: Optional[List[str]] = Query(None),
legal_status: Optional[List[str]] = Query(None),
name: Optional[str] = Query(None,
description="Allows searching for names of schools."
"Searches for case-insensitive substrings."),
by_lat: Optional[float] = Query(None,
description="Allows ordering result by distance from a geographical point."
"Must be used in combination with `by_lon`"
Expand Down Expand Up @@ -66,6 +69,7 @@ def read_schools(skip: int = 0,
"school_type": school_type,
"legal_status": legal_status,
"update_timestamp": update_timestamp,
"name": name,
}
if by_lat or by_lon:
if not (by_lon and by_lat):
Expand Down
23 changes: 23 additions & 0 deletions test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,29 @@ def test_schools_by_update_date(self, client, db):
assert response.status_code == 200
assert len(response.json()) == 1

def test_schools_by_name(self, client, db):
# Arrange
for school in [
SchoolFactory.create(name='Schule am kleinen Deich'),
SchoolFactory.create(name='Schule an der Dorfstraßse'),
]:
db.add(school)
db.commit()

# Act
unfiltered_response = client.get("/schools")
deich_response = client.get("/schools?name=deich")
no_match_response = client.get("/schools?name=nicht%20da")

# Assert
assert unfiltered_response.status_code == 200
assert deich_response.status_code == 200
assert no_match_response.status_code == 200

assert len(unfiltered_response.json()) == 2
assert len(deich_response.json()) == 1
assert len(no_match_response.json()) == 0

def test_get_single_no_result(self, client, db):
# Arrange
self.__setup_schools(db)
Expand Down

0 comments on commit 1f50803

Please sign in to comment.