Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for Literal types. #534

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions dataclasses_json/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Tuple, TypeVar, Type)
from uuid import UUID

from typing_inspect import is_union_type # type: ignore
from typing_inspect import is_union_type, is_literal_type # type: ignore

from dataclasses_json import cfg
from dataclasses_json.utils import (_get_type_cons, _get_type_origin,
Expand Down Expand Up @@ -358,7 +358,8 @@ def _decode_dict_keys(key_type, xs, infer_missing):
# This is a special case for Python 3.7 and Python 3.8.
# By some reason, "unbound" dicts are counted
# as having key type parameter to be TypeVar('KT')
if key_type is None or key_type == Any or isinstance(key_type, TypeVar):
# Literal types are also passed through without any decoding.
if key_type is None or key_type == Any or isinstance(key_type, TypeVar) or is_literal_type(key_type):
decode_function = key_type = (lambda x: x)
# handle a nested python dict that has tuples for keys. E.g. for
# Dict[Tuple[int], int], key_type will be typing.Tuple[int], but
Expand Down
59 changes: 54 additions & 5 deletions dataclasses_json/mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
from uuid import UUID
from enum import Enum

from typing_inspect import is_union_type # type: ignore
from typing_inspect import is_union_type, is_literal_type # type: ignore

from marshmallow import fields, Schema, post_load # type: ignore
from marshmallow.exceptions import ValidationError # type: ignore

from dataclasses_json.core import (_is_supported_generic, _decode_dataclass,
_ExtendedEncoder, _user_overrides_or_exts)
from dataclasses_json.utils import (_is_collection, _is_optional,
from dataclasses_json.utils import (_get_type_args, _is_collection, _is_optional,
_issubclass_safe, _timestamp_to_dt_aware,
_is_new_type, _get_type_origin,
_handle_undefined_parameters_safe,
Expand Down Expand Up @@ -130,6 +130,46 @@ def _deserialize(self, value, attr, data, **kwargs):
return None if optional_list is None else tuple(optional_list)


class _LiteralField(fields.Field):
def __init__(self, literal_values, cls, field, *args, **kwargs):
"""Create a new Literal field.

Literals allow you to specify the set of valid _values_ for a field. The field
implementation validates against these values on deserialization.

Example:
>>> @dataclass
... class DataClassWithLiteral(DataClassJsonMixin):
... read_mode: Literal["r", "w", "a"]

Args:
literal_values: A sequence of possible values for the field.
cls: The dataclass that the field belongs to.
field: The field that the schema describes.
"""
self.literal_values = literal_values
self.cls = cls
self.field = field
super().__init__(*args, **kwargs)

def _serialize(self, value, attr, obj, **kwargs):
if self.allow_none and value is None:
return None
if value not in self.literal_values:
warnings.warn(
f'The value "{value}" is not one of the values of typing.Literal '
f'(dataclass: {self.cls.__name__}, field: {self.field.name}). '
f'Value will not be de-serialized properly.')
return super()._serialize(value, attr, obj, **kwargs)

def _deserialize(self, value, attr, data, **kwargs):
if value not in self.literal_values:
raise ValidationError(
f'Value "{value}" is not one in typing.Literal{self.literal_values} '
f'(dataclass: {self.cls.__name__}, field: {self.field.name}).')
return super()._deserialize(value, attr, data, **kwargs)


TYPES = {
typing.Mapping: fields.Mapping,
typing.MutableMapping: fields.Mapping,
Expand Down Expand Up @@ -259,9 +299,14 @@ def inner(type_, options):
f"`dataclass_json` decorator or mixin.")
return fields.Field(**options)

origin = getattr(type_, '__origin__', type_)
args = [inner(a, {}) for a in getattr(type_, '__args__', []) if
a is not type(None)]
origin = _get_type_origin(type_)

# Type arguments are typically types (e.g. int in list[int]) except for Literal
# types, where they are values.
if is_literal_type(type_):
args = []
else:
args = [inner(a, {}) for a in _get_type_args(type_) if a is not type(None)]

if type_ == Ellipsis:
return type_
Expand All @@ -279,6 +324,10 @@ def inner(type_, options):
if _issubclass_safe(origin, Enum):
return fields.Enum(enum=origin, by_value=True, *args, **options)

if is_literal_type(type_):
literal_values = _get_type_args(type_)
return _LiteralField(literal_values, cls, field, **options)

if is_union_type(type_):
union_types = [a for a in getattr(type_, '__args__', []) if
a is not type(None)]
Expand Down
99 changes: 99 additions & 0 deletions tests/test_literals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""Test dataclasses_json handling of Literal types."""
import sys
import pytest

if sys.version_info < (3, 8):
pytest.skip("Literal types are only supported in Python 3.8+", allow_module_level=True)

import json
from typing import Literal, Optional, List, Dict

from dataclasses import dataclass

from dataclasses_json import dataclass_json, DataClassJsonMixin
from marshmallow.exceptions import ValidationError # type: ignore


@dataclass_json
@dataclass
class DataClassWithLiteral(DataClassJsonMixin):
numeric_literals: Literal[0, 1]
string_literals: Literal["one", "two", "three"]
mixed_literals: Literal[0, "one", 2]


with_valid_literal_json = '{"numeric_literals": 0, "string_literals": "one", "mixed_literals": 2}'
with_valid_literal_data = DataClassWithLiteral(numeric_literals=0, string_literals="one", mixed_literals=2)
with_invalid_literal_json = '{"numeric_literals": 9, "string_literals": "four", "mixed_literals": []}'
with_invalid_literal_data = DataClassWithLiteral(numeric_literals=9, string_literals="four", mixed_literals=[]) # type: ignore

@dataclass_json
@dataclass
class DataClassWithNestedLiteral(DataClassJsonMixin):
list_of_literals: List[Literal[0, 1]]
dict_of_literals: Dict[Literal["one", "two", "three"], Literal[0, 1]]
optional_literal: Optional[Literal[0, 1]]

with_valid_nested_literal_json = '{"list_of_literals": [0, 1], "dict_of_literals": {"one": 0, "two": 1}, "optional_literal": 1}'
with_valid_nested_literal_data = DataClassWithNestedLiteral(list_of_literals=[0, 1], dict_of_literals={"one": 0, "two": 1}, optional_literal=1)
with_invalid_nested_literal_json = '{"list_of_literals": [0, 2], "dict_of_literals": {"one": 0, "four": 2}, "optional_literal": 2}'
with_invalid_nested_literal_data = DataClassWithNestedLiteral(list_of_literals=[0, 2], dict_of_literals={"one": 0, "four": 2}, optional_literal=2) # type: ignore

class TestEncoder:
def test_valid_literal(self):
assert with_valid_literal_data.to_dict(encode_json=True) == json.loads(with_valid_literal_json)

def test_invalid_literal(self):
assert with_invalid_literal_data.to_dict(encode_json=True) == json.loads(with_invalid_literal_json)

def test_valid_nested_literal(self):
assert with_valid_nested_literal_data.to_dict(encode_json=True) == json.loads(with_valid_nested_literal_json)

def test_invalid_nested_literal(self):
assert with_invalid_nested_literal_data.to_dict(encode_json=True) == json.loads(with_invalid_nested_literal_json)


class TestSchemaEncoder:
def test_valid_literal(self):
actual = DataClassWithLiteral.schema().dumps(with_valid_literal_data)
assert json.loads(actual) == json.loads(with_valid_literal_json)

def test_invalid_literal(self):
actual = DataClassWithLiteral.schema().dumps(with_invalid_literal_data)
assert json.loads(actual) == json.loads(with_invalid_literal_json)

def test_valid_nested_literal(self):
actual = DataClassWithNestedLiteral.schema().dumps(with_valid_nested_literal_data)
assert json.loads(actual) == json.loads(with_valid_nested_literal_json)

def test_invalid_nested_literal(self):
actual = DataClassWithNestedLiteral.schema().dumps(with_invalid_nested_literal_data)
assert json.loads(actual) == json.loads(with_invalid_nested_literal_json)

class TestDecoder:
def test_valid_literal(self):
actual = DataClassWithLiteral.from_json(with_valid_literal_json)
assert actual == with_valid_literal_data

def test_invalid_literal(self):
expected = DataClassWithLiteral(numeric_literals=9, string_literals="four", mixed_literals=[]) # type: ignore
actual = DataClassWithLiteral.from_json(with_invalid_literal_json)
assert actual == expected


class TestSchemaDecoder:
def test_valid_literal(self):
actual = DataClassWithLiteral.schema().loads(with_valid_literal_json)
assert actual == with_valid_literal_data

def test_invalid_literal(self):
with pytest.raises(ValidationError):
DataClassWithLiteral.schema().loads(with_invalid_literal_json)

def test_valid_nested_literal(self):
actual = DataClassWithNestedLiteral.schema().loads(with_valid_nested_literal_json)
assert actual == with_valid_nested_literal_data

def test_invalid_nested_literal(self):
with pytest.raises(ValidationError):
DataClassWithNestedLiteral.schema().loads(with_invalid_nested_literal_json)
Loading