-
Notifications
You must be signed in to change notification settings - Fork 107
[Flexible Outputs] Allow multi-output matching #2873
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,7 @@ | |
|
|
||
| import abc | ||
| import dataclasses | ||
| import inspect | ||
| import itertools | ||
| from typing import ( | ||
| Callable, | ||
|
|
@@ -214,11 +215,14 @@ class ReplacementPatternFunction: | |
|
|
||
| def __init__(self, function) -> None: | ||
| self._function = function | ||
| # Cache signature inspection to avoid repeated introspection on hot path | ||
| self._accepts_match = "_match" in inspect.signature(function).parameters | ||
|
|
||
| def get_replacement(self, match: _basics.MatchResult) -> ReplacementSubgraph | None: | ||
| context = RewriterContext() | ||
| bindings = match.bindings if not self._accepts_match else {**match.bindings, "_match": match} | ||
| try: | ||
| new_outputs = self._function(context, **match.bindings) | ||
| new_outputs = self._function(context, **bindings) | ||
|
Comment on lines
216
to
+225
|
||
| except _basics.MatchFailureError as e: | ||
| match.fail(e.reason, list(e.failure_sources)) | ||
| return None | ||
|
|
@@ -313,6 +317,11 @@ def __init__( | |
| # Initialize the base pattern matching functionality | ||
| super().__init__(target_pattern, condition_function, matcher, verbose, name) | ||
|
|
||
| # Check if any node in the pattern uses flexible outputs (cache for hot path) | ||
| self._has_flexible_outputs = any( | ||
| node.allow_flexible_outputs for node in self._target_pattern._nodes | ||
| ) | ||
|
|
||
| if not isinstance(replacement_pattern, ReplacementPatternFunction): | ||
| replacement_pattern = ReplacementPatternFunction(replacement_pattern) | ||
| self._replacement_pattern = replacement_pattern | ||
|
|
@@ -357,7 +366,8 @@ def try_rewrite( | |
| _basics.MatchStatus.REPLACEMENT_FAILED, | ||
| ) | ||
| return None | ||
| if len(replacement_subgraph.new_outputs) != self._target_pattern.num_outputs: | ||
|
|
||
| if not self._has_flexible_outputs and len(replacement_subgraph.new_outputs) != self._target_pattern.num_outputs: | ||
| raise ValueError( | ||
| f"Number of outputs from replacement function does not match the number of outputs from the target pattern. " | ||
| f"Expected {self._target_pattern.num_outputs}, but got {len(replacement_subgraph.new_outputs)}." | ||
|
|
@@ -766,14 +776,33 @@ def _apply_to_graph_or_function( | |
| for n in delta.new_nodes: | ||
| n.metadata_props[RULE_NAME_TAG] = rule.name | ||
|
|
||
| convenience.replace_nodes_and_values( | ||
| graph_or_function, | ||
| node, | ||
| delta.match.nodes if rule.remove_nodes else [], | ||
| delta.new_nodes, | ||
| delta.match.outputs, | ||
| delta.new_outputs, | ||
| ) | ||
| # Check if this is a flexible output case (matched node has more outputs than captured) | ||
| flexible_node = None | ||
| for matched_node in delta.match.nodes: | ||
| if len(matched_node.outputs) > len(delta.match.outputs): | ||
| flexible_node = matched_node | ||
| break | ||
|
|
||
| if flexible_node and len(delta.new_outputs) == len(flexible_node.outputs): | ||
| # Flexible output replacement: replace all outputs of the flexible node | ||
| convenience.replace_nodes_and_values( | ||
| graph_or_function, | ||
| node, | ||
| delta.match.nodes if rule.remove_nodes else [], | ||
| delta.new_nodes, | ||
| flexible_node.outputs, | ||
| delta.new_outputs, | ||
| ) | ||
| else: | ||
| # Standard replacement | ||
| convenience.replace_nodes_and_values( | ||
| graph_or_function, | ||
| node, | ||
| delta.match.nodes if rule.remove_nodes else [], | ||
| delta.new_nodes, | ||
| delta.match.outputs, | ||
| delta.new_outputs, | ||
| ) | ||
|
Comment on lines
+779
to
+805
|
||
|
|
||
| if merge_metadata: | ||
| _default_metadata_merger.copy_merged_metadata( | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -989,5 +989,247 @@ def test_pattern_builder_context(self): | |||||
| self.assertEqual(ops, ["Op1", "Op2", "Add", "Op3", "Mul"]) | ||||||
|
|
||||||
|
|
||||||
| class FlexibleOutputTest(unittest.TestCase): | ||||||
| """Test patterns with flexible output counts using _allow_flexible_outputs.""" | ||||||
|
|
||||||
| def test_flexible_outputs_with_split(self): | ||||||
| """Test that _allow_flexible_outputs works for Split with varying output counts.""" | ||||||
|
|
||||||
| def relu_split_pattern(op, x): | ||||||
| relu = op.Relu(x) | ||||||
| return op.Split(relu, _allow_flexible_outputs=True) | ||||||
|
|
||||||
| def relu_split_rewrite(op, _match=None, x=None): | ||||||
| if x is None or _match is None: | ||||||
| return None | ||||||
|
|
||||||
| split = next((n for n in _match.nodes if n.op_type == "Split"), None) | ||||||
| if not split: | ||||||
| return None | ||||||
|
|
||||||
| num_outputs = len(split.outputs) | ||||||
| split_results = op.Split(x, _outputs=num_outputs, **split.attributes) | ||||||
|
|
||||||
| return tuple(op.Relu(s) for s in split_results) if num_outputs > 1 else op.Relu(split_results) | ||||||
|
|
||||||
| rule = pattern.RewriteRule(relu_split_pattern, relu_split_rewrite) | ||||||
|
|
||||||
| # Test model with Relu -> Split pattern | ||||||
| model_proto = onnx.parser.parse_model( | ||||||
| """ | ||||||
| <ir_version: 7, opset_import: [ "" : 18]> | ||||||
| agraph (float[10] x) => (float[5] out1, float[5] out2) | ||||||
| { | ||||||
| relu_out = Relu(x) | ||||||
| out1, out2 = Split<axis=0, num_outputs=2>(relu_out) | ||||||
| } | ||||||
| """ | ||||||
| ) | ||||||
|
|
||||||
| optimized = onnxscript.rewriter.rewrite(model_proto, pattern_rewrite_rules=[rule]) | ||||||
|
|
||||||
| # Verify transformation: 1 Relu + 1 Split -> 1 Split + 2 Relu | ||||||
| def count_ops(proto, op_type): | ||||||
| return sum(1 for n in proto.graph.node if n.op_type == op_type) | ||||||
|
|
||||||
| self.assertEqual(count_ops(model_proto, "Relu"), 1) | ||||||
| self.assertEqual(count_ops(model_proto, "Split"), 1) | ||||||
| self.assertEqual(count_ops(optimized, "Relu"), 2) | ||||||
| self.assertEqual(count_ops(optimized, "Split"), 1) | ||||||
|
|
||||||
| def test_flexible_outputs_without_match_parameter(self): | ||||||
| def pattern_func(op, x): | ||||||
| return op.Split(x, _allow_flexible_outputs=True) | ||||||
|
|
||||||
| def rewrite_func(op, x=None): | ||||||
| if x is None: | ||||||
| return None | ||||||
| return op.Split(op.Relu(x), _outputs=2) | ||||||
|
|
||||||
| rule = pattern.RewriteRule(pattern_func, rewrite_func) | ||||||
|
|
||||||
| model_proto = onnx.parser.parse_model( | ||||||
| """ | ||||||
| <ir_version: 7, opset_import: [ "" : 18]> | ||||||
| agraph (float[10] x) => (float[5] out1, float[5] out2) | ||||||
| { | ||||||
| out1, out2 = Split<axis=0, num_outputs=2>(x) | ||||||
| } | ||||||
| """ | ||||||
| ) | ||||||
|
|
||||||
| optimized = onnxscript.rewriter.rewrite(model_proto, pattern_rewrite_rules=[rule]) | ||||||
|
|
||||||
| def count_ops(proto, op_type): | ||||||
| return sum(1 for n in proto.graph.node if n.op_type == op_type) | ||||||
|
|
||||||
| self.assertEqual(count_ops(optimized, "Relu"), 1) | ||||||
| self.assertEqual(count_ops(optimized, "Split"), 1) | ||||||
|
|
||||||
| def test_output_count_validation_without_flexible(self): | ||||||
| def pattern_func(op, x): | ||||||
| return op.Relu(x) | ||||||
|
|
||||||
| def bad_rewrite_func(op, x=None): | ||||||
| return op.Split(x, _outputs=2) | ||||||
|
|
||||||
| rule = pattern.RewriteRule(pattern_func, bad_rewrite_func) | ||||||
|
|
||||||
| model_proto = onnx.parser.parse_model( | ||||||
| """ | ||||||
| <ir_version: 7, opset_import: [ "" : 18]> | ||||||
| agraph (float[10] x) => (float[10] out) | ||||||
| { | ||||||
| out = Relu(x) | ||||||
| } | ||||||
| """ | ||||||
| ) | ||||||
|
|
||||||
| with self.assertRaises(ValueError) as ctx: | ||||||
| onnxscript.rewriter.rewrite(model_proto, pattern_rewrite_rules=[rule]) | ||||||
|
|
||||||
| self.assertIn("Number of outputs", str(ctx.exception)) | ||||||
|
|
||||||
| def test_standard_replacement_path(self): | ||||||
| def pattern_func(op, x): | ||||||
| return op.Relu(x) | ||||||
|
|
||||||
| def rewrite_func(op, x=None): | ||||||
| return op.Sigmoid(x) | ||||||
|
|
||||||
| rule = pattern.RewriteRule(pattern_func, rewrite_func) | ||||||
|
|
||||||
| model_proto = onnx.parser.parse_model( | ||||||
| """ | ||||||
| <ir_version: 7, opset_import: [ "" : 18]> | ||||||
| agraph (float[10] x) => (float[10] out) | ||||||
| { | ||||||
| out = Relu(x) | ||||||
| } | ||||||
| """ | ||||||
| ) | ||||||
|
|
||||||
| optimized = onnxscript.rewriter.rewrite(model_proto, pattern_rewrite_rules=[rule]) | ||||||
|
|
||||||
| def count_ops(proto, op_type): | ||||||
| return sum(1 for n in proto.graph.node if n.op_type == op_type) | ||||||
|
|
||||||
| self.assertEqual(count_ops(optimized, "Relu"), 0) | ||||||
| self.assertEqual(count_ops(optimized, "Sigmoid"), 1) | ||||||
|
|
||||||
| def test_flexible_outputs_with_three_outputs(self): | ||||||
| def pattern_func(op, x): | ||||||
| return op.Split(x, _allow_flexible_outputs=True) | ||||||
|
|
||||||
| def rewrite_func(op, _match=None, x=None): | ||||||
| if x is None or _match is None: | ||||||
| return None | ||||||
|
|
||||||
| split = next((n for n in _match.nodes if n.op_type == "Split"), None) | ||||||
| if not split: | ||||||
| return None | ||||||
|
|
||||||
| num_outputs = len(split.outputs) | ||||||
| relu = op.Relu(x) | ||||||
| split_results = op.Split(relu, _outputs=num_outputs, **split.attributes) | ||||||
| return split_results | ||||||
|
|
||||||
| rule = pattern.RewriteRule(pattern_func, rewrite_func) | ||||||
|
|
||||||
| model_proto = onnx.parser.parse_model( | ||||||
| """ | ||||||
| <ir_version: 7, opset_import: [ "" : 18]> | ||||||
| agraph (float[15] x) => (float[5] out1, float[5] out2, float[5] out3) | ||||||
| { | ||||||
| out1, out2, out3 = Split<axis=0, num_outputs=3>(x) | ||||||
| } | ||||||
| """ | ||||||
| ) | ||||||
|
|
||||||
| optimized = onnxscript.rewriter.rewrite(model_proto, pattern_rewrite_rules=[rule]) | ||||||
|
|
||||||
| def count_ops(proto, op_type): | ||||||
| return sum(1 for n in proto.graph.node if n.op_type == op_type) | ||||||
|
|
||||||
| self.assertEqual(count_ops(optimized, "Relu"), 1) | ||||||
| self.assertEqual(count_ops(optimized, "Split"), 1) | ||||||
|
|
||||||
| def test_flexible_outputs_with_single_output(self): | ||||||
| def pattern_func(op, x): | ||||||
| return op.Split(x, _allow_flexible_outputs=True) | ||||||
|
|
||||||
| def rewrite_func(op, _match=None, x=None): | ||||||
| if x is None or _match is None: | ||||||
| return None | ||||||
|
|
||||||
| split = next((n for n in _match.nodes if n.op_type == "Split"), None) | ||||||
| if not split: | ||||||
| return None | ||||||
|
|
||||||
| relu = op.Relu(x) | ||||||
| if len(split.outputs) == 1: | ||||||
| return op.Split(relu, _outputs=1, **split.attributes) | ||||||
| return relu | ||||||
|
|
||||||
| rule = pattern.RewriteRule(pattern_func, rewrite_func) | ||||||
|
|
||||||
| model_proto = onnx.parser.parse_model( | ||||||
| """ | ||||||
| <ir_version: 7, opset_import: [ "" : 18]> | ||||||
| agraph (float[10] x) => (float[10] out) | ||||||
| { | ||||||
| out = Split<axis=0>(x) | ||||||
|
||||||
| out = Split<axis=0>(x) | |
| out = Split<axis=0, num_outputs=1>(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Skipping
_valid_to_replacefor flexible-output patterns can make rewrites unsafe: if a variadic-output node has additional outputs used outside the matched subgraph, the match will still succeed andremove_nodes=Truecan delete a producer of still-used values unless the replacement reliably substitutes all those outputs. Consider running the removability check against the full output list of the matched node whenallow_flexible_outputsis set, and/or require the replacement to return exactlylen(matched_node.outputs)outputs when nodes are being removed.