Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions onnxscript/rewriter/rules/common/_fuse_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,23 @@ def check(self, context, x, inbound_out: ir.Value, batchnorm_out: ir.Value) -> M
if initializer.is_graph_input():
return check_result.fail(f"{initializer.name} is a graph input.")

# Check that the inbound node's weight and bias initializers are not shared
# with other nodes outside this matched pattern. When the fusion creates new
# initializers with the same name as the original shared weights, it overwrites
# the original initializer in the graph, leaving other nodes that reference the
# original value with an invalid (unregistered) input.
matched_nodes = {inbound_node, batchnorm_node}
inbound_initializers = [inbound_node.inputs[1]]
if len(inbound_node.inputs) > 2:
inbound_initializers.append(inbound_node.inputs[2])
for init_value in inbound_initializers:
for user, _ in init_value.uses():
if user not in matched_nodes:
return check_result.fail(
f"Initializer '{init_value.name}' is used by another node "
f"'{user.name}' outside the matched pattern."
)

return check_result


Expand Down
50 changes: 50 additions & 0 deletions onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,56 @@ def test_fuse_batchnorm_does_not_collide_names_with_same_parent_node(self):
bias_names_2 = conv_nodes[1].inputs[2].name
self.assertNotEqual(bias_names_1, bias_names_2)

def test_fuse_batchnorm_skips_shared_weight_initializers(self):
"""Test that BatchNorm fusion is skipped when Conv nodes share weight initializers.

Regression test for https://github.com/microsoft/onnxscript/issues/2382.
When two Conv+BatchNorm pairs share the same weight initializer, fusing the
first pair would overwrite the shared initializer, leaving the second Conv
node with an invalid (unregistered) weight reference.
"""
model_proto = onnx.parser.parse_model("""
< ir_version: 7, opset_import: ["" : 17] >
test_model (float[N, 32, 14, 16] X1, float[N, 32, 14, 16] X2)
=> (float [N, ?, ?, ?] Y)
{
C1 = Conv(X1, W, B)
BN1 = BatchNormalization(C1, gamma, beta, input_mean, input_var)
C2 = Conv(X2, W, B)
BN2 = BatchNormalization(C2, gamma, beta, input_mean, input_var)
Y = Add(BN1, BN2)
}
""")
initializers = [
onnx.numpy_helper.from_array(
np.random.randn(16, 32, 3, 3).astype(np.float32), name="W"
),
onnx.numpy_helper.from_array(np.random.randn(16).astype(np.float32), name="B"),
*self._create_batchnorm_params(size=16),
]
model_proto.graph.initializer.extend(initializers)
onnx.checker.check_model(model_proto, True)
model = ir.serde.deserialize_model(model_proto)

count = _fuse_batchnorm.rules.apply_to_model(model)

# No fusion should be applied because the weight initializer is shared
self.assertEqual(count, 0)

# The model should still be valid after the (non-)optimization
output_model_proto = ir.serde.serialize_model(model)
onnx.checker.check_model(output_model_proto, True)

# Check inference produces correct results
testing.assert_numerically_equal(
model_proto,
model,
(
np.random.rand(1, 32, 14, 16).astype(np.float32),
np.random.rand(1, 32, 14, 16).astype(np.float32),
),
)


if __name__ == "__main__":
unittest.main()
Loading