fmha_fwd_splitkv_kernel.hpp Source File

fmha_fwd_splitkv_kernel.hpp Source File#

Composable Kernel: fmha_fwd_splitkv_kernel.hpp Source File
fmha_fwd_splitkv_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
10
11#include <string>
12#include <type_traits>
13
14// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
15// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
16// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
17// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
18// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
19
20namespace ck_tile {
21
22template <typename FmhaPipeline_, typename EpiloguePipeline_>
24{
27 static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
28 static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
29
30 static_assert(kBlockPerCu > 0);
31 static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
32
41
43
44 static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
45 static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
46 static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
47 static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
48 static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
49 static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
50 static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
51 static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
52 static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
53 static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
54 static constexpr bool kMergeNumHeadGroupsSeqLenQ =
55 FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ;
58 static constexpr bool kHasMask = FmhaMask::IsMasking;
59
60 static_assert(!kMergeNumHeadGroupsSeqLenQ ||
62 !kHasMask));
63
64 // clang-format off
65 template <typename T> struct t2s;
66 template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
67 template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
68 template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
69 template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
70 template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
71 // clang-format on
72
73 CK_TILE_HOST static std::string GetName()
74 {
75 // sync with generate.py
76 // clang-format off
77 using bfs = typename FmhaPipeline::BlockFmhaShape;
78 using g0br = typename bfs::Gemm0BlockWarps;
79 using g1br = typename bfs::Gemm1BlockWarps;
80 using g0wt = typename bfs::Gemm0WarpTile;
81 using g1wt = typename bfs::Gemm1WarpTile;
82 #define _SS_ std::string
83 #define _TS_ std::to_string
84 auto pn = [&] () {
85 std::string n;
86 if (kPadSeqLenQ) n += "s";
87 if (kPadSeqLenK) n += "sk";
88 if (kPadHeadDimQ) n += "d";
89 if (kPadHeadDimV) n += "dv";
90 return n.empty() ? n : std::string("p") + n; }();
91 return
92 _SS_("fmha_fwd_splitkv_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
93 "_" + (kIsGroupMode ? "group" : "batch") + "_"
94 "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
95 _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
96 "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
97 "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
98 "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
99 "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
100 (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
101 "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
103 (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) +
104 (kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" );
105 #undef _SS_
106 #undef _TS_
107 // clang-format on
108 }
109
110 template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
111 // arg
113 {
114 };
115
116 // kargs use aggregate initializer, so no constructor will provided
117 // use inheritance to minimize karg size
118 // user need to use MakeKargs() function to create kargs.
156
158 {
160
161 void init_logits_soft_cap(float logits_soft_cap_)
162 {
163 if(0 < logits_soft_cap_)
164 {
165 logits_soft_cap = logits_soft_cap_;
167 }
168 else
169 {
170 logits_soft_cap = 0.f;
172 }
173 }
174
177 };
178
185
190
192 {
193 // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
194 const void* alibi_slope_ptr;
195 ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
196 };
197
199 {
200 // ck_tile::index_t window_size_left, window_size_right;
203 };
204
206 {
207 float scale_p;
208 };
209
216
221
223 {
225 };
226
228 : CommonKargs,
229 std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
230 BatchModeBiasKargs,
231 std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
232 AlibiKargs,
233 EmptyKargs<0>>>,
234 std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
235 std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
236 std::conditional_t<kIsPagedKV, CommonPageBlockTableKargs, CacheBatchIdxKargs>,
237 std::conditional_t<kHasLogitsSoftCap, LogitsSoftCapKargs, EmptyKargs<3>>
238 {
240
242 ck_tile::index_t batch_stride_k; // when using paged-kvcache, this will be stride/size for
243 // single kcache page-block
244 ck_tile::index_t batch_stride_v; // when using paged-kvcache, this will be stride/size for
245 // single vcache page-block
248 };
249
251 : CommonKargs,
252 std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
253 CommonBiasKargs,
254 std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
255 AlibiKargs,
256 EmptyKargs<0>>>,
257 std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
258 std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
259 std::conditional_t<kIsPagedKV, GroupModePageBlockTableKargs, EmptyKargs<3>>,
260 std::conditional_t<kHasLogitsSoftCap, LogitsSoftCapKargs, EmptyKargs<4>>
261 {
265
266 ck_tile::index_t batch_stride_k; // only used for paged-kvcache, this will be stride/size
267 // for single kcache page-block
268 ck_tile::index_t batch_stride_v; // only used for paged-kvcache, this will be stride/size
269 // for single vcache page-block
270 };
271
272 using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
273
280
281 template <bool Cond = !kIsGroupMode>
282 CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
283 MakeKargs(const void* q_ptr,
284 const void* k_ptr,
285 const void* v_ptr,
286 const void* bias_ptr,
287 void* lse_acc_ptr, /* workspace for lse accumulation when num_splits > 1, otherwise
288 final lse */
289 void* o_acc_ptr, /* workspace for o accumulation when num_splits > 1, otherwise final
290 o */
291 ck_tile::index_t batch,
292 ck_tile::index_t seqlen_q,
293 ck_tile::index_t seqlen_k, // only used if 'seqlen_k_ptr' is not specified
294 const void* seqlen_k_ptr, // only used for (paged-) kvcache
295 ck_tile::index_t hdim_q,
296 ck_tile::index_t hdim_v,
297 ck_tile::index_t num_head_q,
298 ck_tile::index_t nhead_ratio_qk,
299 ck_tile::index_t num_splits,
300 const void* block_table_ptr,
301 ck_tile::index_t batch_stride_block_table,
302 ck_tile::index_t page_block_size,
303 const void* cache_batch_idx,
304 float scale_s,
305 float scale_p,
306 float logits_soft_cap,
307 ck_tile::index_t stride_q,
308 ck_tile::index_t stride_k,
309 ck_tile::index_t stride_v,
310 ck_tile::index_t stride_bias,
311 ck_tile::index_t stride_o_acc,
312 ck_tile::index_t nhead_stride_q,
313 ck_tile::index_t nhead_stride_k,
314 ck_tile::index_t nhead_stride_v,
315 ck_tile::index_t nhead_stride_bias,
316 ck_tile::index_t nhead_stride_lse_acc,
317 ck_tile::index_t nhead_stride_o_acc,
318 ck_tile::index_t batch_stride_q,
319 ck_tile::index_t batch_stride_k,
320 ck_tile::index_t batch_stride_v,
321 ck_tile::index_t batch_stride_bias,
322 ck_tile::index_t batch_stride_lse_acc,
323 ck_tile::index_t batch_stride_o_acc,
324 ck_tile::index_t split_stride_lse_acc,
325 ck_tile::index_t split_stride_o_acc,
326 ck_tile::index_t window_size_left,
327 ck_tile::index_t window_size_right,
328 ck_tile::index_t mask_type)
329 {
330 Kargs kargs{{q_ptr,
331 k_ptr,
332 v_ptr,
333 lse_acc_ptr,
334 o_acc_ptr,
335 batch,
336 seqlen_q,
337 seqlen_k,
338 hdim_q,
339 hdim_v,
340 num_head_q,
341 nhead_ratio_qk,
342 num_splits,
343#if CK_TILE_FMHA_FWD_FAST_EXP2
344 static_cast<float>(scale_s * ck_tile::log2e_v<>),
345#else
346 scale_s,
347#endif
348 stride_q,
349 stride_k,
350 stride_v,
351 stride_o_acc,
352 nhead_stride_q,
353 nhead_stride_k,
354 nhead_stride_v,
355 nhead_stride_lse_acc,
356 nhead_stride_o_acc,
357 split_stride_lse_acc,
358 split_stride_o_acc}, // args for common karg
359 {}, // placeholder for bias
360 {}, // placeholder for mask
361 {}, // placeholder for fp8_static_quant args
362 {}, // placeholder for paged-block table or cache_batch_idx
363 {}, // placeholder for logits_soft_cap
364 reinterpret_cast<const int32_t*>(seqlen_k_ptr),
365 batch_stride_q,
366 batch_stride_k,
367 batch_stride_v,
368 batch_stride_lse_acc,
369 batch_stride_o_acc};
370
372 {
373 kargs.bias_ptr = bias_ptr;
374 kargs.stride_bias = stride_bias;
375 kargs.nhead_stride_bias = nhead_stride_bias;
376 kargs.batch_stride_bias = batch_stride_bias;
377 }
378 else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
379 {
380 kargs.alibi_slope_ptr = bias_ptr;
381 kargs.alibi_slope_stride = stride_bias;
382 }
383 if constexpr(kHasMask)
384 {
385 kargs.window_size_left = window_size_left;
386 kargs.window_size_right = window_size_right;
387 kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
388 }
389 if constexpr(kDoFp8StaticQuant)
390 {
391 kargs.scale_p = scale_p;
392 }
393 if constexpr(kIsPagedKV)
394 {
395 kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
396 kargs.batch_stride_block_table = batch_stride_block_table;
397 kargs.page_block_size = page_block_size;
398 }
399 else
400 {
401 kargs.cache_batch_idx = reinterpret_cast<const int32_t*>(cache_batch_idx);
402 }
403 if constexpr(kHasLogitsSoftCap)
404 {
405 kargs.init_logits_soft_cap(logits_soft_cap);
406 }
407
408 return kargs;
409 }
410
411 template <bool Cond = kIsGroupMode>
412 CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
413 MakeKargs(const void* q_ptr,
414 const void* k_ptr,
415 const void* v_ptr,
416 const void* bias_ptr,
417 void* lse_acc_ptr, /* workspace for lse accumulation when num_splits > 1, otherwise
418 final lse */
419 void* o_acc_ptr, /* workspace for o accumulation when num_splits > 1, otherwise final
420 o */
421 ck_tile::index_t batch,
422 const void* seqstart_q_ptr,
423 const void* seqstart_k_ptr,
424 const void* seqlen_k_ptr,
425 ck_tile::index_t hdim_q,
426 ck_tile::index_t hdim_v,
427 ck_tile::index_t num_head_q,
428 ck_tile::index_t nhead_ratio_qk,
429 ck_tile::index_t num_splits,
430 const void* block_table_ptr,
431 ck_tile::index_t batch_stride_block_table,
432 ck_tile::index_t page_block_size,
433 bool is_gappy,
434 float scale_s,
435 float scale_p,
436 float logits_soft_cap,
437 ck_tile::index_t stride_q,
438 ck_tile::index_t stride_k,
439 ck_tile::index_t stride_v,
440 ck_tile::index_t stride_bias,
441 ck_tile::index_t stride_o_acc,
442 ck_tile::index_t nhead_stride_q,
443 ck_tile::index_t nhead_stride_k,
444 ck_tile::index_t nhead_stride_v,
445 ck_tile::index_t nhead_stride_bias,
446 ck_tile::index_t nhead_stride_lse_acc,
447 ck_tile::index_t nhead_stride_o_acc,
448 ck_tile::index_t batch_stride_k, // only used for paged-kvcache
449 ck_tile::index_t batch_stride_v, // only used for paged-kvcache
450 ck_tile::index_t split_stride_lse_acc,
451 ck_tile::index_t split_stride_o_acc,
452 ck_tile::index_t window_size_left,
453 ck_tile::index_t window_size_right,
454 ck_tile::index_t mask_type)
455 {
456 Kargs kargs{{q_ptr,
457 k_ptr,
458 v_ptr,
459 lse_acc_ptr,
460 o_acc_ptr,
461 batch,
462 -1, // seqlen_q will be updated by another pointer
463 -1, // seqlen_k will be updated by another pointer
464 hdim_q,
465 hdim_v,
466 num_head_q,
467 nhead_ratio_qk,
468 num_splits,
469#if CK_TILE_FMHA_FWD_FAST_EXP2
470 static_cast<float>(scale_s * ck_tile::log2e_v<>),
471#else
472 scale_s,
473#endif
474 stride_q,
475 stride_k,
476 stride_v,
477 stride_o_acc,
478 nhead_stride_q,
479 nhead_stride_k,
480 nhead_stride_v,
481 nhead_stride_lse_acc,
482 nhead_stride_o_acc,
483 split_stride_lse_acc,
484 split_stride_o_acc}, // args for common karg
485 {}, // placeholder for bias
486 {}, // placeholder for mask
487 {}, // placeholder for fp8_static_quant args
488 {}, // placeholder for paged-block table
489 {}, // placeholder for logits_soft_cap
490 reinterpret_cast<const int32_t*>(seqstart_q_ptr),
491 reinterpret_cast<const int32_t*>(seqstart_k_ptr),
492 reinterpret_cast<const int32_t*>(seqlen_k_ptr),
493 batch_stride_k,
494 batch_stride_v};
495
497 {
498 kargs.bias_ptr = bias_ptr;
499 kargs.stride_bias = stride_bias;
500 kargs.nhead_stride_bias = nhead_stride_bias;
501 }
502 else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
503 {
504 kargs.alibi_slope_ptr = bias_ptr;
505 kargs.alibi_slope_stride = stride_bias;
506 }
507 if constexpr(kHasMask)
508 {
509 kargs.window_size_left = window_size_left;
510 kargs.window_size_right = window_size_right;
511 kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
512 }
513 if constexpr(kDoFp8StaticQuant)
514 {
515 kargs.scale_p = scale_p;
516 }
517 if constexpr(kIsPagedKV)
518 {
519 kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
520 kargs.batch_stride_block_table = batch_stride_block_table;
521 kargs.page_block_size = page_block_size;
522 kargs.is_gappy = is_gappy;
523 }
524 if constexpr(kHasLogitsSoftCap)
525 {
526 kargs.init_logits_soft_cap(logits_soft_cap);
527 }
528
529 return kargs;
530 }
531
532 CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
533 ck_tile::index_t nhead_q,
534 ck_tile::index_t nhead_kv,
535 ck_tile::index_t max_seqlen_q,
536 ck_tile::index_t hdim_v,
537 ck_tile::index_t num_splits)
538 {
539 ck_tile::index_t nhead_ = kMergeNumHeadGroupsSeqLenQ ? nhead_kv : nhead_q;
540 ck_tile::index_t max_seqlen_q_ =
541 max_seqlen_q * (kMergeNumHeadGroupsSeqLenQ ? nhead_q / nhead_kv : 1);
542
543 // TODO: this may need tuning
544 return dim3(ck_tile::integer_divide_ceil(max_seqlen_q_, FmhaPipeline::kM0) *
545 ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits,
546 nhead_,
547 batch_size);
548 }
549
550 CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
551 {
552 const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
553
554 const auto f = [](index_t dividend, index_t divisor) {
555 index_t quotient = dividend / divisor;
556 index_t modulus = dividend - quotient * divisor;
557 return ck_tile::make_tuple(quotient, modulus);
558 };
559
560 const auto [mn, i_split] = f(blockIdx.x, kargs.num_splits);
561 const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1);
562 const index_t i_nhead = blockIdx.y;
563 const index_t i_batch = blockIdx.z;
564
565 if constexpr(kHasMask)
566 {
567 // assume that num_tile_n1 is always 1
568 return ck_tile::make_tuple(
569 (gridDim.x / kargs.num_splits) - 1 - i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
570 }
571 else
572 {
573 return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
574 }
575 }
576
578 {
579 if(is_wave32())
580 {
581 return dim3(kBlockSize / 2);
582 }
583 else
584 {
585 return dim3(kBlockSize);
586 }
587 }
588
590 {
591 return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
592 }
593
595 {
596 // allocate LDS
597 __shared__ char smem_ptr[GetSmemSize()];
598
599 // divide problem
600 const auto [i_tile_m, i_tile_n, i_split, i_nhead, i_batch] = GetTileIndex(kargs);
601
602 const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
603 const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
604
605 long_index_t batch_offset_q = 0;
606 long_index_t batch_offset_k = 0; // unused for paged-kvcache
607 long_index_t batch_offset_v = 0; // unused for paged-kvcache
608 long_index_t batch_offset_bias = 0;
609 long_index_t batch_offset_lse_acc = 0;
610 long_index_t batch_offset_o_acc = 0;
611 index_t kv_l2p_offset =
612 0; // logical-to-physical offset of seqlen_k coordinate. only used for paged-kvcache
613
614 if constexpr(kIsGroupMode)
615 {
616 // get starting offset for each batch
617 const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
618 const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
619
620 batch_offset_q = query_start * kargs.stride_q;
621 batch_offset_k = key_start * kargs.stride_k;
622 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
623 {
624 batch_offset_v = key_start * kargs.stride_v;
625 }
626 else
627 {
628 batch_offset_v = key_start;
629 }
631 {
632 batch_offset_bias = query_start * kargs.stride_bias;
633 }
634
635 batch_offset_lse_acc = query_start;
636 batch_offset_o_acc = query_start * kargs.stride_o_acc;
637
638 // get real # queries & # keys under group mode
639 kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
640
641 // # of required blocks is different in each groups, terminate unnecessary blocks
642 // earlier
643 if(kargs.seqlen_q * (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) <= i_m0)
644 {
645 return;
646 }
647
648 if(kargs.seqlen_k_ptr != nullptr)
649 {
650 kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
651 }
652 else
653 {
654 kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch];
655 }
656
657 if constexpr(kIsPagedKV)
658 {
659 if(kargs.is_gappy)
660 {
661 // seqstart_k_ptr has different meaning in this case
662 kv_l2p_offset = kargs.seqstart_k_ptr[i_batch];
663 }
664 }
665 }
666 else
667 {
668 const index_t i_cache_batch = [&, i_batch_ = i_batch] {
669 if constexpr(kIsPagedKV)
670 {
671 return i_batch_;
672 }
673 else
674 {
675 return (kargs.cache_batch_idx != nullptr ? kargs.cache_batch_idx[i_batch_]
676 : i_batch_);
677 }
678 }();
679
680 batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
681 batch_offset_k = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
682 batch_offset_v = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
683 batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
684 batch_offset_o_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
685
687 {
688 batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
689 }
690
691 if(kargs.seqlen_k_ptr != nullptr)
692 {
693 kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
694 }
695 }
696
697 // for simplicity, batch stride we just modify the pointer
698 const index_t i_nhead_k =
699 (kMergeNumHeadGroupsSeqLenQ ? i_nhead : i_nhead / kargs.nhead_ratio_qk);
700
701 const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
702 static_cast<long_index_t>(i_nhead) *
703 (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) *
704 kargs.nhead_stride_q +
705 batch_offset_q;
706 const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr) +
707 static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k +
708 batch_offset_k;
709 const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) +
710 static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v +
711 batch_offset_v;
712
713 ODataType* o_acc_ptr = reinterpret_cast<ODataType*>(kargs.o_acc_ptr) +
714 static_cast<long_index_t>(i_nhead) *
715 (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) *
716 kargs.nhead_stride_o_acc +
717 batch_offset_o_acc + i_split * kargs.split_stride_o_acc;
718
719 // Q/K/V DRAM and DRAM window
720 const auto q_dram = [&] {
721 const auto q_dram_naive = [&] {
722 if constexpr(kMergeNumHeadGroupsSeqLenQ)
723 {
724 // reshape: (nhead_ratio_qk, seqlen_q, hdim_q) -> (nhead_ratio_qk * seqlen_q,
725 // hdim_q)
727 q_ptr,
728 make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q, kargs.hdim_q),
729 make_tuple(kargs.nhead_stride_q, kargs.stride_q, 1),
731 number<1>{});
732
734 view,
736 make_merge_transform(make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q)),
737 make_pass_through_transform(kargs.hdim_q)),
740 }
741 else
742 {
744 q_ptr,
745 make_tuple(kargs.seqlen_q, kargs.hdim_q),
746 make_tuple(kargs.stride_q, 1),
748 number<1>{});
749 }
750 }();
751
752 if constexpr(FmhaPipeline::kQLoadOnce)
753 {
754 return pad_tensor_view(
755 q_dram_naive,
758 }
759 else
760 {
761 return pad_tensor_view(
762 q_dram_naive,
765 }
766 }();
767
768 const auto make_k_dram = [&](const KDataType* data, index_t height) {
770 data, // will update this pointer if using paged-kvcache
771 make_tuple(height, kargs.hdim_q),
772 make_tuple(kargs.stride_k, 1),
774 number<1>{});
775
776 return pad_tensor_view(
777 k_dram_naive,
780 };
781 const auto k_dram = [&]() {
782 if constexpr(kIsPagedKV)
783 {
784 return make_k_dram(nullptr, kargs.page_block_size);
785 }
786 else
787 {
788 return make_k_dram(k_ptr, kargs.seqlen_k);
789 }
790 }();
791
792 const auto make_v_dram = [&](const VDataType* data, index_t length) {
793 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
794 {
796 data, // will update this pointer if using paged-kvcache
797 make_tuple(length, kargs.hdim_v),
798 make_tuple(kargs.stride_v, 1),
800 number<1>{});
801
802 const auto v_dram_transposed =
803 transform_tensor_view(v_dram_naive,
808
809 return pad_tensor_view(
810 v_dram_transposed,
813 }
814 else
815 {
817 data, // will update this pointer if using paged-kvcache
818 make_tuple(kargs.hdim_v, length),
819 make_tuple(kargs.stride_v, 1),
821 number<1>{});
822
823 return pad_tensor_view(
824 v_dram_naive,
827 }
828 };
829 const auto v_dram = [&]() {
830 if constexpr(kIsPagedKV)
831 {
832 return make_v_dram(nullptr, kargs.page_block_size);
833 }
834 else
835 {
836 return make_v_dram(v_ptr, kargs.seqlen_k);
837 }
838 }();
839
840 auto k_page_block_navigator = [&, i_batch_ = i_batch]() {
841 if constexpr(kIsPagedKV)
842 {
843 const auto* block_indices =
844 reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
845 i_batch_ * kargs.batch_stride_block_table;
846 const index_t num_blocks =
847 integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
848
849 const long_index_t fixed_offset =
850 static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k;
851
853 kargs.k_ptr,
854 kargs.batch_stride_k, // kcache page-block stride/size
855 fixed_offset,
856 block_indices,
857 num_blocks,
858 kargs.page_block_size,
859 k_dram,
860 make_k_dram(nullptr,
861 (kv_l2p_offset + kargs.seqlen_k) -
862 (num_blocks - 1) * kargs.page_block_size));
863 }
864 else
865 {
866 return make_page_block_navigator(k_dram);
867 }
868 }();
869
870 auto v_page_block_navigator = [&, i_batch_ = i_batch]() {
871 if constexpr(kIsPagedKV)
872 {
873 const auto* block_indices =
874 reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
875 i_batch_ * kargs.batch_stride_block_table;
876 const index_t num_blocks =
877 integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
878
879 const long_index_t fixed_offset =
880 static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v;
881
883 kargs.v_ptr,
884 kargs.batch_stride_v, // vcache page-block stride/size
885 fixed_offset,
886 block_indices,
887 num_blocks,
888 kargs.page_block_size,
889 v_dram,
890 make_v_dram(nullptr,
891 (kv_l2p_offset + kargs.seqlen_k) -
892 (num_blocks - 1) * kargs.page_block_size));
893 }
894 else
895 {
896 return make_page_block_navigator(v_dram);
897 }
898 }();
899
900 auto q_dram_window = make_tile_window(
901 q_dram,
902 [&]() {
903 if constexpr(FmhaPipeline::kQLoadOnce)
906 else
908 }(),
909 {i_m0, 0});
910
911 auto k_dram_window_lengths =
913 auto v_dram_window_lengths =
915
918 const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
919 constexpr auto bias_dram_window_lengths =
922 {
923 const BiasDataType* bias_ptr =
924 reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
925 static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
926 batch_offset_bias;
927
928 const auto bias_dram = [&]() {
929 const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
930 bias_ptr,
931 make_tuple(kargs.seqlen_q, kargs.seqlen_k),
932 make_tuple(kargs.stride_bias, 1),
934 number<1>{});
935
936 return pad_tensor_view(
937 bias_dram_naive, bias_dram_window_lengths, sequence<false, kPadSeqLenK>{});
938 }();
939
940 return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
941 }
942 else
943 {
944 return make_null_tile_window(bias_dram_window_lengths);
945 }
946 }();
947
948 // lse acc
949 auto lse_acc_dram_window = [&, i_nhead_ = i_nhead, i_split_ = i_split]() {
950 constexpr auto lse_acc_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
951 LSEDataType* lse_acc_ptr = reinterpret_cast<LSEDataType*>(kargs.lse_acc_ptr) +
952 static_cast<long_index_t>(i_nhead_) *
953 (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) *
954 kargs.nhead_stride_lse_acc +
955 batch_offset_lse_acc + i_split_ * kargs.split_stride_lse_acc;
956
957 const auto lse_acc_dram = [&] {
958 const auto lse_acc_dram_naive = [&] {
959 if constexpr(kMergeNumHeadGroupsSeqLenQ)
960 {
961 // reshape: (nhead_ratio_qk, seqlen_q) -> (nhead_ratio_qk * seqlen_q)
963 lse_acc_ptr,
964 make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q),
965 make_tuple(kargs.nhead_stride_lse_acc, 1),
966 number<1>{},
967 number<1>{});
968
969 return transform_tensor_view(view,
971 kargs.nhead_ratio_qk, kargs.seqlen_q))),
974 }
975 else
976 {
978 lse_acc_ptr,
979 make_tuple(kargs.seqlen_q),
980 make_tuple(1),
981 number<1>{},
982 number<1>{});
983 }
984 }();
985 return pad_tensor_view(
986 lse_acc_dram_naive, lse_acc_dram_window_lengths, sequence<kPadSeqLenQ>{});
987 }();
988
989 return make_tile_window(lse_acc_dram, lse_acc_dram_window_lengths, {i_m0});
990 }();
991
992 FmhaMask mask = [&]() {
993 if constexpr(kHasMask)
995 kargs.window_size_left,
996 kargs.window_size_right,
997 kargs.seqlen_q,
998 kargs.seqlen_k,
1000 else
1001 return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1002 }();
1003
1004 // WA i_batch capture structure binding before c++20
1005 auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1007 {
1008 // data loading, shared by entire wg
1009 // TODO: how to use s_read?
1010 SaccDataType slope =
1011 *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
1012 i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1013#if CK_TILE_FMHA_FWD_FAST_EXP2
1014 slope *= ck_tile::log2e_v<>;
1015#endif
1016 if constexpr(kHasMask)
1017 {
1019 kargs.window_size_left,
1020 kargs.window_size_right,
1021 kargs.seqlen_q,
1022 kargs.seqlen_k,
1023 kargs.mask_type);
1024 }
1025 else
1026 {
1028 slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1029 }
1030 }
1031 else
1032 {
1034 }
1035 }();
1036
1037 AttentionVariant variant;
1038 const auto variant_params = [&] {
1039 if constexpr(kHasLogitsSoftCap)
1040 {
1042 mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
1043 }
1044 else
1045 {
1046 return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
1047 }
1048 }();
1049
1050 BlockIndices block_indices{i_batch, i_nhead, i_nhead_k};
1051
1052 auto o_acc_tile = [&, i_split_ = i_split]() {
1053 if constexpr(kDoFp8StaticQuant)
1054 {
1055 return FmhaPipeline{}(q_dram_window,
1056 identity{}, // q_element_func
1057 k_dram_window_lengths,
1058 k_page_block_navigator,
1059 identity{}, // k_element_func
1060 v_dram_window_lengths,
1061 v_page_block_navigator,
1062 identity{}, // v_element_func
1063 bias_dram_window,
1064 identity{}, // bias_element_func
1065 lse_acc_dram_window,
1066 identity{}, // lse_element_func
1067 identity{}, // s_acc_element_func
1068 scales{kargs.scale_p}, // p_compute_element_func
1069 identity{}, // o_acc_element_func
1070 kargs.num_splits,
1071 i_split_,
1072 mask,
1073 position_encoding,
1074 kargs.scale_s,
1075 variant,
1076 variant_params,
1077 block_indices,
1078 kv_l2p_offset,
1079 smem_ptr);
1080 }
1081 else
1082 {
1083 return FmhaPipeline{}(q_dram_window,
1084 k_dram_window_lengths,
1085 k_page_block_navigator,
1086 v_dram_window_lengths,
1087 v_page_block_navigator,
1088 bias_dram_window,
1089 lse_acc_dram_window,
1090 kargs.num_splits,
1091 i_split_,
1092 mask,
1093 position_encoding,
1094 kargs.scale_s,
1095 variant,
1096 variant_params,
1097 block_indices,
1098 kv_l2p_offset,
1099 smem_ptr);
1100 }
1101 }();
1102
1103 // Oacc DRAM and Oacc DRAM window
1104 auto o_acc_dram = [&] {
1105 const auto o_acc_dram_naive = [&] {
1106 if constexpr(kMergeNumHeadGroupsSeqLenQ)
1107 {
1108 // reshape: (nhead_ratio_qk, seqlen_q, hdim_v) -> (nhead_ratio_qk * seqlen_q,
1109 // hdim_v)
1111 o_acc_ptr,
1112 make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q, kargs.hdim_v),
1113 make_tuple(kargs.nhead_stride_o_acc, kargs.stride_o_acc, 1),
1115 number<1>{});
1116
1117 return transform_tensor_view(
1118 view,
1119 make_tuple(
1120 make_merge_transform(make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q)),
1121 make_pass_through_transform(kargs.hdim_v)),
1124 }
1125 else
1126 {
1128 o_acc_ptr,
1129 make_tuple(kargs.seqlen_q, kargs.hdim_v),
1130 make_tuple(kargs.stride_o_acc, 1),
1132 number<1>{});
1133 }
1134 }();
1135
1136 return pad_tensor_view(
1137 o_acc_dram_naive,
1140 }();
1141
1142 auto o_acc_dram_window =
1143 make_tile_window(o_acc_dram,
1145 {i_m0, i_n1});
1146
1147 EpiloguePipeline{}(o_acc_dram_window, o_acc_tile, nullptr);
1148 }
1149};
1150
1151} // namespace ck_tile
#define _TS_
#define _SS_
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr auto make_generic_attention_mask_from_lr_window(index_t left_size, index_t right_size, index_t y_total, index_t x_total, bool is_top_left=true)
Definition block_masking.hpp:632
@ ALIBI
Definition block_attention_bias_enum.hpp:15
@ NO_BIAS
Definition block_attention_bias_enum.hpp:13
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView &tensor_view)
Definition page_block_navigator.hpp:333
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
bfloat16_t bf16_t
Definition bfloat16.hpp:113
_Float16 fp16_t
Definition half.hpp:110
_BitInt(8) fp8_t
Definition float8.hpp:204
constexpr T log2e_v
Definition tile/core/numeric/math.hpp:488
int64_t long_index_t
Definition integer.hpp:11
CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope, index_t window_left_size, index_t window_right_size, index_t y_total, index_t x_total, GenericAttentionMaskEnum mask_enum)
Definition block_position_encoding.hpp:148
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
int32_t int32_t
Definition integer.hpp:10
CK_TILE_DEVICE constexpr auto make_null_tile_window(const WindowLengths &window_lengths)
Definition null_tile_window.hpp:66
unsigned _BitInt(8) bf8_t
Definition float8.hpp:206
GenericAttentionMaskEnum
Definition block_masking.hpp:11
@ MASK_FROM_TOP_LEFT
Definition block_masking.hpp:15
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
int32_t index_t
Definition integer.hpp:9
@ FROM_BOTTOM_RIGHT
Definition block_position_encoding.hpp:43
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_view.hpp:511
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST bool is_wave32()
Definition arch.hpp:72
Definition block_position_encoding.hpp:48
Definition block_attention_bias_enum.hpp:19
Definition block_position_encoding.hpp:137
Definition fmha_fwd_splitkv_kernel.hpp:192
const void * alibi_slope_ptr
Definition fmha_fwd_splitkv_kernel.hpp:194
ck_tile::index_t alibi_slope_stride
Definition fmha_fwd_splitkv_kernel.hpp:195
Definition fmha_fwd_splitkv_kernel.hpp:187
ck_tile::index_t batch_stride_bias
Definition fmha_fwd_splitkv_kernel.hpp:188
Definition fmha_fwd_splitkv_kernel.hpp:238
ck_tile::index_t batch_stride_lse_acc
Definition fmha_fwd_splitkv_kernel.hpp:246
ck_tile::index_t batch_stride_v
Definition fmha_fwd_splitkv_kernel.hpp:244
ck_tile::index_t batch_stride_q
Definition fmha_fwd_splitkv_kernel.hpp:241
ck_tile::index_t batch_stride_k
Definition fmha_fwd_splitkv_kernel.hpp:242
ck_tile::index_t batch_stride_o_acc
Definition fmha_fwd_splitkv_kernel.hpp:247
const int32_t * seqlen_k_ptr
Definition fmha_fwd_splitkv_kernel.hpp:239
Definition fmha_fwd_splitkv_kernel.hpp:275
ck_tile::index_t batch_idx
Definition fmha_fwd_splitkv_kernel.hpp:276
ck_tile::index_t kv_head_idx
Definition fmha_fwd_splitkv_kernel.hpp:278
ck_tile::index_t qo_head_idx
Definition fmha_fwd_splitkv_kernel.hpp:277
Definition fmha_fwd_splitkv_kernel.hpp:223
const int32_t * cache_batch_idx
Definition fmha_fwd_splitkv_kernel.hpp:224
Definition fmha_fwd_splitkv_kernel.hpp:180
const void * bias_ptr
Definition fmha_fwd_splitkv_kernel.hpp:181
ck_tile::index_t stride_bias
Definition fmha_fwd_splitkv_kernel.hpp:182
ck_tile::index_t nhead_stride_bias
Definition fmha_fwd_splitkv_kernel.hpp:183
Definition fmha_fwd_splitkv_kernel.hpp:120
ck_tile::index_t split_stride_o_acc
Definition fmha_fwd_splitkv_kernel.hpp:154
ck_tile::index_t nhead_stride_o_acc
Definition fmha_fwd_splitkv_kernel.hpp:151
const void * k_ptr
Definition fmha_fwd_splitkv_kernel.hpp:122
ck_tile::index_t num_splits
Definition fmha_fwd_splitkv_kernel.hpp:138
void * lse_acc_ptr
Definition fmha_fwd_splitkv_kernel.hpp:124
ck_tile::index_t nhead_stride_q
Definition fmha_fwd_splitkv_kernel.hpp:147
void * o_acc_ptr
Definition fmha_fwd_splitkv_kernel.hpp:125
ck_tile::index_t nhead_stride_lse_acc
Definition fmha_fwd_splitkv_kernel.hpp:150
ck_tile::index_t nhead_stride_v
Definition fmha_fwd_splitkv_kernel.hpp:149
ck_tile::index_t hdim_q
Definition fmha_fwd_splitkv_kernel.hpp:131
const void * v_ptr
Definition fmha_fwd_splitkv_kernel.hpp:123
ck_tile::index_t nhead_stride_k
Definition fmha_fwd_splitkv_kernel.hpp:148
ck_tile::index_t split_stride_lse_acc
Definition fmha_fwd_splitkv_kernel.hpp:153
ck_tile::index_t nhead_ratio_qk
Definition fmha_fwd_splitkv_kernel.hpp:137
ck_tile::index_t stride_q
Definition fmha_fwd_splitkv_kernel.hpp:142
ck_tile::index_t seqlen_k
Definition fmha_fwd_splitkv_kernel.hpp:130
ck_tile::index_t batch
Definition fmha_fwd_splitkv_kernel.hpp:127
const void * q_ptr
Definition fmha_fwd_splitkv_kernel.hpp:121
ck_tile::index_t stride_v
Definition fmha_fwd_splitkv_kernel.hpp:144
ck_tile::index_t stride_k
Definition fmha_fwd_splitkv_kernel.hpp:143
ck_tile::index_t num_head_q
Definition fmha_fwd_splitkv_kernel.hpp:134
ck_tile::index_t seqlen_q
Definition fmha_fwd_splitkv_kernel.hpp:129
ck_tile::index_t stride_o_acc
Definition fmha_fwd_splitkv_kernel.hpp:145
float scale_s
Definition fmha_fwd_splitkv_kernel.hpp:140
ck_tile::index_t hdim_v
Definition fmha_fwd_splitkv_kernel.hpp:132
Definition fmha_fwd_splitkv_kernel.hpp:211
ck_tile::index_t page_block_size
Definition fmha_fwd_splitkv_kernel.hpp:214
ck_tile::index_t batch_stride_block_table
Definition fmha_fwd_splitkv_kernel.hpp:213
const int32_t * block_table_ptr
Definition fmha_fwd_splitkv_kernel.hpp:212
Definition fmha_fwd_splitkv_kernel.hpp:113
Definition fmha_fwd_splitkv_kernel.hpp:206
float scale_p
Definition fmha_fwd_splitkv_kernel.hpp:207
Definition fmha_fwd_splitkv_kernel.hpp:261
const int32_t * seqlen_k_ptr
Definition fmha_fwd_splitkv_kernel.hpp:264
ck_tile::index_t batch_stride_k
Definition fmha_fwd_splitkv_kernel.hpp:266
const int32_t * seqstart_q_ptr
Definition fmha_fwd_splitkv_kernel.hpp:262
const int32_t * seqstart_k_ptr
Definition fmha_fwd_splitkv_kernel.hpp:263
ck_tile::index_t batch_stride_v
Definition fmha_fwd_splitkv_kernel.hpp:268
Definition fmha_fwd_splitkv_kernel.hpp:218
bool is_gappy
Definition fmha_fwd_splitkv_kernel.hpp:219
float logits_soft_cap_rcp
Definition fmha_fwd_splitkv_kernel.hpp:176
float logits_soft_cap
Definition fmha_fwd_splitkv_kernel.hpp:175
void init_logits_soft_cap(float logits_soft_cap_)
Definition fmha_fwd_splitkv_kernel.hpp:161
Definition fmha_fwd_splitkv_kernel.hpp:199
ck_tile::index_t window_size_right
Definition fmha_fwd_splitkv_kernel.hpp:201
ck_tile::GenericAttentionMaskEnum mask_type
Definition fmha_fwd_splitkv_kernel.hpp:202
ck_tile::index_t window_size_left
Definition fmha_fwd_splitkv_kernel.hpp:201
static constexpr const char * name
Definition fmha_fwd_splitkv_kernel.hpp:68
static constexpr const char * name
Definition fmha_fwd_splitkv_kernel.hpp:70
static constexpr const char * name
Definition fmha_fwd_splitkv_kernel.hpp:67
static constexpr const char * name
Definition fmha_fwd_splitkv_kernel.hpp:69
static constexpr const char * name
Definition fmha_fwd_splitkv_kernel.hpp:66
Definition fmha_fwd_splitkv_kernel.hpp:65
Definition fmha_fwd_splitkv_kernel.hpp:24
static constexpr auto BiasEnum
Definition fmha_fwd_splitkv_kernel.hpp:50
static constexpr ck_tile::index_t kBlockSize
Definition fmha_fwd_splitkv_kernel.hpp:27
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition fmha_fwd_splitkv_kernel.hpp:36
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition fmha_fwd_splitkv_kernel.hpp:26
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition fmha_fwd_splitkv_kernel.hpp:42
static constexpr bool kPadHeadDimQ
Definition fmha_fwd_splitkv_kernel.hpp:47
std::conditional_t< kIsGroupMode, GroupModeKargs, BatchModeKargs > Kargs
Definition fmha_fwd_splitkv_kernel.hpp:272
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_acc_ptr, void *o_acc_ptr, ck_tile::index_t batch, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, ck_tile::index_t num_splits, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, const void *cache_batch_idx, float scale_s, float scale_p, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o_acc, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_lse_acc, ck_tile::index_t batch_stride_o_acc, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type)
Definition fmha_fwd_splitkv_kernel.hpp:283
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_acc_ptr, void *o_acc_ptr, ck_tile::index_t batch, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, ck_tile::index_t num_splits, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, bool is_gappy, float scale_s, float scale_p, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o_acc, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type)
Definition fmha_fwd_splitkv_kernel.hpp:413
static constexpr bool kPadSeqLenK
Definition fmha_fwd_splitkv_kernel.hpp:46
static constexpr bool kPadSeqLenQ
Definition fmha_fwd_splitkv_kernel.hpp:45
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition fmha_fwd_splitkv_kernel.hpp:34
remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition fmha_fwd_splitkv_kernel.hpp:40
static constexpr bool kMergeNumHeadGroupsSeqLenQ
Definition fmha_fwd_splitkv_kernel.hpp:54
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition fmha_fwd_splitkv_kernel.hpp:57
static constexpr bool kDoFp8StaticQuant
Definition fmha_fwd_splitkv_kernel.hpp:52
static CK_TILE_HOST std::string GetName()
Definition fmha_fwd_splitkv_kernel.hpp:73
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition fmha_fwd_splitkv_kernel.hpp:38
static CK_TILE_HOST dim3 BlockSize()
Definition fmha_fwd_splitkv_kernel.hpp:577
static constexpr bool kHasMask
Definition fmha_fwd_splitkv_kernel.hpp:58
static CK_TILE_HOST constexpr auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead_q, ck_tile::index_t nhead_kv, ck_tile::index_t max_seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t num_splits)
Definition fmha_fwd_splitkv_kernel.hpp:532
static constexpr bool kPadHeadDimV
Definition fmha_fwd_splitkv_kernel.hpp:48
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition fmha_fwd_splitkv_kernel.hpp:35
static constexpr ck_tile::index_t kBlockPerCuInput
Definition fmha_fwd_splitkv_kernel.hpp:31
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition fmha_fwd_splitkv_kernel.hpp:37
remove_cvref_t< typename FmhaPipeline::OaccDataType > OaccDataType
Definition fmha_fwd_splitkv_kernel.hpp:39
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition fmha_fwd_splitkv_kernel.hpp:56
static constexpr ck_tile::index_t kBlockPerCu
Definition fmha_fwd_splitkv_kernel.hpp:28
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition fmha_fwd_splitkv_kernel.hpp:33
static constexpr bool kStoreLSE
Definition fmha_fwd_splitkv_kernel.hpp:51
static constexpr bool kIsPagedKV
Definition fmha_fwd_splitkv_kernel.hpp:53
static constexpr bool kIsGroupMode
Definition fmha_fwd_splitkv_kernel.hpp:44
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition fmha_fwd_splitkv_kernel.hpp:594
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition fmha_fwd_splitkv_kernel.hpp:589
static constexpr bool kHasLogitsSoftCap
Definition fmha_fwd_splitkv_kernel.hpp:49
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition fmha_fwd_splitkv_kernel.hpp:25
static CK_TILE_DEVICE constexpr auto GetTileIndex(const Kargs &kargs)
Definition fmha_fwd_splitkv_kernel.hpp:550
Definition variants.hpp:63
Definition variants.hpp:51
Definition tile/core/utility/functional.hpp:86
Definition tile/core/numeric/math.hpp:28
Definition tile/core/container/sequence.hpp:49