From 3cf4f554e6f93c12cd4a8973bd38ddd993fa968b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 13 Mar 2026 09:38:09 -0700 Subject: [PATCH 1/6] Allow functions to not be inlined Signed-off-by: Justin Chu --- onnxscript/_internal/builder.py | 51 ++++++++++++++++++++++++++------- onnxscript/_internal/values.py | 4 +++ 2 files changed, 44 insertions(+), 11 deletions(-) diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index 1dc0875871..59591622b4 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -229,6 +229,8 @@ def __init__(self, graph: ir.Graph) -> None: # and allows sharing them across different layers/contexts. self._constant_cache: dict[tuple[Any, ir.DataType | None], ir.Value] = {} + self._functions: dict[ir.OperatorIdentifier, ir.Function] = {} + def opset(self, domain: str, version: int = 1) -> OpBuilder: """Create an OpBuilder bound to the given domain and version.""" return OpBuilder(self, domain, version) @@ -241,6 +243,10 @@ def op(self) -> OpBuilder: def graph(self) -> ir.Graph: return self._graph + @property + def functions(self) -> dict[ir.OperatorIdentifier, ir.Function]: + return self._functions + def initializer( self, tensor: ir.TensorProtocol, name: str | None = None, *, qualify: bool = True ) -> ir.Value: @@ -543,16 +549,18 @@ def call_op( def call( self, - function, + function: ir.Function | onnxscript.OnnxFunction, *args, _outputs: Sequence[str] | None = None, _prefix: str = "", + _inline: bool = True, **kwargs, ): if isinstance(function, ir.Function): graph = function.graph elif isinstance(function, onnxscript.OnnxFunction): graph = function.graph() + function = function.function_ir else: raise TypeError("Function must be an ir.Function or onnxscript.OnnxFunction") output_renaming: dict[str, str] = {} @@ -567,18 +575,36 @@ def call( else: for output in graph.outputs: output_renaming[output.name] = self._qualify_value_name(output.name) - nodes, outputs = _inliner.instantiate(graph, args, kwargs) if _prefix: self.push_module(_prefix) - for node in nodes: - node.name = self._qualify_node_name(node.name) - for output in node.outputs: - if output.name: - if output.name in output_renaming: - output.name = output_renaming[output.name] - else: - output.name = self._qualify_value_name(output.name) + + if _inline: + nodes, outputs = _inliner.instantiate(graph, args, kwargs) + + for node in nodes: + node.name = self._qualify_node_name(node.name) + for output in node.outputs: + if output.name: + if output.name in output_renaming: + output.name = output_renaming[output.name] + else: + output.name = self._qualify_value_name(output.name) + self.add_node(node) + else: + node = ir.node( + op_type=function.name, + inputs=args, + attributes=kwargs or None, + outputs=[ + ir.Value(name=output_renaming[output.name]) for output in graph.outputs + ], + domain=function.domain, + name=self._qualify_node_name(function.name), + ) + outputs = node.outputs self.add_node(node) + self._functions[function.identifier] = function + if _prefix: self.pop_module() return outputs if len(outputs) > 1 else outputs[0] @@ -690,9 +716,10 @@ def call( *args, _outputs: Sequence[str] | None = None, _prefix: str = "", + _inline: bool = True, **kwargs, ): - """Call a function and inline it into the graph. + """Call a function and optionally inline it into the graph. Args: function: The function to call (ir.Function or onnxscript.OnnxFunction). @@ -700,6 +727,8 @@ def call( _outputs: Optional sequence of output names. If provided, must match the number of function outputs. _prefix: Optional prefix for module scoping (e.g., "layers.0"). + _inline: If True, the function body is inlined into the caller graph instead of being + called as a separate node. Defaults to True. **kwargs: Keyword arguments to pass to the function. Returns: diff --git a/onnxscript/_internal/values.py b/onnxscript/_internal/values.py index 051cb3e686..c3fd1b3dad 100644 --- a/onnxscript/_internal/values.py +++ b/onnxscript/_internal/values.py @@ -222,6 +222,10 @@ def __call__(self, *args, **kwargs): def name(self) -> str: return self._name + @property + def domain(self) -> str: + return self._opset.domain + @property def opset(self) -> Opset: return self._opset From ebbf30fa18e035f49c66c8a4219df0460fd700a7 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 13 Mar 2026 09:53:28 -0700 Subject: [PATCH 2/6] Add functions Signed-off-by: Justin Chu --- onnxscript/_internal/builder.py | 8 +- onnxscript/_internal/builder_test.py | 151 +++++++++++++++++++++++++++ 2 files changed, 157 insertions(+), 2 deletions(-) diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index 59591622b4..0c59170f9f 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -710,6 +710,9 @@ def __getattr__(self, op_type: str) -> Callable: def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value: return self._builder.initializer(tensor, name) + def functions(self) -> dict[ir.OperatorIdentifier, ir.Function]: + return self._builder.functions + def call( self, function, @@ -728,12 +731,13 @@ def call( number of function outputs. _prefix: Optional prefix for module scoping (e.g., "layers.0"). _inline: If True, the function body is inlined into the caller graph instead of being - called as a separate node. Defaults to True. + called as a separate node. When False, the function will be added + to the ``.functions`` dictionary. Defaults to True. **kwargs: Keyword arguments to pass to the function. Returns: The output value(s) from the function call. """ return self._builder.call( - function, *args, _outputs=_outputs, _prefix=_prefix, **kwargs + function, *args, _outputs=_outputs, _prefix=_prefix, _inline=_inline, **kwargs ) diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index f6f301954b..cbce5b685f 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -848,6 +848,157 @@ def add_mul(X, Y): self.assertIn("does not match", str(cm.exception)) + def test_call_inline_false_creates_single_function_node(self): + """Test that _inline=False creates a single function call node instead of inlining.""" + op, x, y = _create_builder_with_inputs() + + @script(default_opset=op) + def mul_add_relu(X, Y): + tmp = X * Y + tmp = tmp + X + return op.Relu(tmp) + + result = op.call(mul_add_relu, x, y, _inline=False) + + # With _inline=False, only a single node should be created (the function call) + nodes = list(op.builder.graph) + self.assertEqual(len(nodes), 1) + + node = nodes[0] + self.assertEqual(node.op_type, "mul_add_relu") + self.assertEqual(list(node.inputs), [x, y]) + + # The result should be a single ir.Value + self.assertIsInstance(result, ir.Value) + self.assertIs(result, node.outputs[0]) + + def test_call_inline_false_registers_function(self): + """Test that _inline=False registers the function in GraphBuilder.functions.""" + op, x, y = _create_builder_with_inputs() + + @script(default_opset=op) + def simple_add(X, Y): + return op.Add(X, Y) + + op.call(simple_add, x, y, _inline=False) + + # The function should be registered + self.assertEqual(len(op.builder.functions), 1) + registered = list(op.builder.functions.values())[0] + self.assertEqual(registered.name, "simple_add") + + def test_call_inline_true_does_not_register_function(self): + """Test that _inline=True (default) does not register the function.""" + op, x, y = _create_builder_with_inputs() + + @script(default_opset=op) + def simple_add(X, Y): + return op.Add(X, Y) + + op.call(simple_add, x, y, _inline=True) + + # No function should be registered when inlining + self.assertEqual(len(op.builder.functions), 0) + + def test_call_inline_false_with_outputs_option(self): + """Test that _inline=False respects the _outputs option for renaming.""" + op, x, y = _create_builder_with_inputs() + + @script(default_opset=op) + def add_mul(X, Y): + a = X + Y + b = X * Y + return a, b + + result = op.call( + add_mul, x, y, _outputs=["sum_result", "product_result"], _inline=False + ) + + # The result should be a sequence of 2 ir.Values + self.assertEqual(len(result), 2) + sum_result, product_result = result + + # Verify output names + self.assertEqual(sum_result.name, "v_sum_result") + self.assertEqual(product_result.name, "v_product_result") + + # Only one node (the function call), not inlined + nodes = list(op.builder.graph) + self.assertEqual(len(nodes), 1) + self.assertEqual(nodes[0].op_type, "add_mul") + + def test_call_inline_false_with_prefix_option(self): + """Test that _inline=False respects the _prefix option for hierarchical naming.""" + op, x, y = _create_builder_with_inputs() + + @script(default_opset=op) + def mul_add_relu(X, Y): + tmp = X * Y + tmp = tmp + X + return op.Relu(tmp) + + result = op.call(mul_add_relu, x, y, _prefix="layer1", _inline=False) + + nodes = list(op.builder.graph) + self.assertEqual(len(nodes), 1) + + # The node name should have the prefix + self.assertTrue( + nodes[0].name.startswith("layer1/"), + f"Node name {nodes[0].name} should start with layer1/", + ) + + self.assertIsInstance(result, ir.Value) + + def test_call_inline_false_via_op_builder(self): + """Test that _inline=False works when called through OpBuilder.call.""" + op, x, y = _create_builder_with_inputs() + + @script(default_opset=op) + def simple_add(X, Y): + return op.Add(X, Y) + + # Call through OpBuilder (not GraphBuilder directly) + result = op.call(simple_add, x, y, _inline=False) + + # Should produce a single function call node + nodes = list(op.builder.graph) + self.assertEqual(len(nodes), 1) + self.assertEqual(nodes[0].op_type, "simple_add") + self.assertIsInstance(result, ir.Value) + + # Function should be registered + self.assertEqual(len(op.builder.functions), 1) + + def test_call_inline_true_produces_more_nodes_than_inline_false(self): + """Test that inlining produces individual op nodes while non-inlining produces one.""" + # Inline version + op1, x1, y1 = _create_builder_with_inputs() + + @script(default_opset=op1) + def mul_add(X, Y): + tmp = X * Y + return op1.Add(tmp, X) + + op1.call(mul_add, x1, y1, _inline=True) + inline_nodes = list(op1.builder.graph) + + # Non-inline version + op2, x2, y2 = _create_builder_with_inputs() + + @script(default_opset=op2) + def mul_add2(X, Y): + tmp = X * Y + return op2.Add(tmp, X) + + op2.call(mul_add2, x2, y2, _inline=False) + non_inline_nodes = list(op2.builder.graph) + + # Inlining should produce 2 nodes (Mul, Add), non-inlining should produce 1 + self.assertEqual(len(inline_nodes), 2) + self.assertEqual(len(non_inline_nodes), 1) + self.assertEqual(non_inline_nodes[0].op_type, "mul_add2") + class BuildSubgraphTest(unittest.TestCase): """Tests for GraphBuilder.subgraph().""" From b1951059190e9227f316a42cec274788bf8dba39 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 13 Mar 2026 09:53:47 -0700 Subject: [PATCH 3/6] Lint Signed-off-by: Justin Chu --- onnxscript/_internal/builder_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index cbce5b685f..710e61056c 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -884,7 +884,7 @@ def simple_add(X, Y): # The function should be registered self.assertEqual(len(op.builder.functions), 1) - registered = list(op.builder.functions.values())[0] + registered = next(iter(op.builder.functions.values())) self.assertEqual(registered.name, "simple_add") def test_call_inline_true_does_not_register_function(self): From 72639b14f487d5f85be97f31385d10955c92510c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 13 Mar 2026 10:11:10 -0700 Subject: [PATCH 4/6] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- onnxscript/_internal/builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index 0c59170f9f..fc0267271e 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -603,7 +603,7 @@ def call( ) outputs = node.outputs self.add_node(node) - self._functions[function.identifier] = function + self._functions[function.identifier()] = function if _prefix: self.pop_module() From 5a9e00bff4ff316acfbe45557d8b6fa58dc804fe Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 23 Mar 2026 17:10:53 -0700 Subject: [PATCH 5/6] Update impl --- onnxscript/_internal/builder.py | 117 ++++++++++++++++++--------- onnxscript/_internal/builder_test.py | 84 ++++++++++--------- 2 files changed, 119 insertions(+), 82 deletions(-) diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index fc0267271e..47d28ed3d3 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -511,13 +511,13 @@ def call_op( self, op_type: str, inputs: Sequence[ir.Value | ir.TensorProtocol], - kwargs: dict[str, Any], + kwargs: dict[str, ir.Value | ir.TensorProtocol], + /, + domain: str = "", + version: int | None = None, + outputs: int | Sequence[str | ir.Value] = 1, ): """Create an ONNX node and add it to the graph, returning its output value(s).""" - domain = kwargs.pop("_domain", "") - version = kwargs.pop("_version", None) - outputs = kwargs.pop("_outputs", 1) - count = self.graph.num_nodes() node_name = self._qualify_node_name(f"{op_type}_node_{count}") @@ -548,12 +548,49 @@ def call_op( return node.outputs if len(node.outputs) > 1 else node.outputs[0] def call( + self, + function: ir.Function | onnxscript.OnnxFunction, + *args, + _outputs: int | Sequence[str | ir.Value] | None = None, + **kwargs, + ): + """Call a function as a single function node.""" + if isinstance(function, ir.Function): + graph = function.graph + elif isinstance(function, onnxscript.OnnxFunction): + graph = function.graph() + function = function.function_ir + else: + raise TypeError("Function must be an ir.Function or onnxscript.OnnxFunction") + + if _outputs is None: + _outputs = len(graph.outputs) + output_values = self._adapt_outputs(_outputs, function.name) + + node = ir.node( + op_type=function.name, + inputs=args, + attributes=kwargs or None, + outputs=output_values, + domain=function.domain, + name=self._qualify_node_name(function.name), + ) + # Attach scope metadata to the node + node.metadata_props["namespace"] = self._build_namespace() + node.metadata_props["pkg.onnxscript.class_hierarchy"] = repr(self._scope_classes()) + node.metadata_props["pkg.onnxscript.name_scopes"] = repr(self._scope_names()) + + self.add_node(node) + self._functions[function.identifier()] = function + + return node.outputs if len(node.outputs) > 1 else node.outputs[0] + + def call_inline( self, function: ir.Function | onnxscript.OnnxFunction, *args, _outputs: Sequence[str] | None = None, _prefix: str = "", - _inline: bool = True, **kwargs, ): if isinstance(function, ir.Function): @@ -575,35 +612,21 @@ def call( else: for output in graph.outputs: output_renaming[output.name] = self._qualify_value_name(output.name) + + nodes, outputs = _inliner.instantiate(graph, args, kwargs) + if _prefix: self.push_module(_prefix) - if _inline: - nodes, outputs = _inliner.instantiate(graph, args, kwargs) - - for node in nodes: - node.name = self._qualify_node_name(node.name) - for output in node.outputs: - if output.name: - if output.name in output_renaming: - output.name = output_renaming[output.name] - else: - output.name = self._qualify_value_name(output.name) - self.add_node(node) - else: - node = ir.node( - op_type=function.name, - inputs=args, - attributes=kwargs or None, - outputs=[ - ir.Value(name=output_renaming[output.name]) for output in graph.outputs - ], - domain=function.domain, - name=self._qualify_node_name(function.name), - ) - outputs = node.outputs + for node in nodes: + node.name = self._qualify_node_name(node.name) + for output in node.outputs: + if output.name: + if output.name in output_renaming: + output.name = output_renaming[output.name] + else: + output.name = self._qualify_value_name(output.name) self.add_node(node) - self._functions[function.identifier()] = function if _prefix: self.pop_module() @@ -714,15 +737,34 @@ def functions(self) -> dict[ir.OperatorIdentifier, ir.Function]: return self._builder.functions def call( + self, + function, + *args, + _outputs: Sequence[str] | int | None = None, + **kwargs, + ): + """Call a function as a single function node. + + Args: + function: The function to call (ir.Function or onnxscript.OnnxFunction). + *args: Positional arguments to pass to the function. + _outputs: Optional sequence of output names, or an integer specifying the number of outputs. + **kwargs: Keyword arguments to pass to the function. + + Returns: + The output value(s) from the function call. + """ + return self._builder.call(function, *args, _outputs=_outputs, **kwargs) + + def call_inline( self, function, *args, _outputs: Sequence[str] | None = None, _prefix: str = "", - _inline: bool = True, **kwargs, ): - """Call a function and optionally inline it into the graph. + """Inline a function body into the current graph. Args: function: The function to call (ir.Function or onnxscript.OnnxFunction). @@ -730,14 +772,11 @@ def call( _outputs: Optional sequence of output names. If provided, must match the number of function outputs. _prefix: Optional prefix for module scoping (e.g., "layers.0"). - _inline: If True, the function body is inlined into the caller graph instead of being - called as a separate node. When False, the function will be added - to the ``.functions`` dictionary. Defaults to True. **kwargs: Keyword arguments to pass to the function. Returns: - The output value(s) from the function call. + The output value(s) from the inlined function body. """ - return self._builder.call( - function, *args, _outputs=_outputs, _prefix=_prefix, _inline=_inline, **kwargs + return self._builder.call_inline( + function, *args, _outputs=_outputs, _prefix=_prefix, **kwargs ) diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index 710e61056c..f902d67902 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -652,8 +652,8 @@ def test_attributes_are_created_properly(self): self.assertEqual(strs_attr.type, ir.AttributeType.STRINGS) self.assertEqual(list(strs_attr.value), ["a", "b", "c"]) - def test_call_inlines_onnxscript_function(self): - """Test that GraphBuilder.call inlines an @onnxscript.script function.""" + def test_call_inline_inlines_onnxscript_function(self): + """Test that GraphBuilder.call_inline inlines an @onnxscript.script function.""" # Create a GraphBuilder first op, x, y = _create_builder_with_inputs() @@ -664,7 +664,7 @@ def mul_add_relu(X, Y): tmp = tmp + X return op.Relu(tmp) - result = op.call(mul_add_relu, x, y) + result = op.call_inline(mul_add_relu, x, y) # The inlined function should produce 3 nodes: Mul, Add, Relu nodes = list(op.builder.graph) @@ -688,8 +688,8 @@ def mul_add_relu(X, Y): self.assertIs(mul_node.inputs[0], x) self.assertIs(mul_node.inputs[1], y) - def test_call_with_outputs_option(self): - """Test that GraphBuilder.call respects the _outputs option for renaming.""" + def test_call_inline_with_outputs_option(self): + """Test that GraphBuilder.call_inline respects the _outputs option for renaming.""" # Create a GraphBuilder first op, x, y = _create_builder_with_inputs() @@ -700,7 +700,7 @@ def add_mul(X, Y): b = X * Y return a, b - result = op.call(add_mul, x, y, _outputs=["sum_result", "product_result"]) + result = op.call_inline(add_mul, x, y, _outputs=["sum_result", "product_result"]) # The result should be a list of 2 ir.Values (when function returns multiple outputs) self.assertIsInstance(result, list) @@ -727,8 +727,8 @@ def test_call_with_outer_scope_value(self): def add_product(X): return op.Add(X, product) # Reference to 'product' from outer scope - x_plus = op.call(add_product, x, _outputs=["x_plus"]) - y_plus = op.call(add_product, y, _outputs=["y_plus"]) + x_plus = op.call_inline(add_product, x, _outputs=["x_plus"]) + y_plus = op.call_inline(add_product, y, _outputs=["y_plus"]) op.builder.graph.outputs.extend([x_plus, y_plus]) @@ -742,8 +742,8 @@ def add_product(X): # Verify that the two graphs are structurally equivalent onnxscript.testing.assert_isomorphic_graph(op.builder.graph, op2.builder.graph) - def test_call_with_prefix_option(self): - """Test that GraphBuilder.call respects the _prefix option for hierarchical naming.""" + def test_call_inline_with_prefix_option(self): + """Test that GraphBuilder.call_inline respects the _prefix option for hierarchical naming.""" # Create a GraphBuilder first op, x, y = _create_builder_with_inputs() @@ -754,7 +754,7 @@ def mul_add_relu(X, Y): tmp = tmp + X return op.Relu(tmp) - result = op.call(mul_add_relu, x, y, _prefix="layer1") + result = op.call_inline(mul_add_relu, x, y, _prefix="layer1") # The nodes should have the prefix in their names nodes = list(op.builder.graph) @@ -770,8 +770,8 @@ def mul_add_relu(X, Y): # Verify the result is a single ir.Value self.assertIsInstance(result, ir.Value) - def test_call_with_outputs_and_prefix_options(self): - """Test that GraphBuilder.call respects both _outputs and _prefix options together. + def test_call_inline_with_outputs_and_prefix_options(self): + """Test that GraphBuilder.call_inline respects both _outputs and _prefix options together. Note: _outputs names are set before the prefix context is applied, so they don't get the prefix in their names. However, the inlined nodes do get the prefix applied, and @@ -791,7 +791,7 @@ def add_mul(X, Y): b = XSquare * YSquare return a, b - result = op.call( + result = op.call_inline( add_mul, x, y, _outputs=["custom_sum", "custom_product"], _prefix="math_ops" ) @@ -830,8 +830,8 @@ def add_mul(X, Y): f"Intermediate value {y_square.name} should have prefix", ) - def test_call_outputs_mismatch_error(self): - """Test that GraphBuilder.call raises an error if _outputs has wrong count.""" + def test_call_inline_outputs_mismatch_error(self): + """Test that GraphBuilder.call_inline raises an error if _outputs has wrong count.""" # Create a GraphBuilder first op, x, y = _create_builder_with_inputs() @@ -844,12 +844,12 @@ def add_mul(X, Y): # The function returns 2 outputs, but we provide only 1 name with self.assertRaises(ValueError) as cm: - op.call(add_mul, x, y, _outputs=["only_one_name"]) + op.call_inline(add_mul, x, y, _outputs=["only_one_name"]) self.assertIn("does not match", str(cm.exception)) - def test_call_inline_false_creates_single_function_node(self): - """Test that _inline=False creates a single function call node instead of inlining.""" + def test_call_creates_single_function_node(self): + """Test that GraphBuilder.call creates a single function call node.""" op, x, y = _create_builder_with_inputs() @script(default_opset=op) @@ -858,9 +858,9 @@ def mul_add_relu(X, Y): tmp = tmp + X return op.Relu(tmp) - result = op.call(mul_add_relu, x, y, _inline=False) + result = op.call(mul_add_relu, x, y) - # With _inline=False, only a single node should be created (the function call) + # Only a single node should be created (the function call) nodes = list(op.builder.graph) self.assertEqual(len(nodes), 1) @@ -872,36 +872,36 @@ def mul_add_relu(X, Y): self.assertIsInstance(result, ir.Value) self.assertIs(result, node.outputs[0]) - def test_call_inline_false_registers_function(self): - """Test that _inline=False registers the function in GraphBuilder.functions.""" + def test_call_registers_function(self): + """Test that GraphBuilder.call registers the function in GraphBuilder.functions.""" op, x, y = _create_builder_with_inputs() @script(default_opset=op) def simple_add(X, Y): return op.Add(X, Y) - op.call(simple_add, x, y, _inline=False) + op.call(simple_add, x, y) # The function should be registered self.assertEqual(len(op.builder.functions), 1) registered = next(iter(op.builder.functions.values())) self.assertEqual(registered.name, "simple_add") - def test_call_inline_true_does_not_register_function(self): - """Test that _inline=True (default) does not register the function.""" + def test_call_inline_does_not_register_function(self): + """Test that GraphBuilder.call_inline does not register the function.""" op, x, y = _create_builder_with_inputs() @script(default_opset=op) def simple_add(X, Y): return op.Add(X, Y) - op.call(simple_add, x, y, _inline=True) + op.call_inline(simple_add, x, y) # No function should be registered when inlining self.assertEqual(len(op.builder.functions), 0) - def test_call_inline_false_with_outputs_option(self): - """Test that _inline=False respects the _outputs option for renaming.""" + def test_call_with_outputs_option(self): + """Test that GraphBuilder.call respects the _outputs option for renaming.""" op, x, y = _create_builder_with_inputs() @script(default_opset=op) @@ -910,9 +910,7 @@ def add_mul(X, Y): b = X * Y return a, b - result = op.call( - add_mul, x, y, _outputs=["sum_result", "product_result"], _inline=False - ) + result = op.call(add_mul, x, y, _outputs=["sum_result", "product_result"]) # The result should be a sequence of 2 ir.Values self.assertEqual(len(result), 2) @@ -922,13 +920,13 @@ def add_mul(X, Y): self.assertEqual(sum_result.name, "v_sum_result") self.assertEqual(product_result.name, "v_product_result") - # Only one node (the function call), not inlined + # Only one node (the function call) nodes = list(op.builder.graph) self.assertEqual(len(nodes), 1) self.assertEqual(nodes[0].op_type, "add_mul") - def test_call_inline_false_with_prefix_option(self): - """Test that _inline=False respects the _prefix option for hierarchical naming.""" + def test_call_with_prefix_option(self): + """Test that GraphBuilder.call respects the _prefix option for hierarchical naming.""" op, x, y = _create_builder_with_inputs() @script(default_opset=op) @@ -937,7 +935,7 @@ def mul_add_relu(X, Y): tmp = tmp + X return op.Relu(tmp) - result = op.call(mul_add_relu, x, y, _prefix="layer1", _inline=False) + result = op.call(mul_add_relu, x, y, _prefix="layer1") nodes = list(op.builder.graph) self.assertEqual(len(nodes), 1) @@ -950,8 +948,8 @@ def mul_add_relu(X, Y): self.assertIsInstance(result, ir.Value) - def test_call_inline_false_via_op_builder(self): - """Test that _inline=False works when called through OpBuilder.call.""" + def test_call_via_op_builder(self): + """Test that GraphBuilder.call works when called through OpBuilder.call.""" op, x, y = _create_builder_with_inputs() @script(default_opset=op) @@ -959,7 +957,7 @@ def simple_add(X, Y): return op.Add(X, Y) # Call through OpBuilder (not GraphBuilder directly) - result = op.call(simple_add, x, y, _inline=False) + result = op.call(simple_add, x, y) # Should produce a single function call node nodes = list(op.builder.graph) @@ -970,8 +968,8 @@ def simple_add(X, Y): # Function should be registered self.assertEqual(len(op.builder.functions), 1) - def test_call_inline_true_produces_more_nodes_than_inline_false(self): - """Test that inlining produces individual op nodes while non-inlining produces one.""" + def test_call_inline_produces_more_nodes_than_call(self): + """Test that call_inline produces op nodes while call produces one function node.""" # Inline version op1, x1, y1 = _create_builder_with_inputs() @@ -980,7 +978,7 @@ def mul_add(X, Y): tmp = X * Y return op1.Add(tmp, X) - op1.call(mul_add, x1, y1, _inline=True) + op1.call_inline(mul_add, x1, y1) inline_nodes = list(op1.builder.graph) # Non-inline version @@ -991,7 +989,7 @@ def mul_add2(X, Y): tmp = X * Y return op2.Add(tmp, X) - op2.call(mul_add2, x2, y2, _inline=False) + op2.call(mul_add2, x2, y2) non_inline_nodes = list(op2.builder.graph) # Inlining should produce 2 nodes (Mul, Add), non-inlining should produce 1 From 0a3efad57712c8385036ba8b148619a774119963 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 23 Mar 2026 17:26:02 -0700 Subject: [PATCH 6/6] Fix OpBuilder to properly extract _domain, _version, _outputs from kwargs OpBuilder._call_op was inserting _domain, _version into the kwargs dict, but GraphBuilder.call_op expects domain, version, outputs as separate keyword arguments. This caused them to be treated as node attributes, breaking custom domain handling, schema lookup, type inference, shape inference, and output naming. Changes: - OpBuilder._call_op: pop _domain, _version, _outputs from kwargs and pass as separate keyword args to call_op - Remove _prefix from GraphBuilder.call and OpBuilder.call (only call_inline needs it) - Update test to use push_module/pop_module instead of _prefix on call --- onnxscript/_internal/builder.py | 11 ++++++----- onnxscript/_internal/builder_test.py | 8 +++++--- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index 47d28ed3d3..c74b1844a1 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -721,11 +721,12 @@ def version(self) -> int | None: return self._version def _call_op(self, op_type: str, inputs: Sequence[Any], kwargs: dict[str, Any]): - if "_domain" not in kwargs: - kwargs["_domain"] = self._domain - if self._version is not None and "_version" not in kwargs: - kwargs["_version"] = self._version - return self._builder.call_op(op_type, inputs, kwargs) + domain = kwargs.pop("_domain", self._domain) + version = kwargs.pop("_version", self._version) + outputs = kwargs.pop("_outputs", 1) + return self._builder.call_op( + op_type, inputs, kwargs, domain=domain, version=version, outputs=outputs + ) def __getattr__(self, op_type: str) -> Callable: return lambda *args, **kwargs: self._call_op(op_type, args, kwargs) diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index f902d67902..c439e9e61c 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -925,8 +925,8 @@ def add_mul(X, Y): self.assertEqual(len(nodes), 1) self.assertEqual(nodes[0].op_type, "add_mul") - def test_call_with_prefix_option(self): - """Test that GraphBuilder.call respects the _prefix option for hierarchical naming.""" + def test_call_with_push_module_prefix(self): + """Test that GraphBuilder.call respects push_module for hierarchical naming.""" op, x, y = _create_builder_with_inputs() @script(default_opset=op) @@ -935,7 +935,9 @@ def mul_add_relu(X, Y): tmp = tmp + X return op.Relu(tmp) - result = op.call(mul_add_relu, x, y, _prefix="layer1") + op.builder.push_module("layer1") + result = op.call(mul_add_relu, x, y) + op.builder.pop_module() nodes = list(op.builder.graph) self.assertEqual(len(nodes), 1)