block_fmha_pipeline_qs_ks_vs.hpp Source File

block_fmha_pipeline_qs_ks_vs.hpp Source File#

Composable Kernel: block_fmha_pipeline_qs_ks_vs.hpp Source File
block_fmha_pipeline_qs_ks_vs.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
11
12namespace ck_tile {
13
14// This pipeline is qkv all located in LDS
15template <typename Problem_, typename Policy_ = BlockFmhaPipelineQSKSVSDefaultPolicy>
17{
33
36 static constexpr bool kQLoadOnce = false;
37 static_assert(kQLoadOnce == Policy::QLoadOnce);
38
39 static constexpr index_t kBlockSize = Problem::kBlockSize;
40
41 static constexpr index_t kM0 = BlockFmhaShape::kM0;
42 static constexpr index_t kN0 = BlockFmhaShape::kN0;
43 static constexpr index_t kK0 = BlockFmhaShape::kK0;
44 static constexpr index_t kN1 = BlockFmhaShape::kN1;
45 static constexpr index_t kK1 = BlockFmhaShape::kK1;
46 static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
47 static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
48
49 static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
50 static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
51 static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
52 static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
53 static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
54 static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
55 static constexpr auto BiasEnum = Problem::BiasEnum;
56 static constexpr bool kStoreLSE = Problem::kStoreLSE;
57 static constexpr bool kHasDropout = Problem::kHasDropout;
58
59 static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
60 (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
63
64 // last dimension vector length used to create tensor view(and decide buffer_load vector length)
65 // ... together with tensor distribution. tensor dist should able to overwrite this
66 static constexpr index_t kAlignmentQ =
67 kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
68 static constexpr index_t kAlignmentK =
69 kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
70 static constexpr index_t kAlignmentV = []() {
71 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
72 return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
73 else
74 return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
75 }();
76
77 static constexpr index_t kAlignmentO =
78 kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
79 static constexpr index_t kAlignmentBias =
80 kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
81
82 static constexpr index_t kBlockPerCu = []() {
83 if constexpr(Problem::kBlockPerCu != -1)
84 return Problem::kBlockPerCu;
85 else
86 {
87 if constexpr(kQKHeaddim <= 32)
88 {
89 return 2;
90 }
91 else if constexpr(kQKHeaddim <= 64)
92 {
93 return 3;
94 }
95 else if constexpr(kQKHeaddim <= 128)
96 {
98 return 1;
99 else
100 return 2;
101 }
102 else if constexpr(kQKHeaddim <= 256)
103 {
104 return 1;
105 }
106 else
107 {
108 return 1;
109 }
110 }
111 }();
112
113 static constexpr const char* name = "qs";
114
115 using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
116
118 {
119 return Policy::template GetSmemSize<Problem>();
120 }
121
122 template <typename QDramBlockWindowTmp,
123 typename KDramBlockWindowTmp,
124 typename VDramBlockWindowTmp,
125 typename BiasDramBlockWindowTmp,
126 typename RandValDramBlockWindowTmp,
127 typename LSEDramBlockWindowTmp,
128 typename QElementFunction,
129 typename KElementFunction,
130 typename VElementFunction,
131 typename BiasElementFunction,
132 typename LSEElementFunction,
133 typename SAccElementFunction,
134 typename PComputeElementFunction,
135 typename OAccElementFunction,
136 typename PositionEncoding,
137 typename AttentionVariantParams,
138 typename BlockIndices>
140 operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
141 const QElementFunction& q_element_func,
142 const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
143 const KElementFunction& k_element_func,
144 const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
145 const VElementFunction& v_element_func,
146 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
147 const BiasElementFunction& bias_element_func,
148 RandValDramBlockWindowTmp& /* unused_randval_dram_block_window_tmp */,
149 LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
150 const LSEElementFunction& lse_element_func,
151 const SAccElementFunction& s_acc_element_func,
152 const PComputeElementFunction& p_compute_element_func,
153 const OAccElementFunction& o_acc_element_func,
154 FmhaMask mask,
155 PositionEncoding position_encoding,
156 float scale_s,
157 const AttentionVariant& variant,
158 const AttentionVariantParams& variant_params,
159 const BlockIndices& block_indices,
160 void* smem_ptr,
161 DropoutType& /* unused_dropout */) const
162 {
163 static_assert(
164 std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
165 std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
166 std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
167 "wrong!");
168
169 static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
170 kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
171 kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
172 kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
173 kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
174 kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
175 kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
176 "wrong!");
177
178 // Q tile in LDS
180 reinterpret_cast<QDataType*>(smem_ptr),
181 Policy::template MakeQLdsBlockDescriptor<Problem>());
182 auto q_lds_window =
184
185 // K tile in LDS
186 KDataType* k_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
187 static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQ<Problem>()));
189 k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
190 auto k_lds_window =
192
193 // V tile in LDS
195 reinterpret_cast<VDataType*>(smem_ptr),
196 Policy::template MakeVLdsBlockDescriptor<Problem>());
197 auto v_lds_window = make_tile_window(
198 v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
199
200 // Block GEMM
201 constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
202 constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
203
204 using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
205 auto s_acc = SaccBlockTileType{};
206
207 // reduction function for softmax
208 const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
209 const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
210
211 // infer Sacc, S, P, M, L, Oacc type
212 using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_acc));
213
214 using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
215 SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
216
217 using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
218
219 // init Oacc, M, L
220 auto o_acc = OaccBlockTileType{};
221 auto m = MLBlockTileType{};
222 auto l = MLBlockTileType{};
223
224 clear_tile(o_acc);
226 clear_tile(l);
227
228 const auto q_origin = q_dram_block_window_tmp.get_window_origin();
229 const auto [seqlen_k_start, seqlen_k_end] =
230 mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
231
232 const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
233
234 // check early exit if masked and no work to do.
235 if constexpr(FmhaMask::IsMasking)
236 {
237 if(num_total_loop <= 0)
238 {
239 if constexpr(kStoreLSE)
240 {
241 auto lse =
242 make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
243
245
246 store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
247 }
248
249 // Note: here occ are all cleard, return it
250 // Note: q loaded but no fence, ignore it.
251 return o_acc;
252 }
253 }
254
255 auto k_dram_block_window =
256 make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
257 k_dram_block_window_tmp.get_window_lengths(),
258 {seqlen_k_start, 0});
259
260 const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
261 auto bias_dram_window =
262 make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
263 bias_dram_block_window_tmp.get_window_lengths(),
264 {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
265 Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
266
267 auto v_dram_window =
268 make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
269 v_dram_block_window_tmp.get_window_lengths(),
270 {0, seqlen_k_start}, // TODO: hdim split?
271 Policy::template MakeVDramTileDistribution<Problem>());
272
273 // prefetch K tile
274 index_t i_total_loops = 0;
275 constexpr index_t k0_loops = kQKHeaddim / kK0;
276 constexpr index_t k1_loops = kN0 / kK1;
277
278 static_assert(2 <= k0_loops);
279 static_assert(1 <= k1_loops);
280 do
281 {
282 // STAGE 1, QK gemm
283 auto q_dram_window =
284 make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
285 q_dram_block_window_tmp.get_window_lengths(),
286 q_dram_block_window_tmp.get_window_origin(),
287 Policy::template MakeQDramTileDistribution<Problem>());
288
289 auto k_dram_window =
290 make_tile_window(k_dram_block_window.get_bottom_tensor_view(),
291 k_dram_block_window.get_window_lengths(),
292 k_dram_block_window.get_window_origin(),
293 Policy::template MakeKDramTileDistribution<Problem>());
294
295 auto q_block_tile = load_tile(q_dram_window);
296 auto k_block_tile = load_tile(k_dram_window);
297 {
298 move_tile_window(q_dram_window, {0, kK0});
299 move_tile_window(k_dram_window, {0, kK0});
300
301 clear_tile(s_acc); // initialize C
302
303 store_tile(q_lds_window, tile_elementwise_in(q_element_func, q_block_tile));
304 q_block_tile = load_tile(q_dram_window);
305
306 store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
307 k_block_tile = load_tile(k_dram_window);
308 }
309
311 {
312 __builtin_amdgcn_sched_barrier(
313 0); // prevent from messing up the order of global loads
314 }
315 const auto bias_tile = load_tile(bias_dram_window); // load bias tile
317 {
318 __builtin_amdgcn_sched_barrier(
319 0); // prevent from messing up the order of global loads
320 }
321
322 if constexpr(k0_loops > 2)
323 {
324 static_for<0, k0_loops - 2, 1>{}([&](auto) {
326 gemm_0(s_acc, q_lds_window, k_lds_window);
328
329 move_tile_window(q_dram_window, {0, kK0});
330 move_tile_window(k_dram_window, {0, kK0});
331
333 q_lds_window,
334 tile_elementwise_in(q_element_func, q_block_tile)); // LDS write i + 1
335 q_block_tile = load_tile(q_dram_window); // global read i + 2
336
338 k_lds_window,
339 tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
340 k_block_tile = load_tile(k_dram_window); // global read i + 2
341 });
342 }
343
344 { // tail
346 gemm_0(s_acc, q_lds_window, k_lds_window);
348
349 store_tile(q_lds_window, tile_elementwise_in(q_element_func, q_block_tile));
350 store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
352
353 gemm_0(s_acc, q_lds_window, k_lds_window);
354 }
355
356 __builtin_amdgcn_sched_barrier(0);
357 const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
358 __builtin_amdgcn_sched_barrier(0);
359
360 // STAGE 2, scale_s, add bias, mask, softmax
362 {
363 s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
364 tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
366 [&](auto& x, const auto& y) {
367#if !CK_TILE_FMHA_FWD_FAST_EXP2
368 x += type_convert<SaccDataType>(bias_element_func(y));
369#else
371 type_convert<SaccDataType>(bias_element_func(y));
372#endif
373 },
374 s_acc,
375 bias_tile);
376 }
377 else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
378 {
379 const auto k_origin = k_dram_block_window.get_window_origin();
380 constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
381 s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
382 sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
383 sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
384 const auto tile_idx = get_x_indices_from_distributed_indices(
385 s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
386
387 const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
388 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
389 constexpr auto i_j_idx = make_tuple(idx0, idx1);
390
391 s_acc(i_j_idx) *= scale_s;
392 position_encoding.update(s_acc(i_j_idx), row, col);
393 });
394 });
395 }
396 else
397 {
398 s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
399 if constexpr(kHasLogitsSoftCap)
400 {
401 auto apply_logits_transform =
402 [&variant, &variant_params, &block_indices](auto& x) {
403 x = variant.LogitsTransform(variant_params,
404 variant.QueryTransform(variant_params, x),
405 block_indices.batch_idx,
406 block_indices.qo_head_idx,
407 block_indices.kv_head_idx);
408 };
409#if !CK_TILE_FMHA_FWD_FAST_EXP2
410 tile_elementwise_inout(apply_logits_transform, s_acc);
411#else
412 tile_elementwise_inout(apply_logits_transform, s_acc);
413#endif
414 }
415 else
416 {
417#if !CK_TILE_FMHA_FWD_FAST_EXP2
418 tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
419#endif
420 }
421 }
422 move_tile_window(bias_dram_window, {0, kN0});
423 if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
424 {
425 const auto k_origin = k_dram_block_window.get_window_origin();
426 bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
427 k_origin.at(number<0>{}),
428 number<kM0>{},
429 number<kN0>{});
430 if(need_perpixel_check)
431 {
433 s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
434 const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
435 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
436 return !variant.LogitsMask(variant_params,
437 block_indices.batch_idx,
438 row,
439 col,
440 block_indices.qo_head_idx,
441 block_indices.kv_head_idx);
442 });
443 }
444 }
445
446 const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
448 s,
449 sequence<1>{},
450 f_max,
451 -numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
453
454 const auto m_old = m; // m{j-1}
456 [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
457
459 s.get_tile_distribution()); // Pcompute{j}
460
461 static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
465 FmhaMask::IsMasking)
466 {
469 : raw_m;
470 }
471 else
472 {
473 return raw_m;
474 }
475 };
476
477 constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
478 sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
479 constexpr auto i_idx = make_tuple(idx0);
480#if CK_TILE_FMHA_FWD_FAST_EXP2
481 auto row_max = scale_s * get_validated_m(m[i_idx]);
482#endif
483 sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
484 constexpr auto i_j_idx = make_tuple(idx0, idx1);
485#if CK_TILE_FMHA_FWD_FAST_EXP2
488 {
489 p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
490 }
491 else
492 {
493 if constexpr(kHasLogitsSoftCap)
494 {
495 p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
496 }
497 else
498 {
499 p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
500 }
501 }
502#else
503 p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
504#endif
505 });
506 });
507
509 p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
510
512
513 const auto p =
514 cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
515
516 __builtin_amdgcn_sched_barrier(0);
517
518 // l{j}, Oacc{j}
519 constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
520 sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
521 constexpr auto i_idx = make_tuple(idx0);
522#if CK_TILE_FMHA_FWD_FAST_EXP2
523 const auto tmp = [&]() {
526 {
527 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
528 }
529 else
530 {
531 if constexpr(kHasLogitsSoftCap)
532 {
533
534 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
535 }
536 else
537 {
538 auto row_max = scale_s * get_validated_m(m[i_idx]);
539 return exp2(scale_s * m_old[i_idx] - row_max);
540 }
541 }
542 }();
543#else
544 const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
545#endif
546 l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
547 sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
548 constexpr auto i_j_idx = make_tuple(idx0, idx1);
549 // FIXME: this use different equation from FA v2 paper,
550 // but produce correc result.
551 // Is the equation wrong?
552 o_acc(i_j_idx) *= tmp;
553 });
554 });
555
557 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
558 {
560 Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
561 shuffle_tile(v_shuffle_tmp, v_prefetch);
563 v_lds_window,
564 tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
565 }
566 else
567 {
568 store_tile(v_lds_window,
569 tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
570 }
571 move_tile_window(v_dram_window, {0, kK1});
572
573 // STAGE 3, KV gemm
574 if constexpr(k1_loops > 1)
575 {
576 static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
577 const auto v = load_tile(v_dram_window); // load next v
579 gemm_1(o_acc,
581 p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
582 v_lds_window);
584 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
585 {
587 Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
588 shuffle_tile(v_shuffle_tmp, v);
589 store_tile(v_lds_window,
590 tile_elementwise_in(v_element_func,
591 v_shuffle_tmp)); // store the prefetch
592 }
593 else
594 {
595 store_tile(v_lds_window,
596 tile_elementwise_in(v_element_func, v)); // store next v
597 }
598 move_tile_window(v_dram_window, {0, kK1});
599 });
600 }
601 // move K tile windows
602 move_tile_window(k_dram_block_window, {kN0, 0});
603 // tail
604 {
606 gemm_1(o_acc,
607 get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
608 v_lds_window);
610 }
611 } while(++i_total_loops < num_total_loop);
612
613 // store lse
614 if constexpr(kStoreLSE)
615 {
616 auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
617
618 constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
619 sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
620 constexpr auto i_idx = make_tuple(idx0);
621#if CK_TILE_FMHA_FWD_FAST_EXP2
624 {
625 lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
626 }
627 else
628 {
629 if constexpr(kHasLogitsSoftCap)
630 {
631 lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
632 }
633 else
634 {
635 lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
636 }
637 }
638#else
639 lse(i_idx) = m_[i_idx] + log(l_[i_idx]);
640#endif
641 });
642
643 store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
644 }
645
646 // finally, O
647 constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
648
649 sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
650 constexpr auto i_idx = make_tuple(idx0);
651 const auto tmp = [&]() {
652 if constexpr(FmhaMask::IsMasking)
653 {
654 return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
655 }
656 else
657 return 1 / l[i_idx];
658 }();
659 sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
660 constexpr auto i_j_idx = make_tuple(idx0, idx1);
661 o_acc(i_j_idx) *= tmp;
662 });
663 });
664
665 o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
666
667 return o_acc;
668 }
669
670 template <typename QDramBlockWindowTmp,
671 typename KDramBlockWindowTmp,
672 typename VDramBlockWindowTmp,
673 typename BiasDramBlockWindowTmp,
674 typename RandValDramBlockWindowTmp,
675 typename LSEDramBlockWindowTmp,
676 typename PositionEncoding,
677 typename AttentionVariantParams,
678 typename BlockIndices>
680 operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
681 const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
682 const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
683 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
684 RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
685 LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
686 FmhaMask mask,
687 PositionEncoding position_encoding,
688 float scale_s,
689 const AttentionVariant& variant,
690 const AttentionVariantParams& variant_params,
691 const BlockIndices& block_indices,
692 void* smem_ptr,
693 DropoutType& dropout) const
694 {
695 return operator()(q_dram_block_window_tmp,
696 identity{},
697 k_dram_block_window_tmp,
698 identity{},
699 v_dram_block_window_tmp,
700 identity{},
701 bias_dram_block_window_tmp,
702 identity{},
703 randval_dram_block_window_tmp,
704 lse_dram_block_window_tmp,
705 identity{},
706 identity{},
707 identity{},
708 identity{},
709 mask,
710 position_encoding,
711 scale_s,
712 variant,
713 variant_params,
714 block_indices,
715 smem_ptr,
716 dropout);
717 }
718};
719
720} // namespace ck_tile
#define CK_TILE_FMHA_FWD_FAST_EXP2
Definition config.hpp:234
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition bfloat16.hpp:428
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition tile_elementwise.hpp:95
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_HOST_DEVICE constexpr auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition static_distributed_tensor.hpp:159
CK_TILE_DEVICE constexpr auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition slice_tile.hpp:23
@ ALIBI
Definition block_attention_bias_enum.hpp:15
@ NO_BIAS
Definition block_attention_bias_enum.hpp:13
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition block_reduce.hpp:21
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
constexpr T log2e_v
Definition tile/core/numeric/math.hpp:488
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_DEVICE void shuffle_tile(OutTensor &out, const InTensor &in)
Definition shuffle_tile.hpp:154
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_ &acc_tensor, const InDistributedTensor_ &in_tensor, sequence< InReduceDims... >, const ReduceFunc &reduce_func)
Definition block_reduce.hpp:191
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition bfloat16.hpp:419
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition static_distributed_tensor.hpp:175
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition bfloat16.hpp:425
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_fmha_pipeline_qs_ks_vs.hpp:17
static constexpr bool kHasDropout
Definition block_fmha_pipeline_qs_ks_vs.hpp:57
static constexpr bool kPadSeqLenQ
Definition block_fmha_pipeline_qs_ks_vs.hpp:50
static constexpr index_t kBlockSize
Definition block_fmha_pipeline_qs_ks_vs.hpp:39
static constexpr index_t kAlignmentQ
Definition block_fmha_pipeline_qs_ks_vs.hpp:66
static constexpr index_t kBlockPerCu
Definition block_fmha_pipeline_qs_ks_vs.hpp:82
static constexpr index_t kM0
Definition block_fmha_pipeline_qs_ks_vs.hpp:41
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, RandValDramBlockWindowTmp &randval_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_block_window_tmp, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, void *smem_ptr, DropoutType &dropout) const
Definition block_fmha_pipeline_qs_ks_vs.hpp:680
remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition block_fmha_pipeline_qs_ks_vs.hpp:24
static constexpr const char * name
Definition block_fmha_pipeline_qs_ks_vs.hpp:113
remove_cvref_t< typename Problem::AttentionVariant > AttentionVariant
Definition block_fmha_pipeline_qs_ks_vs.hpp:31
remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition block_fmha_pipeline_qs_ks_vs.hpp:34
static constexpr bool kStoreLSE
Definition block_fmha_pipeline_qs_ks_vs.hpp:56
static constexpr auto BiasEnum
Definition block_fmha_pipeline_qs_ks_vs.hpp:55
remove_cvref_t< typename Problem::ODataType > ODataType
Definition block_fmha_pipeline_qs_ks_vs.hpp:30
static constexpr index_t kK0
Definition block_fmha_pipeline_qs_ks_vs.hpp:43
static constexpr index_t kK1
Definition block_fmha_pipeline_qs_ks_vs.hpp:45
remove_cvref_t< typename Problem::QDataType > QDataType
Definition block_fmha_pipeline_qs_ks_vs.hpp:20
remove_cvref_t< typename Problem::KDataType > KDataType
Definition block_fmha_pipeline_qs_ks_vs.hpp:21
remove_cvref_t< typename Problem::PDataType > PDataType
Definition block_fmha_pipeline_qs_ks_vs.hpp:28
static constexpr index_t kAlignmentBias
Definition block_fmha_pipeline_qs_ks_vs.hpp:79
static constexpr bool kPadSeqLenK
Definition block_fmha_pipeline_qs_ks_vs.hpp:51
static constexpr index_t kAlignmentO
Definition block_fmha_pipeline_qs_ks_vs.hpp:77
remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition block_fmha_pipeline_qs_ks_vs.hpp:32
remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition block_fmha_pipeline_qs_ks_vs.hpp:27
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowTmp &k_dram_block_window_tmp, const KElementFunction &k_element_func, const VDramBlockWindowTmp &v_dram_block_window_tmp, const VElementFunction &v_element_func, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, const BiasElementFunction &bias_element_func, RandValDramBlockWindowTmp &, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, void *smem_ptr, DropoutType &) const
Definition block_fmha_pipeline_qs_ks_vs.hpp:140
remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition block_fmha_pipeline_qs_ks_vs.hpp:29
static constexpr bool kIsGroupMode
Definition block_fmha_pipeline_qs_ks_vs.hpp:49
remove_cvref_t< Policy_ > Policy
Definition block_fmha_pipeline_qs_ks_vs.hpp:19
static constexpr bool kPadHeadDimQ
Definition block_fmha_pipeline_qs_ks_vs.hpp:52
static constexpr index_t kSubQKHeaddim
Definition block_fmha_pipeline_qs_ks_vs.hpp:47
static constexpr bool kHasLogitsSoftCap
Definition block_fmha_pipeline_qs_ks_vs.hpp:54
static constexpr bool kQLoadOnce
Definition block_fmha_pipeline_qs_ks_vs.hpp:36
remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition block_fmha_pipeline_qs_ks_vs.hpp:23
static constexpr bool kPadHeadDimV
Definition block_fmha_pipeline_qs_ks_vs.hpp:53
static constexpr index_t kAlignmentK
Definition block_fmha_pipeline_qs_ks_vs.hpp:68
static constexpr index_t kN0
Definition block_fmha_pipeline_qs_ks_vs.hpp:42
static constexpr index_t kN1
Definition block_fmha_pipeline_qs_ks_vs.hpp:44
remove_cvref_t< Problem_ > Problem
Definition block_fmha_pipeline_qs_ks_vs.hpp:18
remove_cvref_t< typename Problem::BiasDataType > BiasDataType
Definition block_fmha_pipeline_qs_ks_vs.hpp:25
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_pipeline_qs_ks_vs.hpp:117
remove_cvref_t< typename Problem::RandValOutputDataType > RandValOutputDataType
Definition block_fmha_pipeline_qs_ks_vs.hpp:26
remove_cvref_t< typename BlockFmhaShape::VLayout > VLayout
Definition block_fmha_pipeline_qs_ks_vs.hpp:35
remove_cvref_t< typename Problem::VDataType > VDataType
Definition block_fmha_pipeline_qs_ks_vs.hpp:22
static constexpr index_t kAlignmentV
Definition block_fmha_pipeline_qs_ks_vs.hpp:70
std::conditional_t< kHasDropout, BlockDropout, NullBlockDropout > DropoutType
Definition block_fmha_pipeline_qs_ks_vs.hpp:115
static constexpr index_t kQKHeaddim
Definition block_fmha_pipeline_qs_ks_vs.hpp:46
Definition tile/core/utility/functional.hpp:86
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38
Definition tile/core/container/sequence.hpp:49
#define C_LOG2E
Definition tile/core/numeric/math.hpp:469