Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(API): allow anyone to access proof data #606

Merged
merged 1 commit into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 15 additions & 38 deletions open_prices/api/proofs/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,30 +69,17 @@ def setUpTestData(cls):

def test_proof_list(self):
# anonymous
response = self.client.get(self.url)
self.assertEqual(response.status_code, 403)
# wrong token
response = self.client.get(
self.url, headers={"Authorization": f"Bearer {self.user_session.token}X"}
)
self.assertEqual(response.status_code, 403)
# authenticated
# thanks to select_related and prefetch_related, we only have 6
# thanks to select_related and prefetch_related, we only have 3
# queries:
# - 1 to get the fetch the user session
# - 1 to update the session
# - 1 to get the user
# - 1 to count the number of proofs of the user
# - 1 to get the proofs and their associated locations (select_related)
# - 1 to get the associated proof predictions (prefetch_related)
with self.assertNumQueries(6):
response = self.client.get(
self.url, headers={"Authorization": f"Bearer {self.user_session.token}"}
)
with self.assertNumQueries(3):
response = self.client.get(self.url)
self.assertEqual(response.status_code, 200)
data = response.data
self.assertEqual(data["total"], 2) # only user's proofs
self.assertEqual(len(data["items"]), 2)
self.assertEqual(data["total"], 3)
self.assertEqual(len(data["items"]), 3)
item = data["items"][0]
self.assertEqual(item["id"], self.proof.id) # default order
self.assertIn("predictions", item)
Expand Down Expand Up @@ -122,10 +109,8 @@ def setUpTestData(cls):

def test_proof_list_order_by(self):
url = self.url + "?order_by=-price_count"
response = self.client.get(
url, headers={"Authorization": f"Bearer {self.user_session.token}"}
)
self.assertEqual(response.data["total"], 2)
response = self.client.get(url)
self.assertEqual(response.data["total"], 3)
self.assertEqual(response.data["items"][0]["price_count"], 50)


Expand All @@ -146,12 +131,16 @@ def setUpTestData(cls):

def test_proof_list_filter_by_type(self):
url = self.url + "?type=RECEIPT"
response = self.client.get(
url, headers={"Authorization": f"Bearer {self.user_session.token}"}
)
response = self.client.get(url)
self.assertEqual(response.data["total"], 1)
self.assertEqual(response.data["items"][0]["price_count"], 15)

def test_proof_list_filter_by_owner(self):
url = self.url + f"?owner={self.user_session.user.user_id}"
response = self.client.get(url)
self.assertEqual(response.data["total"], 2)
self.assertEqual(response.data["items"][0]["price_count"], 15)


class ProofDetailApiTest(TestCase):
@classmethod
Expand All @@ -166,23 +155,11 @@ def setUpTestData(cls):
def test_proof_detail(self):
# 404
url = reverse("api:proofs-detail", args=[999])
response = self.client.get(
url, headers={"Authorization": f"Bearer {self.user_session_1.token}"}
)
response = self.client.get(url)
self.assertEqual(response.status_code, 404)
self.assertEqual(response.data["detail"], "No Proof matches the given query.")
# anonymous
response = self.client.get(self.url)
self.assertEqual(response.status_code, 403)
# wrong token
response = self.client.get(
self.url, headers={"Authorization": f"Bearer {self.user_session_1.token}X"}
)
self.assertEqual(response.status_code, 403)
# authenticated
response = self.client.get(
self.url, headers={"Authorization": f"Bearer {self.user_session_1.token}"}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data["id"], self.proof.id)

Expand Down
30 changes: 15 additions & 15 deletions open_prices/api/proofs/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from rest_framework import filters, mixins, status, viewsets
from rest_framework.decorators import action
from rest_framework.parsers import MultiPartParser
from rest_framework.permissions import IsAuthenticated
from rest_framework.permissions import IsAuthenticatedOrReadOnly
from rest_framework.request import Request
from rest_framework.response import Response

Expand All @@ -31,28 +31,28 @@ class ProofViewSet(
viewsets.GenericViewSet,
):
authentication_classes = [CustomAuthentication]
permission_classes = [IsAuthenticated]
permission_classes = [IsAuthenticatedOrReadOnly]
http_method_names = ["get", "post", "patch", "delete"] # disable "put"
queryset = Proof.objects.none()
queryset = Proof.objects.all()
serializer_class = ProofFullSerializer
filter_backends = [DjangoFilterBackend, filters.OrderingFilter]
filterset_class = ProofFilter
ordering_fields = ["date", "price_count", "created"]
ordering = ["created"]

def get_queryset(self):
# only return proofs owned by the current user
if self.request.user.is_authenticated:
queryset = Proof.objects.filter(owner=self.request.user.user_id)
if self.request.method in ["GET"]:
# Select all proofs along with their locations using a select
# related query (1 single query)
# Then prefetch all the predictions related to the proof using
# a prefetch related query (only 1 query for all proofs)
return queryset.select_related("location").prefetch_related(
"predictions"
)
return queryset
if self.request.method in ["GET"]:
# Select all proofs along with their locations using a select
# related query (1 single query)
# Then prefetch all the predictions related to the proof using
# a prefetch related query (only 1 query for all proofs)
return self.queryset.select_related("location").prefetch_related(
"predictions"
)
elif self.request.method in ["PATCH", "DELETE"]:
# only return proofs owned by the current user
if self.request.user.is_authenticated:
return self.queryset.filter(owner=self.request.user.user_id)
return self.queryset

def get_serializer_class(self):
Expand Down
Loading