diff --git a/pylint/plugins/hass_imports.py b/pylint/plugins/hass_imports.py index 2fe70fad10dda8..33e79ea8f7ba16 100644 --- a/pylint/plugins/hass_imports.py +++ b/pylint/plugins/hass_imports.py @@ -227,6 +227,11 @@ class HassImportsFormatChecker(BaseChecker): "hass-import-constant-alias", "Used when a constant should be imported as an alias", ), + "W7427": ( + "`%s` should be imported using `from %s import %s as %s`", + "hass-alias-import", + "Used when an alias import should be imported with from ... import ... as ...", + ), } options = () @@ -247,22 +252,33 @@ def visit_import(self, node: nodes.Import) -> None: """Check for improper `import _` invocations.""" if self.current_package is None: return - for module, _alias in node.names: + for module, alias in node.names: if module.startswith(f"{self.current_package}."): self.add_message("hass-relative-import", node=node) continue + + # Disable imports from component sub-modules if ( module.startswith("homeassistant.components.") and len(module.split(".")) > 3 - ): - if ( + # Unless from tests for the same component + and not ( self.current_package.startswith("tests.components.") and self.current_package.split(".")[2] == module.split(".")[2] - ): - # Ignore check if the component being tested matches - # the component being imported from - continue + ) + ): self.add_message("hass-component-root-import", node=node) + continue + + # Prefer `from ... import ... as ...` over `import ... as ...` + if alias and module.startswith("homeassistant."): + prefix, _delimiter, imported_module = module.rpartition(".") + self.add_message( + "hass-alias-import", + node=node, + args=(imported_module, prefix, imported_module, alias), + ) + continue def _visit_importfrom_relative( self, current_package: str, node: nodes.ImportFrom diff --git a/tests/pylint/__init__.py b/tests/pylint/__init__.py index abe4c14c8791ac..553664c94a284a 100644 --- a/tests/pylint/__init__.py +++ b/tests/pylint/__init__.py @@ -1,19 +1,24 @@ """Tests for pylint.""" +from collections.abc import Generator import contextlib from pylint.testutils.unittest_linter import UnittestLinter @contextlib.contextmanager -def assert_no_messages(linter: UnittestLinter): +def assert_no_messages( + linter: UnittestLinter, *, ignore_codes: list[str] | None = None +) -> Generator[None]: """Assert that no messages are added by the given method.""" - with assert_adds_messages(linter): + with assert_adds_messages(linter, ignore_codes=ignore_codes): yield @contextlib.contextmanager -def assert_adds_messages(linter: UnittestLinter, *messages): +def assert_adds_messages( + linter: UnittestLinter, *messages, ignore_codes: list[str] | None = None +) -> Generator[None]: """Assert that exactly the given method adds the given messages. The list of messages must exactly match *all* the messages added by the @@ -22,6 +27,8 @@ def assert_adds_messages(linter: UnittestLinter, *messages): """ yield got = linter.release_messages() + if ignore_codes: + got = [msg for msg in got if msg.msg_id not in ignore_codes] no_msg = "No message." expected = "\n".join(repr(m) for m in messages) or no_msg got_str = "\n".join(repr(m) for m in got) or no_msg diff --git a/tests/pylint/test_imports.py b/tests/pylint/test_imports.py index 5044e73d253486..98d9c27dc43bd1 100644 --- a/tests/pylint/test_imports.py +++ b/tests/pylint/test_imports.py @@ -186,7 +186,7 @@ def test_good_root_import( ) imports_checker.visit_module(node.parent) - with assert_no_messages(linter): + with assert_no_messages(linter, ignore_codes=["hass-alias-import"]): if import_node.startswith("import"): imports_checker.visit_import(node) if import_node.startswith("from"): @@ -368,3 +368,50 @@ def test_domain_alias( imports_checker.visit_import(import_node) else: imports_checker.visit_importfrom(import_node) + + +def test_import_as(linter: UnittestLinter, imports_checker: BaseChecker) -> None: + """Prefer `from x.y import z as my_z` over `import x.y.z as my_z`.""" + + module_name = "pylint_test" + + import_nodes = astroid.extract_node( + """ + import foo #@ + import foo as my_foo #@ + import foo.bar #@ + import foo.bar as my_bar #@ + import foo.local.bar #@ + import foo.local.bar as my_bar #@ + import homeassistant.bar #@ + import homeassistant.bar as my_bar #@ + import homeassistant.local.bar #@ + import homeassistant.local.bar as my_bar #@ + """, + module_name, + ) + imports_checker.visit_module(import_nodes[0].parent) + + expected_messages = [ + pylint.testutils.MessageTest( + msg_id="hass-alias-import", + node=import_nodes[7], + args=("bar", "homeassistant", "bar", "my_bar"), + line=9, + col_offset=0, + end_line=9, + end_col_offset=34, + ), + pylint.testutils.MessageTest( + msg_id="hass-alias-import", + node=import_nodes[9], + args=("bar", "homeassistant.local", "bar", "my_bar"), + line=11, + col_offset=0, + end_line=11, + end_col_offset=40, + ), + ] + with assert_adds_messages(linter, *expected_messages): + for import_node in import_nodes: + imports_checker.visit_import(import_node)