Skip to content

Commit

Permalink
feat: detect builtin module name shadowing
Browse files Browse the repository at this point in the history
  • Loading branch information
asfaltboy committed Feb 20, 2024
1 parent cb65879 commit 06a3280
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 24 deletions.
79 changes: 56 additions & 23 deletions flake8_builtins.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,31 @@
import sys
from flake8 import utils as stdin_utils

import ast
import builtins
import inspect
from pathlib import Path


class BuiltinsChecker:
name = 'flake8_builtins'
version = '1.5.2'
name = "flake8_builtins"
version = "1.5.2"
assign_msg = 'A001 variable "{0}" is shadowing a Python builtin'
argument_msg = 'A002 argument "{0}" is shadowing a Python builtin'
class_attribute_msg = 'A003 class attribute "{0}" is shadowing a Python builtin'
import_msg = 'A004 import statement "{0}" is shadowing a Python builtin'
module_name_msg = (
'A005 the module is shadowing a Python builtin module "{0}"'
)

names = []
ignore_list = {
'__name__',
'__doc__',
'credits',
'_',
"__name__",
"__doc__",
"credits",
"_",
}
ignored_module_names = {}

def __init__(self, tree, filename):
self.tree = tree
Expand All @@ -28,11 +34,18 @@ def __init__(self, tree, filename):
@classmethod
def add_options(cls, option_manager):
option_manager.add_option(
'--builtins-ignorelist',
metavar='builtins',
"--builtins-ignorelist",
metavar="builtins",
parse_from_config=True,
comma_separated_list=True,
help='A comma separated list of builtins to skip checking',
help="A comma separated list of builtins to skip checking",
)
option_manager.add_option(
"--builtins-allowed-modules",
metavar="builtins",
parse_from_config=True,
comma_separated_list=True,
help="A comma separated list of builtin module names to allow",
)

@classmethod
Expand All @@ -43,33 +56,42 @@ def parse_options(cls, options):
cls.names = {
a[0] for a in inspect.getmembers(builtins) if a[0] not in cls.ignore_list
}
flake8_builtins = getattr(options, 'builtins', None)
flake8_builtins = getattr(options, "builtins", None)
if flake8_builtins:
cls.names.update(flake8_builtins)

if options.builtins_allowed_modules is not None:
cls.ignored_module_names.update(options.builtins_allowed_modules)

cls.module_names = {
m for m in sys.stdlib_module_names if m not in cls.ignored_module_names
}

def run(self):
tree = self.tree

if self.filename == 'stdin':
if self.filename == "stdin":
lines = stdin_utils.stdin_get_value()
tree = ast.parse(lines)
else:
yield from self.check_module_name(self.filename)

for statement in ast.walk(tree):
for child in ast.iter_child_nodes(statement):
child.__flake8_builtins_parent = statement

function_nodes = [ast.FunctionDef]
if getattr(ast, 'AsyncFunctionDef', None):
if getattr(ast, "AsyncFunctionDef", None):
function_nodes.append(ast.AsyncFunctionDef)
function_nodes = tuple(function_nodes)

for_nodes = [ast.For]
if getattr(ast, 'AsyncFor', None):
if getattr(ast, "AsyncFor", None):
for_nodes.append(ast.AsyncFor)
for_nodes = tuple(for_nodes)

with_nodes = [ast.With]
if getattr(ast, 'AsyncWith', None):
if getattr(ast, "AsyncWith", None):
with_nodes.append(ast.AsyncWith)
with_nodes = tuple(with_nodes)

Expand Down Expand Up @@ -126,13 +148,13 @@ def check_assignment(self, statement):
elif isinstance(item, ast.Name) and item.id in self.names:
yield self.error(item, message=msg, variable=item.id)
elif isinstance(item, ast.Starred):
if hasattr(item.value, 'id') and item.value.id in self.names:
if hasattr(item.value, "id") and item.value.id in self.names:
yield self.error(
statement,
message=msg,
variable=item.value.id,
)
elif hasattr(item.value, 'elts'):
elif hasattr(item.value, "elts"):
stack.extend(list(item.value.elts))

def check_function_definition(self, statement):
Expand All @@ -145,8 +167,8 @@ def check_function_definition(self, statement):

all_arguments = []
all_arguments.extend(statement.args.args)
all_arguments.extend(getattr(statement.args, 'kwonlyargs', []))
all_arguments.extend(getattr(statement.args, 'posonlyargs', []))
all_arguments.extend(getattr(statement.args, "kwonlyargs", []))
all_arguments.extend(getattr(statement.args, "posonlyargs", []))

for arg in all_arguments:
if isinstance(arg, ast.arg) and arg.arg in self.names:
Expand All @@ -165,12 +187,12 @@ def check_for_loop(self, statement):
elif isinstance(item, ast.Name) and item.id in self.names:
yield self.error(statement, variable=item.id)
elif isinstance(item, ast.Starred):
if hasattr(item.value, 'id') and item.value.id in self.names:
if hasattr(item.value, "id") and item.value.id in self.names:
yield self.error(
statement,
variable=item.value.id,
)
elif hasattr(item.value, 'elts'):
elif hasattr(item.value, "elts"):
stack.extend(list(item.value.elts))

def check_with(self, statement):
Expand Down Expand Up @@ -234,13 +256,24 @@ def check_class(self, statement):
if statement.name in self.names:
yield self.error(statement, variable=statement.name)

def error(self, statement, variable, message=None):
def error(self, statement=None, variable=None, message=None):
if not message:
message = self.assign_msg

# lineno and col_offset must be integers
return (
statement.lineno,
statement.col_offset,
statement.lineno if statement else 0,
statement.col_offset if statement else 0,
message.format(variable),
type(self),
)

def check_module_name(self, filename: str):
path = Path(filename)
module_name = path.name.removesuffix(".py")
if module_name in self.module_names:
yield self.error(
None,
module_name,
message=self.module_name_msg,
)
13 changes: 12 additions & 1 deletion run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@
class FakeOptions:
builtins_ignorelist = []
builtins = None
builtins_allowed_modules = None

def __init__(self, ignore_list='', builtins=None):
def __init__(self, ignore_list='', builtins=None, builtins_allowed_modules=None):
if ignore_list:
self.builtins_ignorelist = ignore_list
if builtins:
self.builtins = builtins
if builtins_allowed_modules:
self.builtins_allowed_modules = builtins_allowed_modules


def check_code(source, expected_codes=None, ignore_list=None, builtins=None, filename='/home/script.py'):
Expand Down Expand Up @@ -471,3 +474,11 @@ def test_tuple_unpacking():
check_code(source)


def test_module_name():
source = ''
check_code(source, expected_codes='A005', filename='./temp/logging.py')


def test_module_name_not_builtin():
source = ''
check_code(source, filename='log_config')

0 comments on commit 06a3280

Please sign in to comment.