gridwise_gemm_wmma.hpp Source File

gridwise_gemm_wmma.hpp Source File#

Composable Kernel: gridwise_gemm_wmma.hpp Source File
gridwise_gemm_wmma.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck/utility/env.hpp"
18
19namespace ck {
20
21template <typename GridwiseGemm,
22 typename ADataType,
23 typename BDataType,
24 typename CDataType,
25 typename AGridDesc,
26 typename BGridDesc,
27 typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
28 typename AElementwiseOperation,
29 typename BElementwiseOperation,
30 typename CElementwiseOperation,
31 typename Block2CTileMap,
32 bool HasMainKBlockLoop>
33__global__ void
34#if CK_USE_LAUNCH_BOUNDS
36#endif
37 kernel_gemm_wmma(const ADataType* __restrict__ p_a_grid,
38 const BDataType* __restrict__ p_b_grid,
39 CDataType* __restrict__ p_c_grid,
40 const AGridDesc a_grid_desc,
41 const BGridDesc b_grid_desc,
42 const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
43 c_grid_desc_mblock_mperblock_nblock_nperblock,
44 const AElementwiseOperation a_element_op,
45 const BElementwiseOperation b_element_op,
46 const CElementwiseOperation c_element_op,
47 const Block2CTileMap block_2_ctile_map)
48{
49#if(defined(__gfx11__) || defined(__gfx12__))
50 __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size];
51
52 GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
53 p_b_grid,
54 p_c_grid,
55 p_shared,
56 a_grid_desc,
57 b_grid_desc,
58 c_grid_desc_mblock_mperblock_nblock_nperblock,
59 a_element_op,
60 b_element_op,
61 c_element_op,
62 block_2_ctile_map);
63#else
64 ignore = p_a_grid;
65 ignore = p_b_grid;
66 ignore = p_c_grid;
67 ignore = a_grid_desc;
68 ignore = b_grid_desc;
69 ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
70 ignore = a_element_op;
71 ignore = b_element_op;
72 ignore = c_element_op;
73 ignore = block_2_ctile_map;
74#endif // end of if (defined(__gfx11__))
75}
76
77template <index_t BlockSize,
78 typename ADataType,
79 typename BDataType,
80 typename AccDataType,
81 typename CShuffleDataType,
82 typename CDataType,
83 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
84 typename AGridDesc,
85 typename BGridDesc,
86 typename CGridDesc_M_N,
87 typename AElementwiseOperation,
88 typename BElementwiseOperation,
89 typename CElementwiseOperation,
90 index_t MPerBlock,
91 index_t NPerBlock,
92 index_t KPerBlock,
93 index_t MPerWmma,
94 index_t NPerWmma,
95 index_t K1Value,
96 index_t MRepeat,
97 index_t NRepeat,
98 typename ABlockTransferThreadClusterLengths_K0_M_K1,
99 typename ABlockTransferThreadClusterArrangeOrder,
100 typename ABlockTransferSrcAccessOrder,
101 index_t ABlockTransferSrcVectorDim,
102 index_t ABlockTransferSrcScalarPerVector,
103 index_t ABlockTransferDstScalarPerVector_K1,
104 bool AThreadTransferSrcResetCoordinateAfterRun,
105 bool AEnableLds,
106 bool ABlockLdsExtraM,
107 typename BBlockTransferThreadClusterLengths_K0_N_K1,
108 typename BBlockTransferThreadClusterArrangeOrder,
109 typename BBlockTransferSrcAccessOrder,
110 index_t BBlockTransferSrcVectorDim,
111 index_t BBlockTransferSrcScalarPerVector,
112 index_t BBlockTransferDstScalarPerVector_K1,
113 bool BThreadTransferSrcResetCoordinateAfterRun,
114 bool BEnableLds,
115 bool BBlockLdsExtraN,
116 index_t CShuffleMRepeatPerShuffle,
117 index_t CShuffleNRepeatPerShuffle,
118 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
119 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
120 index_t NumGemmKPrefetchStage = 1,
124{
125 static constexpr auto I0 = Number<0>{};
126 static constexpr auto I1 = Number<1>{};
127 static constexpr auto I2 = Number<2>{};
128 static constexpr auto I3 = Number<3>{};
129 static constexpr auto I4 = Number<4>{};
130 static constexpr auto I5 = Number<5>{};
131 static constexpr auto I6 = Number<6>{};
132 static constexpr auto I7 = Number<7>{};
133
134 // FIX ME: To be deprecated
135 static constexpr auto K1 = Number<K1Value>{};
136
137 static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
138 static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
139 static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
140
142
145 NumGemmKPrefetchStage,
146 LoopSched,
147 AEnableLds,
148 BEnableLds>())>;
149
150 // Describe how data store to (LDS/VGPR) buffer from Global memory
151 __host__ __device__ static constexpr auto MakeABlockDescriptor()
152 {
153 constexpr auto a_block_desc = [&]() {
154 if constexpr(AEnableLds)
155 {
156 // K0->M->K1 Per Block
157 constexpr auto K0PerBlock = KPerBlock / K1;
158 constexpr auto max_lds_align = K1;
159
160 if constexpr(ABlockLdsExtraM)
161 {
165 }
166 else
167 {
169 make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
170 }
171 }
172 else
173 {
174 constexpr auto A_KRow = I2;
175 constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
176 constexpr auto K0PerWmma = WmmaK / A_KRow / K1;
177 // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
181 I1,
183 I1,
184 I1,
185 K1),
189 K1,
190 K1,
191 K1,
192 I1));
193 }
194 }();
195
196 return a_block_desc;
197 }
198
199 __host__ __device__ static constexpr auto MakeBBlockDescriptor()
200 {
201 constexpr auto b_block_desc = [&]() {
202 if constexpr(BEnableLds)
203 {
204 // K0->N->K1 Per Block
205 constexpr auto K0PerBlock = KPerBlock / K1;
206 constexpr auto max_lds_align = K1;
207
208 if constexpr(BBlockLdsExtraN)
209 {
213 }
214 else
215 {
217 make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
218 }
219 }
220 else
221 {
222
223 constexpr auto B_KRow = I2;
224 constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
225 constexpr auto K0PerWmma = WmmaK / B_KRow / K1;
226 // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
230 I1,
232 I1,
233 I1,
234 K1),
238 K1,
239 K1,
240 K1,
241 I1));
242 }
243 }();
244
245 return b_block_desc;
246 }
247
248 __host__ __device__ static constexpr auto MakeABlockSliceCopyStep()
249 {
250 constexpr auto a_block_copy_step = [&]() {
251 if constexpr(AEnableLds)
252 {
253 constexpr auto K0PerBlock = KPerBlock / K1;
254
255 return make_multi_index(K0PerBlock, 0, 0);
256 }
257 else
258 {
259 constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
260
261 return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
262 }
263 }();
264
265 return a_block_copy_step;
266 }
267
268 __host__ __device__ static constexpr auto MakeBBlockSliceCopyStep()
269 {
270 constexpr auto b_block_copy_step = [&]() {
271 if constexpr(BEnableLds)
272 {
273 constexpr auto K0PerBlock = KPerBlock / K1;
274
275 return make_multi_index(K0PerBlock, 0, 0);
276 }
277 else
278 {
279 constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
280
281 return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
282 }
283 }();
284
285 return b_block_copy_step;
286 }
287
288 // Describe how data read from (LDS/VGPR) buffer
289 template <typename ABlockDesc_>
290 __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&)
291 {
292
293 constexpr auto a_wave_desc = [&]() {
294 if constexpr(AEnableLds)
295 {
296 // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
297 constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
298 constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
299#ifdef __gfx12__
300 constexpr auto A_KRow = I2;
301#else
302 constexpr auto A_KRow = I1;
303#endif
304
306 ABlockDesc_{},
313 }
314 else
315 {
316 // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
317 constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
318 constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3);
319 constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4);
320 constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6);
321
322 // Err: merge transform cause non-constexpr issue
323
324 // return transform_tensor_descriptor(
325 // ABlockDesc_{},
326 // make_tuple(make_merge_transform(make_tuple(Number<KWmma>{}, I1)),
327 // make_pass_through_transform(Number<MRepeat>{}),
328 // make_pass_through_transform(I1),
329 // make_pass_through_transform(I1),
330 // make_pass_through_transform(Number<A_K1>{})),
331 // make_tuple(Sequence<0, 3>{},
332 // Sequence<1>{},
333 // Sequence<2>{},
334 // Sequence<4>{},
335 // Sequence<5>{}),
336 // make_tuple(
337 // Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{},
338 // Sequence<4>{}));
339
340 // Workaround, Freeze transform
343 I1,
345 I1,
346 Number<A_K1>{}));
347 }
348 }();
349
350 return a_wave_desc;
351 }
352
353 template <typename BBlockDesc_>
354 __host__ __device__ static constexpr auto MakeBWaveDescriptor(const BBlockDesc_&)
355 {
356 constexpr auto b_wave_desc = [&]() {
357 if constexpr(BEnableLds)
358 {
359 // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
360 constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
361 constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
362#ifdef __gfx12__
363 constexpr auto B_KRow = I2;
364#else
365 constexpr auto B_KRow = I1;
366#endif
368 BBlockDesc_{},
375 }
376 else
377 {
378 // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
379 constexpr auto KWmma = BBlockDesc_{}.GetLength(I0);
380 constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3);
381 constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4);
382 constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6);
383
384 // Workaround, Freeze transform
387 I1,
389 I1,
390 Number<B_K1>{}));
391 }
392 }();
393
394 return b_wave_desc;
395 }
396
397 __host__ __device__ static constexpr auto
398 // *Caution Here repeat is shuffle repeat
400 {
401 constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
405 I1,
407
408 return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
409 }
410
411 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
412 template <typename Block2CTileMap>
413 __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
414 const BGridDesc& b_grid_desc,
415 const CGridDesc_M_N& c_grid_desc_m_n,
416 const Block2CTileMap& block_2_ctile_map)
417 {
418 static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
419 "wrong! K1 need to be known at compile-time");
420
421 static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) &&
422 (NPerBlock % (NRepeat * NPerWmma)) == 0,
423 "Invalid tuning param!");
424
425 const auto GetAProblemsizeMK = [&]() {
426 if constexpr(AEnableLds)
427 {
428 return make_tuple(a_grid_desc.GetLength(I1),
429 a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2));
430 }
431 else
432 {
433 return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) *
434 a_grid_desc.GetLength(I5),
435 a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) *
436 a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6));
437 }
438 };
439
440 const auto GetBProblemsizeNK = [&]() {
441 if constexpr(BEnableLds)
442 {
443 return make_tuple(b_grid_desc.GetLength(I1),
444 b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2));
445 }
446 else
447 {
448 return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) *
449 b_grid_desc.GetLength(I5),
450 b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) *
451 b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6));
452 }
453 };
454
455 const auto M = GetAProblemsizeMK()[I0];
456 const auto N = GetBProblemsizeNK()[I0];
457 const auto K = GetAProblemsizeMK()[I1];
458
459 if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
460 K == GetBProblemsizeNK()[I1]))
461 {
462 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
463 {
464 printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n",
465 GetAProblemsizeMK()[I0],
466 GetAProblemsizeMK()[I1],
467 GetBProblemsizeNK()[I0],
468 GetBProblemsizeNK()[I1],
469 c_grid_desc_m_n.GetLength(I0),
470 c_grid_desc_m_n.GetLength(I1));
471 printf("GridwiseOp err: ProblemSize check");
472 }
473 return false;
474 }
475
476 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
477 {
478 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
479 {
480 printf("GridwiseOp err: ProblemSize division");
481 }
482 return false;
483 }
484
485 // check gridwise gemm pipeline
486 const auto num_k_loop = K / KPerBlock;
487
488 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
489 {
490 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
491 {
492 printf("GridwiseOp err: Pipeline not support this k_loop");
493 }
494 return false;
495 }
496
497 if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
498 {
499 return false;
500 }
501
502 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
503 constexpr long_index_t TwoGB = (long_index_t{1} << 31);
504
505 if(!(a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
506 b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB))
507 {
508 return false;
509 }
510 return true;
511 }
512
513 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
514 {
515 const index_t num_loop = K / KPerBlock;
516
517 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
518 }
519
520 __host__ __device__ static constexpr auto
521 MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
522 {
523 const auto M = c_grid_desc_m_n.GetLength(I0);
524 const auto N = c_grid_desc_m_n.GetLength(I1);
525
526 const auto MBlock = M / MPerBlock;
527 const auto NBlock = N / NPerBlock;
528
529 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
530 c_grid_desc_m_n,
535
536 return c_grid_desc_mblock_mperblock_nblock_nperblock;
537 }
538
539 // return block_id to C matrix tile idx (m0, n0) mapping
540 __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
541 const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
542 {
544 c_grid_desc_m_n);
545 }
546
548 {
549 // LDS allocation for A and B: be careful of alignment
550
551 static constexpr auto max_lds_align = K1;
552
553 static constexpr auto a_block_space_size_aligned =
554 AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(),
556 : 0;
557 static constexpr auto b_block_space_size_aligned =
558 BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(),
560 : 0;
561
562 static constexpr auto a_block_space_offset = 0;
564
565 // LDS allocation for C shuffle in LDS
566 static constexpr auto c_shuffle_block_space_size =
568 .GetElementSpaceSize();
569
570 static constexpr auto c_shuffle_block_space_offset = 0;
571
572 static constexpr auto lds_size =
573 math::max(c_shuffle_block_space_size * sizeof(CShuffleDataType),
574 a_block_space_size_aligned * sizeof(ADataType) +
575 b_block_space_size_aligned * sizeof(BDataType));
576 };
577
580 CGridDesc_M_N{}))>;
582 remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
583
584 template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
585 __device__ static void Run(const ADataType* __restrict__ p_a_grid,
586 const BDataType* __restrict__ p_b_grid,
587 CDataType* __restrict__ p_c_grid,
588 void* __restrict__ p_shared,
589 const AGridDesc& a_grid_desc,
590 const BGridDesc& b_grid_desc,
592 c_grid_desc_mblock_mperblock_nblock_nperblock,
593 const AElementwiseOperation& a_element_op,
594 const BElementwiseOperation& b_element_op,
595 const CElementwiseOperation& c_element_op,
596 const Block2CTileMap& block_2_ctile_map)
597 {
598 // clang-format off
599/*******************************************************************************/
600// Memory buffer zone.
601 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
602 p_a_grid, a_grid_desc.GetElementSpaceSize());
603 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
604 p_b_grid, b_grid_desc.GetElementSpaceSize());
606 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
607
608/*******************************************************************************/
609// BlockIdx.x -> [BlockId.m, BlockId.n]
610 const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
611 if(!block_2_ctile_map.ValidCTileIndex(
612 block_work_idx,
613 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
614 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
615 { return; }
616
617 // Store BlockId into SGPR
618 const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
619 const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
620
621/*******************************************************************************/
622// BlockLevel, A/B Matrix ThreadMapping in WMMA Source buffer, As Destinaion of BlockWise_Copy
623 const auto K = [&](){
624 if constexpr(AEnableLds){
625 return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2);
626 }
627 else{
628 return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3)
629 * a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6);
630 }
631 }();
632
633 constexpr auto a_block_desc = MakeABlockDescriptor();
634 constexpr auto b_block_desc = MakeBBlockDescriptor();
635
636 auto a_block_trait = [&](){
637 // A matrix blockwise copy
638 if constexpr(AEnableLds)
639 {
640 constexpr auto K0PerBlock = KPerBlock/ K1;
642 static_cast<ADataType*>(p_shared),
644
645 auto a_blockwise_copy =
647/* typename SrcElementwiseOperation, */ AElementwiseOperation,
648/* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough,
649/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set,
650/* typename BlockSliceLengths, */ Sequence<K0PerBlock, MPerBlock, K1>,
651/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1,
652/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder,
653/* typename SrcData, */ ADataType,
654/* typename DstData, */ ADataType,
655/* typename SrcDesc, */ decltype(a_grid_desc),
656/* typename DstDesc, */ decltype(a_block_desc),
657/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
658/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>,
659/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim,
660/* index_t DstVectorDim, */ 2,
661/* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector,
662/* index_t DstScalarPerVector, */ ABlockTransferDstScalarPerVector_K1,
663/* index_t SrcScalarStrideInVector, */ 1,
664/* index_t DstScalarStrideInVector, */ 1,
665/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun,
666/* bool ThreadTransferDstResetCoordinateAfterRun, */ true,
667 NumGemmKPrefetchStage>(
668 a_grid_desc,
669 make_multi_index(0, m_block_data_idx_on_grid, 0),
670 a_element_op,
671 a_block_desc,
672 make_multi_index(0, 0, 0),
674
675 return make_tuple(a_block_buf, a_blockwise_copy);
676 }
677 else
678 {
679 // Thread-wise copy
680 // KPerBlock/WmmaK -> MRepeat -> MWaves -> K0PerWmma -> KRow -> MPerWmma -> K1
681 constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
682 constexpr auto K0PerWmma = WmmaK/2/K1Value;
684 a_block_desc.GetElementSpaceSize());
685
686 // Limitation: NumDim of Src and Dst descriptor should be identical
687 auto a_blockwise_copy =
689 ADataType,
690 decltype(a_grid_desc),
691 decltype(a_block_desc),
694 I1,
696 I1,
697 I1,
700 6,
701 ABlockTransferSrcScalarPerVector,
702 AThreadTransferSrcResetCoordinateAfterRun,
703 true>(
704 a_grid_desc,
706 m_block_data_idx_on_grid/(MWaves * MPerWmma),
708 0,
709 (get_thread_local_1d_id() % 32 )/ 16,
711 0));
712
713 return make_tuple(a_block_buf, a_blockwise_copy);
714 }
715 };
716
717 auto b_block_trait = [&](){
718 if constexpr(BEnableLds)
719 {
720 constexpr auto K0PerBlock = KPerBlock/ K1;
722 static_cast<BDataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
724
725 auto b_blockwise_copy =
727 BElementwiseOperation,
731 BBlockTransferThreadClusterLengths_K0_N_K1,
732 BBlockTransferThreadClusterArrangeOrder,
733 BDataType,
734 BDataType,
735 decltype(b_grid_desc),
736 decltype(b_block_desc),
737 BBlockTransferSrcAccessOrder,
739 BBlockTransferSrcVectorDim,
740 2,
741 BBlockTransferSrcScalarPerVector,
742 BBlockTransferDstScalarPerVector_K1,
743 1,
744 1,
745 BThreadTransferSrcResetCoordinateAfterRun,
746 true,
747 NumGemmKPrefetchStage>(
748 b_grid_desc,
749 make_multi_index(0, n_block_data_idx_on_grid, 0),
750 b_element_op,
751 b_block_desc,
752 make_multi_index(0, 0, 0),
754
755 return make_tuple(b_block_buf, b_blockwise_copy);
756 }
757 else
758 {
759 // Thread-wise copy
760 // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
761 constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
762 constexpr auto K0PerWmma = WmmaK/2/K1Value;
764 b_block_desc.GetElementSpaceSize());
765
766 // Limitation: NumDim of Src and Dst descriptor should be identical
767 auto b_blockwise_copy =
769 BDataType,
770 decltype(b_grid_desc),
771 decltype(b_block_desc),
774 I1,
776 I1,
777 I1,
780 6,
781 BBlockTransferSrcScalarPerVector,
782 BThreadTransferSrcResetCoordinateAfterRun,
783 true>(
784 b_grid_desc,
786 n_block_data_idx_on_grid/(NWaves * NPerWmma),
788 0,
789 (get_thread_local_1d_id() % 32 )/ 16,
791 0));
792
793 return make_tuple(b_block_buf, b_blockwise_copy);
794 }
795 };
796
797 auto a_block_buf = a_block_trait()[I0];
798 auto a_blockwise_copy = a_block_trait()[I1];
799
800 auto b_block_buf = b_block_trait()[I0];
801 auto b_blockwise_copy = b_block_trait()[I1];
802/*******************************************************************************/
803 // GEMM
804 constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
805
806 auto blockwise_gemm =
807 BlockwiseGemmWMMA<BlockSize,
808 ADataType,
809 BDataType,
810 AccDataType,
811 decltype(MakeAWaveDescriptor(a_block_desc)),
812 decltype(MakeBWaveDescriptor(b_block_desc)),
813 MPerBlock,
814 NPerBlock,
815 KPerBlock,
816 MPerWmma,
817 NPerWmma,
818 MRepeat,
819 NRepeat,
820 KPack,
821 AEnableLds,
822 BEnableLds>{};
823
824 // Prepare Register for C matrix
825 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
826
827/*******************************************************************************/
828 // Shift Per SUB_K
829 constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep();
830 constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep();
831
832 // gridwise GEMM pipeline
833 const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock);
834 GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc,
835 a_block_desc,
836 a_blockwise_copy,
837 a_grid_buf,
838 a_block_buf,
839 a_block_slice_copy_step,
840 b_grid_desc,
841 b_block_desc,
842 b_blockwise_copy,
843 b_grid_buf,
844 b_block_buf,
845 b_block_slice_copy_step,
846 blockwise_gemm,
847 c_thread_buf,
848 KBlockMainLoop);
849/*******************************************************************************/
850 // write out to C, implement shuffle
851 {
852 // C mapping in single thread.
853 constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
854 blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
855
856 // C mapping in single block
857 constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
858 blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
859
860 constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I1);
861 constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I2);
862 constexpr auto NWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I4);
863 constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I5);
864 constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I6);
865
866 // LDS descriptor, shuffle and write out in MRepeat x NRepeat times
867 constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
869
870 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
871 static_cast<CShuffleDataType*>(p_shared) + SharedMemTrait::c_shuffle_block_space_offset,
873
874 constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor(
875 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
879 Number<CShuffleMRepeatPerShuffle>{}, // MRepeat per shuffle repeat
880 MWave, // MWave
881 MSubGroup, // MSubGroup * MAccVgprs = MPerWmma
882 MAccVgprs)),
885 Number<CShuffleNRepeatPerShuffle>{}, // NRepeat per shuffle repeat
886 NWave, // NWave
887 NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma
890
891 // calculate origin of thread output tensor on global memory
892 // blockwise GEMM c matrix starting index
893 const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0);
894
895 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
896 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
897
898 const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
900 make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))),
903
904 const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
906 make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))),
909
910 const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex(
911 make_multi_index(m_thread_data_on_block));
912
913 const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex(
914 make_multi_index(n_thread_data_on_block));
915
916 // shuffle: threadwise copy C from VGPR to LDS
917 auto c_thread_copy_vgpr_to_lds =
919 CShuffleDataType,
920 decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
921 decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
923 Sequence<CShuffleMRepeatPerShuffle,
924 I1,
925 I1,
926 CShuffleNRepeatPerShuffle,
927 I1,
928 I1,
929 MAccVgprs>,
931 6,
932 1, // vector write pixel
934 1,
935 true>{
936 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
938 m_thread_data_on_block_idx[I1],
939 m_thread_data_on_block_idx[I2],
940 0,
941 n_thread_data_on_block_idx[I1],
942 n_thread_data_on_block_idx[I2],
943 m_thread_data_on_block_idx[I3]),
945
946 // shuffle: blockwise copy C from LDS to global
947 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
948 ThisThreadBlock, // ThreadGroup
949 CElementwiseOperation, // ElementwiseOperation,
950 CGlobalMemoryDataOperation, // DstInMemOp,
951 Sequence<1,
952 CShuffleMRepeatPerShuffle * MWave * MPerWmma,
953 1,
954 CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
955 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
956 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
957 CShuffleDataType, // typename SrcData,
958 CDataType, // typename DstData,
959 decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
960 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
961 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
962 3, // index_t VectorDim,
963 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
964 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
965 false> // bool ThreadTransferDstResetCoordinateAfterRun>
966 {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
967 make_multi_index(0, 0, 0, 0),
968 c_grid_desc_mblock_mperblock_nblock_nperblock,
969 make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
970 c_element_op};
971
972 // space filling curve for local reg & global memory
973 // space filling curve for threadwise C in VGPR
974 constexpr auto sfc_c_vgpr =
977 Sequence<CShuffleMRepeatPerShuffle,
978 1,
979 1,
980 CShuffleNRepeatPerShuffle,
981 1,
982 1,
983 MAccVgprs>>{};
984
985 // space filling curve for shuffled blockwise C in global mem
986 constexpr auto sfc_c_global =
989 Sequence<1,
990 CShuffleMRepeatPerShuffle * MWave * MPerWmma,
991 1,
992 CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{};
993
994 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
995
996 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
997
998 static_for<0, num_access, 1>{}([&](auto access_id) {
999 // make sure it's safe to write to LDS
1001
1002 // each thread write its data from VGPR to LDS
1003 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
1004 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1005 c_thread_buf,
1006 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
1007 c_shuffle_block_buf);
1008
1009 // make sure it's safe to read from LDS
1011
1012 // each block copy its data from LDS to global
1013 c_shuffle_block_copy_lds_to_global.Run(
1014 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
1015 c_shuffle_block_buf,
1016 c_grid_desc_mblock_mperblock_nblock_nperblock,
1017 c_grid_buf);
1018
1019 if constexpr(access_id < num_access - 1)
1020 {
1021 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1022
1023 // move on C
1024 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1025 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1026 }
1027 });
1028 }
1029 // clang-format on
1030 }
1031};
1032
1033} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
constexpr auto GridwiseGemmPipeline_Selector()
Definition gridwise_gemm_pipeline_selector.hpp:31
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
integral_constant< index_t, N > Number
Definition number.hpp:12
__global__ void kernel_gemm_wmma(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, CDataType *__restrict__ p_c_grid, const AGridDesc a_grid_desc, const BGridDesc b_grid_desc, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map)
Definition gridwise_gemm_wmma.hpp:37
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
__host__ __device__ constexpr auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition tensor_descriptor_helper.hpp:132
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
typename remove_cv< T >::type remove_cv_t
Definition type.hpp:295
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
Definition block_to_ctile_map.hpp:261
Definition blockwise_gemm_wmma.hpp:550
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_wmma.hpp:585
Definition gridwise_gemm_wmma.hpp:548
static constexpr auto c_shuffle_block_space_size
Definition gridwise_gemm_wmma.hpp:566
static constexpr auto b_block_space_size_aligned
Definition gridwise_gemm_wmma.hpp:557
static constexpr auto max_lds_align
Definition gridwise_gemm_wmma.hpp:551
static constexpr auto c_shuffle_block_space_offset
Definition gridwise_gemm_wmma.hpp:570
static constexpr auto lds_size
Definition gridwise_gemm_wmma.hpp:572
static constexpr auto a_block_space_size_aligned
Definition gridwise_gemm_wmma.hpp:553
static constexpr auto a_block_space_offset
Definition gridwise_gemm_wmma.hpp:562
static constexpr auto b_block_space_offset
Definition gridwise_gemm_wmma.hpp:563
Definition gridwise_gemm_wmma.hpp:124
__host__ static __device__ constexpr auto MakeAWaveDescriptor(const ABlockDesc_ &)
Definition gridwise_gemm_wmma.hpp:290
static __device__ void Run(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, CDataType *__restrict__ p_c_grid, void *__restrict__ p_shared, const AGridDesc &a_grid_desc, const BGridDesc &b_grid_desc, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_wmma.hpp:585
__host__ static __device__ constexpr auto MakeBBlockDescriptor()
Definition gridwise_gemm_wmma.hpp:199
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc &a_grid_desc, const BGridDesc &b_grid_desc, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_wmma.hpp:413
__host__ static __device__ constexpr auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition gridwise_gemm_wmma.hpp:399
__host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n, index_t, index_t)
Definition gridwise_gemm_wmma.hpp:540
__host__ static __device__ constexpr auto MakeABlockDescriptor()
Definition gridwise_gemm_wmma.hpp:151
__host__ static __device__ constexpr auto MakeBWaveDescriptor(const BBlockDesc_ &)
Definition gridwise_gemm_wmma.hpp:354
__host__ static __device__ constexpr auto MakeABlockSliceCopyStep()
Definition gridwise_gemm_wmma.hpp:248
__host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_wmma.hpp:521
__host__ static __device__ constexpr auto MakeBBlockSliceCopyStep()
Definition gridwise_gemm_wmma.hpp:268
__host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_gemm_wmma.hpp:513
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition is_known_at_compile_time.hpp:14
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
#define CK_ENV(name)
Definition utility/env.hpp:129