diff --git a/onnxscript/rewriter/_matcher.py b/onnxscript/rewriter/_matcher.py index f54b77033f..3cb8181685 100644 --- a/onnxscript/rewriter/_matcher.py +++ b/onnxscript/rewriter/_matcher.py @@ -310,9 +310,12 @@ def _match_single_output_node( if output_values is None: # TODO(rama): Is this a valid (useful) case? return match - if check_removable and not _valid_to_replace(match.nodes, output_values): - # TODO(rama): Match status should be updated to reflect failure reason. - return match.fail("Matched nodes have other uses preventing replacement.") + # Skip removability check for flexible output nodes since they may have + # additional outputs beyond those captured in the pattern + if check_removable and not pattern.output_node.allow_flexible_outputs: + if not _valid_to_replace(match.nodes, output_values): + # TODO(rama): Match status should be updated to reflect failure reason. + return match.fail("Matched nodes have other uses preventing replacement.") match.outputs.extend(output_values) return match diff --git a/onnxscript/rewriter/_pattern_ir.py b/onnxscript/rewriter/_pattern_ir.py index 674d1fc593..11ed16ecdf 100644 --- a/onnxscript/rewriter/_pattern_ir.py +++ b/onnxscript/rewriter/_pattern_ir.py @@ -248,6 +248,7 @@ def __call__( _outputs: int | list[str | None] = 1, _allow_other_attributes: bool | None = None, _allow_other_inputs: bool | None = None, + _allow_flexible_outputs: bool | None = None, _check: Callable | None = None, **kwargs, ): @@ -280,6 +281,7 @@ def __call__( _outputs, allow_other_attributes=_allow_other_attributes, allow_other_inputs=_allow_other_inputs, + allow_flexible_outputs=_allow_flexible_outputs, check=_check, ) self.pattern_builder.add_node(node_pattern) @@ -440,6 +442,7 @@ def __init__( *, allow_other_attributes: bool | None, allow_other_inputs: bool | None, + allow_flexible_outputs: bool | None = None, check: Callable | None = None, ): if allow_other_attributes is None: @@ -448,12 +451,16 @@ def __init__( if allow_other_inputs is None: # TODO(rama): Should we default to True? For now, we preserve the current behavior. allow_other_inputs = False + if allow_flexible_outputs is None: + # Default behavior: do not match flexible outputs + allow_flexible_outputs = False self.domain = domain self.op = StringConstantPattern(op) if isinstance(op, str) else op self.inputs = [_to_value_pattern(x) for x in inputs] self.attributes = attributes self.allow_other_attributes = allow_other_attributes self.allow_other_inputs = allow_other_inputs + self.allow_flexible_outputs = allow_flexible_outputs self._check = check # In the common case, domain and op are constants, which can be used to optimize matching. if isinstance(op, str) and isinstance(domain, StringConstantPattern): diff --git a/onnxscript/rewriter/_rewrite_rule.py b/onnxscript/rewriter/_rewrite_rule.py index 108bc01c29..dfe0889854 100644 --- a/onnxscript/rewriter/_rewrite_rule.py +++ b/onnxscript/rewriter/_rewrite_rule.py @@ -6,6 +6,7 @@ import abc import dataclasses +import inspect import itertools from typing import ( Callable, @@ -214,11 +215,14 @@ class ReplacementPatternFunction: def __init__(self, function) -> None: self._function = function + # Cache signature inspection to avoid repeated introspection on hot path + self._accepts_match = "_match" in inspect.signature(function).parameters def get_replacement(self, match: _basics.MatchResult) -> ReplacementSubgraph | None: context = RewriterContext() + bindings = match.bindings if not self._accepts_match else {**match.bindings, "_match": match} try: - new_outputs = self._function(context, **match.bindings) + new_outputs = self._function(context, **bindings) except _basics.MatchFailureError as e: match.fail(e.reason, list(e.failure_sources)) return None @@ -313,6 +317,11 @@ def __init__( # Initialize the base pattern matching functionality super().__init__(target_pattern, condition_function, matcher, verbose, name) + # Check if any node in the pattern uses flexible outputs (cache for hot path) + self._has_flexible_outputs = any( + node.allow_flexible_outputs for node in self._target_pattern._nodes + ) + if not isinstance(replacement_pattern, ReplacementPatternFunction): replacement_pattern = ReplacementPatternFunction(replacement_pattern) self._replacement_pattern = replacement_pattern @@ -357,7 +366,8 @@ def try_rewrite( _basics.MatchStatus.REPLACEMENT_FAILED, ) return None - if len(replacement_subgraph.new_outputs) != self._target_pattern.num_outputs: + + if not self._has_flexible_outputs and len(replacement_subgraph.new_outputs) != self._target_pattern.num_outputs: raise ValueError( f"Number of outputs from replacement function does not match the number of outputs from the target pattern. " f"Expected {self._target_pattern.num_outputs}, but got {len(replacement_subgraph.new_outputs)}." @@ -766,14 +776,33 @@ def _apply_to_graph_or_function( for n in delta.new_nodes: n.metadata_props[RULE_NAME_TAG] = rule.name - convenience.replace_nodes_and_values( - graph_or_function, - node, - delta.match.nodes if rule.remove_nodes else [], - delta.new_nodes, - delta.match.outputs, - delta.new_outputs, - ) + # Check if this is a flexible output case (matched node has more outputs than captured) + flexible_node = None + for matched_node in delta.match.nodes: + if len(matched_node.outputs) > len(delta.match.outputs): + flexible_node = matched_node + break + + if flexible_node and len(delta.new_outputs) == len(flexible_node.outputs): + # Flexible output replacement: replace all outputs of the flexible node + convenience.replace_nodes_and_values( + graph_or_function, + node, + delta.match.nodes if rule.remove_nodes else [], + delta.new_nodes, + flexible_node.outputs, + delta.new_outputs, + ) + else: + # Standard replacement + convenience.replace_nodes_and_values( + graph_or_function, + node, + delta.match.nodes if rule.remove_nodes else [], + delta.new_nodes, + delta.match.outputs, + delta.new_outputs, + ) if merge_metadata: _default_metadata_merger.copy_merged_metadata( diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index f296b5320c..fed9b05deb 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -989,5 +989,247 @@ def test_pattern_builder_context(self): self.assertEqual(ops, ["Op1", "Op2", "Add", "Op3", "Mul"]) +class FlexibleOutputTest(unittest.TestCase): + """Test patterns with flexible output counts using _allow_flexible_outputs.""" + + def test_flexible_outputs_with_split(self): + """Test that _allow_flexible_outputs works for Split with varying output counts.""" + + def relu_split_pattern(op, x): + relu = op.Relu(x) + return op.Split(relu, _allow_flexible_outputs=True) + + def relu_split_rewrite(op, _match=None, x=None): + if x is None or _match is None: + return None + + split = next((n for n in _match.nodes if n.op_type == "Split"), None) + if not split: + return None + + num_outputs = len(split.outputs) + split_results = op.Split(x, _outputs=num_outputs, **split.attributes) + + return tuple(op.Relu(s) for s in split_results) if num_outputs > 1 else op.Relu(split_results) + + rule = pattern.RewriteRule(relu_split_pattern, relu_split_rewrite) + + # Test model with Relu -> Split pattern + model_proto = onnx.parser.parse_model( + """ + + agraph (float[10] x) => (float[5] out1, float[5] out2) + { + relu_out = Relu(x) + out1, out2 = Split(relu_out) + } + """ + ) + + optimized = onnxscript.rewriter.rewrite(model_proto, pattern_rewrite_rules=[rule]) + + # Verify transformation: 1 Relu + 1 Split -> 1 Split + 2 Relu + def count_ops(proto, op_type): + return sum(1 for n in proto.graph.node if n.op_type == op_type) + + self.assertEqual(count_ops(model_proto, "Relu"), 1) + self.assertEqual(count_ops(model_proto, "Split"), 1) + self.assertEqual(count_ops(optimized, "Relu"), 2) + self.assertEqual(count_ops(optimized, "Split"), 1) + + def test_flexible_outputs_without_match_parameter(self): + def pattern_func(op, x): + return op.Split(x, _allow_flexible_outputs=True) + + def rewrite_func(op, x=None): + if x is None: + return None + return op.Split(op.Relu(x), _outputs=2) + + rule = pattern.RewriteRule(pattern_func, rewrite_func) + + model_proto = onnx.parser.parse_model( + """ + + agraph (float[10] x) => (float[5] out1, float[5] out2) + { + out1, out2 = Split(x) + } + """ + ) + + optimized = onnxscript.rewriter.rewrite(model_proto, pattern_rewrite_rules=[rule]) + + def count_ops(proto, op_type): + return sum(1 for n in proto.graph.node if n.op_type == op_type) + + self.assertEqual(count_ops(optimized, "Relu"), 1) + self.assertEqual(count_ops(optimized, "Split"), 1) + + def test_output_count_validation_without_flexible(self): + def pattern_func(op, x): + return op.Relu(x) + + def bad_rewrite_func(op, x=None): + return op.Split(x, _outputs=2) + + rule = pattern.RewriteRule(pattern_func, bad_rewrite_func) + + model_proto = onnx.parser.parse_model( + """ + + agraph (float[10] x) => (float[10] out) + { + out = Relu(x) + } + """ + ) + + with self.assertRaises(ValueError) as ctx: + onnxscript.rewriter.rewrite(model_proto, pattern_rewrite_rules=[rule]) + + self.assertIn("Number of outputs", str(ctx.exception)) + + def test_standard_replacement_path(self): + def pattern_func(op, x): + return op.Relu(x) + + def rewrite_func(op, x=None): + return op.Sigmoid(x) + + rule = pattern.RewriteRule(pattern_func, rewrite_func) + + model_proto = onnx.parser.parse_model( + """ + + agraph (float[10] x) => (float[10] out) + { + out = Relu(x) + } + """ + ) + + optimized = onnxscript.rewriter.rewrite(model_proto, pattern_rewrite_rules=[rule]) + + def count_ops(proto, op_type): + return sum(1 for n in proto.graph.node if n.op_type == op_type) + + self.assertEqual(count_ops(optimized, "Relu"), 0) + self.assertEqual(count_ops(optimized, "Sigmoid"), 1) + + def test_flexible_outputs_with_three_outputs(self): + def pattern_func(op, x): + return op.Split(x, _allow_flexible_outputs=True) + + def rewrite_func(op, _match=None, x=None): + if x is None or _match is None: + return None + + split = next((n for n in _match.nodes if n.op_type == "Split"), None) + if not split: + return None + + num_outputs = len(split.outputs) + relu = op.Relu(x) + split_results = op.Split(relu, _outputs=num_outputs, **split.attributes) + return split_results + + rule = pattern.RewriteRule(pattern_func, rewrite_func) + + model_proto = onnx.parser.parse_model( + """ + + agraph (float[15] x) => (float[5] out1, float[5] out2, float[5] out3) + { + out1, out2, out3 = Split(x) + } + """ + ) + + optimized = onnxscript.rewriter.rewrite(model_proto, pattern_rewrite_rules=[rule]) + + def count_ops(proto, op_type): + return sum(1 for n in proto.graph.node if n.op_type == op_type) + + self.assertEqual(count_ops(optimized, "Relu"), 1) + self.assertEqual(count_ops(optimized, "Split"), 1) + + def test_flexible_outputs_with_single_output(self): + def pattern_func(op, x): + return op.Split(x, _allow_flexible_outputs=True) + + def rewrite_func(op, _match=None, x=None): + if x is None or _match is None: + return None + + split = next((n for n in _match.nodes if n.op_type == "Split"), None) + if not split: + return None + + relu = op.Relu(x) + if len(split.outputs) == 1: + return op.Split(relu, _outputs=1, **split.attributes) + return relu + + rule = pattern.RewriteRule(pattern_func, rewrite_func) + + model_proto = onnx.parser.parse_model( + """ + + agraph (float[10] x) => (float[10] out) + { + out = Split(x) + } + """ + ) + + optimized = onnxscript.rewriter.rewrite(model_proto, pattern_rewrite_rules=[rule]) + + def count_ops(proto, op_type): + return sum(1 for n in proto.graph.node if n.op_type == op_type) + + self.assertEqual(count_ops(optimized, "Relu"), 1) + self.assertEqual(count_ops(optimized, "Split"), 1) + + def test_flexible_outputs_with_partial_usage(self): + def pattern_func(op, x): + return op.Split(x, _allow_flexible_outputs=True) + + def rewrite_func(op, _match=None, x=None): + if x is None or _match is None: + return None + + split = next((n for n in _match.nodes if n.op_type == "Split"), None) + if not split: + return None + + num_outputs = len(split.outputs) + relu = op.Relu(x) + return op.Split(relu, _outputs=num_outputs, **split.attributes) + + rule = pattern.RewriteRule(pattern_func, rewrite_func) + + model_proto = onnx.parser.parse_model( + """ + + agraph (float[10] x) => (float[5] out1, float[5] out2, float[5] sum) + { + s1, s2 = Split(x) + out1 = Abs(s1) + out2 = Neg(s2) + sum = Add(out1, out2) + } + """ + ) + + optimized = onnxscript.rewriter.rewrite(model_proto, pattern_rewrite_rules=[rule]) + + def count_ops(proto, op_type): + return sum(1 for n in proto.graph.node if n.op_type == op_type) + + self.assertEqual(count_ops(optimized, "Relu"), 1) + self.assertEqual(count_ops(optimized, "Split"), 1) + + if __name__ == "__main__": unittest.main()