[Flexible Outputs] Allow multi-output matching#2873
[Flexible Outputs] Allow multi-output matching#2873srikris-sridhar wants to merge 1 commit intomicrosoft:mainfrom
Conversation
|
@justinchuby Hopefully this is a good attempt. It's my first PR, so go easy on me :) |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2873 +/- ##
==========================================
+ Coverage 71.96% 71.98% +0.02%
==========================================
Files 239 239
Lines 29224 29262 +38
Branches 2878 2885 +7
==========================================
+ Hits 21031 21065 +34
- Misses 7216 7218 +2
- Partials 977 979 +2 ☔ View full report in Codecov by Sentry. |
@microsoft-github-policy-service agree |
Adds support for pattern matching on multiple outputs
for optimizations involving ops like Split.
```
def pattern(self, op, x):
relu = op.Relu(x)
return op.Split(relu, _allow_flexible_outputs=True)
```
Fixes microsoft#2581
2d0cc42 to
38cf331
Compare
|
Thanks for creating a PR! Will review it this week. cc @gramalingam |
There was a problem hiding this comment.
Pull request overview
Adds support for rewrite patterns that can match nodes with variadic/unknown output arity (e.g., Split) by introducing an _allow_flexible_outputs pattern option and adjusting rewrite application to optionally replace all outputs of the matched variadic node.
Changes:
- Add
_allow_flexible_outputsto pattern IR node construction and store it onNodePattern. - Update matcher + rewrite application to support flexible-output rewrites (including optional
_matchinjection into replacement functions). - Add unit tests covering flexible-output matching/rewriting scenarios.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
| onnxscript/rewriter/_pattern_ir.py | Introduces _allow_flexible_outputs plumbing into NodePattern. |
| onnxscript/rewriter/_matcher.py | Adjusts removability validation behavior for flexible-output patterns. |
| onnxscript/rewriter/_rewrite_rule.py | Injects _match into replacement functions, relaxes output-count validation for flexible patterns, and adds flexible-output replacement path. |
| onnxscript/rewriter/pattern_test.py | Adds tests validating flexible-output matching/rewrite behavior. |
Comments suppressed due to low confidence (1)
onnxscript/rewriter/_rewrite_rule.py:374
- Skipping the replacement-output arity check whenever the pattern contains any flexible-output node is too broad: it can silently allow output-count mismatches even when the flexible-output node isn’t the one being replaced (or when the replacement isn’t actually returning the full variadic outputs). Consider validating expected arity based on the actual matched flexible node(s) (e.g.,
len(bound_node.outputs)) vs the normaltarget_pattern.num_outputs, and fail clearly when neither condition is satisfied.
# Check if any node in the pattern uses flexible outputs (cache for hot path)
self._has_flexible_outputs = any(
node.allow_flexible_outputs for node in self._target_pattern._nodes
)
if not isinstance(replacement_pattern, ReplacementPatternFunction):
replacement_pattern = ReplacementPatternFunction(replacement_pattern)
self._replacement_pattern = replacement_pattern
self.remove_nodes = remove_nodes
self.graph_pre_visitor = graph_pre_visitor
self.graph_post_visitor = graph_post_visitor
self.as_function = as_function
def __str__(self) -> str:
return self.name if self.name else "Anonymous Rule"
def try_rewrite(
self,
model: ir.Model,
graph_or_function: ir.Graph | ir.Function,
node: ir.Node,
*,
verbose: int | None = None,
tracer: _basics.MatchingTracer | None = None,
) -> ReplacementSubgraph | None:
"""If the node matches the pattern, then replace the node with the replacement pattern."""
# Use the inherited match method from Pattern
match = self.match(
model,
graph_or_function,
node,
verbose=verbose,
check_nodes_are_removable=self.remove_nodes,
tracer=tracer,
)
if not match:
return None
replacement_subgraph = self._replacement_pattern.get_replacement(match)
if replacement_subgraph is None:
if tracer:
tracer.log(
self,
graph_or_function,
node,
match,
_basics.MatchStatus.REPLACEMENT_FAILED,
)
return None
if not self._has_flexible_outputs and len(replacement_subgraph.new_outputs) != self._target_pattern.num_outputs:
raise ValueError(
f"Number of outputs from replacement function does not match the number of outputs from the target pattern. "
f"Expected {self._target_pattern.num_outputs}, but got {len(replacement_subgraph.new_outputs)}."
)
| def __init__(self, function) -> None: | ||
| self._function = function | ||
| # Cache signature inspection to avoid repeated introspection on hot path | ||
| self._accepts_match = "_match" in inspect.signature(function).parameters | ||
|
|
||
| def get_replacement(self, match: _basics.MatchResult) -> ReplacementSubgraph | None: | ||
| context = RewriterContext() | ||
| bindings = match.bindings if not self._accepts_match else {**match.bindings, "_match": match} | ||
| try: | ||
| new_outputs = self._function(context, **match.bindings) | ||
| new_outputs = self._function(context, **bindings) |
There was a problem hiding this comment.
inspect.signature(function) can raise TypeError/ValueError for some callables (e.g., certain builtins, C-extension callables, or objects with non-introspectable __call__). That would make constructing ReplacementPatternFunction fail even though the callable is otherwise usable. Consider guarding this with try/except and defaulting to not injecting _match when the signature can't be inspected.
| # Check if this is a flexible output case (matched node has more outputs than captured) | ||
| flexible_node = None | ||
| for matched_node in delta.match.nodes: | ||
| if len(matched_node.outputs) > len(delta.match.outputs): | ||
| flexible_node = matched_node | ||
| break | ||
|
|
||
| if flexible_node and len(delta.new_outputs) == len(flexible_node.outputs): | ||
| # Flexible output replacement: replace all outputs of the flexible node | ||
| convenience.replace_nodes_and_values( | ||
| graph_or_function, | ||
| node, | ||
| delta.match.nodes if rule.remove_nodes else [], | ||
| delta.new_nodes, | ||
| flexible_node.outputs, | ||
| delta.new_outputs, | ||
| ) | ||
| else: | ||
| # Standard replacement | ||
| convenience.replace_nodes_and_values( | ||
| graph_or_function, | ||
| node, | ||
| delta.match.nodes if rule.remove_nodes else [], | ||
| delta.new_nodes, | ||
| delta.match.outputs, | ||
| delta.new_outputs, | ||
| ) |
There was a problem hiding this comment.
The flexible-output replacement detection (len(matched_node.outputs) > len(delta.match.outputs)) is a heuristic that can misidentify the node (or miss it) when the pattern has multiple outputs / multiple matched nodes. Also, when the heuristic doesn’t trigger, the code falls back to standard replacement even though removability checks were skipped for flexible-output patterns, which can lead to removing a node whose other outputs are still used. Prefer identifying flexible nodes via delta.match.node_bindings + NodePattern.allow_flexible_outputs, and enforce that remove_nodes=True rules either (a) replace all outputs of that node or (b) don’t apply.
| # Skip removability check for flexible output nodes since they may have | ||
| # additional outputs beyond those captured in the pattern | ||
| if check_removable and not pattern.output_node.allow_flexible_outputs: |
There was a problem hiding this comment.
Skipping _valid_to_replace for flexible-output patterns can make rewrites unsafe: if a variadic-output node has additional outputs used outside the matched subgraph, the match will still succeed and remove_nodes=True can delete a producer of still-used values unless the replacement reliably substitutes all those outputs. Consider running the removability check against the full output list of the matched node when allow_flexible_outputs is set, and/or require the replacement to return exactly len(matched_node.outputs) outputs when nodes are being removed.
| # Skip removability check for flexible output nodes since they may have | |
| # additional outputs beyond those captured in the pattern | |
| if check_removable and not pattern.output_node.allow_flexible_outputs: | |
| if check_removable: |
| attributes: dict[str, AttrPattern], | ||
| outputs: Sequence[str | None], | ||
| *, | ||
| allow_other_attributes: bool | None, | ||
| allow_other_inputs: bool | None, | ||
| allow_flexible_outputs: bool | None = None, | ||
| check: Callable | None = None, | ||
| ): | ||
| if allow_other_attributes is None: | ||
| # Default behavior: allow other unmatched attributes in the node. | ||
| allow_other_attributes = True | ||
| if allow_other_inputs is None: | ||
| # TODO(rama): Should we default to True? For now, we preserve the current behavior. | ||
| allow_other_inputs = False | ||
| if allow_flexible_outputs is None: | ||
| # Default behavior: do not match flexible outputs | ||
| allow_flexible_outputs = False | ||
| self.domain = domain | ||
| self.op = StringConstantPattern(op) if isinstance(op, str) else op | ||
| self.inputs = [_to_value_pattern(x) for x in inputs] | ||
| self.attributes = attributes | ||
| self.allow_other_attributes = allow_other_attributes | ||
| self.allow_other_inputs = allow_other_inputs | ||
| self.allow_flexible_outputs = allow_flexible_outputs |
There was a problem hiding this comment.
allow_flexible_outputs is stored on NodePattern, but NodePattern.clone() (used by GraphPattern.commute() and potentially other copying paths) doesn’t pass this flag to the constructor, so cloned/commuted patterns will lose flexible-output behavior. Ensure cloning preserves allow_flexible_outputs so the feature works consistently across pattern transformations.
| <ir_version: 7, opset_import: [ "" : 18]> | ||
| agraph (float[10] x) => (float[10] out) | ||
| { | ||
| out = Split<axis=0>(x) |
There was a problem hiding this comment.
This test constructs a Split node with a single output but without specifying num_outputs/split. Depending on the ONNX opset/schema validation, this may be considered invalid or ambiguous. To make the test robust, consider explicitly setting num_outputs=1 (and/or providing a split input) so the model is unambiguously valid across checkers/runtimes.
| out = Split<axis=0>(x) | |
| out = Split<axis=0, num_outputs=1>(x) |
Adds support for pattern matching on multiple outputs for optimizations involving ops like Split.
Fixes #2581