diff --git a/programs/apps/api/jwt_decode_handler.py b/programs/apps/api/jwt_decode_handler.py new file mode 100644 index 0000000..2104742 --- /dev/null +++ b/programs/apps/api/jwt_decode_handler.py @@ -0,0 +1,42 @@ +""" +Custom JWT decoding function for django_rest_framework jwt package. + +Adds logging to facilitate debugging of InvalidTokenErrors. Also +requires "exp" and "iat" claims to be present - the base package +doesn't expose settings to enforce this. +""" +import logging + +import jwt +from rest_framework_jwt.settings import api_settings + +logger = logging.getLogger(__name__) + + +def decode(token): + """ + Ensure InvalidTokenErrors are logged for diagnostic purposes, before + failing authentication. + """ + + options = { + 'verify_exp': api_settings.JWT_VERIFY_EXPIRATION, + 'require_exp': True, + 'require_iat': True, + } + + try: + return jwt.decode( + token, + api_settings.JWT_SECRET_KEY, + api_settings.JWT_VERIFY, + options=options, + leeway=api_settings.JWT_LEEWAY, + audience=api_settings.JWT_AUDIENCE, + issuer=api_settings.JWT_ISSUER, + algorithms=[api_settings.JWT_ALGORITHM] + ) + except jwt.InvalidTokenError as exc: + exc_type = u'{}.{}'.format(exc.__class__.__module__, exc.__class__.__name__) + logger.info("raised_invalid_token: exc_type=%r, exc_detail=%r", exc_type, exc.message) + raise diff --git a/programs/apps/api/tests/test_authentication.py b/programs/apps/api/tests/test_authentication.py index c483f1e..e1d80a9 100644 --- a/programs/apps/api/tests/test_authentication.py +++ b/programs/apps/api/tests/test_authentication.py @@ -1,20 +1,25 @@ """ Tests for REST API Authentication """ +import time + import ddt from django.contrib.auth.models import Group from django.db import IntegrityError from django.test import TestCase import mock from rest_framework.exceptions import AuthenticationFailed +from rest_framework.test import APIRequestFactory from programs.apps.api.authentication import JwtAuthentication, pipeline_set_user_roles +from programs.apps.api.jwt_decode_handler import api_settings as drf_jwt_settings from programs.apps.api.v1.tests.mixins import JwtMixin from programs.apps.core.constants import Role from programs.apps.core.models import User from programs.apps.core.tests.factories import UserFactory +@ddt.ddt class TestJWTAuthentication(JwtMixin, TestCase): """ Test id_token authentication used with the browseable API. @@ -53,6 +58,44 @@ def test_user_creation_failure(self): user = authentication.authenticate_credentials({'preferred_username': self.USERNAME}) self.assertEqual(user.username, self.USERNAME) + @ddt.data(('exp', -1), ('iat', 1)) + @ddt.unpack + def test_leeway(self, claim, offset): + """ + Verify that the service allows the specified amount of leeway (in + seconds) when nonzero and validating "exp" and "iat" claims. + """ + authentication = JwtAuthentication() + user = UserFactory() + jwt_value = self.generate_id_token(user, **{claim: int(time.time()) + offset}) + request = APIRequestFactory().get('dummy', HTTP_AUTHORIZATION='JWT {}'.format(jwt_value)) + + # with no leeway, these requests should not be authenticated + with mock.patch.object(drf_jwt_settings, 'JWT_LEEWAY', 0): + with self.assertRaises(AuthenticationFailed): + authentication.authenticate(request) + + # with enough leeway, these requests should be authenticated + with mock.patch.object(drf_jwt_settings, 'JWT_LEEWAY', abs(offset)): + self.assertEqual( + (user, jwt_value), + authentication.authenticate(request) + ) + + @ddt.data('exp', 'iat') + def test_required_claims(self, claim): + """ + Verify that tokens that do not carry 'exp' or 'iat' claims are rejected + """ + authentication = JwtAuthentication() + user = UserFactory() + jwt_payload = self.default_payload(user) + del jwt_payload[claim] + jwt_value = self.generate_token(jwt_payload) + request = APIRequestFactory().get('dummy', HTTP_AUTHORIZATION='JWT {}'.format(jwt_value)) + with self.assertRaises(AuthenticationFailed): + authentication.authenticate(request) + @ddt.ddt class TestPipelineUserRoles(TestCase): diff --git a/programs/apps/api/v1/tests/mixins.py b/programs/apps/api/v1/tests/mixins.py index 3056a95..49cb090 100644 --- a/programs/apps/api/v1/tests/mixins.py +++ b/programs/apps/api/v1/tests/mixins.py @@ -27,13 +27,20 @@ def generate_token(self, payload, secret=None): token = jwt.encode(payload, secret) return token - def generate_id_token(self, user, admin=False, ttl=0): + def generate_id_token(self, user, admin=False, ttl=1, **overrides): """Generate a JWT id_token that looks like the ones currently returned by the edx oidc provider.""" + payload = self.default_payload(user=user, admin=admin, ttl=ttl) + payload.update(overrides) + return self.generate_token(payload) + + def default_payload(self, user, admin=False, ttl=1): + """Generate a bare payload, in case tests need to manipulate + it directly before encoding.""" now = int(time()) - return self.generate_token({ + return { "iss": self.JWT_ISSUER, "sub": user.pk, "aud": self.JWT_AUDIENCE, @@ -47,7 +54,7 @@ def generate_id_token(self, user, admin=False, ttl=0): "name": user.full_name, "given_name": "", "family_name": "", - }) + } class AuthClientMixin(object): diff --git a/programs/settings/base.py b/programs/settings/base.py index 8e02acf..4e2edae 100644 --- a/programs/settings/base.py +++ b/programs/settings/base.py @@ -242,6 +242,8 @@ 'JWT_ISSUER': None, 'JWT_PAYLOAD_GET_USERNAME_HANDLER': lambda d: d.get('preferred_username'), 'JWT_AUDIENCE': None, + 'JWT_LEEWAY': 1, + 'JWT_DECODE_HANDLER': 'programs.apps.api.jwt_decode_handler.decode', } # END AUTHENTICATION CONFIGURATION