device_gemm_xdl_cshuffle_v3r1.hpp Source File

device_gemm_xdl_cshuffle_v3r1.hpp Source File#

Composable Kernel: device_gemm_xdl_cshuffle_v3r1.hpp Source File
device_gemm_xdl_cshuffle_v3r1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8#include <typeinfo>
9
20
24
25namespace ck {
26namespace tensor_operation {
27namespace device {
28
29template <typename ALayout,
30 typename BLayout,
31 typename DsLayout,
32 typename CLayout,
33 typename ADataType,
34 typename BDataType,
35 typename DsDataType,
36 typename CDataType,
37 typename GemmAccDataType,
38 typename CShuffleDataType,
39 typename AElementwiseOperation,
40 typename BElementwiseOperation,
41 typename CElementwiseOperation,
42 GemmSpecialization GemmSpec,
43 index_t BlockSize,
44 index_t MPerBlock,
45 index_t NPerBlock,
46 index_t KPerBlock,
47 index_t AK1,
48 index_t BK1,
49 index_t MPerXDL,
50 index_t NPerXDL,
51 index_t MXdlPerWave,
52 index_t NXdlPerWave,
53 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
54 typename ABlockTransferThreadClusterArrangeOrder,
55 typename ABlockTransferSrcAccessOrder,
56 index_t ABlockTransferSrcVectorDim,
57 index_t ABlockTransferSrcScalarPerVector,
58 index_t ABlockTransferDstScalarPerVector_AK1,
59 bool ABlockLdsExtraM,
60 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
61 typename BBlockTransferThreadClusterArrangeOrder,
62 typename BBlockTransferSrcAccessOrder,
63 index_t BBlockTransferSrcVectorDim,
64 index_t BBlockTransferSrcScalarPerVector,
65 index_t BBlockTransferDstScalarPerVector_BK1,
66 bool BBlockLdsExtraN,
67 index_t CShuffleMXdlPerWavePerShuffle,
68 index_t CShuffleNXdlPerWavePerShuffle,
69 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
70 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
73 typename ReduceDataType = CDataType,
74 typename ComputeTypeA = CDataType,
75 typename ComputeTypeB = ComputeTypeA>
77 BLayout,
78 DsLayout,
79 CLayout,
80 ADataType,
81 BDataType,
82 DsDataType,
83 CDataType,
84 AElementwiseOperation,
85 BElementwiseOperation,
86 CElementwiseOperation>
87{
89 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
90 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
91
92 static constexpr index_t NumDTensor = DsDataType::Size();
93
95
96 // GridwiseGemm
97 template <index_t NXdlPerWave_>
99 ALayout,
100 BLayout,
101 CLayout,
102 ADataType,
103 BDataType,
104 GemmAccDataType,
105 CShuffleDataType,
106 ReduceDataType,
107 AElementwiseOperation,
108 BElementwiseOperation,
110 GemmSpec,
111 BlockSize,
112 MPerBlock,
113 NPerBlock,
114 KPerBlock,
115 AK1,
116 BK1,
117 MPerXDL,
118 NPerXDL,
119 MXdlPerWave,
120 NXdlPerWave_,
121 ABlockTransferThreadClusterLengths_AK0_M_AK1,
122 ABlockTransferThreadClusterArrangeOrder,
123 ABlockTransferSrcAccessOrder,
124 ABlockTransferSrcVectorDim,
125 ABlockTransferSrcScalarPerVector,
126 ABlockTransferDstScalarPerVector_AK1,
127 false,
128 ABlockLdsExtraM,
129 BBlockTransferThreadClusterLengths_BK0_N_BK1,
130 BBlockTransferThreadClusterArrangeOrder,
131 BBlockTransferSrcAccessOrder,
132 BBlockTransferSrcVectorDim,
133 BBlockTransferSrcScalarPerVector,
134 BBlockTransferDstScalarPerVector_BK1,
135 false,
136 BBlockLdsExtraN,
137 CShuffleMXdlPerWavePerShuffle,
138 CShuffleNXdlPerWavePerShuffle,
139 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
140 CShuffleBlockTransferScalarPerVector_NPerBlock,
141 BlkGemmPipeSched,
142 BlkGemmPipelineVer,
143 ComputeTypeA,
144 ComputeTypeB>;
147
148 struct Argument : public GridwiseGemm64::Argument
149 {
150 Argument(const ADataType* p_a_grid_,
151 const BDataType* p_b_grid_,
152 const std::array<const void*, NumDTensor> p_ds_,
153 CDataType* p_c_grid_,
154 index_t M_,
155 index_t N_,
156 index_t K_,
157 index_t StrideA_,
158 index_t StrideB_,
159 std::array<ck::index_t, NumDTensor> StrideDs_,
160 index_t StrideC_,
161 index_t k_batch_)
162 : GridwiseGemm64::Argument(p_a_grid_,
163 p_b_grid_,
164 reinterpret_cast<ReduceDataType*>(p_c_grid_),
165 M_,
166 N_,
167 K_,
168 StrideA_,
169 StrideB_,
170 StrideC_,
171 k_batch_,
172 true),
173 p_ds(p_ds_),
174 StrideDs(StrideDs_)
175 {
176 }
177
178 const std::array<const void*, NumDTensor> p_ds;
179 std::array<ck::index_t, NumDTensor> StrideDs;
180 };
181
183 using OutElementwiseOperation = CElementwiseOperation;
184
186 [](auto i) {
187 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
188 if constexpr(std::is_same<CLayout, DLayout>::value)
190 else
191 return Number<1>{};
192 },
194
196 ReduceDataType, // InDataType,
197 DsDataType, // DsDatatype
198 GemmAccDataType, // AccDataType,
199 CDataType, // OutDataType,
200 3, // Rank
201 1, // NumReduceDim
202 ReduceAdd,
205 256, // BlockSize_,
206 CShuffleBlockTransferScalarPerVector_NPerBlock, // MThreadSliceSize_,
207 1, // KThreadSliceSize_,
208 0, // InSrcVectorDim_,
209 CShuffleBlockTransferScalarPerVector_NPerBlock, // InSrcVectorSize_,
210 CShuffleBlockTransferScalarPerVector_NPerBlock, // OutDstVectorSize_
211 decltype(DsVectorLengthSequence)>;
212
213 // Invoker
214 struct Invoker : public BaseInvoker
215 {
216 float RunReduce(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
217 {
218 static constexpr index_t NumInDim = 3;
219 static constexpr index_t NumOutDim = 2;
220
221 std::array<ck::index_t, NumInDim> in_lengths = {arg.KBatch, arg.M, arg.N};
222 std::array<ck::index_t, NumOutDim> out_lengths = {arg.M, arg.N};
223
224 std::array<ck::index_t, NumInDim> in_strides;
225 std::array<ck::index_t, NumOutDim> out_strides;
226 if constexpr(std::is_same<CLayout, ck::tensor_layout::gemm::RowMajor>::value)
227 {
228 in_strides = {arg.M * arg.N, arg.N, 1};
229 out_strides = {arg.N, 1};
230 }
231 else
232 {
233 in_strides = {arg.M * arg.N, 1, arg.M};
234 out_strides = {1, arg.M};
235 }
236
237 std::array<int, 1> reduce_dims{0};
238
239 std::array<std::array<index_t, NumOutDim>, NumDTensor> DsLengths;
240 std::array<std::array<index_t, NumOutDim>, NumDTensor> DsStrides;
241
242 static_for<0, NumDTensor, 1>{}([&](auto i) {
243 DsLengths[i] = out_lengths;
244
245 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
246 if constexpr(std::is_same<DLayout, ck::tensor_layout::gemm::RowMajor>::value)
247 {
248 DsStrides[i] = {arg.StrideDs[i], 1};
249 }
250 else
251 {
252 DsStrides[i] = {1, arg.StrideDs[i]};
253 }
254 });
255
256 auto reduce = DeviceReduceInstance{};
257
258 auto argument_ptr = reduce.MakeArgumentPointer(in_lengths,
259 in_strides,
260 DsLengths,
261 DsStrides,
262 out_lengths,
263 out_strides,
264 reduce_dims,
265 arg.p_workspace_,
266 arg.p_ds,
267 arg.p_c_grid,
268 PassThrough{},
270
271 auto invoker_ptr = reduce.MakeInvokerPointer();
272
273 float ave_time = 0;
274
275 if(reduce.IsSupportedArgument(argument_ptr.get()))
276 {
277 ave_time = invoker_ptr->Run(argument_ptr.get(), stream_config);
278 }
279 else
280 {
281 throw std::runtime_error(
282 "The runtime parameters seems not supported by the device instance, exiting!");
283 }
284
285 return ave_time;
286 }
287
288 template <typename GridwiseGemm>
289 float RunImp(const Argument& arg_, const StreamConfig& stream_config = StreamConfig{})
290 {
291 auto arg = *reinterpret_cast<const typename GridwiseGemm::Argument*>(&arg_);
292
293 if(!(!(arg.IsReduceAdd() || NumDTensor > 0) &&
294 std::is_same<CDataType, ReduceDataType>::value))
295 {
296 if(arg.p_workspace_ == nullptr)
297 {
298 throw std::runtime_error("using reduce , but empty workspace!");
299 }
300
301 arg.p_c_grid = reinterpret_cast<ReduceDataType*>(arg.p_workspace_);
302 }
303
304 if(stream_config.log_level_ > 0)
305 {
306 arg.Print();
307 }
308
309 if(!GridwiseGemm::CheckValidity(arg))
310 {
311 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
312 }
313
314 index_t gdx, gdy, gdz;
315 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
316
317 float ave_time = 0;
318
319 index_t k_grain = arg.KBatch * KPerBlock;
320 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
321
322 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
323
324 const auto Run = [&](const auto& kernel) {
325 if(stream_config.flush_cache)
326 {
328 arg,
329 stream_config.rotating_count,
330 arg.M * arg.K * sizeof(ADataType),
331 arg.K * arg.N * sizeof(BDataType));
332 rotating_mem.Print();
333
334 auto run_flush_cache = [&]() {
335 // flush icache
337 // rotating mem
338 rotating_mem.Next();
339 };
340
342 stream_config,
343 run_flush_cache,
344 kernel,
345 dim3(gdx, gdy, gdz),
346 dim3(BlockSize),
347 0,
348 arg);
349 }
350 else
351 {
352 ave_time = launch_and_time_kernel(
353 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
354 }
355 };
356
357 constexpr index_t minimum_occupancy =
358 BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
359
360 if(has_main_k_block_loop)
361 {
362 // Tail number always full
363 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
364 BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
365 {
366
367 const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
368 true,
370 minimum_occupancy>;
371 Run(kernel);
372 }
373 // Tail number could be One to Seven
374 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
375 {
376 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
377 {
378 const auto kernel =
379 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
380 true,
382 minimum_occupancy,
384 Run(kernel);
385 }
386 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full)
387 {
388 const auto kernel =
389 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
390 true,
392 minimum_occupancy,
394 Run(kernel);
395 }
396
397 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
398 {
399 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
400 {
401 const auto kernel =
402 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
403 true,
405 minimum_occupancy,
407 Run(kernel);
408 }
409 }
410
411 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
412 {
413 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three)
414 {
415 const auto kernel =
416 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
417 true,
419 minimum_occupancy,
421 Run(kernel);
422 }
423 }
424
425 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
426 {
427 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four)
428 {
429 const auto kernel =
430 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
431 true,
433 minimum_occupancy,
435 Run(kernel);
436 }
437 }
438
439 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
440 {
441 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five)
442 {
443 const auto kernel =
444 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
445 true,
447 minimum_occupancy,
449 Run(kernel);
450 }
451 }
452
453 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
454 {
455 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
456 {
457 const auto kernel =
458 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
459 true,
461 minimum_occupancy,
463 Run(kernel);
464 }
465 }
466
467 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
468 {
469 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven)
470 {
471 const auto kernel =
472 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
473 true,
475 minimum_occupancy,
477 Run(kernel);
478 }
479 }
480 }
481 // Tail number could be Odd or Even
482 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
483 {
484
485 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
486 {
487 const auto kernel =
489 true,
491 minimum_occupancy,
493 Run(kernel);
494 }
495 else
496 {
497 const auto kernel =
499 true,
501 minimum_occupancy,
503 Run(kernel);
504 }
505 }
506 else
507 {
508 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
509 {
510 const auto kernel =
511 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
512 true,
514 minimum_occupancy,
516 Run(kernel);
517 }
518 else
519 {
520 const auto kernel =
521 kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
522 true,
524 minimum_occupancy,
526 Run(kernel);
527 }
528 }
529 }
530 else
531 {
532 // Tail number always 1
533 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
534 {
535
536 const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
537 false,
539 minimum_occupancy>;
540 Run(kernel);
541 }
542 }
543
544 if(!(!(arg.IsReduceAdd() || NumDTensor > 0) &&
545 std::is_same<CDataType, ReduceDataType>::value))
546 {
547 // reduce c data
548 ave_time += RunReduce(arg_, stream_config);
549 }
550 return ave_time;
551 }
552
554
555 // polymorphic
556 float Run(const BaseArgument* p_arg,
557 const StreamConfig& stream_config = StreamConfig{}) override
558 {
559 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
560 }
561 };
562
563 static constexpr bool IsValidCompilationParameter()
564 {
565 // TODO: properly implement this check
566 return true;
567 }
568
569 static bool IsSupportedArgument(const Argument& arg)
570 {
572 {
573 return false;
574 }
575 if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
576 GemmSpec == GemmSpecialization::NKPadding ||
577 GemmSpec == GemmSpecialization::MNKPadding ||
578 GemmSpec == GemmSpecialization::KPadding))
579 {
580 return false;
581 }
582
583 if(get_warp_size() == 64)
584 {
585 if constexpr(NXdlPerWave64 > 0)
586 {
588 }
589 }
590 else
591 {
592 if constexpr(NXdlPerWave32 > 0)
593 {
595 reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
596 }
597 }
598 return false;
599 }
600
601 // polymorphic
602 bool IsSupportedArgument(const BaseArgument* p_arg) override
603 {
604 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
605 }
606
607 static auto MakeArgument(const ADataType* p_a,
608 const BDataType* p_b,
609 const std::array<const void*, NumDTensor> p_ds,
610 CDataType* p_c,
611 index_t M,
612 index_t N,
613 index_t K,
614 index_t StrideA,
615 index_t StrideB,
616 std::array<ck::index_t, NumDTensor> StrideDs,
617 index_t StrideC,
618 index_t KBatch,
619 AElementwiseOperation,
620 BElementwiseOperation,
621 CElementwiseOperation)
622 {
623 return Argument{p_a, p_b, p_ds, p_c, M, N, K, StrideA, StrideB, StrideDs, StrideC, KBatch};
624 }
625
626 static auto MakeInvoker() { return Invoker{}; }
627
628 // polymorphic
629 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
630 const void* p_b,
631 std::array<const void*, NumDTensor> p_ds,
632 void* p_c,
633 index_t M,
634 index_t N,
635 index_t K,
636 index_t StrideA,
637 index_t StrideB,
638 std::array<ck::index_t, NumDTensor> StrideDs,
639 index_t StrideC,
640 index_t KBatch,
641 AElementwiseOperation,
642 BElementwiseOperation,
643 CElementwiseOperation) override
644 {
645 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
646 static_cast<const BDataType*>(p_b),
647 p_ds,
648 static_cast<CDataType*>(p_c),
649 M,
650 N,
651 K,
652 StrideA,
653 StrideB,
654 StrideDs,
655 StrideC,
656 KBatch);
657 }
658
659 // polymorphic
660 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
661 {
662 return std::make_unique<Invoker>(Invoker{});
663 }
664
665 // polymorphic
666 std::string GetTypeString() const override
667 {
668 auto str = std::stringstream();
669
670 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
673
674 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
680
681 // clang-format off
682 str << "DeviceGemmXdlUniversalReduce"
683 << "<"
684 << getGemmSpecializationString(GemmSpec) << ", "
685 << std::string(ALayout::name)[0]
686 << std::string(BLayout::name)[0]
687 << std::string(CLayout::name)[0]
688 << ">"
689 << " BlkSize: "
690 << BlockSize << ", "
691 << "BlkTile: "
692 << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
693 << "WaveTile: "
694 << MPerXDL<<"x"<<NPerXDL << ", "
695 << "WaveMap: "
696 << MXdlPerWave<<"x" << NXdlPerWave<<", "
697 << "VmemReadVec: "
698 << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
699 << "BlkGemmPipelineScheduler: "
700 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
701 << "BlkGemmPipelineVersion: "
702 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
703 << "BlkGemmPipelinePrefetchStages: "
704 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
705 // clang-format on
706
707 return str.str();
708 }
709
710 size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
711 {
712 auto arg = *dynamic_cast<const Argument*>(p_arg);
713
714 if(!(!(arg.IsReduceAdd() || NumDTensor > 0) &&
715 std::is_same<CDataType, ReduceDataType>::value))
716 {
717 std::cout << "using workspace" << std::endl;
718 return arg.M * arg.N * arg.KBatch * sizeof(ReduceDataType);
719 }
720
721 return 0;
722 }
723};
724
725} // namespace device
726} // namespace tensor_operation
727} // namespace ck
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
integral_constant< index_t, N > Number
Definition number.hpp:12
@ One
Definition blkgemmpipe_scheduler.hpp:37
@ Seven
Definition blkgemmpipe_scheduler.hpp:43
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Four
Definition blkgemmpipe_scheduler.hpp:40
@ Two
Definition blkgemmpipe_scheduler.hpp:38
@ Full
Definition blkgemmpipe_scheduler.hpp:49
@ Three
Definition blkgemmpipe_scheduler.hpp:39
@ Five
Definition blkgemmpipe_scheduler.hpp:41
@ Six
Definition blkgemmpipe_scheduler.hpp:42
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:62
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__host__ __device__ constexpr auto generate_sequence_v2(F &&f, Number< N >)
Definition sequence_helper.hpp:25
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
Definition ck/stream_config.hpp:10
"Universal" GEMM kernel with SplitK support.
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:247
Definition reduction_operator.hpp:37
Definition device_base.hpp:197
Definition device_gemm_xdl_cshuffle_v3r1.hpp:149
const std::array< const void *, NumDTensor > p_ds
Definition device_gemm_xdl_cshuffle_v3r1.hpp:178
std::array< ck::index_t, NumDTensor > StrideDs
Definition device_gemm_xdl_cshuffle_v3r1.hpp:179
Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, const std::array< const void *, NumDTensor > p_ds_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< ck::index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_)
Definition device_gemm_xdl_cshuffle_v3r1.hpp:150
Definition device_gemm_xdl_cshuffle_v3r1.hpp:215
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_xdl_cshuffle_v3r1.hpp:556
float RunReduce(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_xdl_cshuffle_v3r1.hpp:216
float RunImp(const Argument &arg_, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_xdl_cshuffle_v3r1.hpp:289
Definition device_gemm_xdl_cshuffle_v3r1.hpp:87
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition device_gemm_xdl_cshuffle_v3r1.hpp:629
static constexpr index_t NumDTensor
Definition device_gemm_xdl_cshuffle_v3r1.hpp:92
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition device_gemm_xdl_cshuffle_v3r1.hpp:710
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_xdl_cshuffle_v3r1.hpp:660
static constexpr auto NXdlPerWave32
Definition device_gemm_xdl_cshuffle_v3r1.hpp:90
DeviceReduceThreadWiseMultiD< ReduceDataType, DsDataType, GemmAccDataType, CDataType, 3, 1, ReduceAdd, PassThrough, OutElementwiseOperation, 256, CShuffleBlockTransferScalarPerVector_NPerBlock, 1, 0, CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, decltype(DsVectorLengthSequence)> DeviceReduceInstance
Definition device_gemm_xdl_cshuffle_v3r1.hpp:195
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_xdl_cshuffle_v3r1.hpp:563
static constexpr auto DsVectorLengthSequence
Definition device_gemm_xdl_cshuffle_v3r1.hpp:185
ck::reduce::Add ReduceAdd
Definition device_gemm_xdl_cshuffle_v3r1.hpp:182
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_xdl_cshuffle_v3r1.hpp:89
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_xdl_cshuffle_v3r1.hpp:569
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_xdl_cshuffle_v3r1.hpp:602
GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, ReduceDataType, AElementwiseOperation, BElementwiseOperation, PassThrough, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB > GridwiseGemmBase
Definition device_gemm_xdl_cshuffle_v3r1.hpp:98
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, const std::array< const void *, NumDTensor > p_ds, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition device_gemm_xdl_cshuffle_v3r1.hpp:607
static auto MakeInvoker()
Definition device_gemm_xdl_cshuffle_v3r1.hpp:626
CElementwiseOperation OutElementwiseOperation
Definition device_gemm_xdl_cshuffle_v3r1.hpp:183
ck::tensor_operation::element_wise::PassThrough PassThrough
Definition device_gemm_xdl_cshuffle_v3r1.hpp:94
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_xdl_cshuffle_v3r1.hpp:146
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_xdl_cshuffle_v3r1.hpp:145
std::string GetTypeString() const override
Definition device_gemm_xdl_cshuffle_v3r1.hpp:666
Definition device_gemm_v2.hpp:57
Definition device_reduce_threadwise_multi_d.hpp:47
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, Rank > inLengths, const std::array< index_t, Rank > inStrides, const std::array< std::array< index_t, NumDstDim >, NumDTensor > DsLengths, const std::array< std::array< index_t, NumDstDim >, NumDTensor > DsStrides, const std::array< index_t, NumDstDim > outLengths, const std::array< index_t, NumDstDim > outStrides, const std::array< int, NumReduceDim > reduceDims, const void *in_dev, const std::array< const void *, NumDTensor > ds_dev, void *out_dev, const InElementwiseOperation in_elementwise_op, const OutElementwiseOperation out_elementwise_op) override
Definition device_reduce_threadwise_multi_d.hpp:363
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
Definition flush_cache.hpp:299