Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions onnxscript/rewriter/_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,12 @@ def _match_single_output_node(
if output_values is None:
# TODO(rama): Is this a valid (useful) case?
return match
if check_removable and not _valid_to_replace(match.nodes, output_values):
# TODO(rama): Match status should be updated to reflect failure reason.
return match.fail("Matched nodes have other uses preventing replacement.")
# Skip removability check for flexible output nodes since they may have
# additional outputs beyond those captured in the pattern
if check_removable and not pattern.output_node.allow_flexible_outputs:
Comment on lines +313 to +315
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skipping _valid_to_replace for flexible-output patterns can make rewrites unsafe: if a variadic-output node has additional outputs used outside the matched subgraph, the match will still succeed and remove_nodes=True can delete a producer of still-used values unless the replacement reliably substitutes all those outputs. Consider running the removability check against the full output list of the matched node when allow_flexible_outputs is set, and/or require the replacement to return exactly len(matched_node.outputs) outputs when nodes are being removed.

Suggested change
# Skip removability check for flexible output nodes since they may have
# additional outputs beyond those captured in the pattern
if check_removable and not pattern.output_node.allow_flexible_outputs:
if check_removable:

Copilot uses AI. Check for mistakes.
if not _valid_to_replace(match.nodes, output_values):
# TODO(rama): Match status should be updated to reflect failure reason.
return match.fail("Matched nodes have other uses preventing replacement.")

match.outputs.extend(output_values)
return match
Expand Down
7 changes: 7 additions & 0 deletions onnxscript/rewriter/_pattern_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def __call__(
_outputs: int | list[str | None] = 1,
_allow_other_attributes: bool | None = None,
_allow_other_inputs: bool | None = None,
_allow_flexible_outputs: bool | None = None,
_check: Callable | None = None,
**kwargs,
):
Expand Down Expand Up @@ -280,6 +281,7 @@ def __call__(
_outputs,
allow_other_attributes=_allow_other_attributes,
allow_other_inputs=_allow_other_inputs,
allow_flexible_outputs=_allow_flexible_outputs,
check=_check,
)
self.pattern_builder.add_node(node_pattern)
Expand Down Expand Up @@ -440,6 +442,7 @@ def __init__(
*,
allow_other_attributes: bool | None,
allow_other_inputs: bool | None,
allow_flexible_outputs: bool | None = None,
check: Callable | None = None,
):
if allow_other_attributes is None:
Expand All @@ -448,12 +451,16 @@ def __init__(
if allow_other_inputs is None:
# TODO(rama): Should we default to True? For now, we preserve the current behavior.
allow_other_inputs = False
if allow_flexible_outputs is None:
# Default behavior: do not match flexible outputs
allow_flexible_outputs = False
self.domain = domain
self.op = StringConstantPattern(op) if isinstance(op, str) else op
self.inputs = [_to_value_pattern(x) for x in inputs]
self.attributes = attributes
self.allow_other_attributes = allow_other_attributes
self.allow_other_inputs = allow_other_inputs
self.allow_flexible_outputs = allow_flexible_outputs
self._check = check
# In the common case, domain and op are constants, which can be used to optimize matching.
if isinstance(op, str) and isinstance(domain, StringConstantPattern):
Expand Down
49 changes: 39 additions & 10 deletions onnxscript/rewriter/_rewrite_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import abc
import dataclasses
import inspect
import itertools
from typing import (
Callable,
Expand Down Expand Up @@ -214,11 +215,14 @@ class ReplacementPatternFunction:

def __init__(self, function) -> None:
self._function = function
# Cache signature inspection to avoid repeated introspection on hot path
self._accepts_match = "_match" in inspect.signature(function).parameters

def get_replacement(self, match: _basics.MatchResult) -> ReplacementSubgraph | None:
context = RewriterContext()
bindings = match.bindings if not self._accepts_match else {**match.bindings, "_match": match}
try:
new_outputs = self._function(context, **match.bindings)
new_outputs = self._function(context, **bindings)
Comment on lines 216 to +225
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inspect.signature(function) can raise TypeError/ValueError for some callables (e.g., certain builtins, C-extension callables, or objects with non-introspectable __call__). That would make constructing ReplacementPatternFunction fail even though the callable is otherwise usable. Consider guarding this with try/except and defaulting to not injecting _match when the signature can't be inspected.

Copilot uses AI. Check for mistakes.
except _basics.MatchFailureError as e:
match.fail(e.reason, list(e.failure_sources))
return None
Expand Down Expand Up @@ -313,6 +317,11 @@ def __init__(
# Initialize the base pattern matching functionality
super().__init__(target_pattern, condition_function, matcher, verbose, name)

# Check if any node in the pattern uses flexible outputs (cache for hot path)
self._has_flexible_outputs = any(
node.allow_flexible_outputs for node in self._target_pattern._nodes
)

if not isinstance(replacement_pattern, ReplacementPatternFunction):
replacement_pattern = ReplacementPatternFunction(replacement_pattern)
self._replacement_pattern = replacement_pattern
Expand Down Expand Up @@ -357,7 +366,8 @@ def try_rewrite(
_basics.MatchStatus.REPLACEMENT_FAILED,
)
return None
if len(replacement_subgraph.new_outputs) != self._target_pattern.num_outputs:

if not self._has_flexible_outputs and len(replacement_subgraph.new_outputs) != self._target_pattern.num_outputs:
raise ValueError(
f"Number of outputs from replacement function does not match the number of outputs from the target pattern. "
f"Expected {self._target_pattern.num_outputs}, but got {len(replacement_subgraph.new_outputs)}."
Expand Down Expand Up @@ -766,14 +776,33 @@ def _apply_to_graph_or_function(
for n in delta.new_nodes:
n.metadata_props[RULE_NAME_TAG] = rule.name

convenience.replace_nodes_and_values(
graph_or_function,
node,
delta.match.nodes if rule.remove_nodes else [],
delta.new_nodes,
delta.match.outputs,
delta.new_outputs,
)
# Check if this is a flexible output case (matched node has more outputs than captured)
flexible_node = None
for matched_node in delta.match.nodes:
if len(matched_node.outputs) > len(delta.match.outputs):
flexible_node = matched_node
break

if flexible_node and len(delta.new_outputs) == len(flexible_node.outputs):
# Flexible output replacement: replace all outputs of the flexible node
convenience.replace_nodes_and_values(
graph_or_function,
node,
delta.match.nodes if rule.remove_nodes else [],
delta.new_nodes,
flexible_node.outputs,
delta.new_outputs,
)
else:
# Standard replacement
convenience.replace_nodes_and_values(
graph_or_function,
node,
delta.match.nodes if rule.remove_nodes else [],
delta.new_nodes,
delta.match.outputs,
delta.new_outputs,
)
Comment on lines +779 to +805
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The flexible-output replacement detection (len(matched_node.outputs) > len(delta.match.outputs)) is a heuristic that can misidentify the node (or miss it) when the pattern has multiple outputs / multiple matched nodes. Also, when the heuristic doesn’t trigger, the code falls back to standard replacement even though removability checks were skipped for flexible-output patterns, which can lead to removing a node whose other outputs are still used. Prefer identifying flexible nodes via delta.match.node_bindings + NodePattern.allow_flexible_outputs, and enforce that remove_nodes=True rules either (a) replace all outputs of that node or (b) don’t apply.

Copilot uses AI. Check for mistakes.

if merge_metadata:
_default_metadata_merger.copy_merged_metadata(
Expand Down
242 changes: 242 additions & 0 deletions onnxscript/rewriter/pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,5 +989,247 @@ def test_pattern_builder_context(self):
self.assertEqual(ops, ["Op1", "Op2", "Add", "Op3", "Mul"])


class FlexibleOutputTest(unittest.TestCase):
"""Test patterns with flexible output counts using _allow_flexible_outputs."""

def test_flexible_outputs_with_split(self):
"""Test that _allow_flexible_outputs works for Split with varying output counts."""

def relu_split_pattern(op, x):
relu = op.Relu(x)
return op.Split(relu, _allow_flexible_outputs=True)

def relu_split_rewrite(op, _match=None, x=None):
if x is None or _match is None:
return None

split = next((n for n in _match.nodes if n.op_type == "Split"), None)
if not split:
return None

num_outputs = len(split.outputs)
split_results = op.Split(x, _outputs=num_outputs, **split.attributes)

return tuple(op.Relu(s) for s in split_results) if num_outputs > 1 else op.Relu(split_results)

rule = pattern.RewriteRule(relu_split_pattern, relu_split_rewrite)

# Test model with Relu -> Split pattern
model_proto = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[10] x) => (float[5] out1, float[5] out2)
{
relu_out = Relu(x)
out1, out2 = Split<axis=0, num_outputs=2>(relu_out)
}
"""
)

optimized = onnxscript.rewriter.rewrite(model_proto, pattern_rewrite_rules=[rule])

# Verify transformation: 1 Relu + 1 Split -> 1 Split + 2 Relu
def count_ops(proto, op_type):
return sum(1 for n in proto.graph.node if n.op_type == op_type)

self.assertEqual(count_ops(model_proto, "Relu"), 1)
self.assertEqual(count_ops(model_proto, "Split"), 1)
self.assertEqual(count_ops(optimized, "Relu"), 2)
self.assertEqual(count_ops(optimized, "Split"), 1)

def test_flexible_outputs_without_match_parameter(self):
def pattern_func(op, x):
return op.Split(x, _allow_flexible_outputs=True)

def rewrite_func(op, x=None):
if x is None:
return None
return op.Split(op.Relu(x), _outputs=2)

rule = pattern.RewriteRule(pattern_func, rewrite_func)

model_proto = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[10] x) => (float[5] out1, float[5] out2)
{
out1, out2 = Split<axis=0, num_outputs=2>(x)
}
"""
)

optimized = onnxscript.rewriter.rewrite(model_proto, pattern_rewrite_rules=[rule])

def count_ops(proto, op_type):
return sum(1 for n in proto.graph.node if n.op_type == op_type)

self.assertEqual(count_ops(optimized, "Relu"), 1)
self.assertEqual(count_ops(optimized, "Split"), 1)

def test_output_count_validation_without_flexible(self):
def pattern_func(op, x):
return op.Relu(x)

def bad_rewrite_func(op, x=None):
return op.Split(x, _outputs=2)

rule = pattern.RewriteRule(pattern_func, bad_rewrite_func)

model_proto = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[10] x) => (float[10] out)
{
out = Relu(x)
}
"""
)

with self.assertRaises(ValueError) as ctx:
onnxscript.rewriter.rewrite(model_proto, pattern_rewrite_rules=[rule])

self.assertIn("Number of outputs", str(ctx.exception))

def test_standard_replacement_path(self):
def pattern_func(op, x):
return op.Relu(x)

def rewrite_func(op, x=None):
return op.Sigmoid(x)

rule = pattern.RewriteRule(pattern_func, rewrite_func)

model_proto = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[10] x) => (float[10] out)
{
out = Relu(x)
}
"""
)

optimized = onnxscript.rewriter.rewrite(model_proto, pattern_rewrite_rules=[rule])

def count_ops(proto, op_type):
return sum(1 for n in proto.graph.node if n.op_type == op_type)

self.assertEqual(count_ops(optimized, "Relu"), 0)
self.assertEqual(count_ops(optimized, "Sigmoid"), 1)

def test_flexible_outputs_with_three_outputs(self):
def pattern_func(op, x):
return op.Split(x, _allow_flexible_outputs=True)

def rewrite_func(op, _match=None, x=None):
if x is None or _match is None:
return None

split = next((n for n in _match.nodes if n.op_type == "Split"), None)
if not split:
return None

num_outputs = len(split.outputs)
relu = op.Relu(x)
split_results = op.Split(relu, _outputs=num_outputs, **split.attributes)
return split_results

rule = pattern.RewriteRule(pattern_func, rewrite_func)

model_proto = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[15] x) => (float[5] out1, float[5] out2, float[5] out3)
{
out1, out2, out3 = Split<axis=0, num_outputs=3>(x)
}
"""
)

optimized = onnxscript.rewriter.rewrite(model_proto, pattern_rewrite_rules=[rule])

def count_ops(proto, op_type):
return sum(1 for n in proto.graph.node if n.op_type == op_type)

self.assertEqual(count_ops(optimized, "Relu"), 1)
self.assertEqual(count_ops(optimized, "Split"), 1)

def test_flexible_outputs_with_single_output(self):
def pattern_func(op, x):
return op.Split(x, _allow_flexible_outputs=True)

def rewrite_func(op, _match=None, x=None):
if x is None or _match is None:
return None

split = next((n for n in _match.nodes if n.op_type == "Split"), None)
if not split:
return None

relu = op.Relu(x)
if len(split.outputs) == 1:
return op.Split(relu, _outputs=1, **split.attributes)
return relu

rule = pattern.RewriteRule(pattern_func, rewrite_func)

model_proto = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[10] x) => (float[10] out)
{
out = Split<axis=0>(x)
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test constructs a Split node with a single output but without specifying num_outputs/split. Depending on the ONNX opset/schema validation, this may be considered invalid or ambiguous. To make the test robust, consider explicitly setting num_outputs=1 (and/or providing a split input) so the model is unambiguously valid across checkers/runtimes.

Suggested change
out = Split<axis=0>(x)
out = Split<axis=0, num_outputs=1>(x)

Copilot uses AI. Check for mistakes.
}
"""
)

optimized = onnxscript.rewriter.rewrite(model_proto, pattern_rewrite_rules=[rule])

def count_ops(proto, op_type):
return sum(1 for n in proto.graph.node if n.op_type == op_type)

self.assertEqual(count_ops(optimized, "Relu"), 1)
self.assertEqual(count_ops(optimized, "Split"), 1)

def test_flexible_outputs_with_partial_usage(self):
def pattern_func(op, x):
return op.Split(x, _allow_flexible_outputs=True)

def rewrite_func(op, _match=None, x=None):
if x is None or _match is None:
return None

split = next((n for n in _match.nodes if n.op_type == "Split"), None)
if not split:
return None

num_outputs = len(split.outputs)
relu = op.Relu(x)
return op.Split(relu, _outputs=num_outputs, **split.attributes)

rule = pattern.RewriteRule(pattern_func, rewrite_func)

model_proto = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[10] x) => (float[5] out1, float[5] out2, float[5] sum)
{
s1, s2 = Split<axis=0, num_outputs=2>(x)
out1 = Abs(s1)
out2 = Neg(s2)
sum = Add(out1, out2)
}
"""
)

optimized = onnxscript.rewriter.rewrite(model_proto, pattern_rewrite_rules=[rule])

def count_ops(proto, op_type):
return sum(1 for n in proto.graph.node if n.op_type == op_type)

self.assertEqual(count_ops(optimized, "Relu"), 1)
self.assertEqual(count_ops(optimized, "Split"), 1)


if __name__ == "__main__":
unittest.main()
Loading