diff --git a/refurb/checks/logical/use_in.py b/refurb/checks/logical/use_in.py index 4e06f08..cad5e99 100644 --- a/refurb/checks/logical/use_in.py +++ b/refurb/checks/logical/use_in.py @@ -1,8 +1,26 @@ from dataclasses import dataclass -from mypy.nodes import OpExpr - -from refurb.checks.common import get_common_expr_in_comparison_chain +from mypy.nodes import ( + BytesExpr, + CallExpr, + ComparisonExpr, + ComplexExpr, + Expression, + FloatExpr, + IndexExpr, + IntExpr, + MemberExpr, + NameExpr, + OpExpr, + StrExpr, + UnaryExpr, +) + +from refurb.checks.common import ( + extract_binary_oper, + get_common_expr_in_comparison_chain, + get_common_expr_positions, +) from refurb.error import Error @@ -36,6 +54,52 @@ class ErrorInfo(Error): categories = ("logical", "readability") +def _is_simple_expr(node: Expression) -> bool: + """ + Check if an expression is simple enough to be safely eagerly evaluated. + + Simple expressions are those that cannot raise exceptions or have side + effects when evaluated, making them safe for use in `in` tuple checks + where short-circuit evaluation is lost. + """ + match node: + case NameExpr() | IntExpr() | StrExpr() | BytesExpr() | FloatExpr() | ComplexExpr(): + return True + + case MemberExpr(expr=expr): + return _is_simple_expr(expr) + + case UnaryExpr(expr=expr): + return _is_simple_expr(expr) + + case OpExpr(left=left, right=right): + return _is_simple_expr(left) and _is_simple_expr(right) + + case IndexExpr() | CallExpr(): + return False + + return False # pragma: no cover + + +def _get_non_common_operands(node: OpExpr) -> list[Expression] | None: + """ + Extract non-common operands from a comparison chain. + + Given `a == b or c == d` where some operands are common, + returns the non-common operands (those that would be eagerly + evaluated in an `in` tuple). + """ + match extract_binary_oper("or", node): + case ( + ComparisonExpr(operators=[lhs_oper], operands=[a, b]), + ComparisonExpr(operators=[rhs_oper], operands=[c, d]), + ) if lhs_oper == rhs_oper == "==" and (indices := get_common_expr_positions(a, b, c, d)): + operands = [a, b, c, d] + return [op for i, op in enumerate(operands) if i not in indices] + + return None # pragma: no cover + + def create_message(indices: tuple[int, int]) -> str: names = ["x", "y", "z"] common_name = names[indices[0]] @@ -53,4 +117,8 @@ def check(node: OpExpr, errors: list[Error]) -> None: if data := get_common_expr_in_comparison_chain(node, oper="or"): expr, indices = data + non_common = _get_non_common_operands(node) + if non_common is not None and not all(_is_simple_expr(op) for op in non_common): + return + errors.append(ErrorInfo.from_node(expr, create_message(indices))) diff --git a/test/data/err_108.py b/test/data/err_108.py index 5f19491..310363a 100644 --- a/test/data/err_108.py +++ b/test/data/err_108.py @@ -22,7 +22,22 @@ class C: _ = "abc" == x or "def" == x _ = "abc" == x or x == "def" +# simple compound expressions should still match +_ = x == c.y or x == "def" +_ = x == -1 or x == 1 +a = 1 +b = 2 +_ = x == a + b or x == "def" + # these should not _ = x == "abc" or y == "def" _ = x == "abc" or x == "def" and y == "ghi" + +# short-circuit dependent expressions should not match (see #350) +events = [1, 2, 3] +cutoff = 0 +_ = cutoff == 0 or events[cutoff - 1] == 0 +d = {"a": 1} +_ = x == "abc" or x == d["a"] +_ = x == len(x) or x == "abc" diff --git a/test/data/err_108.txt b/test/data/err_108.txt index 81f89a1..af9d911 100644 --- a/test/data/err_108.txt +++ b/test/data/err_108.txt @@ -7,3 +7,6 @@ test/data/err_108.py:17:5 [FURB108]: Replace `x == y or x == z` with `x in (y, z test/data/err_108.py:21:5 [FURB108]: Replace `x == y or z == x` with `x in (y, z)` test/data/err_108.py:22:5 [FURB108]: Replace `x == y or z == y` with `y in (x, z)` test/data/err_108.py:23:5 [FURB108]: Replace `x == y or y == z` with `y in (x, z)` +test/data/err_108.py:26:5 [FURB108]: Replace `x == y or x == z` with `x in (y, z)` +test/data/err_108.py:27:5 [FURB108]: Replace `x == y or x == z` with `x in (y, z)` +test/data/err_108.py:30:5 [FURB108]: Replace `x == y or x == z` with `x in (y, z)`