Skip to content

[Flexible Outputs] Allow multi-output matching#2873

Open
srikris-sridhar wants to merge 1 commit intomicrosoft:mainfrom
srikris-sridhar:dev/srikris/allow_other_output
Open

[Flexible Outputs] Allow multi-output matching#2873
srikris-sridhar wants to merge 1 commit intomicrosoft:mainfrom
srikris-sridhar:dev/srikris/allow_other_output

Conversation

@srikris-sridhar
Copy link
Copy Markdown

Adds support for pattern matching on multiple outputs for optimizations involving ops like Split.

def pattern(self, op, x):
        relu = op.Relu(x)
        return op.Split(relu, _allow_flexible_outputs=True)

Fixes #2581

@srikris-sridhar
Copy link
Copy Markdown
Author

@justinchuby Hopefully this is a good attempt. It's my first PR, so go easy on me :)

@codecov
Copy link
Copy Markdown

codecov bot commented Mar 29, 2026

Codecov Report

❌ Patch coverage is 88.37209% with 5 lines in your changes missing coverage. Please review.
✅ Project coverage is 71.98%. Comparing base (1077da7) to head (2d0cc42).

Files with missing lines Patch % Lines
onnxscript/rewriter/pattern_test.py 82.60% 2 Missing and 2 partials ⚠️
onnxscript/rewriter/_rewrite_rule.py 92.85% 0 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2873      +/-   ##
==========================================
+ Coverage   71.96%   71.98%   +0.02%     
==========================================
  Files         239      239              
  Lines       29224    29262      +38     
  Branches     2878     2885       +7     
==========================================
+ Hits        21031    21065      +34     
- Misses       7216     7218       +2     
- Partials      977      979       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@srikris-sridhar
Copy link
Copy Markdown
Author

@srikris-sridhar please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.

@microsoft-github-policy-service agree [company="{your company}"]

Options:

  • (default - no company specified) I have sole ownership of intellectual property rights to my Submissions and I am not making Submissions in the course of work for my employer.
@microsoft-github-policy-service agree
  • (when company given) I am making Submissions in the course of work for my employer (or my employer has intellectual property rights in my Submissions by contract or applicable law). I have permission from my employer to make Submissions and enter into this Agreement on behalf of my employer. By signing below, the defined term “You” includes me and my employer.
@microsoft-github-policy-service agree company="Microsoft"

Contributor License Agreement

@microsoft-github-policy-service agree

Adds support for pattern matching on multiple outputs
for optimizations involving ops like Split.

```
def pattern(self, op, x):
        relu = op.Relu(x)
        return op.Split(relu, _allow_flexible_outputs=True)
```

Fixes microsoft#2581
@srikris-sridhar srikris-sridhar force-pushed the dev/srikris/allow_other_output branch from 2d0cc42 to 38cf331 Compare March 29, 2026 06:17
@justinchuby
Copy link
Copy Markdown
Collaborator

Thanks for creating a PR! Will review it this week. cc @gramalingam

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds support for rewrite patterns that can match nodes with variadic/unknown output arity (e.g., Split) by introducing an _allow_flexible_outputs pattern option and adjusting rewrite application to optionally replace all outputs of the matched variadic node.

Changes:

  • Add _allow_flexible_outputs to pattern IR node construction and store it on NodePattern.
  • Update matcher + rewrite application to support flexible-output rewrites (including optional _match injection into replacement functions).
  • Add unit tests covering flexible-output matching/rewriting scenarios.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.

File Description
onnxscript/rewriter/_pattern_ir.py Introduces _allow_flexible_outputs plumbing into NodePattern.
onnxscript/rewriter/_matcher.py Adjusts removability validation behavior for flexible-output patterns.
onnxscript/rewriter/_rewrite_rule.py Injects _match into replacement functions, relaxes output-count validation for flexible patterns, and adds flexible-output replacement path.
onnxscript/rewriter/pattern_test.py Adds tests validating flexible-output matching/rewrite behavior.
Comments suppressed due to low confidence (1)

onnxscript/rewriter/_rewrite_rule.py:374

  • Skipping the replacement-output arity check whenever the pattern contains any flexible-output node is too broad: it can silently allow output-count mismatches even when the flexible-output node isn’t the one being replaced (or when the replacement isn’t actually returning the full variadic outputs). Consider validating expected arity based on the actual matched flexible node(s) (e.g., len(bound_node.outputs)) vs the normal target_pattern.num_outputs, and fail clearly when neither condition is satisfied.
        # Check if any node in the pattern uses flexible outputs (cache for hot path)
        self._has_flexible_outputs = any(
            node.allow_flexible_outputs for node in self._target_pattern._nodes
        )

        if not isinstance(replacement_pattern, ReplacementPatternFunction):
            replacement_pattern = ReplacementPatternFunction(replacement_pattern)
        self._replacement_pattern = replacement_pattern
        self.remove_nodes = remove_nodes
        self.graph_pre_visitor = graph_pre_visitor
        self.graph_post_visitor = graph_post_visitor
        self.as_function = as_function

    def __str__(self) -> str:
        return self.name if self.name else "Anonymous Rule"

    def try_rewrite(
        self,
        model: ir.Model,
        graph_or_function: ir.Graph | ir.Function,
        node: ir.Node,
        *,
        verbose: int | None = None,
        tracer: _basics.MatchingTracer | None = None,
    ) -> ReplacementSubgraph | None:
        """If the node matches the pattern, then replace the node with the replacement pattern."""
        # Use the inherited match method from Pattern
        match = self.match(
            model,
            graph_or_function,
            node,
            verbose=verbose,
            check_nodes_are_removable=self.remove_nodes,
            tracer=tracer,
        )
        if not match:
            return None

        replacement_subgraph = self._replacement_pattern.get_replacement(match)
        if replacement_subgraph is None:
            if tracer:
                tracer.log(
                    self,
                    graph_or_function,
                    node,
                    match,
                    _basics.MatchStatus.REPLACEMENT_FAILED,
                )
            return None

        if not self._has_flexible_outputs and len(replacement_subgraph.new_outputs) != self._target_pattern.num_outputs:
            raise ValueError(
                f"Number of outputs from replacement function does not match the number of outputs from the target pattern. "
                f"Expected {self._target_pattern.num_outputs}, but got {len(replacement_subgraph.new_outputs)}."
            )

Comment on lines 216 to +225
def __init__(self, function) -> None:
self._function = function
# Cache signature inspection to avoid repeated introspection on hot path
self._accepts_match = "_match" in inspect.signature(function).parameters

def get_replacement(self, match: _basics.MatchResult) -> ReplacementSubgraph | None:
context = RewriterContext()
bindings = match.bindings if not self._accepts_match else {**match.bindings, "_match": match}
try:
new_outputs = self._function(context, **match.bindings)
new_outputs = self._function(context, **bindings)
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inspect.signature(function) can raise TypeError/ValueError for some callables (e.g., certain builtins, C-extension callables, or objects with non-introspectable __call__). That would make constructing ReplacementPatternFunction fail even though the callable is otherwise usable. Consider guarding this with try/except and defaulting to not injecting _match when the signature can't be inspected.

Copilot uses AI. Check for mistakes.
Comment on lines +779 to +805
# Check if this is a flexible output case (matched node has more outputs than captured)
flexible_node = None
for matched_node in delta.match.nodes:
if len(matched_node.outputs) > len(delta.match.outputs):
flexible_node = matched_node
break

if flexible_node and len(delta.new_outputs) == len(flexible_node.outputs):
# Flexible output replacement: replace all outputs of the flexible node
convenience.replace_nodes_and_values(
graph_or_function,
node,
delta.match.nodes if rule.remove_nodes else [],
delta.new_nodes,
flexible_node.outputs,
delta.new_outputs,
)
else:
# Standard replacement
convenience.replace_nodes_and_values(
graph_or_function,
node,
delta.match.nodes if rule.remove_nodes else [],
delta.new_nodes,
delta.match.outputs,
delta.new_outputs,
)
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The flexible-output replacement detection (len(matched_node.outputs) > len(delta.match.outputs)) is a heuristic that can misidentify the node (or miss it) when the pattern has multiple outputs / multiple matched nodes. Also, when the heuristic doesn’t trigger, the code falls back to standard replacement even though removability checks were skipped for flexible-output patterns, which can lead to removing a node whose other outputs are still used. Prefer identifying flexible nodes via delta.match.node_bindings + NodePattern.allow_flexible_outputs, and enforce that remove_nodes=True rules either (a) replace all outputs of that node or (b) don’t apply.

Copilot uses AI. Check for mistakes.
Comment on lines +313 to +315
# Skip removability check for flexible output nodes since they may have
# additional outputs beyond those captured in the pattern
if check_removable and not pattern.output_node.allow_flexible_outputs:
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skipping _valid_to_replace for flexible-output patterns can make rewrites unsafe: if a variadic-output node has additional outputs used outside the matched subgraph, the match will still succeed and remove_nodes=True can delete a producer of still-used values unless the replacement reliably substitutes all those outputs. Consider running the removability check against the full output list of the matched node when allow_flexible_outputs is set, and/or require the replacement to return exactly len(matched_node.outputs) outputs when nodes are being removed.

Suggested change
# Skip removability check for flexible output nodes since they may have
# additional outputs beyond those captured in the pattern
if check_removable and not pattern.output_node.allow_flexible_outputs:
if check_removable:

Copilot uses AI. Check for mistakes.
Comment on lines 440 to +463
attributes: dict[str, AttrPattern],
outputs: Sequence[str | None],
*,
allow_other_attributes: bool | None,
allow_other_inputs: bool | None,
allow_flexible_outputs: bool | None = None,
check: Callable | None = None,
):
if allow_other_attributes is None:
# Default behavior: allow other unmatched attributes in the node.
allow_other_attributes = True
if allow_other_inputs is None:
# TODO(rama): Should we default to True? For now, we preserve the current behavior.
allow_other_inputs = False
if allow_flexible_outputs is None:
# Default behavior: do not match flexible outputs
allow_flexible_outputs = False
self.domain = domain
self.op = StringConstantPattern(op) if isinstance(op, str) else op
self.inputs = [_to_value_pattern(x) for x in inputs]
self.attributes = attributes
self.allow_other_attributes = allow_other_attributes
self.allow_other_inputs = allow_other_inputs
self.allow_flexible_outputs = allow_flexible_outputs
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allow_flexible_outputs is stored on NodePattern, but NodePattern.clone() (used by GraphPattern.commute() and potentially other copying paths) doesn’t pass this flag to the constructor, so cloned/commuted patterns will lose flexible-output behavior. Ensure cloning preserves allow_flexible_outputs so the feature works consistently across pattern transformations.

Copilot uses AI. Check for mistakes.
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[10] x) => (float[10] out)
{
out = Split<axis=0>(x)
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test constructs a Split node with a single output but without specifying num_outputs/split. Depending on the ONNX opset/schema validation, this may be considered invalid or ambiguous. To make the test robust, consider explicitly setting num_outputs=1 (and/or providing a split input) so the model is unambiguously valid across checkers/runtimes.

Suggested change
out = Split<axis=0>(x)
out = Split<axis=0, num_outputs=1>(x)

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Development

Successfully merging this pull request may close these issues.

Matching variadic lists of tensors?

4 participants