diff --git a/event_routing_backends/helpers.py b/event_routing_backends/helpers.py index bfa21d57..8784aeb4 100644 --- a/event_routing_backends/helpers.py +++ b/event_routing_backends/helpers.py @@ -93,17 +93,16 @@ def get_user(username_or_id): Returns: user object """ - user = user_id = username = None - if username_or_id: - try: - user_id = int(username_or_id) - except ValueError: - username = username_or_id + user = username = None + + if not username_or_id: + return None - if username: + try: + user = User.objects.get(id=int(username_or_id)) + except (User.DoesNotExist, ValueError): + username = username_or_id user = User.objects.filter(username=username).first() - elif user_id: - user = User.objects.filter(id=user_id).first() if username and not user: try: diff --git a/event_routing_backends/tests/test_helpers.py b/event_routing_backends/tests/test_helpers.py index 48b90a0e..0d76c123 100644 --- a/event_routing_backends/tests/test_helpers.py +++ b/event_routing_backends/tests/test_helpers.py @@ -3,25 +3,30 @@ """ from unittest.mock import patch +from ddt import data, ddt from django.test import TestCase from event_routing_backends.helpers import ( get_anonymous_user_id, get_block_id_from_event_referrer, get_course_from_id, + get_user, get_user_email, get_uuid5, ) from event_routing_backends.tests.factories import UserFactory +@ddt class TestHelpers(TestCase): """ Test the helper methods. """ + def setUp(self): super().setUp() - UserFactory.create(username='edx', email='edx@example.com') + self.edx_user = UserFactory.create(username='edx', email='edx@example.com') + UserFactory.create(username='10228945687', email='edx@example.com') def test_get_block_id_from_event_referrer_with_error(self): sample_event = { @@ -83,3 +88,24 @@ def test_get_course_from_id_unknown_course(self, mock_get_course_overviews): mock_get_course_overviews.return_value = [] with self.assertRaises(ValueError): get_course_from_id("foo") + + @data("edx", "10228945687") + def test_get_user_by_username(self, username): + """Test that the method get_user returns the right user based on given username parameter. + + Expected behavior: + - Returned user corresponds to the username. + """ + user = get_user(username) + + self.assertEqual(username, user.username) + + def test_get_user_by_id(self): + """ Test that the method get_user returns the right user based on the user id. + + Expected behavior: + - Returned user is the edx_user + """ + user = get_user(self.edx_user.id) + + self.assertEqual(self.edx_user, user)