diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py index cea674d4d..7758f5aba 100644 --- a/mypy_django_plugin/lib/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -526,6 +526,10 @@ def is_model_type(info: TypeInfo) -> bool: return info.metaclass_type is not None and info.metaclass_type.type.has_base(fullnames.MODEL_METACLASS_FULLNAME) +def is_registered_model_type(info: TypeInfo, django_context: "DjangoContext") -> bool: + return info.fullname in {get_class_fullname(cls) for cls in django_context.all_registered_model_classes} + + def get_model_from_expression( expr: Expression, *, diff --git a/mypy_django_plugin/transformers/fields.py b/mypy_django_plugin/transformers/fields.py index f2bd21009..ae5f2efeb 100644 --- a/mypy_django_plugin/transformers/fields.py +++ b/mypy_django_plugin/transformers/fields.py @@ -237,7 +237,10 @@ def transform_into_proper_return_type(ctx: FunctionContext, django_context: Djan assert isinstance(default_return_type, Instance) outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class() - if outer_model_info is None or not helpers.is_model_type(outer_model_info): + if outer_model_info is None or ( + not helpers.is_model_type(outer_model_info) + and not helpers.is_registered_model_type(outer_model_info, django_context) + ): return ctx.default_return_type assert isinstance(outer_model_info, TypeInfo) diff --git a/tests/typecheck/models/test_3rd_party_models.yml b/tests/typecheck/models/test_3rd_party_models.yml index 9a6636575..e7a23ef09 100644 --- a/tests/typecheck/models/test_3rd_party_models.yml +++ b/tests/typecheck/models/test_3rd_party_models.yml @@ -1,8 +1,15 @@ - case: handles_type_annotations_on_3rd_party_models + installed_apps: + - django_extensions + - myapp main: | - from django.db import models - from django_extensions.db.models import TimeStampedModel # type: ignore [import-untyped] + from myapp.models import A + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + from django_extensions.db.models import TimeStampedModel # type: ignore[import-untyped] - class A(TimeStampedModel): - name = models.CharField() - count = models.IntegerField() + class A(TimeStampedModel): + name = models.CharField()