diff --git a/tests/test_enum_field.py b/tests/test_enum_field.py new file mode 100644 index 0000000..c415bcc --- /dev/null +++ b/tests/test_enum_field.py @@ -0,0 +1,26 @@ +import enum +import sqlalchemy as sa +from tests import MultiDict, ModelFormTestCase + + +class TestEnumSelectField(ModelFormTestCase): + def setup_method(self, method): + super().setup_method(method) + + class TestEnum(enum.Enum): + A = 'a' + B = 'b' + + self.init(type_=sa.Enum(TestEnum), nullable=True) + + def test_valid_options(self): + for option in ['a', 'b']: + form = self.form_class(MultiDict(test_column=option)) + assert form.validate() + assert len(form.errors) == 0 + + def test_invalid_options(self): + for option in ['c', 'unknown']: + form = self.form_class(MultiDict(test_column=option)) + assert not form.validate() + assert len(form.errors['test_column']) == 2 diff --git a/tests/test_select_field.py b/tests/test_select_field.py index 64e060e..a52ce15 100644 --- a/tests/test_select_field.py +++ b/tests/test_select_field.py @@ -1,3 +1,4 @@ +import enum from decimal import Decimal import six @@ -84,3 +85,12 @@ def test_unicode_text_coerces_values_to_unicode_strings(self): form = self.form_class(MultiDict({'test_column': '2.0'})) assert form.test_column.data == u'2.0' assert isinstance(form.test_column.data, six.text_type) + + def test_enum_coerces_values_to_enums(self): + class TestEnum(enum.Enum): + A = 'a' + B = 'b' + + self.init(type_=sa.Enum(TestEnum), nullable=True) + form = self.form_class(MultiDict({'test_column': 'a'})) + assert form.test_column.data == TestEnum.A \ No newline at end of file diff --git a/tests/test_types.py b/tests/test_types.py index 04e16cd..35364e5 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -38,6 +38,7 @@ from tests import ModelFormTestCase from wtforms_alchemy import ( CountryField, + EnumSelectField, ModelForm, null_or_unicode, PhoneNumberField, @@ -136,10 +137,20 @@ def test_custom_numeric_converts_to_decimal_field(self): def test_enum_field_converts_to_select_field(self): choices = ['1', '2'] self.init(type_=sa.Enum(*choices)) - self.assert_type('test_column', SelectField) + self.assert_type('test_column', EnumSelectField) form = self.form_class() assert form.test_column.choices == [(s, s) for s in choices] + def test_builtin_enum_field_converts_to_select_field(self): + class TestEnum(Enum): + A = 'a' + B = 'b' + + self.init(type_=sa.Enum(TestEnum)) + self.assert_type('test_column', EnumSelectField) + form = self.form_class() + assert form.test_column.choices == [('a', TestEnum.A), ('b', TestEnum.B)] + def test_nullable_enum_uses_null_or_unicode_coerce_func_by_default(self): choices = ['1', '2'] self.init(type_=sa.Enum(*choices), nullable=True) diff --git a/tests/test_validators.py b/tests/test_validators.py index 389b29a..ead4768 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -1,3 +1,4 @@ +import enum from datetime import datetime, time import sqlalchemy as sa @@ -265,3 +266,14 @@ def get_session(): form = ModelTestForm() assert form.test_column.validators[2].message == 'Not unique' + + def test_enum_validators(self): + class TestEnum(enum.Enum): + A = 'a' + B = 'b' + + self.init(type_=sa.Enum(TestEnum), nullable=True) + form = self.form_class() + + assert len(form.test_column.validators) == 1 + assert isinstance(form.test_column.validators[0], Optional) \ No newline at end of file diff --git a/wtforms_alchemy/__init__.py b/wtforms_alchemy/__init__.py index 3e532d6..39a62af 100644 --- a/wtforms_alchemy/__init__.py +++ b/wtforms_alchemy/__init__.py @@ -19,6 +19,7 @@ ) from .fields import ( # noqa CountryField, + EnumSelectField, GroupedQuerySelectField, GroupedQuerySelectMultipleField, ModelFieldList, @@ -40,6 +41,7 @@ __all__ = ( AttributeTypeException, + EnumSelectField, CountryField, DateRange, InvalidAttributeException, diff --git a/wtforms_alchemy/fields.py b/wtforms_alchemy/fields.py index 65abb3b..3e48689 100644 --- a/wtforms_alchemy/fields.py +++ b/wtforms_alchemy/fields.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals +import enum import operator from itertools import groupby @@ -109,6 +110,7 @@ def populate_obj(self, obj, name): FieldList.populate_obj(self, obj, name) + class CountryField(SelectField): def __init__(self, *args, **kwargs): kwargs['coerce'] = Country @@ -127,6 +129,24 @@ def _get_choices(self): return sorted(territories, key=operator.itemgetter(1)) +class EnumSelectField(SelectField): + @property + def choice_values(self): + values = [] + for value, label in self.concrete_choices: + if isinstance(label, enum.Enum): + values.append(label) + elif isinstance(label, (list, tuple)): + for subvalue, sublabel in label: + if isinstance(sublabel, enum.Enum): + values.append(sublabel) + else: + values.append(subvalue) + else: + values.append(value) + return values + + class QuerySelectField(SelectFieldBase): """ Will display a select drop-down field to choose between ORM results in a diff --git a/wtforms_alchemy/generator.py b/wtforms_alchemy/generator.py index 1b1f2aa..ee5786e 100644 --- a/wtforms_alchemy/generator.py +++ b/wtforms_alchemy/generator.py @@ -46,7 +46,7 @@ InvalidAttributeException, UnknownTypeException ) -from .fields import CountryField, PhoneNumberField, WeekDaysField +from .fields import CountryField, EnumSelectField, PhoneNumberField, WeekDaysField from .utils import ( choice_type_coerce_factory, ClassMap, @@ -78,7 +78,7 @@ class FormGenerator(object): (sa.types.Boolean, BooleanField), (sa.types.Date, DateField), (sa.types.DateTime, DateTimeField), - (sa.types.Enum, SelectField), + (sa.types.Enum, EnumSelectField), (sa.types.Float, FloatField), (sa.types.Integer, IntegerField), (sa.types.Numeric, DecimalField), @@ -449,6 +449,10 @@ def select_field_kwargs(self, column): kwargs['choices'] = choices elif 'choices' in column.info and column.info['choices']: kwargs['choices'] = column.info['choices'] + elif issubclass(column.type.python_type, Enum): + kwargs['choices'] = [ + (choice.value, choice) for choice in column.type.python_type + ] else: kwargs['choices'] = [ (enum, enum) for enum in column.type.enums @@ -588,6 +592,10 @@ def length_validator(self, column): :param column: SQLAlchemy Column object """ + if (isinstance(column.type, sa.types.Enum) and + issubclass(column.type.python_type, Enum)): + return None + if ( isinstance(column.type, sa.types.String) and hasattr(column.type, 'length') and