fix(fuse_batchnorm): support convtranpose + bn fusion with group != 1#2879
fix(fuse_batchnorm): support convtranpose + bn fusion with group != 1#2879AyoubMDL wants to merge 1 commit intomicrosoft:mainfrom
Conversation
|
Thanks! |
There was a problem hiding this comment.
Pull request overview
Fixes BatchNormalization fusion for grouped ConvTranspose so that optimization no longer crashes (broadcast/reshape mismatch) when group != 1, addressing #2867.
Changes:
- Implement group-aware weight scaling for
ConvTranspose+ BatchNormalization fusion. - Add a guard to skip fusion for semantically invalid grouped
ConvTranspose(whenin_channels % group != 0). - Expand tests to cover grouped
Conv/ConvTransposefusion and the invalid-model skip case.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
onnxscript/rewriter/rules/common/_fuse_batchnorm.py |
Refactors weight scaling and adds correct grouped ConvTranspose handling + validation to avoid reshape/broadcast failures. |
onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py |
Extends unit tests for grouped fusion behavior and adds a regression test for invalid grouped ConvTranspose. |
| class _FuseBatchNormBase(RewriteRuleClassBase, ABC): | ||
| """Interface for BatchNormalization nodes fusion.""" | ||
|
|
||
| @abstractmethod | ||
| def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: | ||
| """Return the axis along which BatchNorm scale should be broadcasted.""" | ||
| raise NotImplementedError() | ||
|
|
||
| def _scale_weights( | ||
| self, | ||
| weights: np.ndarray, | ||
| scale_factor: np.ndarray, | ||
| attributes: Mapping[str, ir.Attr], | ||
| ) -> np.ndarray: | ||
| axis = self.get_filters_axis(attributes) | ||
| return weights * _reshape_for_broadcast(scale_factor, weights.ndim, axis=axis) |
There was a problem hiding this comment.
_FuseBatchNormBase no longer enforces an abstract interface (the @abstractmethod decorator was removed), so a new subclass can be instantiated without implementing get_filters_axis/_scale_weights and will only fail later at runtime. Consider restoring an abstract contract (e.g., keep get_filters_axis abstract and implement it in FuseBatchNormIntoConvTranspose even if unused, or refactor to make _scale_weights the required abstract method and provide a helper for the axis-based default).
| def test_fuse_batchnorm_convtranspose_grouped_invalid_skipped(self): | ||
| """Fusion is skipped when in_channels is not divisible by group (semantically invalid model).""" | ||
| # in_channels=32 is not divisible by group=3, the ONNX checker won't catch this. | ||
| model_proto = onnx.parser.parse_model(""" | ||
| < ir_version: 7, opset_import: ["" : 17] > | ||
| test_model (float[N, 32, 14, 14] X) => (float[N, ?, ?, ?] Y) | ||
| <float[32, 64, 3, 3] W, | ||
| float[192] gamma, float[192] beta, float[192] input_mean, float[192] input_var> | ||
| { | ||
| X1 = ConvTranspose<group=3>(X, W) | ||
| Y = BatchNormalization(X1, gamma, beta, input_mean, input_var) | ||
| } | ||
| """) |
There was a problem hiding this comment.
This new test builds a semantically-invalid ConvTranspose model to verify fusion is skipped, but it never calls onnx.checker.check_model (unlike the other tests in this file). Adding a checker call would ensure the test actually covers the stated scenario that the ONNX checker accepts the model, and will catch accidental construction errors unrelated to the fusion logic.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2879 +/- ##
=======================================
Coverage 71.96% 71.97%
=======================================
Files 239 239
Lines 29224 29251 +27
Branches 2878 2880 +2
=======================================
+ Hits 21031 21053 +22
- Misses 7216 7219 +3
- Partials 977 979 +2 ☔ View full report in Codecov by Sentry. |
Fixes #2867