diff --git a/README.rst b/README.rst index 81c46f90..b7cb1883 100644 --- a/README.rst +++ b/README.rst @@ -28,6 +28,7 @@ Features - Elasticsearch auto mapping from django models fields. - Complex field type support (ObjectField, NestedField). - Index fast using `parallel` indexing. +- Bulk operations support. - Requirements - Django >= 1.11 diff --git a/django_elasticsearch_dsl/db/__init__.py b/django_elasticsearch_dsl/db/__init__.py new file mode 100644 index 00000000..ed287ab3 --- /dev/null +++ b/django_elasticsearch_dsl/db/__init__.py @@ -0,0 +1,4 @@ +from .managers import ( # noqa + DjangoElasticsearchDslManagerMixin, + DjangoElasticsearchDslModelManager +) diff --git a/django_elasticsearch_dsl/db/managers.py b/django_elasticsearch_dsl/db/managers.py new file mode 100644 index 00000000..9150598f --- /dev/null +++ b/django_elasticsearch_dsl/db/managers.py @@ -0,0 +1,123 @@ +from django.db import models + +from .utils import get_queryset_by_ids +from ..registries import registry + + +class DjangoElasticsearchDslManagerMixin(object): + """Elasticsearch DSL manager mixin for processing mass work with objects. + + Performs normalization by supported types and causes updating the + search engine appropriately. + + It acts similarly to a signal processor. + """ + _registry = registry + + def _normalize_results(self, result): + if isinstance(result, models.Model): + return [result] + elif isinstance(result, (list, models.QuerySet)): + return result + else: + raise TypeError( + "Incorrect results type. " + "Expected 'django.db.models.Model', " + " or 'django.db.models.Queryset', " + "but got %s" % type(result) + ) + + def _handle_save(self, result): + """Handle save. + + Given a many model instances, update the objects in the index. + Update the related objects either. + """ + results = self._normalize_results(result) + + self._registry.update(results) + self._registry.update_related(results, many=True) + + def _handle_pre_delete(self, result): + """Handle removing of objects from related models instances. + + We need to do this before the real delete otherwise the relation + doesn't exist anymore, and we can't get the related models instances. + """ + results = self._normalize_results(result) + + self._registry.delete_related( + results, + many=True, + raise_on_error=False, + ) + + def _handle_delete(self, result): + """Handle delete. + + Given a many model instances, delete the objects in the index. + """ + results = self._normalize_results(result) + + self._registry.delete( + results, + raise_on_error=False, + ) + + +class DjangoElasticsearchDslModelManager(models.QuerySet, + DjangoElasticsearchDslManagerMixin): + """Django Elasticsearch Dsl model manager. + + Working with possible bulk operations, updates documents accordingly. + """ + + def bulk_create(self, objs, *args, **kwargs): + """Bulk create. + + Calls `handle_save` after saving is completed + """ + result = super().bulk_create(objs, *args, **kwargs) + self._handle_save(result) + return result + + def bulk_update(self, objs, *args, **kwargs): + """Bulk update. + + Calls `handle_save` after saving is completed + """ + result = super().bulk_update(objs, *args, **kwargs) + self._handle_save(objs) + return result + + def update(self, **kwargs): + """Update. + + Calls `handle_save` after saving is completed + """ + ids = list(self.values_list("id", flat=True)) + result = super().update(**kwargs) + if not ids: + return result + self._handle_save(get_queryset_by_ids(self.model, ids)) + return result + + def delete(self): + """Delete. + + Calls `handle_pre_delete` before performing the deletion. + + After deleting it causes `handle_delete`. + """ + objs = get_queryset_by_ids( + self.model, + list(self.values_list("id", flat=True)) + ) + self._handle_pre_delete(objs) + objs = list(objs) + + result = super().delete() + + self._handle_delete(objs) + + return result diff --git a/django_elasticsearch_dsl/db/utils.py b/django_elasticsearch_dsl/db/utils.py new file mode 100644 index 00000000..449816bf --- /dev/null +++ b/django_elasticsearch_dsl/db/utils.py @@ -0,0 +1,9 @@ +from typing import List + +from django.db import models + + +def get_queryset_by_ids(model: models.Model, ids: List[int]): + return model.objects.filter( + id__in=ids + ) diff --git a/django_elasticsearch_dsl/fields.py b/django_elasticsearch_dsl/fields.py index 2652b68e..7ed255ca 100644 --- a/django_elasticsearch_dsl/fields.py +++ b/django_elasticsearch_dsl/fields.py @@ -88,6 +88,15 @@ def get_value_from_instance(self, instance, field_value_to_ignore=None): if instance == field_value_to_ignore: return None + elif isinstance(field_value_to_ignore, models.QuerySet) and \ + isinstance(instance, models.Model) and \ + field_value_to_ignore.contains(instance): + return None + elif isinstance(field_value_to_ignore, models.QuerySet) and \ + isinstance(instance, models.QuerySet): + instance = instance.exclude( + id__in=field_value_to_ignore.values_list("id", flat=True) + ) # convert lazy object like lazy translations to string if isinstance(instance, Promise): diff --git a/django_elasticsearch_dsl/registries.py b/django_elasticsearch_dsl/registries.py index 72510610..4890de89 100644 --- a/django_elasticsearch_dsl/registries.py +++ b/django_elasticsearch_dsl/registries.py @@ -5,6 +5,7 @@ from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ImproperlyConfigured +from django.db import models from elasticsearch_dsl import AttrDict from six import itervalues, iterkeys, iteritems @@ -88,22 +89,47 @@ def register_document(self, document): return document def _get_related_doc(self, instance): - for model in self._related_models.get(instance.__class__, []): + instance_cls = self._get_cls_from_instance(instance) + for model in self._related_models.get(instance_cls, []): for doc in self._models[model]: - if instance.__class__ in doc.django.related_models: + if instance_cls in doc.django.related_models: yield doc + def _get_cls_from_instance(self, instance): + """ + Get class from instance. + + Supports getting a class from a model, list, or queryset. + """ + if isinstance(instance, models.Model): + return instance.__class__ + elif isinstance(instance, list): + return instance[0].__class__ + elif isinstance(instance, models.QuerySet): + return instance.model + else: + return None + def update_related(self, instance, **kwargs): """ Update docs that have related_models. + + The `many` flag has been introduced to handle mass updates of objects. """ if not DEDConfig.autosync_enabled(): return + many = kwargs.pop("many", False) for doc in self._get_related_doc(instance): doc_instance = doc() try: - related = doc_instance.get_instances_from_related(instance) + if many: + related = doc_instance.get_instances_from_many_related( + self._get_cls_from_instance(instance), + instance + ) + else: + related = doc_instance.get_instances_from_related(instance) except ObjectDoesNotExist: related = None @@ -113,14 +139,23 @@ def update_related(self, instance, **kwargs): def delete_related(self, instance, **kwargs): """ Remove `instance` from related models. + + The `many` flag has been introduced to handle mass updates of objects. """ if not DEDConfig.autosync_enabled(): return + many = kwargs.pop("many", False) for doc in self._get_related_doc(instance): doc_instance = doc(related_instance_to_ignore=instance) try: - related = doc_instance.get_instances_from_related(instance) + if many: + related = doc_instance.get_instances_from_many_related( + self._get_cls_from_instance(instance), + instance + ) + else: + related = doc_instance.get_instances_from_related(instance) except ObjectDoesNotExist: related = None @@ -135,8 +170,9 @@ def update(self, instance, **kwargs): if not DEDConfig.autosync_enabled(): return - if instance.__class__ in self._models: - for doc in self._models[instance.__class__]: + instance_cls = self._get_cls_from_instance(instance) + if instance_cls in self._models: + for doc in self._models[instance_cls]: if not doc.django.ignore_signals: doc().update(instance, **kwargs) diff --git a/django_elasticsearch_dsl/signals.py b/django_elasticsearch_dsl/signals.py index 35a631c4..f9ba485e 100644 --- a/django_elasticsearch_dsl/signals.py +++ b/django_elasticsearch_dsl/signals.py @@ -47,7 +47,7 @@ def handle_m2m_changed(self, sender, instance, action, **kwargs): if action in ('post_add', 'post_remove', 'post_clear'): self.handle_save(sender, instance) elif action in ('pre_remove', 'pre_clear'): - self.handle_pre_delete(sender, instance) + self.handle_pre_delete(sender, instance, origin=kwargs['model']()) def handle_save(self, sender, instance, **kwargs): """Handle save. @@ -62,15 +62,23 @@ def handle_pre_delete(self, sender, instance, **kwargs): """Handle removing of instance object from related models instance. We need to do this before the real delete otherwise the relation doesn't exists anymore and we can't get the related models instance. + + Disabling distribution for deletion cases other + than deletion by entity. """ - registry.delete_related(instance) + if isinstance(kwargs.get("origin"), models.Model): + registry.delete_related(instance) def handle_delete(self, sender, instance, **kwargs): """Handle delete. Given an individual model instance, delete the object from index. + + Disabling distribution for deletion cases other + than deletion by entity. """ - registry.delete(instance, raise_on_error=False) + if isinstance(kwargs.get("origin"), models.Model): + registry.delete(instance, raise_on_error=False) class RealTimeSignalProcessor(BaseSignalProcessor): diff --git a/tests/documents.py b/tests/documents.py index 08ef83f1..248e5483 100644 --- a/tests/documents.py +++ b/tests/documents.py @@ -2,7 +2,17 @@ from django_elasticsearch_dsl import Document, fields from django_elasticsearch_dsl.registries import registry -from .models import Ad, Category, Car, Manufacturer, Article +from .models import ( + Ad, + Category, + Car, + Manufacturer, + Article, + CarBulkManager, + CategoryBulkManager, + AdBulkManager, + ManufacturerBulkManager +) index_settings = { 'number_of_shards': 1, @@ -180,3 +190,120 @@ def generate_id(cls, article): ad_index = AdDocument._index car_index = CarDocument._index + + +@registry.register_document +class CarBulkDocument(Document): + + manufacturer = fields.ObjectField(properties={ + 'name': fields.TextField(), + 'country': fields.TextField(), + }) + + ads = fields.NestedField(properties={ + 'description': fields.TextField(analyzer=html_strip), + 'title': fields.TextField(), + 'pk': fields.IntegerField(), + }) + + categories = fields.NestedField(properties={ + 'title': fields.TextField(), + 'slug': fields.TextField(), + 'icon': fields.FileField(), + }) + + class Django: + model = CarBulkManager + related_models = [AdBulkManager, + ManufacturerBulkManager, + CategoryBulkManager] + fields = [ + 'name', + 'launched', + 'type', + ] + + class Index: + name = 'test_cars_bulk' + settings = index_settings + + def get_queryset(self): + return super(CarBulkDocument, self).get_queryset().select_related( + 'manufacturer') + + def get_instances_from_many_related(self, cls, related_instance): + if isinstance(related_instance, list): + if cls == AdBulkManager: + return CarBulkManager.objects.filter( + id__in=[ + item.car_id for item in related_instance + ] + ) + elif cls == ManufacturerBulkManager: + return CarBulkManager.objects.filter( + manufacturer_id__in=[ + item.id for item in related_instance + ] + ) + elif cls == CategoryBulkManager: + return CarBulkManager.objects.filter( + categories__id__in=[ + item.id for item in related_instance + ] + ) + else: + if cls == AdBulkManager: + return CarBulkManager.objects.filter( + id__in=related_instance.values_list("car_id", flat=True) + ) + elif cls == ManufacturerBulkManager: + return CarBulkManager.objects.filter( + manufacturer_id__in=related_instance.values_list( + "id", flat=True + ) + ) + elif cls == CategoryBulkManager: + return CarBulkManager.objects.filter( + categories__id__in=related_instance.values_list( + "id", flat=True + ) + ) + + +@registry.register_document +class ManufacturerBulkDocument(Document): + country = fields.TextField() + + class Django: + model = ManufacturerBulkManager + fields = [ + 'name', + 'created', + 'country_code', + 'logo', + ] + + class Index: + name = 'test_manufacturers_bulk' + settings = index_settings + + +@registry.register_document +class AdBulkDocument(Document): + description = fields.TextField( + analyzer=html_strip, + fields={'raw': fields.KeywordField()} + ) + + class Django: + model = AdBulkManager + fields = [ + 'title', + 'created', + 'modified', + 'url', + ] + + class Index: + name = 'test_ads_bulk' + settings = index_settings diff --git a/tests/fixtures.py b/tests/fixtures.py index 4d3bad9e..ba427e74 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -45,5 +45,6 @@ class Django: Doc.get_queryset = Mock(return_value=mock_qs) if _related_models: Doc.get_instances_from_related = Mock() + Doc.get_instances_from_many_related = Mock() return Doc diff --git a/tests/models.py b/tests/models.py index d1bd4f11..37968125 100644 --- a/tests/models.py +++ b/tests/models.py @@ -9,6 +9,8 @@ from django.utils.translation import gettext_lazy as _ from six import python_2_unicode_compatible +from django_elasticsearch_dsl.db import DjangoElasticsearchDslModelManager + @python_2_unicode_compatible class Car(models.Model): @@ -106,3 +108,90 @@ class Meta: def __str__(self): return self.slug + + +@python_2_unicode_compatible +class CarBulkManager(models.Model): + objects = DjangoElasticsearchDslModelManager.as_manager() + + TYPE_CHOICES = ( + ('se', "Sedan"), + ('br', "Break"), + ('4x', "4x4"), + ('co', "Coupé"), + ) + + name = models.CharField(max_length=255) + launched = models.DateField() + type = models.CharField( + max_length=2, + choices=TYPE_CHOICES, + default='se', + ) + manufacturer = models.ForeignKey( + 'ManufacturerBulkManager', null=True, on_delete=models.SET_NULL + ) + categories = models.ManyToManyField('CategoryBulkManager') + + class Meta: + app_label = 'tests' + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class ManufacturerBulkManager(models.Model): + objects = DjangoElasticsearchDslModelManager.as_manager() + + name = models.CharField(max_length=255, default=_("Test lazy tanslation")) + country_code = models.CharField(max_length=2) + created = models.DateField() + logo = models.ImageField(blank=True) + + class meta: + app_label = 'tests' + + def country(self): + return COUNTRIES.get(self.country_code, self.country_code) + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class CategoryBulkManager(models.Model): + objects = DjangoElasticsearchDslModelManager.as_manager() + + title = models.CharField(max_length=255) + slug = models.CharField(max_length=255) + icon = models.ImageField(blank=True) + + class Meta: + app_label = 'tests' + + def __str__(self): + return self.title + + +@python_2_unicode_compatible +class AdBulkManager(models.Model): + objects = DjangoElasticsearchDslModelManager.as_manager() + + title = models.CharField(max_length=255) + description = models.TextField() + created = models.DateField(auto_now_add=True) + modified = models.DateField(auto_now=True) + url = models.URLField() + car = models.ForeignKey( + 'CarBulkManager', + related_name='ads', + null=True, + on_delete=models.SET_NULL + ) + + class Meta: + app_label = 'tests' + + def __str__(self): + return self.title diff --git a/tests/test_integration.py b/tests/test_integration.py index 0dce997c..a7657c89 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -23,9 +23,22 @@ CarWithPrepareDocument, ArticleDocument, ArticleWithSlugAsIdDocument, - index_settings + index_settings, + CarBulkDocument, + ManufacturerBulkDocument +) +from .models import ( + Car, + Manufacturer, + Ad, + Category, + Article, + COUNTRIES, + CarBulkManager, + ManufacturerBulkManager, + AdBulkManager, + CategoryBulkManager ) -from .models import Car, Manufacturer, Ad, Category, Article, COUNTRIES @unittest.skipUnless(is_es_online(), 'Elasticsearch is offline') @@ -394,3 +407,158 @@ def test_custom_document_id(self): "using a custom id is broken".format(article_slug) ) self.assertEqual(es_obj.slug, article.slug) + + +@unittest.skipUnless(is_es_online(), 'Elasticsearch is offline') +class IntegrationBulkOperationConfTestCase(ESTestCase, TestCase): + + def setUp(self): + super().setUp() + + manufacturers = ManufacturerBulkManager.objects.bulk_create([ + ManufacturerBulkManager( + name="Peugeot", created=datetime(1900, 10, 9, 0, 0), + country_code="FR", logo='logo.jpg' + ) + ]) + self.manufacturer = manufacturers[0] + + cars = CarBulkManager.objects.bulk_create([ + CarBulkManager( + name="508", launched=datetime(2010, 9, 9, 0, 0), + manufacturer=self.manufacturer + ), + CarBulkManager( + name="208", launched=datetime(2010, 10, 9, 0, 0), + manufacturer=self.manufacturer + ), + CarBulkManager( + name="308", launched=datetime(2010, 11, 9, 0, 0) + ) + ]) + self.car1 = cars[0] + self.car2 = cars[1] + self.car3 = cars[2] + + self.assertEqual(self.car1.name, "508") + self.assertEqual(self.car2.name, "208") + self.assertEqual(self.car3.name, "308") + + categories = CategoryBulkManager.objects.bulk_create([ + CategoryBulkManager( + title="Category 1", slug="category-1", icon="icon.jpeg" + ), + CategoryBulkManager(title="Category 2", slug="category-2") + ]) + self.category1 = categories[0] + self.category2 = categories[1] + + self.assertEqual(self.category1.title, "Category 1") + self.assertEqual(self.category2.title, "Category 2") + + self.car2.categories.add(self.category1) + self.car2.save() + + self.car3.categories.add(self.category1, self.category2) + self.car3.save() + + ads = AdBulkManager.objects.bulk_create([ + AdBulkManager( + title=_("Ad number 1"), url="www.ad1.com", + description="My super ad description 1", + car=self.car1 + ), + AdBulkManager( + title="Ad number 2", url="www.ad2.com", + description="My super ad descriptio 2", + car=self.car1 + ) + ]) + self.ad1 = ads[0] + self.ad2 = ads[1] + + self.assertEqual(self.ad1.title, _("Ad number 1")) + self.assertEqual(self.ad2.title, "Ad number 2") + + def test_docs_are_updated_by_bulk_operations(self): + old_car2_name = self.car2.name + car2_name = "1008" + + s = CarBulkDocument.search().query("match", name=old_car2_name) + self.assertEqual(s.count(), 1) + + s = CarBulkDocument.search().query("match", name=car2_name) + self.assertEqual(s.count(), 0) + + CarBulkManager.objects.filter(id=self.car2.id).update(name=car2_name) + + s = CarBulkDocument.search().query("match", name=old_car2_name) + self.assertEqual(s.count(), 0) + + s = CarBulkDocument.search().query("match", name=car2_name) + self.assertEqual(s.count(), 1) + + s = CarBulkDocument.search().query("match", name=self.car3.name) + car3_doc = s.execute()[0] + self.assertEqual(car3_doc.manufacturer.name, None) + + CarBulkManager.objects.filter( + id=self.car3.id + ).update(manufacturer_id=self.manufacturer.pk) + + s = CarBulkDocument.search().query("match", name=self.car3.name) + car3_doc = s.execute()[0] + self.assertEqual(car3_doc.manufacturer.name, self.manufacturer.name) + + s = CarBulkDocument.search().query("match", name=self.car3.name) + self.assertEqual(s.count(), 1) + + CarBulkManager.objects.filter(id=self.car3.id).delete() + s = CarBulkDocument.search().query("match", name=self.car3.name) + self.assertEqual(s.count(), 0) + + s = CarBulkDocument.search() + self.assertEqual(s.count(), 2) + + s = ManufacturerBulkDocument.search() + self.assertEqual(s.count(), 1) + + ManufacturerBulkManager.objects.all().delete() + s = ManufacturerBulkDocument.search() + self.assertEqual(s.count(), 0) + + s = CarBulkDocument.search() + for result in s.execute(): + self.assertEqual(result.manufacturer.name, None) + + def test_related_docs_are_updated_by_bulk_operations(self): + ManufacturerBulkManager.objects.filter(id=self.manufacturer.id).update( + name="Citroen" + ) + + s = CarBulkDocument.search().query("match", name=self.car2.name) + car2_doc = s.execute()[0] + self.assertEqual(car2_doc.manufacturer.name, 'Citroen') + self.assertEqual(len(car2_doc.ads), 0) + + ad3 = AdBulkManager.objects.bulk_create([ + AdBulkManager(title=_("Ad number 3"), url="www.ad3.com", + description="My super ad description 3", + car=self.car2) + ]) + s = CarBulkDocument.search().query("match", name=self.car2.name) + car2_doc = s.execute()[0] + self.assertEqual(len(car2_doc.ads), 1) + + AdBulkManager.objects.filter(id__in=[ad.id for ad in ad3]).delete() + s = CarBulkDocument.search().query("match", name=self.car2.name) + car2_doc = s.execute()[0] + self.assertEqual(len(car2_doc.ads), 0) + + ManufacturerBulkManager.objects.filter( + id=self.manufacturer.id + ).delete() + s = CarBulkDocument.search().query("match", name=self.car2.name) + car2_doc = s.execute()[0] + + self.assertEqual(car2_doc.manufacturer.name, None) diff --git a/tests/test_registries.py b/tests/test_registries.py index c72894cb..ef13d68b 100644 --- a/tests/test_registries.py +++ b/tests/test_registries.py @@ -144,3 +144,187 @@ def test_autosync(self): self.assertFalse(self.doc_a1.update.called) settings.ELASTICSEARCH_DSL_AUTOSYNC = True + + +class DocumentRegistryBulkOperationsTestCase(WithFixturesMixin, TestCase): + """ + Test case for working with bulk operations. + """ + + def setUp(self) -> None: + self.registry = DocumentRegistry() + self.index_1 = Index(name='index_1') + self.index_2 = Index(name='index_2') + + self.doc_a1 = self._generate_doc_mock(self.ModelA, self.index_1) + self.doc_a2 = self._generate_doc_mock(self.ModelA, self.index_1) + self.doc_b1 = self._generate_doc_mock(self.ModelB, self.index_2) + self.doc_c1 = self._generate_doc_mock(self.ModelC, self.index_1) + + def test_update_instances(self): + """ + Checking for `update`. + """ + doc_a3 = self._generate_doc_mock( + self.ModelA, self.index_1, _ignore_signals=True + ) + + instances = self.ModelA.objects.all() + self.registry.update(instances) + + self.assertFalse(doc_a3.update.called) + self.assertFalse(self.doc_b1.update.called) + self.doc_a1.update.assert_called_once_with(instances) + self.doc_a2.update.assert_called_once_with(instances) + + def test_update_instances_as_list(self): + """ + Checking for `update` where instances is list. + """ + doc_a3 = self._generate_doc_mock( + self.ModelA, self.index_1, _ignore_signals=True + ) + + instances = [self.ModelA()] + self.registry.update(instances) + + self.assertFalse(doc_a3.update.called) + self.assertFalse(self.doc_b1.update.called) + self.doc_a1.update.assert_called_once_with(instances) + self.doc_a2.update.assert_called_once_with(instances) + + def test_update_related_instances(self): + """ + Checking the correct call of the get function from + related objects. + """ + doc_d1 = self._generate_doc_mock( + self.ModelD, self.index_1, + _related_models=[self.ModelE, self.ModelB] + ) + doc_d2 = self._generate_doc_mock( + self.ModelD, self.index_1, _related_models=[self.ModelE] + ) + + instances_e = self.ModelE.objects.all() + instances_b = self.ModelB.objects.all() + related_instances = self.ModelD.objects.all() + + doc_d2.get_instances_from_many_related.return_value = related_instances + doc_d1.get_instances_from_many_related.return_value = related_instances + self.registry.update_related(instances_e, many=True) + + doc_d1.get_instances_from_many_related.assert_called_once_with( + self.ModelE, instances_e + ) + doc_d1.get_instances_from_related.assert_not_called() + doc_d1.update.assert_called_once_with(related_instances) + doc_d2.get_instances_from_many_related.assert_called_once_with( + self.ModelE, instances_e + ) + doc_d2.get_instances_from_related.assert_not_called() + doc_d2.update.assert_called_once_with(related_instances) + + doc_d1.get_instances_from_many_related.reset_mock() + doc_d1.update.reset_mock() + doc_d2.get_instances_from_many_related.reset_mock() + doc_d2.update.reset_mock() + + self.registry.update_related(instances_b, many=True) + doc_d1.get_instances_from_many_related.assert_called_once_with( + self.ModelB, instances_b + ) + doc_d1.get_instances_from_related.assert_not_called() + doc_d1.update.assert_called_once_with(related_instances) + doc_d2.get_instances_from_many_related.assert_not_called() + doc_d2.get_instances_from_related.assert_not_called() + doc_d2.update.assert_not_called() + + def test_update_related_instances_not_defined(self): + """ + Checking the correct call, if the function of + getting objects from related is not defined. + """ + doc_d1 = self._generate_doc_mock(_model=self.ModelD, index=self.index_1, + _related_models=[self.ModelE]) + + instances = self.ModelE.objects.all() + + doc_d1.get_instances_from_related.return_value = None + self.registry.update_related(instances) + + doc_d1.get_instances_from_related.assert_called_once_with(instances) + doc_d1.update.assert_not_called() + + def test_delete_instances(self): + """ + Checking the correct call `delete`. + """ + doc_a3 = self._generate_doc_mock( + self.ModelA, self.index_1, _ignore_signals=True + ) + + instances = self.ModelA.objects.all() + self.registry.delete(instances) + + self.assertFalse(doc_a3.update.called) + self.assertFalse(self.doc_b1.update.called) + self.doc_a1.update.assert_called_once_with(instances, action='delete') + self.doc_a2.update.assert_called_once_with(instances, action='delete') + + def test_delete_related_instances(self): + """ + Checking the correct call `delete_related`. + + The signature is similar to `update_related`. + """ + doc_d1 = self._generate_doc_mock( + self.ModelD, self.index_1, + _related_models=[self.ModelE, self.ModelB] + ) + doc_d2 = self._generate_doc_mock( + self.ModelD, self.index_1, _related_models=[self.ModelE] + ) + + instances_e = self.ModelE.objects.all() + instances_b = self.ModelB.objects.all() + related_instances = self.ModelD.objects.all() + + doc_d2.get_instances_from_many_related.return_value = related_instances + doc_d1.get_instances_from_many_related.return_value = related_instances + self.registry.delete_related(instances_e, many=True) + + doc_d1.get_instances_from_many_related.assert_called_once_with( + self.ModelE, instances_e + ) + doc_d1.get_instances_from_related.assert_not_called() + doc_d1.update.assert_called_once_with(related_instances) + doc_d2.get_instances_from_many_related.assert_called_once_with( + self.ModelE, instances_e + ) + doc_d2.get_instances_from_related.assert_not_called() + doc_d2.update.assert_called_once_with(related_instances) + + doc_d1.get_instances_from_many_related.reset_mock() + doc_d1.update.reset_mock() + doc_d2.get_instances_from_many_related.reset_mock() + doc_d2.update.reset_mock() + + self.registry.delete_related(instances_b, many=True) + doc_d1.get_instances_from_many_related.assert_called_once_with( + self.ModelB, instances_b + ) + doc_d1.get_instances_from_related.assert_not_called() + doc_d1.update.assert_called_once_with(related_instances) + doc_d2.get_instances_from_many_related.assert_not_called() + doc_d2.get_instances_from_related.assert_not_called() + doc_d2.update.assert_not_called() + + def test_autosync(self): + settings.ELASTICSEARCH_DSL_AUTOSYNC = False + + instances = self.ModelA.objects.all() + self.registry.update(instances) + self.assertFalse(self.doc_a1.update.called) + + settings.ELASTICSEARCH_DSL_AUTOSYNC = True