Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 115 additions & 2 deletions ggml/src/ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1852,8 +1852,39 @@ struct ggml_cuda_fattn_route_plan {
bool allow_vec;
bool unsafe_vec_after_turbo_v_decode;
best_fattn_kernel kernel;
const char * none_reason;
};

static inline bool ggml_cuda_fattn_can_force_hip_tile_fallback(
const ggml_tensor * Q,
const ggml_tensor * K,
const ggml_tensor * V,
const ggml_tensor * mask,
const ggml_cuda_fattn_route_plan & plan) {
#if defined(GGML_USE_HIP)
if (plan.effective_type_K != GGML_TYPE_F16 || plan.effective_type_V != GGML_TYPE_F16) {
return false;
}

if (Q->ne[0] != 256 || K->ne[0] != 256 || V->ne[0] != 256) {
return false;
}

if (K->ne[1] % FATTN_KQ_STRIDE != 0) {
return false;
}

return true;
#else
GGML_UNUSED(Q);
GGML_UNUSED(K);
GGML_UNUSED(V);
GGML_UNUSED(mask);
GGML_UNUSED(plan);
return false;
#endif
}

static ggml_cuda_fattn_route_plan ggml_cuda_fattn_make_route_plan(const int device, const ggml_tensor * dst) {
GGML_ASSERT(dst != nullptr);
GGML_ASSERT(dst->src[0] != nullptr);
Expand All @@ -1863,6 +1894,7 @@ static ggml_cuda_fattn_route_plan ggml_cuda_fattn_make_route_plan(const int devi
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];

ggml_cuda_fattn_route_plan plan = {};

Expand Down Expand Up @@ -1958,9 +1990,54 @@ static ggml_cuda_fattn_route_plan ggml_cuda_fattn_make_route_plan(const int devi
ggml_cuda_fattn_effective_vec_shape_unsafe(Q, &K_eff, &V_eff);

plan.allow_vec = !plan.unsafe_vec_after_turbo_v_decode;
plan.none_reason = nullptr;

plan.kernel = ggml_cuda_get_best_fattn_kernel(device, &dst_eff, plan.allow_vec);

if (plan.kernel == BEST_FATTN_KERNEL_NONE) {
if (V_eff.ne[0] != K_eff.ne[0]) {
plan.none_reason = "K/V head dim mismatch";
} else if (mask && mask->ne[2] != 1) {
plan.none_reason = "mask ne[2] != 1";
} else if (K_eff.ne[1] % FATTN_KQ_STRIDE != 0) {
plan.none_reason = "K sequence length is not FATTN_KQ_STRIDE padded";
} else if (plan.effective_type_K != GGML_TYPE_F16 || plan.effective_type_V != GGML_TYPE_F16) {
plan.none_reason = "unsupported effective K/V type pair";
} else {
plan.none_reason = "kernel selector returned NONE";
}
}

if (plan.kernel == BEST_FATTN_KERNEL_NONE &&
ggml_cuda_fattn_can_force_hip_tile_fallback(Q, &K_eff, &V_eff, dst->src[3], plan)) {
plan.kernel = BEST_FATTN_KERNEL_TILE;
plan.none_reason = "forced HIP tile fallback";

if (ggml_cuda_fattn_route_debug_enabled()) {
fprintf(stderr,
"CUDA_FA_ROUTE_FALLBACK "
"device=%d forcing=HIP_TILE "
"Q=[%lld,%lld,%lld,%lld] "
"mask=%s mask_shape=[%lld,%lld,%lld,%lld] "
"Keff=%s Veff=%s "
"Kshape=[%lld,%lld,%lld,%lld] Vshape=[%lld,%lld,%lld,%lld] "
"reason=%s\n",
device,
(long long) Q->ne[0], (long long) Q->ne[1], (long long) Q->ne[2], (long long) Q->ne[3],
mask ? "yes" : "no",
mask ? (long long) mask->ne[0] : -1LL,
mask ? (long long) mask->ne[1] : -1LL,
mask ? (long long) mask->ne[2] : -1LL,
mask ? (long long) mask->ne[3] : -1LL,
ggml_type_name(plan.effective_type_K),
ggml_type_name(plan.effective_type_V),
(long long) K_eff.ne[0], (long long) K_eff.ne[1], (long long) K_eff.ne[2], (long long) K_eff.ne[3],
(long long) V_eff.ne[0], (long long) V_eff.ne[1], (long long) V_eff.ne[2], (long long) V_eff.ne[3],
plan.none_reason);
fflush(stderr);
}
}

plan.need_generic_f16_K = false;
plan.need_generic_f16_V = false;
switch (plan.kernel) {
Expand All @@ -1983,18 +2060,29 @@ static ggml_cuda_fattn_route_plan ggml_cuda_fattn_make_route_plan(const int devi
"CUDA_FA_ROUTE_PLAN "
"device=%d "
"Q=[%lld,%lld,%lld,%lld] "
"mask=%s mask_shape=[%lld,%lld,%lld,%lld] mask_nb=[%lld,%lld,%lld,%lld] "
"Kraw=%s Vraw=%s "
"Kshape=[%lld,%lld,%lld,%lld] Vshape=[%lld,%lld,%lld,%lld] "
"Keff=%s Veff=%s "
"turbo_kv=%d "
"decode=%d dk=%d dv=%d "
"unsafe_vec_after_turbo_v_decode=%d allow_vec=%d "
"kernel=%s "
"none_reason=%s "
"need_f16_K=%d need_f16_V=%d "
"Knb=[%lld,%lld,%lld,%lld] Vnb=[%lld,%lld,%lld,%lld] "
"Keff_nb=[%lld,%lld,%lld,%lld] Veff_nb=[%lld,%lld,%lld,%lld]\n",
device,
(long long) Q->ne[0], (long long) Q->ne[1], (long long) Q->ne[2], (long long) Q->ne[3],
mask ? "yes" : "no",
mask ? (long long) mask->ne[0] : -1LL,
mask ? (long long) mask->ne[1] : -1LL,
mask ? (long long) mask->ne[2] : -1LL,
mask ? (long long) mask->ne[3] : -1LL,
mask ? (long long) mask->nb[0] : -1LL,
mask ? (long long) mask->nb[1] : -1LL,
mask ? (long long) mask->nb[2] : -1LL,
mask ? (long long) mask->nb[3] : -1LL,
ggml_type_name(K->type), ggml_type_name(V->type),
(long long) K->ne[0], (long long) K->ne[1], (long long) K->ne[2], (long long) K->ne[3],
(long long) V->ne[0], (long long) V->ne[1], (long long) V->ne[2], (long long) V->ne[3],
Expand All @@ -2006,6 +2094,7 @@ static ggml_cuda_fattn_route_plan ggml_cuda_fattn_make_route_plan(const int devi
(int) plan.unsafe_vec_after_turbo_v_decode,
(int) plan.allow_vec,
ggml_cuda_fattn_kernel_name(plan.kernel),
plan.none_reason ? plan.none_reason : "n/a",
(int) plan.need_generic_f16_K,
(int) plan.need_generic_f16_V,
(long long) K->nb[0], (long long) K->nb[1], (long long) K->nb[2], (long long) K->nb[3],
Expand Down Expand Up @@ -2056,6 +2145,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];

#if defined(GGML_CUDA_FA_ALL_QUANTS) || defined(GGML_CUDA_FA_HALF_QUANTS)
if ((ggml_cuda_fattn_is_ranked_kv_type(K->type) || ggml_cuda_fattn_is_ranked_kv_type(V->type)) &&
Expand Down Expand Up @@ -2457,8 +2547,31 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst

switch (selected_kernel) {
case BEST_FATTN_KERNEL_NONE:
fprintf(stderr, "No CUDA FA kernel selected: K=%s V=%s D=%lld\n",
ggml_type_name(K->type), ggml_type_name(V->type), (long long) Q->ne[0]);
fprintf(stderr,
"No CUDA FA kernel selected: "
"Q=[%lld,%lld,%lld,%lld] "
"K=%s Kshape=[%lld,%lld,%lld,%lld] Knb=[%lld,%lld,%lld,%lld] "
"V=%s Vshape=[%lld,%lld,%lld,%lld] Vnb=[%lld,%lld,%lld,%lld] "
"mask=%s mask_shape=[%lld,%lld,%lld,%lld] mask_nb=[%lld,%lld,%lld,%lld] "
"allow_vec=%d none_reason=%s\n",
(long long) Q->ne[0], (long long) Q->ne[1], (long long) Q->ne[2], (long long) Q->ne[3],
ggml_type_name(K->type),
(long long) K->ne[0], (long long) K->ne[1], (long long) K->ne[2], (long long) K->ne[3],
(long long) K->nb[0], (long long) K->nb[1], (long long) K->nb[2], (long long) K->nb[3],
ggml_type_name(V->type),
(long long) V->ne[0], (long long) V->ne[1], (long long) V->ne[2], (long long) V->ne[3],
(long long) V->nb[0], (long long) V->nb[1], (long long) V->nb[2], (long long) V->nb[3],
mask ? "yes" : "no",
mask ? (long long) mask->ne[0] : -1LL,
mask ? (long long) mask->ne[1] : -1LL,
mask ? (long long) mask->ne[2] : -1LL,
mask ? (long long) mask->ne[3] : -1LL,
mask ? (long long) mask->nb[0] : -1LL,
mask ? (long long) mask->nb[1] : -1LL,
mask ? (long long) mask->nb[2] : -1LL,
mask ? (long long) mask->nb[3] : -1LL,
(int) plan.allow_vec,
plan.none_reason ? plan.none_reason : "n/a");
GGML_ABORT("no CUDA FA kernel selected");
case BEST_FATTN_KERNEL_TILE:
ggml_cuda_flash_attn_ext_tile(ctx, dst);
Expand Down