block_fmha_pipeline_qx_ks_vs_custom_policy.hpp Source File

block_fmha_pipeline_qx_ks_vs_custom_policy.hpp Source File#

Composable Kernel: block_fmha_pipeline_qx_ks_vs_custom_policy.hpp Source File
block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
19
20namespace ck_tile {
21
22template <bool QLoadOnce_>
24
25template <>
26struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
27{
28 static constexpr bool QLoadOnce = true;
29
30 template <typename Problem>
32 {
33 return 0;
34 }
35
36 // TODO: GetAlignment*() currently didn't consider if need padding or not
37 // so in pipeline still need check padding requirement
38 template <typename Problem>
39 CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
40 {
41 constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
42
44 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
45 using WG = remove_cvref_t<decltype(config.template at<0>())>;
46
47 return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane);
48 }
49
50 template <typename Problem>
52 {
54
55 return BlockGemm::template MakeABlockTileDistribution<
56 Problem::BlockFmhaShape::kM0,
57 Problem::BlockFmhaShape::kSubQKHeaddim>();
58 }
59
60 template <typename Problem>
61 CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
62 {
63 using GemmProblem =
64 BlockGemmProblem<typename Problem::QDataType,
65 typename Problem::KDataType,
66 typename Problem::SaccDataType,
67 Problem::kNumGemm0Warps * get_warp_size(),
68 TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
69 Problem::BlockFmhaShape::kN0,
70 Problem::BlockFmhaShape::kK0>,
71 typename Problem::BlockFmhaShape::Gemm0BlockWarps,
72 typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
73
74 constexpr auto warp_gemm = []() {
75 if constexpr(get_warp_size() == 64 &&
76 std::is_same_v<typename Problem::QDataType, fp8_t> &&
77 std::is_same_v<typename Problem::KDataType, fp8_t> &&
78 std::is_same_v<typename Problem::SaccDataType, float>)
79 {
80 static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32);
81 static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}) == 32);
82 static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}) == 32);
83
84 // TODO: hard coded here. Otherwise, it produces incorrect results
85 constexpr index_t swizzle_factor = 4;
87 swizzle_factor>{};
88 }
89 else
90 {
91 constexpr bool SwizzleA =
92 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32;
93 return WarpGemmDispatcher<typename Problem::QDataType,
94 typename Problem::KDataType,
95 typename Problem::SaccDataType,
96 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
97 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
98 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
99 true, // TransposeC
100 SwizzleA>{};
101 }
102 }();
103
104 using BlockGemmPolicy =
105 BlockGemmARegBSmemCRegV2CustomPolicy<typename Problem::QDataType,
106 typename Problem::KDataType,
107 typename Problem::SaccDataType,
108 typename Problem::BlockFmhaShape::Gemm0BlockWarps,
109 decltype(warp_gemm)>;
110
111 if constexpr(1 < Problem::kNumGemm0Warps)
113 else
115 }
116};
117
118template <>
119struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
120{
121 static constexpr bool QLoadOnce = false;
122
123 template <typename Problem>
125 {
126 constexpr index_t lds_alignment = 16; // optional
127 constexpr index_t q_smem_size =
129 sizeof(typename Problem::QDataType) *
130 MakeQLdsBlockDescriptor<Problem>().get_element_space_size(),
131 lds_alignment) *
132 lds_alignment;
133 return q_smem_size;
134 }
135
136 template <typename Problem>
137 CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
138 {
139 constexpr index_t kBlockSize = Problem::kBlockSize;
140 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
141 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
142
143 constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
144
145 // this should align with MakeQDramTileDistribution()
146 constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
147 static_assert(0 < ElemPerThread);
148 return min(ElemPerThread, MaxVectorSize);
149 }
150
151 template <typename Problem>
153 {
155
156 constexpr index_t kBlockSize = Problem::kBlockSize;
157 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
158 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
159
160 constexpr index_t MaxVectorSize = 16 / sizeof(QDataType);
161
162 constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
163 static_assert(0 < ElemPerThread);
164 constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
165
166 constexpr index_t KPerThread = kMaxVecLoad;
167 constexpr index_t KThreads = kKPerBlock / KPerThread;
168 constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
169 constexpr index_t NumWarps = kBlockSize / get_warp_size();
170 constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
171
179 sequence<0, 1>>{});
180 }
181
182 // 3d + padding
183 template <typename Problem>
185 {
187
188 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
189 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
190 constexpr index_t kKPack = 16 / sizeof(QDataType);
191
192 constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor(
194 make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
196 number<1>{});
197
198 constexpr auto q_lds_block_desc = transform_tensor_descriptor(
199 q_lds_block_desc_0,
201 make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
204
205 return q_lds_block_desc;
206 }
207
208 template <typename Problem>
209 CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
210 {
211 using GemmProblem =
212 BlockGemmProblem<typename Problem::QDataType,
213 typename Problem::KDataType,
214 typename Problem::SaccDataType,
215 Problem::kNumGemm0Warps * get_warp_size(),
216 TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
217 Problem::BlockFmhaShape::kN0,
218 Problem::BlockFmhaShape::kK0>,
219 typename Problem::BlockFmhaShape::Gemm0BlockWarps,
220 typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
221
222 constexpr auto warp_gemm = []() {
223 if constexpr(get_warp_size() == 64 &&
224 std::is_same_v<typename Problem::QDataType, fp8_t> &&
225 std::is_same_v<typename Problem::KDataType, fp8_t> &&
226 std::is_same_v<typename Problem::SaccDataType, float>)
227 {
228 static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32);
229 static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}) == 32);
230 static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}) == 32);
231
232 // TODO: hard coded here. Otherwise, it produces incorrect results
233 constexpr index_t swizzle_factor = 4;
235 swizzle_factor>{};
236 }
237 else
238 {
239 constexpr bool SwizzleA =
240 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32;
241 return WarpGemmDispatcher<typename Problem::QDataType,
242 typename Problem::KDataType,
243 typename Problem::SaccDataType,
244 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
245 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
246 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
247 true, // TransposeC
248 SwizzleA>{};
249 }
250 }();
251
252 using BlockGemmPolicy =
253 BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::QDataType,
254 typename Problem::KDataType,
255 typename Problem::SaccDataType,
256 typename Problem::BlockFmhaShape::Gemm0BlockWarps,
257 decltype(warp_gemm)>;
258
260 }
261};
262
263// This pipeline is qkv all located in LDS
264template <bool QLoadOnce_, bool AsyncCopy_, index_t NumPrefetchK_, index_t NumPrefetchV_>
266{
267 static constexpr bool AsyncCopy = AsyncCopy_;
268
269 static constexpr index_t NumPrefetchK = NumPrefetchK_;
270 static constexpr index_t NumPrefetchV = NumPrefetchK_;
271
273
275
276 template <index_t k_prefetches_, index_t v_prefetches_, index_t k_loops_, index_t v_loops_>
278 {
279 static constexpr index_t num_lds_buffers_ = max(k_prefetches_, v_prefetches_);
280 static constexpr index_t ceil_ = ((v_loops_ - 1) / num_lds_buffers_) * num_lds_buffers_;
281
282 // for qr_ks_vs_async, the Lds buffer assigned to last gemm_1 iteration of V should not
283 // overlap with the Lds buffers used by first two gemm_0 iterations of K
284 static constexpr auto Make()
285 {
286 // ensure v_loop_-1 is assigned to num_lds_buffers-1
287 return transform_sequences(
288 [&](auto i) {
289 if(i < k_loops_)
290 return i % num_lds_buffers_;
291 else
292 return ((num_lds_buffers_ - 1) + (i - k_loops_ + ceil_ - (v_loops_ - 1))) %
294 },
296 };
297
298 using type = remove_cvref_t<decltype(Make())>;
299 };
300
301 // clang-format off
304
307
310
313
316
319 // clang-format on
320
321 template <typename Problem>
323 {
325
326 constexpr index_t kN0 = BlockFmhaShape::kN0;
327 constexpr index_t kK0 = BlockFmhaShape::kK0;
328 constexpr index_t kK1 = BlockFmhaShape::kK1;
329 constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
330
331 constexpr index_t k0_loops = kQKHeaddim / kK0;
332 constexpr index_t k1_loops = kN0 / kK1;
333
335 }
336
337 template <typename Problem>
338 CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK()
339 {
340 // TODO: this is for 3d layout
342 return 16 / sizeof(KDataType);
343 }
344
345 template <typename Problem>
346 CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK()
347 {
349 if constexpr(AsyncCopy)
350 {
351#if defined(__gfx950__)
352 constexpr index_t MaxLoadSizeInBytes = 4 * 4; // dwordx4
353#else
354 constexpr index_t MaxLoadSizeInBytes = 4; // dword
355#endif
356
357 return MaxLoadSizeInBytes / sizeof(KDataType);
358 }
359 else
360 {
361 constexpr index_t kBlockSize = Problem::kBlockSize;
362 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
363 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
364
365 constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
366 constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
367
368 return min(MaxVectorSize, ElemPerThread);
369 }
370 }
371
372 template <typename Problem>
373 CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
374 {
375 // TODO: this is for 3d layout
377 constexpr index_t kBlockSize = Problem::kBlockSize;
378 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
379 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
380 constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
381 constexpr index_t kMaxVecLoad =
382 min(total_pixels, static_cast<index_t>(16 / sizeof(VDataType)));
383
384 return kMaxVecLoad;
385 }
386
387 template <typename Problem>
388 CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
389 {
392 constexpr index_t kBlockSize = Problem::kBlockSize;
393 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
394 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
395 constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
396 constexpr index_t kMaxVecLoad =
397 min(total_pixels, static_cast<index_t>(16 / sizeof(VDataType)));
398
399 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
400 {
401 constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
402
403 constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
404 ? kMaxVecLoad
405 : (total_pixels / kMinVecLoad);
406
407 return kVecLoad;
408 }
409 else
410 {
411 return kMaxVecLoad;
412 }
413 }
414
415 template <typename Problem>
416 CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBias()
417 {
419 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
420 using WG = remove_cvref_t<decltype(config.template at<0>())>;
421
422 return WG::WarpGemmAttribute::Impl::kCM1PerLane;
423 }
424
425 template <typename Problem>
426 CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
427 {
429 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
430 using WG = remove_cvref_t<decltype(config.template at<0>())>;
431
432 constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::ODataType);
433 return min(MaxVectorSize, WG::WarpGemmAttribute::Impl::kCM1PerLane);
434 }
435
436 template <typename Problem>
438 {
439 // this function assume K/V can share smem
440 constexpr index_t SingleKSize = [&]() {
441 if constexpr(!AsyncCopy)
442 {
443 return MakeKLdsBlockDescriptor<Problem>().get_element_space_size();
444 }
445 else
446 {
447 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
448 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
449 constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
450 constexpr index_t WarpSize = ck_tile::get_warp_size();
451
452 constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
453 constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
454 constexpr index_t kPad = KPack;
455
456 static_assert(WarpSize * KVector >= kKPerBlock &&
457 WarpSize * KVector % kKPerBlock == 0);
458 constexpr index_t LanesPerK = kKPerBlock / KVector;
459 constexpr index_t LaneGroups = WarpSize / LanesPerK;
460 constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
461
462 return NumIssues * NumWarps * (WarpSize * KVector + kPad);
463 }
464 }();
465
466 constexpr index_t SingleVSize = [&]() {
468 constexpr index_t Banks = get_n_lds_banks();
469 constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
470 constexpr index_t kKPack = GetSmemKPackK<Problem>();
471 static_assert(PixelsPerRow % kKPack == 0);
472 constexpr index_t NPerRow = PixelsPerRow / kKPack;
473 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
474 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
475 static_assert(kNPerBlock % NPerRow == 0);
476 static_assert(kKPerBlock % kKPack == 0);
477
478 return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
479 }();
480
481 return max(SingleKSize, SingleVSize);
482 }
483
484 // TODO: this is used for non async copy desc. unify in the future
485 template <typename Problem>
487 {
488 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
489 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
490 constexpr index_t kKPack = GetSmemKPackK<Problem>();
491
492 constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
494 make_tuple(number<(kNPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
496 number<1>{});
497
498 constexpr auto k_lds_block_desc = transform_tensor_descriptor(
499 k_lds_block_desc_0,
505
506 return k_lds_block_desc;
507 }
508
509 template <typename Problem, index_t IBuf = 0>
510 CK_TILE_HOST_DEVICE static constexpr auto
512 {
513 // K is always k-major, we use async-copy to load into LDS
514 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
515 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
516 constexpr index_t kBlockSize = Problem::kBlockSize;
517 constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
518 constexpr index_t WarpSize = ck_tile::get_warp_size();
519
520 constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
521 constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
522 constexpr index_t kPad =
523 KPack; // for async-copy, this pad is between warps. Optimize this for lds_read speed
524
525 static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
526 constexpr index_t LanesPerK =
527 kKPerBlock / KVector; // how many lane (within a wave) to load K
528 constexpr index_t LaneGroups =
529 WarpSize /
530 LanesPerK; // how many groups (within a wave), they may load different N, but same K
531 constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
532 static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
533
534 constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
536 number<LaneGroups>{}, // n1
537 number<NumWarps>{}, // n2
538 number<LanesPerK>{}, // k0
539 number<KVector>{}), // k1
540 make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
544 number<1>{}),
547 number<1>{});
548
549 // TODO this layout is hard coded, and will be used in async copy buffer view load
550 // in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
551 constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
552 k_lds_block_desc_0,
557 make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
558 make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
559
560 return k_lds_block_desc_issues_warps_lanes;
561 }
562
563 template <typename Problem>
565 {
566 // K is always k-major, we use async-copy to load into LDS
567 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
568 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
569 constexpr index_t kBlockSize = Problem::kBlockSize;
570 constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
571 constexpr index_t WarpSize = ck_tile::get_warp_size();
572
573 constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
574 constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
575 constexpr index_t kPad = KPack; // for async-copy, this pad is between warps
576
577 static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
578 constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
579 constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
580 constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
581 static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
582 // constexpr index_t SingleKSize = NumIssues * NumWarps * (WarpSize * KVector + kPad);
583 // constexpr index_t SingleVSize =
584 // MakeVLdsBlockDescriptor<Problem>().get_element_space_size();
585 constexpr index_t BufferSize =
586 GetSingleSmemElementSpaceSize<Problem>(); // max(SingleKSize, SingleVSize);
587
588 constexpr auto k_lds_block_desc_0 =
590 number<NumIssues>{}, // n0
591 number<NumWarps>{}, // n2
592 number<LaneGroups>{}, // n1
593 number<kKPerBlock / KPack>{}, // k0
594 number<KPack>{}), // k1
596 number<NumWarps*(WarpSize * KVector + kPad)>{},
600 number<1>{}),
602 number<1>{});
603
604 constexpr auto k_lds_block_desc = transform_tensor_descriptor(
605 k_lds_block_desc_0,
614
615 return k_lds_block_desc;
616 }
617
618 // 3d + padding
619 template <typename Problem>
621 {
623 constexpr index_t Banks = get_n_lds_banks();
624 constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
625 constexpr index_t kKPack = GetSmemKPackV<Problem>();
626 static_assert(PixelsPerRow % kKPack == 0);
627 constexpr index_t NPerRow = PixelsPerRow / kKPack;
628 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
629 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
630 static_assert(kNPerBlock % NPerRow == 0);
631 static_assert(kKPerBlock % kKPack == 0);
632
633 constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
635 number<kKPerBlock / kKPack>{},
636 number<kNPerBlock / NPerRow>{},
640 number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
643 number<1>{}),
645 number<1>{});
646
647 constexpr auto v_lds_block_desc = transform_tensor_descriptor(
648 v_lds_block_desc_0,
651 number<NumKVLdsBuffers>{}, number<kNPerBlock / NPerRow>{}, number<NPerRow>{})),
655
656 return v_lds_block_desc;
657 }
658
659 template <typename Problem>
661 {
662 // TODO: assume Q is in register
663 // TODO: assume K/V has same data type
664 constexpr index_t single_smem_size =
665 GetSingleSmemElementSpaceSize<Problem>() * sizeof(typename Problem::KDataType);
666
667 return QXPolicy::template GetSmemSizeQ<Problem>() + single_smem_size * NumKVLdsBuffers;
668 }
669
670 template <typename Problem>
672 {
673 if constexpr(AsyncCopy)
674 {
676 }
677 else
678 {
680 }
681 }
682
683 // this method is only available when Problem::kHasDropout is present
684 template <typename Problem>
685 CK_TILE_HOST_DEVICE static constexpr std::
686 enable_if_t<std::is_convertible_v<decltype(Problem::kHasDropout), bool>, ck_tile::index_t>
688 {
689 if constexpr(Problem::kHasDropout)
690 {
691 constexpr auto gemm_0 = QXPolicy::template GetQKBlockGemm<Problem>();
692 constexpr auto config =
693 decltype(gemm_0)::Policy::template GetWarpGemmMWarpNWarp<Problem>();
694 using WG = remove_cvref_t<decltype(config.template at<0>())>;
695 constexpr index_t MWarp = config.template at<1>();
696 constexpr index_t kMPerStep = MWarp * WG::kM;
697 constexpr index_t kNPerStep = WG::kN;
698
699 return (kMPerStep + 1) * kNPerStep * sizeof(uint8_t);
700 }
701 else
702 {
703 return 0;
704 }
705 }
706
707 // fallback version if Problem::kHasDropout is not exist
708 template <typename Problem>
710 {
711 return 0;
712 }
713
714 template <typename Problem>
716 {
717 if constexpr(!AsyncCopy)
718 {
720
721 constexpr index_t kBlockSize = Problem::kBlockSize;
722 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
723 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
724
725 constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
726 constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
727
728 constexpr index_t K1 = min(MaxVectorSize, ElemPerThread);
729 constexpr index_t K0 = kKPerBlock / K1;
730 constexpr index_t N2 = get_warp_size() / K0;
731 constexpr index_t N1 = kBlockSize / get_warp_size();
732 constexpr index_t N0 = kNPerBlock / (N2 * N1);
733
740 sequence<0, 1>>{});
741 }
742 else
743 {
744 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
745 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
746 constexpr index_t kBlockSize = Problem::kBlockSize;
747 constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
748 constexpr index_t WarpSize = ck_tile::get_warp_size();
749
750 constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
751
752 static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
753 constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
754 constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
755 constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
756 static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
757
758 constexpr index_t N0 = NumIssues;
759 constexpr index_t N1 = LaneGroups;
760 constexpr index_t N2 = NumWarps;
761 constexpr index_t K0 = LanesPerK;
762 constexpr index_t K1 = KVector;
763
770 sequence<0, 1>>{});
771 }
772 }
773
774 template <typename Problem>
776 {
778
779 constexpr index_t kBlockSize = Problem::kBlockSize;
780 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
781 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
782
783 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
784 {
785 constexpr index_t N1 = GetAlignmentV<Problem>();
786 constexpr index_t N0 = kNPerBlock / N1; // P
787
788 constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
789 constexpr index_t kKPack = GetSmemKPackV<Problem>();
790 constexpr index_t K3 = total_pixels / N1;
791 constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
792 if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible
793 {
794 static_assert(kNPerBlock % 16 == 0);
795 constexpr index_t kNPack = kNPerBlock % 32 == 0 ? 32 : 16;
796 constexpr index_t K0 = kBlockSize / get_warp_size();
797 constexpr index_t N2 = 2;
798 constexpr index_t N1_m = kNPack / N2;
799 constexpr index_t N0_m = kNPerBlock / kNPack;
800 constexpr index_t K1 = get_warp_size() / N1_m;
801 constexpr index_t K2_m = kKPerBlock / K1 / K0;
806 tuple<sequence<2>, sequence<2, 1>>, // K0, K1 N0
808 sequence<1, 2, 1>, // N0 K2 N2
810 }
811 else if constexpr(get_warp_size() % (K2 * N0) == 0)
812 {
813 constexpr index_t K1 = get_warp_size() / (K2 * N0);
814 constexpr index_t K0 = kBlockSize / get_warp_size();
815 static_assert(kKPerBlock == K0 * K1 * K2 * K3);
822 sequence<3, 1>>{});
823 }
824 else
825 {
826 constexpr index_t K1 = (K2 * N0) / get_warp_size();
827 constexpr index_t K2_m = K2 / K1;
828 constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
829 static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
836 sequence<3, 1>>{});
837 }
838 }
839 else
840 {
841 constexpr index_t K1 = GetAlignmentV<Problem>();
842 constexpr index_t K0 = kKPerBlock / K1;
843 constexpr index_t N2 = get_warp_size() / K0;
844 constexpr index_t N1 = kBlockSize / get_warp_size();
845 static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
846 static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
847 constexpr index_t N0 = kNPerBlock / (N2 * N1);
848 static_assert(N0 != 0);
849
850 constexpr auto dstr = make_static_tile_distribution(
853 tuple<sequence<1>, sequence<1, 2>>, // N1, N2 K0
855 sequence<1, 2>, // N0 K1
856 sequence<0, 1>>{});
857 if constexpr(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
858 kNPerBlock * kKPerBlock)
859 {
860 return dstr;
861 }
862 else
863 {
864 static_assert(kKPerBlock % 16 == 0);
865 constexpr index_t kKPerIter = kKPerBlock % 32 == 0 ? 32 : 16;
866 constexpr index_t K0_m = kKPerBlock / kKPerIter;
867 constexpr index_t K2 = 2;
868 constexpr index_t K1_m = kKPerIter / K2;
869 constexpr index_t N2_m = get_warp_size() / K1_m;
870 constexpr index_t N0_m = kNPerBlock / (N2_m * N1);
871 constexpr auto dstr_m = make_static_tile_distribution(
875 tuple<sequence<1>, sequence<1, 2>>, // N1, N2 K1
877 sequence<2, 1, 2>, // K0 N0 K2
879 static_assert(container_reduce(dstr_m.get_lengths(),
880 std::multiplies<index_t>{},
881 1) == kNPerBlock * kKPerBlock);
882 return dstr_m;
883 }
884 }
885 }
886
887 template <typename BlockGemm>
889 {
890 return BlockGemm::MakeCBlockTile().get_tile_distribution();
891 }
892
893 template <typename Problem>
895 {
896 // This descriptor only used when V layout is seqlen * hdim
898 static_assert(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
899 constexpr index_t kBlockSize = Problem::kBlockSize;
900 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
901 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
902
903 constexpr index_t N1 = GetAlignmentV<Problem>();
904 constexpr index_t N0 = kNPerBlock / N1;
905 constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
906 constexpr index_t K3 = total_pixels / N1;
907 constexpr index_t kKPack = GetSmemKPackV<Problem>();
908 constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
909 if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible
910 {
911 static_assert(kNPerBlock % 16 == 0);
912 constexpr index_t kNPack = kNPerBlock % 32 == 0 ? 32 : 16;
913 constexpr index_t K0 = kBlockSize / get_warp_size();
914 constexpr index_t N2 = 2;
915 constexpr index_t N1_m = kNPack / N2;
916 constexpr index_t N0_m = kNPerBlock / kNPack;
917 constexpr index_t K1 = get_warp_size() / N1_m;
918 constexpr index_t K2_m = kKPerBlock / K1 / K0;
922 tuple<sequence<2>, sequence<2, 1>>, // K0, K1 N0
924 sequence<1, 1, 2>, // N0 K2 <-> N2
926 }
927 else if constexpr(get_warp_size() % (K2 * N0) == 0)
928 {
929 constexpr index_t K1 = get_warp_size() / (K2 * N0);
930 constexpr index_t K0 = kBlockSize / get_warp_size();
931
938 sequence<1, 3>>{});
939 }
940 else
941 {
942 constexpr index_t K1 = (K2 * N0) / get_warp_size();
943 constexpr index_t K2_m = K2 / K1;
944 constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
945 static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
952 sequence<1, 3>>{});
953 }
954 }
955
956 template <typename Problem>
957 CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
958 {
959 using GemmProblem =
960 BlockGemmProblem<typename Problem::PDataType,
961 typename Problem::VDataType,
962 typename Problem::OaccDataType,
963 Problem::kNumGemm1Warps * get_warp_size(),
964 TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
965 Problem::BlockFmhaShape::kN1,
966 Problem::BlockFmhaShape::kK1>,
967 typename Problem::BlockFmhaShape::Gemm1BlockWarps,
968 typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
969
970 auto warp_gemm = [&]() {
971 if constexpr(get_warp_size() == 64 &&
972 std::is_same_v<typename Problem::PDataType, fp8_t> &&
973 std::is_same_v<typename Problem::VDataType, fp8_t> &&
974 std::is_same_v<typename Problem::OaccDataType, float>)
975 {
976 static_assert(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}) == 32);
977 static_assert(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 32);
978 static_assert(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32);
979
981 }
982 else
983 {
984 return WarpGemmDispatcher<typename Problem::PDataType,
985 typename Problem::VDataType,
986 typename Problem::OaccDataType,
987 Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
988 Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
989 Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
990 true>{};
991 }
992 }();
993
994 using WarpGemm = remove_cvref_t<decltype(warp_gemm)>;
995
996 using BlockGemmPolicy =
997 BlockGemmARegBSmemCRegV2CustomPolicy<typename Problem::PDataType,
998 typename Problem::VDataType,
999 typename Problem::OaccDataType,
1000 typename Problem::BlockFmhaShape::Gemm1BlockWarps,
1001 WarpGemm>;
1003 }
1004};
1005
1006} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container &x, Reduce reduce, Init init, number< IBegin >=number< 0 >{}, number< IEnd >=number< Container::size()>{}, number< IStep >=number< 1 >{})
Definition tile/core/container/container_helper.hpp:198
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor_with_offset(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, const offset &os, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:319
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto transform_sequences(F f, sequence< Xs... >)
Definition tile/core/container/sequence.hpp:832
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
WarpGemmImpl< WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB< WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8< WGAttrCtlEnum::Default_ >, 2, swizzle_factor > > WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution
Definition warp_gemm.hpp:394
unsigned char uint8_t
Definition stdint.h:124
static CK_TILE_HOST_DEVICE constexpr auto MakeQLdsBlockDescriptor()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:184
static constexpr bool QLoadOnce
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:121
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSizeQ()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:124
static CK_TILE_HOST_DEVICE constexpr auto GetQKBlockGemm()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:209
static CK_TILE_HOST_DEVICE constexpr auto MakeQDramTileDistribution()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:152
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentQ()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:137
static CK_TILE_HOST_DEVICE constexpr auto GetQKBlockGemm()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:61
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSizeQ()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:31
static CK_TILE_HOST_DEVICE constexpr auto MakeQRegTileDistribution()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:51
static constexpr bool QLoadOnce
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:28
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentQ()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:39
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:23
sequence< 1, 2, 1, 0 > type
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:318
sequence< 1, 2, 0, 1, 2, 0 > type
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:309
sequence< 1, 2, 0, 1, 2, 0 > type
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:312
sequence< 1, 2, 0, 0, 1, 2, 0 > type
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:315
sequence< 1, 2, 0, 1, 2, 0 > type
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:306
sequence< 1, 2, 0, 1, 0, 1, 2, 0 > type
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:303
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:278
static constexpr index_t num_lds_buffers_
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:279
remove_cvref_t< decltype(Make())> type
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:298
static constexpr index_t ceil_
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:280
static constexpr auto Make()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:284
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:266
static constexpr bool AsyncCopy
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:267
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledVRegBlockDescriptor()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:894
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSizeKV()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:660
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentK()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:346
static CK_TILE_HOST_DEVICE constexpr std::enable_if_t< std::is_convertible_v< decltype(Problem::kHasDropout), bool >, ck_tile::index_t > GetSmemSizeDropout(int)
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:687
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:671
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentV()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:388
static CK_TILE_HOST_DEVICE constexpr auto GetKVBlockGemm()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:957
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackV()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:373
static CK_TILE_HOST_DEVICE constexpr auto MakeVLdsBlockDescriptor()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:620
static constexpr index_t NumPrefetchK
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:269
static constexpr index_t NumKVLdsBuffers
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:272
static CK_TILE_HOST_DEVICE constexpr auto GetLdsBufferSequence()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:322
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSizeDropout(...)
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:709
static CK_TILE_HOST_DEVICE constexpr auto MakeKLdsLoadBlockDescriptor()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:564
static CK_TILE_HOST_DEVICE constexpr auto MakeKLdsStoreBlockDescriptor(number< IBuf >=number< 0 >{})
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:511
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentBias()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:416
static CK_TILE_DEVICE constexpr auto MakeVDramTileDistribution()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:775
BlockFmhaPipelineQXCustomPolicy< QLoadOnce_ > QXPolicy
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:274
static CK_TILE_HOST_DEVICE constexpr auto MakeKDramTileDistribution()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:715
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentO()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:426
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackK()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:338
static CK_TILE_HOST_DEVICE constexpr auto GetSingleSmemElementSpaceSize()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:437
static constexpr index_t NumPrefetchV
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:270
static CK_TILE_HOST_DEVICE constexpr auto MakeBiasDramTileDistribution()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:888
static CK_TILE_HOST_DEVICE constexpr auto MakeKLdsBlockDescriptor()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:486
Definition block_gemm_areg_bsmem_creg_one_warp_v1.hpp:16
Definition block_gemm_areg_bsmem_creg_v2_custom_policy.hpp:16
Definition block_gemm_areg_bsmem_creg_v2.hpp:16
Definition block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp:18
Definition block_gemm_asmem_bsmem_creg_v1.hpp:16
Definition block_gemm_problem.hpp:18
Definition tile_gemm_shape.hpp:17
typename std::conditional< kHasContent, type0, type1 >::type type
Definition tile/core/container/sequence.hpp:302
Definition tile/core/container/sequence.hpp:49
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192