21#ifndef KERNEL_GEMM_XDL_CSHUFFLE_V3_MX
22#define KERNEL_GEMM_XDL_CSHUFFLE_V3_MX
29template <
bool Use2LDS,
30 typename GridwiseGemm,
31 bool HasMainKBlockLoop,
36#if CK_USE_LAUNCH_BOUNDS
42#if defined(__gfx950__)
43 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
45 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
47 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
49 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
50 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
51 karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
52 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
53 karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
54 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
63template <
bool Use2LDS,
64 typename GridwiseGemm,
65 bool HasMainKBlockLoop,
70#if CK_USE_LAUNCH_BOUNDS
76#if defined(__gfx950__)
77 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
81 __shared__
char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
82 __shared__
char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
84 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
86 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
87 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
88 karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
89 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
90 karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
91 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
102template <
typename ALayout,
106 typename AScaleDataType,
108 typename BScaleDataType,
109 typename AccDataType,
110 typename CShuffleDataType,
112 typename AElementwiseOperation,
113 typename BElementwiseOperation,
114 typename CElementwiseOperation,
127 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
128 typename ABlockTransferThreadClusterArrangeOrder,
129 typename ABlockTransferSrcAccessOrder,
130 index_t ABlockTransferSrcVectorDim,
131 index_t ABlockTransferSrcScalarPerVector,
132 index_t ABlockTransferDstScalarPerVector_AK1,
133 bool AThreadTransferSrcResetCoordinateAfterRun,
135 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
136 typename BBlockTransferThreadClusterArrangeOrder,
137 typename BBlockTransferSrcAccessOrder,
138 index_t BBlockTransferSrcVectorDim,
139 index_t BBlockTransferSrcScalarPerVector,
140 index_t BBlockTransferDstScalarPerVector_BK1,
141 bool BThreadTransferSrcResetCoordinateAfterRun,
143 index_t CShuffleMXdlPerWavePerShuffle,
144 index_t CShuffleNXdlPerWavePerShuffle,
145 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
146 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
149 typename ComputeTypeA =
151 typename ComputeTypeB =
153 bool PermuteA =
false,
154 bool PermuteB =
false>
213 "A scale pack data type too large!");
215 "B scale pack data type too large!");
239 auto K_t = K_Batch * KPerBlock;
240 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
245 auto K_t = K_Batch * KPerBlock;
246 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
260 auto K_t = K_Batch * KPerBlock;
261 return (K + K_t - 1) / K_t * KPerBlock;
267 auto K_t = K_Batch * KReadVec;
268 return (K + K_t - 1) / K_t * KReadVec;
281 template <
index_t MNXdlPerWave,
286 typename TileDesc_K0_MN_K1>
331 const auto a_grid_desc_mraw_kraw = [&]() {
344 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
345 GemmSpec == GemmSpecialization::MNKPadding)
348 const auto a_grid_desc_m_k =
362 return a_grid_desc_ak0_m_ak1;
364 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
365 GemmSpec == GemmSpecialization::MNPadding)
369 a_grid_desc_mraw_kraw,
375 return a_grid_desc_ak0_m_ak1;
377 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
378 GemmSpec == GemmSpecialization::NKPadding)
382 a_grid_desc_mraw_kraw,
394 return a_grid_desc_ak0_m_ak1;
400 a_grid_desc_mraw_kraw,
407 a_grid_desc_ak0_m_ak1,
415 a_grid_desc_permuted,
429 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
430 constexpr index_t WaveSize = BlockSize / (MWave *
NWave);
439 const auto b_grid_desc_nraw_kraw = [&]() {
453 GemmSpec != GemmSpecialization::Default),
454 "pk_i4_t does not support padding");
456 GemmSpec != GemmSpecialization::Default),
457 "f4x2_pk_t does not support padding");
459 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
460 GemmSpec == GemmSpecialization::MNKPadding)
463 const auto b_grid_desc_n_k =
477 return b_grid_desc_bk0_n_bk1;
479 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
480 GemmSpec == GemmSpecialization::MNPadding)
484 b_grid_desc_nraw_kraw,
490 return b_grid_desc_bk0_n_bk1;
492 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
493 GemmSpec == GemmSpecialization::MKPadding)
497 b_grid_desc_nraw_kraw,
509 return b_grid_desc_bk0_n_bk1;
513 if constexpr(!PermuteB)
517 b_grid_desc_nraw_kraw,
525 b_grid_desc_bk0_n_bk1,
533 b_grid_desc_permuted,
546 constexpr index_t BK01 = KPerBlock / BK1Value;
548 const index_t BK0_ = StrideB / BK1Value;
549 const index_t BK00 = BK0_ / BK01;
551 const auto b_grid_desc_bk00_n_bk01_bk1_permute =
555 b_grid_desc_bk00_n_bk01_bk1_permute,
562 return b_grid_desc_bk0_n_bk1_permute;
567 template <
typename ABlockDesc_AK0_M_AK1>
568 __host__ __device__
static constexpr auto
571 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
574 ABlockDesc_AK0_M_AK1{});
577 template <
typename BBlockDesc_BK0_N_BK1>
578 __host__ __device__
static constexpr auto
581 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
584 BBlockDesc_BK0_N_BK1{});
587 __host__ __device__
static auto
590 const auto c_grid_desc_mraw_nraw = [&]() {
610 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
611 GemmSpec == GemmSpecialization::MNKPadding)
620 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
621 GemmSpec == GemmSpecialization::MKPadding)
625 c_grid_desc_mraw_nraw,
630 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
631 GemmSpec == GemmSpecialization::NKPadding)
635 c_grid_desc_mraw_nraw,
643 return c_grid_desc_mraw_nraw;
681 std::cout <<
"problem {" <<
"M:" <<
M <<
", " <<
"N:" <<
N <<
", " <<
"K:" <<
K <<
", "
685 <<
", " <<
"KRead:" <<
KRead <<
", " <<
"KP:" <<
KPadded <<
", "
686 <<
"AK0:" <<
AK0 <<
", " <<
"BK0:" <<
BK0 <<
", " <<
"MBlock: " <<
MBlock
687 <<
", " <<
"NBlock: " <<
NBlock <<
"}" << std::endl;
713 const AScaleDataType* p_a_scale_grid_,
714 const BDataType* p_b_grid_,
715 const BScaleDataType* p_b_scale_grid_,
716 CDataType* p_c_grid_,
726 AElementwiseOperation a_element_op_,
727 BElementwiseOperation b_element_op_,
728 CElementwiseOperation c_element_op_,
729 bool is_reduce_ =
false)
793 if constexpr(!PermuteB)
799 const int k0_offset = karg.
KRead * karg.
N;
812 if(k_id < (karg.
KBatch - 1))
840 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
841 constexpr index_t WaveSize = BlockSize / (MWave *
NWave);
855 constexpr auto a_lds_block_desc =
867 return a_lds_block_desc_permuted;
874 constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
875 constexpr auto M1 = MPerBlock / M0;
877 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
878 constexpr auto K0PerThreadWrite =
AK0Number / KThreadWrite;
879 constexpr auto KThreadRead = WaveSize / MPerXdl;
880 constexpr auto K0PerThreadRead =
AK0Number / KThreadRead;
882 constexpr auto kfold = (
AK1Number * M0 *
sizeof(ADataType) > 128)
884 : 128 / (
AK1Number * M0 *
sizeof(ADataType));
885 constexpr auto KThreadReadPerm =
886 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
887 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
891 constexpr auto mpair = (
AK1Number * MPerXdl *
sizeof(ADataType) > 128)
893 : ((128 / (
AK1Number * MPerXdl *
sizeof(ADataType))) > M0
895 : 128 / (
AK1Number * MPerXdl *
sizeof(ADataType)));
901 Number<kfold * M0 / mpair>{},
920 a_lds_block_desc_permuted,
942 a_lds_block_desc_unmerged,
945 Number<KThreadWrite / kfold / KThreadReadPerm>{},
954 return a_lds_block_desc_ak0_m_ak1;
970 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
973 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
980 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1001 ABlockTransferSrcScalarPerVector,
1002 BBlockTransferSrcScalarPerVector,
1021 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1024 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1027 constexpr auto c_block_size =
1028 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1030 return math::max(a_block_space_size_aligned *
sizeof(ADataType),
1031 c_block_size *
sizeof(CShuffleDataType));
1039 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1040 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1041 "Invalid tuning param!");
1043 static_assert(KPerBlock % (ScaleBlockSize /
BPackedSize) == 0,
1044 "KPerBlock should be multiple of ScaleBlockSize");
1052 if(!(karg.M % MPerBlock == 0))
1056 std::cout <<
"Arg M value is not a multiple of MPerBlock! M: " << karg.M <<
" "
1057 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1070 if(!(karg.N % NPerBlock == 0))
1074 std::cout <<
"Arg N value is not a multiple of NPerBlock! N: " << karg.N <<
" "
1075 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1087 auto K_t = karg.KBatch * KPerBlock;
1088 if(!(karg.K % K_t == 0))
1092 std::cout <<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1093 << karg.K <<
" " << __FILE__ <<
":" << __LINE__
1094 <<
", in function: " << __func__ << std::endl;
1102 auto K_t = karg.KBatch * KReadVec;
1104 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1112 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1116 std::cout <<
"Arg K (" << karg.K
1117 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1118 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1119 << __LINE__ <<
", in function: " << __func__ << std::endl;
1126 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1130 std::cout <<
"Arg M (" << karg.M
1131 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1132 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1133 << __LINE__ <<
", in function: " << __func__ << std::endl;
1141 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1145 std::cout <<
"Arg N (" << karg.N
1146 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1147 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1148 << __LINE__ <<
", in function: " << __func__ << std::endl;
1155 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1159 std::cout <<
"Arg K (" << karg.K
1160 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1161 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1162 << __LINE__ <<
", in function: " << __func__ << std::endl;
1170 if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1174 std::cout <<
"Arg N (" << karg.N
1175 <<
") value is not a multiple of "
1176 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1177 << CShuffleBlockTransferScalarPerVector_NPerBlock <<
" )! "
1178 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1186 if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1190 std::cout <<
"Arg M (" << karg.M
1191 <<
") value is not a multiple of "
1192 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1193 << CShuffleBlockTransferScalarPerVector_NPerBlock <<
" )! "
1194 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1206 if(!karg.IsReduceAdd())
1210 std::cout <<
" KBatch: " << karg.KBatch <<
" > 1 is not support yet" << __FILE__
1211 <<
":" << __LINE__ <<
", in function: " << __func__ << std::endl;
1220 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1224 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1235 const index_t num_loop = K / KPerBlock;
1237 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1242 const index_t num_loop = K / KPerBlock;
1244 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1247 template <
typename CGr
idDesc>
1249 const CGridDesc& c_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
1258 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1266 template <
typename AGridDesc_AK0_M_K1,
1267 typename AScaleGridDesc_AM_AK,
1268 typename BGridDesc_BK0_N_K1,
1269 typename BScaleGridDesc_BN_AK,
1270 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1271 bool HasMainKBlockLoop,
1274 __device__
static void Run(
const ADataType* p_a_grid,
1275 const AScaleDataType* p_a_scale_grid,
1276 const BDataType* p_b_grid,
1277 const BScaleDataType* p_b_scale_grid,
1278 CDataType* p_c_grid,
1280 const Problem& problem,
1281 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1282 const AScaleGridDesc_AM_AK& a_scale_grid_desc_am_ak,
1283 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1284 const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak,
1285 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1286 c_grid_desc_mblock_mperblock_nblock_nperblock)
1289 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1291 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1293 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1297 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1301 p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1303 const AElementwiseOperation a_element_op{};
1304 const BElementwiseOperation b_element_op{};
1305 const CElementwiseOperation c_element_op{};
1308 const auto block_2_ctile_map =
Block2CTileMap{problem.M, problem.N, 4};
1310 const auto block_work_idx =
1313 if(!block_2_ctile_map.ValidCTileIndex(
1315 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
1316 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
1321 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
1322 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
1325 const index_t m_block_data_idx_on_grid =
1326 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1328 const index_t n_block_data_idx_on_grid =
1329 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1340 auto a_blockwise_copy =
1343 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1344 ABlockTransferThreadClusterArrangeOrder,
1347 decltype(a_grid_desc_ak0_m_ak1),
1348 decltype(a_block_desc_ak0_m_ak1),
1349 ABlockTransferSrcAccessOrder,
1350 ABlockTransferSrcVectorDim,
1352 ABlockTransferSrcScalarPerVector>(
1353 a_grid_desc_ak0_m_ak1,
1355 a_block_desc_ak0_m_ak1,
1359 auto b_blockwise_copy =
1362 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1363 BBlockTransferThreadClusterArrangeOrder,
1366 decltype(b_grid_desc_bk0_n_bk1),
1367 decltype(b_block_desc_bk0_n_bk1),
1368 BBlockTransferSrcAccessOrder,
1369 BBlockTransferSrcVectorDim,
1371 BBlockTransferSrcScalarPerVector>(
1372 b_grid_desc_bk0_n_bk1,
1374 b_block_desc_bk0_n_bk1,
1379 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1383 static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1386 reinterpret_cast<BDataType*
>(
static_cast<char*
>(p_shared) +
1387 a_block_space_size_aligned *
sizeof(ADataType)),
1388 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1394 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1396 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1398 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1399 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1419 const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
1420 const auto waveId_m = wave_idx[
I0];
1421 const auto waveId_n = wave_idx[
I1];
1429 auto thread_offset_shuffled =
1432 auto a_thread_offset_m = waveId_m;
1437 decltype(a_scale_grid_desc_am_ak),
1438 decltype(BlockwiseGemmPipe::a_scale_thread_desc),
1444 true>(a_scale_grid_desc_am_ak,
1449 auto b_thread_offset_n = waveId_n;
1454 decltype(b_scale_grid_desc_bn_ak),
1455 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1461 true>(b_scale_grid_desc_bn_ak,
1467 a_block_desc_ak0_m_ak1,
1471 a_block_slice_copy_step,
1472 b_grid_desc_bk0_n_bk1,
1473 b_block_desc_bk0_n_bk1,
1477 b_block_slice_copy_step,
1479 a_scale_grid_desc_am_ak,
1480 a_scale_thread_copy,
1482 b_scale_grid_desc_bn_ak,
1483 b_scale_thread_copy,
1485 num_k_block_main_loop);
1489 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1490 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1492 static_assert(CShuffleMXdlPerWavePerShuffle %
MXdlPack == 0 &&
1493 CShuffleNXdlPerWavePerShuffle %
NXdlPack == 0,
1496 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1499 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1500 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1504 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1505 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1507 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
1508 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
1509 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
1510 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
1511 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
1512 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
1513 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
1514 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
1515 constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I8);
1516 constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I9);
1518 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1522 static_cast<CShuffleDataType*
>(p_shared),
1523 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1526 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1552 const auto c_thread_mtx_on_block =
1553 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
1555 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
1556 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
1558 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1564 const auto m_thread_data_on_block_idx =
1565 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1568 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1574 const auto n_thread_data_on_block_idx =
1575 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1582 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1583 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1586 CShuffleNXdlPerWavePerShuffle /
NXdlPack,
1595 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
1600 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1603 m_thread_data_on_block_idx[
I1],
1604 n_thread_data_on_block_idx[
I1],
1605 m_thread_data_on_block_idx[
I2],
1606 n_thread_data_on_block_idx[
I2],
1607 m_thread_data_on_block_idx[
I3],
1608 m_thread_data_on_block_idx[
I4],
1609 m_thread_data_on_block_idx[
I5],
1610 n_thread_data_on_block_idx[
I3]),
1616 CElementwiseOperation,
1617 CGlobalMemoryDataOperation,
1619 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1621 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>,
1622 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1626 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1627 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1630 CShuffleBlockTransferScalarPerVector_NPerBlock,
1633 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1635 c_grid_desc_mblock_mperblock_nblock_nperblock,
1640 constexpr auto sfc_c_vgpr =
1651 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
1653 CShuffleNXdlPerWavePerShuffle /
NXdlPack,
1664 constexpr auto sfc_c_global =
1668 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1670 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>>{};
1672 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1674 static_assert(num_access == sfc_c_global.GetNumOfAccess(),
"wrong!");
1681 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1682 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1684 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1685 c_shuffle_block_buf);
1691 c_shuffle_block_copy_lds_to_global.Run(
1692 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1693 c_shuffle_block_buf,
1694 c_grid_desc_mblock_mperblock_nblock_nperblock,
1697 if constexpr(access_id < num_access - 1)
1699 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1702 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1703 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1709 template <
bool HasMainKBlockLoop,
1712 __device__
static void Run(
const ADataType* p_a_grid,
1713 const AScaleDataType* p_a_scale_grid,
1714 const BDataType* p_b_grid,
1715 const BScaleDataType* p_b_scale_grid,
1716 CDataType* p_c_grid,
1718 const Problem& problem)
1721 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
1723 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1725 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1726 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1728 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1744 Run<
decltype(a_grid_desc_ak0_m_ak1),
1745 decltype(a_scale_grid_desc_am_ak),
1746 decltype(b_grid_desc_bk0_n_bk1),
1747 decltype(b_scale_grid_desc_bn_ak),
1748 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1750 CGlobalMemoryDataOperation,
1758 a_grid_desc_ak0_m_ak1,
1759 a_scale_grid_desc_am_ak,
1760 b_grid_desc_bk0_n_bk1,
1761 b_scale_grid_desc_bn_ak,
1762 c_grid_desc_mblock_mperblock_nblock_nperblock);
1765 template <
typename AGridDesc_AK0_M_K1,
1766 typename AScaleGridDesc_AM_AK,
1767 typename BGridDesc_BK0_N_K1,
1768 typename BScaleGridDesc_BN_AK,
1769 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1770 bool HasMainKBlockLoop,
1773 __device__
static void Run_2Lds(
const ADataType* p_a_grid,
1774 const AScaleDataType* p_a_scale_grid,
1775 const BDataType* p_b_grid,
1776 const BScaleDataType* p_b_scale_grid,
1777 CDataType* p_c_grid,
1780 const Problem& problem,
1781 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1782 const AScaleGridDesc_AM_AK& a_scale_grid_desc_am_ak,
1783 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1784 const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak,
1785 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1786 c_grid_desc_mblock_mperblock_nblock_nperblock)
1789 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1790 const auto b_grid_buf =
1792 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1794 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1798 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1801 const auto b_scale_grid_buf =
1803 p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1805 const CElementwiseOperation c_element_op{};
1808 const auto block_2_ctile_map =
Block2CTileMap{problem.M, problem.N, 4};
1810 const auto block_work_idx =
1813 if(!block_2_ctile_map.ValidCTileIndex(
1815 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
1816 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
1821 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
1822 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
1825 const index_t m_block_data_idx_on_grid =
1826 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1828 const index_t n_block_data_idx_on_grid =
1829 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave /
NXdlPack);
1840 auto a_blockwise_copy =
1843 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1844 ABlockTransferThreadClusterArrangeOrder,
1847 decltype(a_grid_desc_ak0_m_ak1),
1848 decltype(a_block_desc_ak0_m_ak1),
1849 ABlockTransferSrcAccessOrder,
1850 ABlockTransferSrcVectorDim,
1852 ABlockTransferSrcScalarPerVector>(
1853 a_grid_desc_ak0_m_ak1,
1855 a_block_desc_ak0_m_ak1,
1860 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1862 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1863 auto b_block_bufs =
make_tuple(b_block_buf_ping, b_block_buf_pong);
1868 decltype(b_grid_desc_bk0_n_bk1),
1869 decltype(b_block_desc_bk0_n_bk1),
1877 BBlockTransferSrcScalarPerVector,
1878 BThreadTransferSrcResetCoordinateAfterRun,
1879 true>(b_grid_desc_bk0_n_bk1,
1888 static_cast<ADataType*
>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1891 static_cast<ADataType*
>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1893 auto a_block_bufs =
make_tuple(a_block_buf_ping, a_block_buf_pong);
1899 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1901 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1903 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1904 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1924 const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
1925 const auto waveId_m = wave_idx[
I0];
1926 const auto waveId_n = wave_idx[
I1];
1934 auto thread_offset_shuffled =
1937 auto a_thread_offset_m = waveId_m;
1942 decltype(a_scale_grid_desc_am_ak),
1943 decltype(BlockwiseGemmPipe::a_scale_thread_desc),
1949 true>(a_scale_grid_desc_am_ak,
1954 auto b_thread_offset_n = waveId_n;
1959 decltype(b_scale_grid_desc_bn_ak),
1960 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1966 true>(b_scale_grid_desc_bn_ak,
1972 a_block_desc_ak0_m_ak1,
1976 a_block_slice_copy_step,
1977 b_grid_desc_bk0_n_bk1,
1978 b_block_desc_bk0_n_bk1,
1982 b_block_slice_copy_step,
1984 a_scale_grid_desc_am_ak,
1985 a_scale_thread_copy,
1987 b_scale_grid_desc_bn_ak,
1988 b_scale_thread_copy,
1990 num_k_block_main_loop);
1994 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1995 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1997 static_assert(CShuffleMXdlPerWavePerShuffle %
MXdlPack == 0 &&
1998 CShuffleNXdlPerWavePerShuffle %
NXdlPack == 0,
2001 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2005 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2006 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2010 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2011 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2013 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
2014 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
2015 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
2016 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
2017 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
2018 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
2019 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
2020 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
2021 constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I8);
2022 constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I9);
2024 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2028 static_cast<CShuffleDataType*
>(p_shared_0),
2029 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2032 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2058 const auto c_thread_mtx_on_block =
2059 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
2061 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
2062 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
2064 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2070 const auto m_thread_data_on_block_idx =
2071 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2074 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2080 const auto n_thread_data_on_block_idx =
2081 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2088 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2089 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2092 CShuffleNXdlPerWavePerShuffle /
NXdlPack,
2101 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
2106 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2109 m_thread_data_on_block_idx[
I1],
2110 n_thread_data_on_block_idx[
I1],
2111 m_thread_data_on_block_idx[
I2],
2112 n_thread_data_on_block_idx[
I2],
2113 m_thread_data_on_block_idx[
I3],
2114 m_thread_data_on_block_idx[
I4],
2115 m_thread_data_on_block_idx[
I5],
2116 n_thread_data_on_block_idx[
I3]),
2120 constexpr auto DWORD_BYTES = 4;
2121 constexpr auto atomic_vector_size = DWORD_BYTES /
sizeof(CDataType);
2123 constexpr auto CShuffleBlockTransferClusterLengths = [&]() {
2126 return CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{};
2133 if constexpr(i == 3)
2136 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{}
2138 CShuffleBlockTransferScalarPerVector_NPerBlock /
2139 atomic_vector_size>{};
2141 else if constexpr(i == 1)
2144 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{}
2146 CShuffleBlockTransferScalarPerVector_NPerBlock *
2147 atomic_vector_size>{};
2152 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{}
2160 constexpr auto CShuffleBlockTransferScalarPerVector = [&]() {
2163 return CShuffleBlockTransferScalarPerVector_NPerBlock;
2167 return atomic_vector_size;
2174 CElementwiseOperation,
2175 CGlobalMemoryDataOperation,
2177 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2179 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>,
2180 decltype(CShuffleBlockTransferClusterLengths),
2184 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2185 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2188 CShuffleBlockTransferScalarPerVector,
2191 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2193 c_grid_desc_mblock_mperblock_nblock_nperblock,
2198 constexpr auto sfc_c_vgpr =
2209 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
2211 CShuffleNXdlPerWavePerShuffle /
NXdlPack,
2222 constexpr auto sfc_c_global =
2226 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2228 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>>{};
2230 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2232 static_assert(num_access == sfc_c_global.GetNumOfAccess(),
"wrong!");
2239 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2240 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2242 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2243 c_shuffle_block_buf);
2249 c_shuffle_block_copy_lds_to_global.Run(
2250 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2251 c_shuffle_block_buf,
2252 c_grid_desc_mblock_mperblock_nblock_nperblock,
2255 if constexpr(access_id < num_access - 1)
2257 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
2260 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
2261 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
2267 template <
bool HasMainKBlockLoop,
2270 __device__
static void Run_2Lds(
const ADataType* p_a_grid,
2271 const AScaleDataType* p_a_scale_grid,
2272 const BDataType* p_b_grid,
2273 const BScaleDataType* p_b_scale_grid,
2274 CDataType* p_c_grid,
2277 const Problem& problem)
2282 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
2283 const auto b_grid_desc_bk0_n_bk1 =
2287 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
2288 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2290 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2295 const auto Padded_Scale_M =
2319 Run_2Lds<
decltype(a_grid_desc_ak0_m_ak1),
2320 decltype(a_scale_grid_desc_am_ak),
2321 decltype(b_grid_desc_bk0_n_bk1),
2322 decltype(b_scale_grid_desc_bn_ak),
2323 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2325 CGlobalMemoryDataOperation,
2334 a_grid_desc_ak0_m_ak1,
2335 a_scale_grid_desc_am_ak,
2336 b_grid_desc_bk0_n_bk1,
2337 b_scale_grid_desc_bn_ak,
2338 c_grid_desc_mblock_mperblock_nblock_nperblock);
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
ushort bhalf_t
Definition data_type.hpp:30
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition get_id.hpp:45
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
constexpr auto BlockGemmMXBPreshufflePipeline_Selector()
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp:37
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
_Float16 half_t
Definition data_type.hpp:31
__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:185
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Full
Definition blkgemmpipe_scheduler.hpp:49
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr auto generate_sequence_v2(F &&f, Number< N >)
Definition sequence_helper.hpp:25
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
__global__ enable_if_t<!Use2LDS, void > kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:40
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
constexpr index_t packed_size_v
Definition data_type.hpp:411
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
signed int int32_t
Definition stdint.h:123
Definition block_to_ctile_map.hpp:271
__host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition block_to_ctile_map.hpp:283
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:711
__host__ Argument(const ADataType *p_a_grid_, const AScaleDataType *p_a_scale_grid_, const BDataType *p_b_grid_, const BScaleDataType *p_b_scale_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_, bool is_reduce_=false)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:712
const BDataType * p_b_grid
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:763
const BScaleDataType * p_b_scale_grid
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:764
__host__ __device__ bool IsReduceAdd() const
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:751
__host__ __device__ bool IsAtomicAdd() const
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:756
CDataType * p_c_grid
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:765
const CElementwiseOperation c_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:769
const AElementwiseOperation a_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:767
const BElementwiseOperation b_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:768
bool is_reduce
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:770
const ADataType * p_a_grid
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:761
const AScaleDataType * p_a_scale_grid
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:762
index_t StrideScaleB
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:696
index_t M
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:690
index_t N
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:691
index_t MPadded
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:699
index_t NPadded
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:700
__host__ void Print() const
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:679
index_t KBatch
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:698
index_t StrideA
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:693
index_t NBlock
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:706
index_t K
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:692
index_t BK0
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:704
index_t StrideScaleA
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:694
index_t MBlock
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:705
index_t StrideC
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:697
index_t AK0
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:703
index_t KPadded
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:702
index_t KRead
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:701
index_t StrideB
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:695
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, index_t StrideC_, index_t KBatch_)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:650
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:776
index_t c_reduce_offset
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:835
index_t a_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:831
index_t a_scale_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:833
index_t b_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:832
index_t b_scale_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:834
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:156
static __device__ void Run(const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const AScaleGridDesc_AM_AK &a_scale_grid_desc_am_ak, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const BScaleGridDesc_BN_AK &b_scale_grid_desc_bn_ak, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:1274
__host__ static __device__ constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:287
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::scale_pack_size_b static constexpr index_t scale_pack_size_b
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:211
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::AK0Number static constexpr auto AK0Number
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:170
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::I3 static constexpr auto I3
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:161
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::mx_scale_t e8m0_bexp_t mx_scale_t
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:209
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::BPackedSize static constexpr index_t BPackedSize
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:190
__host__ static __device__ auto CalculateBN0Shuffled(index_t N)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:249
static __host__ auto CalculateMPadded(index_t M)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:222
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:258
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::MXdlPack static constexpr auto MXdlPack
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:179
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::BK0Number static constexpr auto BK0Number
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:171
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::NWave static constexpr index_t NWave
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:204
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::ThisThreadBlock ThisThreadBlock< BlockSize > ThisThreadBlock
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:207
__host__ static __device__ constexpr auto MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1 &)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:569
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::I4 static constexpr auto I4
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:162
static __host__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:1233
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::I6 static constexpr auto I6
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:164
static __device__ constexpr index_t GetSharedMemoryNumberOfByte()
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:1012
__host__ static __device__ auto CalculateBK0Shuffled(index_t K)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:253
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::I7 static constexpr auto I7
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:165
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::I1 static constexpr auto I1
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:159
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:217
static __device__ constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:958
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::KXdlPack static constexpr auto KXdlPack
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:181
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:588
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::KRepeat static constexpr index_t KRepeat
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:205
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::AK1Number static constexpr auto AK1Number
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:172
static __device__ void Run(const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:1712
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:436
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::I2 static constexpr auto I2
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:160
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::BlockwiseGemmPipe remove_cvref_t< decltype(BlockGemmMXBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ComputeTypeA, GemmAccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:983
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:328
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::I8 static constexpr auto I8
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:166
static __host__ auto CalculateKPadded(index_t K)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:232
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::is_scale_mfma static constexpr auto is_scale_mfma
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:177
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::I9 static constexpr auto I9
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:167
static __device__ constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:838
static __host__ auto CalculateNBlock(index_t N)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:276
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::KLane static constexpr index_t KLane
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:203
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::lcm_AK1_BK1 static constexpr auto lcm_AK1_BK1
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:175
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:427
static __device__ void Run_2Lds(const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:2270
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::NLane static constexpr index_t NLane
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:202
static __host__ auto CalculateNPadded(index_t N)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:227
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:264
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::NXdlPack static constexpr auto NXdlPack
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:180
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::KPack static constexpr index_t KPack
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:192
__host__ static __device__ constexpr auto MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1 &)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:579
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::CheckValidity static __host__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:1037
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:243
__host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:1248
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:237
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::scale_pack_size_a static constexpr index_t scale_pack_size_a
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:210
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::I0 static constexpr auto I0
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:158
static __device__ void Run_2Lds(const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const AScaleGridDesc_AM_AK &a_scale_grid_desc_am_ak, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const BScaleGridDesc_BN_AK &b_scale_grid_desc_bn_ak, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:1773
static __device__ constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:968
static __host__ constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:1240
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::I5 static constexpr auto I5
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:163
static __host__ auto CalculateMBlock(index_t M)
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:271
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::is_single_rate_mfma static constexpr bool is_single_rate_mfma
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:176
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::Block2CTileMap BlockToCTileMap_Grouped_M00_N0_M01Adapt< 8, MPerBlock, NPerBlock > Block2CTileMap
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:1263
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::BK1Number static constexpr auto BK1Number
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:173
ck::GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::APackedSize static constexpr index_t APackedSize
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:189
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Definition thread_group_tensor_slice_transfer_direct_load.hpp:55
Definition thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Unsigned representation of a conventional biased Float32 exponent.
Definition utility/e8m0.hpp:26
Definition data_type.hpp:42
Definition data_type.hpp:187
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
#define CK_ENV(name)
Definition utility/env.hpp:129