From a0ec77d7058bf6863845925aeda1992b0b257b65 Mon Sep 17 00:00:00 2001 From: Kirill Date: Sun, 14 May 2023 21:43:40 +0700 Subject: [PATCH 1/7] Described manager for working with bulk operations on objects --- django_elasticsearch_dsl/db/__init__.py | 4 + django_elasticsearch_dsl/db/managers.py | 120 ++++++++++++++++++++++++ django_elasticsearch_dsl/db/utils.py | 9 ++ 3 files changed, 133 insertions(+) create mode 100644 django_elasticsearch_dsl/db/__init__.py create mode 100644 django_elasticsearch_dsl/db/managers.py create mode 100644 django_elasticsearch_dsl/db/utils.py diff --git a/django_elasticsearch_dsl/db/__init__.py b/django_elasticsearch_dsl/db/__init__.py new file mode 100644 index 00000000..c0db935f --- /dev/null +++ b/django_elasticsearch_dsl/db/__init__.py @@ -0,0 +1,4 @@ +from .managers import ( + DjangoElasticsearchDslManagerMixin, + DjangoElasticsearchDslModelManager +) diff --git a/django_elasticsearch_dsl/db/managers.py b/django_elasticsearch_dsl/db/managers.py new file mode 100644 index 00000000..304a6b2c --- /dev/null +++ b/django_elasticsearch_dsl/db/managers.py @@ -0,0 +1,120 @@ +from typing import List, Union + +from django.db import models + +from .utils import get_queryset_by_ids +from ..registries import registry + + +class DjangoElasticsearchDslManagerMixin(object): + """Elasticsearch DSL manager mixin for processing mass work with objects. + + Performs normalization by supported types and causes updating the + search engine appropriately. + + It acts similarly to a signal processor. + """ + _registry = registry + + def _normalize_results(self, result) -> Union[List[models.Model], models.QuerySet]: + if isinstance(result, models.Model): + return [result] + elif isinstance(result, (list, models.QuerySet)): + return result + else: + raise TypeError( + "Incorrect results type. " + "Expected 'django.db.models.Model', or 'django.db.models.Queryset', " + "but got %s" % type(result) + ) + + def _handle_save(self, result): + """Handle save. + + Given a many model instances, update the objects in the index. + Update the related objects either. + """ + results = self._normalize_results(result) + + self._registry.update(results) + self._registry.update_related(results, many=True) + + def _handle_pre_delete(self, result): + """Handle removing of objects from related models instances. + + We need to do this before the real delete otherwise the relation + doesn't exist anymore, and we can't get the related models instances. + """ + results = self._normalize_results(result) + + self._registry.delete_related( + results, + many=True, + raise_on_error=False, + ) + + def _handle_delete(self, result): + """Handle delete. + + Given a many model instances, delete the objects in the index. + """ + results = self._normalize_results(result) + + self._registry.delete( + results, + raise_on_error=False, + ) + + +class DjangoElasticsearchDslModelManager(models.QuerySet, DjangoElasticsearchDslManagerMixin): + """Django Elasticsearch Dsl model manager. + + Working with possible bulk operations, updates documents accordingly. + """ + + def bulk_create(self, objs, *args, **kwargs): + """Bulk create. + + Calls `handle_save` after saving is completed + """ + result = super().bulk_create(objs, *args, **kwargs) + self._handle_save(result) + return result + + def bulk_update(self, objs, *args, **kwargs): + """Bulk update. + + Calls `handle_save` after saving is completed + """ + result = super().bulk_update(objs, *args, **kwargs) + self._handle_save(objs) + return result + + def update(self, **kwargs): + """Update. + + Calls `handle_save` after saving is completed + """ + ids = list(self.values_list("id", flat=True)) + result = super().update(**kwargs) + if not ids: + return result + self._handle_save(get_queryset_by_ids(self.model, ids)) + return result + + def delete(self): + """Delete. + + Calls `handle_pre_delete` before performing the deletion. + + After deleting it causes `handle_delete`. + """ + objs = get_queryset_by_ids(self.model, list(self.values_list("id", flat=True))) + self._handle_pre_delete(objs) + objs = list(objs) + + result = super().delete() + + self._handle_delete(objs) + + return result diff --git a/django_elasticsearch_dsl/db/utils.py b/django_elasticsearch_dsl/db/utils.py new file mode 100644 index 00000000..449816bf --- /dev/null +++ b/django_elasticsearch_dsl/db/utils.py @@ -0,0 +1,9 @@ +from typing import List + +from django.db import models + + +def get_queryset_by_ids(model: models.Model, ids: List[int]): + return model.objects.filter( + id__in=ids + ) From a8f655ddffb77c2884ab10de301bb766809b32ad Mon Sep 17 00:00:00 2001 From: Kirill Date: Sun, 14 May 2023 21:45:28 +0700 Subject: [PATCH 2/7] corrected registries to work with updates on bulk operations --- django_elasticsearch_dsl/registries.py | 48 ++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/django_elasticsearch_dsl/registries.py b/django_elasticsearch_dsl/registries.py index 72510610..4890de89 100644 --- a/django_elasticsearch_dsl/registries.py +++ b/django_elasticsearch_dsl/registries.py @@ -5,6 +5,7 @@ from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ImproperlyConfigured +from django.db import models from elasticsearch_dsl import AttrDict from six import itervalues, iterkeys, iteritems @@ -88,22 +89,47 @@ def register_document(self, document): return document def _get_related_doc(self, instance): - for model in self._related_models.get(instance.__class__, []): + instance_cls = self._get_cls_from_instance(instance) + for model in self._related_models.get(instance_cls, []): for doc in self._models[model]: - if instance.__class__ in doc.django.related_models: + if instance_cls in doc.django.related_models: yield doc + def _get_cls_from_instance(self, instance): + """ + Get class from instance. + + Supports getting a class from a model, list, or queryset. + """ + if isinstance(instance, models.Model): + return instance.__class__ + elif isinstance(instance, list): + return instance[0].__class__ + elif isinstance(instance, models.QuerySet): + return instance.model + else: + return None + def update_related(self, instance, **kwargs): """ Update docs that have related_models. + + The `many` flag has been introduced to handle mass updates of objects. """ if not DEDConfig.autosync_enabled(): return + many = kwargs.pop("many", False) for doc in self._get_related_doc(instance): doc_instance = doc() try: - related = doc_instance.get_instances_from_related(instance) + if many: + related = doc_instance.get_instances_from_many_related( + self._get_cls_from_instance(instance), + instance + ) + else: + related = doc_instance.get_instances_from_related(instance) except ObjectDoesNotExist: related = None @@ -113,14 +139,23 @@ def update_related(self, instance, **kwargs): def delete_related(self, instance, **kwargs): """ Remove `instance` from related models. + + The `many` flag has been introduced to handle mass updates of objects. """ if not DEDConfig.autosync_enabled(): return + many = kwargs.pop("many", False) for doc in self._get_related_doc(instance): doc_instance = doc(related_instance_to_ignore=instance) try: - related = doc_instance.get_instances_from_related(instance) + if many: + related = doc_instance.get_instances_from_many_related( + self._get_cls_from_instance(instance), + instance + ) + else: + related = doc_instance.get_instances_from_related(instance) except ObjectDoesNotExist: related = None @@ -135,8 +170,9 @@ def update(self, instance, **kwargs): if not DEDConfig.autosync_enabled(): return - if instance.__class__ in self._models: - for doc in self._models[instance.__class__]: + instance_cls = self._get_cls_from_instance(instance) + if instance_cls in self._models: + for doc in self._models[instance_cls]: if not doc.django.ignore_signals: doc().update(instance, **kwargs) From f6c50d82feb086272d9f41ffa282b175979acbd5 Mon Sep 17 00:00:00 2001 From: Kirill Date: Sun, 14 May 2023 21:46:52 +0700 Subject: [PATCH 3/7] disabled distribution for delete signals in queryset --- django_elasticsearch_dsl/signals.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/django_elasticsearch_dsl/signals.py b/django_elasticsearch_dsl/signals.py index 35a631c4..e50f0f46 100644 --- a/django_elasticsearch_dsl/signals.py +++ b/django_elasticsearch_dsl/signals.py @@ -47,7 +47,7 @@ def handle_m2m_changed(self, sender, instance, action, **kwargs): if action in ('post_add', 'post_remove', 'post_clear'): self.handle_save(sender, instance) elif action in ('pre_remove', 'pre_clear'): - self.handle_pre_delete(sender, instance) + self.handle_pre_delete(sender, instance, origin=kwargs['model']()) def handle_save(self, sender, instance, **kwargs): """Handle save. @@ -62,15 +62,21 @@ def handle_pre_delete(self, sender, instance, **kwargs): """Handle removing of instance object from related models instance. We need to do this before the real delete otherwise the relation doesn't exists anymore and we can't get the related models instance. + + Disabling distribution for deletion cases other than deletion by entity. """ - registry.delete_related(instance) + if isinstance(kwargs.get("origin"), models.Model): + registry.delete_related(instance) def handle_delete(self, sender, instance, **kwargs): """Handle delete. Given an individual model instance, delete the object from index. + + Disabling distribution for deletion cases other than deletion by entity. """ - registry.delete(instance, raise_on_error=False) + if isinstance(kwargs.get("origin"), models.Model): + registry.delete(instance, raise_on_error=False) class RealTimeSignalProcessor(BaseSignalProcessor): From d7f92aaf2ab5e8d125361f9e680cf6251a18200a Mon Sep 17 00:00:00 2001 From: Kirill Date: Sun, 14 May 2023 21:48:14 +0700 Subject: [PATCH 4/7] added exclusion of related entities by queryset --- django_elasticsearch_dsl/fields.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/django_elasticsearch_dsl/fields.py b/django_elasticsearch_dsl/fields.py index 2652b68e..8bc6ea97 100644 --- a/django_elasticsearch_dsl/fields.py +++ b/django_elasticsearch_dsl/fields.py @@ -88,6 +88,15 @@ def get_value_from_instance(self, instance, field_value_to_ignore=None): if instance == field_value_to_ignore: return None + elif isinstance(field_value_to_ignore, models.QuerySet) and \ + isinstance(instance, models.Model) and \ + field_value_to_ignore.contains(instance): + return None + elif isinstance(field_value_to_ignore, models.QuerySet) and \ + isinstance(instance, models.QuerySet): + instance = instance.exclude( + id__in=field_value_to_ignore.values_list("id", flat=True) + ) # convert lazy object like lazy translations to string if isinstance(instance, Promise): From 8d7ce64b419357c32e8c1cc4cdcbc55454c19b68 Mon Sep 17 00:00:00 2001 From: Kirill Date: Sun, 14 May 2023 21:48:44 +0700 Subject: [PATCH 5/7] completed and updated testing --- tests/documents.py | 129 +++++++++++++++++++++++++++- tests/fixtures.py | 1 + tests/models.py | 86 +++++++++++++++++++ tests/test_integration.py | 168 ++++++++++++++++++++++++++++++++++++- tests/test_registries.py | 172 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 553 insertions(+), 3 deletions(-) diff --git a/tests/documents.py b/tests/documents.py index 08ef83f1..248e5483 100644 --- a/tests/documents.py +++ b/tests/documents.py @@ -2,7 +2,17 @@ from django_elasticsearch_dsl import Document, fields from django_elasticsearch_dsl.registries import registry -from .models import Ad, Category, Car, Manufacturer, Article +from .models import ( + Ad, + Category, + Car, + Manufacturer, + Article, + CarBulkManager, + CategoryBulkManager, + AdBulkManager, + ManufacturerBulkManager +) index_settings = { 'number_of_shards': 1, @@ -180,3 +190,120 @@ def generate_id(cls, article): ad_index = AdDocument._index car_index = CarDocument._index + + +@registry.register_document +class CarBulkDocument(Document): + + manufacturer = fields.ObjectField(properties={ + 'name': fields.TextField(), + 'country': fields.TextField(), + }) + + ads = fields.NestedField(properties={ + 'description': fields.TextField(analyzer=html_strip), + 'title': fields.TextField(), + 'pk': fields.IntegerField(), + }) + + categories = fields.NestedField(properties={ + 'title': fields.TextField(), + 'slug': fields.TextField(), + 'icon': fields.FileField(), + }) + + class Django: + model = CarBulkManager + related_models = [AdBulkManager, + ManufacturerBulkManager, + CategoryBulkManager] + fields = [ + 'name', + 'launched', + 'type', + ] + + class Index: + name = 'test_cars_bulk' + settings = index_settings + + def get_queryset(self): + return super(CarBulkDocument, self).get_queryset().select_related( + 'manufacturer') + + def get_instances_from_many_related(self, cls, related_instance): + if isinstance(related_instance, list): + if cls == AdBulkManager: + return CarBulkManager.objects.filter( + id__in=[ + item.car_id for item in related_instance + ] + ) + elif cls == ManufacturerBulkManager: + return CarBulkManager.objects.filter( + manufacturer_id__in=[ + item.id for item in related_instance + ] + ) + elif cls == CategoryBulkManager: + return CarBulkManager.objects.filter( + categories__id__in=[ + item.id for item in related_instance + ] + ) + else: + if cls == AdBulkManager: + return CarBulkManager.objects.filter( + id__in=related_instance.values_list("car_id", flat=True) + ) + elif cls == ManufacturerBulkManager: + return CarBulkManager.objects.filter( + manufacturer_id__in=related_instance.values_list( + "id", flat=True + ) + ) + elif cls == CategoryBulkManager: + return CarBulkManager.objects.filter( + categories__id__in=related_instance.values_list( + "id", flat=True + ) + ) + + +@registry.register_document +class ManufacturerBulkDocument(Document): + country = fields.TextField() + + class Django: + model = ManufacturerBulkManager + fields = [ + 'name', + 'created', + 'country_code', + 'logo', + ] + + class Index: + name = 'test_manufacturers_bulk' + settings = index_settings + + +@registry.register_document +class AdBulkDocument(Document): + description = fields.TextField( + analyzer=html_strip, + fields={'raw': fields.KeywordField()} + ) + + class Django: + model = AdBulkManager + fields = [ + 'title', + 'created', + 'modified', + 'url', + ] + + class Index: + name = 'test_ads_bulk' + settings = index_settings diff --git a/tests/fixtures.py b/tests/fixtures.py index 4d3bad9e..ba427e74 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -45,5 +45,6 @@ class Django: Doc.get_queryset = Mock(return_value=mock_qs) if _related_models: Doc.get_instances_from_related = Mock() + Doc.get_instances_from_many_related = Mock() return Doc diff --git a/tests/models.py b/tests/models.py index d1bd4f11..3eb09ce3 100644 --- a/tests/models.py +++ b/tests/models.py @@ -9,6 +9,8 @@ from django.utils.translation import gettext_lazy as _ from six import python_2_unicode_compatible +from django_elasticsearch_dsl.db import DjangoElasticsearchDslModelManager + @python_2_unicode_compatible class Car(models.Model): @@ -106,3 +108,87 @@ class Meta: def __str__(self): return self.slug + + +@python_2_unicode_compatible +class CarBulkManager(models.Model): + objects = DjangoElasticsearchDslModelManager.as_manager() + + TYPE_CHOICES = ( + ('se', "Sedan"), + ('br', "Break"), + ('4x', "4x4"), + ('co', "Coupé"), + ) + + name = models.CharField(max_length=255) + launched = models.DateField() + type = models.CharField( + max_length=2, + choices=TYPE_CHOICES, + default='se', + ) + manufacturer = models.ForeignKey( + 'ManufacturerBulkManager', null=True, on_delete=models.SET_NULL + ) + categories = models.ManyToManyField('CategoryBulkManager') + + class Meta: + app_label = 'tests' + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class ManufacturerBulkManager(models.Model): + objects = DjangoElasticsearchDslModelManager.as_manager() + + name = models.CharField(max_length=255, default=_("Test lazy tanslation")) + country_code = models.CharField(max_length=2) + created = models.DateField() + logo = models.ImageField(blank=True) + + class meta: + app_label = 'tests' + + def country(self): + return COUNTRIES.get(self.country_code, self.country_code) + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class CategoryBulkManager(models.Model): + objects = DjangoElasticsearchDslModelManager.as_manager() + + title = models.CharField(max_length=255) + slug = models.CharField(max_length=255) + icon = models.ImageField(blank=True) + + class Meta: + app_label = 'tests' + + def __str__(self): + return self.title + + +@python_2_unicode_compatible +class AdBulkManager(models.Model): + objects = DjangoElasticsearchDslModelManager.as_manager() + + title = models.CharField(max_length=255) + description = models.TextField() + created = models.DateField(auto_now_add=True) + modified = models.DateField(auto_now=True) + url = models.URLField() + car = models.ForeignKey( + 'CarBulkManager', related_name='ads', null=True, on_delete=models.SET_NULL + ) + + class Meta: + app_label = 'tests' + + def __str__(self): + return self.title diff --git a/tests/test_integration.py b/tests/test_integration.py index 0dce997c..190e66d0 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -23,9 +23,22 @@ CarWithPrepareDocument, ArticleDocument, ArticleWithSlugAsIdDocument, - index_settings + index_settings, + CarBulkDocument, + AdBulkDocument, ManufacturerBulkDocument +) +from .models import ( + Car, + Manufacturer, + Ad, + Category, + Article, + COUNTRIES, + CarBulkManager, + ManufacturerBulkManager, + AdBulkManager, + CategoryBulkManager ) -from .models import Car, Manufacturer, Ad, Category, Article, COUNTRIES @unittest.skipUnless(is_es_online(), 'Elasticsearch is offline') @@ -394,3 +407,154 @@ def test_custom_document_id(self): "using a custom id is broken".format(article_slug) ) self.assertEqual(es_obj.slug, article.slug) + + +@unittest.skipUnless(is_es_online(), 'Elasticsearch is offline') +class IntegrationBulkOperationConfTestCase(ESTestCase, TestCase): + + def setUp(self): + super().setUp() + + manufacturers = ManufacturerBulkManager.objects.bulk_create([ + ManufacturerBulkManager( + name="Peugeot", created=datetime(1900, 10, 9, 0, 0), + country_code="FR", logo='logo.jpg' + ) + ]) + self.manufacturer = manufacturers[0] + + cars = CarBulkManager.objects.bulk_create([ + CarBulkManager( + name="508", launched=datetime(2010, 9, 9, 0, 0), + manufacturer=self.manufacturer + ), + CarBulkManager( + name="208", launched=datetime(2010, 10, 9, 0, 0), + manufacturer=self.manufacturer + ), + CarBulkManager( + name="308", launched=datetime(2010, 11, 9, 0, 0) + ) + ]) + self.car1 = cars[0] + self.car2 = cars[1] + self.car3 = cars[2] + + self.assertEqual(self.car1.name, "508") + self.assertEqual(self.car2.name, "208") + self.assertEqual(self.car3.name, "308") + + categories = CategoryBulkManager.objects.bulk_create([ + CategoryBulkManager( + title="Category 1", slug="category-1", icon="icon.jpeg" + ), + CategoryBulkManager(title="Category 2", slug="category-2") + ]) + self.category1 = categories[0] + self.category2 = categories[1] + + self.assertEqual(self.category1.title, "Category 1") + self.assertEqual(self.category2.title, "Category 2") + + self.car2.categories.add(self.category1) + self.car2.save() + + self.car3.categories.add(self.category1, self.category2) + self.car3.save() + + ads = AdBulkManager.objects.bulk_create([ + AdBulkManager( + title=_("Ad number 1"), url="www.ad1.com", + description="My super ad description 1", + car=self.car1 + ), + AdBulkManager( + title="Ad number 2", url="www.ad2.com", + description="My super ad descriptio 2", + car=self.car1 + ) + ]) + self.ad1 = ads[0] + self.ad2 = ads[1] + + self.assertEqual(self.ad1.title, _("Ad number 1")) + self.assertEqual(self.ad2.title, "Ad number 2") + + def test_docs_are_updated_by_bulk_operations(self): + old_car2_name = self.car2.name + car2_name = "1008" + + s = CarBulkDocument.search().query("match", name=old_car2_name) + self.assertEqual(s.count(), 1) + + s = CarBulkDocument.search().query("match", name=car2_name) + self.assertEqual(s.count(), 0) + + CarBulkManager.objects.filter(id=self.car2.id).update(name=car2_name) + + s = CarBulkDocument.search().query("match", name=old_car2_name) + self.assertEqual(s.count(), 0) + + s = CarBulkDocument.search().query("match", name=car2_name) + self.assertEqual(s.count(), 1) + + s = CarBulkDocument.search().query("match", name=self.car3.name) + car3_doc = s.execute()[0] + self.assertEqual(car3_doc.manufacturer.name, None) + + CarBulkManager.objects.filter(id=self.car3.id).update(manufacturer_id=self.manufacturer.pk) + + s = CarBulkDocument.search().query("match", name=self.car3.name) + car3_doc = s.execute()[0] + self.assertEqual(car3_doc.manufacturer.name, self.manufacturer.name) + + s = CarBulkDocument.search().query("match", name=self.car3.name) + self.assertEqual(s.count(), 1) + + CarBulkManager.objects.filter(id=self.car3.id).delete() + s = CarBulkDocument.search().query("match", name=self.car3.name) + self.assertEqual(s.count(), 0) + + s = CarBulkDocument.search() + self.assertEqual(s.count(), 2) + + s = ManufacturerBulkDocument.search() + self.assertEqual(s.count(), 1) + + ManufacturerBulkManager.objects.all().delete() + s = ManufacturerBulkDocument.search() + self.assertEqual(s.count(), 0) + + s = CarBulkDocument.search() + for result in s.execute(): + self.assertEqual(result.manufacturer.name, None) + + def test_related_docs_are_updated_by_bulk_operations(self): + ManufacturerBulkManager.objects.filter(id=self.manufacturer.id).update( + name="Citroen" + ) + + s = CarBulkDocument.search().query("match", name=self.car2.name) + car2_doc = s.execute()[0] + self.assertEqual(car2_doc.manufacturer.name, 'Citroen') + self.assertEqual(len(car2_doc.ads), 0) + + ad3 = AdBulkManager.objects.bulk_create([ + AdBulkManager(title=_("Ad number 3"), url="www.ad3.com", + description="My super ad description 3", + car=self.car2) + ]) + s = CarBulkDocument.search().query("match", name=self.car2.name) + car2_doc = s.execute()[0] + self.assertEqual(len(car2_doc.ads), 1) + + AdBulkManager.objects.filter(id__in=[ad.id for ad in ad3]).delete() + s = CarBulkDocument.search().query("match", name=self.car2.name) + car2_doc = s.execute()[0] + self.assertEqual(len(car2_doc.ads), 0) + + ManufacturerBulkManager.objects.filter(id=self.manufacturer.id).delete() + s = CarBulkDocument.search().query("match", name=self.car2.name) + car2_doc = s.execute()[0] + + self.assertEqual(car2_doc.manufacturer.name, None) diff --git a/tests/test_registries.py b/tests/test_registries.py index c72894cb..e4dabad3 100644 --- a/tests/test_registries.py +++ b/tests/test_registries.py @@ -144,3 +144,175 @@ def test_autosync(self): self.assertFalse(self.doc_a1.update.called) settings.ELASTICSEARCH_DSL_AUTOSYNC = True + + +class DocumentRegistryBulkOperationsTestCase(WithFixturesMixin, TestCase): + """ + Test case for working with bulk operations. + """ + + def setUp(self) -> None: + self.registry = DocumentRegistry() + self.index_1 = Index(name='index_1') + self.index_2 = Index(name='index_2') + + self.doc_a1 = self._generate_doc_mock(self.ModelA, self.index_1) + self.doc_a2 = self._generate_doc_mock(self.ModelA, self.index_1) + self.doc_b1 = self._generate_doc_mock(self.ModelB, self.index_2) + self.doc_c1 = self._generate_doc_mock(self.ModelC, self.index_1) + + def test_update_instances(self): + """ + Checking for `update`. + """ + doc_a3 = self._generate_doc_mock( + self.ModelA, self.index_1, _ignore_signals=True + ) + + instances = self.ModelA.objects.all() + self.registry.update(instances) + + self.assertFalse(doc_a3.update.called) + self.assertFalse(self.doc_b1.update.called) + self.doc_a1.update.assert_called_once_with(instances) + self.doc_a2.update.assert_called_once_with(instances) + + def test_update_instances_as_list(self): + """ + Checking for `update` where instances is list. + """ + doc_a3 = self._generate_doc_mock( + self.ModelA, self.index_1, _ignore_signals=True + ) + + instances = [self.ModelA()] + self.registry.update(instances) + + self.assertFalse(doc_a3.update.called) + self.assertFalse(self.doc_b1.update.called) + self.doc_a1.update.assert_called_once_with(instances) + self.doc_a2.update.assert_called_once_with(instances) + + def test_update_related_instances(self): + """ + Checking the correct call of the get function from + related objects. + """ + doc_d1 = self._generate_doc_mock( + self.ModelD, self.index_1, + _related_models=[self.ModelE, self.ModelB] + ) + doc_d2 = self._generate_doc_mock( + self.ModelD, self.index_1, _related_models=[self.ModelE] + ) + + instances_e = self.ModelE.objects.all() + instances_b = self.ModelB.objects.all() + related_instances = self.ModelD.objects.all() + + doc_d2.get_instances_from_many_related.return_value = related_instances + doc_d1.get_instances_from_many_related.return_value = related_instances + self.registry.update_related(instances_e, many=True) + + doc_d1.get_instances_from_many_related.assert_called_once_with(self.ModelE, instances_e) + doc_d1.get_instances_from_related.assert_not_called() + doc_d1.update.assert_called_once_with(related_instances) + doc_d2.get_instances_from_many_related.assert_called_once_with(self.ModelE, instances_e) + doc_d2.get_instances_from_related.assert_not_called() + doc_d2.update.assert_called_once_with(related_instances) + + doc_d1.get_instances_from_many_related.reset_mock() + doc_d1.update.reset_mock() + doc_d2.get_instances_from_many_related.reset_mock() + doc_d2.update.reset_mock() + + self.registry.update_related(instances_b, many=True) + doc_d1.get_instances_from_many_related.assert_called_once_with(self.ModelB, instances_b) + doc_d1.get_instances_from_related.assert_not_called() + doc_d1.update.assert_called_once_with(related_instances) + doc_d2.get_instances_from_many_related.assert_not_called() + doc_d2.get_instances_from_related.assert_not_called() + doc_d2.update.assert_not_called() + + def test_update_related_instances_not_defined(self): + """ + Checking the correct call, if the function of + getting objects from related is not defined. + """ + doc_d1 = self._generate_doc_mock(_model=self.ModelD, index=self.index_1, + _related_models=[self.ModelE]) + + instances = self.ModelE.objects.all() + + doc_d1.get_instances_from_related.return_value = None + self.registry.update_related(instances) + + doc_d1.get_instances_from_related.assert_called_once_with(instances) + doc_d1.update.assert_not_called() + + def test_delete_instances(self): + """ + Checking the correct call `delete`. + """ + doc_a3 = self._generate_doc_mock( + self.ModelA, self.index_1, _ignore_signals=True + ) + + instances = self.ModelA.objects.all() + self.registry.delete(instances) + + self.assertFalse(doc_a3.update.called) + self.assertFalse(self.doc_b1.update.called) + self.doc_a1.update.assert_called_once_with(instances, action='delete') + self.doc_a2.update.assert_called_once_with(instances, action='delete') + + def test_delete_related_instances(self): + """ + Checking the correct call `delete_related`. + + The signature is similar to `update_related`. + """ + doc_d1 = self._generate_doc_mock( + self.ModelD, self.index_1, + _related_models=[self.ModelE, self.ModelB] + ) + doc_d2 = self._generate_doc_mock( + self.ModelD, self.index_1, _related_models=[self.ModelE] + ) + + instances_e = self.ModelE.objects.all() + instances_b = self.ModelB.objects.all() + related_instances = self.ModelD.objects.all() + + doc_d2.get_instances_from_many_related.return_value = related_instances + doc_d1.get_instances_from_many_related.return_value = related_instances + self.registry.delete_related(instances_e, many=True) + + doc_d1.get_instances_from_many_related.assert_called_once_with(self.ModelE, instances_e) + doc_d1.get_instances_from_related.assert_not_called() + doc_d1.update.assert_called_once_with(related_instances) + doc_d2.get_instances_from_many_related.assert_called_once_with(self.ModelE, instances_e) + doc_d2.get_instances_from_related.assert_not_called() + doc_d2.update.assert_called_once_with(related_instances) + + doc_d1.get_instances_from_many_related.reset_mock() + doc_d1.update.reset_mock() + doc_d2.get_instances_from_many_related.reset_mock() + doc_d2.update.reset_mock() + + self.registry.delete_related(instances_b, many=True) + doc_d1.get_instances_from_many_related.assert_called_once_with(self.ModelB, instances_b) + doc_d1.get_instances_from_related.assert_not_called() + doc_d1.update.assert_called_once_with(related_instances) + doc_d2.get_instances_from_many_related.assert_not_called() + doc_d2.get_instances_from_related.assert_not_called() + doc_d2.update.assert_not_called() + + def test_autosync(self): + settings.ELASTICSEARCH_DSL_AUTOSYNC = False + + instances = self.ModelA.objects.all() + self.registry.update(instances) + self.assertFalse(self.doc_a1.update.called) + + settings.ELASTICSEARCH_DSL_AUTOSYNC = True From fb711744791f2f9b3e12d4ba05208a47aac43689 Mon Sep 17 00:00:00 2001 From: Kirill Date: Sun, 14 May 2023 22:39:09 +0700 Subject: [PATCH 6/7] Corrected by flake8 --- django_elasticsearch_dsl/db/__init__.py | 2 +- django_elasticsearch_dsl/db/managers.py | 15 +++++++++------ django_elasticsearch_dsl/fields.py | 6 +++--- django_elasticsearch_dsl/signals.py | 6 ++++-- tests/models.py | 5 ++++- tests/test_integration.py | 10 +++++++--- tests/test_registries.py | 24 ++++++++++++++++++------ 7 files changed, 46 insertions(+), 22 deletions(-) diff --git a/django_elasticsearch_dsl/db/__init__.py b/django_elasticsearch_dsl/db/__init__.py index c0db935f..ed287ab3 100644 --- a/django_elasticsearch_dsl/db/__init__.py +++ b/django_elasticsearch_dsl/db/__init__.py @@ -1,4 +1,4 @@ -from .managers import ( +from .managers import ( # noqa DjangoElasticsearchDslManagerMixin, DjangoElasticsearchDslModelManager ) diff --git a/django_elasticsearch_dsl/db/managers.py b/django_elasticsearch_dsl/db/managers.py index 304a6b2c..9150598f 100644 --- a/django_elasticsearch_dsl/db/managers.py +++ b/django_elasticsearch_dsl/db/managers.py @@ -1,5 +1,3 @@ -from typing import List, Union - from django.db import models from .utils import get_queryset_by_ids @@ -16,7 +14,7 @@ class DjangoElasticsearchDslManagerMixin(object): """ _registry = registry - def _normalize_results(self, result) -> Union[List[models.Model], models.QuerySet]: + def _normalize_results(self, result): if isinstance(result, models.Model): return [result] elif isinstance(result, (list, models.QuerySet)): @@ -24,7 +22,8 @@ def _normalize_results(self, result) -> Union[List[models.Model], models.QuerySe else: raise TypeError( "Incorrect results type. " - "Expected 'django.db.models.Model', or 'django.db.models.Queryset', " + "Expected 'django.db.models.Model', " + " or 'django.db.models.Queryset', " "but got %s" % type(result) ) @@ -66,7 +65,8 @@ def _handle_delete(self, result): ) -class DjangoElasticsearchDslModelManager(models.QuerySet, DjangoElasticsearchDslManagerMixin): +class DjangoElasticsearchDslModelManager(models.QuerySet, + DjangoElasticsearchDslManagerMixin): """Django Elasticsearch Dsl model manager. Working with possible bulk operations, updates documents accordingly. @@ -109,7 +109,10 @@ def delete(self): After deleting it causes `handle_delete`. """ - objs = get_queryset_by_ids(self.model, list(self.values_list("id", flat=True))) + objs = get_queryset_by_ids( + self.model, + list(self.values_list("id", flat=True)) + ) self._handle_pre_delete(objs) objs = list(objs) diff --git a/django_elasticsearch_dsl/fields.py b/django_elasticsearch_dsl/fields.py index 8bc6ea97..7ed255ca 100644 --- a/django_elasticsearch_dsl/fields.py +++ b/django_elasticsearch_dsl/fields.py @@ -89,11 +89,11 @@ def get_value_from_instance(self, instance, field_value_to_ignore=None): if instance == field_value_to_ignore: return None elif isinstance(field_value_to_ignore, models.QuerySet) and \ - isinstance(instance, models.Model) and \ - field_value_to_ignore.contains(instance): + isinstance(instance, models.Model) and \ + field_value_to_ignore.contains(instance): return None elif isinstance(field_value_to_ignore, models.QuerySet) and \ - isinstance(instance, models.QuerySet): + isinstance(instance, models.QuerySet): instance = instance.exclude( id__in=field_value_to_ignore.values_list("id", flat=True) ) diff --git a/django_elasticsearch_dsl/signals.py b/django_elasticsearch_dsl/signals.py index e50f0f46..f9ba485e 100644 --- a/django_elasticsearch_dsl/signals.py +++ b/django_elasticsearch_dsl/signals.py @@ -63,7 +63,8 @@ def handle_pre_delete(self, sender, instance, **kwargs): We need to do this before the real delete otherwise the relation doesn't exists anymore and we can't get the related models instance. - Disabling distribution for deletion cases other than deletion by entity. + Disabling distribution for deletion cases other + than deletion by entity. """ if isinstance(kwargs.get("origin"), models.Model): registry.delete_related(instance) @@ -73,7 +74,8 @@ def handle_delete(self, sender, instance, **kwargs): Given an individual model instance, delete the object from index. - Disabling distribution for deletion cases other than deletion by entity. + Disabling distribution for deletion cases other + than deletion by entity. """ if isinstance(kwargs.get("origin"), models.Model): registry.delete(instance, raise_on_error=False) diff --git a/tests/models.py b/tests/models.py index 3eb09ce3..37968125 100644 --- a/tests/models.py +++ b/tests/models.py @@ -184,7 +184,10 @@ class AdBulkManager(models.Model): modified = models.DateField(auto_now=True) url = models.URLField() car = models.ForeignKey( - 'CarBulkManager', related_name='ads', null=True, on_delete=models.SET_NULL + 'CarBulkManager', + related_name='ads', + null=True, + on_delete=models.SET_NULL ) class Meta: diff --git a/tests/test_integration.py b/tests/test_integration.py index 190e66d0..a7657c89 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -25,7 +25,7 @@ ArticleWithSlugAsIdDocument, index_settings, CarBulkDocument, - AdBulkDocument, ManufacturerBulkDocument + ManufacturerBulkDocument ) from .models import ( Car, @@ -502,7 +502,9 @@ def test_docs_are_updated_by_bulk_operations(self): car3_doc = s.execute()[0] self.assertEqual(car3_doc.manufacturer.name, None) - CarBulkManager.objects.filter(id=self.car3.id).update(manufacturer_id=self.manufacturer.pk) + CarBulkManager.objects.filter( + id=self.car3.id + ).update(manufacturer_id=self.manufacturer.pk) s = CarBulkDocument.search().query("match", name=self.car3.name) car3_doc = s.execute()[0] @@ -553,7 +555,9 @@ def test_related_docs_are_updated_by_bulk_operations(self): car2_doc = s.execute()[0] self.assertEqual(len(car2_doc.ads), 0) - ManufacturerBulkManager.objects.filter(id=self.manufacturer.id).delete() + ManufacturerBulkManager.objects.filter( + id=self.manufacturer.id + ).delete() s = CarBulkDocument.search().query("match", name=self.car2.name) car2_doc = s.execute()[0] diff --git a/tests/test_registries.py b/tests/test_registries.py index e4dabad3..ef13d68b 100644 --- a/tests/test_registries.py +++ b/tests/test_registries.py @@ -214,10 +214,14 @@ def test_update_related_instances(self): doc_d1.get_instances_from_many_related.return_value = related_instances self.registry.update_related(instances_e, many=True) - doc_d1.get_instances_from_many_related.assert_called_once_with(self.ModelE, instances_e) + doc_d1.get_instances_from_many_related.assert_called_once_with( + self.ModelE, instances_e + ) doc_d1.get_instances_from_related.assert_not_called() doc_d1.update.assert_called_once_with(related_instances) - doc_d2.get_instances_from_many_related.assert_called_once_with(self.ModelE, instances_e) + doc_d2.get_instances_from_many_related.assert_called_once_with( + self.ModelE, instances_e + ) doc_d2.get_instances_from_related.assert_not_called() doc_d2.update.assert_called_once_with(related_instances) @@ -227,7 +231,9 @@ def test_update_related_instances(self): doc_d2.update.reset_mock() self.registry.update_related(instances_b, many=True) - doc_d1.get_instances_from_many_related.assert_called_once_with(self.ModelB, instances_b) + doc_d1.get_instances_from_many_related.assert_called_once_with( + self.ModelB, instances_b + ) doc_d1.get_instances_from_related.assert_not_called() doc_d1.update.assert_called_once_with(related_instances) doc_d2.get_instances_from_many_related.assert_not_called() @@ -288,10 +294,14 @@ def test_delete_related_instances(self): doc_d1.get_instances_from_many_related.return_value = related_instances self.registry.delete_related(instances_e, many=True) - doc_d1.get_instances_from_many_related.assert_called_once_with(self.ModelE, instances_e) + doc_d1.get_instances_from_many_related.assert_called_once_with( + self.ModelE, instances_e + ) doc_d1.get_instances_from_related.assert_not_called() doc_d1.update.assert_called_once_with(related_instances) - doc_d2.get_instances_from_many_related.assert_called_once_with(self.ModelE, instances_e) + doc_d2.get_instances_from_many_related.assert_called_once_with( + self.ModelE, instances_e + ) doc_d2.get_instances_from_related.assert_not_called() doc_d2.update.assert_called_once_with(related_instances) @@ -301,7 +311,9 @@ def test_delete_related_instances(self): doc_d2.update.reset_mock() self.registry.delete_related(instances_b, many=True) - doc_d1.get_instances_from_many_related.assert_called_once_with(self.ModelB, instances_b) + doc_d1.get_instances_from_many_related.assert_called_once_with( + self.ModelB, instances_b + ) doc_d1.get_instances_from_related.assert_not_called() doc_d1.update.assert_called_once_with(related_instances) doc_d2.get_instances_from_many_related.assert_not_called() From 84334fc3d6f8ae7cbd66e43cf8ffd3da00dfdbd7 Mon Sep 17 00:00:00 2001 From: Kirill Date: Sun, 14 May 2023 22:39:44 +0700 Subject: [PATCH 7/7] updated readme --- README.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/README.rst b/README.rst index 81c46f90..b7cb1883 100644 --- a/README.rst +++ b/README.rst @@ -28,6 +28,7 @@ Features - Elasticsearch auto mapping from django models fields. - Complex field type support (ObjectField, NestedField). - Index fast using `parallel` indexing. +- Bulk operations support. - Requirements - Django >= 1.11