Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle Enum columns backed by Python Enums #164

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions tests/test_enum_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import enum
import sqlalchemy as sa
from tests import MultiDict, ModelFormTestCase


class TestEnumSelectField(ModelFormTestCase):
def setup_method(self, method):
super().setup_method(method)

class TestEnum(enum.Enum):
A = 'a'
B = 'b'

self.init(type_=sa.Enum(TestEnum), nullable=True)

def test_valid_options(self):
for option in ['a', 'b']:
form = self.form_class(MultiDict(test_column=option))
assert form.validate()
assert len(form.errors) == 0

def test_invalid_options(self):
for option in ['c', 'unknown']:
form = self.form_class(MultiDict(test_column=option))
assert not form.validate()
assert len(form.errors['test_column']) == 2
10 changes: 10 additions & 0 deletions tests/test_select_field.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
from decimal import Decimal

import six
Expand Down Expand Up @@ -84,3 +85,12 @@ def test_unicode_text_coerces_values_to_unicode_strings(self):
form = self.form_class(MultiDict({'test_column': '2.0'}))
assert form.test_column.data == u'2.0'
assert isinstance(form.test_column.data, six.text_type)

def test_enum_coerces_values_to_enums(self):
class TestEnum(enum.Enum):
A = 'a'
B = 'b'

self.init(type_=sa.Enum(TestEnum), nullable=True)
form = self.form_class(MultiDict({'test_column': 'a'}))
assert form.test_column.data == TestEnum.A
13 changes: 12 additions & 1 deletion tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from tests import ModelFormTestCase
from wtforms_alchemy import (
CountryField,
EnumSelectField,
ModelForm,
null_or_unicode,
PhoneNumberField,
Expand Down Expand Up @@ -136,10 +137,20 @@ def test_custom_numeric_converts_to_decimal_field(self):
def test_enum_field_converts_to_select_field(self):
choices = ['1', '2']
self.init(type_=sa.Enum(*choices))
self.assert_type('test_column', SelectField)
self.assert_type('test_column', EnumSelectField)
form = self.form_class()
assert form.test_column.choices == [(s, s) for s in choices]

def test_builtin_enum_field_converts_to_select_field(self):
class TestEnum(Enum):
A = 'a'
B = 'b'

self.init(type_=sa.Enum(TestEnum))
self.assert_type('test_column', EnumSelectField)
form = self.form_class()
assert form.test_column.choices == [('a', TestEnum.A), ('b', TestEnum.B)]

def test_nullable_enum_uses_null_or_unicode_coerce_func_by_default(self):
choices = ['1', '2']
self.init(type_=sa.Enum(*choices), nullable=True)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_validators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
from datetime import datetime, time

import sqlalchemy as sa
Expand Down Expand Up @@ -265,3 +266,14 @@ def get_session():

form = ModelTestForm()
assert form.test_column.validators[2].message == 'Not unique'

def test_enum_validators(self):
class TestEnum(enum.Enum):
A = 'a'
B = 'b'

self.init(type_=sa.Enum(TestEnum), nullable=True)
form = self.form_class()

assert len(form.test_column.validators) == 1
assert isinstance(form.test_column.validators[0], Optional)
2 changes: 2 additions & 0 deletions wtforms_alchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from .fields import ( # noqa
CountryField,
EnumSelectField,
GroupedQuerySelectField,
GroupedQuerySelectMultipleField,
ModelFieldList,
Expand All @@ -40,6 +41,7 @@

__all__ = (
AttributeTypeException,
EnumSelectField,
CountryField,
DateRange,
InvalidAttributeException,
Expand Down
20 changes: 20 additions & 0 deletions wtforms_alchemy/fields.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import unicode_literals

import enum
import operator
from itertools import groupby

Expand Down Expand Up @@ -109,6 +110,7 @@ def populate_obj(self, obj, name):
FieldList.populate_obj(self, obj, name)



class CountryField(SelectField):
def __init__(self, *args, **kwargs):
kwargs['coerce'] = Country
Expand All @@ -127,6 +129,24 @@ def _get_choices(self):
return sorted(territories, key=operator.itemgetter(1))


class EnumSelectField(SelectField):
@property
def choice_values(self):
values = []
for value, label in self.concrete_choices:
if isinstance(label, enum.Enum):
values.append(label)
elif isinstance(label, (list, tuple)):
for subvalue, sublabel in label:
if isinstance(sublabel, enum.Enum):
values.append(sublabel)
else:
values.append(subvalue)
else:
values.append(value)
return values


class QuerySelectField(SelectFieldBase):
"""
Will display a select drop-down field to choose between ORM results in a
Expand Down
12 changes: 10 additions & 2 deletions wtforms_alchemy/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
InvalidAttributeException,
UnknownTypeException
)
from .fields import CountryField, PhoneNumberField, WeekDaysField
from .fields import CountryField, EnumSelectField, PhoneNumberField, WeekDaysField
from .utils import (
choice_type_coerce_factory,
ClassMap,
Expand Down Expand Up @@ -78,7 +78,7 @@ class FormGenerator(object):
(sa.types.Boolean, BooleanField),
(sa.types.Date, DateField),
(sa.types.DateTime, DateTimeField),
(sa.types.Enum, SelectField),
(sa.types.Enum, EnumSelectField),
(sa.types.Float, FloatField),
(sa.types.Integer, IntegerField),
(sa.types.Numeric, DecimalField),
Expand Down Expand Up @@ -449,6 +449,10 @@ def select_field_kwargs(self, column):
kwargs['choices'] = choices
elif 'choices' in column.info and column.info['choices']:
kwargs['choices'] = column.info['choices']
elif issubclass(column.type.python_type, Enum):
kwargs['choices'] = [
(choice.value, choice) for choice in column.type.python_type
]
else:
kwargs['choices'] = [
(enum, enum) for enum in column.type.enums
Expand Down Expand Up @@ -588,6 +592,10 @@ def length_validator(self, column):

:param column: SQLAlchemy Column object
"""
if (isinstance(column.type, sa.types.Enum) and
issubclass(column.type.python_type, Enum)):
return None

if (
isinstance(column.type, sa.types.String) and
hasattr(column.type, 'length') and
Expand Down