diff --git a/ansible_base/oauth2_provider/models/access_token.py b/ansible_base/oauth2_provider/models/access_token.py index 89998ff39..204a3bb33 100644 --- a/ansible_base/oauth2_provider/models/access_token.py +++ b/ansible_base/oauth2_provider/models/access_token.py @@ -1,5 +1,6 @@ import oauth2_provider.models as oauth2_models from django.conf import settings +from django.core.exceptions import ValidationError from django.db import connection, models from django.utils.timezone import now from django.utils.translation import gettext_lazy as _ @@ -10,6 +11,18 @@ from ansible_base.lib.utils.settings import get_setting from ansible_base.oauth2_provider.utils import is_external_account +SCOPES = ['read', 'write'] + + +def validate_scope(value): + given_scopes = value.split(' ') + if not given_scopes: + raise ValidationError(_('Scope must be a simple space-separated string with allowed scopes: %(scopes)s') % {'scopes': ', '.join(SCOPES)}) + for scope in given_scopes: + if scope not in SCOPES: + raise ValidationError(_('Invalid scope: %(scope)s. Must be one of: %(scopes)s') % {'scope': scope, 'scopes': ', '.join(SCOPES)}) + + activitystream = object if 'ansible_base.activitystream' in settings.INSTALLED_APPS: from ansible_base.activitystream.models import AuditableModel @@ -52,10 +65,10 @@ class Meta(oauth2_models.AbstractAccessToken.Meta): editable=False, ) scope = models.CharField( - blank=True, default='write', max_length=32, help_text=_("Allowed scopes, further restricts user's permissions. Must be a simple space-separated string with allowed scopes ['read', 'write']."), + validators=[validate_scope], ) token = prevent_search( models.CharField( diff --git a/test_app/tests/oauth2_provider/views/test_token.py b/test_app/tests/oauth2_provider/views/test_token.py index b268943b8..15d1594f2 100644 --- a/test_app/tests/oauth2_provider/views/test_token.py +++ b/test_app/tests/oauth2_provider/views/test_token.py @@ -425,3 +425,33 @@ def test_oauth2_tokens_list_for_user( response = admin_api_client.get(url) assert response.status_code == 200 assert len(response.data['results']) == 6 + + +@pytest.mark.parametrize( + 'given,error', + [ + ('read write', None), + ('read', None), + ('write', None), + ('read write foo', 'Invalid scope: foo'), + ('foo', 'Invalid scope: foo'), + ('', None), # default scope is 'write' + ], +) +@pytest.mark.django_db +def test_oauth2_token_scope_validator(user_api_client, given, error): + """ + Ensure that the scope validator works as expected. + """ + + url = reverse("token-list") + + # Create PAT + data = { + 'description': 'new PAT', + 'scope': given, + } + response = user_api_client.post(url, data=data) + assert response.status_code == 400 if error else 201 + if error: + assert error in str(response.data['scope'][0])