Fix BatchNorm fusion producing invalid ONNX when Conv nodes share weight initializers#2883
Conversation
…ght initializers Add a check in _FuseBatchNormBase.check() to verify that the inbound Conv/ConvTranspose/Gemm node's weight and bias initializers are not shared with other nodes outside the matched pattern. When two Conv+BatchNorm pairs share the same weight initializer, fusing the first pair overwrites the shared initializer in the graph with fused values, leaving the second Conv node with an invalid (unregistered) weight reference, producing an invalid ONNX model. Fixes #2382 Agent-Logs-Url: https://github.com/microsoft/onnxscript/sessions/10e4a5fd-e010-48dc-8a29-991b7b0a6ca7 Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #2883 +/- ##
==========================================
+ Coverage 72.04% 72.06% +0.01%
==========================================
Files 239 239
Lines 29305 29324 +19
Branches 2880 2884 +4
==========================================
+ Hits 21112 21131 +19
Misses 7216 7216
Partials 977 977 ☔ View full report in Codecov by Sentry. |
gramalingam
left a comment
There was a problem hiding this comment.
Looks ok. I think it would help to have a utility to do this kind of destructive transformation (the rule shouldn't have done this in the first place). Specifically, an utility could implement replace_uses_of_initializer_in( old_initializer_value, new_value, nodes) ... to abstract this logic. If the initializer has no uses outside the set of nodes, we can simply alter/swap. Otherwise, we can create a new initializer with the new value and a new (unique) name and use that instead. This sort of improvement can be done in a later PR though.
When a model reuses
Conv2d+BatchNorm2dblocks (same weights called twice), the BatchNorm fusion rewrite creates a new initializer with the same name as the shared weight, overwriting it in the graph's initializer dict. This sets_is_initializer = Falseon the original value still referenced by the second Conv node, producing an invalid model:Reproducer:
Changes
_fuse_batchnorm.py: Added a check in_FuseBatchNormBase.check()that rejects the fusion when the inbound node's weight/bias initializers are used by nodes outside the matched pattern. This prevents the overwrite of shared initializers._fuse_batchnorm_test.py: Added regression test with two Conv+BN pairs sharing the same weight and bias initializers.