diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index e6b0d069c9fb..2970c8d46656 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -1852,8 +1852,39 @@ struct ggml_cuda_fattn_route_plan { bool allow_vec; bool unsafe_vec_after_turbo_v_decode; best_fattn_kernel kernel; + const char * none_reason; }; +static inline bool ggml_cuda_fattn_can_force_hip_tile_fallback( + const ggml_tensor * Q, + const ggml_tensor * K, + const ggml_tensor * V, + const ggml_tensor * mask, + const ggml_cuda_fattn_route_plan & plan) { +#if defined(GGML_USE_HIP) + if (plan.effective_type_K != GGML_TYPE_F16 || plan.effective_type_V != GGML_TYPE_F16) { + return false; + } + + if (Q->ne[0] != 256 || K->ne[0] != 256 || V->ne[0] != 256) { + return false; + } + + if (K->ne[1] % FATTN_KQ_STRIDE != 0) { + return false; + } + + return true; +#else + GGML_UNUSED(Q); + GGML_UNUSED(K); + GGML_UNUSED(V); + GGML_UNUSED(mask); + GGML_UNUSED(plan); + return false; +#endif +} + static ggml_cuda_fattn_route_plan ggml_cuda_fattn_make_route_plan(const int device, const ggml_tensor * dst) { GGML_ASSERT(dst != nullptr); GGML_ASSERT(dst->src[0] != nullptr); @@ -1863,6 +1894,7 @@ static ggml_cuda_fattn_route_plan ggml_cuda_fattn_make_route_plan(const int devi const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; ggml_cuda_fattn_route_plan plan = {}; @@ -1958,9 +1990,54 @@ static ggml_cuda_fattn_route_plan ggml_cuda_fattn_make_route_plan(const int devi ggml_cuda_fattn_effective_vec_shape_unsafe(Q, &K_eff, &V_eff); plan.allow_vec = !plan.unsafe_vec_after_turbo_v_decode; + plan.none_reason = nullptr; plan.kernel = ggml_cuda_get_best_fattn_kernel(device, &dst_eff, plan.allow_vec); + if (plan.kernel == BEST_FATTN_KERNEL_NONE) { + if (V_eff.ne[0] != K_eff.ne[0]) { + plan.none_reason = "K/V head dim mismatch"; + } else if (mask && mask->ne[2] != 1) { + plan.none_reason = "mask ne[2] != 1"; + } else if (K_eff.ne[1] % FATTN_KQ_STRIDE != 0) { + plan.none_reason = "K sequence length is not FATTN_KQ_STRIDE padded"; + } else if (plan.effective_type_K != GGML_TYPE_F16 || plan.effective_type_V != GGML_TYPE_F16) { + plan.none_reason = "unsupported effective K/V type pair"; + } else { + plan.none_reason = "kernel selector returned NONE"; + } + } + + if (plan.kernel == BEST_FATTN_KERNEL_NONE && + ggml_cuda_fattn_can_force_hip_tile_fallback(Q, &K_eff, &V_eff, dst->src[3], plan)) { + plan.kernel = BEST_FATTN_KERNEL_TILE; + plan.none_reason = "forced HIP tile fallback"; + + if (ggml_cuda_fattn_route_debug_enabled()) { + fprintf(stderr, + "CUDA_FA_ROUTE_FALLBACK " + "device=%d forcing=HIP_TILE " + "Q=[%lld,%lld,%lld,%lld] " + "mask=%s mask_shape=[%lld,%lld,%lld,%lld] " + "Keff=%s Veff=%s " + "Kshape=[%lld,%lld,%lld,%lld] Vshape=[%lld,%lld,%lld,%lld] " + "reason=%s\n", + device, + (long long) Q->ne[0], (long long) Q->ne[1], (long long) Q->ne[2], (long long) Q->ne[3], + mask ? "yes" : "no", + mask ? (long long) mask->ne[0] : -1LL, + mask ? (long long) mask->ne[1] : -1LL, + mask ? (long long) mask->ne[2] : -1LL, + mask ? (long long) mask->ne[3] : -1LL, + ggml_type_name(plan.effective_type_K), + ggml_type_name(plan.effective_type_V), + (long long) K_eff.ne[0], (long long) K_eff.ne[1], (long long) K_eff.ne[2], (long long) K_eff.ne[3], + (long long) V_eff.ne[0], (long long) V_eff.ne[1], (long long) V_eff.ne[2], (long long) V_eff.ne[3], + plan.none_reason); + fflush(stderr); + } + } + plan.need_generic_f16_K = false; plan.need_generic_f16_V = false; switch (plan.kernel) { @@ -1983,6 +2060,7 @@ static ggml_cuda_fattn_route_plan ggml_cuda_fattn_make_route_plan(const int devi "CUDA_FA_ROUTE_PLAN " "device=%d " "Q=[%lld,%lld,%lld,%lld] " + "mask=%s mask_shape=[%lld,%lld,%lld,%lld] mask_nb=[%lld,%lld,%lld,%lld] " "Kraw=%s Vraw=%s " "Kshape=[%lld,%lld,%lld,%lld] Vshape=[%lld,%lld,%lld,%lld] " "Keff=%s Veff=%s " @@ -1990,11 +2068,21 @@ static ggml_cuda_fattn_route_plan ggml_cuda_fattn_make_route_plan(const int devi "decode=%d dk=%d dv=%d " "unsafe_vec_after_turbo_v_decode=%d allow_vec=%d " "kernel=%s " + "none_reason=%s " "need_f16_K=%d need_f16_V=%d " "Knb=[%lld,%lld,%lld,%lld] Vnb=[%lld,%lld,%lld,%lld] " "Keff_nb=[%lld,%lld,%lld,%lld] Veff_nb=[%lld,%lld,%lld,%lld]\n", device, (long long) Q->ne[0], (long long) Q->ne[1], (long long) Q->ne[2], (long long) Q->ne[3], + mask ? "yes" : "no", + mask ? (long long) mask->ne[0] : -1LL, + mask ? (long long) mask->ne[1] : -1LL, + mask ? (long long) mask->ne[2] : -1LL, + mask ? (long long) mask->ne[3] : -1LL, + mask ? (long long) mask->nb[0] : -1LL, + mask ? (long long) mask->nb[1] : -1LL, + mask ? (long long) mask->nb[2] : -1LL, + mask ? (long long) mask->nb[3] : -1LL, ggml_type_name(K->type), ggml_type_name(V->type), (long long) K->ne[0], (long long) K->ne[1], (long long) K->ne[2], (long long) K->ne[3], (long long) V->ne[0], (long long) V->ne[1], (long long) V->ne[2], (long long) V->ne[3], @@ -2006,6 +2094,7 @@ static ggml_cuda_fattn_route_plan ggml_cuda_fattn_make_route_plan(const int devi (int) plan.unsafe_vec_after_turbo_v_decode, (int) plan.allow_vec, ggml_cuda_fattn_kernel_name(plan.kernel), + plan.none_reason ? plan.none_reason : "n/a", (int) plan.need_generic_f16_K, (int) plan.need_generic_f16_V, (long long) K->nb[0], (long long) K->nb[1], (long long) K->nb[2], (long long) K->nb[3], @@ -2056,6 +2145,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; #if defined(GGML_CUDA_FA_ALL_QUANTS) || defined(GGML_CUDA_FA_HALF_QUANTS) if ((ggml_cuda_fattn_is_ranked_kv_type(K->type) || ggml_cuda_fattn_is_ranked_kv_type(V->type)) && @@ -2457,8 +2547,31 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst switch (selected_kernel) { case BEST_FATTN_KERNEL_NONE: - fprintf(stderr, "No CUDA FA kernel selected: K=%s V=%s D=%lld\n", - ggml_type_name(K->type), ggml_type_name(V->type), (long long) Q->ne[0]); + fprintf(stderr, + "No CUDA FA kernel selected: " + "Q=[%lld,%lld,%lld,%lld] " + "K=%s Kshape=[%lld,%lld,%lld,%lld] Knb=[%lld,%lld,%lld,%lld] " + "V=%s Vshape=[%lld,%lld,%lld,%lld] Vnb=[%lld,%lld,%lld,%lld] " + "mask=%s mask_shape=[%lld,%lld,%lld,%lld] mask_nb=[%lld,%lld,%lld,%lld] " + "allow_vec=%d none_reason=%s\n", + (long long) Q->ne[0], (long long) Q->ne[1], (long long) Q->ne[2], (long long) Q->ne[3], + ggml_type_name(K->type), + (long long) K->ne[0], (long long) K->ne[1], (long long) K->ne[2], (long long) K->ne[3], + (long long) K->nb[0], (long long) K->nb[1], (long long) K->nb[2], (long long) K->nb[3], + ggml_type_name(V->type), + (long long) V->ne[0], (long long) V->ne[1], (long long) V->ne[2], (long long) V->ne[3], + (long long) V->nb[0], (long long) V->nb[1], (long long) V->nb[2], (long long) V->nb[3], + mask ? "yes" : "no", + mask ? (long long) mask->ne[0] : -1LL, + mask ? (long long) mask->ne[1] : -1LL, + mask ? (long long) mask->ne[2] : -1LL, + mask ? (long long) mask->ne[3] : -1LL, + mask ? (long long) mask->nb[0] : -1LL, + mask ? (long long) mask->nb[1] : -1LL, + mask ? (long long) mask->nb[2] : -1LL, + mask ? (long long) mask->nb[3] : -1LL, + (int) plan.allow_vec, + plan.none_reason ? plan.none_reason : "n/a"); GGML_ABORT("no CUDA FA kernel selected"); case BEST_FATTN_KERNEL_TILE: ggml_cuda_flash_attn_ext_tile(ctx, dst);