Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 52 additions & 8 deletions onnxscript/rewriter/rules/common/_fuse_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT License.
"""Fuses BatchNormalization nodes into preceding nodes. Supported fusion patterns:
- BatchNormalization ∘ Conv -> Conv
- BatchNormalization ∘ ConvTranpose -> ConvTranpose
- BatchNormalization ∘ ConvTranspose -> ConvTranspose
- BatchNormalization ∘ Gemm -> Gemm

Approach:
Expand All @@ -14,7 +14,7 @@
- B_fused = (B - μ) * (gamma / std) + β
"""

from abc import ABC, abstractmethod
from abc import ABC
from typing import ClassVar, Mapping

import numpy as np
Expand All @@ -33,9 +33,18 @@ def _reshape_for_broadcast(x: np.ndarray, rank: int, axis: int = 1) -> np.ndarra
class _FuseBatchNormBase(RewriteRuleClassBase, ABC):
"""Interface for BatchNormalization nodes fusion."""

@abstractmethod
def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
"""Return the axis along which BatchNorm scale should be broadcasted."""
raise NotImplementedError()

def _scale_weights(
self,
weights: np.ndarray,
scale_factor: np.ndarray,
attributes: Mapping[str, ir.Attr],
) -> np.ndarray:
axis = self.get_filters_axis(attributes)
return weights * _reshape_for_broadcast(scale_factor, weights.ndim, axis=axis)
Comment on lines 33 to +47
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_FuseBatchNormBase no longer enforces an abstract interface (the @abstractmethod decorator was removed), so a new subclass can be instantiated without implementing get_filters_axis/_scale_weights and will only fail later at runtime. Consider restoring an abstract contract (e.g., keep get_filters_axis abstract and implement it in FuseBatchNormIntoConvTranspose even if unused, or refactor to make _scale_weights the required abstract method and provide a helper for the axis-based default).

Copilot uses AI. Check for mistakes.

def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Value):
batchnorm_node = batchnorm_out.producer()
Expand All @@ -56,10 +65,8 @@ def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Valu
inbound_node = inbound_out.producer()
weights = inbound_node.inputs[1].const_value.numpy()

# Reshape scale factor so it is broadcastable
axis = self.get_filters_axis(inbound_node.attributes)
fused_weights = ir.tensor(
weights * _reshape_for_broadcast(scale_factor, weights.ndim, axis=axis)
self._scale_weights(weights, scale_factor, inbound_node.attributes)
)

# Update bias
Expand Down Expand Up @@ -127,8 +134,26 @@ class FuseBatchNormIntoConvTranspose(_FuseBatchNormBase):

op_type: ClassVar = "ConvTranspose"

def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
return 1
def _scale_weights(
self,
weights: np.ndarray,
scale_factor: np.ndarray,
attributes: Mapping[str, ir.Attr],
) -> np.ndarray:
# ConvTranspose weight: (in_channels, out_channels/group, *kernel)
# Reshape weights: [in_channels, out_channels/group, *kernel] → [group, in_channels/group, out_channels/group, *kernel]
in_channels = weights.shape[0]
out_channels_per_group = weights.shape[1]
kernel_shape = weights.shape[2:]
group = attributes.get("group", ir.AttrInt64("group", 1)).as_int()
w = weights.reshape(group, in_channels // group, out_channels_per_group, *kernel_shape)

# Per group scale_factor (out_channels,) -> (group, out_channels/group) -> (group, 1, out_channels/group, 1, ..., 1)
s = scale_factor.reshape((group, out_channels_per_group) + (1,) * len(kernel_shape))
# insert in_channels/group axis -> (group, 1, out_channels/group, *ones)
s = s[:, None, ...]

return (w * s).reshape(weights.shape)

def pattern(self, op, x):
return op.BatchNormalization(
Expand All @@ -137,6 +162,25 @@ def pattern(self, op, x):
_outputs=["batchnorm_out"],
)

def check(self, context, x, inbound_out, batchnorm_out):
check_result = super().check(context, x, inbound_out, batchnorm_out)
if not check_result:
return check_result

inbound_node = inbound_out.producer()

in_channels = inbound_node.inputs[1].const_value.numpy().shape[0]
group = inbound_node.attributes.get("group", ir.AttrInt64("group", 1)).as_int()

# Check that in_channels is divisible by group as ONNX checker allows it
# But this is invalid case
if in_channels % group != 0:
return check_result.fail(
f"ConvTranspose in_channels ({in_channels}) is not divisible by group ({group})."
)

return check_result


class FuseBatchNormIntoGemm(_FuseBatchNormBase):
"""Replaces ``BatchNormalization(Gemm(x))`` with ``Gemm(x)``."""
Expand Down
75 changes: 55 additions & 20 deletions onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import unittest

import numpy as np
import onnx.checker
import onnx.parser
import onnx
import parameterized

from onnxscript import ir
Expand All @@ -31,29 +30,33 @@ def _create_batchnorm_params(self, size: int):

@parameterized.parameterized.expand(
[
("bias_false", False),
("bias_true", True),
("bias_false_group1", False, 1),
("bias_true_group1", True, 1),
("bias_false_group4", False, 4),
("bias_true_group4", True, 4),
]
)
def test_fuse_batchnorm_convtranspose(self, _: str, convtranspose_bias: bool):
def test_fuse_batchnorm_convtranspose(self, _: str, convtranspose_bias: bool, group: int):
# ConvTranspose weight: [in_channels, out_channels/group, kH, kW]
out_channels = 64 * group
convtranspose_inputs = "X, W"
parameters = (
"float[32, 64, 3, 3] W, "
"float[64] gamma, "
"float[64] beta, "
"float[64] input_mean, "
"float[64] input_var"
f"float[32, 64, 3, 3] W, "
f"float[{out_channels}] gamma, "
f"float[{out_channels}] beta, "
f"float[{out_channels}] input_mean, "
f"float[{out_channels}] input_var"
)
if convtranspose_bias:
parameters += ", float[64] B"
parameters += f", float[{out_channels}] B"
convtranspose_inputs += ", B"

model_proto = onnx.parser.parse_model(f"""
< ir_version: 7, opset_import: ["" : 17] >
test_model (float[N, 32, 14, 16] X) => (float [N, ?, ?, ?] Y)
<{parameters}>
{{
X1 = ConvTranspose({convtranspose_inputs})
X1 = ConvTranspose<group={group}>({convtranspose_inputs})
Y = BatchNormalization(X1, gamma, beta, input_mean, input_var)
}}
""")
Expand All @@ -62,11 +65,13 @@ def test_fuse_batchnorm_convtranspose(self, _: str, convtranspose_bias: bool):
onnx.numpy_helper.from_array(
np.random.randn(32, 64, 3, 3).astype(np.float32), name="W"
),
*self._create_batchnorm_params(size=64),
*self._create_batchnorm_params(size=out_channels),
]
if convtranspose_bias:
initializers.append(
onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="B")
onnx.numpy_helper.from_array(
np.random.randn(out_channels).astype(np.float32), name="B"
)
)
model_proto.graph.initializer.extend(initializers)

Expand All @@ -90,14 +95,18 @@ def test_fuse_batchnorm_convtranspose(self, _: str, convtranspose_bias: bool):

@parameterized.parameterized.expand(
[
("bias_false", False),
("bias_true", True),
("bias_false_group1", False, 1),
("bias_true_group1", True, 1),
("bias_false_group2", False, 2),
("bias_true_group2", True, 2),
]
)
def test_fuse_batchnorm_conv(self, _: str, conv_bias: bool):
def test_fuse_batchnorm_conv(self, _: str, conv_bias: bool, group: int):
# Conv weight: [out_channels, in_channels/group, kH, kW]
in_channels_per_group = 32 // group
conv_inputs = "X, W"
parameters = (
"float[64, 32, 3, 3] W, "
f"float[64, {in_channels_per_group}, 3, 3] W, "
"float[64] gamma, "
"float[64] beta, "
"float[64] input_mean, "
Expand All @@ -112,14 +121,14 @@ def test_fuse_batchnorm_conv(self, _: str, conv_bias: bool):
test_model (float[N, 32, 14, 16] X) => (float [N, ?, ?, ?] Y)
<{parameters}>
{{
X1 = Conv({conv_inputs})
X1 = Conv<group={group}>({conv_inputs})
Y = BatchNormalization(X1, gamma, beta, input_mean, input_var)
}}
""")
# Add initializers
initializers = [
onnx.numpy_helper.from_array(
np.random.randn(64, 32, 3, 3).astype(np.float32), name="W"
np.random.randn(64, in_channels_per_group, 3, 3).astype(np.float32), name="W"
),
*self._create_batchnorm_params(size=64),
]
Expand Down Expand Up @@ -211,6 +220,32 @@ def test_fuse_batchnorm_gemm(self, _: str, gemm_bias: bool, transB: int):
output_model_proto = ir.serde.serialize_model(model)
onnx.checker.check_model(output_model_proto, True)

def test_fuse_batchnorm_convtranspose_grouped_invalid_skipped(self):
"""Fusion is skipped when in_channels is not divisible by group (semantically invalid model)."""
# in_channels=32 is not divisible by group=3, the ONNX checker won't catch this.
model_proto = onnx.parser.parse_model("""
< ir_version: 7, opset_import: ["" : 17] >
test_model (float[N, 32, 14, 14] X) => (float[N, ?, ?, ?] Y)
<float[32, 64, 3, 3] W,
float[192] gamma, float[192] beta, float[192] input_mean, float[192] input_var>
{
X1 = ConvTranspose<group=3>(X, W)
Y = BatchNormalization(X1, gamma, beta, input_mean, input_var)
}
""")
Comment on lines +223 to +235
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new test builds a semantically-invalid ConvTranspose model to verify fusion is skipped, but it never calls onnx.checker.check_model (unlike the other tests in this file). Adding a checker call would ensure the test actually covers the stated scenario that the ONNX checker accepts the model, and will catch accidental construction errors unrelated to the fusion logic.

Copilot uses AI. Check for mistakes.
initializers = [
onnx.numpy_helper.from_array(
np.random.randn(32, 64, 3, 3).astype(np.float32), name="W"
),
*self._create_batchnorm_params(size=192),
]
model_proto.graph.initializer.extend(initializers)
model = ir.serde.deserialize_model(model_proto)
count = _fuse_batchnorm.rules.apply_to_model(model)

# Fusion must be skipped, applying it would crash on the invalid dimensions.
self.assertEqual(count, 0)

def test_fuse_batchnorm_non_initializers(self):
model_proto = onnx.parser.parse_model("""
< ir_version: 7, opset_import: ["" : 17] >
Expand Down
Loading