diff --git a/ansible_base/oauth2_provider/fixtures.py b/ansible_base/oauth2_provider/fixtures.py index 97f97cc55..7a8e44b0b 100644 --- a/ansible_base/oauth2_provider/fixtures.py +++ b/ansible_base/oauth2_provider/fixtures.py @@ -65,16 +65,17 @@ def oauth2_application_password(randname): @pytest.fixture def oauth2_admin_access_token(oauth2_application, admin_api_client, admin_user): """ - 2-tuple with (token object with hashed token, plaintext token) + 3-tuple with (token object with hashed token, plaintext token, plaintext_refresh_token) """ url = get_relative_url('token-list') response = admin_api_client.post(url, {'application': oauth2_application[0].pk}) assert response.status_code == 201 plaintext_token = response.data['token'] + plaintext_refresh_token = response.data['refresh_token'] hashed_token = hash_string(plaintext_token, hasher=hashlib.sha256) token = OAuth2AccessToken.objects.get(token=hashed_token) - return (token, plaintext_token) + return (token, plaintext_token, plaintext_refresh_token) @copy_fixture(copies=3) diff --git a/ansible_base/oauth2_provider/models/refresh_token.py b/ansible_base/oauth2_provider/models/refresh_token.py index 078a87cf9..dd498dfea 100644 --- a/ansible_base/oauth2_provider/models/refresh_token.py +++ b/ansible_base/oauth2_provider/models/refresh_token.py @@ -1,9 +1,12 @@ +import hashlib + import oauth2_provider.models as oauth2_models from django.conf import settings from django.db import models from django.utils.translation import gettext_lazy as _ from ansible_base.lib.abstract_models.common import CommonModel +from ansible_base.lib.utils.hashing import hash_string from ansible_base.lib.utils.models import prevent_search activitystream = object @@ -21,3 +24,8 @@ class Meta(oauth2_models.AbstractRefreshToken.Meta): token = prevent_search(models.CharField(max_length=255)) updated = None # Tracked in CommonModel with 'modified', no need for this + + def save(self, *args, **kwargs): + if not self.pk: + self.token = hash_string(self.token, hasher=hashlib.sha256) + super().save(*args, **kwargs) diff --git a/ansible_base/oauth2_provider/serializers/token.py b/ansible_base/oauth2_provider/serializers/token.py index e1edb74eb..427e96e18 100644 --- a/ansible_base/oauth2_provider/serializers/token.py +++ b/ansible_base/oauth2_provider/serializers/token.py @@ -20,10 +20,11 @@ logger = logging.getLogger("ansible_base.oauth2_provider.serializers.token") -class BaseOAuth2TokenSerializer(CommonModelSerializer): +class OAuth2TokenSerializer(CommonModelSerializer): refresh_token = SerializerMethodField() unencrypted_token = None # Only used in POST so we can return the token in the response + unencrypted_refresh_token = None # Only used in POST so we can return the refresh token in the response class Meta: model = OAuth2AccessToken @@ -45,9 +46,10 @@ def to_representation(self, instance): request = self.context.get('request', None) ret = super().to_representation(instance) if request and request.method == 'POST': - # If we're creating the token, show it. Otherwise, show the encrypted string - # which is the default from the supermethod. + # If we're creating the token, show it. Otherwise, show the encrypted string. ret['token'] = self.unencrypted_token + else: + ret['token'] = ENCRYPTED_STRING return ret def get_refresh_token(self, obj) -> Optional[str]: @@ -56,7 +58,7 @@ def get_refresh_token(self, obj) -> Optional[str]: if not obj.refresh_token: return None elif request and request.method == 'POST': - return getattr(obj.refresh_token, 'token', '') + return self.unencrypted_refresh_token else: return ENCRYPTED_STRING except ObjectDoesNotExist: @@ -78,16 +80,6 @@ def validate_scope(self, value): raise ValidationError(_('Must be a simple space-separated string with allowed scopes {}.').format(SCOPES)) return value - def create(self, validated_data): - validated_data['user'] = self.context['request'].user - self.unencrypted_token = validated_data.get('token') # So we don't have to decrypt it - try: - return super().create(validated_data) - except AccessDeniedError as e: - raise PermissionDenied(str(e)) - - -class OAuth2TokenSerializer(BaseOAuth2TokenSerializer): def create(self, validated_data): current_user = get_current_user() validated_data['token'] = generate_token() @@ -95,10 +87,23 @@ def create(self, validated_data): if expires_delta == 0: logger.warning("OAUTH2_PROVIDER.ACCESS_TOKEN_EXPIRE_SECONDS was set to 0, creating token that has already expired") validated_data['expires'] = now() + timedelta(seconds=expires_delta) - obj = super().create(validated_data) + validated_data['user'] = self.context['request'].user + self.unencrypted_token = validated_data.get('token') # Before it is hashed + + try: + obj = super().create(validated_data) + except AccessDeniedError as e: + raise PermissionDenied(str(e)) + if obj.application and obj.application.user: obj.user = obj.application.user obj.save() if obj.application: - OAuth2RefreshToken.objects.create(user=current_user, token=generate_token(), application=obj.application, access_token=obj) + self.unencrypted_refresh_token = generate_token() + OAuth2RefreshToken.objects.create( + user=current_user, + token=self.unencrypted_refresh_token, + application=obj.application, + access_token=obj, + ) return obj diff --git a/ansible_base/oauth2_provider/views/token.py b/ansible_base/oauth2_provider/views/token.py index 39e087089..69395cca5 100644 --- a/ansible_base/oauth2_provider/views/token.py +++ b/ansible_base/oauth2_provider/views/token.py @@ -1,3 +1,4 @@ +import hashlib from datetime import timedelta from django.utils.timezone import now @@ -5,6 +6,7 @@ from oauthlib import oauth2 from rest_framework.viewsets import ModelViewSet +from ansible_base.lib.utils.hashing import hash_string from ansible_base.lib.utils.settings import get_setting from ansible_base.lib.utils.views.django_app_api import AnsibleBaseDjangoAppApiView from ansible_base.oauth2_provider.models import OAuth2AccessToken, OAuth2RefreshToken @@ -28,7 +30,8 @@ def create_token_response(self, request): # This code detects and auto-expires them on refresh grant # requests. if request.POST.get('grant_type') == 'refresh_token' and 'refresh_token' in request.POST: - refresh_token = OAuth2RefreshToken.objects.filter(token=request.POST['refresh_token']).first() + hashed_refresh_token = hash_string(request.POST['refresh_token'], hasher=hashlib.sha256) + refresh_token = OAuth2RefreshToken.objects.filter(token=hashed_refresh_token).first() if refresh_token: expire_seconds = get_setting('OAUTH2_PROVIDER', {}).get('REFRESH_TOKEN_EXPIRE_SECONDS', 0) if refresh_token.created + timedelta(seconds=expire_seconds) < now(): @@ -38,7 +41,23 @@ def create_token_response(self, request): # oauth2_provider.oauth2_backends.OAuthLibCore.create_token_response # (we override this so we can implement our own error handling to be compatible with AWX) - uri, http_method, body, headers = core._extract_params(request) + + # This is really, really ugly. Modify the request to hash the refresh_token + # but only long enough for the oauth lib to do its magic. + did_hash_refresh_token = False + old_post = request.POST + if 'refresh_token' in request.POST: + did_hash_refresh_token = True + request.POST = request.POST.copy() # so it's mutable + hashed_refresh_token = hash_string(request.POST['refresh_token'], hasher=hashlib.sha256) + request.POST['refresh_token'] = hashed_refresh_token + + try: + uri, http_method, body, headers = core._extract_params(request) + finally: + if did_hash_refresh_token: + request.POST = old_post + extra_credentials = core._get_extra_credentials(request) try: headers, body, status = core.server.create_token_response(uri, http_method, body, headers, extra_credentials) diff --git a/test_app/tests/oauth2_provider/views/test_token.py b/test_app/tests/oauth2_provider/views/test_token.py index 84e3b9d00..012b4da8a 100644 --- a/test_app/tests/oauth2_provider/views/test_token.py +++ b/test_app/tests/oauth2_provider/views/test_token.py @@ -273,7 +273,8 @@ def test_oauth2_token_create(oauth2_application, admin_api_client, admin_user): assert 'updated' not in response.data hashed_token = hash_string(response.data['token'], hasher=hashlib.sha256) token = OAuth2AccessToken.objects.get(token=hashed_token) - refresh_token = OAuth2RefreshToken.objects.get(token=response.data['refresh_token']) + hashed_refresh_token = hash_string(response.data['refresh_token'], hasher=hashlib.sha256) + refresh_token = OAuth2RefreshToken.objects.get(token=hashed_refresh_token) assert token.application == oauth2_application assert refresh_token.application == oauth2_application assert token.user == admin_user @@ -345,12 +346,13 @@ def test_oauth2_refresh_access_token(oauth2_application, oauth2_admin_access_tok """ app = oauth2_application[0] secret = oauth2_application[1] - refresh_token = oauth2_admin_access_token[0].refresh_token + refresh_token = oauth2_admin_access_token[2] + refresh_token_obj = oauth2_admin_access_token[0].refresh_token url = get_relative_url('token') data = { 'grant_type': 'refresh_token', - 'refresh_token': refresh_token.token, + 'refresh_token': refresh_token, } resp = unauthenticated_api_client.post( url, @@ -359,8 +361,8 @@ def test_oauth2_refresh_access_token(oauth2_application, oauth2_admin_access_tok headers={'Authorization': 'Basic ' + base64.b64encode(f"{app.client_id}:{secret}".encode()).decode()}, ) assert resp.status_code == 201 - assert OAuth2RefreshToken.objects.filter(token=refresh_token).exists() - original_refresh_token = OAuth2RefreshToken.objects.get(token=refresh_token) + assert OAuth2RefreshToken.objects.filter(token=refresh_token_obj.token).exists() + original_refresh_token = OAuth2RefreshToken.objects.get(token=refresh_token_obj.token) assert oauth2_admin_access_token not in OAuth2AccessToken.objects.all() assert OAuth2AccessToken.objects.count() == 1 @@ -372,11 +374,12 @@ def test_oauth2_refresh_access_token(oauth2_application, oauth2_admin_access_tok new_token = json_resp['access_token'] new_token_hashed = hash_string(new_token, hasher=hashlib.sha256) new_refresh_token = json_resp['refresh_token'] + new_refresh_token_hashed = hash_string(new_refresh_token, hasher=hashlib.sha256) assert OAuth2AccessToken.objects.filter(token=new_token_hashed).count() == 1 # checks that RefreshTokens are rotated (new RefreshToken issued) - assert OAuth2RefreshToken.objects.filter(token=new_refresh_token).count() == 1 - new_refresh_obj = OAuth2RefreshToken.objects.get(token=new_refresh_token) + assert OAuth2RefreshToken.objects.filter(token=new_refresh_token_hashed).count() == 1 + new_refresh_obj = OAuth2RefreshToken.objects.get(token=new_refresh_token_hashed) assert not new_refresh_obj.revoked @@ -387,7 +390,8 @@ def test_oauth2_refresh_token_expiration_is_respected(oauth2_application, oauth2 """ app = oauth2_application[0] secret = oauth2_application[1] - refresh_token = oauth2_admin_access_token[0].refresh_token + refresh_token = oauth2_admin_access_token[2] + refresh_token_obj = oauth2_admin_access_token[0].refresh_token settings.OAUTH2_PROVIDER['REFRESH_TOKEN_EXPIRE_SECONDS'] = 1 settings.OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS'] = 1 @@ -397,7 +401,7 @@ def test_oauth2_refresh_token_expiration_is_respected(oauth2_application, oauth2 url = get_relative_url('token') data = { 'grant_type': 'refresh_token', - 'refresh_token': refresh_token.token, + 'refresh_token': refresh_token, } response = admin_api_client.post( url, @@ -407,7 +411,7 @@ def test_oauth2_refresh_token_expiration_is_respected(oauth2_application, oauth2 ) assert response.status_code == 403 assert b'The refresh token has expired.' in response.content - assert OAuth2RefreshToken.objects.filter(token=refresh_token).exists() + assert OAuth2RefreshToken.objects.filter(token=refresh_token_obj.token).exists() assert OAuth2AccessToken.objects.count() == 1 assert OAuth2RefreshToken.objects.count() == 1