9#include <initializer_list>
28template <
typename DeviceOp,
typename Gr
idwiseOp,
bool HasMainKBlockLoop, TailNumber TailNum>
30#if CK_USE_LAUNCH_BOUNDS
35#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
37 __shared__
char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
38 const index_t num_blocks_per_batch =
39 __builtin_amdgcn_readfirstlane(
get_grid_size() / arg.batch_count);
43 __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx)));
45 __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
47 __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
49 __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
51 GridwiseOp::template Run<HasMainKBlockLoop, TailNum>(
52 arg.p_a_grid + a_batch_offset,
53 arg.p_b0_grid + b0_batch_offset,
54 arg.p_b1_grid + b1_batch_offset,
55 arg.p_c_grid + c_batch_offset,
60 arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
66 arg.block_2_ctile_map);
76template <
typename ALayout,
85 typename CShuffleDataType,
86 typename AElementwiseOperation,
87 typename B0ElementwiseOperation,
88 typename AccElementwiseOperation,
89 typename B1ElementwiseOperation,
90 typename CElementwiseOperation,
106 typename ABlockTransferThreadClusterLengths_K0_M_K1,
107 typename ABlockTransferThreadClusterArrangeOrder,
108 typename ABlockTransferSrcAccessOrder,
112 bool ABlockLdsAddExtraM,
113 typename B0BlockTransferThreadClusterLengths_K0_L_K1,
114 typename B0BlockTransferThreadClusterArrangeOrder,
115 typename B0BlockTransferSrcAccessOrder,
119 bool B0BlockLdsAddExtraL,
120 typename B1BlockTransferThreadClusterLengths_L0_N_L1,
121 typename B1BlockTransferThreadClusterArrangeOrder,
122 typename B1BlockTransferSrcAccessOrder,
126 bool B1BlockLdsAddExtraN,
127 index_t CShuffleMRepeatPerShuffle,
128 index_t CShuffleNRepeatPerShuffle,
129 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
130 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
141 AElementwiseOperation,
142 B0ElementwiseOperation,
143 AccElementwiseOperation,
144 B1ElementwiseOperation,
145 CElementwiseOperation>
166 __host__ __device__
static auto
168 const std::array<index_t, 3>& a_g_m_k_strides_vec)
175 __host__ __device__
static auto
177 const std::array<index_t, 3>& b0_g_l_k_strides_vec)
184 __host__ __device__
static auto
186 const std::array<index_t, 3>& b1_g_n_l_strides_vec)
204 : BatchStrideA_(BatchStrideA),
205 BatchStrideB0_(BatchStrideB0),
206 BatchStrideB1_(BatchStrideB1),
207 BatchStrideC_(BatchStrideC)
213 return g_idx *
static_cast<long_index_t>(BatchStrideA_);
218 return g_idx *
static_cast<long_index_t>(BatchStrideB0_);
223 return g_idx *
static_cast<long_index_t>(BatchStrideB1_);
228 return g_idx *
static_cast<long_index_t>(BatchStrideC_);
249 AElementwiseOperation,
250 B0ElementwiseOperation,
251 AccElementwiseOperation,
252 B1ElementwiseOperation,
253 CElementwiseOperation,
277 ABlockTransferThreadClusterLengths_K0_M_K1,
278 ABlockTransferThreadClusterArrangeOrder,
279 ABlockTransferSrcAccessOrder,
280 ABlockTransferSrcVectorDim,
281 ABlockTransferSrcScalarPerVector,
282 ABlockTransferDstScalarPerVector_K1,
285 B0BlockTransferThreadClusterLengths_K0_L_K1,
286 B0BlockTransferThreadClusterArrangeOrder,
287 B0BlockTransferSrcAccessOrder,
288 B0BlockTransferSrcVectorDim,
289 B0BlockTransferSrcScalarPerVector,
290 B0BlockTransferDstScalarPerVector_K1,
293 B1BlockTransferThreadClusterLengths_L0_N_L1,
294 B1BlockTransferThreadClusterArrangeOrder,
295 B1BlockTransferSrcAccessOrder,
296 B1BlockTransferSrcVectorDim,
297 B1BlockTransferSrcScalarPerVector,
298 B1BlockTransferDstScalarPerVector_L1,
301 CShuffleMRepeatPerShuffle,
302 CShuffleNRepeatPerShuffle,
303 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
304 CShuffleBlockTransferScalarPerVector_NPerBlock,
311 using arr3 = std::array<ck::index_t, 3>;
314 const B0DataType* p_b0_grid_,
315 const B1DataType* p_b1_grid_,
316 CDataType* p_c_grid_,
330 AElementwiseOperation a_element_op_,
331 B0ElementwiseOperation b0_element_op_,
332 AccElementwiseOperation acc_element_op_,
333 B1ElementwiseOperation b1_element_op_,
334 CElementwiseOperation c_element_op_)
361 ?
arr3{BatchStrideB1, 1, StrideB1}
362 :
arr3{BatchStrideB1, StrideB1, 1};
419 const char* curFunc = __func__;
420 auto print = [&curFunc](
const char* format, ...) ->
void {
423#if defined(__clang__)
424#pragma clang diagnostic push
425#pragma clang diagnostic ignored "-Wformat-nonliteral"
428 va_start(args, format);
429 std::vfprintf(stdout, format, args);
431#if defined(__clang__)
432#pragma clang diagnostic pop
434 std::cout <<
"In file: " << __FILE__ <<
", function: " << curFunc <<
"\n";
440 print(
"DeviceOp: Arch err\n");
444 if constexpr(std::is_same_v<ADataType, f8_t> || std::is_same_v<ADataType, bf8_t> ||
445 std::is_same_v<B0DataType, f8_t> || std::is_same_v<B0DataType, bf8_t> ||
446 std::is_same_v<B1DataType, f8_t> || std::is_same_v<B1DataType, bf8_t>)
450 print(
"DeviceOp: gfx 11 does not support fp8\n");
457 print(
"DeviceOp: Acc0 Type err\n");
463 print(
"DeviceOp: A layout must be Row\n");
469 print(
"DeviceOp: B layout must be Column\n");
476 print(
"DeviceOp: B1 layout must be Column or Row\n");
482 print(
"DeviceOp: C layout must be Row\n");
490 print(
"Padding mode must be default or MNKO\n");
495 if constexpr(MPerWmma != 16 || LPerWmma != 16 ||
NPerWmma != 16)
497 print(
"M, L, N per Wmma must be 16\n");
505 arg.block_2_ctile_map))
511 const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M;
512 const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N;
513 const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O;
514 const auto c_extent_lowest = arg.O;
516 if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
517 b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
518 b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
519 c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
521 print(
"DeviceOp: Data Transfer Vector scalar err\n");
526 const auto a_stride_lowest =
527 ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1];
528 const auto b0_stride_lowest =
529 B0BlockTransferSrcVectorDim == 2 ? arg.b0_g_n_k_strides[2] : arg.b0_g_n_k_strides[1];
530 const auto b1_stride_lowest =
531 B1BlockTransferSrcVectorDim == 2 ? arg.b1_g_o_n_strides[2] : arg.b1_g_o_n_strides[1];
532 const auto c_stride_lowest = arg.c_g_m_o_strides[2];
534 if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
535 c_stride_lowest == 1))
537 print(
"DeviceOp: Data Vectorize transfer err\n");
566 auto launch_kernel = [&](
auto has_main_k_block_loop,
auto tail_number) {
567 constexpr bool has_loop =
decltype(has_main_k_block_loop)
::value;
574 stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg);
584 return launch_kernel(std::integral_constant<bool, true>{},
585 std::integral_constant<TailNumber, TailNumber::Full>{});
589 return launch_kernel(std::integral_constant<bool, false>{},
590 std::integral_constant<TailNumber, TailNumber::Full>{});
594 printf(
"Invalid HasMainKBlockLoop and TailNum combination for V1!\n");
603 std::integral_constant<TailNumber, TailNumber::Full>{});
608 std::integral_constant<TailNumber, TailNumber::Even>{});
613 std::integral_constant<TailNumber, TailNumber::Odd>{});
617 printf(
"Invalid HasMainKBlockLoop and TailNum combination for V3!\n");
623 printf(
"Invalid pipeline version!\n");
632 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
654 AElementwiseOperation a_element_op,
655 B0ElementwiseOperation b0_element_op,
656 AccElementwiseOperation acc_element_op,
657 B1ElementwiseOperation b1_element_op,
658 CElementwiseOperation c_element_op)
override
660 return std::make_unique<RawArg>(
static_cast<const ADataType*
>(p_a),
661 static_cast<const B0DataType*
>(p_b0),
662 static_cast<const B1DataType*
>(p_b1),
663 static_cast<CDataType*
>(p_c),
689 return std::make_unique<Invoker>(
Invoker{});
692 template <
typename T>
695 if constexpr(std::is_same_v<T, float>)
699 else if constexpr(std::is_same_v<T, ck::half_t>)
703 else if constexpr(std::is_same_v<T, ck::bhalf_t>)
707 else if constexpr(std::is_same_v<T, ck::f8_t>)
711 else if constexpr(std::is_same_v<T, ck::bf8_t>)
715 else if constexpr(std::is_same_v<T, int32_t>)
719 else if constexpr(std::is_same_v<T, int8_t>)
723 else if constexpr(std::is_same_v<T, ck::int4_t>)
736 auto str = std::stringstream();
738 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
742 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
750 str <<
"DeviceBatchedGemmGemm_Wmma_CShuffleV3"
755 << CLayout::name[0] <<
", "
770 << LTilePerBlock <<
", "
774 <<
"BlkGemmPipelineScheduler: "
775 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] <<
", "
776 <<
"BlkGemmPipelineVersion: "
777 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] <<
", "
778 <<
"BlkGemmPipelinePrefetchStages: "
779 << GridwiseOp::BlockwiseGemmPipe::PrefetchStages;
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
@ Default
Definition tensor_specialization.hpp:12
GemmSpecialization
Definition gemm_specialization.hpp:11
@ Default
Definition gemm_specialization.hpp:13
@ MNKOPadding
Definition gemm_specialization.hpp:29
__global__ void kernel_batched_gemm_gemm_wmma_cshuffle_v3(typename DeviceOp::RawArg arg)
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:33
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Full
Definition blkgemmpipe_scheduler.hpp:49
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
int64_t long_index_t
Definition ck.hpp:300
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
Definition ck/stream_config.hpp:10
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:88
ck::GridwiseBatchedGemmGemm_wmma_cshuffle_v3< ADataType, B0DataType, AccDataType, B1DataType, AccDataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, BlkGemmPipeSched, BlkGemmPipelineVer >::DefaultBlock2CTileMap remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))> DefaultBlock2CTileMap
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:498
ck::GridwiseBatchedGemmGemm_wmma_cshuffle_v3< ADataType, B0DataType, AccDataType, B1DataType, AccDataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, BlkGemmPipeSched, BlkGemmPipelineVer >::CalculateKBlockLoopTailNum __host__ static __device__ constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:462
ck::GridwiseBatchedGemmGemm_wmma_cshuffle_v3< ADataType, B0DataType, AccDataType, B1DataType, AccDataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, BlkGemmPipeSched, BlkGemmPipelineVer >::CalculateHasMainKBlockLoop __host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:456
ck::GridwiseBatchedGemmGemm_wmma_cshuffle_v3< ADataType, B0DataType, AccDataType, B1DataType, AccDataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, BlkGemmPipeSched, BlkGemmPipelineVer >::MakeDefaultBlock2CTileMap __host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n, index_t, index_t)
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:488
ck::GridwiseBatchedGemmGemm_wmma_cshuffle_v3< ADataType, B0DataType, AccDataType, B1DataType, AccDataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, BlkGemmPipeSched, BlkGemmPipelineVer >::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock __host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:469
ck::GridwiseBatchedGemmGemm_wmma_cshuffle_v3< ADataType, B0DataType, AccDataType, B1DataType, AccDataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, BlkGemmPipeSched, BlkGemmPipelineVer >::CheckValidity __host__ static __device__ constexpr bool CheckValidity(const AGridDesc &a_grid_desc, const B0GridDesc &b0_grid_desc, const B1GridDesc &b1_grid_desc, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:363
ck::GridwiseBatchedGemmGemm_wmma_cshuffle_v3< ADataType, B0DataType, AccDataType, B1DataType, AccDataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, BlkGemmPipeSched, BlkGemmPipelineVer >::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:495
Definition utility/sequence.hpp:43
Definition device_base.hpp:197
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:199
__host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:216
ComputeBasePtrOfStridedBatch(index_t BatchStrideA, index_t BatchStrideB0, index_t BatchStrideB1, index_t BatchStrideC)
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:200
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:211
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:226
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:221
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:556
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:629
DeviceOp::RawArg Argument
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:557
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:559
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:310
arr3 a_g_m_k_lengths
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:388
std::array< ck::index_t, 3 > arr3
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:311
arr3 b1_g_o_n_strides
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:393
B1ElementwiseOperation b1_element_op
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:400
AElementwiseOperation a_element_op
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:397
GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:409
AccElementwiseOperation acc_element_op
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:399
arr3 a_g_m_k_strides
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:389
B1GridDesc b1_grid_desc
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:406
const B1DataType * p_b1_grid
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:378
arr3 c_g_m_o_lengths
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:394
index_t N
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:383
index_t batch_count
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:386
arr3 b1_g_o_n_lengths
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:392
arr3 b0_g_n_k_lengths
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:390
CGridDesc_M_N c_grid_desc_m_n
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:407
arr3 b0_g_n_k_strides
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:391
arr3 c_g_m_o_strides
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:395
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:413
B0GridDesc b0_grid_desc
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:405
RawArg(const ADataType *p_a_grid_, const B0DataType *p_b0_grid_, const B1DataType *p_b1_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t O_, index_t Batch, index_t StrideA, index_t StrideB0, index_t StrideB1, index_t StrideC, index_t BatchStrideA, index_t BatchStrideB0, index_t BatchStrideB1, index_t BatchStrideC, AElementwiseOperation a_element_op_, B0ElementwiseOperation b0_element_op_, AccElementwiseOperation acc_element_op_, B1ElementwiseOperation b1_element_op_, CElementwiseOperation c_element_op_)
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:313
CElementwiseOperation c_element_op
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:401
CDataType * p_c_grid
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:379
GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:411
index_t K
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:384
const B0DataType * p_b0_grid
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:377
AGridDesc a_grid_desc
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:404
B0ElementwiseOperation b0_element_op
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:398
index_t O
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:385
index_t M
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:382
const ADataType * p_a_grid
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:376
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:146
static constexpr const char * DataTypeToString()
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:693
TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< Sequence< 1, 1, 1, 1, 1 >, Sequence< MPerBlock, LPerBlock, KPerBlock, NPerBlock >, GemmSpec, TensorSpecialization::Default, TensorSpecialization::Default, TensorSpecialization::Default, TensorSpecialization::Default > Transform
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:157
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:550
decltype(MakeB0GridDescriptor({}, {})) B0GridDesc
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:194
GridwiseBatchedGemmGemm_wmma_cshuffle_v3< ADataType, B0DataType, AccDataType, B1DataType, AccDataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, BlkGemmPipeSched, BlkGemmPipelineVer > GridwiseOp
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:239
__host__ static __device__ auto MakeB1GridDescriptor(const std::array< index_t, 3 > &b1_g_n_l_lengths_vec, const std::array< index_t, 3 > &b1_g_n_l_strides_vec)
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:185
static constexpr index_t NPerWmma
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:153
static bool IsSupportedArgument(const RawArg &arg)
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:416
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:687
decltype(MakeAGridDescriptor({}, {})) AGridDesc
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:193
__host__ static __device__ auto MakeAGridDescriptor(const std::array< index_t, 3 > &a_g_m_k_lengths_vec, const std::array< index_t, 3 > &a_g_m_k_strides_vec)
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:167
static auto MakeInvoker()
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:684
__host__ static __device__ auto MakeB0GridDescriptor(const std::array< index_t, 3 > &b0_g_l_k_lengths_vec, const std::array< index_t, 3 > &b0_g_l_k_strides_vec)
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:176
decltype(Transform::MakeCGridDescriptor_M_N({}, {})) CGridDesc_M_N
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:196
decltype(MakeB1GridDescriptor({}, {})) B1GridDesc
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:195
DeviceBatchedGemmGemm_Wmma_CShuffleV3 DeviceOp
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:147
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b0, const void *p_b1, void *p_c, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t O, ck::index_t Batch, ck::index_t StrideA, ck::index_t StrideB0, ck::index_t StrideB1, ck::index_t StrideC, ck::index_t BatchStrideA, ck::index_t BatchStrideB0, ck::index_t BatchStrideB1, ck::index_t BatchStrideC, AElementwiseOperation a_element_op, B0ElementwiseOperation b0_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op) override
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:637
std::string GetTypeString() const override
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:734
static constexpr auto I0
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:149
Definition device_batched_gemm_gemm.hpp:29
#define CK_ENV(name)
Definition utility/env.hpp:129