Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions onnxscript/_internal/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,54 @@ def initializer(
self._graph.register_initializer(value)
return value

def input(
self,
name: str,
dtype: ir.DataType | None = None,
shape: ir.Shape | Sequence[int | str | None] | None = None,
*,
type: ir.TypeProtocol | None = None,
const_value: ir.TensorProtocol | None = None,
metadata_props: dict[str, str] | None = None,
) -> ir.Value:
"""Create an input to the graph and return the corresponding ir.Value.

Args:
name: The name of the value.
dtype: The data type of the TensorType of the value. This is used only when type is None.
shape: The shape of the value.
type: The type of the value. Only one of dtype and type can be specified.
const_value: The constant tensor that initializes the value. Supply this argument
when you want to create an initializer. The type and shape can be obtained from the tensor.
metadata_props: The metadata properties that will be serialized to the ONNX proto.

Returns:
A Value object.
"""
value = ir.val(
name=name,
dtype=dtype,
shape=shape,
type=type,
const_value=const_value,
metadata_props=metadata_props,
)
self._graph.inputs.append(value)
if const_value is not None:
self._graph.register_initializer(value)
return value

def add_output(self, value: ir.Value, name: str | None) -> None:
"""Add an output to the graph.

Args:
value: The ir.Value to add as an output.
name: The name to assign to the output value. If None, no renaming is done.
"""
if name:
value.name = name
self._graph.outputs.append(value)

def _input_to_ir_value(
self, value: VALUE_LIKE, like_type: ir.Value | None = None
) -> ir.Value:
Expand Down
96 changes: 96 additions & 0 deletions onnxscript/_internal/builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,102 @@ def test_output_names_are_unique_for_same_op_type(self):
names = [t1.name, t2.name, t3.name]
self.assertEqual(len(set(names)), 3)

def test_input_creates_and_registers_graph_input(self):
"""Test that GraphBuilder.input creates and appends a graph input value."""
graph = ir.Graph(
name="test_model",
inputs=[],
outputs=[],
nodes=[],
opset_imports={"": _default_opset_version},
)
graph_builder = builder.GraphBuilder(graph)

value = graph_builder.input("data", dtype=ir.DataType.FLOAT, shape=[2, 3])

self.assertEqual(value.name, "data")
self.assertEqual(value.type.dtype, ir.DataType.FLOAT)
self.assertEqual(list(value.shape), [2, 3])
self.assertEqual(len(graph.inputs), 1)
self.assertIs(graph.inputs[0], value)

def test_input_with_const_value_registers_initializer(self):
"""Test that GraphBuilder.input registers initializer when const_value is provided."""
graph = ir.Graph(
name="test_model",
inputs=[],
outputs=[],
nodes=[],
opset_imports={"": _default_opset_version},
)
graph_builder = builder.GraphBuilder(graph)

const_tensor = ir.tensor([1.0, 2.0], dtype=ir.DataType.FLOAT, name="const_data")
value = graph_builder.input("const_input", const_value=const_tensor)

self.assertEqual(len(graph.inputs), 1)
self.assertIs(graph.inputs[0], value)
self.assertIn("const_input", graph.initializers)
self.assertIs(graph.initializers["const_input"], value)
self.assertIs(value.const_value, const_tensor)

def test_input_without_const_value_does_not_register_initializer(self):
"""Test that GraphBuilder.input does not register initializer without const_value."""
graph = ir.Graph(
name="test_model",
inputs=[],
outputs=[],
nodes=[],
opset_imports={"": _default_opset_version},
)
graph_builder = builder.GraphBuilder(graph)

value = graph_builder.input("regular_input", dtype=ir.DataType.FLOAT, shape=[2])

self.assertEqual(len(graph.inputs), 1)
self.assertIs(graph.inputs[0], value)
self.assertNotIn("regular_input", graph.initializers)

def test_add_output_renames_and_registers_output(self):
"""Test that GraphBuilder.add_output renames (optionally) and appends outputs."""
graph = ir.Graph(
name="test_model",
inputs=[],
outputs=[],
nodes=[],
opset_imports={"": _default_opset_version},
)
graph_builder = builder.GraphBuilder(graph)

output = ir.Value(name="old_name")
graph_builder.add_output(output, "new_name")

self.assertEqual(output.name, "new_name")
self.assertEqual(len(graph.outputs), 1)
self.assertIs(graph.outputs[0], output)

def test_initializer_qualification_behavior(self):
"""Test that GraphBuilder.initializer qualifies names unless explicitly disabled."""
graph = ir.Graph(
name="test_model",
inputs=[],
outputs=[],
nodes=[],
opset_imports={"": _default_opset_version},
)
graph_builder = builder.GraphBuilder(graph)

graph_builder.push_module("layer1")
qualified = graph_builder.initializer(ir.tensor([1.0], name="w"), name="weight")
unqualified = graph_builder.initializer(
ir.tensor([2.0], name="b"), name="bias", qualify=False
)

self.assertEqual(qualified.name, "layer1.weight")
self.assertEqual(unqualified.name, "bias")
self.assertIn("layer1.weight", graph.initializers)
self.assertIn("bias", graph.initializers)

def test_multi_output_names_are_unique(self):
"""Test that multi-output ops produce unique names with counter suffix."""
op, x, y = _create_builder_with_inputs()
Expand Down
Loading