device_grouped_conv_bwd_weight_wmma_cshuffle.hpp Source File

device_grouped_conv_bwd_weight_wmma_cshuffle.hpp Source File#

Composable Kernel: device_grouped_conv_bwd_weight_wmma_cshuffle.hpp Source File
device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <numeric>
8#include <sstream>
9
20
21namespace ck {
22namespace tensor_operation {
23namespace device {
24
25template <index_t NDimSpatial,
26 typename InLayout,
27 typename WeiLayout,
28 typename OutLayout,
29 typename InDataType,
30 typename WeiDataType,
31 typename OutDataType,
32 typename AccDataType,
33 typename InElementwiseOperation,
34 typename WeiElementwiseOperation,
35 typename OutElementwiseOperation,
36 ConvolutionBackwardWeightSpecialization ConvBackwardWeightSpecialization,
37 index_t BlockSize,
38 index_t MPerBlock,
39 index_t NPerBlock,
40 index_t K0PerBlock,
41 index_t K1,
42 index_t MPerWMMA,
43 index_t NPerWMMA,
44 index_t MRepeat,
45 index_t NRepeat,
46 typename ABlockTransferThreadClusterLengths_K0_M_K1,
47 typename ABlockTransferThreadClusterArrangeOrder,
48 typename ABlockTransferSrcAccessOrder,
49 index_t ABlockTransferSrcVectorDim,
50 index_t ABlockTransferSrcScalarPerVector,
51 index_t ABlockTransferDstScalarPerVector_K1,
52 bool ABlockLdsAddExtraM,
53 typename BBlockTransferThreadClusterLengths_K0_N_K1,
54 typename BBlockTransferThreadClusterArrangeOrder,
55 typename BBlockTransferSrcAccessOrder,
56 index_t BBlockTransferSrcVectorDim,
57 index_t BBlockTransferSrcScalarPerVector,
58 index_t BBlockTransferDstScalarPerVector_K1,
59 bool BBlockLdsAddExtraN,
60 index_t CShuffleMRepeatPerShuffle,
61 index_t CShuffleNRepeatPerShuffle,
62 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
63 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
64 index_t NumGemmKPrefetchStage = 1,
67 typename ck::enable_if<NDimSpatial == 3, bool>::type = false>
69 : public DeviceGroupedConvBwdWeight<NDimSpatial,
70 InLayout,
71 WeiLayout,
72 OutLayout,
73 InDataType,
74 WeiDataType,
75 OutDataType,
76 InElementwiseOperation,
77 WeiElementwiseOperation,
78 OutElementwiseOperation>
79{
81
82 using ADataType = OutDataType;
83 using BDataType = InDataType;
84 using CDataType = WeiDataType;
85
86 using AElementwiseOperation = OutElementwiseOperation;
87 using BElementwiseOperation = InElementwiseOperation;
88 using CElementwiseOperation = WeiElementwiseOperation;
89
90 // TODO make A/B datatype different
91 using ABDataType = InDataType;
92
93 static constexpr auto I0 = Number<0>{};
94 static constexpr auto I1 = Number<1>{};
95 static constexpr auto I2 = Number<2>{};
96 static constexpr auto I3 = Number<3>{};
97 static constexpr auto I4 = Number<4>{};
98 static constexpr auto I5 = Number<5>{};
99
100 static constexpr auto GemmK1Number = Number<K1>{};
101 static constexpr index_t KPerBlock = K0PerBlock * GemmK1Number;
102
103 template <index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
104 constexpr static auto
106 const index_t Do,
107 const index_t Ho,
108 const index_t Wo,
109 const index_t K,
110 const std::array<index_t, NDimSpatial + 3>& output_strides)
111 {
112 const index_t WoStride = output_strides[5];
113 const auto KStride = Number<1>{};
114 return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K),
115 make_tuple(WoStride, KStride));
116 }
117
118 template <index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
119 constexpr static auto
121 const index_t Di,
122 const index_t Hi,
123 const index_t Wi,
124 const index_t C,
125 const std::array<index_t, NDimSpatial + 3>& input_strides)
126 {
127 const index_t NStride = input_strides[1];
128 const index_t DiStride = input_strides[3];
129 const index_t HiStride = input_strides[4];
130 const index_t WiStride = input_strides[5];
131 const auto CStride = input_strides[2];
132 if constexpr(ConvBackwardWeightSpecialization ==
134 {
135 return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, C),
136 make_tuple(WiStride, CStride));
137 }
138 else
139 {
141 make_tuple(N, Di, Hi, Wi, C),
142 make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
143 }
144 }
145
146 template <index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
147 constexpr static auto
149 const index_t Z,
150 const index_t Y,
151 const index_t X,
152 const index_t C,
153 const std::array<index_t, NDimSpatial + 3>& weights_strides)
154 {
155 const auto CStride = Number<1>{};
156 const auto KStride = weights_strides[1];
157 return make_naive_tensor_descriptor(make_tuple(K, Z * Y * X * C),
158 make_tuple(KStride, CStride));
159 }
160
161 template <index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
163 const index_t N,
164 const index_t K,
165 const index_t C,
166 const std::array<index_t, NDimSpatial>& input_spatial_lengths,
167 const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
168 const std::array<index_t, NDimSpatial>& output_spatial_lengths,
169 const std::array<index_t, NDimSpatial + 3>& input_strides,
170 const std::array<index_t, NDimSpatial + 3>& weights_strides,
171 const std::array<index_t, NDimSpatial + 3>& output_strides,
172 const std::array<index_t, NDimSpatial>& conv_filter_strides,
173 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
174 const std::array<index_t, NDimSpatial>& input_left_pads,
175 const std::array<index_t, NDimSpatial>& input_right_pads)
176 {
177 using namespace ck;
178
179 const index_t Di = input_spatial_lengths[0];
180 const index_t Hi = input_spatial_lengths[1];
181 const index_t Wi = input_spatial_lengths[2];
182
183 const index_t Do = output_spatial_lengths[0];
184 const index_t Ho = output_spatial_lengths[1];
185 const index_t Wo = output_spatial_lengths[2];
186
187 const index_t Z = filter_spatial_lengths[0];
188 const index_t Y = filter_spatial_lengths[1];
189 const index_t X = filter_spatial_lengths[2];
190
191 const index_t ConvStrideD = conv_filter_strides[0];
192 const index_t ConvStrideH = conv_filter_strides[1];
193 const index_t ConvStrideW = conv_filter_strides[2];
194
195 const index_t ConvDilationD = conv_filter_dilations[0];
196 const index_t ConvDilationH = conv_filter_dilations[1];
197 const index_t ConvDilationW = conv_filter_dilations[2];
198
199 const index_t InLeftPadD = input_left_pads[0];
200 const index_t InLeftPadH = input_left_pads[1];
201 const index_t InLeftPadW = input_left_pads[2];
202
203 const index_t InRightPadD = input_right_pads[0];
204 const index_t InRightPadH = input_right_pads[1];
205 const index_t InRightPadW = input_right_pads[2];
206
207 const index_t GemmKTotal = N * Do * Ho * Wo;
208 const index_t GemmM = K;
209 const index_t GemmN = C * Z * X * Y;
210
211 const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
212 const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
213
214 const index_t GemmK0 =
215 math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock) * K0PerBlock;
216 const index_t GemmKPad = GemmK0 * GemmK1Number;
217
218 const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
219 const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);
220 const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, Z, Y, X, C, weights_strides);
221
222 if constexpr(ConvBackwardWeightSpecialization ==
224 {
225 // A: output tensor
226 const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
227 out_grid_desc,
228 make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
232
233 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
234 out_gemmkpad_gemmm_grid_desc,
239
240 // B: input tensor
241 const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
242 in_grid_desc,
243 make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
247
248 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
249 in_gemmkpad_gemmn_grid_desc,
254
255 return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
256 in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
257 wei_grid_desc);
258 }
259 else
260 {
261 // A: output tensor
262 const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
263 out_grid_desc,
264 make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
268
269 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
270 out_gemmkpad_gemmm_grid_desc,
275
276 // B: input tensor
277 const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
278 in_grid_desc,
280 make_pad_transform(Di, InLeftPadD, InRightPadD),
281 make_pad_transform(Hi, InLeftPadH, InRightPadH),
282 make_pad_transform(Wi, InLeftPadW, InRightPadW),
288
289 const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
290 in_n_dip_hip_wip_c_grid_desc,
293 make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
294 make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
295 make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
303 Sequence<7>{}));
304
305 const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor(
306 in_n_z_do_y_ho_x_wo_c_grid_desc,
308 make_merge_transform(make_tuple(N, Do, Ho, Wo))),
311
312 const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
313 in_gemmktotal_gemmn_grid_desc,
314 make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
318
319 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
320 in_gemmkpad_gemmn_grid_desc,
325
326 // Pad
327 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
329 out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
331 make_right_pad_transform(GemmM, PadGemmM),
335
336 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
338 in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
340 make_right_pad_transform(GemmN, PadGemmN),
344
345 const auto wei_gemmm_gemmn_pad_grid_desc =
346 transform_tensor_descriptor(wei_grid_desc,
347 make_tuple(make_right_pad_transform(GemmM, PadGemmM),
348 make_right_pad_transform(GemmN, PadGemmN)),
351
352 return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
353 in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
354 wei_gemmm_gemmn_pad_grid_desc);
355 }
356 }
357
358 template <index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
359 static auto GetABCGridDesc()
360 {
361 const index_t dim = 1;
362 const std::array<index_t, NDimSpatial> lengths{1, 1, 1};
363 const std::array<index_t, NDimSpatial + 3> strides{1, 1, 1, 1, 1, 1};
364 const std::array<index_t, NDimSpatial> params{1, 1, 1};
366 dim,
367 dim,
368 lengths,
369 lengths,
370 lengths,
371 strides,
372 strides,
373 strides,
374 params,
375 params,
376 params,
377 params);
378 }
379
381
385
386 using CShuffleDataType = AccDataType;
387
389 // DataType Family
390 ADataType,
391 BDataType,
392 AccDataType,
394 Tuple<>,
395 CDataType,
396 // InMemory Data Descriptor
399 Tuple<>,
401 // ElementwiseOp Family
406 // Tiling Family
407 MPerBlock,
408 NPerBlock,
409 KPerBlock,
410 MPerWMMA,
411 NPerWMMA,
412 K1,
413 MRepeat,
414 NRepeat,
415 // ThreadCluster Family
416 BlockSize,
417 ABlockTransferThreadClusterLengths_K0_M_K1,
418 ABlockTransferThreadClusterArrangeOrder,
419 ABlockTransferSrcAccessOrder,
420 ABlockTransferSrcVectorDim,
421 ABlockTransferSrcScalarPerVector,
422 ABlockTransferDstScalarPerVector_K1,
423 false,
424 true,
425 ABlockLdsAddExtraM,
426 BBlockTransferThreadClusterLengths_K0_N_K1,
427 BBlockTransferThreadClusterArrangeOrder,
428 BBlockTransferSrcAccessOrder,
429 BBlockTransferSrcVectorDim,
430 BBlockTransferSrcScalarPerVector,
431 BBlockTransferDstScalarPerVector_K1,
432 false,
433 true,
434 BBlockLdsAddExtraN,
435 CShuffleMRepeatPerShuffle,
436 CShuffleNRepeatPerShuffle,
437 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
438 CShuffleBlockTransferScalarPerVector_NPerBlock,
439 NumGemmKPrefetchStage,
440 LoopSched,
441 PipelineVer>;
442
445
448 CGridDesc_M_N{}));
449
451 CGridDesc_M_N{}, I1 /* M01 */, I1 /* N01 */));
452
453 struct Argument : public BaseArgument
454 {
455 Argument(const InDataType* p_in_grid,
456 WeiDataType* p_wei_grid,
457 const OutDataType* p_out_grid,
458 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
459 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
460 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
461 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
462 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
463 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
464 const std::array<index_t, NDimSpatial>& conv_filter_strides,
465 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
466 const std::array<index_t, NDimSpatial>& input_left_pads,
467 const std::array<index_t, NDimSpatial>& input_right_pads,
468 InElementwiseOperation in_element_op,
469 WeiElementwiseOperation wei_element_op,
470 OutElementwiseOperation out_element_op,
471 index_t split_k)
472 : p_a_grid_{p_out_grid},
473 p_b_grid_{p_in_grid},
474 p_c_grid_{p_wei_grid},
481 a_element_op_{out_element_op},
482 b_element_op_{in_element_op},
483 c_element_op_{wei_element_op},
484 Conv_G_{a_g_n_c_wis_lengths[0]},
485 Conv_N_{a_g_n_c_wis_lengths[1]},
486 Conv_K_{b_g_k_c_xs_lengths[1]},
487 Conv_C_{a_g_n_c_wis_lengths[2]},
491 conv_filter_strides_{conv_filter_strides},
492 input_left_pads_{input_left_pads},
493 input_right_pads_{input_right_pads},
494 k_batch_{split_k}
495 {
496 constexpr index_t spatial_offset = 3;
497 std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset,
498 end(a_g_n_c_wis_lengths),
500 std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset,
501 end(b_g_k_c_xs_lengths),
503 std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset,
504 end(e_g_n_k_wos_lengths),
506
507 const auto descs =
509 Conv_N_,
510 Conv_K_,
511 Conv_C_,
515 a_g_n_c_wis_strides,
516 b_g_k_c_xs_strides,
517 e_g_n_k_wos_strides,
518 conv_filter_strides,
519 conv_filter_dilations,
520 input_left_pads,
521 input_right_pads);
522
525 c_grid_desc_m_n_ = descs[I2];
526
528 c_grid_desc_m_n_, I1 /* M01 */, I1 /* N01 */);
529
530 // A/B/C Batch Stride
531 compute_ptr_offset_of_batch_.BatchStrideA_ = e_g_n_k_wos_strides[0];
532 compute_ptr_offset_of_batch_.BatchStrideB_ = a_g_n_c_wis_strides[0];
533 compute_ptr_offset_of_batch_.BatchStrideE_ =
534 Conv_K_ * Conv_C_ *
535 std::accumulate(begin(filter_spatial_lengths_),
537 index_t{1},
538 std::multiplies<>{});
539
544 {
548 }
549 }
550
558
560
561 // for computing batch offset
562 ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_;
563
564 OutElementwiseOperation a_element_op_;
565 InElementwiseOperation b_element_op_;
566 WeiElementwiseOperation c_element_op_;
567
568 // for checking IsSupportedArgument()
573 std::array<index_t, NDimSpatial> input_spatial_lengths_;
574 std::array<index_t, NDimSpatial> filter_spatial_lengths_;
575 std::array<index_t, NDimSpatial> output_spatial_lengths_;
576 const std::array<index_t, NDimSpatial>& conv_filter_strides_;
577 const std::array<index_t, NDimSpatial>& input_left_pads_;
578 const std::array<index_t, NDimSpatial>& input_right_pads_;
580 };
581
582 // Invoker
583 struct Invoker : public BaseInvoker
584 {
586
587 void Print(const Argument& arg)
588 {
589 std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
590 << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
591 << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
592 << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << "}" << std::endl;
593
594 std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
595 << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
596 << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
597 << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << "}" << std::endl;
598
599 std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
600 << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
601 }
602
603 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
604 {
605 if(stream_config.log_level_ > 0)
606 {
607 Print(arg);
608 }
609
614 {
615 throw std::runtime_error(
616 "wrong! GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle has invalid "
617 "setting");
618 }
619
620 const index_t grid_size =
621 arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Conv_G_;
622
623 const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
624
625 const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
626
627 auto launch_kernel = [&](auto has_main_k_block_loop) {
628 constexpr bool has_main_loop = has_main_k_block_loop.value;
629
631 GridwiseGemm,
632 ADataType,
633 BDataType,
635 CDataType,
636 OutElementwiseOperation,
637 InElementwiseOperation,
638 WeiElementwiseOperation,
644 ComputePtrOffsetOfStridedBatch<>,
645 has_main_loop>;
646
647 using EmptyTuple = Tuple<>;
648 return launch_and_time_kernel(stream_config,
649 kernel,
650 dim3(grid_size),
651 dim3(BlockSize),
652 0,
653 arg.p_a_grid_,
654 arg.p_b_grid_,
655 EmptyTuple{}, // Ds
656 arg.p_c_grid_,
657 arg.a_element_op_,
658 arg.b_element_op_,
659 arg.c_element_op_,
660 arg.Conv_G_,
667 };
668
669 if(has_main_k0_block_loop)
670 {
671 return launch_kernel(integral_constant<bool, true>{});
672 }
673 else
674 {
675 return launch_kernel(integral_constant<bool, false>{});
676 }
677 }
678
679 float Run(const BaseArgument* p_arg,
680 const StreamConfig& stream_config = StreamConfig{}) override
681 {
682 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
683 }
684 };
685
686 static constexpr bool IsValidCompilationParameter()
687 {
688 // TODO: properly implement this check
689 return true;
690 }
691
692 static bool IsSupportedArgument(const Argument& arg)
693 {
694 // check device
696 {
698 {
699 return false;
700 }
701 }
702 else
703 {
704 return false;
705 }
706
707 // TODO: Add support for split_k > 1
708 if(arg.k_batch_ != 1)
709 {
710 return false;
711 }
712
715 {
716 return false;
717 }
718
719 if constexpr(ConvBackwardWeightSpecialization ==
721 {
722 // check if it's a 1x1 convolution with stride=1 and no padding
723 for(int i = 0; i < NDimSpatial; i++)
724 {
725 if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
726 arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
727 {
728 return false;
729 }
730 }
731 }
732
733 // vector load A/B matrix from global memory
734 if(!(ABlockTransferSrcVectorDim == 1 && BBlockTransferSrcVectorDim == 1 &&
735 arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 &&
736 arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
737 {
738 return false;
739 }
740
741 // vector store C matrix into global memory
742 if(!(arg.Conv_C_ % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
743 {
744 return false;
745 }
746
747 // Gridwise GEMM size
752 }
753
754 bool IsSupportedArgument(const BaseArgument* p_arg) override
755 {
756 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
757 }
758
759 static auto
760 MakeArgument(const InDataType* p_in_grid,
761 WeiDataType* p_wei_grid,
762 const OutDataType* p_out_grid,
763 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
764 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
765 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
766 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
767 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
768 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
769 const std::array<index_t, NDimSpatial>& conv_filter_strides,
770 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
771 const std::array<index_t, NDimSpatial>& input_left_pads,
772 const std::array<index_t, NDimSpatial>& input_right_pads,
773 InElementwiseOperation in_element_op,
774 WeiElementwiseOperation wei_element_op,
775 OutElementwiseOperation out_element_op,
776 const index_t split_k)
777 {
778 return Argument{p_in_grid,
779 p_wei_grid,
780 p_out_grid,
781 a_g_n_c_wis_lengths, // input
782 a_g_n_c_wis_strides,
783 b_g_k_c_xs_lengths, // weight
784 b_g_k_c_xs_strides,
785 e_g_n_k_wos_lengths, // output
786 e_g_n_k_wos_strides,
787 conv_filter_strides,
788 conv_filter_dilations,
789 input_left_pads,
790 input_right_pads,
791 in_element_op,
792 wei_element_op,
793 out_element_op,
794 split_k};
795 }
796
797 static auto MakeInvoker() { return Invoker{}; }
798
799 std::unique_ptr<BaseArgument>
800 MakeArgumentPointer(const void* p_in_grid,
801 void* p_wei_grid,
802 const void* p_out_grid,
803 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
804 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
805 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
806 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
807 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
808 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
809 const std::array<index_t, NDimSpatial>& conv_filter_strides,
810 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
811 const std::array<index_t, NDimSpatial>& input_left_pads,
812 const std::array<index_t, NDimSpatial>& input_right_pads,
813 InElementwiseOperation in_element_op,
814 WeiElementwiseOperation wei_element_op,
815 OutElementwiseOperation out_element_op,
816 const index_t split_k) override
817 {
818 return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
819 static_cast<WeiDataType*>(p_wei_grid),
820 static_cast<const OutDataType*>(p_out_grid),
821 a_g_n_c_wis_lengths, // input
822 a_g_n_c_wis_strides,
823 b_g_k_c_xs_lengths, // weight
824 b_g_k_c_xs_strides,
825 e_g_n_k_wos_lengths, // output
826 e_g_n_k_wos_strides,
827 conv_filter_strides,
828 conv_filter_dilations,
829 input_left_pads,
830 input_right_pads,
831 in_element_op,
832 wei_element_op,
833 out_element_op,
834 split_k);
835 }
836
837 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
838 {
839 return std::make_unique<Invoker>(Invoker{});
840 }
841
842 std::string GetTypeString() const override
843 {
844 auto str = std::stringstream();
845
846 // clang-format off
847 str << "DeviceGroupedConvBwdWeight_Wmma_CShuffle"
848 << "<"
849 << BlockSize << ", "
850 << MPerBlock << ", "
851 << NPerBlock << ", "
852 << K0PerBlock << ", "
853 << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", "
854 << K1 << ", "
855 << ABlockTransferSrcScalarPerVector << ", "
856 << ABlockTransferDstScalarPerVector_K1 << ", "
857 << BBlockTransferSrcScalarPerVector << ", "
858 << BBlockTransferDstScalarPerVector_K1 << ", "
859 << CShuffleMRepeatPerShuffle << ", "
860 << CShuffleNRepeatPerShuffle << ", "
861 << CShuffleBlockTransferScalarPerVector_NPerBlock
862 << ">";
863 // clang-format on
864
865 return str.str();
866 }
867};
868
869} // namespace device
870} // namespace tensor_operation
871} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
ConvolutionBackwardWeightSpecialization
Definition convolution_backward_weight_specialization.hpp:13
@ Filter1x1Stride1Pad0
Definition convolution_backward_weight_specialization.hpp:15
constexpr bool is_GNDHWC_GKZYXC_GNDHWK()
Definition device_grouped_conv_utils.hpp:88
constexpr bool is_NDHWGC_GKZYXC_NDHWGK()
Definition device_grouped_conv_utils.hpp:80
std::string getConvBackwardWeightSpecializationString(const ConvolutionBackwardWeightSpecialization &s)
Definition convolution_backward_weight_specialization.hpp:21
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:19
__host__ __device__ constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition multi_index_transform_helper.hpp:48
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
__global__ void kernel_grouped_conv_multiple_d_wmma_cshuffle(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const index_t batch_count, const AGridDesc_AK0_M_AK1 a_grid_desc, const BGridDesc_BK0_N_BK1 b_grid_desc, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_, const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:40
constexpr bool is_same_v
Definition type.hpp:283
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:326
Definition utility/sequence.hpp:43
Definition utility/tuple.hpp:186
Definition utility/integral_constant.hpp:20
Definition device_base.hpp:197
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:454
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:557
OutElementwiseOperation a_element_op_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:564
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:562
const index_t k_batch_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:579
InElementwiseOperation b_element_op_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:565
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:555
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:554
std::array< index_t, NDimSpatial > output_spatial_lengths_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:575
const std::array< index_t, NDimSpatial > & conv_filter_strides_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:576
CGridDesc_M_N c_grid_desc_m_n_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:556
std::array< index_t, NDimSpatial > filter_spatial_lengths_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:574
std::array< index_t, NDimSpatial > input_spatial_lengths_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:573
const index_t Conv_K_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:571
CDataType * p_c_grid_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:553
Block2CTileMap block_2_ctile_map_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:559
const std::array< index_t, NDimSpatial > & input_right_pads_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:578
const BDataType * p_b_grid_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:552
const ADataType * p_a_grid_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:551
const index_t Conv_C_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:572
Argument(const InDataType *p_in_grid, WeiDataType *p_wei_grid, const OutDataType *p_out_grid, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, index_t split_k)
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:455
const index_t Conv_G_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:569
const index_t Conv_N_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:570
WeiElementwiseOperation c_element_op_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:566
const std::array< index_t, NDimSpatial > & input_left_pads_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:577
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:584
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:679
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:603
void Print(const Argument &arg)
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:587
DeviceOp::Argument Argument
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:585
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:79
remove_cvref_t< decltype(ABCGridDescs{}[I1])> BGridDesc_K0_N_K1
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:383
remove_cvref_t< decltype(ABCGridDescs{}[I0])> AGridDesc_K0_M_K1
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:382
OutElementwiseOperation AElementwiseOperation
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:86
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:692
static constexpr auto make_wei_grid_desc(const index_t K, const index_t Z, const index_t Y, const index_t X, const index_t C, const std::array< index_t, NDimSpatial+3 > &weights_strides)
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:148
static auto MakeArgument(const InDataType *p_in_grid, WeiDataType *p_wei_grid, const OutDataType *p_out_grid, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, const index_t split_k)
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:760
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in_grid, void *p_wei_grid, const void *p_out_grid, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, const index_t split_k) override
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:800
static constexpr auto make_in_grid_desc(const index_t N, const index_t Di, const index_t Hi, const index_t Wi, const index_t C, const std::array< index_t, NDimSpatial+3 > &input_strides)
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:120
static constexpr bool IsValidCompilationParameter()
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:686
decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( CGridDesc_M_N{})) CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:446
GridwiseGemmMultipleD_Wmma< ADataType, BDataType, AccDataType, CShuffleDataType, Tuple<>, CDataType, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, Tuple<>, CGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, MPerBlock, NPerBlock, KPerBlock, MPerWMMA, NPerWMMA, K1, MRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, true, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, true, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, NumGemmKPrefetchStage, LoopSched, PipelineVer > GridwiseGemm
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:388
DeviceGroupedConvBwdWeight_Wmma_CShuffle DeviceOp
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:80
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap( CGridDesc_M_N{}, I1, I1)) Block2CTileMap
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:450
static constexpr index_t KPerBlock
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:101
static constexpr auto I3
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:96
InDataType BDataType
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:83
InElementwiseOperation BElementwiseOperation
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:87
static constexpr auto I4
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:97
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:754
static constexpr auto I0
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:93
AccDataType CShuffleDataType
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:386
InDataType ABDataType
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:91
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:837
static constexpr auto I5
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:98
static constexpr auto I2
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:95
static auto MakeInvoker()
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:797
static constexpr auto make_out_grid_desc(const index_t N, const index_t Do, const index_t Ho, const index_t Wo, const index_t K, const std::array< index_t, NDimSpatial+3 > &output_strides)
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:105
WeiElementwiseOperation CElementwiseOperation
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:88
static constexpr auto GemmK1Number
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:100
static constexpr auto I1
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:94
decltype(GetABCGridDesc< NDimSpatial >()) ABCGridDescs
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:380
static auto GetABCGridDesc()
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:359
std::string GetTypeString() const override
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:842
WeiDataType CDataType
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:84
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(const index_t N, const index_t K, const index_t C, const std::array< index_t, NDimSpatial > &input_spatial_lengths, const std::array< index_t, NDimSpatial > &filter_spatial_lengths, const std::array< index_t, NDimSpatial > &output_spatial_lengths, const std::array< index_t, NDimSpatial+3 > &input_strides, const std::array< index_t, NDimSpatial+3 > &weights_strides, const std::array< index_t, NDimSpatial+3 > &output_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads)
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:162
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(Tuple<>{})) DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:443
remove_cvref_t< decltype(ABCGridDescs{}[I2])> CGridDesc_M_N
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:384
OutDataType ADataType
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:82
Definition device_grouped_conv_bwd_weight.hpp:29