Skip to content

Commit

Permalink
Also hash refresh tokens
Browse files Browse the repository at this point in the history
Signed-off-by: Rick Elrod <[email protected]>
  • Loading branch information
relrod committed Nov 9, 2024
1 parent 231a5af commit 9ba2d06
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 30 deletions.
5 changes: 3 additions & 2 deletions ansible_base/oauth2_provider/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,17 @@ def oauth2_application_password(randname):
@pytest.fixture
def oauth2_admin_access_token(oauth2_application, admin_api_client, admin_user):
"""
2-tuple with (token object with hashed token, plaintext token)
3-tuple with (token object with hashed token, plaintext token, plaintext_refresh_token)
"""
url = get_relative_url('token-list')
response = admin_api_client.post(url, {'application': oauth2_application[0].pk})
assert response.status_code == 201

plaintext_token = response.data['token']
plaintext_refresh_token = response.data['refresh_token']
hashed_token = hash_string(plaintext_token, hasher=hashlib.sha256)
token = OAuth2AccessToken.objects.get(token=hashed_token)
return (token, plaintext_token)
return (token, plaintext_token, plaintext_refresh_token)


@copy_fixture(copies=3)
Expand Down
8 changes: 8 additions & 0 deletions ansible_base/oauth2_provider/models/refresh_token.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import hashlib

import oauth2_provider.models as oauth2_models
from django.conf import settings
from django.db import models
from django.utils.translation import gettext_lazy as _

from ansible_base.lib.abstract_models.common import CommonModel
from ansible_base.lib.utils.hashing import hash_string
from ansible_base.lib.utils.models import prevent_search

activitystream = object
Expand All @@ -21,3 +24,8 @@ class Meta(oauth2_models.AbstractRefreshToken.Meta):

token = prevent_search(models.CharField(max_length=255))
updated = None # Tracked in CommonModel with 'modified', no need for this

def save(self, *args, **kwargs):
if not self.pk:
self.token = hash_string(self.token, hasher=hashlib.sha256)
super().save(*args, **kwargs)
37 changes: 21 additions & 16 deletions ansible_base/oauth2_provider/serializers/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
logger = logging.getLogger("ansible_base.oauth2_provider.serializers.token")


class BaseOAuth2TokenSerializer(CommonModelSerializer):
class OAuth2TokenSerializer(CommonModelSerializer):
refresh_token = SerializerMethodField()

unencrypted_token = None # Only used in POST so we can return the token in the response
unencrypted_refresh_token = None # Only used in POST so we can return the refresh token in the response

class Meta:
model = OAuth2AccessToken
Expand All @@ -45,9 +46,10 @@ def to_representation(self, instance):
request = self.context.get('request', None)
ret = super().to_representation(instance)
if request and request.method == 'POST':
# If we're creating the token, show it. Otherwise, show the encrypted string
# which is the default from the supermethod.
# If we're creating the token, show it. Otherwise, show the encrypted string.
ret['token'] = self.unencrypted_token
else:
ret['token'] = ENCRYPTED_STRING
return ret

def get_refresh_token(self, obj) -> Optional[str]:
Expand All @@ -56,7 +58,7 @@ def get_refresh_token(self, obj) -> Optional[str]:
if not obj.refresh_token:
return None
elif request and request.method == 'POST':
return getattr(obj.refresh_token, 'token', '')
return self.unencrypted_refresh_token
else:
return ENCRYPTED_STRING
except ObjectDoesNotExist:
Expand All @@ -78,27 +80,30 @@ def validate_scope(self, value):
raise ValidationError(_('Must be a simple space-separated string with allowed scopes {}.').format(SCOPES))
return value

def create(self, validated_data):
validated_data['user'] = self.context['request'].user
self.unencrypted_token = validated_data.get('token') # So we don't have to decrypt it
try:
return super().create(validated_data)
except AccessDeniedError as e:
raise PermissionDenied(str(e))


class OAuth2TokenSerializer(BaseOAuth2TokenSerializer):
def create(self, validated_data):
current_user = get_current_user()
validated_data['token'] = generate_token()
expires_delta = get_setting('OAUTH2_PROVIDER', {}).get('ACCESS_TOKEN_EXPIRE_SECONDS', 0)
if expires_delta == 0:
logger.warning("OAUTH2_PROVIDER.ACCESS_TOKEN_EXPIRE_SECONDS was set to 0, creating token that has already expired")
validated_data['expires'] = now() + timedelta(seconds=expires_delta)
obj = super().create(validated_data)
validated_data['user'] = self.context['request'].user
self.unencrypted_token = validated_data.get('token') # Before it is hashed

try:
obj = super().create(validated_data)
except AccessDeniedError as e:
raise PermissionDenied(str(e))

if obj.application and obj.application.user:
obj.user = obj.application.user
obj.save()
if obj.application:
OAuth2RefreshToken.objects.create(user=current_user, token=generate_token(), application=obj.application, access_token=obj)
self.unencrypted_refresh_token = generate_token()
OAuth2RefreshToken.objects.create(
user=current_user,
token=self.unencrypted_refresh_token,
application=obj.application,
access_token=obj,
)
return obj
23 changes: 21 additions & 2 deletions ansible_base/oauth2_provider/views/token.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import hashlib
from datetime import timedelta

from django.utils.timezone import now
from oauth2_provider import views as oauth_views
from oauthlib import oauth2
from rest_framework.viewsets import ModelViewSet

from ansible_base.lib.utils.hashing import hash_string
from ansible_base.lib.utils.settings import get_setting
from ansible_base.lib.utils.views.django_app_api import AnsibleBaseDjangoAppApiView
from ansible_base.oauth2_provider.models import OAuth2AccessToken, OAuth2RefreshToken
Expand All @@ -28,7 +30,8 @@ def create_token_response(self, request):
# This code detects and auto-expires them on refresh grant
# requests.
if request.POST.get('grant_type') == 'refresh_token' and 'refresh_token' in request.POST:
refresh_token = OAuth2RefreshToken.objects.filter(token=request.POST['refresh_token']).first()
hashed_refresh_token = hash_string(request.POST['refresh_token'], hasher=hashlib.sha256)
refresh_token = OAuth2RefreshToken.objects.filter(token=hashed_refresh_token).first()
if refresh_token:
expire_seconds = get_setting('OAUTH2_PROVIDER', {}).get('REFRESH_TOKEN_EXPIRE_SECONDS', 0)
if refresh_token.created + timedelta(seconds=expire_seconds) < now():
Expand All @@ -38,7 +41,23 @@ def create_token_response(self, request):

# oauth2_provider.oauth2_backends.OAuthLibCore.create_token_response
# (we override this so we can implement our own error handling to be compatible with AWX)
uri, http_method, body, headers = core._extract_params(request)

# This is really, really ugly. Modify the request to hash the refresh_token
# but only long enough for the oauth lib to do its magic.
did_hash_refresh_token = False
old_post = request.POST
if 'refresh_token' in request.POST:
did_hash_refresh_token = True
request.POST = request.POST.copy() # so it's mutable
hashed_refresh_token = hash_string(request.POST['refresh_token'], hasher=hashlib.sha256)
request.POST['refresh_token'] = hashed_refresh_token

try:
uri, http_method, body, headers = core._extract_params(request)
finally:
if did_hash_refresh_token:
request.POST = old_post

extra_credentials = core._get_extra_credentials(request)
try:
headers, body, status = core.server.create_token_response(uri, http_method, body, headers, extra_credentials)
Expand Down
24 changes: 14 additions & 10 deletions test_app/tests/oauth2_provider/views/test_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ def test_oauth2_token_create(oauth2_application, admin_api_client, admin_user):
assert 'updated' not in response.data
hashed_token = hash_string(response.data['token'], hasher=hashlib.sha256)
token = OAuth2AccessToken.objects.get(token=hashed_token)
refresh_token = OAuth2RefreshToken.objects.get(token=response.data['refresh_token'])
hashed_refresh_token = hash_string(response.data['refresh_token'], hasher=hashlib.sha256)
refresh_token = OAuth2RefreshToken.objects.get(token=hashed_refresh_token)
assert token.application == oauth2_application
assert refresh_token.application == oauth2_application
assert token.user == admin_user
Expand Down Expand Up @@ -345,12 +346,13 @@ def test_oauth2_refresh_access_token(oauth2_application, oauth2_admin_access_tok
"""
app = oauth2_application[0]
secret = oauth2_application[1]
refresh_token = oauth2_admin_access_token[0].refresh_token
refresh_token = oauth2_admin_access_token[2]
refresh_token_obj = oauth2_admin_access_token[0].refresh_token

url = get_relative_url('token')
data = {
'grant_type': 'refresh_token',
'refresh_token': refresh_token.token,
'refresh_token': refresh_token,
}
resp = unauthenticated_api_client.post(
url,
Expand All @@ -359,8 +361,8 @@ def test_oauth2_refresh_access_token(oauth2_application, oauth2_admin_access_tok
headers={'Authorization': 'Basic ' + base64.b64encode(f"{app.client_id}:{secret}".encode()).decode()},
)
assert resp.status_code == 201
assert OAuth2RefreshToken.objects.filter(token=refresh_token).exists()
original_refresh_token = OAuth2RefreshToken.objects.get(token=refresh_token)
assert OAuth2RefreshToken.objects.filter(token=refresh_token_obj.token).exists()
original_refresh_token = OAuth2RefreshToken.objects.get(token=refresh_token_obj.token)
assert oauth2_admin_access_token not in OAuth2AccessToken.objects.all()
assert OAuth2AccessToken.objects.count() == 1

Expand All @@ -372,11 +374,12 @@ def test_oauth2_refresh_access_token(oauth2_application, oauth2_admin_access_tok
new_token = json_resp['access_token']
new_token_hashed = hash_string(new_token, hasher=hashlib.sha256)
new_refresh_token = json_resp['refresh_token']
new_refresh_token_hashed = hash_string(new_refresh_token, hasher=hashlib.sha256)

assert OAuth2AccessToken.objects.filter(token=new_token_hashed).count() == 1
# checks that RefreshTokens are rotated (new RefreshToken issued)
assert OAuth2RefreshToken.objects.filter(token=new_refresh_token).count() == 1
new_refresh_obj = OAuth2RefreshToken.objects.get(token=new_refresh_token)
assert OAuth2RefreshToken.objects.filter(token=new_refresh_token_hashed).count() == 1
new_refresh_obj = OAuth2RefreshToken.objects.get(token=new_refresh_token_hashed)
assert not new_refresh_obj.revoked


Expand All @@ -387,7 +390,8 @@ def test_oauth2_refresh_token_expiration_is_respected(oauth2_application, oauth2
"""
app = oauth2_application[0]
secret = oauth2_application[1]
refresh_token = oauth2_admin_access_token[0].refresh_token
refresh_token = oauth2_admin_access_token[2]
refresh_token_obj = oauth2_admin_access_token[0].refresh_token

settings.OAUTH2_PROVIDER['REFRESH_TOKEN_EXPIRE_SECONDS'] = 1
settings.OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS'] = 1
Expand All @@ -397,7 +401,7 @@ def test_oauth2_refresh_token_expiration_is_respected(oauth2_application, oauth2
url = get_relative_url('token')
data = {
'grant_type': 'refresh_token',
'refresh_token': refresh_token.token,
'refresh_token': refresh_token,
}
response = admin_api_client.post(
url,
Expand All @@ -407,7 +411,7 @@ def test_oauth2_refresh_token_expiration_is_respected(oauth2_application, oauth2
)
assert response.status_code == 403
assert b'The refresh token has expired.' in response.content
assert OAuth2RefreshToken.objects.filter(token=refresh_token).exists()
assert OAuth2RefreshToken.objects.filter(token=refresh_token_obj.token).exists()
assert OAuth2AccessToken.objects.count() == 1
assert OAuth2RefreshToken.objects.count() == 1

Expand Down

0 comments on commit 9ba2d06

Please sign in to comment.