Summary
The aten_gru translation in onnxscript/function_libs/torch_lib/ops/core.py (added in #2674) does not set linear_before_reset=1 on the ONNX GRU op. This causes numerically incorrect results because PyTorch's nn.GRU uses the linear_before_reset=1 variant.
Details
PyTorch GRU computes the new gate as:
n_t = tanh(W_in @ x_t + b_in + r_t * (W_hn @ h_{t-1} + b_hn))
This matches the ONNX GRU spec with linear_before_reset=1. But the default linear_before_reset=0 applies the reset gate before the linear transformation — a different equation.
The two op.GRU calls (~lines 4352 and 4362) need linear_before_reset=1 added.
Reproduction
import torch, numpy as np
m = torch.nn.GRU(1, 32, batch_first=True)
m.eval()
inp = torch.randn(1, 10, 1)
with torch.no_grad():
pt_out, _ = m(inp)
torch.onnx.export(m, (inp,), f="gru.onnx")
import onnxruntime as ort
sess = ort.InferenceSession("gru.onnx")
onnx_out = sess.run(None, {sess.get_inputs()[0].name: inp.numpy()})[0]
print("Max abs diff:", np.abs(pt_out.numpy() - onnx_out).max())
# Expected: ~1e-7 (float32 precision)
# Actual: ~0.1 (incorrect GRU equation)
Environment
- torch 2.10.0
- onnxscript 0.6.2
- onnxruntime 1.22.0
References
Summary
The
aten_grutranslation inonnxscript/function_libs/torch_lib/ops/core.py(added in #2674) does not setlinear_before_reset=1on the ONNXGRUop. This causes numerically incorrect results because PyTorch'snn.GRUuses thelinear_before_reset=1variant.Details
PyTorch GRU computes the new gate as:
This matches the ONNX GRU spec with
linear_before_reset=1. But the defaultlinear_before_reset=0applies the reset gate before the linear transformation — a different equation.The two
op.GRUcalls (~lines 4352 and 4362) needlinear_before_reset=1added.Reproduction
Environment
References
linear_before_reset)aten_grutranslation)