diff --git a/traits/api.py b/traits/api.py index 762a64e5b..e5cb0ce44 100644 --- a/traits/api.py +++ b/traits/api.py @@ -78,6 +78,7 @@ Module, Python, ReadOnly, + TypedReadOnly, Disallow, Constant, Delegate, diff --git a/traits/tests/test_typed_read_only.py b/traits/tests/test_typed_read_only.py new file mode 100644 index 000000000..0eb3f8534 --- /dev/null +++ b/traits/tests/test_typed_read_only.py @@ -0,0 +1,156 @@ +# ------------------------------------------------------------------------------ +# +# Copyright (c) 2019, Enthought, Inc. +# All rights reserved. +# +# This software is provided without warranty under the terms of the BSD +# license included in enthought/LICENSE.txt and may be redistributed only +# under the conditions described in the aforementioned license. The license +# is also available online at http://www.enthought.com/licenses/BSD.txt +# +# Thanks for using Enthought open source! +# +# Author: Ioannis Tziakos +# Date: 05/24/2019 +# +# ------------------------------------------------------------------------------ +import unittest + +from traits.api import ( + HasStrictTraits, Int, List, TraitError, TypedReadOnly, Undefined) + + +class Dummy(HasStrictTraits): + + value_1 = TypedReadOnly(List(Int)) + + value_2 = TypedReadOnly(Int) + + events_1 = List + + events_2 = List + + def _value_2_default(self): + return 8 + + def _value_1_changed(self, old, new): + self.events_1.append((old, new)) + + def _value_2_changed(self, old, new): + self.events_2.append((old, new)) + + +class TestTypedReadOnly(unittest.TestCase): + + def test_initialization(self): + # given + dummy = Dummy() + + # when/then + self.assertEqual(dummy.events_1, []) + self.assertEqual(dummy.events_2, []) + self.assertEqual(dummy.value_1, []) + self.assertEqual(dummy.value_2, 8) + + # when/then + dummy.events_2 = [] + with self.assertRaises(TraitError): + dummy.value_2 = 23 + self.assertEqual(dummy.events_2, []) + + def test_setting_value_at_initialization(self): + # given + dummy = Dummy(value_1=[1, 2, 7], value_2=46) + + # when/then + self.assertEqual(dummy.events_1, [(Undefined, [1, 2, 7])]) + self.assertEqual(dummy.events_2, [(Undefined, 46)]) + self.assertEqual(dummy.value_1, [1, 2, 7]) + self.assertEqual(dummy.value_2, 46) + + # when/then + dummy.events_2 = [] + with self.assertRaises(TraitError): + dummy.value_2 = 23 + self.assertEqual(dummy.events_2, []) + + # when/then + dummy.events_1 = [] + with self.assertRaises(TraitError): + dummy.value_1 = [] + self.assertEqual(dummy.events_1, []) + + def test_setting_value_after_initialization(self): + # given + dummy = Dummy() + + # when + dummy.value_2 = 23 + + # then + self.assertEqual(dummy.events_2, [(Undefined, 23)]) + self.assertEqual(dummy.events_1, []) + self.assertEqual(dummy.value_2, 23) + + # when + dummy.value_1 = [9] + + # then + self.assertEqual(dummy.events_1, [(Undefined, [9])]) + self.assertEqual(dummy.events_2, [(Undefined, 23)]) + self.assertEqual(dummy.value_1, [9]) + + # when/then + dummy.events_2 = [] + with self.assertRaises(TraitError): + dummy.value_2 = 23 + self.assertEqual(dummy.events_2, []) + + # when/then + dummy.events_1 = [] + with self.assertRaises(TraitError): + dummy.value_1 = [] + self.assertEqual(dummy.events_1, []) + + def test_invalid_value(self): + # given + dummy = Dummy() + + # when/then + with self.assertRaises(TraitError): + dummy.value_2 = '23' + self.assertEqual(dummy.events_1, []) + self.assertEqual(dummy.events_2, []) + + # when/then + with self.assertRaises(TraitError): + dummy.value_1 = [2, 'dff'] + self.assertEqual(dummy.events_1, []) + self.assertEqual(dummy.events_2, []) + + def test_clone(self): + # given + dummy = Dummy(value_1=[1, 2, 7], value_2=46) + dummy.events_1 = [] + dummy.events_2 = [] + + # when + cloned = dummy.clone_traits() + + # then + self.assertEqual(cloned.events_1, [(Undefined, [1, 2, 7])]) + self.assertEqual(cloned.events_2, [(Undefined, 46)]) + self.assertEqual(cloned.value_1, [1, 2, 7]) + self.assertEqual(cloned.value_2, 46) + + # when/then + cloned.events_2 = [] + with self.assertRaises(TraitError): + cloned.value_2 = 23 + self.assertEqual(cloned.events_2, []) + + # when/then + cloned.events_1 = [] + with self.assertRaises(TraitError): + cloned.value_1 = [] + self.assertEqual(cloned.events_1, []) diff --git a/traits/trait_types.py b/traits/trait_types.py index 61a1002f8..29b1e260b 100644 --- a/traits/trait_types.py +++ b/traits/trait_types.py @@ -983,6 +983,51 @@ class ReadOnly(TraitType): ReadOnly = ReadOnly() +class TypedReadOnly(TraitType): + """ A typed write once read many trait. + + The trait allows a compatible value (default is ``Any``) to be + assigned to the attribute if the current value is the + ``Undefined`` object. Once any other value is assigned, no + further assignment is allowed. Normally, the initial + assignment to the attribute is performed in the class + constructor, based on information passed to the + constructor. If the read-only value is known in advance of + run-time, use the ``Constant()`` function instead of + ``TypedReadOnly`` to define the trait. + """ + + def __init__(self, trait=Any, **metadata): + if isinstance(trait, type): + self.inner_traits = (trait(),) + else: + self.inner_traits = (trait,) + super(TypedReadOnly, self).__init__(**metadata) + + def get(self, object, name, trait): + inner_trait = self.inner_traits[0] + value = inner_trait.get_value(object, name, trait) + if value is Undefined: + return inner_trait.get_default_value()[1] + else: + return value + + def set(self, object, name, value): + cname = TraitsCache + name + old = object.__dict__.get(cname, Undefined) + if old is Undefined: + value = self.inner_traits[0].validate(object, name, value) + object.__dict__[cname] = value + object.trait_property_changed(name, old, value) + else: + message = u'Cannot set {!r} of {!r} more than once.' + raise TraitError(message.format(name, object.__class__.__name__)) + + +# Create a singleton instance as the trait: +TypedReadOnly = TypedReadOnly() + + class Disallow(TraitType): """ A trait that prevents any value from being assigned or read.