Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
259662b
update FE to 1.17
cyanguwa Jan 10, 2026
c243794
add determinism flag
cyanguwa Jan 10, 2026
5578a60
add determinism to test
cyanguwa Jan 12, 2026
9bc1d64
add determinism to qa/
cyanguwa Jan 12, 2026
b1bdab7
move bias/dbias/versioning/dropout logic to C API
cyanguwa Jan 12, 2026
ea109c2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 12, 2026
70fc94b
Update qa/L0_pytorch_unittest/test.sh
cyanguwa Jan 12, 2026
e82bd96
add determinism to Jax extension
cyanguwa Jan 13, 2026
8365962
add determinism to Jax tests
cyanguwa Jan 13, 2026
c7db02b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 13, 2026
4aaa627
Update tests/jax/test_fused_attn.py
cyanguwa Jan 13, 2026
0ee6b87
Update transformer_engine/common/fused_attn/fused_attn.cpp
cyanguwa Jan 13, 2026
4bd5e95
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 13, 2026
bd31e01
fix the AI fixes
cyanguwa Jan 13, 2026
eb2e055
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 13, 2026
6f0e515
fix Jax extension call
cyanguwa Jan 13, 2026
2c22cbf
minor fixes based on comments
cyanguwa Jan 13, 2026
279f2f6
Merge branch 'main' into blackwell_determinism
cyanguwa Jan 13, 2026
aae98f3
fix selection logic and fwd arg
cyanguwa Jan 14, 2026
b962d32
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 14, 2026
1068594
fix version check in Jax test
cyanguwa Jan 14, 2026
fed07f2
Merge branch 'main' into blackwell_determinism
cyanguwa Jan 14, 2026
c51cf44
fix pytorch CI failures
cyanguwa Jan 15, 2026
3885684
fix Jax CI failures
cyanguwa Jan 15, 2026
8bf3a0f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 15, 2026
f3a0234
Merge branch 'main' into blackwell_determinism
cyanguwa Jan 15, 2026
f526569
fix non-/determinism logic and CI
cyanguwa Jan 15, 2026
0cb374a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 15, 2026
9162ff6
fix formatting
cyanguwa Jan 15, 2026
ee90c5a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 15, 2026
77d0f2a
Update transformer_engine/common/fused_attn/fused_attn.cpp
cyanguwa Jan 15, 2026
6a3346f
Merge branch 'main' into blackwell_determinism
cyanguwa Jan 18, 2026
65a67c6
update to 9.18.1 for requirement
cyanguwa Jan 19, 2026
7187d02
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 19, 2026
cd5bcf3
reduce Jax CI tests for determinism
cyanguwa Jan 19, 2026
57e2f58
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/cudnn-frontend
Submodule cudnn-frontend updated 102 files
1 change: 1 addition & 0 deletions qa/L0_jax_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
mkdir -p "$XML_LOG_DIR"

python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*"
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_fused_attn_with_determinism.xml $TE_PATH/tests/jax/test_fused_attn.py -k "TestFusedAttnWithDeterminism" || test_fail "tests/jax/test_fused_attn.py"

pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist"
Expand Down
1 change: 1 addition & 0 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_e
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
Expand Down
212 changes: 203 additions & 9 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# See LICENSE for license information.
"""Tests for fused attention"""
import os
from enum import Enum, auto
from dataclasses import dataclass, field
from functools import partial
Expand Down Expand Up @@ -49,6 +50,9 @@
from distributed_test_base import assert_equal_collectives
from utils import assert_allclose, print_debug_tensor_stats

# Get determinism
_deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))


@pytest.fixture(autouse=True, scope="module")
def init():
Expand Down Expand Up @@ -413,15 +417,25 @@ def _check_configs(self):
pytest.skip(
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
)
# TODO(KshitijLakhani): Set the upper limit for skipping this test when cuDNN adds support
if (
get_device_compute_capability(0) >= 100
and self.dropout_prob == 0.1
and self.attn_bias_type is not AttnBiasType.NO_BIAS
):
pytest.skip(
"For sm100+, bprop kernel support for dropout + determinism (bias) is not supported"
)

if get_device_compute_capability(0) >= 100 and self.is_training:
if FusedAttnHelper.is_non_deterministic_allowed() and (
(self.dropout_prob != 0.0 and self.attn_bias_type != AttnBiasType.NO_BIAS)
or get_cudnn_version() < 90700
):
pytest.skip(
"For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with"
" dropout"
)
if not FusedAttnHelper.is_non_deterministic_allowed() and (
self.dropout_prob != 0.0
or self.attn_bias_type != AttnBiasType.NO_BIAS
or get_cudnn_version() < 91801
):
pytest.skip(
"For sm100+, deterministic bprop (cuDNN 9.18.1+) does not support bias or"
" dropout"
)
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate():
Expand Down Expand Up @@ -1269,6 +1283,7 @@ def check_dqkv(primitive, reference, pad, idx):
pytest.param(SeqDescFormat.SegmentIDs, id="SegmentIDs"),
],
)
@pytest.mark.skipif(_deterministic, reason="Test non-determinism only")
class TestFusedAttn:
"""
Fused attention tester
Expand Down Expand Up @@ -1392,3 +1407,182 @@ def test_backward(
seq_desc_format,
)
runner.test_backward()


@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
pytest.param(
AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, id="PADDING_CAUSAL_BOTTOM_RIGHT"
),
],
)
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
],
)
@pytest.mark.parametrize(
"b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype, qkv_layout",
[
# large data size + fp16 + cross attn + gqa + diff hidden v dim + qkv separate
pytest.param(
2,
1024,
2048,
12,
6,
128,
64,
jnp.bfloat16,
QKVLayout.BSHD_BSHD_BSHD,
id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-SEPARATE",
),
pytest.param(
2,
1024,
2048,
12,
6,
128,
64,
jnp.bfloat16,
QKVLayout.THD_THD_THD,
id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-RAGGED_SEPARATE",
),
],
)
@pytest.mark.parametrize(
"dropout_prob",
[
pytest.param(0.0, id="DROP_0.0"),
],
)
@pytest.mark.parametrize(
"swa",
[
pytest.param(False, id="NO_SWA"),
],
)
@pytest.mark.parametrize(
"seq_desc_format",
[
pytest.param(SeqDescFormat.Seqlens, id="Seqlens"),
],
)
@pytest.mark.skipif(not _deterministic, reason="Test determinism only")
class TestFusedAttnWithDeterminism:
"""
Fused attention tester with determinism
"""

@staticmethod
@pytest.mark.parametrize(
"is_training",
[
pytest.param(True, id="TRAINING"),
],
)
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
],
)
def _test_forward(
b,
s_q,
s_kv,
h_q,
h_kv,
d_qk,
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
is_training,
qkv_layout,
bias_shape,
swa,
seq_desc_format,
):
"""
Test forward with parameterized configs
This test is not intended to run automatically during CI as it is time-consuming
It is kept for development and debugging
"""
TestFusedAttn._test_forward(
b,
s_q,
s_kv,
h_q,
h_kv,
d_qk,
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
is_training,
qkv_layout,
bias_shape,
swa,
seq_desc_format,
)

@staticmethod
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
],
)
def test_backward(
b,
s_q,
s_kv,
h_q,
h_kv,
d_qk,
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
qkv_layout,
bias_shape,
swa,
seq_desc_format,
):
"""
Test backward with parameterized configs
"""
TestFusedAttn.test_backward(
b,
s_q,
s_kv,
h_q,
h_kv,
d_qk,
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
qkv_layout,
bias_shape,
swa,
seq_desc_format,
)
Loading
Loading