grouped_gemm_kernel.hpp Source File

grouped_gemm_kernel.hpp Source File#

Composable Kernel: grouped_gemm_kernel.hpp Source File
grouped_gemm_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
13#include "ck_tile/host.hpp"
14
15#include <hip/hip_runtime.h>
16
17namespace ck_tile {
18
26
27template <index_t NumDTensor = 0>
29{
31 const void* b_ptr_,
32 const std::array<const void*, NumDTensor>& ds_ptr_,
33 void* e_ptr_,
34 index_t k_batch_,
35 index_t M_,
36 index_t N_,
37 index_t K_,
38 index_t stride_A_,
39 index_t stride_B_,
40 const std::array<index_t, NumDTensor>& stride_Ds_,
41 index_t stride_E_)
42 : a_ptr(a_ptr_),
43 b_ptr(b_ptr_),
44 ds_ptr(ds_ptr_),
45 e_ptr(e_ptr_),
46 M(M_),
47 N(N_),
48 K(K_),
49 stride_A(stride_A_),
50 stride_B(stride_B_),
51 stride_Ds(stride_Ds_),
52 stride_E(stride_E_),
53 k_batch(k_batch_)
54 {
55 }
56
57 const void* a_ptr;
58 const void* b_ptr;
59 const std::array<const void*, NumDTensor> ds_ptr;
60 union
61 {
62 void* e_ptr;
63 void* c_ptr;
64 };
65
71 const std::array<index_t, NumDTensor> stride_Ds;
72 union
73 {
76 };
77
79};
80
81template <index_t NumDTensor = 0>
101
102template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
104{
108
112
117
123
124 static constexpr index_t NumDTensor_ = DsDataType::size();
125
127 static_assert(
129 "ALayout and ADataType must be scalars. Multiple parameters are not currently supported.");
130
132 static_assert(
134 "BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
135
139 "C/CLayout and C/EDataType must be scalars.");
140
143
144 static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
145 static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel;
146
147 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
148 {
149 // clang-format off
150 using P_ = GemmPipeline;
151
152 return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>(),
153 concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
154 concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
155 concat('x', P_::kPadM, P_::kPadN, P_::kPadK),
156 (UsePersistentKernel ? "Persistent" : "NonPersistent"),
157 (NumDTensor_ == 2 ? "MultiD" : "NoMultiD"),
158 (GemmPipeline::DoubleSmemBuffer ? "DoubleSmemBuffer" : "SingleSmemBuffer"));
159 // clang-format on
160 }
161
162 CK_TILE_HOST static auto
163 GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs<>>& gemm_descs) -> std::size_t
164 {
165 return gemm_descs.size() * sizeof(GemmTransKernelArg<NumDTensor_>);
166 }
167
168 CK_TILE_HOST static auto GetWorkSpaceSize(index_t group_count) -> std::size_t
169 {
170 return group_count * sizeof(GemmTransKernelArg<NumDTensor_>);
171 }
172
173 CK_TILE_HOST static auto BlockSize() -> dim3
174 {
175 if(is_wave32())
176 {
177 return dim3(kBlockSize / 2);
178 }
179 else
180 {
181 return dim3(kBlockSize);
182 }
183 }
184
191 CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
192 {
193 using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*;
195 int occupancy;
197 hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));
198 const int grid_size = get_available_compute_units(s) * occupancy;
199 return dim3(grid_size, 1, 1);
200 }
201
202 CK_TILE_HOST static auto
203 GridSize(const std::vector<GroupedGemmHostArgs<NumDTensor_>>& gemm_descs)
204 {
205 index_t grid_size = 0;
206 for(const auto& it_desc : gemm_descs)
207 {
208 const auto local_grid_size = TilePartitioner::GridSize(it_desc.M, it_desc.N);
209 grid_size += local_grid_size * it_desc.k_batch;
210 }
211 return dim3(grid_size, 1, 1);
212 }
213
214 CK_TILE_HOST static auto
215 MakeKargs(const std::vector<GroupedGemmHostArgs<NumDTensor_>>& gemm_descs)
216 -> std::vector<GemmTransKernelArg<NumDTensor_>>
217 {
218 std::vector<GemmTransKernelArg<NumDTensor_>> gemm_kernel_args_;
219 index_t group_count = ck_tile::type_convert<ck_tile::index_t>(gemm_descs.size());
220 index_t grid_size = 0;
221 gemm_kernel_args_.reserve(group_count);
222
223 for(std::size_t i = 0; i < gemm_descs.size(); ++i)
224 {
225 const index_t M = gemm_descs[i].M;
226 const index_t N = gemm_descs[i].N;
227 const index_t K = gemm_descs[i].K;
228
229 if(M == 0 || N == 0 || K == 0)
230 {
231 continue;
232 }
233
234 const index_t stride_a = gemm_descs[i].stride_A;
235 const index_t stride_b = gemm_descs[i].stride_B;
236 const index_t stride_e = gemm_descs[i].stride_E;
237 auto stride_ds = gemm_descs[i].stride_Ds;
238
239 const index_t grid_size_grp = TilePartitioner::GridSize(M, N) * gemm_descs[i].k_batch;
240
241 const index_t block_start = grid_size;
242 const index_t block_end = grid_size + grid_size_grp;
243
244 grid_size += grid_size_grp;
245
247 {type_convert<const ADataType*>(gemm_descs[i].a_ptr)},
248 {type_convert<const BDataType*>(gemm_descs[i].b_ptr)},
249 {gemm_descs[i].ds_ptr},
250 type_convert<CDataType*>(gemm_descs[i].e_ptr),
251 M,
252 N,
253 K,
254 {stride_a},
255 {stride_b},
256 stride_ds,
257 stride_e,
258 gemm_descs[i].k_batch};
259
260 gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
261 }
262
263 return gemm_kernel_args_;
264 }
265
266 CK_TILE_HOST static bool
268 {
269 for(const auto& karg : kargs)
270 {
271 if(!Base::IsSupportedArgument(karg.group_karg))
272 {
273 return false;
274 }
275 }
276 return true;
277 }
278
279 CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() -> index_t
280 {
281 return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
282 }
283
285 const tuple<index_t, index_t>& block_idx_2d,
286 const index_t block_idx_z) const
287 {
288
289 static_assert(GemmPipeline::DoubleSmemBuffer || !GemmPipeline::Preshuffle,
290 "SingleSmemBuffer and Preshuffle cannot both be enabled simultaneously!");
291
292 const auto [iM, iN] = block_idx_2d;
293
294 const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
295 const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
296
297 const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, block_idx_z);
298
299 const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) +
300 splitk_batch_offset.as_k_split_offset[0];
301 const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) +
302 splitk_batch_offset.bs_k_split_offset[0];
303 CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
304
305 // allocate LDS
306 __shared__ char smem_ptr_0[GetSmemSize()];
307
308 // TO DO:
309 // Can we simplify this branching logic?
310 if constexpr(GemmPipeline::DoubleSmemBuffer == true)
311 {
312
313 __shared__ char smem_ptr_1[GetSmemSize()];
315 b_ptr,
316 c_ptr,
317 kargs.ds_ptr,
318 smem_ptr_0,
319 smem_ptr_1,
320 kargs,
321 splitk_batch_offset,
322 i_m,
323 i_n);
324 }
325 else // SingleSmemBuffer
326 {
327
328 if constexpr(UsePersistentKernel)
329 {
331 b_ptr,
332 kargs.ds_ptr,
333 c_ptr,
334 smem_ptr_0,
335 kargs,
336 splitk_batch_offset,
337 i_m,
338 i_n);
339 }
340 else // Non-persistent kernel
341 {
342 Base::RunGemm({a_ptr},
343 {b_ptr},
344 kargs.ds_ptr,
345 c_ptr,
346 smem_ptr_0,
347 kargs,
348 splitk_batch_offset,
349 i_m,
350 i_n);
351 }
352 }
353 }
354
373 CK_TILE_DEVICE static void
375 const BDataType* b_ptr,
376 const std::array<const void*, NumDTensor_>& ds_ptr,
377 CDataType* c_ptr,
378 void* smem_ptr_0,
380 const typename Base::SplitKBatchOffset& splitk_batch_offset,
381 const index_t block_idx_m,
382 const index_t block_idx_n)
383 {
384 // Create Gemm tensor views, pad views and tile windows
385 const auto& gemm_tensor_views_tuple =
386 Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
387 {a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k);
388
389 const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
390 auto gemm_tile_windows =
391 Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
392 const auto& a_block_window = gemm_tile_windows.at(Base::I0);
393 const auto& b_block_window = gemm_tile_windows.at(Base::I1);
394 const auto& d_block_window = gemm_tile_windows.at(Base::I2);
395
396 // Get hot-loop and tail configuration
397 const index_t num_loop =
398 amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
399 const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
400 const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
401
402 // Run GEMM pipeline
403 const auto& c_block_tile = GemmPipeline{}.template operator()(
404 a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
405 // Run Epilogue Pipeline
406 auto& c_block_window = gemm_tile_windows.at(Base::I3);
407 EpiloguePipeline{}.template
408 operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
409 c_block_window, c_block_tile, d_block_window, smem_ptr_0);
410 }
411
431 CK_TILE_DEVICE static void
433 const BDataType* b_ptr,
434 CDataType* c_ptr,
435 const std::array<const void*, NumDTensor_>& ds_ptr,
436 void* __restrict__ smem_ptr_0,
437 void* __restrict__ smem_ptr_1,
439 const typename Base::SplitKBatchOffset& splitk_batch_offset,
440 const index_t block_idx_m,
441 const index_t block_idx_n)
442 {
443 // Create Gemm tensor views, pad views and tile windows
444 const auto& gemm_tensor_views_tuple =
445 Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
446 {a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k);
447
448 const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
449 auto gemm_tile_windows =
450 Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
451 const auto& a_block_window = gemm_tile_windows.at(Base::I0);
452 const auto& b_block_window = gemm_tile_windows.at(Base::I1);
453 const auto& d_block_window = gemm_tile_windows.at(Base::I2);
454
455 // Get hot-loop and tail configuration
456 const index_t num_loop =
457 amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
458 const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
459
460 // Run GEMM pipeline with compile-time branching
461 const auto& c_block_tile = [&]() {
462 if constexpr(GemmPipeline::Preshuffle)
463 {
464 // Preshuffle version - without has_hot_loop parameter
465 return GemmPipeline{}.template operator()(a_block_window[Base::I0],
466 b_block_window[Base::I0],
467 num_loop,
468 tail_num,
469 smem_ptr_0,
470 smem_ptr_1);
471 }
472 else
473 {
474 // Regular version - with has_hot_loop parameter
475 const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
476 return GemmPipeline{}.template operator()(a_block_window[Base::I0],
477 b_block_window[Base::I0],
478 num_loop,
479 has_hot_loop,
480 tail_num,
481 smem_ptr_0,
482 smem_ptr_1);
483 }
484 }();
485
486 // Run Epilogue Pipeline
487 auto& c_block_window = gemm_tile_windows.at(Base::I3);
488 EpiloguePipeline{}.template
489 operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
490 c_block_window, c_block_tile, d_block_window, smem_ptr_0);
491 }
492
494 index_t block_id,
495 index_t group_count) const
496 {
497 index_t left = 0;
498 index_t right = group_count;
499 index_t group_id = index_t((left + right) >> 1);
500
501 while((!(block_id >= gemm_desc_ptr[group_id].block_start &&
502 block_id < gemm_desc_ptr[group_id].block_end)) &&
503 left <= right)
504 {
505 if(block_id < gemm_desc_ptr[group_id].block_start)
506 {
507 right = group_id;
508 }
509 else
510 {
511 left = group_id;
512 }
513 group_id = index_t((left + right) >> 1);
514 }
515
516 return group_id;
517 }
518
519 // For non-persistent kernels
520 template <bool U = UsePersistentKernel, typename = std::enable_if_t<!U>>
521 CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
522 index_t group_count) const
523 {
524 const index_t block_id = ck_tile::get_block_1d_id();
525 const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg<NumDTensor_>*>(
526 cast_pointer_to_generic_address_space(gemm_descs_const));
527
528 const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count);
529 const auto& kargs = gemm_desc_ptr[group_id];
530
531 const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N);
532 const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
533 0,
534 kargs.group_karg.M,
535 kargs.group_karg.N,
536 (block_id - kargs.block_start) % grid_size_2d);
537 Run(kargs.group_karg, block_idx_2d, (block_id - kargs.block_start) / grid_size_2d);
538 }
539
540 // For persistent kernels
541 template <bool U = UsePersistentKernel,
542 typename = std::enable_if_t<U>,
543 typename = void> // extra template parameter to avoid redefinition
544 CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
545 const index_t group_count) const
546 {
547 const index_t grid_size = ck_tile::get_grid_size();
548 const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg<NumDTensor_>*>(
549 cast_pointer_to_generic_address_space(gemm_descs_const));
550 index_t block_id = ck_tile::get_block_1d_id(); // initial block_id
551 index_t cum_grid_size = 0;
552 for(index_t group_id = 0; group_id < group_count; ++group_id)
553 {
554 const auto& kargs = gemm_desc_ptr[group_id].group_karg;
555 const auto& k_batch = kargs.k_batch;
556 const auto block_start = cum_grid_size;
557 cum_grid_size += TilePartitioner::GridSize(kargs.M, kargs.N) * k_batch;
558 while(block_id < cum_grid_size)
559 {
560 const auto grid_size_2d = TilePartitioner::GridSize(kargs.M, kargs.N);
561 const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
562 0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d);
563 Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d);
564 block_id = block_id + grid_size; // advance to next block
565 // NOTE: this check is redundant but helps the compiler avoid spilling some VGPR
566 if(block_id >= cum_grid_size)
567 {
568 break; // exit the loop if all blocks are processed
569 }
570 }
571 }
572 }
573};
574
575} // namespace ck_tile
#define CK_CONSTANT_ADDRESS_SPACE
Definition ck.hpp:23
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
#define HIP_CHECK_ERROR(retval_or_funcall)
Definition host_utility/hip_check_error.hpp:21
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
TailNumber
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:21
__global__ void kentry(Args... args)
Definition tile/host/kernel_launch.hpp:22
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition type_traits.hpp:67
CK_TILE_DEVICE index_t get_block_1d_id()
Definition arch.hpp:98
std::string gemm_prec_str()
Definition utils.hpp:31
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition arch.hpp:307
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_DEVICE index_t get_grid_size()
Definition arch.hpp:89
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_HOST bool is_wave32()
Definition arch.hpp:72
STL namespace.
Definition grouped_gemm_kernel.hpp:83
GemmTransKernelArg(UniversalGemmKernelArgs< 1, 1, NumDTensor > &&karg, index_t bl_start, index_t bl_end)
Definition grouped_gemm_kernel.hpp:89
UniversalGemmKernelArgs< 1, 1, NumDTensor > group_karg
Definition grouped_gemm_kernel.hpp:84
GemmTransKernelArg(UniversalGemmKernelArgs< 1, 1, NumDTensor > &&karg)
Definition grouped_gemm_kernel.hpp:96
ck_tile::index_t block_start
Definition grouped_gemm_kernel.hpp:85
ck_tile::index_t block_end
Definition grouped_gemm_kernel.hpp:86
The Grouped GEMM kernel host arguments.
Definition grouped_gemm_kernel.hpp:29
void * e_ptr
Definition grouped_gemm_kernel.hpp:62
index_t stride_E
Definition grouped_gemm_kernel.hpp:74
CK_TILE_HOST GroupedGemmHostArgs(const void *a_ptr_, const void *b_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_)
Definition grouped_gemm_kernel.hpp:30
index_t stride_C
Definition grouped_gemm_kernel.hpp:75
index_t k_batch
Definition grouped_gemm_kernel.hpp:78
index_t stride_A
Definition grouped_gemm_kernel.hpp:69
index_t M
Definition grouped_gemm_kernel.hpp:66
void * c_ptr
Definition grouped_gemm_kernel.hpp:63
index_t stride_B
Definition grouped_gemm_kernel.hpp:70
const void * b_ptr
Definition grouped_gemm_kernel.hpp:58
const void * a_ptr
Definition grouped_gemm_kernel.hpp:57
index_t N
Definition grouped_gemm_kernel.hpp:67
index_t K
Definition grouped_gemm_kernel.hpp:68
const std::array< const void *, NumDTensor > ds_ptr
Definition grouped_gemm_kernel.hpp:59
const std::array< index_t, NumDTensor > stride_Ds
Definition grouped_gemm_kernel.hpp:71
Definition grouped_gemm_kernel.hpp:104
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition grouped_gemm_kernel.hpp:109
static constexpr index_t NumDTensor_
Definition grouped_gemm_kernel.hpp:124
static CK_TILE_HOST auto GridSize(const std::vector< GroupedGemmHostArgs< NumDTensor_ > > &gemm_descs)
Definition grouped_gemm_kernel.hpp:203
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count) const
Definition grouped_gemm_kernel.hpp:544
static CK_TILE_HOST auto GetWorkSpaceSize(index_t group_count) -> std::size_t
Definition grouped_gemm_kernel.hpp:168
static CK_TILE_HOST auto MakeKargs(const std::vector< GroupedGemmHostArgs< NumDTensor_ > > &gemm_descs) -> std::vector< GemmTransKernelArg< NumDTensor_ > >
Definition grouped_gemm_kernel.hpp:215
static CK_TILE_DEVICE void RunGemmWithPipelineSelection(const ADataType *a_ptr, const BDataType *b_ptr, const std::array< const void *, NumDTensor_ > &ds_ptr, CDataType *c_ptr, void *smem_ptr_0, const UniversalGemmKernelArgs< 1, 1, NumDTensor_ > &kargs, const typename Base::SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition grouped_gemm_kernel.hpp:374
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition grouped_gemm_kernel.hpp:114
CK_TILE_DEVICE void Run(const UniversalGemmKernelArgs< 1, 1, NumDTensor_ > &kargs, const tuple< index_t, index_t > &block_idx_2d, const index_t block_idx_z) const
Definition grouped_gemm_kernel.hpp:284
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition grouped_gemm_kernel.hpp:115
static CK_TILE_HOST bool IsSupportedArgument(const std::vector< GemmTransKernelArg< NumDTensor_ > > &kargs)
Definition grouped_gemm_kernel.hpp:267
UniversalGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > Base
Inject the UniversalGemmKernel base class to support execution of all necessary functions.
Definition grouped_gemm_kernel.hpp:107
static CK_TILE_HOST_DEVICE constexpr auto GetSmemSize() -> index_t
Definition grouped_gemm_kernel.hpp:279
OffsettedTile1DPartitioner< TilePartitioner > OffsetTile1DPartitioner
ALayout and ADataType are expected to be scalars, not a tuple.
Definition grouped_gemm_kernel.hpp:141
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition grouped_gemm_kernel.hpp:120
static constexpr index_t kBlockSize
Definition grouped_gemm_kernel.hpp:144
static CK_TILE_HOST auto GetWorkSpaceSize(const std::vector< GroupedGemmHostArgs<> > &gemm_descs) -> std::size_t
Definition grouped_gemm_kernel.hpp:163
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition grouped_gemm_kernel.hpp:116
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition grouped_gemm_kernel.hpp:111
CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg< NumDTensor_ > *gemm_desc_ptr, index_t block_id, index_t group_count) const
Definition grouped_gemm_kernel.hpp:493
static CK_TILE_HOST auto BlockSize() -> dim3
Definition grouped_gemm_kernel.hpp:173
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition grouped_gemm_kernel.hpp:121
static CK_TILE_DEVICE void RunGemmWithPipelineSelection2LDS(const ADataType *a_ptr, const BDataType *b_ptr, CDataType *c_ptr, const std::array< const void *, NumDTensor_ > &ds_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const UniversalGemmKernelArgs< 1, 1, NumDTensor_ > &kargs, const typename Base::SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition grouped_gemm_kernel.hpp:432
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition grouped_gemm_kernel.hpp:191
static CK_TILE_HOST const std::string GetName()
Definition grouped_gemm_kernel.hpp:147
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition grouped_gemm_kernel.hpp:122
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition grouped_gemm_kernel.hpp:110
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Specify the data type configurations for A, B, C/E.
Definition grouped_gemm_kernel.hpp:119
static constexpr bool UsePersistentKernel
Definition grouped_gemm_kernel.hpp:145
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, index_t group_count) const
Definition grouped_gemm_kernel.hpp:521
GroupedGemmKernel< TilePartitioner, GemmPipeline, EpiloguePipeline > Kernel
Definition grouped_gemm_kernel.hpp:142
Struct used to calculate offseted tile indexes.
Definition gemm_tile_partitioner.hpp:184
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from 1D raw-indexes.
Definition gemm_tile_partitioner.hpp:192
Definition universal_gemm_kernel.hpp:325
std::array< index_t, NumATensor > as_k_split_offset
Definition universal_gemm_kernel.hpp:368
index_t splitted_k
Definition universal_gemm_kernel.hpp:370
std::array< index_t, NumBTensor > bs_k_split_offset
Definition universal_gemm_kernel.hpp:369
The GEMM kernel device arguments.
Definition universal_gemm_kernel.hpp:86
void * e_ptr
The E output tensor's pointer to device memory.
Definition universal_gemm_kernel.hpp:94
const std::array< const void *, NumDTensor > ds_ptr
The Ds input tensor's pointer to device memory.
Definition universal_gemm_kernel.hpp:92
const std::array< const void *, NumATensor > as_ptr
The As input tensor's pointer to device memory.
Definition universal_gemm_kernel.hpp:88
const std::array< const void *, NumBTensor > bs_ptr
The Bs input tensor's pointer to device memory.
Definition universal_gemm_kernel.hpp:90
The Universal GEMM kernel template.
Definition universal_gemm_kernel.hpp:154
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_0, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition universal_gemm_kernel.hpp:955
static constexpr auto I2
Definition universal_gemm_kernel.hpp:238
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition universal_gemm_kernel.hpp:853
static constexpr auto I3
Definition universal_gemm_kernel.hpp:239
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition universal_gemm_kernel.hpp:754
static constexpr auto I1
Definition universal_gemm_kernel.hpp:237
static constexpr auto I0
Definition universal_gemm_kernel.hpp:236
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition universal_gemm_kernel.hpp:373
Definition ck_tile/host/stream_config.hpp:30
Definition tile/core/container/tuple.hpp:192