Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 71 additions & 3 deletions refurb/checks/logical/use_in.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,26 @@
from dataclasses import dataclass

from mypy.nodes import OpExpr

from refurb.checks.common import get_common_expr_in_comparison_chain
from mypy.nodes import (
BytesExpr,
CallExpr,
ComparisonExpr,
ComplexExpr,
Expression,
FloatExpr,
IndexExpr,
IntExpr,
MemberExpr,
NameExpr,
OpExpr,
StrExpr,
UnaryExpr,
)

from refurb.checks.common import (
extract_binary_oper,
get_common_expr_in_comparison_chain,
get_common_expr_positions,
)
from refurb.error import Error


Expand Down Expand Up @@ -36,6 +54,52 @@ class ErrorInfo(Error):
categories = ("logical", "readability")


def _is_simple_expr(node: Expression) -> bool:
"""
Check if an expression is simple enough to be safely eagerly evaluated.

Simple expressions are those that cannot raise exceptions or have side
effects when evaluated, making them safe for use in `in` tuple checks
where short-circuit evaluation is lost.
"""
match node:
case NameExpr() | IntExpr() | StrExpr() | BytesExpr() | FloatExpr() | ComplexExpr():
return True

case MemberExpr(expr=expr):
return _is_simple_expr(expr)

case UnaryExpr(expr=expr):
return _is_simple_expr(expr)

case OpExpr(left=left, right=right):
return _is_simple_expr(left) and _is_simple_expr(right)

case IndexExpr() | CallExpr():
return False

return False # pragma: no cover


def _get_non_common_operands(node: OpExpr) -> list[Expression] | None:
"""
Extract non-common operands from a comparison chain.

Given `a == b or c == d` where some operands are common,
returns the non-common operands (those that would be eagerly
evaluated in an `in` tuple).
"""
match extract_binary_oper("or", node):
case (
ComparisonExpr(operators=[lhs_oper], operands=[a, b]),
ComparisonExpr(operators=[rhs_oper], operands=[c, d]),
) if lhs_oper == rhs_oper == "==" and (indices := get_common_expr_positions(a, b, c, d)):
operands = [a, b, c, d]
return [op for i, op in enumerate(operands) if i not in indices]

return None # pragma: no cover


def create_message(indices: tuple[int, int]) -> str:
names = ["x", "y", "z"]
common_name = names[indices[0]]
Expand All @@ -53,4 +117,8 @@ def check(node: OpExpr, errors: list[Error]) -> None:
if data := get_common_expr_in_comparison_chain(node, oper="or"):
expr, indices = data

non_common = _get_non_common_operands(node)
if non_common is not None and not all(_is_simple_expr(op) for op in non_common):
return

errors.append(ErrorInfo.from_node(expr, create_message(indices)))
15 changes: 15 additions & 0 deletions test/data/err_108.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,22 @@ class C:
_ = "abc" == x or "def" == x
_ = "abc" == x or x == "def"

# simple compound expressions should still match
_ = x == c.y or x == "def"
_ = x == -1 or x == 1
a = 1
b = 2
_ = x == a + b or x == "def"

# these should not

_ = x == "abc" or y == "def"
_ = x == "abc" or x == "def" and y == "ghi"

# short-circuit dependent expressions should not match (see #350)
events = [1, 2, 3]
cutoff = 0
_ = cutoff == 0 or events[cutoff - 1] == 0
d = {"a": 1}
_ = x == "abc" or x == d["a"]
_ = x == len(x) or x == "abc"
3 changes: 3 additions & 0 deletions test/data/err_108.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ test/data/err_108.py:17:5 [FURB108]: Replace `x == y or x == z` with `x in (y, z
test/data/err_108.py:21:5 [FURB108]: Replace `x == y or z == x` with `x in (y, z)`
test/data/err_108.py:22:5 [FURB108]: Replace `x == y or z == y` with `y in (x, z)`
test/data/err_108.py:23:5 [FURB108]: Replace `x == y or y == z` with `y in (x, z)`
test/data/err_108.py:26:5 [FURB108]: Replace `x == y or x == z` with `x in (y, z)`
test/data/err_108.py:27:5 [FURB108]: Replace `x == y or x == z` with `x in (y, z)`
test/data/err_108.py:30:5 [FURB108]: Replace `x == y or x == z` with `x in (y, z)`