diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index 1dc0875871..c74b1844a1 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -229,6 +229,8 @@ def __init__(self, graph: ir.Graph) -> None: # and allows sharing them across different layers/contexts. self._constant_cache: dict[tuple[Any, ir.DataType | None], ir.Value] = {} + self._functions: dict[ir.OperatorIdentifier, ir.Function] = {} + def opset(self, domain: str, version: int = 1) -> OpBuilder: """Create an OpBuilder bound to the given domain and version.""" return OpBuilder(self, domain, version) @@ -241,6 +243,10 @@ def op(self) -> OpBuilder: def graph(self) -> ir.Graph: return self._graph + @property + def functions(self) -> dict[ir.OperatorIdentifier, ir.Function]: + return self._functions + def initializer( self, tensor: ir.TensorProtocol, name: str | None = None, *, qualify: bool = True ) -> ir.Value: @@ -505,13 +511,13 @@ def call_op( self, op_type: str, inputs: Sequence[ir.Value | ir.TensorProtocol], - kwargs: dict[str, Any], + kwargs: dict[str, ir.Value | ir.TensorProtocol], + /, + domain: str = "", + version: int | None = None, + outputs: int | Sequence[str | ir.Value] = 1, ): """Create an ONNX node and add it to the graph, returning its output value(s).""" - domain = kwargs.pop("_domain", "") - version = kwargs.pop("_version", None) - outputs = kwargs.pop("_outputs", 1) - count = self.graph.num_nodes() node_name = self._qualify_node_name(f"{op_type}_node_{count}") @@ -543,7 +549,45 @@ def call_op( def call( self, - function, + function: ir.Function | onnxscript.OnnxFunction, + *args, + _outputs: int | Sequence[str | ir.Value] | None = None, + **kwargs, + ): + """Call a function as a single function node.""" + if isinstance(function, ir.Function): + graph = function.graph + elif isinstance(function, onnxscript.OnnxFunction): + graph = function.graph() + function = function.function_ir + else: + raise TypeError("Function must be an ir.Function or onnxscript.OnnxFunction") + + if _outputs is None: + _outputs = len(graph.outputs) + output_values = self._adapt_outputs(_outputs, function.name) + + node = ir.node( + op_type=function.name, + inputs=args, + attributes=kwargs or None, + outputs=output_values, + domain=function.domain, + name=self._qualify_node_name(function.name), + ) + # Attach scope metadata to the node + node.metadata_props["namespace"] = self._build_namespace() + node.metadata_props["pkg.onnxscript.class_hierarchy"] = repr(self._scope_classes()) + node.metadata_props["pkg.onnxscript.name_scopes"] = repr(self._scope_names()) + + self.add_node(node) + self._functions[function.identifier()] = function + + return node.outputs if len(node.outputs) > 1 else node.outputs[0] + + def call_inline( + self, + function: ir.Function | onnxscript.OnnxFunction, *args, _outputs: Sequence[str] | None = None, _prefix: str = "", @@ -553,6 +597,7 @@ def call( graph = function.graph elif isinstance(function, onnxscript.OnnxFunction): graph = function.graph() + function = function.function_ir else: raise TypeError("Function must be an ir.Function or onnxscript.OnnxFunction") output_renaming: dict[str, str] = {} @@ -567,9 +612,12 @@ def call( else: for output in graph.outputs: output_renaming[output.name] = self._qualify_value_name(output.name) + nodes, outputs = _inliner.instantiate(graph, args, kwargs) + if _prefix: self.push_module(_prefix) + for node in nodes: node.name = self._qualify_node_name(node.name) for output in node.outputs: @@ -579,6 +627,7 @@ def call( else: output.name = self._qualify_value_name(output.name) self.add_node(node) + if _prefix: self.pop_module() return outputs if len(outputs) > 1 else outputs[0] @@ -672,11 +721,12 @@ def version(self) -> int | None: return self._version def _call_op(self, op_type: str, inputs: Sequence[Any], kwargs: dict[str, Any]): - if "_domain" not in kwargs: - kwargs["_domain"] = self._domain - if self._version is not None and "_version" not in kwargs: - kwargs["_version"] = self._version - return self._builder.call_op(op_type, inputs, kwargs) + domain = kwargs.pop("_domain", self._domain) + version = kwargs.pop("_version", self._version) + outputs = kwargs.pop("_outputs", 1) + return self._builder.call_op( + op_type, inputs, kwargs, domain=domain, version=version, outputs=outputs + ) def __getattr__(self, op_type: str) -> Callable: return lambda *args, **kwargs: self._call_op(op_type, args, kwargs) @@ -684,7 +734,30 @@ def __getattr__(self, op_type: str) -> Callable: def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value: return self._builder.initializer(tensor, name) + def functions(self) -> dict[ir.OperatorIdentifier, ir.Function]: + return self._builder.functions + def call( + self, + function, + *args, + _outputs: Sequence[str] | int | None = None, + **kwargs, + ): + """Call a function as a single function node. + + Args: + function: The function to call (ir.Function or onnxscript.OnnxFunction). + *args: Positional arguments to pass to the function. + _outputs: Optional sequence of output names, or an integer specifying the number of outputs. + **kwargs: Keyword arguments to pass to the function. + + Returns: + The output value(s) from the function call. + """ + return self._builder.call(function, *args, _outputs=_outputs, **kwargs) + + def call_inline( self, function, *args, @@ -692,7 +765,7 @@ def call( _prefix: str = "", **kwargs, ): - """Call a function and inline it into the graph. + """Inline a function body into the current graph. Args: function: The function to call (ir.Function or onnxscript.OnnxFunction). @@ -703,8 +776,8 @@ def call( **kwargs: Keyword arguments to pass to the function. Returns: - The output value(s) from the function call. + The output value(s) from the inlined function body. """ - return self._builder.call( + return self._builder.call_inline( function, *args, _outputs=_outputs, _prefix=_prefix, **kwargs ) diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index f6f301954b..c439e9e61c 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -652,8 +652,8 @@ def test_attributes_are_created_properly(self): self.assertEqual(strs_attr.type, ir.AttributeType.STRINGS) self.assertEqual(list(strs_attr.value), ["a", "b", "c"]) - def test_call_inlines_onnxscript_function(self): - """Test that GraphBuilder.call inlines an @onnxscript.script function.""" + def test_call_inline_inlines_onnxscript_function(self): + """Test that GraphBuilder.call_inline inlines an @onnxscript.script function.""" # Create a GraphBuilder first op, x, y = _create_builder_with_inputs() @@ -664,7 +664,7 @@ def mul_add_relu(X, Y): tmp = tmp + X return op.Relu(tmp) - result = op.call(mul_add_relu, x, y) + result = op.call_inline(mul_add_relu, x, y) # The inlined function should produce 3 nodes: Mul, Add, Relu nodes = list(op.builder.graph) @@ -688,8 +688,8 @@ def mul_add_relu(X, Y): self.assertIs(mul_node.inputs[0], x) self.assertIs(mul_node.inputs[1], y) - def test_call_with_outputs_option(self): - """Test that GraphBuilder.call respects the _outputs option for renaming.""" + def test_call_inline_with_outputs_option(self): + """Test that GraphBuilder.call_inline respects the _outputs option for renaming.""" # Create a GraphBuilder first op, x, y = _create_builder_with_inputs() @@ -700,7 +700,7 @@ def add_mul(X, Y): b = X * Y return a, b - result = op.call(add_mul, x, y, _outputs=["sum_result", "product_result"]) + result = op.call_inline(add_mul, x, y, _outputs=["sum_result", "product_result"]) # The result should be a list of 2 ir.Values (when function returns multiple outputs) self.assertIsInstance(result, list) @@ -727,8 +727,8 @@ def test_call_with_outer_scope_value(self): def add_product(X): return op.Add(X, product) # Reference to 'product' from outer scope - x_plus = op.call(add_product, x, _outputs=["x_plus"]) - y_plus = op.call(add_product, y, _outputs=["y_plus"]) + x_plus = op.call_inline(add_product, x, _outputs=["x_plus"]) + y_plus = op.call_inline(add_product, y, _outputs=["y_plus"]) op.builder.graph.outputs.extend([x_plus, y_plus]) @@ -742,8 +742,8 @@ def add_product(X): # Verify that the two graphs are structurally equivalent onnxscript.testing.assert_isomorphic_graph(op.builder.graph, op2.builder.graph) - def test_call_with_prefix_option(self): - """Test that GraphBuilder.call respects the _prefix option for hierarchical naming.""" + def test_call_inline_with_prefix_option(self): + """Test that GraphBuilder.call_inline respects the _prefix option for hierarchical naming.""" # Create a GraphBuilder first op, x, y = _create_builder_with_inputs() @@ -754,7 +754,7 @@ def mul_add_relu(X, Y): tmp = tmp + X return op.Relu(tmp) - result = op.call(mul_add_relu, x, y, _prefix="layer1") + result = op.call_inline(mul_add_relu, x, y, _prefix="layer1") # The nodes should have the prefix in their names nodes = list(op.builder.graph) @@ -770,8 +770,8 @@ def mul_add_relu(X, Y): # Verify the result is a single ir.Value self.assertIsInstance(result, ir.Value) - def test_call_with_outputs_and_prefix_options(self): - """Test that GraphBuilder.call respects both _outputs and _prefix options together. + def test_call_inline_with_outputs_and_prefix_options(self): + """Test that GraphBuilder.call_inline respects both _outputs and _prefix options together. Note: _outputs names are set before the prefix context is applied, so they don't get the prefix in their names. However, the inlined nodes do get the prefix applied, and @@ -791,7 +791,7 @@ def add_mul(X, Y): b = XSquare * YSquare return a, b - result = op.call( + result = op.call_inline( add_mul, x, y, _outputs=["custom_sum", "custom_product"], _prefix="math_ops" ) @@ -830,8 +830,8 @@ def add_mul(X, Y): f"Intermediate value {y_square.name} should have prefix", ) - def test_call_outputs_mismatch_error(self): - """Test that GraphBuilder.call raises an error if _outputs has wrong count.""" + def test_call_inline_outputs_mismatch_error(self): + """Test that GraphBuilder.call_inline raises an error if _outputs has wrong count.""" # Create a GraphBuilder first op, x, y = _create_builder_with_inputs() @@ -844,10 +844,161 @@ def add_mul(X, Y): # The function returns 2 outputs, but we provide only 1 name with self.assertRaises(ValueError) as cm: - op.call(add_mul, x, y, _outputs=["only_one_name"]) + op.call_inline(add_mul, x, y, _outputs=["only_one_name"]) self.assertIn("does not match", str(cm.exception)) + def test_call_creates_single_function_node(self): + """Test that GraphBuilder.call creates a single function call node.""" + op, x, y = _create_builder_with_inputs() + + @script(default_opset=op) + def mul_add_relu(X, Y): + tmp = X * Y + tmp = tmp + X + return op.Relu(tmp) + + result = op.call(mul_add_relu, x, y) + + # Only a single node should be created (the function call) + nodes = list(op.builder.graph) + self.assertEqual(len(nodes), 1) + + node = nodes[0] + self.assertEqual(node.op_type, "mul_add_relu") + self.assertEqual(list(node.inputs), [x, y]) + + # The result should be a single ir.Value + self.assertIsInstance(result, ir.Value) + self.assertIs(result, node.outputs[0]) + + def test_call_registers_function(self): + """Test that GraphBuilder.call registers the function in GraphBuilder.functions.""" + op, x, y = _create_builder_with_inputs() + + @script(default_opset=op) + def simple_add(X, Y): + return op.Add(X, Y) + + op.call(simple_add, x, y) + + # The function should be registered + self.assertEqual(len(op.builder.functions), 1) + registered = next(iter(op.builder.functions.values())) + self.assertEqual(registered.name, "simple_add") + + def test_call_inline_does_not_register_function(self): + """Test that GraphBuilder.call_inline does not register the function.""" + op, x, y = _create_builder_with_inputs() + + @script(default_opset=op) + def simple_add(X, Y): + return op.Add(X, Y) + + op.call_inline(simple_add, x, y) + + # No function should be registered when inlining + self.assertEqual(len(op.builder.functions), 0) + + def test_call_with_outputs_option(self): + """Test that GraphBuilder.call respects the _outputs option for renaming.""" + op, x, y = _create_builder_with_inputs() + + @script(default_opset=op) + def add_mul(X, Y): + a = X + Y + b = X * Y + return a, b + + result = op.call(add_mul, x, y, _outputs=["sum_result", "product_result"]) + + # The result should be a sequence of 2 ir.Values + self.assertEqual(len(result), 2) + sum_result, product_result = result + + # Verify output names + self.assertEqual(sum_result.name, "v_sum_result") + self.assertEqual(product_result.name, "v_product_result") + + # Only one node (the function call) + nodes = list(op.builder.graph) + self.assertEqual(len(nodes), 1) + self.assertEqual(nodes[0].op_type, "add_mul") + + def test_call_with_push_module_prefix(self): + """Test that GraphBuilder.call respects push_module for hierarchical naming.""" + op, x, y = _create_builder_with_inputs() + + @script(default_opset=op) + def mul_add_relu(X, Y): + tmp = X * Y + tmp = tmp + X + return op.Relu(tmp) + + op.builder.push_module("layer1") + result = op.call(mul_add_relu, x, y) + op.builder.pop_module() + + nodes = list(op.builder.graph) + self.assertEqual(len(nodes), 1) + + # The node name should have the prefix + self.assertTrue( + nodes[0].name.startswith("layer1/"), + f"Node name {nodes[0].name} should start with layer1/", + ) + + self.assertIsInstance(result, ir.Value) + + def test_call_via_op_builder(self): + """Test that GraphBuilder.call works when called through OpBuilder.call.""" + op, x, y = _create_builder_with_inputs() + + @script(default_opset=op) + def simple_add(X, Y): + return op.Add(X, Y) + + # Call through OpBuilder (not GraphBuilder directly) + result = op.call(simple_add, x, y) + + # Should produce a single function call node + nodes = list(op.builder.graph) + self.assertEqual(len(nodes), 1) + self.assertEqual(nodes[0].op_type, "simple_add") + self.assertIsInstance(result, ir.Value) + + # Function should be registered + self.assertEqual(len(op.builder.functions), 1) + + def test_call_inline_produces_more_nodes_than_call(self): + """Test that call_inline produces op nodes while call produces one function node.""" + # Inline version + op1, x1, y1 = _create_builder_with_inputs() + + @script(default_opset=op1) + def mul_add(X, Y): + tmp = X * Y + return op1.Add(tmp, X) + + op1.call_inline(mul_add, x1, y1) + inline_nodes = list(op1.builder.graph) + + # Non-inline version + op2, x2, y2 = _create_builder_with_inputs() + + @script(default_opset=op2) + def mul_add2(X, Y): + tmp = X * Y + return op2.Add(tmp, X) + + op2.call(mul_add2, x2, y2) + non_inline_nodes = list(op2.builder.graph) + + # Inlining should produce 2 nodes (Mul, Add), non-inlining should produce 1 + self.assertEqual(len(inline_nodes), 2) + self.assertEqual(len(non_inline_nodes), 1) + self.assertEqual(non_inline_nodes[0].op_type, "mul_add2") + class BuildSubgraphTest(unittest.TestCase): """Tests for GraphBuilder.subgraph().""" diff --git a/onnxscript/_internal/values.py b/onnxscript/_internal/values.py index 051cb3e686..c3fd1b3dad 100644 --- a/onnxscript/_internal/values.py +++ b/onnxscript/_internal/values.py @@ -222,6 +222,10 @@ def __call__(self, *args, **kwargs): def name(self) -> str: return self._name + @property + def domain(self) -> str: + return self._opset.domain + @property def opset(self) -> Opset: return self._opset