-
Notifications
You must be signed in to change notification settings - Fork 107
fix(fuse_batchnorm): support convtranpose + bn fusion with group != 1 #2879
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,8 +3,7 @@ | |
| import unittest | ||
|
|
||
| import numpy as np | ||
| import onnx.checker | ||
| import onnx.parser | ||
| import onnx | ||
| import parameterized | ||
|
|
||
| from onnxscript import ir | ||
|
|
@@ -31,29 +30,33 @@ def _create_batchnorm_params(self, size: int): | |
|
|
||
| @parameterized.parameterized.expand( | ||
| [ | ||
| ("bias_false", False), | ||
| ("bias_true", True), | ||
| ("bias_false_group1", False, 1), | ||
| ("bias_true_group1", True, 1), | ||
| ("bias_false_group4", False, 4), | ||
| ("bias_true_group4", True, 4), | ||
| ] | ||
| ) | ||
| def test_fuse_batchnorm_convtranspose(self, _: str, convtranspose_bias: bool): | ||
| def test_fuse_batchnorm_convtranspose(self, _: str, convtranspose_bias: bool, group: int): | ||
| # ConvTranspose weight: [in_channels, out_channels/group, kH, kW] | ||
| out_channels = 64 * group | ||
| convtranspose_inputs = "X, W" | ||
| parameters = ( | ||
| "float[32, 64, 3, 3] W, " | ||
| "float[64] gamma, " | ||
| "float[64] beta, " | ||
| "float[64] input_mean, " | ||
| "float[64] input_var" | ||
| f"float[32, 64, 3, 3] W, " | ||
| f"float[{out_channels}] gamma, " | ||
| f"float[{out_channels}] beta, " | ||
| f"float[{out_channels}] input_mean, " | ||
| f"float[{out_channels}] input_var" | ||
| ) | ||
| if convtranspose_bias: | ||
| parameters += ", float[64] B" | ||
| parameters += f", float[{out_channels}] B" | ||
| convtranspose_inputs += ", B" | ||
|
|
||
| model_proto = onnx.parser.parse_model(f""" | ||
| < ir_version: 7, opset_import: ["" : 17] > | ||
| test_model (float[N, 32, 14, 16] X) => (float [N, ?, ?, ?] Y) | ||
| <{parameters}> | ||
| {{ | ||
| X1 = ConvTranspose({convtranspose_inputs}) | ||
| X1 = ConvTranspose<group={group}>({convtranspose_inputs}) | ||
| Y = BatchNormalization(X1, gamma, beta, input_mean, input_var) | ||
| }} | ||
| """) | ||
|
|
@@ -62,11 +65,13 @@ def test_fuse_batchnorm_convtranspose(self, _: str, convtranspose_bias: bool): | |
| onnx.numpy_helper.from_array( | ||
| np.random.randn(32, 64, 3, 3).astype(np.float32), name="W" | ||
| ), | ||
| *self._create_batchnorm_params(size=64), | ||
| *self._create_batchnorm_params(size=out_channels), | ||
| ] | ||
| if convtranspose_bias: | ||
| initializers.append( | ||
| onnx.numpy_helper.from_array(np.random.randn(64).astype(np.float32), name="B") | ||
| onnx.numpy_helper.from_array( | ||
| np.random.randn(out_channels).astype(np.float32), name="B" | ||
| ) | ||
| ) | ||
| model_proto.graph.initializer.extend(initializers) | ||
|
|
||
|
|
@@ -90,14 +95,18 @@ def test_fuse_batchnorm_convtranspose(self, _: str, convtranspose_bias: bool): | |
|
|
||
| @parameterized.parameterized.expand( | ||
| [ | ||
| ("bias_false", False), | ||
| ("bias_true", True), | ||
| ("bias_false_group1", False, 1), | ||
| ("bias_true_group1", True, 1), | ||
| ("bias_false_group2", False, 2), | ||
| ("bias_true_group2", True, 2), | ||
| ] | ||
| ) | ||
| def test_fuse_batchnorm_conv(self, _: str, conv_bias: bool): | ||
| def test_fuse_batchnorm_conv(self, _: str, conv_bias: bool, group: int): | ||
| # Conv weight: [out_channels, in_channels/group, kH, kW] | ||
| in_channels_per_group = 32 // group | ||
| conv_inputs = "X, W" | ||
| parameters = ( | ||
| "float[64, 32, 3, 3] W, " | ||
| f"float[64, {in_channels_per_group}, 3, 3] W, " | ||
| "float[64] gamma, " | ||
| "float[64] beta, " | ||
| "float[64] input_mean, " | ||
|
|
@@ -112,14 +121,14 @@ def test_fuse_batchnorm_conv(self, _: str, conv_bias: bool): | |
| test_model (float[N, 32, 14, 16] X) => (float [N, ?, ?, ?] Y) | ||
| <{parameters}> | ||
| {{ | ||
| X1 = Conv({conv_inputs}) | ||
| X1 = Conv<group={group}>({conv_inputs}) | ||
| Y = BatchNormalization(X1, gamma, beta, input_mean, input_var) | ||
| }} | ||
| """) | ||
| # Add initializers | ||
| initializers = [ | ||
| onnx.numpy_helper.from_array( | ||
| np.random.randn(64, 32, 3, 3).astype(np.float32), name="W" | ||
| np.random.randn(64, in_channels_per_group, 3, 3).astype(np.float32), name="W" | ||
| ), | ||
| *self._create_batchnorm_params(size=64), | ||
| ] | ||
|
|
@@ -211,6 +220,32 @@ def test_fuse_batchnorm_gemm(self, _: str, gemm_bias: bool, transB: int): | |
| output_model_proto = ir.serde.serialize_model(model) | ||
| onnx.checker.check_model(output_model_proto, True) | ||
|
|
||
| def test_fuse_batchnorm_convtranspose_grouped_invalid_skipped(self): | ||
| """Fusion is skipped when in_channels is not divisible by group (semantically invalid model).""" | ||
| # in_channels=32 is not divisible by group=3, the ONNX checker won't catch this. | ||
| model_proto = onnx.parser.parse_model(""" | ||
| < ir_version: 7, opset_import: ["" : 17] > | ||
| test_model (float[N, 32, 14, 14] X) => (float[N, ?, ?, ?] Y) | ||
| <float[32, 64, 3, 3] W, | ||
| float[192] gamma, float[192] beta, float[192] input_mean, float[192] input_var> | ||
| { | ||
| X1 = ConvTranspose<group=3>(X, W) | ||
| Y = BatchNormalization(X1, gamma, beta, input_mean, input_var) | ||
| } | ||
| """) | ||
|
Comment on lines
+223
to
+235
|
||
| initializers = [ | ||
| onnx.numpy_helper.from_array( | ||
| np.random.randn(32, 64, 3, 3).astype(np.float32), name="W" | ||
| ), | ||
| *self._create_batchnorm_params(size=192), | ||
| ] | ||
| model_proto.graph.initializer.extend(initializers) | ||
| model = ir.serde.deserialize_model(model_proto) | ||
| count = _fuse_batchnorm.rules.apply_to_model(model) | ||
|
|
||
| # Fusion must be skipped, applying it would crash on the invalid dimensions. | ||
| self.assertEqual(count, 0) | ||
|
|
||
| def test_fuse_batchnorm_non_initializers(self): | ||
| model_proto = onnx.parser.parse_model(""" | ||
| < ir_version: 7, opset_import: ["" : 17] > | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_FuseBatchNormBaseno longer enforces an abstract interface (the@abstractmethoddecorator was removed), so a new subclass can be instantiated without implementingget_filters_axis/_scale_weightsand will only fail later at runtime. Consider restoring an abstract contract (e.g., keepget_filters_axisabstract and implement it inFuseBatchNormIntoConvTransposeeven if unused, or refactor to make_scale_weightsthe required abstract method and provide a helper for the axis-based default).