diff --git a/rewrite/rewrite/python/format/spaces_visitor.py b/rewrite/rewrite/python/format/spaces_visitor.py index 49956cff..1bfcc228 100644 --- a/rewrite/rewrite/python/format/spaces_visitor.py +++ b/rewrite/rewrite/python/format/spaces_visitor.py @@ -6,8 +6,8 @@ MethodDeclaration, Empty, ArrayAccess, Space, If, Block, ClassDeclaration, VariableDeclarations, JRightPadded, \ Import, ParameterizedType, Parentheses, WhileLoop from rewrite.python import PythonVisitor, SpacesStyle, Binary, ChainedAssignment, Slice, CollectionLiteral, \ - ForLoop, DictLiteral, KeyValue, TypeHint, MultiImport, ExpressionTypeTree, ComprehensionExpression -from rewrite.visitor import P + ForLoop, DictLiteral, KeyValue, TypeHint, MultiImport, ExpressionTypeTree, ComprehensionExpression, NamedArgument +from rewrite.visitor import P, Cursor class SpacesVisitor(PythonVisitor): @@ -99,6 +99,39 @@ def _process_argument(index, arg, args_size): ) ) + def visit_named_argument(self, named_argument: NamedArgument, p: P) -> J: + a = cast(NamedArgument, super().visit_named_argument(named_argument, p)) + if a.padding.value is not None: + a = a.padding.with_value( + space_before_left_padded(a.padding.value, self._style.around_operators.eq_in_keyword_argument)) + return a.padding.with_value( + space_before_left_padded_element(a.padding.value, self._style.around_operators.eq_in_keyword_argument)) + return a + + @staticmethod + def _part_of_method_header(cursor: Cursor) -> bool: + if (c := cursor.parent_tree_cursor()) and isinstance(c.value, VariableDeclarations): + return c.parent_tree_cursor() is not None and isinstance(c.parent_tree_cursor().value, MethodDeclaration) + return False + + def visit_variable(self, named_variable: VariableDeclarations.NamedVariable, p: P) -> J: + v = cast(VariableDeclarations.NamedVariable, super().visit_variable(named_variable, p)) + + # Check if the variable is a named parameter in a method declaration + if not self._part_of_method_header(self.cursor): + return v + + if v.padding.initializer is not None and v.padding.initializer.element is not None: + use_space = self._style.around_operators.eq_in_named_parameter or v.variable_type is not None + # Argument with a typehint will always receive a space e.g. foo(a: int =1) <-> foo(a: int = 1) + use_space |= self.cursor.first_enclosing_or_throw(VariableDeclarations).type_expression is not None + + v = v.padding.with_initializer( + space_before_left_padded(v.padding.initializer, use_space)) + v = v.padding.with_initializer( + space_before_left_padded_element(v.padding.initializer, use_space)) + return v + def visit_block(self, block: Block, p: P) -> J: b = cast(Block, super().visit_block(block, p)) b = space_before(b, self._style.other.before_colon) @@ -226,8 +259,10 @@ def visit_binary(self, binary: j.Binary, p: P) -> J: b = self._apply_binary_space_around(b, self._style.around_operators.additive) elif op in [j.Binary.Type.Multiplication, j.Binary.Type.Division, j.Binary.Type.Modulo]: b = self._apply_binary_space_around(b, self._style.around_operators.multiplicative) + elif op in [j.Binary.Type.Equal, j.Binary.Type.NotEqual]: + b = self._apply_binary_space_around(b, self._style.around_operators.equality) elif op in [j.Binary.Type.LessThan, j.Binary.Type.GreaterThan, j.Binary.Type.LessThanOrEqual, - j.Binary.Type.GreaterThanOrEqual, j.Binary.Type.Equal, j.Binary.Type.NotEqual]: + j.Binary.Type.GreaterThanOrEqual]: b = self._apply_binary_space_around(b, self._style.around_operators.relational) elif op in [j.Binary.Type.BitAnd, j.Binary.Type.BitOr, j.Binary.Type.BitXor]: b = self._apply_binary_space_around(b, self._style.around_operators.bitwise) @@ -246,10 +281,13 @@ def visit_python_binary(self, binary: Binary, p: P) -> J: if op == Binary.Type.In or op == Binary.Type.Is or op == Binary.Type.IsNot or op == Binary.Type.NotIn: # TODO: Not sure what style options to use for these operators b = self._apply_binary_space_around(b, True) - elif op == Binary.Type.FloorDivision or op == Binary.Type.MatrixMultiplication or op == Binary.Type.Power: + elif op in [Binary.Type.FloorDivision, Binary.Type.MatrixMultiplication]: b = self._apply_binary_space_around(b, self._style.around_operators.multiplicative) elif op == Binary.Type.StringConcatenation: b = self._apply_binary_space_around(b, self._style.around_operators.additive) + elif op == Binary.Type.Power: + b = self._apply_binary_space_around(b, self._style.around_operators.power) + return b def visit_if(self, if_stm: If, p: P) -> J: diff --git a/rewrite/tests/python/all/format/spaces/around_operator_test.py b/rewrite/tests/python/all/format/spaces/around_operator_test.py new file mode 100644 index 00000000..5b0fa15f --- /dev/null +++ b/rewrite/tests/python/all/format/spaces/around_operator_test.py @@ -0,0 +1,461 @@ +from typing import Callable + +import pytest + +from rewrite.python import IntelliJ, SpacesVisitor, SpacesStyle +from rewrite.test import rewrite_run, python, RecipeSpec, from_visitor + + +def get_around_operator_style( + _with: Callable[[SpacesStyle.AroundOperators], SpacesStyle.AroundOperators]) -> SpacesStyle: + style = IntelliJ.spaces() + return style.with_around_operators( + _with(style.around_operators) + ) + + +def test_spaces_around_assignment(): + style = get_around_operator_style(lambda x: x.with_assignment(True)) + rewrite_run( + # language=python + python( + """ + a = 1 + a= 1 + a =1 + a=1 + """, + """ + a = 1 + a = 1 + a = 1 + a = 1 + """ + ), + spec=RecipeSpec() + .with_recipe(from_visitor(SpacesVisitor(style))) + ) + + style = get_around_operator_style(lambda x: x.with_assignment(False)) + rewrite_run( + # language=python + python( + """ + a = 1 + a= 1 + a =1 + a=1 + """, + """ + a=1 + a=1 + a=1 + a=1 + """ + ), + spec=RecipeSpec() + .with_recipe(from_visitor(SpacesVisitor(style))) + ) + + +def test_spaces_around_equality(): + style = get_around_operator_style(lambda x: x.with_equality(True)) + rewrite_run( + # language=python + python( + """ + a == 1 + a== 1 + a ==1 + a==1 + """, + """ + a == 1 + a == 1 + a == 1 + a == 1 + """ + ), + spec=RecipeSpec() + .with_recipe(from_visitor(SpacesVisitor(style))) + ) + + style = get_around_operator_style(lambda x: x.with_equality(False)) + rewrite_run( + # language=python + python( + """ + a == 1 + a== 1 + a ==1 + a==1 + """, + """ + a==1 + a==1 + a==1 + a==1 + """ + ), + spec=RecipeSpec() + .with_recipe(from_visitor(SpacesVisitor(style))) + ) + + +def test_spaces_around_relational(): + style = get_around_operator_style(lambda x: x.with_relational(True)) + rewrite_run( + # language=python + python( + """ + a < 1 + a< 1 + a <1 + a<1 + """, + """ + a < 1 + a < 1 + a < 1 + a < 1 + """ + ), + spec=RecipeSpec() + .with_recipe(from_visitor(SpacesVisitor(style))) + ) + + style = get_around_operator_style(lambda x: x.with_relational(False)) + rewrite_run( + # language=python + python( + """ + a < 1 + a< 1 + a <1 + a<1 + """, + """ + a<1 + a<1 + a<1 + a<1 + """ + ), + spec=RecipeSpec() + .with_recipe(from_visitor(SpacesVisitor(style))) + ) + + +def test_spaces_around_bitwise(): + style = get_around_operator_style(lambda x: x.with_bitwise(True)) + rewrite_run( + # language=python + python( + """ + a & 1 + a& 1 + a &1 + a&1 + """, + """ + a & 1 + a & 1 + a & 1 + a & 1 + """ + ), + spec=RecipeSpec() + .with_recipe(from_visitor(SpacesVisitor(style))) + ) + + style = get_around_operator_style(lambda x: x.with_bitwise(False)) + rewrite_run( + # language=python + python( + """ + a & 1 + a& 1 + a &1 + a&1 + """, + """ + a&1 + a&1 + a&1 + a&1 + """ + ), + spec=RecipeSpec() + .with_recipe(from_visitor(SpacesVisitor(style))) + ) + + +def test_spaces_around_additive(): + style = get_around_operator_style(lambda x: x.with_additive(True)) + rewrite_run( + # language=python + python( + """ + a + 1 + a+ 1 + a +1 + a+1 + """, + """ + a + 1 + a + 1 + a + 1 + a + 1 + """ + ), + spec=RecipeSpec() + .with_recipe(from_visitor(SpacesVisitor(style))) + ) + + style = get_around_operator_style(lambda x: x.with_additive(False)) + rewrite_run( + # language=python + python( + """ + a + 1 + a+ 1 + a +1 + a+1 + """, + """ + a+1 + a+1 + a+1 + a+1 + """ + ), + spec=RecipeSpec() + .with_recipe(from_visitor(SpacesVisitor(style))) + ) + + +def test_spaces_around_multiplicative(): + style = get_around_operator_style(lambda x: x.with_multiplicative(True)) + rewrite_run( + # language=python + python( + """ + a * 1 + a* 1 + a *1 + a*1 + """, + """ + a * 1 + a * 1 + a * 1 + a * 1 + """ + ), + spec=RecipeSpec() + .with_recipe(from_visitor(SpacesVisitor(style))) + ) + + style = get_around_operator_style(lambda x: x.with_multiplicative(False)) + rewrite_run( + # language=python + python( + """ + a * 1 + a* 1 + a *1 + a*1 + """, + """ + a*1 + a*1 + a*1 + a*1 + """ + ), + spec=RecipeSpec() + .with_recipe(from_visitor(SpacesVisitor(style))) + ) + + +def test_spaces_around_shift(): + style = get_around_operator_style(lambda x: x.with_shift(True)) + rewrite_run( + # language=python + python( + """ + a << 1 + a<< 1 + a <<1 + a<<1 + """, + """ + a << 1 + a << 1 + a << 1 + a << 1 + """ + ), + spec=RecipeSpec() + .with_recipe(from_visitor(SpacesVisitor(style))) + ) + + style = get_around_operator_style(lambda x: x.with_shift(False)) + rewrite_run( + # language=python + python( + """ + a << 1 + a<< 1 + a <<1 + a<<1 + """, + """ + a<<1 + a<<1 + a<<1 + a<<1 + """ + ), + spec=RecipeSpec() + .with_recipe(from_visitor(SpacesVisitor(style))) + ) + + +def test_spaces_around_power(): + style = get_around_operator_style(lambda x: x.with_power(True)) + rewrite_run( + # language=python + python( + """ + a ** 1 + a** 1 + a **1 + a**1 + """, + """ + a ** 1 + a ** 1 + a ** 1 + a ** 1 + """ + ), + spec=RecipeSpec() + .with_recipe(from_visitor(SpacesVisitor(style))) + ) + + style = get_around_operator_style(lambda x: x.with_power(False)) + rewrite_run( + # language=python + python( + """ + a ** 1 + a** 1 + a **1 + a**1 + """, + """ + a**1 + a**1 + a**1 + a**1 + """ + ), + spec=RecipeSpec() + .with_recipe(from_visitor(SpacesVisitor(style))) + ) + + +def test_spaces_around_eq_in_named_parameter(): + style = get_around_operator_style(lambda x: x.with_eq_in_named_parameter(True)) + rewrite_run( + # language=python + python( + """ + def func(a = 1): pass + def func(a= 1): pass + def func(a =1): pass + def func(a=1): pass + def func(a: int=1): pass + def func(a: int =1): pass + """, + """ + def func(a = 1): pass + def func(a = 1): pass + def func(a = 1): pass + def func(a = 1): pass + def func(a: int = 1): pass + def func(a: int = 1): pass + """ + ), + spec=RecipeSpec() + .with_recipe(from_visitor(SpacesVisitor(style))) + ) + + style = get_around_operator_style(lambda x: x.with_eq_in_named_parameter(False)) + rewrite_run( + # language=python + python( + """ + def func(a = 1): pass + def func(a= 1): pass + def func(a =1): pass + def func(a=1): pass + def func(a : int =1): pass + def func(a : int =1): pass + """, + """ + def func(a=1): pass + def func(a=1): pass + def func(a=1): pass + def func(a=1): pass + def func(a: int = 1): pass + def func(a: int = 1): pass + """ + ), + spec=RecipeSpec() + .with_recipe(from_visitor(SpacesVisitor(style))) + ) + +def test_spaces_around_eq_in_keyword_argument(): + style = get_around_operator_style(lambda x: x.with_eq_in_keyword_argument(True)) + rewrite_run( + # language=python + python( + """ + func(a = 1) + func(a= 1) + func(a =1) + func(a=1) + """, + """ + func(a = 1) + func(a = 1) + func(a = 1) + func(a = 1) + """ + ), + spec=RecipeSpec() + .with_recipe(from_visitor(SpacesVisitor(style))) + ) + + style = get_around_operator_style(lambda x: x.with_eq_in_keyword_argument(False)) + rewrite_run( + # language=python + python( + """ + func(a = 1) + func(a= 1) + func(a =1) + func(a=1) + """, + """ + func(a=1) + func(a=1) + func(a=1) + func(a=1) + """ + ), + spec=RecipeSpec() + .with_recipe(from_visitor(SpacesVisitor(style))) + ) diff --git a/rewrite/tests/python/all/format/spaces/method_declaration_spaces_test.py b/rewrite/tests/python/all/format/spaces/method_declaration_spaces_test.py index c13f8322..0abdb435 100644 --- a/rewrite/tests/python/all/format/spaces/method_declaration_spaces_test.py +++ b/rewrite/tests/python/all/format/spaces/method_declaration_spaces_test.py @@ -117,15 +117,15 @@ def test_spaces_after_within_method_declaration_type_hints(): # language=python python( """ - def x(a : int , b : int = 2): + def x(a : int , b = "foo", c : int=2): pass - def y(a :int , b : int = 2): + def y(a :int , b = "foo", c : int=2): pass """, """ - def x(a: int, b: int = 2): + def x(a: int, b="foo", c: int = 2): pass - def y(a: int, b: int = 2): + def y(a: int, b="foo", c: int = 2): pass """ ),