diff --git a/envisage/extension_point.py b/envisage/extension_point.py index 4391e576b..835edb8cc 100644 --- a/envisage/extension_point.py +++ b/envisage/extension_point.py @@ -11,11 +11,14 @@ # Standard library imports. +from functools import wraps import inspect +import warnings import weakref # Enthought library imports. from traits.api import List, TraitType, Undefined, provides +from traits.trait_list_object import TraitList # Local imports. from .i_extension_point import IExtensionPoint @@ -156,14 +159,14 @@ def __repr__(self): def get(self, obj, trait_name): """ Trait type getter. """ + cache_name = _get_cache_name(trait_name) + if cache_name not in obj.__dict__: + _update_cache(obj, trait_name) - extension_registry = self._get_extension_registry(obj) - - # Get the extensions to this extension point. - extensions = extension_registry.get_extensions(self.id) - - # Make sure the contributions are of the appropriate type. - return self.trait_type.validate(obj, trait_name, extensions) + value = obj.__dict__[cache_name] + # validate again + self.trait_type.validate(obj, trait_name, value[:]) + return value def set(self, obj, name, value): """ Trait type setter. """ @@ -192,21 +195,41 @@ def connect(self, obj, trait_name): """ def listener(extension_registry, event): - """ Listener called when an extension point is changed. """ - - # If an index was specified then we fire an '_items' changed event. + """ Listener called when an extension point is changed. + + Parameters + ---------- + extension_registry : IExtensionRegistry + Registry that maintains the extensions. + event : ExtensionPointChangedEvent + Event created for the change. + If the event.index is None, this means the entire extensions + list is set to a new value. If the event.index is not None, + some portion of the list has been modified. + """ if event.index is not None: - name = trait_name + "_items" - old = Undefined - new = event + # We know where in the list is changed. + + # Mutate the _ExtensionPointValue to fire ListChangeEvent + # expected from observing item change. + getattr(obj, trait_name)._sync_values(event) + + # For on_trait_change('name_items') + obj.trait_property_changed( + trait_name + "_items", Undefined, event + ) - # Otherwise, we fire a normal trait changed event. else: - name = trait_name - old = event.removed - new = event.added + # The entire list has changed. We reset the cache and fire a + # normal trait changed event. + _update_cache(obj, trait_name) - obj.trait_property_changed(name, old, new) + # In case the cache was created first and the registry is then mutated + # before this ``connect`` is called, the internal cache would be in + # an inconsistent state. This also has the side-effect of firing + # another change event, hence allowing future changes to be observed + # without having to access the trait first. + _update_cache(obj, trait_name) extension_registry = self._get_extension_registry(obj) @@ -250,3 +273,166 @@ def _get_extension_registry(self, obj): ) return extension_registry + + +def _warn_if_not_internal(func): + """ Decorator for instance methods of _ExtensionPointValue such that its + effect is nullified if the function is not called with the _internal_use + flag set to true. + """ + + @wraps(func) + def decorated(object, *args, **kwargs): + if not object._internal_use: + warnings.warn( + "Extension point cannot be mutated directly.", + RuntimeWarning, + stacklevel=2, + ) + # This restores the existing behavior where the operation + # is acted on a list object that is not persisted. + return func(TraitList(iter(object)), *args, **kwargs) + return func(object, *args, **kwargs) + + return decorated + + +class _ExtensionPointValue(TraitList): + """ _ExtensionPointValue is the list being returned while retrieving the + attribute value for an ExtensionPoint trait. + + This list returned for an ExtensionPoint acts as a proxy to query + extensions in an ExtensionRegistry for a given extension point id. Users of + ExtensionPoint expect to handle a list-like object, and expect to be able + to listen to "mutation" on the list. The ExtensionRegistry remains to be + the source of truth as to what extensions are available for a given + extension point ID. + + Users are not expected to mutate the list directly. All mutations to + extensions are expected to go through the extension registry to maintain + consistency. With that, all methods for mutating the list are nullified, + unless it is used internally. + + The requirement to support ``observe("name:items")`` means this list, + associated with `name`, cannot be a property that gets recomputed on every + access (enthought/traits#624), it needs to be cached. As with any + cached quantity, it needs to be synchronized with the ExtensionRegistry. + + Note that the list can only be synchronized with the extension registry + when the listeners are connected (see ``ExtensionPoint.connect``). + + Parameters + ---------- + iterable : iterable + Iterable providing the items for the list + """ + + def __new__(cls, *args, **kwargs): + # Methods such as 'append' or 'extend' may be called during unpickling. + # Initialize internal flag to true which gets changed back to false + # in __init__. + self = super().__new__(cls) + self._internal_use = True + return self + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Flag to control access for mutating the list. Only internal + # code can mutate the list. See _sync_values + self._internal_use = False + + def _sync_values(self, event): + """ Given an ExtensionPointChangedEvent, modify the values in this list + to match. This is an internal method only used by Envisage code. + + Parameters + ---------- + event : ExtenstionPointChangedEvent + Event being fired for extension point values changed (typically + via the extension registry) + """ + self._internal_use = True + try: + if isinstance(event.index, slice): + if event.added: + self[event.index] = event.added + else: + del self[event.index] + else: + slice_ = slice( + event.index, event.index + len(event.removed) + ) + self[slice_] = event.added + finally: + self._internal_use = False + + __delitem__ = _warn_if_not_internal(TraitList.__delitem__) + __iadd__ = _warn_if_not_internal(TraitList.__iadd__) + __imul__ = _warn_if_not_internal(TraitList.__imul__) + __setitem__ = _warn_if_not_internal(TraitList.__setitem__) + append = _warn_if_not_internal(TraitList.append) + clear = _warn_if_not_internal(TraitList.clear) + extend = _warn_if_not_internal(TraitList.extend) + insert = _warn_if_not_internal(TraitList.insert) + pop = _warn_if_not_internal(TraitList.pop) + remove = _warn_if_not_internal(TraitList.remove) + reverse = _warn_if_not_internal(TraitList.reverse) + sort = _warn_if_not_internal(TraitList.sort) + + +def _get_extensions(object, name): + """ Return the extensions reported by the extension registry for the + given object and the name of a trait whose type is an ExtensionPoint. + + Parameters + ---------- + object : HasTraits + Object on which an ExtensionPoint is defined + name : str + Name of the trait whose trait type is an ExtensionPoint. + + Returns + ------- + extensions : list + All the extensions for the extension point. + """ + extension_point = object.trait(name).trait_type + extension_registry = extension_point._get_extension_registry(object) + + # Get the extensions to this extension point. + return extension_registry.get_extensions(extension_point.id) + + +def _get_cache_name(trait_name): + """ Return the attribute name on the object for storing the cached + extension point value associated with a given trait. + + Parameters + ---------- + trait_name : str + The name of the trait for which ExtensionPoint is defined. + """ + return "__envisage_{}".format(trait_name) + + +def _update_cache(obj, trait_name): + """ Update the internal cached value for the extension point and + fire change event. + + Parameters + ---------- + obj : HasTraits + The object on which an ExtensionPoint is defined. + trait_name : str + The name of the trait for which ExtensionPoint is defined. + """ + cache_name = _get_cache_name(trait_name) + old = obj.__dict__.get(cache_name, Undefined) + new = ( + _ExtensionPointValue( + _get_extensions(obj, trait_name) + ) + ) + obj.__dict__[cache_name] = new + obj.trait_property_changed(trait_name, old, new) diff --git a/envisage/tests/test_extension_point.py b/envisage/tests/test_extension_point.py index b1c6257ee..82b969a6f 100644 --- a/envisage/tests/test_extension_point.py +++ b/envisage/tests/test_extension_point.py @@ -10,12 +10,16 @@ """ Tests for extension points. """ # Standard library imports. +import pickle import unittest +import weakref + +from traits.api import Undefined # Enthought library imports. from envisage.api import Application, ExtensionPoint -from envisage.api import ExtensionRegistry -from traits.api import HasTraits, Int, List, TraitError +from envisage.api import ExtensionRegistry, IExtensionRegistry +from traits.api import HasTraits, Instance, Int, List, TraitError class TestBase(HasTraits): @@ -24,6 +28,16 @@ class TestBase(HasTraits): extension_registry = None +class ClassWithExtensionPoint(HasTraits): + """ Class with an ExtensionPoint for testing purposes. + Defined at the module level for pickability. + """ + + extension_registry = Instance(IExtensionRegistry) + + x = ExtensionPoint(List(Int), id="my.ep") + + class ExtensionPointTestCase(unittest.TestCase): """ Tests for extension points. """ @@ -120,7 +134,7 @@ def test_mutate_extension_point_no_effect(self): registry.add_extension_point(self._create_extension_point("my.ep")) # Set the extensions. - registry.set_extensions("my.ep", [1, 2, 3]) + registry.set_extensions("my.ep", [1, 2, 3, 0, 5]) # Declare a class that consumes the extension. class Foo(TestBase): @@ -128,13 +142,42 @@ class Foo(TestBase): # when f = Foo() - f.x.append(42) + + with self.assertWarns(RuntimeWarning): + f.x.append(42) + + with self.assertWarns(RuntimeWarning): + f.x.clear() + + with self.assertWarns(RuntimeWarning): + f.x.extend((100, 101)) + + with self.assertWarns(RuntimeWarning): + f.x.insert(0, 1) + + with self.assertWarns(RuntimeWarning): + f.x.pop() + + with self.assertWarns(RuntimeWarning): + f.x.remove(1) + + with self.assertWarns(RuntimeWarning): + f.x[0] = 99 + + with self.assertWarns(RuntimeWarning): + del f.x[0:2] + + with self.assertWarns(RuntimeWarning): + f.x.reverse() + + with self.assertWarns(RuntimeWarning): + f.x.sort() # then # The registry is not changed, and the extension point is still the # same as before - self.assertEqual(registry.get_extensions("my.ep"), [1, 2, 3]) - self.assertEqual(f.x, [1, 2, 3]) + self.assertEqual(registry.get_extensions("my.ep"), [1, 2, 3, 0, 5]) + self.assertEqual(f.x.copy(), [1, 2, 3, 0, 5]) def test_untyped_extension_point(self): """ untyped extension point """ @@ -205,6 +248,32 @@ class Foo(TestBase): with self.assertRaises(TraitError): getattr(f, "x") + def test_invalid_extension_point_after_mutation(self): + """ Test extension point becomes invalid later. """ + + registry = self.registry + + # Add an extension point. + registry.add_extension_point(self._create_extension_point("my.ep")) + + # Declare a class that consumes the extension. + class Foo(TestBase): + x = ExtensionPoint(List(Int), id="my.ep") + + # Make sure we get a trait error because the type of the extension + # doesn't match that of the extension point. + f = Foo() + ExtensionPoint.connect_extension_point_traits(f) + + # This is okay, the list is empty. + f.x + + registry.set_extensions("my.ep", "xxx") + + # Now this should fail. + with self.assertRaises(TraitError): + getattr(f, "x") + def test_extension_point_with_no_id(self): """ extension point with no Id """ @@ -253,6 +322,80 @@ class Foo(TestBase): self.assertEqual([42], registry.get_extensions("my.ep")) + def test_set_typed_extension_point_emit_change(self): + """ Test change event is emitted for setting the extension point """ + + registry = self.registry + + # Add an extension point. + registry.add_extension_point(self._create_extension_point("my.ep")) + + # Declare a class that consumes the extension. + class Foo(TestBase): + x = ExtensionPoint(List(Int), id="my.ep") + + on_trait_change_events = [] + + def on_trait_change_handler(*args): + on_trait_change_events.append(args) + + observed_events = [] + + f = Foo() + f.on_trait_change(on_trait_change_handler, "x") + f.observe(observed_events.append, "x") + + # when + ExtensionPoint.connect_extension_point_traits(f) + + # then + self.assertEqual(len(on_trait_change_events), 1) + self.assertEqual(len(observed_events), 1) + event, = observed_events + self.assertEqual(event.object, f) + self.assertEqual(event.name, "x") + self.assertEqual(event.old, Undefined) + self.assertEqual(event.new, []) + + def test_object_garbage_collectable(self): + """ object can be garbage collected after disconnecting listeners.""" + registry = self.registry + + # Add an extension point. + registry.add_extension_point(self._create_extension_point("my.ep")) + + # Declare a class that consumes the extension. + class Foo(TestBase): + x = ExtensionPoint(List(Int), id="my.ep") + + f = Foo() + object_ref = weakref.ref(f) + + # when + ExtensionPoint.connect_extension_point_traits(f) + ExtensionPoint.disconnect_extension_point_traits(f) + del f + + # then + self.assertIsNone(object_ref()) + + def test_object_pickability(self): + # Add an extension point. + self.registry.add_extension_point(ExtensionPoint(id="my.ep")) + + # An object is created, connected to the registry and have the + # extension point created. + f = ClassWithExtensionPoint(extension_registry=self.registry) + ExtensionPoint.connect_extension_point_traits(f) + self.registry.set_extensions("my.ep", [1, 2, 3]) + self.assertEqual(f.x, [1, 2, 3]) + + # then + for protocol in range(pickle.HIGHEST_PROTOCOL + 1): + serialized = pickle.dumps(f.x, protocol=protocol) + deserialized = pickle.loads(serialized) + self.assertEqual(deserialized, [1, 2, 3]) + def test_extension_point_str_representation(self): """ test the string representation of the extension point """ ep_repr = "ExtensionPoint(id={!r})" diff --git a/envisage/tests/test_extension_point_changed.py b/envisage/tests/test_extension_point_changed.py index 69cfc1d84..e3a6d855a 100644 --- a/envisage/tests/test_extension_point_changed.py +++ b/envisage/tests/test_extension_point_changed.py @@ -51,18 +51,59 @@ def test_mutate_extension_point_no_events(self): """ Mutation will not emit change event for name_items """ a = PluginA() + b = PluginB() + c = PluginC() + a.on_trait_change(listener, "x_items") + events = [] + a.observe(events.append, "x:items") + + application = TestApplication(plugins=[a, b, c]) + application.start() + + # when + with self.assertWarns(RuntimeWarning): + a.x.append(42) + + # then + self.assertIsNone(listener.obj) + self.assertEqual(len(events), 0) + + def test_mutate_extension_point_then_modify_from_registry(self): + """ Mutating the extension point does nothing and should not cause + subsequent change event information to become inconsistent. + """ + a = PluginA() b = PluginB() c = PluginC() + a.on_trait_change(listener, "x_items") + events = [] + a.observe(events.append, "x:items") + application = TestApplication(plugins=[a, b, c]) application.start() # when - a.x.append(42) + with self.assertWarns(RuntimeWarning): + a.x.clear() # then self.assertIsNone(listener.obj) + self.assertEqual(len(events), 0) + + # when + # Append a contribution. + b.x.append(4) + + # then + self.assertEqual(a.x, [1, 2, 3, 4, 98, 99, 100]) + self.assertEqual(len(events), 1) + event, = events + self.assertEqual(event.object, a.x) + self.assertEqual(event.index, 3) + self.assertEqual(event.added, [4]) + self.assertEqual(event.removed, []) def test_append(self): """ append """ @@ -75,11 +116,6 @@ def test_append(self): application = TestApplication(plugins=[a, b, c]) application.start() - # fixme: If the extension point has not been accessed then the - # provider extension registry can't work out what has changed, so it - # won't fire a changed event. - self.assertEqual([1, 2, 3, 98, 99, 100], a.x) - # Append a contribution. b.x.append(4) @@ -105,6 +141,30 @@ def test_append(self): self.assertEqual([], listener.new.removed) self.assertEqual(3, listener.new.index) + def test_append_with_observe(self): + """ append with observe """ + + a = PluginA() + b = PluginB() + c = PluginC() + + events = [] + a.observe(events.append, "x:items") + + application = TestApplication(plugins=[a, b, c]) + application.start() + + # Append a contribution. + b.x.append(4) + + # then + self.assertEqual(len(events), 1) + event, = events + self.assertEqual(event.object, a.x) + self.assertEqual(event.index, 3) + self.assertEqual(event.added, [4]) + self.assertEqual(event.removed, []) + def test_remove(self): """ remove """ @@ -116,11 +176,6 @@ def test_remove(self): application = TestApplication(plugins=[a, b, c]) application.start() - # fixme: If the extension point has not been accessed then the - # provider extension registry can't work out what has changed, so it - # won't fire a changed event. - self.assertEqual([1, 2, 3, 98, 99, 100], a.x) - # Remove a contribution. b.x.remove(3) @@ -146,6 +201,30 @@ def test_remove(self): self.assertEqual([3], listener.new.removed) self.assertEqual(2, listener.new.index) + def test_remove_with_observe(self): + """ remove with observing items change. """ + + a = PluginA() + b = PluginB() + c = PluginC() + + events = [] + a.observe(events.append, "x:items") + + application = TestApplication(plugins=[a, b, c]) + application.start() + + # Remove a contribution. + b.x.remove(3) + + # then + self.assertEqual(len(events), 1) + event, = events + self.assertEqual(event.object, a.x) + self.assertEqual(event.index, 2) + self.assertEqual(event.added, []) + self.assertEqual(event.removed, [3]) + def test_assign_empty_list(self): """ assign empty list """ @@ -157,11 +236,6 @@ def test_assign_empty_list(self): application = TestApplication(plugins=[a, b, c]) application.start() - # fixme: If the extension point has not been accessed then the - # provider extension registry can't work out what has changed, so it - # won't fire a changed event. - self.assertEqual([1, 2, 3, 98, 99, 100], a.x) - # Assign an empty list to one of the plugin's contributions. b.x = [] @@ -188,37 +262,29 @@ def test_assign_empty_list(self): self.assertEqual(0, listener.new.index.start) self.assertEqual(3, listener.new.index.stop) - def test_assign_empty_list_no_event(self): - """ assign empty list no event """ + def test_assign_empty_list_with_observe(self): + """ assign an empty list to a plugin triggers a list change event.""" a = PluginA() - a.on_trait_change(listener, "x_items") b = PluginB() c = PluginC() + events = [] + a.observe(events.append, "x:items") + application = TestApplication(plugins=[a, b, c]) application.start() # Assign an empty list to one of the plugin's contributions. b.x = [] - # Make sure we pick up the correct contribution via the application. - extensions = application.get_extensions("a.x") - extensions.sort() - - self.assertEqual(3, len(extensions)) - self.assertEqual([98, 99, 100], extensions) - - # Make sure we pick up the correct contribution via the plugin. - extensions = a.x[:] - extensions.sort() - - self.assertEqual(3, len(extensions)) - self.assertEqual([98, 99, 100], extensions) - - # We shouldn't get a trait event here because we haven't accessed the - # extension point yet! - self.assertEqual(None, listener.obj) + # then + self.assertEqual(len(events), 1) + event, = events + self.assertEqual(event.object, a.x) + self.assertEqual(event.added, []) + self.assertEqual(event.removed, [1, 2, 3]) + self.assertEqual(event.index, 0) def test_assign_non_empty_list(self): """ assign non-empty list """ @@ -231,11 +297,6 @@ def test_assign_non_empty_list(self): application = TestApplication(plugins=[a, b, c]) application.start() - # fixme: If the extension point has not been accessed then the - # provider extension registry can't work out what has changed, so it - # won't fire a changed event. - self.assertEqual([1, 2, 3, 98, 99, 100], a.x) - # Keep the old values for later slicing check source_values = list(a.x) @@ -276,6 +337,30 @@ def test_assign_non_empty_list(self): self.assertEqual(0, listener.new.index.start) self.assertEqual(3, listener.new.index.stop) + def test_assign_non_empty_list_with_observe(self): + """ assign non-empty list """ + + a = PluginA() + b = PluginB() + c = PluginC() + + events = [] + a.observe(events.append, "x:items") + + application = TestApplication(plugins=[a, b, c]) + application.start() + + # Assign a non-empty list to one of the plugin's contributions. + b.x = [2, 4, 6, 8] + + # then + self.assertEqual(len(events), 1) + event, = events + self.assertEqual(event.object, a.x) + self.assertEqual(event.index, 0) + self.assertEqual(event.added, [2, 4, 6, 8]) + self.assertEqual(event.removed, [1, 2, 3]) + def test_add_plugin(self): """ add plugin """ @@ -327,6 +412,31 @@ def test_add_plugin(self): self.assertEqual([], listener.new.removed) self.assertEqual(3, listener.new.index) + def test_add_plugin_with_observe(self): + """ add plugin with observe """ + + a = PluginA() + b = PluginB() + c = PluginC() + + events = [] + a.observe(events.append, "x:items") + + # Start off with just two of the plugins. + application = TestApplication(plugins=[a, b]) + application.start() + + # Now add the other plugin. + application.add_plugin(c) + + # then + self.assertEqual(len(events), 1) + event, = events + self.assertEqual(event.object, a.x) + self.assertEqual(event.index, 3) + self.assertEqual(event.added, [98, 99, 100]) + self.assertEqual(event.removed, []) + def test_remove_plugin(self): """ remove plugin """ @@ -377,6 +487,80 @@ def test_remove_plugin(self): self.assertEqual([1, 2, 3], listener.new.removed) self.assertEqual(0, listener.new.index) + def test_remove_plugin_with_observe(self): + """ remove plugin with observe """ + + a = PluginA() + b = PluginB() + c = PluginC() + + events = [] + a.observe(events.append, "x:items") + + # Start off with just two of the plugins. + application = TestApplication(plugins=[a, b, c]) + application.start() + + # Now remove one plugin. + application.remove_plugin(b) + + # then + self.assertEqual(len(events), 1) + event, = events + self.assertEqual(event.object, a.x) + self.assertEqual(event.index, 0) + self.assertEqual(event.added, []) + self.assertEqual(event.removed, [1, 2, 3]) + + def test_race_condition(self): + """ Test the extension point being modified before the application + starts, changes before starting the application are not notified. + """ + a = PluginA() + b = PluginB() + c = PluginC() + application = TestApplication(plugins=[a, b, c]) + + events = [] + a.observe(events.append, "x:items") + + # This sets the cache. + self.assertEqual(a.x, [1, 2, 3, 98, 99, 100]) + + # Now we mutate the registry, but the application has not started. + b.x = [4, 5, 6] + + # then + # The values are not synchronized. + self.assertEqual(a.x, [1, 2, 3, 98, 99, 100]) + + # application has not started, no events. + self.assertEqual(len(events), 0) + + # Now we start the application, which connects the listener. + application.start() + + # then + self.assertEqual(a.x, [4, 5, 6, 98, 99, 100]) + + # Change the value again. + b.x = [1, 2] + + # then + self.assertEqual(a.x, [1, 2, 98, 99, 100]) + + # The mutation occurred before application starting is not reported. + self.assertEqual(len(events), 1) + event, = events + self.assertEqual(event.object, a.x) + self.assertEqual(event.index, 0) + self.assertEqual(event.added, [1, 2]) + self.assertEqual(event.removed, [4, 5, 6]) + + +class TestExtensionPointChangedEvent(unittest.TestCase): + """ Test ExtensionPointChangedEvent object.""" + def test_extension_point_change_event_str_representation(self): """ test string representation of the ExtensionPointChangedEvent class """ diff --git a/setup.py b/setup.py index ac2388714..0884e5e8c 100644 --- a/setup.py +++ b/setup.py @@ -309,7 +309,7 @@ def get_long_description(): "demo_examples = envisage.examples._etsdemo_info:info", ], }, - install_requires=["apptools", "setuptools", "traits"], + install_requires=["apptools", "setuptools", "traits>=6.1"], extras_require={ "docs": ["enthought-sphinx-theme", "Sphinx>=2.1.0,!=3.2.0"], "ipython": ["ipykernel", "tornado"],