device_batched_gemm_gemm_xdl_cshuffle.hpp Source File

device_batched_gemm_gemm_xdl_cshuffle.hpp Source File#

Composable Kernel: device_batched_gemm_gemm_xdl_cshuffle.hpp Source File
device_batched_gemm_gemm_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
20
21namespace ck {
22namespace tensor_operation {
23namespace device {
24
25template <typename GridwiseGemm,
26 typename FloatAB,
27 typename FloatC,
28 typename AElementwiseOperation,
29 typename BElementwiseOperation,
30 typename AccElementwiseOperation,
31 typename B1ElementwiseOperation,
32 typename CElementwiseOperation,
33 typename AGridDesc_AK0_M_AK1,
34 typename BGridDesc_BK0_N_BK1,
35 typename B1GridDesc_BK0_N_BK1,
36 typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
37 typename Block2CTileMap,
38 typename ComputeBasePtrOfStridedBatch,
39 bool HasMainKBlockLoop>
40__global__ void
41#if CK_USE_LAUNCH_BOUNDS
43#endif
44 kernel_gemm_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid,
45 const FloatAB* __restrict__ p_b_grid,
46 const FloatAB* __restrict__ p_b1_grid,
47 FloatC* __restrict__ p_c_grid,
48 const AElementwiseOperation a_element_op,
49 const BElementwiseOperation b_element_op,
50 const AccElementwiseOperation acc_element_op,
51 const B1ElementwiseOperation b1_element_op,
52 const CElementwiseOperation c_element_op,
53 const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
54 const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
55 const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
56 const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
57 c_grid_desc_mblock_mperblock_nblock_nperblock,
58 const Block2CTileMap block_2_ctile_map,
59 const index_t batch_count,
60 const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
61{
62#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
63 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
64 {
65 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
66 const index_t num_blocks_per_batch =
67 __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
68 const index_t g_idx =
69 __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
70
71 const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
72 static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
73 const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
74 static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx)));
75 const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
76 static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
77 const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
78 static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
79
80 GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
81 p_b_grid + b_batch_offset,
82 p_b1_grid + b1_batch_offset,
83 p_c_grid + c_batch_offset,
84 p_shared,
85 a_element_op,
86 b_element_op,
87 acc_element_op,
88 b1_element_op,
89 c_element_op,
90 a_grid_desc_ak0_m_ak1,
91 b_grid_desc_bk0_n_bk1,
92 b1_grid_desc_bk0_n_bk1,
93 c_grid_desc_mblock_mperblock_nblock_nperblock,
94 block_2_ctile_map);
95 }
96#else
97 ignore = p_a_grid;
98 ignore = p_b_grid;
99 ignore = p_b1_grid;
100 ignore = p_c_grid;
101 ignore = a_element_op;
102 ignore = b_element_op;
103 ignore = acc_element_op;
104 ignore = b1_element_op;
105 ignore = c_element_op;
106 ignore = a_grid_desc_ak0_m_ak1;
107 ignore = b_grid_desc_bk0_n_bk1;
108 ignore = b1_grid_desc_bk0_n_bk1;
109 ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
110 ignore = block_2_ctile_map;
111 ignore = batch_count;
112 ignore = compute_base_ptr_of_batch;
113#endif // end of if (defined(__gfx9__))
114}
115
116// Computes C = A * B0 * B1
117// ^^^^^^ (Acc0)
118// ^^^^^^^^^^^ (Acc1)
119template <typename ALayout,
120 typename BLayout, // B0Layout
121 typename B1Layout,
122 typename CLayout,
123 typename ADataType,
124 typename BDataType,
125 typename B1DataType,
126 typename CDataType,
127 typename GemmAccDataType,
128 typename CShuffleDataType,
129 typename AElementwiseOperation,
130 typename BElementwiseOperation,
131 typename AccElementwiseOperation,
132 typename B1ElementwiseOperation,
133 typename CElementwiseOperation,
134 GemmSpecialization GemmSpec,
135 index_t NumGemmKPrefetchStage,
136 index_t BlockSize,
137 index_t MPerBlock,
138 index_t NPerBlock, // Gemm0NPerBlock
139 index_t KPerBlock, // Gemm0KPerBlock
140 index_t Gemm1NPerBlock,
141 index_t Gemm1KPerBlock,
142 index_t AK1,
143 index_t BK1,
144 index_t B1K1,
145 index_t MPerXDL,
146 index_t NPerXDL,
147 index_t MXdlPerWave,
148 index_t NXdlPerWave,
149 index_t Gemm1NXdlPerWave,
150 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
151 typename ABlockTransferThreadClusterArrangeOrder,
152 typename ABlockTransferSrcAccessOrder,
153 index_t ABlockTransferSrcVectorDim,
154 index_t ABlockTransferSrcScalarPerVector,
155 index_t ABlockTransferDstScalarPerVector_AK1,
156 bool ABlockLdsExtraM,
157 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
158 typename BBlockTransferThreadClusterArrangeOrder,
159 typename BBlockTransferSrcAccessOrder,
160 index_t BBlockTransferSrcVectorDim,
161 index_t BBlockTransferSrcScalarPerVector,
162 index_t BBlockTransferDstScalarPerVector_BK1,
163 bool BBlockLdsExtraN,
164 typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
165 typename B1BlockTransferThreadClusterArrangeOrder,
166 typename B1BlockTransferSrcAccessOrder,
167 index_t B1BlockTransferSrcVectorDim,
168 index_t B1BlockTransferSrcScalarPerVector,
169 index_t B1BlockTransferDstScalarPerVector_BK1,
170 bool B1BlockLdsExtraN,
171 index_t CShuffleMXdlPerWavePerShuffle,
172 index_t CShuffleNXdlPerWavePerShuffle,
173 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
174 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
177 BLayout,
178 B1Layout,
179 CLayout,
180 ADataType,
181 BDataType,
182 B1DataType,
183 CDataType,
184 AElementwiseOperation,
185 BElementwiseOperation,
186 AccElementwiseOperation,
187 B1ElementwiseOperation,
188 CElementwiseOperation>
189{
191
192 static constexpr auto MXdlPerWave64 =
193 GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, true>();
194 static constexpr auto MXdlPerWave32 =
195 GetNXdlPerWave2<BlockSize, NPerBlock, MPerBlock, NPerXDL, MPerXDL, NXdlPerWave, false>();
196 static constexpr auto I0 = Number<0>{};
197 static constexpr auto I1 = Number<1>{};
198 static constexpr auto I2 = Number<2>{};
199
200 static constexpr auto matrix_padder =
202 MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
203
204 static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
205 {
206 const auto a_grid_desc_mraw_kraw = [&]() {
208 {
209 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
210 make_tuple(StrideA, I1));
211 }
213 {
214 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
215 make_tuple(I1, StrideA));
216 }
217 }();
218
219 const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
220
221 const auto M = a_grid_desc_m_k.GetLength(I0);
222 const auto K = a_grid_desc_m_k.GetLength(I1);
223
224 const auto AK0 = K / AK1;
225
226 return transform_tensor_descriptor(a_grid_desc_m_k,
231 }
232
233 static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
234 {
235 const auto b_grid_desc_nraw_kraw = [&]() {
237 {
238 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
239 make_tuple(I1, StrideB));
240 }
242 {
243 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
244 make_tuple(StrideB, I1));
245 }
246 }();
247
248 const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
249
250 const auto N = b_grid_desc_n_k.GetLength(I0);
251 const auto K = b_grid_desc_n_k.GetLength(I1);
252
253 const auto BK0 = K / BK1;
254
255 return transform_tensor_descriptor(b_grid_desc_n_k,
260 }
261
262 // Args: Gemm1KRaw, Gemm1NRaw, StrideB1
263 static auto MakeB1GridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
264 {
265 const auto b1_grid_desc_nraw_kraw = [&]() {
267 {
268 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
269 make_tuple(I1, StrideB));
270 }
272 {
273 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
274 make_tuple(StrideB, I1));
275 }
276 }();
277
278 const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw);
279
280 const auto N = b1_grid_desc_n_k.GetLength(I0);
281 const auto K = b1_grid_desc_n_k.GetLength(I1);
282
283 const auto B1K0 = K / B1K1;
284
286 b1_grid_desc_n_k,
291 }
292
293 static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
294 {
295 const auto c_grid_desc_mraw_nraw = [&]() {
297 {
298 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
299 make_tuple(StrideC, I1));
300 }
302 {
303 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
304 make_tuple(I1, StrideC));
305 }
306 }();
307
308 return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
309 }
310
312 {
314 index_t BatchStrideB,
315 index_t BatchStrideB1,
316 index_t BatchStrideC)
317 : BatchStrideA_(BatchStrideA),
318 BatchStrideB_(BatchStrideB),
319 BatchStrideB1_(BatchStrideB1),
320 BatchStrideC_(BatchStrideC)
321 {
322 }
323
324 __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
325 {
326 return g_idx * static_cast<long_index_t>(BatchStrideA_);
327 }
328
329 __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
330 {
331 return g_idx * static_cast<long_index_t>(BatchStrideB_);
332 }
333
334 __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
335 {
336 return g_idx * static_cast<long_index_t>(BatchStrideB1_);
337 }
338
339 __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
340 {
341 return g_idx * static_cast<long_index_t>(BatchStrideC_);
342 }
343
344 private:
345 index_t BatchStrideA_;
346 index_t BatchStrideB_;
347 index_t BatchStrideB1_;
348 index_t BatchStrideC_;
349 };
350
354 using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
355
356 // GridwiseGemm
357 template <index_t MXdlPerWave_>
359 ADataType, // TODO: distinguish A/B datatype
360 GemmAccDataType,
361 CShuffleDataType,
362 CDataType,
363 AElementwiseOperation,
364 BElementwiseOperation,
365 AccElementwiseOperation,
366 B1ElementwiseOperation,
367 CElementwiseOperation,
373 NumGemmKPrefetchStage,
374 BlockSize,
375 MPerBlock,
376 NPerBlock,
377 KPerBlock,
378 Gemm1NPerBlock,
379 Gemm1KPerBlock,
380 AK1,
381 BK1,
382 B1K1,
383 MPerXDL,
384 NPerXDL,
385 MXdlPerWave_,
386 NXdlPerWave,
387 Gemm1NXdlPerWave,
388 ABlockTransferThreadClusterLengths_AK0_M_AK1,
389 ABlockTransferThreadClusterArrangeOrder,
390 ABlockTransferSrcAccessOrder,
391 ABlockTransferSrcVectorDim,
392 ABlockTransferSrcScalarPerVector,
393 ABlockTransferDstScalarPerVector_AK1,
394 true,
395 ABlockLdsExtraM,
396 BBlockTransferThreadClusterLengths_BK0_N_BK1,
397 BBlockTransferThreadClusterArrangeOrder,
398 BBlockTransferSrcAccessOrder,
399 BBlockTransferSrcVectorDim,
400 BBlockTransferSrcScalarPerVector,
401 BBlockTransferDstScalarPerVector_BK1,
402 true,
403 BBlockLdsExtraN,
404 B1BlockTransferThreadClusterLengths_BK0_N_BK1,
405 B1BlockTransferThreadClusterArrangeOrder,
406 B1BlockTransferSrcAccessOrder,
407 B1BlockTransferSrcVectorDim,
408 B1BlockTransferSrcScalarPerVector,
409 B1BlockTransferDstScalarPerVector_BK1,
410 false,
411 B1BlockLdsExtraN,
412 CShuffleMXdlPerWavePerShuffle,
413 CShuffleNXdlPerWavePerShuffle,
414 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
415 CShuffleBlockTransferScalarPerVector_NPerBlock,
416 LoopSched>;
419
420 // Argument
421 struct Argument : public BaseArgument
422 {
423 Argument(const ADataType* p_a_grid,
424 const BDataType* p_b_grid,
425 const B1DataType* p_b1_grid,
426 CDataType* p_c_grid,
427 index_t MRaw,
428 index_t NRaw,
429 index_t KRaw,
430 index_t Gemm1NRaw, // = ORaw
431 index_t Batch,
432 index_t StrideA,
433 index_t StrideB,
434 index_t StrideB1,
435 index_t StrideC,
436 index_t BatchStrideA,
437 index_t BatchStrideB,
438 index_t BatchStrideB1,
439 index_t BatchStrideC,
440 AElementwiseOperation a_element_op,
441 BElementwiseOperation b_element_op,
442 AccElementwiseOperation acc_element_op,
443 B1ElementwiseOperation b1_element_op,
444 CElementwiseOperation c_element_op)
445 : p_a_grid_{p_a_grid},
446 p_b_grid_{p_b_grid},
447 p_b1_grid_{p_b1_grid},
448 p_c_grid_{p_c_grid},
452 DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(NRaw, Gemm1NRaw, StrideB1)},
453 c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, Gemm1NRaw, StrideC)},
454 block_2_ctile_map_{GridwiseGemm64::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
455 a_element_op_{a_element_op},
456 b_element_op_{b_element_op},
457 acc_element_op_{acc_element_op},
458 b1_element_op_{b1_element_op},
459 c_element_op_{c_element_op},
460 batch_count_(Batch),
461 compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC},
462 raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw}
463 {
464 }
465
466 void Print() const
467 {
468 std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl;
469 std::cout << "B0[BK0, N, BK1]: " << b_grid_desc_bk0_n_bk1_ << std::endl;
470 std::cout << "B1[BK0, N, BK1]: " << b1_grid_desc_bk0_n_bk1_ << std::endl;
471 std::cout << "C[M, N]: " << c_grid_desc_m_n_ << std::endl;
472 }
473
474 // private:
475 const ADataType* p_a_grid_;
476 const BDataType* p_b_grid_;
477 const B1DataType* p_b1_grid_;
478 CDataType* p_c_grid_;
484 AElementwiseOperation a_element_op_;
485 BElementwiseOperation b_element_op_;
486 AccElementwiseOperation acc_element_op_;
487 B1ElementwiseOperation b1_element_op_;
488 CElementwiseOperation c_element_op_;
491
492 // For robust IsSupportedArgument() check
493 std::vector<index_t> raw_lengths_m_n_k_o_;
494 };
495
496 // Invoker
497 struct Invoker : public BaseInvoker
498 {
500
501 template <typename GridwiseGemm>
502 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
503 {
505 {
506 throw std::runtime_error("wrong! unsupported argument");
507 }
508 auto c_grid_desc_mblock_mperblock_nblock_nperblock =
509 GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
510 arg.c_grid_desc_m_n_);
511 const index_t grid_size =
512 arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_;
513
514 // Gemm0_K
515 const auto K =
516 arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
517
518 float ave_time = 0;
519
520 auto launch_kernel = [&](auto has_main_k_block_loop_) {
521 const auto kernel = kernel_gemm_gemm_xdl_cshuffle_v1<
522 GridwiseGemm,
523 ADataType, // TODO: distiguish A/B datatype
524 CDataType,
525 AElementwiseOperation,
526 BElementwiseOperation,
527 AccElementwiseOperation,
528 B1ElementwiseOperation,
529 CElementwiseOperation,
533 typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
534 typename GridwiseGemm::DefaultBlock2CTileMap,
535 ComputeBasePtrOfStridedBatch,
536 has_main_k_block_loop_>;
537
538 return launch_and_time_kernel(stream_config,
539 kernel,
540 dim3(grid_size),
541 dim3(BlockSize),
542 0,
543 arg.p_a_grid_,
544 arg.p_b_grid_,
545 arg.p_b1_grid_,
546 arg.p_c_grid_,
547 arg.a_element_op_,
548 arg.b_element_op_,
549 arg.acc_element_op_,
550 arg.b1_element_op_,
551 arg.c_element_op_,
555 c_grid_desc_mblock_mperblock_nblock_nperblock,
557 arg.batch_count_,
559 };
560
561 // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
562 // to concern Gemm0's loop
563 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
564 {
565 ave_time = launch_kernel(integral_constant<bool, true>{});
566 }
567 else
568 {
569 ave_time = launch_kernel(integral_constant<bool, false>{});
570 }
571
572 return ave_time;
573 }
574 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
575 {
576 if(get_warp_size() == 64)
577 {
578 if constexpr(MXdlPerWave64 > 0)
579 {
580 return RunImp<GridwiseGemm64>(arg, stream_config);
581 }
582 }
583 else
584 {
585 if constexpr(MXdlPerWave32 > 0)
586 {
587 return RunImp<GridwiseGemm32>(arg, stream_config);
588 }
589 }
590 return 0;
591 }
592 // polymorphic
593 float Run(const BaseArgument* p_arg,
594 const StreamConfig& stream_config = StreamConfig{}) override
595 {
596 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
597 }
598 };
599
600 static constexpr bool IsValidCompilationParameter()
601 {
602 // TODO: properly implement this check
603 return true;
604 }
605
606 static bool IsSupportedArgument(const Argument& arg)
607 {
609 {
610 return false;
611 }
612 // Note: we need raw lengths since threadwise copy can not handle vector load when part of
613 // vector is out of bounds
614 const auto MRaw = arg.raw_lengths_m_n_k_o_[0];
615 const auto NRaw = arg.raw_lengths_m_n_k_o_[1];
616 const auto KRaw = arg.raw_lengths_m_n_k_o_[2];
617 const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3];
618
619 // Check scalar per vector requirement
620 const auto a_extent_lowest =
622 const auto b_extent_lowest =
624 const auto b1_extent_lowest =
626 const auto c_extent_lowest =
628
629 if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
630 b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
631 b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
632 c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
633 {
634 return false;
635 }
636
637 if(get_warp_size() == 64)
638 {
639 if constexpr(MXdlPerWave64 > 0)
640 {
646 }
647 }
648 else
649 {
650 if constexpr(MXdlPerWave32 > 0)
651 {
657 }
658 }
659 return false;
660 }
661
662 // polymorphic
663 bool IsSupportedArgument(const BaseArgument* p_arg) override
664 {
665 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
666 }
667
668 static auto MakeArgument(const ADataType* p_a,
669 const BDataType* p_b,
670 const B1DataType* p_b1,
671 CDataType* p_c,
672 index_t MRaw,
673 index_t NRaw,
674 index_t KRaw,
675 index_t Gemm1NRaw,
676 index_t Batch,
677 index_t StrideA,
678 index_t StrideB,
679 index_t StrideB1,
680 index_t StrideC,
681 index_t BatchStrideA,
682 index_t BatchStrideB,
683 index_t BatchStrideB1,
684 index_t BatchStrideC,
685 AElementwiseOperation a_element_op,
686 BElementwiseOperation b_element_op,
687 AccElementwiseOperation acc_element_op,
688 B1ElementwiseOperation b1_element_op,
689 CElementwiseOperation c_element_op)
690 {
691 return Argument{p_a, p_b, p_b1, p_c, MRaw,
692 NRaw, KRaw, Gemm1NRaw, Batch, StrideA,
693 StrideB, StrideB1, StrideC, BatchStrideA, BatchStrideB,
694 BatchStrideB1, BatchStrideC, a_element_op, b_element_op, acc_element_op,
695 b1_element_op, c_element_op};
696 }
697
698 static auto MakeInvoker() { return Invoker{}; }
699
700 // polymorphic
701 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
702 const void* p_b,
703 const void* p_b1,
704 void* p_c,
705 index_t MRaw,
706 index_t NRaw,
707 index_t KRaw,
708 index_t Gemm1NRaw,
709 index_t Batch,
710 index_t StrideA,
711 index_t StrideB,
712 index_t StrideB1,
713 index_t StrideC,
714 index_t BatchStrideA,
715 index_t BatchStrideB,
716 index_t BatchStrideB1,
717 index_t BatchStrideC,
718 AElementwiseOperation a_element_op,
719 BElementwiseOperation b_element_op,
720 AccElementwiseOperation acc_element_op,
721 B1ElementwiseOperation b1_element_op,
722 CElementwiseOperation c_element_op) override
723 {
724 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
725 static_cast<const BDataType*>(p_b),
726 static_cast<const B1DataType*>(p_b1),
727 static_cast<CDataType*>(p_c),
728 MRaw,
729 NRaw,
730 KRaw,
731 Gemm1NRaw,
732 Batch,
733 StrideA,
734 StrideB,
735 StrideB1,
736 StrideC,
737 BatchStrideA,
738 BatchStrideB,
739 BatchStrideB1,
740 BatchStrideC,
741 a_element_op,
742 b_element_op,
743 acc_element_op,
744 b1_element_op,
745 c_element_op);
746 }
747
748 // polymorphic
749 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
750 {
751 return std::make_unique<Invoker>(Invoker{});
752 }
753
754 // polymorphic
755 std::string GetTypeString() const override
756 {
757 auto str = std::stringstream();
758
759 // clang-format off
760 str << "DeviceBatchedGemmGemm_Xdl_CShuffle"
761 << "<"
762 << BlockSize << ", "
763 << MPerBlock << ", "
764 << NPerBlock << ", "
765 << KPerBlock << ", "
766 << AK1 << ", "
767 << BK1 << ", "
768 << MPerBlock << ", "
769 << Gemm1NPerBlock << ", "
770 << Gemm1KPerBlock << ", "
771 << B1K1 << ", "
772 << getGemmSpecializationString(GemmSpec) << ">";
773 // clang-format on
774
775 return str.str();
776 }
777};
778
779} // namespace device
780} // namespace tensor_operation
781} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
__global__ void kernel_gemm_gemm_xdl_cshuffle_v1(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, const FloatAB *__restrict__ p_b1_grid, FloatC *__restrict__ p_c_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const AccElementwiseOperation acc_element_op, const B1ElementwiseOperation b1_element_op, const CElementwiseOperation c_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const Block2CTileMap block_2_ctile_map, const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:44
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
integral_constant< index_t, N > Number
Definition number.hpp:12
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition ck/stream_config.hpp:10
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:80
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 &b1_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:223
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))> DefaultBlock2CTileMap
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:315
Definition utility/sequence.hpp:43
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition utility/integral_constant.hpp:20
Definition device_base.hpp:197
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:422
const B1DataType * p_b1_grid_
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:477
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:490
index_t batch_count_
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:489
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:479
AElementwiseOperation a_element_op_
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:484
void Print() const
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:466
CDataType * p_c_grid_
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:478
const ADataType * p_a_grid_
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:475
BElementwiseOperation b_element_op_
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:485
B1ElementwiseOperation b1_element_op_
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:487
CGridDesc_M_N c_grid_desc_m_n_
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:482
const BDataType * p_b_grid_
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:476
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:480
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, const B1DataType *p_b1_grid, CDataType *p_c_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t Gemm1NRaw, index_t Batch, index_t StrideA, index_t StrideB, index_t StrideB1, index_t StrideC, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideB1, index_t BatchStrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op)
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:423
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:481
GridwiseGemm64::DefaultBlock2CTileMap block_2_ctile_map_
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:483
AccElementwiseOperation acc_element_op_
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:486
CElementwiseOperation c_element_op_
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:488
std::vector< index_t > raw_lengths_m_n_k_o_
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:493
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:334
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:324
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:339
ComputeBasePtrOfStridedBatch(index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideB1, index_t BatchStrideC)
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:313
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:329
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:498
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:574
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:593
DeviceOp::Argument Argument
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:499
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:502
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:189
static bool IsSupportedArgument(const Argument &arg)
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:606
static constexpr auto matrix_padder
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:200
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:663
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:293
GridwiseBatchedGemmGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched > GridwiseGemmBase
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:358
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const void *p_b1, void *p_c, index_t MRaw, index_t NRaw, index_t KRaw, index_t Gemm1NRaw, index_t Batch, index_t StrideA, index_t StrideB, index_t StrideB1, index_t StrideC, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideB1, index_t BatchStrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op) override
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:701
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:749
decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)) BGridDesc_BK0_N_BK1
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:352
static constexpr auto I0
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:196
static auto MakeInvoker()
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:698
decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1)) B1GridDesc_BK0_N_BK1
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:353
static constexpr auto MXdlPerWave32
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:194
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:204
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:233
static constexpr auto I2
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:198
static auto MakeB1GridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:263
static constexpr auto MXdlPerWave64
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:192
static constexpr auto I1
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:197
decltype(MakeCGridDescriptor_M_N(1, 1, 1)) CGridDesc_M_N
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:354
decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)) AGridDesc_AK0_M_AK1
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:351
DeviceBatchedGemmGemm_Xdl_CShuffle DeviceOp
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:190
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, const B1DataType *p_b1, CDataType *p_c, index_t MRaw, index_t NRaw, index_t KRaw, index_t Gemm1NRaw, index_t Batch, index_t StrideA, index_t StrideB, index_t StrideB1, index_t StrideC, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideB1, index_t BatchStrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op)
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:668
GridwiseGemmBase< MXdlPerWave32 > GridwiseGemm32
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:418
GridwiseGemmBase< math::max(MXdlPerWave64, 1)> GridwiseGemm64
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:417
std::string GetTypeString() const override
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:755
static constexpr bool IsValidCompilationParameter()
Definition device_batched_gemm_gemm_xdl_cshuffle.hpp:600
Definition device_batched_gemm_gemm.hpp:29
Definition matrix_padder.hpp:63