-
Notifications
You must be signed in to change notification settings - Fork 607
[Common] Enable determinism for cuDNN >= 9.18 on Blackwell #2584
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR enables deterministic FusedAttention on Blackwell GPUs (SM 100+) for FP16/BF16 with cuDNN >= 9.18.0. Key Changes:
Implementation Notes:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant PyTorch/JAX
participant Backend Selection
participant cuDNN Frontend
participant Forward Pass
participant Backward Pass
User->>PyTorch/JAX: Set NVTE_ALLOW_NONDETERMINISTIC_ALGO
Note over User,PyTorch/JAX: 0=deterministic, 1=non-deterministic
PyTorch/JAX->>Backend Selection: get_fused_attn_backend(deterministic)
Note over Backend Selection: New parameter: deterministic
alt Blackwell (sm_arch >= 100) Training
Backend Selection->>Backend Selection: Check cuDNN version & constraints
alt Non-deterministic (cuDNN >= 9.7.0)
Note over Backend Selection: Requires: dropout=0 OR bias=NONE
else Deterministic (cuDNN >= 9.18.0)
Note over Backend Selection: Requires: dropout=0 AND bias=NONE
end
end
Backend Selection->>PyTorch/JAX: Return backend (arbitrary_seqlen or max512)
PyTorch/JAX->>Forward Pass: nvte_fused_attn_fwd(deterministic=false)
Note over Forward Pass: Always uses deterministic algorithm
Forward Pass->>cuDNN Frontend: Execute deterministic forward
cuDNN Frontend-->>Forward Pass: Return O, aux tensors
alt Training Mode
PyTorch/JAX->>Backward Pass: nvte_fused_attn_bwd(deterministic)
Note over Backward Pass: Uses actual deterministic flag
alt Deterministic
Backward Pass->>cuDNN Frontend: Execute deterministic backward (9.18+)
else Non-deterministic
Backward Pass->>cuDNN Frontend: Execute non-deterministic backward (9.7+)
end
cuDNN Frontend-->>Backward Pass: Return dQ, dK, dV
end
Backward Pass-->>User: Gradients
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
Greptile OverviewGreptile SummaryOverviewThis PR enables determinism for FusedAttention on Blackwell GPUs (SM 100) with cuDNN version 9.18.0 or higher. The implementation moves determinism checking logic from Python to the C++ backend selection layer. Key Changes
ArchitectureThe change follows a layered approach:
The implementation correctly restricts deterministic FusedAttention to cases where cuDNN guarantees deterministic behavior, avoiding silent non-determinism. Confidence Score: 4/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User as User/Test
participant PyAPI as Python API
participant Utils as utils.py
participant CppExt as C++ Extensions
participant Backend as Backend Selection
participant cuDNN as cuDNN Library
User->>PyAPI: Call attention with deterministic=True
PyAPI->>Utils: get_attention_backend(params)
Utils->>Utils: Extract deterministic from params
Utils->>CppExt: get_fused_attn_backend(..., deterministic)
CppExt->>Backend: nvte_get_fused_attn_backend(..., deterministic)
alt Blackwell (sm_arch >= 100) & Training & Deterministic
Backend->>Backend: Check cuDNN version >= 9.18.0
Backend->>Backend: Check bias_type == NO_BIAS
Backend->>Backend: Check dropout == 0.0
alt All checks pass
Backend-->>CppExt: F16_arbitrary_seqlen backend
else Any check fails
Backend-->>CppExt: No_Backend (disabled)
end
else Other architectures or inference
Backend->>Backend: Apply standard backend selection
Backend-->>CppExt: Selected backend
end
CppExt-->>Utils: Backend choice
Utils-->>PyAPI: Backend configuration
alt Forward Pass
PyAPI->>CppExt: nvte_fused_attn_fwd(..., deterministic=true)
Note over PyAPI,CppExt: Forward always uses deterministic=true
else Backward Pass
PyAPI->>CppExt: nvte_fused_attn_bwd(..., deterministic)
Note over PyAPI,CppExt: Backward respects user's deterministic flag
end
CppExt->>cuDNN: Execute attention operation
cuDNN-->>CppExt: Results
CppExt-->>PyAPI: Output tensors
PyAPI-->>User: Attention output
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 2 comments
make .xml file specific to deterministic tests in qa/ Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Charlene Yang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
Signed-off-by: Charlene Yang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No files reviewed, no comments
Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
fix typo Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Charlene Yang <[email protected]>
fix indentation Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 3 comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 2 comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 3 comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
|
/te-ci L0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No files reviewed, no comments
Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
Greptile's behavior is changing!From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section. This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR". |
|
/te-ci L0 |
Signed-off-by: Charlene Yang <[email protected]>
|
/te-ci jax L0 |
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci L0 |
|
/te-ci L1 |
Signed-off-by: Charlene Yang <[email protected]>
|
/te-ci L1 |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
13 files reviewed, 3 comments
Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
fix and/or logic Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Charlene Yang <[email protected]>
|
/te-ci L1 |
|
Cool, we are currently suffering from this issue. |
KshitijLakhani
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments - some suggested changes and some questions.
Looks good to me, otherwise. Approving to not block from merge, if urgent.
It would be helpful, if you have a table for what's supported for <cuDNN9.18, >=cuDNN9.18, <sm100, sm100+, drop, dbias, etc. in the PR description.
I would also suggest to look into the number of tests being run and the timing (you can compare your PRs L0 jax and L0 pyt timings to the timings in TE 2.11 or in TE main CI - we would not want to go overboard with our timing budget, for sure. If you can report the timing in the PR, it would be helpful as well.
Worst case, if urgent, we can merge this PR and address the QA bit (which runs in the CI) in a separate PR subsequently .
Lastly, this might be some effort but would ensure correctness. As the code for skipping the tests in TE JAX tests has been modified, it would be good to check the test count before and after this PR to check if tests that should not be skipped are incorrectly being skipped
| mkdir -p "$XML_LOG_DIR" | ||
|
|
||
| python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*" | ||
| NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_fused_attn_deterministic.xml $TE_PATH/tests/jax/test_fused_attn.py || test_fail "tests/jax/test_fused_attn.py" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like this will first run the non-deterministic fused attn tests as part of L31, which runs all non distributed tests, followed by running the fused attn deterministic tests as part of L32.
Is that the intention ? - to run fused attn 2x - with and without determinism ?
That will greatly increase our test time and might be unnecessary. The last pipeline launched was for L1 so I am unsure that I can track the effect this change will have on timing as this is an L0 change. Could you report that in the PR please ?
Thanks !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we could come with an approach that runs half the fused attn tests deterministically and the other half non-deterministically ?
Or run all deterministically only ?
| python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" | ||
| NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed that part of my question for the JAX side L0 test is answered here.
Seems like the intention is to run 2x attention tests - with and without determinism.
I think we'd have to think more about this as I'd assume this would consume substantial testing budget.
Thoughts ?
| float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, | ||
| size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, | ||
| int64_t window_size_right, bool return_max_logit, bool cuda_graph); | ||
| int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: To be consistent, should we call this flag is_deterministic. Similar to the first arg, is_training ?
| window_size[1], | ||
| return_max_logit, | ||
| cuda_graph, | ||
| deterministic, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: To be consistent, should we call this flag is_deterministic. Similar to the first arg, is_training ?
| float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, | ||
| size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, | ||
| int64_t window_size_right); | ||
| int64_t window_size_right, bool deterministic); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: To be consistent, should we call this flag is_deterministic. Similar to the first arg, is_training ?
|
|
||
| os.environ["NVTE_FLASH_ATTN"] = "0" | ||
| os.environ["NVTE_FUSED_ATTN"] = "1" | ||
| os.environ["NVTE_UNFUSED_ATTN"] = "0" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just an effort to be explicit with the env vars and has nothing to do with determinism directly, right ?
| if any(x >= 100 for x in compute_capabilities) and is_training: | ||
| assert ( | ||
| FusedAttnHelper.is_non_deterministic_allowed() | ||
| and get_cudnn_version() >= (9, 7, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the min cuDNN version check for ? i.e. what was supported cuDNN 9.7 onwards ?
| assert ( | ||
| FusedAttnHelper.is_non_deterministic_allowed() | ||
| and get_cudnn_version() >= (9, 7, 0) | ||
| and (attn_bias_type == AttnBiasType.NO_BIAS or dropout_probability == 0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding was that :
-
=sm100 + dropout + no dbias = supported but non deterministic as dropout requires choosing a non deterministic kernel
-
=sm100 + no dropout + dbias = not supported as dbias requires choosing the deterministic path
-
=sm100 + no dropout + no dbias = supported
If this is true wouldn't case #2 falsely pass even though not supported ?
Or is my understanding incorrect ?
| "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported" | ||
| ) | ||
|
|
||
| if get_device_compute_capability(0) >= 100 and self.is_training: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding the is_training flag in the check
| (self.dropout_prob != 0.0 and self.attn_bias_type != AttnBiasType.NO_BIAS) | ||
| or get_cudnn_version() < 90700 | ||
| ): | ||
| pytest.skip( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For sm100+, non-deterministic bprop (cuDNN 9.7+), ONLY bias or ONLY dropout is supported but no both t the same time,right ?
|
/te-ci L0 L1 |
Description
This PR enables determinism for
FusedAttentionon Blackwell for FP16/BF16 precisions and cuDNN >= 9.18.0.To run with determinism, please set this flag:
export NVTE_ALLOW_NONDETERMINISTIC_ALGO=0.Type of change
Changes
Please see Description.
Checklist: