diff --git a/CHANGELOG.md b/CHANGELOG.md index f566ce5be..2c5e2cac8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -83,6 +83,10 @@ Semantic versioning in our case means: - Updates GitHub Action's base Python image version to `3.8.8` +### Features + +- Adds a math operations evaluator to improve and allow several violation checks. + ## 0.15.1 diff --git a/tests/test_transformations/test_enhancements.py b/tests/test_transformations/test_enhancements.py new file mode 100644 index 000000000..75643737e --- /dev/null +++ b/tests/test_transformations/test_enhancements.py @@ -0,0 +1,41 @@ +import pytest + + +@pytest.mark.parametrize(('expression', 'output'), [ + ('-1 + 1', 0), + ('1 * 2', 2), + ('"a" * 5', 'aaaaa'), + ('b"hello" * 2', b'hellohello'), + ('"hello " + "world"', 'hello world'), + ('(2 + 6) / 4 - 2', 0), + ('1 << 4', 16), + ('255 >> 4', 15), + ('2**4', 16), + ('5^9', 12), + ('12 & 24', 8), + ('6 | 9', 15), + ('5 % 3', 2), + ('4 // 3', 1), + ('(6 - 2) * ((3 << 3) // 10) % 5 | 7**2', 51), +]) +def test_evaluate_valid_operations(parse_ast_tree, expression: str, output): + """Tests that the operations are correctly evaluated.""" + tree = parse_ast_tree(expression) + assert tree.body[0].value.wps_op_eval == output + + +@pytest.mark.parametrize('expression', [ + 'x * 2', + 'x << y', + '-x + y', + '0 / 0', + '"a" * 2.1', + '"a" + 1', + '3 << 1.5', + '((4 - 1) * 3 - 9) // (7 >> 4)', + '[[1, 0], [0, 1]] @ [[1, 1], [0, 0]]', +]) +def test_evaluate_invalid_operations(parse_ast_tree, expression: str): + """Tests that the operations can not be evaluated and thus return None.""" + tree = parse_ast_tree(expression) + assert tree.body[0].value.wps_op_eval is None diff --git a/tests/test_visitors/test_ast/test_builtins/test_collection_hashes/test_float_keys.py b/tests/test_visitors/test_ast/test_builtins/test_collection_hashes/test_float_keys.py index 280aca768..a6857ad39 100644 --- a/tests/test_visitors/test_ast/test_builtins/test_collection_hashes/test_float_keys.py +++ b/tests/test_visitors/test_ast/test_builtins/test_collection_hashes/test_float_keys.py @@ -21,6 +21,9 @@ '1.0', '-0.3', '+0.0', + '1 / 3', + '-1 - 0.5', + '0 + 0.1', ]) def test_dict_with_float_key( assert_errors, @@ -47,9 +50,6 @@ def test_dict_with_float_key( @pytest.mark.parametrize('element', [ '1', '"-0.3"', - '0 + 0.1', - '0 - 1.0', - '1 / 3', '1 // 3', 'call()', 'name', diff --git a/wemake_python_styleguide/logic/nodes.py b/wemake_python_styleguide/logic/nodes.py index feb10dfbc..c994aec5d 100644 --- a/wemake_python_styleguide/logic/nodes.py +++ b/wemake_python_styleguide/logic/nodes.py @@ -1,6 +1,7 @@ import ast -from typing import Optional +from typing import Optional, Union +from wemake_python_styleguide.logic.safe_eval import literal_eval_with_names from wemake_python_styleguide.types import ContextNodes @@ -26,3 +27,16 @@ def get_parent(node: ast.AST) -> Optional[ast.AST]: def get_context(node: ast.AST) -> Optional[ContextNodes]: """Returns the context or ``None`` if node has no context.""" return getattr(node, 'wps_context', None) + + +def evaluate_node(node: ast.AST) -> Optional[Union[int, float, str, bytes]]: + """Returns the value of a node or its evaluation.""" + if isinstance(node, ast.Name): + return None + if isinstance(node, (ast.Str, ast.Bytes)): + return node.s + try: + signed_node = literal_eval_with_names(node) + except Exception: + return None + return signed_node diff --git a/wemake_python_styleguide/transformations/ast/enhancements.py b/wemake_python_styleguide/transformations/ast/enhancements.py index d5c24769c..4013d198d 100644 --- a/wemake_python_styleguide/transformations/ast/enhancements.py +++ b/wemake_python_styleguide/transformations/ast/enhancements.py @@ -1,8 +1,13 @@ import ast -from typing import Optional, Tuple, Type +import operator +from contextlib import suppress +from types import MappingProxyType +from typing import Optional, Tuple, Type, Union + +from typing_extensions import Final from wemake_python_styleguide.compat.aliases import FunctionNodes -from wemake_python_styleguide.logic.nodes import get_parent +from wemake_python_styleguide.logic.nodes import evaluate_node, get_parent from wemake_python_styleguide.types import ContextNodes _CONTEXTS: Tuple[Type[ContextNodes], ...] = ( @@ -11,6 +16,21 @@ *FunctionNodes, ) +_AST_OPS_TO_OPERATORS: Final = MappingProxyType({ + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.FloorDiv: operator.floordiv, + ast.Mod: operator.mod, + ast.Pow: operator.pow, + ast.LShift: operator.lshift, + ast.RShift: operator.rshift, + ast.BitAnd: operator.and_, + ast.BitOr: operator.or_, + ast.BitXor: operator.xor, +}) + def set_if_chain(tree: ast.AST) -> ast.AST: """ @@ -71,6 +91,31 @@ def set_node_context(tree: ast.AST) -> ast.AST: return tree +def set_constant_evaluations(tree: ast.AST) -> ast.AST: + """ + Used to evaluate operations between constants. + + We want this to be able to analyze parts of the code in which a math + operation is making the linter unable to understand if the code is + compliant or not. + + Example: + .. code:: python + + value = array[1 + 0.5] + + This should not be allowed, because we would be using a float to index an + array, but since there is an addition, the linter does not know that and + does not raise an error. + """ + for stmt in ast.walk(tree): + parent = get_parent(stmt) + if isinstance(stmt, ast.BinOp) and not isinstance(parent, ast.BinOp): + evaluation = evaluate_operation(stmt) + setattr(stmt, 'wps_op_eval', evaluation) # noqa: B010 + return tree + + def _find_context( node: ast.AST, contexts: Tuple[Type[ast.AST], ...], @@ -96,3 +141,27 @@ def _apply_if_statement(statement: ast.If) -> None: if child in statement.orelse: setattr(statement, 'wps_if_chained', True) # noqa: B010 setattr(child, 'wps_if_chain', statement) # noqa: B010 + + +def evaluate_operation( + statement: ast.BinOp, +) -> Optional[Union[int, float, str, bytes]]: + """Tries to evaluate all math operations inside the statement.""" + if isinstance(statement.left, ast.BinOp): + left = evaluate_operation(statement.left) + else: + left = evaluate_node(statement.left) + + if isinstance(statement.right, ast.BinOp): + right = evaluate_operation(statement.right) + else: + right = evaluate_node(statement.right) + + op = _AST_OPS_TO_OPERATORS.get(type(statement.op)) + + evaluation = None + if op is not None: + with suppress(Exception): + evaluation = op(left, right) + + return evaluation diff --git a/wemake_python_styleguide/transformations/ast_tree.py b/wemake_python_styleguide/transformations/ast_tree.py index 547028a6e..b55be1449 100644 --- a/wemake_python_styleguide/transformations/ast_tree.py +++ b/wemake_python_styleguide/transformations/ast_tree.py @@ -8,6 +8,7 @@ fix_line_number, ) from wemake_python_styleguide.transformations.ast.enhancements import ( + set_constant_evaluations, set_if_chain, set_node_context, ) @@ -85,6 +86,7 @@ def transform(tree: ast.AST) -> ast.AST: # Enhancements, order is not important: set_node_context, set_if_chain, + set_constant_evaluations, ) for transformation in pipeline: diff --git a/wemake_python_styleguide/visitors/ast/builtins.py b/wemake_python_styleguide/visitors/ast/builtins.py index 3aa8dd109..24c1d72ce 100644 --- a/wemake_python_styleguide/visitors/ast/builtins.py +++ b/wemake_python_styleguide/visitors/ast/builtins.py @@ -449,12 +449,17 @@ def _check_float_keys(self, keys: _HashItems) -> None: if dict_key is None: continue + evaluates_to_float = False + if isinstance(dict_key, ast.BinOp): + evaluated_key = getattr(dict_key, 'wps_op_eval', None) + evaluates_to_float = isinstance(evaluated_key, float) + real_key = operators.unwrap_unary_node(dict_key) is_float_key = ( isinstance(real_key, ast.Num) and isinstance(real_key.n, float) ) - if is_float_key: + if is_float_key or evaluates_to_float: self.add_violation(best_practices.FloatKeyViolation(dict_key)) def _check_unhashable_elements(