Skip to content

fix(fuse_batchnorm): support convtranpose + bn fusion with group != 1#2879

Open
AyoubMDL wants to merge 1 commit intomicrosoft:mainfrom
AyoubMDL:convtranspose-bn-grouped
Open

fix(fuse_batchnorm): support convtranpose + bn fusion with group != 1#2879
AyoubMDL wants to merge 1 commit intomicrosoft:mainfrom
AyoubMDL:convtranspose-bn-grouped

Conversation

@AyoubMDL
Copy link
Copy Markdown
Contributor

@AyoubMDL AyoubMDL commented Apr 1, 2026

Fixes #2867

@justinchuby
Copy link
Copy Markdown
Collaborator

Thanks!

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes BatchNormalization fusion for grouped ConvTranspose so that optimization no longer crashes (broadcast/reshape mismatch) when group != 1, addressing #2867.

Changes:

  • Implement group-aware weight scaling for ConvTranspose + BatchNormalization fusion.
  • Add a guard to skip fusion for semantically invalid grouped ConvTranspose (when in_channels % group != 0).
  • Expand tests to cover grouped Conv/ConvTranspose fusion and the invalid-model skip case.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
onnxscript/rewriter/rules/common/_fuse_batchnorm.py Refactors weight scaling and adds correct grouped ConvTranspose handling + validation to avoid reshape/broadcast failures.
onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py Extends unit tests for grouped fusion behavior and adds a regression test for invalid grouped ConvTranspose.

Comment on lines 33 to +47
class _FuseBatchNormBase(RewriteRuleClassBase, ABC):
"""Interface for BatchNormalization nodes fusion."""

@abstractmethod
def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
"""Return the axis along which BatchNorm scale should be broadcasted."""
raise NotImplementedError()

def _scale_weights(
self,
weights: np.ndarray,
scale_factor: np.ndarray,
attributes: Mapping[str, ir.Attr],
) -> np.ndarray:
axis = self.get_filters_axis(attributes)
return weights * _reshape_for_broadcast(scale_factor, weights.ndim, axis=axis)
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_FuseBatchNormBase no longer enforces an abstract interface (the @abstractmethod decorator was removed), so a new subclass can be instantiated without implementing get_filters_axis/_scale_weights and will only fail later at runtime. Consider restoring an abstract contract (e.g., keep get_filters_axis abstract and implement it in FuseBatchNormIntoConvTranspose even if unused, or refactor to make _scale_weights the required abstract method and provide a helper for the axis-based default).

Copilot uses AI. Check for mistakes.
Comment on lines +223 to +235
def test_fuse_batchnorm_convtranspose_grouped_invalid_skipped(self):
"""Fusion is skipped when in_channels is not divisible by group (semantically invalid model)."""
# in_channels=32 is not divisible by group=3, the ONNX checker won't catch this.
model_proto = onnx.parser.parse_model("""
< ir_version: 7, opset_import: ["" : 17] >
test_model (float[N, 32, 14, 14] X) => (float[N, ?, ?, ?] Y)
<float[32, 64, 3, 3] W,
float[192] gamma, float[192] beta, float[192] input_mean, float[192] input_var>
{
X1 = ConvTranspose<group=3>(X, W)
Y = BatchNormalization(X1, gamma, beta, input_mean, input_var)
}
""")
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new test builds a semantically-invalid ConvTranspose model to verify fusion is skipped, but it never calls onnx.checker.check_model (unlike the other tests in this file). Adding a checker call would ensure the test actually covers the stated scenario that the ONNX checker accepts the model, and will catch accidental construction errors unrelated to the fusion logic.

Copilot uses AI. Check for mistakes.
@codecov
Copy link
Copy Markdown

codecov bot commented Apr 2, 2026

Codecov Report

❌ Patch coverage is 91.89189% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 71.97%. Comparing base (1077da7) to head (3c55666).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
...nnxscript/rewriter/rules/common/_fuse_batchnorm.py 87.50% 2 Missing and 1 partial ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main    #2879   +/-   ##
=======================================
  Coverage   71.96%   71.97%           
=======================================
  Files         239      239           
  Lines       29224    29251   +27     
  Branches     2878     2880    +2     
=======================================
+ Hits        21031    21053   +22     
- Misses       7216     7219    +3     
- Partials      977      979    +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Development

Successfully merging this pull request may close these issues.

fusion of grouped ConvTranspose and BatchNormalization results in exception

3 participants