diff --git a/onnxscript/rewriter/rules/common/_fuse_batchnorm.py b/onnxscript/rewriter/rules/common/_fuse_batchnorm.py index e3298ffbd8..000d15c916 100644 --- a/onnxscript/rewriter/rules/common/_fuse_batchnorm.py +++ b/onnxscript/rewriter/rules/common/_fuse_batchnorm.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. """Fuses BatchNormalization nodes into preceding nodes. Supported fusion patterns: - BatchNormalization ∘ Conv -> Conv -- BatchNormalization ∘ ConvTranpose -> ConvTranpose +- BatchNormalization ∘ ConvTranspose -> ConvTranspose - BatchNormalization ∘ Gemm -> Gemm Approach: @@ -14,7 +14,7 @@ - B_fused = (B - μ) * (gamma / std) + β """ -from abc import ABC, abstractmethod +from abc import ABC from typing import ClassVar, Mapping import numpy as np @@ -33,9 +33,18 @@ def _reshape_for_broadcast(x: np.ndarray, rank: int, axis: int = 1) -> np.ndarra class _FuseBatchNormBase(RewriteRuleClassBase, ABC): """Interface for BatchNormalization nodes fusion.""" - @abstractmethod def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: """Return the axis along which BatchNorm scale should be broadcasted.""" + raise NotImplementedError() + + def _scale_weights( + self, + weights: np.ndarray, + scale_factor: np.ndarray, + attributes: Mapping[str, ir.Attr], + ) -> np.ndarray: + axis = self.get_filters_axis(attributes) + return weights * _reshape_for_broadcast(scale_factor, weights.ndim, axis=axis) def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Value): batchnorm_node = batchnorm_out.producer() @@ -56,10 +65,8 @@ def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Valu inbound_node = inbound_out.producer() weights = inbound_node.inputs[1].const_value.numpy() - # Reshape scale factor so it is broadcastable - axis = self.get_filters_axis(inbound_node.attributes) fused_weights = ir.tensor( - weights * _reshape_for_broadcast(scale_factor, weights.ndim, axis=axis) + self._scale_weights(weights, scale_factor, inbound_node.attributes) ) # Update bias @@ -127,8 +134,26 @@ class FuseBatchNormIntoConvTranspose(_FuseBatchNormBase): op_type: ClassVar = "ConvTranspose" - def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: - return 1 + def _scale_weights( + self, + weights: np.ndarray, + scale_factor: np.ndarray, + attributes: Mapping[str, ir.Attr], + ) -> np.ndarray: + # ConvTranspose weight: (in_channels, out_channels/group, *kernel) + # Reshape weights: [in_channels, out_channels/group, *kernel] → [group, in_channels/group, out_channels/group, *kernel] + in_channels = weights.shape[0] + out_channels_per_group = weights.shape[1] + kernel_shape = weights.shape[2:] + group = attributes.get("group", ir.AttrInt64("group", 1)).as_int() + w = weights.reshape(group, in_channels // group, out_channels_per_group, *kernel_shape) + + # Per group scale_factor (out_channels,) -> (group, out_channels/group) -> (group, 1, out_channels/group, 1, ..., 1) + s = scale_factor.reshape((group, out_channels_per_group) + (1,) * len(kernel_shape)) + # insert in_channels/group axis -> (group, 1, out_channels/group, *ones) + s = s[:, None, ...] + + return (w * s).reshape(weights.shape) def pattern(self, op, x): return op.BatchNormalization( @@ -137,6 +162,25 @@ def pattern(self, op, x): _outputs=["batchnorm_out"], ) + def check(self, context, x, inbound_out, batchnorm_out): + check_result = super().check(context, x, inbound_out, batchnorm_out) + if not check_result: + return check_result + + inbound_node = inbound_out.producer() + + in_channels = inbound_node.inputs[1].const_value.numpy().shape[0] + group = inbound_node.attributes.get("group", ir.AttrInt64("group", 1)).as_int() + + # Check that in_channels is divisible by group as ONNX checker allows it + # But this is invalid case + if in_channels % group != 0: + return check_result.fail( + f"ConvTranspose in_channels ({in_channels}) is not divisible by group ({group})." + ) + + return check_result + class FuseBatchNormIntoGemm(_FuseBatchNormBase): """Replaces ``BatchNormalization(Gemm(x))`` with ``Gemm(x)``.""" diff --git a/onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py b/onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py index 2007033ef6..04506e0fe0 100644 --- a/onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py +++ b/onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py @@ -3,8 +3,7 @@ import unittest import numpy as np -import onnx.checker -import onnx.parser +import onnx import parameterized from onnxscript import ir @@ -31,21 +30,25 @@ def _create_batchnorm_params(self, size: int): @parameterized.parameterized.expand( [ - ("bias_false", False), - ("bias_true", True), + ("bias_false_group1", False, 1), + ("bias_true_group1", True, 1), + ("bias_false_group4", False, 4), + ("bias_true_group4", True, 4), ] ) - def test_fuse_batchnorm_convtranspose(self, _: str, convtranspose_bias: bool): + def test_fuse_batchnorm_convtranspose(self, _: str, convtranspose_bias: bool, group: int): + # ConvTranspose weight: [in_channels, out_channels/group, kH, kW] + out_channels = 64 * group convtranspose_inputs = "X, W" parameters = ( - "float[32, 64, 3, 3] W, " - "float[64] gamma, " - "float[64] beta, " - "float[64] input_mean, " - "float[64] input_var" + f"float[32, 64, 3, 3] W, " + f"float[{out_channels}] gamma, " + f"float[{out_channels}] beta, " + f"float[{out_channels}] input_mean, " + f"float[{out_channels}] input_var" ) if convtranspose_bias: - parameters += ", float[64] B" + parameters += f", float[{out_channels}] B" convtranspose_inputs += ", B" model_proto = onnx.parser.parse_model(f""" @@ -53,7 +56,7 @@ def test_fuse_batchnorm_convtranspose(self, _: str, convtranspose_bias: bool): test_model (float[N, 32, 14, 16] X) => (float [N, ?, ?, ?] Y) <{parameters}> {{ - X1 = ConvTranspose({convtranspose_inputs}) + X1 = ConvTranspose({convtranspose_inputs}) Y = BatchNormalization(X1, gamma, beta, input_mean, input_var) }} """) @@ -62,11 +65,13 @@ def test_fuse_batchnorm_convtranspose(self, _: str, convtranspose_bias: bool): onnx.numpy_helper.from_array( np.random.randn(32, 64, 3, 3).astype(np.float32), name="W" ), - *self._create_batchnorm_params(size=64), + *self._create_batchnorm_params(size=out_channels), ] if convtranspose_bias: initializers.append( - onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="B") + onnx.numpy_helper.from_array( + np.random.randn(out_channels).astype(np.float32), name="B" + ) ) model_proto.graph.initializer.extend(initializers) @@ -90,14 +95,18 @@ def test_fuse_batchnorm_convtranspose(self, _: str, convtranspose_bias: bool): @parameterized.parameterized.expand( [ - ("bias_false", False), - ("bias_true", True), + ("bias_false_group1", False, 1), + ("bias_true_group1", True, 1), + ("bias_false_group2", False, 2), + ("bias_true_group2", True, 2), ] ) - def test_fuse_batchnorm_conv(self, _: str, conv_bias: bool): + def test_fuse_batchnorm_conv(self, _: str, conv_bias: bool, group: int): + # Conv weight: [out_channels, in_channels/group, kH, kW] + in_channels_per_group = 32 // group conv_inputs = "X, W" parameters = ( - "float[64, 32, 3, 3] W, " + f"float[64, {in_channels_per_group}, 3, 3] W, " "float[64] gamma, " "float[64] beta, " "float[64] input_mean, " @@ -112,14 +121,14 @@ def test_fuse_batchnorm_conv(self, _: str, conv_bias: bool): test_model (float[N, 32, 14, 16] X) => (float [N, ?, ?, ?] Y) <{parameters}> {{ - X1 = Conv({conv_inputs}) + X1 = Conv({conv_inputs}) Y = BatchNormalization(X1, gamma, beta, input_mean, input_var) }} """) # Add initializers initializers = [ onnx.numpy_helper.from_array( - np.random.randn(64, 32, 3, 3).astype(np.float32), name="W" + np.random.randn(64, in_channels_per_group, 3, 3).astype(np.float32), name="W" ), *self._create_batchnorm_params(size=64), ] @@ -211,6 +220,32 @@ def test_fuse_batchnorm_gemm(self, _: str, gemm_bias: bool, transB: int): output_model_proto = ir.serde.serialize_model(model) onnx.checker.check_model(output_model_proto, True) + def test_fuse_batchnorm_convtranspose_grouped_invalid_skipped(self): + """Fusion is skipped when in_channels is not divisible by group (semantically invalid model).""" + # in_channels=32 is not divisible by group=3, the ONNX checker won't catch this. + model_proto = onnx.parser.parse_model(""" + < ir_version: 7, opset_import: ["" : 17] > + test_model (float[N, 32, 14, 14] X) => (float[N, ?, ?, ?] Y) + + { + X1 = ConvTranspose(X, W) + Y = BatchNormalization(X1, gamma, beta, input_mean, input_var) + } + """) + initializers = [ + onnx.numpy_helper.from_array( + np.random.randn(32, 64, 3, 3).astype(np.float32), name="W" + ), + *self._create_batchnorm_params(size=192), + ] + model_proto.graph.initializer.extend(initializers) + model = ir.serde.deserialize_model(model_proto) + count = _fuse_batchnorm.rules.apply_to_model(model) + + # Fusion must be skipped, applying it would crash on the invalid dimensions. + self.assertEqual(count, 0) + def test_fuse_batchnorm_non_initializers(self): model_proto = onnx.parser.parse_model(""" < ir_version: 7, opset_import: ["" : 17] >