24template <
typename ADataType,
31 typename AElementwiseOperation,
32 typename BElementwiseOperation,
33 typename CElementwiseOperation,
43 typename ABlockTransferThreadClusterLengths_K0_M_K1,
44 typename ABlockTransferThreadClusterArrangeOrder,
45 typename ABlockTransferSrcAccessOrder,
50 typename BBlockTransferThreadClusterLengths_K0_N_K1,
51 typename BBlockTransferThreadClusterArrangeOrder,
52 typename BBlockTransferSrcAccessOrder,
57 index_t CShuffleMRepeatPerShuffle,
58 index_t CShuffleNRepeatPerShuffle,
59 typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
60 index_t CBlockTransferScalarPerVector_NWaveNPerXDL>
67 AElementwiseOperation,
68 BElementwiseOperation,
69 CElementwiseOperation>
80 template <index_t NXdlPerWave_>
93 AElementwiseOperation,
94 BElementwiseOperation,
95 CElementwiseOperation,
104 ABlockTransferThreadClusterLengths_K0_M_K1,
105 ABlockTransferThreadClusterArrangeOrder,
106 ABlockTransferSrcAccessOrder,
107 ABlockTransferSrcVectorDim,
108 ABlockTransferSrcScalarPerVector,
109 ABlockTransferDstScalarPerVector_K1,
112 BBlockTransferThreadClusterLengths_K0_N_K1,
113 BBlockTransferThreadClusterArrangeOrder,
114 BBlockTransferSrcAccessOrder,
115 BBlockTransferSrcVectorDim,
116 BBlockTransferSrcScalarPerVector,
117 BBlockTransferDstScalarPerVector_K1,
120 CShuffleMRepeatPerShuffle,
121 CShuffleNRepeatPerShuffle,
122 CBlockTransferScalarPerVector_NWaveNPerXDL,
123 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
127 using Argument =
typename GridwiseGemm64::Argument;
132 template <
typename Argument_>
138 template <
typename Gr
idwiseGemm>
139 float RunImp(
const typename GridwiseGemm::Argument& karg,
142 if(stream_config.log_level_ > 0)
146 if(!GridwiseGemm::CheckValidity(karg))
148 throw std::runtime_error(
149 "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid "
153 dim3 grid_dims = karg.block_mapping.get_grid_dims();
160 if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
163 hipGetErrorString(hipMemsetAsync(karg.p_c_grid,
165 karg.M * karg.N *
sizeof(CDataType),
166 stream_config.stream_id_));
184 else if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
187 char* workspace_semaphore =
reinterpret_cast<char*
>(karg.p_workspace_) +
188 karg.block_mapping.get_workspace_size_for_acc(
189 sizeof(
typename GridwiseGemm::FloatAcc));
190 auto preprocess = [&]() {
192 hipMemsetAsync(workspace_semaphore,
194 karg.block_mapping.get_workspace_size_for_semaphore(),
195 stream_config.stream_id_));
226 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
235 if constexpr(GridwiseGemm64::Block2CTileMap::ReductionStrategy ==
238 return p_arg->block_mapping.get_workspace_size(
244 if constexpr(GridwiseGemm32::Block2CTileMap::ReductionStrategy ==
247 return p_arg->block_mapping.get_workspace_size(
260 pArg_->p_workspace_ = p_workspace;
287 reinterpret_cast<const typename GridwiseGemm32::Argument&
>(karg));
300 const BDataType* p_b,
308 AElementwiseOperation,
309 BElementwiseOperation,
310 CElementwiseOperation,
315 int occupancy = [&]() {
322 rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
335 rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
346 hipDeviceProp_t dev_prop;
348 rtn = hipGetDevice(&dev);
350 rtn = hipGetDeviceProperties(&dev_prop, dev);
352 num_cu = dev_prop.multiProcessorCount;
380 AElementwiseOperation,
381 BElementwiseOperation,
382 CElementwiseOperation,
383 index_t NumSKBlocks = 0)
override
388 int occupancy = [&]() {
395 rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
408 rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
419 hipDeviceProp_t dev_prop;
421 rtn = hipGetDevice(&dev);
423 rtn = hipGetDeviceProperties(&dev_prop, dev);
425 num_cu = dev_prop.multiProcessorCount;
427 return std::make_unique<Argument>(
reinterpret_cast<const ADataType*
>(p_a),
428 reinterpret_cast<const BDataType*
>(p_b),
429 reinterpret_cast<CDataType*
>(p_c),
438 static_cast<uint32_t>(NumSKBlocks));
444 return std::make_unique<Invoker>(
Invoker{});
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
void hip_check_error(hipError_t x)
Definition host_utility/hip_check_error.hpp:10
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:91
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
@ Atomic
Definition block_to_ctile_map.hpp:1012
@ Reduction
Definition block_to_ctile_map.hpp:1013
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
__global__ void kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB *p_a_grid, const typename GridwiseGemm::FloatAB *p_b_grid, typename GridwiseGemm::FloatC *p_c_grid, void *p_workspace, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, typename GridwiseGemm::Block2CTileMap block_mapping)
Definition gridwise_gemm_xdlops_streamk.hpp:28
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
unsigned int uint32_t
Definition stdint.h:126
Definition ck/stream_config.hpp:10
Definition block_to_ctile_map.hpp:1022
Definition gridwise_gemm_xdlops_streamk.hpp:115
ck::GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk< BlockSize, BlockToCTileMap_GemmStreamK< MPerBlock, NPerBlock, K0PerBlock *K1, StreamKReductionStrategy::Atomic >, ADataType, AccDataType, CDataType, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CBlockTransferScalarPerVector_NWaveNPerXDL, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock >::GetTypeString static std::string GetTypeString()
Definition gridwise_gemm_xdlops_streamk.hpp:1163
ck::GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk< BlockSize, BlockToCTileMap_GemmStreamK< MPerBlock, NPerBlock, K0PerBlock *K1, StreamKReductionStrategy::Atomic >, ADataType, AccDataType, CDataType, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CBlockTransferScalarPerVector_NWaveNPerXDL, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock >::CheckValidity __host__ static __device__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_xdlops_streamk.hpp:315
ck::GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk< BlockSize, BlockToCTileMap_GemmStreamK< MPerBlock, NPerBlock, K0PerBlock *K1, StreamKReductionStrategy::Atomic >, ADataType, AccDataType, CDataType, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CBlockTransferScalarPerVector_NWaveNPerXDL, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock >::GetSharedMemoryNumberOfByte __host__ static __device__ constexpr index_t GetSharedMemoryNumberOfByte()
Definition gridwise_gemm_xdlops_streamk.hpp:289
ck::GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk< BlockSize, BlockToCTileMap_GemmStreamK< MPerBlock, NPerBlock, K0PerBlock *K1, StreamKReductionStrategy::Atomic >, ADataType, AccDataType, CDataType, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CBlockTransferScalarPerVector_NWaveNPerXDL, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock >::FloatAcc AccDataType FloatAcc
Definition gridwise_gemm_xdlops_streamk.hpp:132
Definition device_base.hpp:197
Definition device_gemm_streamk.hpp:25
Definition device_gemm_xdl_streamk.hpp:131
void Print(const Argument_ &karg)
Definition device_gemm_xdl_streamk.hpp:133
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_xdl_streamk.hpp:223
float RunImp(const typename GridwiseGemm::Argument &karg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_xdl_streamk.hpp:139
Definition device_gemm_xdl_streamk.hpp:70
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_xdl_streamk.hpp:263
static constexpr auto I3
Definition device_gemm_xdl_streamk.hpp:78
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, uint32_t NumSKBlocks=0xffffffff)
Definition device_gemm_xdl_streamk.hpp:299
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_xdl_streamk.hpp:294
static auto MakeInvoker()
Definition device_gemm_xdl_streamk.hpp:368
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_xdl_streamk.hpp:442
std::string GetTypeString() const override
Definition device_gemm_xdl_streamk.hpp:448
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_xdl_streamk.hpp:124
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, index_t NumSKBlocks=0) override
Definition device_gemm_xdl_streamk.hpp:371
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_xdl_streamk.hpp:125
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk< BlockSize, BlockToCTileMap_GemmStreamK< MPerBlock, NPerBlock, K0PerBlock *K1, StreamKReductionStrategy::Atomic >, ADataType, AccDataType, CDataType, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CBlockTransferScalarPerVector_NWaveNPerXDL, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock > GridwiseGemmBase
Definition device_gemm_xdl_streamk.hpp:81
typename GridwiseGemm64::Argument Argument
Definition device_gemm_xdl_streamk.hpp:127
static constexpr auto I2
Definition device_gemm_xdl_streamk.hpp:77
static bool IsSupportedArgument(const Argument &karg)
Definition device_gemm_xdl_streamk.hpp:269
static constexpr auto I0
Definition device_gemm_xdl_streamk.hpp:75
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_xdl_streamk.hpp:72
void SetWorkSpacePointer(BaseArgument *pArg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition device_gemm_xdl_streamk.hpp:254
static constexpr auto NXdlPerWave32
Definition device_gemm_xdl_streamk.hpp:73
size_t GetWorkSpaceSize(const BaseArgument *pArg) const override
Definition device_gemm_xdl_streamk.hpp:230
static constexpr auto I1
Definition device_gemm_xdl_streamk.hpp:76