grouped_convolution_backward_data_kernel.hpp Source File

grouped_convolution_backward_data_kernel.hpp Source File#

Composable Kernel: grouped_convolution_backward_data_kernel.hpp Source File
grouped_convolution_backward_data_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <string>
8
9#include "ck_tile/core.hpp"
16
17namespace ck_tile {
18
20template <typename GroupedConvTraitsType_, typename TilePartitioner_>
22{
24
26 TransformConvBwdDataToGemm<GroupedConvTraitsType_::NDimSpatial,
27 GroupedConvTraitsType_::ConvSpecialization,
28 GroupedConvTraitsType_::VectorSizeA,
29 GroupedConvTraitsType_::VectorSizeB,
30 GroupedConvTraitsType_::VectorSizeC,
31 true>; // Split N enabled
32 static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
33
34 static constexpr auto I0 = number<0>();
35 static constexpr auto I1 = number<1>();
36
37 template <
38 typename InLay = typename GroupedConvTraitsType_::InLayout,
39 typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
40 typename OutLay = typename GroupedConvTraitsType_::OutLayout,
41 typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NWGC> &&
42 std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
43 std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
44 bool>::type = false>
46 {
47 in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
48 static_cast<index_t>(args.N_),
49 static_cast<index_t>(args.C_),
50 static_cast<index_t>(args.input_spatial_lengths_[0])};
51 wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
52 static_cast<index_t>(args.K_),
53 static_cast<index_t>(args.C_),
54 static_cast<index_t>(args.filter_spatial_lengths_[0])};
55 out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
56 static_cast<index_t>(args.N_),
57 static_cast<index_t>(args.K_),
58 static_cast<index_t>(args.output_spatial_lengths_[0])};
59
60 conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0])};
62 input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
63 input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
64
65 k_batch = args.k_batch;
66
67 in_ptr = args.in_ptr;
68 wei_ptr = args.wei_ptr;
69 for(index_t d = 0; d < NumDTensor; d++)
70 {
71 ds_ptr[d] = args.ds_ptr[d];
72 }
73 out_ptr = args.out_ptr;
74
75 const index_t X = wei_g_k_c_xs_lengths[3];
76 const index_t ConvStrideW = conv_filter_strides[0];
77 const index_t ConvDilationW = conv_filter_dilations[0];
78 const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW);
79 const auto XTilde = ConvStrideW / GcdStrideDilationW;
80
81 for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
82 {
83 const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde);
84
85 if(XDotSlice <= 0)
86 {
87 continue;
88 }
89
91 {
92 gemm_count++;
93 // Avoid array segfault
94 continue;
95 }
96
97 tildes = {i_xtilde};
98
99 ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
106 tildes};
107
108 auto grid_descs =
109 conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
110 GroupedConvTraitsType_::NDimSpatial>(1);
111
112 a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{});
113 b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{});
114 c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{});
115
116 const index_t grid_size_grp =
117 TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0),
118 c_grid_descs_m_n[gemm_count].get_length(I1));
119
121 block_ends[gemm_count] = grid_size_ + grid_size_grp;
122
123 grid_size_ += grid_size_grp;
124
125 // Get the actual split N from transformer
126 n_per_split = conv_to_gemm_transformer.GetN();
127 original_n = conv_to_gemm_transformer.GetOriginalN();
129
130 ++gemm_count;
131 }
132 group_stride_a = args.K_; // A: Out NWGK
133 group_stride_b = args.K_ * args.C_ *
134 std::accumulate(args.filter_spatial_lengths_.begin(),
135 args.filter_spatial_lengths_.end(),
136 1,
137 std::multiplies<index_t>()); // B: Wei GKXC
138 group_stride_c = args.C_; // C: In NWGC
139
140 input_batch_stride = args.C_ * args.G_ * args.input_spatial_lengths_[0];
141 output_batch_stride = args.K_ * args.G_ * args.output_spatial_lengths_[0];
142
143 GemmBatch = args.G_;
144 }
145
146 template <
147 typename InLay = typename GroupedConvTraitsType_::InLayout,
148 typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
149 typename OutLay = typename GroupedConvTraitsType_::OutLayout,
150 typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NHWGC> &&
151 std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
152 std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
153 bool>::type = false>
155 {
156 in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
157 static_cast<index_t>(args.N_),
158 static_cast<index_t>(args.C_),
159 static_cast<index_t>(args.input_spatial_lengths_[0]),
160 static_cast<index_t>(args.input_spatial_lengths_[1])};
161 wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
162 static_cast<index_t>(args.K_),
163 static_cast<index_t>(args.C_),
164 static_cast<index_t>(args.filter_spatial_lengths_[0]),
165 static_cast<index_t>(args.filter_spatial_lengths_[1])};
166 out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
167 static_cast<index_t>(args.N_),
168 static_cast<index_t>(args.K_),
169 static_cast<index_t>(args.output_spatial_lengths_[0]),
170 static_cast<index_t>(args.output_spatial_lengths_[1])};
171
172 conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
173 static_cast<index_t>(args.conv_filter_strides_[1])};
175 static_cast<index_t>(args.conv_filter_dilations_[1])};
176 input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
177 static_cast<index_t>(args.input_left_pads_[1])};
178 input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
179 static_cast<index_t>(args.input_right_pads_[1])};
180
181 k_batch = args.k_batch;
182
183 in_ptr = args.in_ptr;
184 wei_ptr = args.wei_ptr;
185 for(index_t d = 0; d < NumDTensor; d++)
186 {
187 ds_ptr[d] = args.ds_ptr[d];
188 }
189 out_ptr = args.out_ptr;
190
191 const index_t Y = wei_g_k_c_xs_lengths[3];
192 const index_t X = wei_g_k_c_xs_lengths[4];
193 const index_t ConvStrideH = conv_filter_strides[0];
194 const index_t ConvStrideW = conv_filter_strides[1];
195 const index_t ConvDilationH = conv_filter_dilations[0];
196 const index_t ConvDilationW = conv_filter_dilations[1];
197 const auto GcdStrideDilationH = gcd(ConvStrideH, ConvDilationH);
198 const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW);
199 const auto YTilde = ConvStrideH / GcdStrideDilationH;
200 const auto XTilde = ConvStrideW / GcdStrideDilationW;
201
202 for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
203 {
204 for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
205 {
206 const auto YDotSlice = integer_divide_ceil(Y - i_ytilde, YTilde);
207 const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde);
208
209 if(XDotSlice * YDotSlice <= 0)
210 {
211 continue;
212 }
213
215 {
216 gemm_count++;
217 // Avoid array segfault
218 continue;
219 }
220
221 tildes = {i_ytilde, i_xtilde};
222
223 ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
230 tildes};
231
232 auto grid_descs = conv_to_gemm_transformer
233 .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
234 GroupedConvTraitsType_::NDimSpatial>(1);
235
236 a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{});
237 b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{});
238 c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{});
239
240 const index_t grid_size_grp =
241 TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0),
242 c_grid_descs_m_n[gemm_count].get_length(I1));
243
245 block_ends[gemm_count] = grid_size_ + grid_size_grp;
246
247 grid_size_ += grid_size_grp;
248
249 // Get the actual split N from transformer
250 n_per_split = conv_to_gemm_transformer.GetN();
251 original_n = conv_to_gemm_transformer.GetOriginalN();
253
254 ++gemm_count;
255 }
256 }
257 group_stride_a = args.K_; // A: Out NWGK
258 group_stride_b = args.K_ * args.C_ *
259 std::accumulate(args.filter_spatial_lengths_.begin(),
260 args.filter_spatial_lengths_.end(),
261 1,
262 std::multiplies<index_t>()); // B: Wei GKXC
263 group_stride_c = args.C_; // C: In NWGC
264
266 args.C_ * args.G_ * args.input_spatial_lengths_[0] * args.input_spatial_lengths_[1];
268 args.K_ * args.G_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
269
270 GemmBatch = args.G_;
271 }
272
273 template <
274 typename InLay = typename GroupedConvTraitsType_::InLayout,
275 typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
276 typename OutLay = typename GroupedConvTraitsType_::OutLayout,
277 typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NDHWGC> &&
278 std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
279 std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
280 bool>::type = false>
282 {
283 in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
284 static_cast<index_t>(args.N_),
285 static_cast<index_t>(args.C_),
286 static_cast<index_t>(args.input_spatial_lengths_[0]),
287 static_cast<index_t>(args.input_spatial_lengths_[1]),
288 static_cast<index_t>(args.input_spatial_lengths_[2])};
289 wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
290 static_cast<index_t>(args.K_),
291 static_cast<index_t>(args.C_),
292 static_cast<index_t>(args.filter_spatial_lengths_[0]),
293 static_cast<index_t>(args.filter_spatial_lengths_[1]),
294 static_cast<index_t>(args.filter_spatial_lengths_[2])};
295 out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
296 static_cast<index_t>(args.N_),
297 static_cast<index_t>(args.K_),
298 static_cast<index_t>(args.output_spatial_lengths_[0]),
299 static_cast<index_t>(args.output_spatial_lengths_[1]),
300 static_cast<index_t>(args.output_spatial_lengths_[2])};
301
302 conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
303 static_cast<index_t>(args.conv_filter_strides_[1]),
304 static_cast<index_t>(args.conv_filter_strides_[2])};
306 static_cast<index_t>(args.conv_filter_dilations_[1]),
307 static_cast<index_t>(args.conv_filter_dilations_[2])};
308 input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
309 static_cast<index_t>(args.input_left_pads_[1]),
310 static_cast<index_t>(args.input_left_pads_[2])};
311 input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
312 static_cast<index_t>(args.input_right_pads_[1]),
313 static_cast<index_t>(args.input_right_pads_[2])};
314
315 k_batch = args.k_batch;
316
317 in_ptr = args.in_ptr;
318 wei_ptr = args.wei_ptr;
319 for(index_t d = 0; d < NumDTensor; d++)
320 {
321 ds_ptr[d] = args.ds_ptr[d];
322 }
323 out_ptr = args.out_ptr;
324
325 const index_t Z = wei_g_k_c_xs_lengths[3];
326 const index_t Y = wei_g_k_c_xs_lengths[4];
327 const index_t X = wei_g_k_c_xs_lengths[5];
328 const index_t ConvStrideD = conv_filter_strides[0];
329 const index_t ConvStrideH = conv_filter_strides[1];
330 const index_t ConvStrideW = conv_filter_strides[2];
331 const index_t ConvDilationD = conv_filter_dilations[0];
332 const index_t ConvDilationH = conv_filter_dilations[1];
333 const index_t ConvDilationW = conv_filter_dilations[2];
334 const auto GcdStrideDilationD = gcd(ConvStrideD, ConvDilationD);
335 const auto GcdStrideDilationH = gcd(ConvStrideH, ConvDilationH);
336 const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW);
337 const auto ZTilde = ConvStrideD / GcdStrideDilationD;
338 const auto YTilde = ConvStrideH / GcdStrideDilationH;
339 const auto XTilde = ConvStrideW / GcdStrideDilationW;
340
341 for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde)
342 {
343 for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
344 {
345 for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
346 {
347 const auto ZDotSlice = integer_divide_ceil(Z - i_ztilde, ZTilde);
348 const auto YDotSlice = integer_divide_ceil(Y - i_ytilde, YTilde);
349 const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde);
350
351 if(ZDotSlice * XDotSlice * YDotSlice <= 0)
352 {
353 continue;
354 }
355
357 {
358 gemm_count++;
359 // Avoid array segfault
360 continue;
361 }
362
363 tildes = {i_ztilde, i_ytilde, i_xtilde};
364
365 ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
372 tildes};
373
374 auto grid_descs = conv_to_gemm_transformer
375 .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
376 GroupedConvTraitsType_::NDimSpatial>(1);
377
378 a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{});
379 b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{});
380 c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{});
381
382 const index_t grid_size_grp =
383 TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0),
384 c_grid_descs_m_n[gemm_count].get_length(I1));
385
387 block_ends[gemm_count] = grid_size_ + grid_size_grp;
388
389 grid_size_ += grid_size_grp;
390
391 // Get the actual split N from transformer
392 n_per_split = conv_to_gemm_transformer.GetN();
393 original_n = conv_to_gemm_transformer.GetOriginalN();
395
396 ++gemm_count;
397 }
398 }
399 }
400
401 group_stride_a = args.K_; // A: Out NWGK
402 group_stride_b = args.K_ * args.C_ *
403 std::accumulate(args.filter_spatial_lengths_.begin(),
404 args.filter_spatial_lengths_.end(),
405 1,
406 std::multiplies<index_t>()); // B: Wei GKXC
407 group_stride_c = args.C_; // C: In NWGC
408
409 input_batch_stride = args.C_ * args.G_ * args.input_spatial_lengths_[0] *
411 output_batch_stride = args.K_ * args.G_ * args.output_spatial_lengths_[0] *
413
414 GemmBatch = args.G_; // C: In NWGC
415 }
416
417 static constexpr index_t MaxGroupedGemmGroupsNum = 128;
418
420 decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(1))>;
421
425
426 static constexpr index_t NonSpatialDims = 3;
430
436
441
442 const void* out_ptr;
443 void* in_ptr;
444 std::array<const void*, NumDTensor> ds_ptr;
445 const void* wei_ptr;
446
450
453
457
458 // Split-N support fields - initialize to safe defaults
459 index_t n_splits = 1; // Number of batch splits (e.g., 2 for 128→64×2)
460 index_t n_per_split = 1; // Batches per split (N_ from transformer)
461 index_t original_n = 1; // Original batch size before splitting
462 index_t input_batch_stride = 0; // Stride to next batch in input tensor
463 index_t output_batch_stride = 0; // Stride to next batch in output tensor
464};
465
504template <typename GroupedConvTraitsType_,
505 typename TilePartitioner_,
506 typename GemmPipeline_,
507 typename EpiloguePipeline_>
509{
510 static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_;
512 GroupedConvTraitsType_::ConvSpecialization;
519
524
526 static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
527
528 static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
529
533
535
540
541 // TODO: Enable this
542 static constexpr bool IsSplitKSupported = false;
543
544 static constexpr auto I0 = number<0>();
545 static constexpr auto I1 = number<1>();
546 static constexpr auto I2 = number<2>();
547 static constexpr auto I3 = number<3>();
548
549 static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
550 "Not supported!");
551 static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
552 static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
553 static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
554 "Not supported C GEMM layout!");
555
556 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
557 {
558 // clang-format off
559 return concat('_', "grouped_convolution_backward_data",
561 "gemm",
562 GemmPipeline::GetName(),
563 "epilogue",
564 EpiloguePipeline::GetName());
565 // clang-format on
566 }
567
569 {
570 // enable batched grouped gemm
571 return dim3(kargs.grid_size_, kargs.GemmBatch, kargs.n_splits * kargs.k_batch);
572 }
573
574 CK_TILE_HOST static constexpr auto BlockSize()
575 {
576 return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
577 }
578
584
586 {
587 return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
588 }
589
590 CK_TILE_HOST static bool
592 {
593 if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
596 {
597 if(kargs.k_batch != 1)
598 {
599 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
600 {
601 CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
602 }
603 return false;
604 }
605 }
606
608 {
609 return false;
610 }
611
612 const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
613 const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
614
615 // check ConvSpecialization
617 {
618 // check if it's 1x1, stride=1 conv
619 for(index_t i = 0; i < NDimSpatial; ++i)
620 {
621 const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
622 const index_t ConvStride = kargs.conv_filter_strides[i];
623 const index_t LeftPad = kargs.input_left_pads[i];
624 const index_t RightPad = kargs.input_right_pads[i];
625
626 if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
627 {
628 return false;
629 }
630 }
631 }
633 {
634 // check if it's 1x1 conv
635 for(index_t i = 0; i < NDimSpatial; ++i)
636 {
637 const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
638 const index_t LeftPad = kargs.input_left_pads[i];
639 const index_t RightPad = kargs.input_right_pads[i];
640
641 if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
642 {
643 return false;
644 }
645 }
646 }
648 {
649 if(ConvC != 1)
650 {
651 return false;
652 }
653 for(index_t i = 0; i < NDimSpatial; ++i)
654 {
655 const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
656
657 if(filter_spatial_dim != I3)
658 {
659 return false;
660 }
661 }
662 }
663
664 namespace ctc = tensor_layout::convolution;
665
666 if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
667 std::is_same_v<InLayout, ctc::NDHWGC>)
668 {
669 // Check access per C
670 if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
671 {
672 CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
673 return false;
674 }
675 }
676 else
677 {
678 CK_TILE_ERROR("Not supported input layout!");
679 return false;
680 }
681
682 // FIXME: layout
683 if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
684 std::is_same_v<WeiLayout, ctc::GKYXC> ||
685 std::is_same_v<WeiLayout, ctc::GKZYXC>)
686 {
687 if(ConvC % GroupedConvTraitsType_::VectorSizeC != 0)
688 {
689 CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
690 return false;
691 }
692 }
693 else
694 {
695 CK_TILE_ERROR("Not supported weight layout!");
696 return false;
697 }
698
699 if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
700 std::is_same_v<OutLayout, ctc::NHWGK> ||
701 std::is_same_v<OutLayout, ctc::NDHWGK>)
702 {
703 if(ConvK % GroupedConvTraitsType_::VectorSizeA != 0)
704 {
705 CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
706 return false;
707 }
708 }
709 else
710 {
711 CK_TILE_ERROR("Not supported output layout!");
712 return false;
713 }
714
715 return true;
716 }
717
718 template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
719 CK_TILE_DEVICE static auto
721 const InDataType* b_ptr,
722 const std::array<const void*, NumDTensor>& ds_ptr,
723 WeiDataType* c_ptr,
725 const index_t group_id)
726 {
727 static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
728 static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
729 const auto& a_tensor_view = [&]() {
731 a_ptr,
732 kargs.a_grid_descs_m_k[group_id]); // A: out
733 }();
734
735 const auto& b_tensor_view = [&]() {
737 b_ptr,
738 kargs.b_grid_descs_n_k[group_id]); // B: weight
739 }();
740
741 const auto& c_tensor_view = [&]() {
743 kargs.c_grid_descs_m_n[group_id]);
744 }();
745
746 const auto& ds_tensor_view = generate_tuple(
747 [&](auto i) {
748 static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
749 "Not supported!");
750 static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
751 "Not supported!");
752 static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, OutDataType>,
753 "Not supported!");
754
756 static_cast<OutDataType*>(ds_ptr[i]), kargs.c_grid_descs_m_n[group_id]);
757 },
759
760 return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
761 }
762
763 template <typename TensorView>
764 CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
765 {
766 const auto& a_pad_view = [&]() {
767 const auto& a_tensor_view = views.at(I0);
768 return pad_tensor_view(a_tensor_view,
772 }();
773
774 const auto& b_pad_view = [&]() {
775 const auto& b_tensor_view = views.at(I1);
776 return pad_tensor_view(b_tensor_view,
780 }();
781
782 const auto& ds_tensor_view = views.at(I2);
783 const auto& ds_pad_view = generate_tuple(
784 [&](auto i) {
785 return pad_tensor_view(ds_tensor_view[i],
789 },
791
792 const auto& c_pad_view = [&]() {
793 const auto& c_tensor_view = views.at(I3);
794 return pad_tensor_view(c_tensor_view,
798 }();
799
800 return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
801 }
802
803 template <typename PadView>
804 CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
805 const index_t i_m,
806 const index_t i_n,
807 const index_t i_k = 0)
808 {
809 const auto& a_pad_view = views.at(I0);
810 const auto& b_pad_view = views.at(I1);
811 const auto& ds_pad_view = views.at(I2);
812 const auto& c_pad_view = views.at(I3);
813
814 const auto& a_block_window = [&]() {
815 return make_tile_window(a_pad_view,
818 {i_m, i_k});
819 }();
820
821 const auto& b_block_window = [&]() {
822 return make_tile_window(b_pad_view,
825 {i_k, i_n});
826 }();
827
828 const auto ds_block_window = generate_tuple(
829 [&](auto i) {
830 return make_tile_window(ds_pad_view[i],
833 {i_m, i_n});
834 },
836
837 auto c_block_window = make_tile_window(
838 c_pad_view,
840 {i_m, i_n});
841
842 return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
843 }
844
857 CK_TILE_DEVICE static void RunGemm(const OutDataType* a_ptr,
858 const InDataType* b_ptr,
859 const std::array<const void*, NumDTensor>& ds_ptr,
860 WeiDataType* c_ptr,
861 void* smem_ptr_0,
863 const index_t block_idx_m,
864 const index_t block_idx_n,
865 const index_t group_id)
866 {
867 // Create Gemm tensor views, pad views and tile windows
868 const auto& gemm_tensor_views_tuple =
870 a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id);
871
872 const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
873 auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
874
875 const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(
876 gemm_pad_views.at(I0).get_tensor_descriptor().get_length(I1)));
877
878 // Run GEMM cooperatively by whole workgroup.
879 const auto& a_block_window = gemm_tile_windows.at(I0);
880 const auto& b_block_window = gemm_tile_windows.at(I1);
881 const auto& d_block_window = gemm_tile_windows.at(I2);
882
883 const auto& c_block_tile = GemmPipeline{}.template operator()(
884 a_block_window, b_block_window, num_loop, smem_ptr_0);
885
886 // Run Epilogue Pipeline
887 auto& c_block_window = gemm_tile_windows.at(I3);
888
889 EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
890 c_block_window, c_block_tile, d_block_window, smem_ptr_0);
891 }
892
908 CK_TILE_DEVICE static void RunGemm2LDS(const OutDataType* a_ptr,
909 const InDataType* b_ptr,
910 const std::array<const void*, NumDTensor>& ds_ptr,
911 WeiDataType* c_ptr,
912 void* __restrict__ smem_ptr_0,
913 void* __restrict__ smem_ptr_1,
915 const index_t block_idx_m,
916 const index_t block_idx_n,
917 const index_t group_id)
918 {
919 // Create Gemm tensor views, pad views and tile windows
920 const auto& gemm_tensor_views_tuple =
922 a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id);
923 const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
924 auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
925
926 const index_t num_loop = amd_wave_read_first_lane(
927 TilePartitioner::GetLoopNum(gemm_tile_windows.at(I0).get_length(I1)));
928
929 // Run GEMM cooperatively by whole workgroup.
930 const auto& a_block_window = gemm_tile_windows.at(I0);
931 const auto& b_block_window = gemm_tile_windows.at(I1);
932 const auto& d_block_window = gemm_tile_windows.at(I2);
933
934 const auto& c_block_tile = GemmPipeline{}.template operator()(
935 a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
936
937 // Run Epilogue Pipeline
938 auto& c_block_window = gemm_tile_windows.at(I3);
939
940 EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
941 c_block_window, c_block_tile, d_block_window, smem_ptr_0);
942 }
943
945 index_t block_id) const
946 {
947 index_t left = 0;
948 index_t right = kargs.gemm_count;
949 index_t group_id = index_t((left + right) >> 1);
950
951 while((!(block_id >= kargs.block_starts[group_id] &&
952 block_id < kargs.block_ends[group_id])) &&
953 left <= right)
954 {
955 if(block_id < kargs.block_starts[group_id])
956 {
957 right = group_id;
958 }
959 else
960 {
961 left = group_id;
962 }
963 group_id = index_t((left + right) >> 1);
964 }
965
966 return group_id;
967 }
968
970 {
971 const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
972 const index_t group_id = FindGroupId(kargs, blockIdX);
973
975 kargs.block_starts[group_id],
976 kargs.c_grid_descs_m_n[group_id].get_length(I0),
977 kargs.c_grid_descs_m_n[group_id].get_length(I1));
978
979 const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
980 const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
981
982 const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
983 const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
984 const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
985 const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
986
987 const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
988
989 // SplitN
990 const index_t split_n_idx = __builtin_amdgcn_readfirstlane(blockIdZ / kargs.k_batch);
991 const index_t split_n_offset =
992 __builtin_amdgcn_readfirstlane(split_n_idx * kargs.n_per_split);
993
994 const long_index_t output_batch_offset =
995 static_cast<long_index_t>(split_n_offset) *
996 static_cast<long_index_t>(kargs.output_batch_stride);
997 const long_index_t input_batch_offset = static_cast<long_index_t>(split_n_offset) *
998 static_cast<long_index_t>(kargs.input_batch_stride);
999
1000 // SplitK
1001 // TODO: Implement SplitK support
1002 // const index_t split_k_idx =
1003 // __builtin_amdgcn_readfirstlane(blockIdZ - split_n_idx * kargs.k_batch);
1004
1005 // options
1006 // conv_bwd_data = Out * Weight = In
1007 const OutDataType* a_ptr =
1008 static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a + output_batch_offset;
1009 const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) + group_offset_b;
1010 InDataType* c_ptr =
1011 static_cast<InDataType*>(kargs.in_ptr) + group_offset_c + input_batch_offset;
1012
1013 // allocate LDS
1014 __shared__ char smem_ptr_0[GetSmemSize()];
1015
1016 if constexpr(GemmPipeline::DoubleSmemBuffer == true)
1017 {
1018 __shared__ char smem_ptr_1[GetSmemSize()];
1019 if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
1020 GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
1022 {
1023 RunGemm2LDS(a_ptr,
1024 b_ptr,
1025 kargs.ds_ptr,
1026 c_ptr,
1027 smem_ptr_0,
1028 smem_ptr_1,
1029 kargs,
1030 i_m,
1031 i_n,
1032 group_id);
1033 }
1034 }
1035 else
1036 {
1037 if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
1038 GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
1040 {
1041 RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n, group_id);
1042 }
1043 }
1044 }
1045};
1046
1047} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/ops/common/tensor_layout.hpp:27
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
bool EnvIsEnabled(EnvVar)
Definition tile/core/utility/env.hpp:156
void CK_TILE_ERROR(Args &&... args) noexcept
Definition tile/core/utility/env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
ConvolutionSpecialization
Definition convolution_specialization.hpp:11
@ Filter1x1Stride1Pad0
Definition convolution_specialization.hpp:14
@ Filter3x3
Definition convolution_specialization.hpp:15
@ Filter1x1Pad0
Definition convolution_specialization.hpp:13
@ atomic_add
Definition arch.hpp:58
std::string gemm_prec_str()
Definition utils.hpp:31
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
int64_t long_index_t
Definition integer.hpp:11
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_HOST_DEVICE constexpr index_t gcd(index_t x, index_t y)
Definition tile/core/numeric/math.hpp:268
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
GroupedConvHostArgs< void *, const void *, const void *, PassThrough > GroupedConvBwdDataHostArgs
Definition grouped_convolution_utils.hpp:53
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST bool is_wave32()
Definition arch.hpp:72
The Grouped Convolution kernel device arguments.
Definition grouped_convolution_backward_data_kernel.hpp:22
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > wei_g_k_c_xs_lengths
Definition grouped_convolution_backward_data_kernel.hpp:428
static constexpr auto I1
Definition grouped_convolution_backward_data_kernel.hpp:35
CK_TILE_HOST GroupedConvBwdDataKernelArgs(const GroupedConvBwdDataHostArgs &args)
Definition grouped_convolution_backward_data_kernel.hpp:45
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_dilations
Definition grouped_convolution_backward_data_kernel.hpp:432
std::array< const void *, NumDTensor > ds_ptr
Definition grouped_convolution_backward_data_kernel.hpp:444
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_strides
Definition grouped_convolution_backward_data_kernel.hpp:431
array< index_t, MaxGroupedGemmGroupsNum > block_starts
Definition grouped_convolution_backward_data_kernel.hpp:451
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_left_pads
Definition grouped_convolution_backward_data_kernel.hpp:433
long_index_t group_stride_b
Definition grouped_convolution_backward_data_kernel.hpp:455
long_index_t group_stride_c
Definition grouped_convolution_backward_data_kernel.hpp:456
array< index_t, MaxGroupedGemmGroupsNum > block_ends
Definition grouped_convolution_backward_data_kernel.hpp:452
const void * out_ptr
Definition grouped_convolution_backward_data_kernel.hpp:442
remove_cvref_t< decltype(ABCGridDescs{}[number< 1 >{}])> BGridDescNK
Definition grouped_convolution_backward_data_kernel.hpp:423
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition grouped_convolution_backward_data_kernel.hpp:23
TransformConvBwdDataToGemm< GroupedConvTraitsType_::NDimSpatial, GroupedConvTraitsType_::ConvSpecialization, GroupedConvTraitsType_::VectorSizeA, GroupedConvTraitsType_::VectorSizeB, GroupedConvTraitsType_::VectorSizeC, true > ConvToGemmTransformer
Definition grouped_convolution_backward_data_kernel.hpp:25
array< index_t, GroupedConvTraitsType_::NDimSpatial > tildes
Definition grouped_convolution_backward_data_kernel.hpp:435
remove_cvref_t< decltype(ABCGridDescs{}[number< 0 >{}])> AGridDescMK
Definition grouped_convolution_backward_data_kernel.hpp:422
const void * wei_ptr
Definition grouped_convolution_backward_data_kernel.hpp:445
remove_cvref_t< decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(1))> ABCGridDescs
Definition grouped_convolution_backward_data_kernel.hpp:419
index_t n_per_split
Definition grouped_convolution_backward_data_kernel.hpp:460
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > out_g_n_k_wos_lengths
Definition grouped_convolution_backward_data_kernel.hpp:429
long_index_t group_stride_a
Definition grouped_convolution_backward_data_kernel.hpp:454
index_t GemmBatch
Definition grouped_convolution_backward_data_kernel.hpp:438
void * in_ptr
Definition grouped_convolution_backward_data_kernel.hpp:443
index_t n_splits
Definition grouped_convolution_backward_data_kernel.hpp:459
index_t gemm_count
Definition grouped_convolution_backward_data_kernel.hpp:440
array< CGridDescMN, MaxGroupedGemmGroupsNum > c_grid_descs_m_n
Definition grouped_convolution_backward_data_kernel.hpp:449
index_t original_n
Definition grouped_convolution_backward_data_kernel.hpp:461
index_t grid_size_
Definition grouped_convolution_backward_data_kernel.hpp:439
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_right_pads
Definition grouped_convolution_backward_data_kernel.hpp:434
array< BGridDescNK, MaxGroupedGemmGroupsNum > b_grid_descs_n_k
Definition grouped_convolution_backward_data_kernel.hpp:448
index_t k_batch
Definition grouped_convolution_backward_data_kernel.hpp:437
static constexpr auto I0
Definition grouped_convolution_backward_data_kernel.hpp:34
static constexpr index_t MaxGroupedGemmGroupsNum
Definition grouped_convolution_backward_data_kernel.hpp:417
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > in_g_n_c_wis_lengths
Definition grouped_convolution_backward_data_kernel.hpp:427
static constexpr index_t NumDTensor
Definition grouped_convolution_backward_data_kernel.hpp:32
index_t output_batch_stride
Definition grouped_convolution_backward_data_kernel.hpp:463
index_t input_batch_stride
Definition grouped_convolution_backward_data_kernel.hpp:462
array< AGridDescMK, MaxGroupedGemmGroupsNum > a_grid_descs_m_k
Definition grouped_convolution_backward_data_kernel.hpp:447
remove_cvref_t< decltype(ABCGridDescs{}[number< 2 >{}])> CGridDescMN
Definition grouped_convolution_backward_data_kernel.hpp:424
static constexpr index_t NonSpatialDims
Definition grouped_convolution_backward_data_kernel.hpp:426
InPtr in_ptr
Definition grouped_convolution_utils.hpp:39
OutPtr out_ptr
Definition grouped_convolution_utils.hpp:42
WeiPtr wei_ptr
Definition grouped_convolution_utils.hpp:40
index_t k_batch
Definition grouped_convolution_utils.hpp:43
const std::vector< const void * > ds_ptr
Definition grouped_convolution_utils.hpp:41
The Grouped Convolution Backward Data kernel template.
Definition grouped_convolution_backward_data_kernel.hpp:509
static CK_TILE_HOST constexpr GroupedConvBwdDataKernelArgsSpecialized MakeKernelArgs(const GroupedConvBwdDataHostArgs &hostArgs)
Definition grouped_convolution_backward_data_kernel.hpp:580
static constexpr index_t NDimSpatial
Definition grouped_convolution_backward_data_kernel.hpp:510
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition grouped_convolution_backward_data_kernel.hpp:514
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n, const index_t i_k=0)
Definition grouped_convolution_backward_data_kernel.hpp:804
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition grouped_convolution_backward_data_kernel.hpp:585
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition grouped_convolution_backward_data_kernel.hpp:764
remove_cvref_t< typename GemmPipeline::ADataType > InDataType
Definition grouped_convolution_backward_data_kernel.hpp:530
static constexpr index_t MaxGroupedGemmGroupsNum
Definition grouped_convolution_backward_data_kernel.hpp:538
static constexpr auto I1
Definition grouped_convolution_backward_data_kernel.hpp:545
static constexpr auto I3
Definition grouped_convolution_backward_data_kernel.hpp:547
remove_cvref_t< typename GroupedConvTraitsType_::OutLayout > OutLayout
Definition grouped_convolution_backward_data_kernel.hpp:522
GroupedConvBwdDataKernelArgs< GroupedConvTraitsType_, TilePartitioner > GroupedConvBwdDataKernelArgsSpecialized
Definition grouped_convolution_backward_data_kernel.hpp:536
static constexpr ConvolutionSpecialization ConvSpecialization
Definition grouped_convolution_backward_data_kernel.hpp:511
static CK_TILE_HOST constexpr auto BlockSize()
Definition grouped_convolution_backward_data_kernel.hpp:574
static constexpr index_t NumDTensor
Definition grouped_convolution_backward_data_kernel.hpp:526
remove_cvref_t< typename GemmPipeline::BDataType > WeiDataType
Definition grouped_convolution_backward_data_kernel.hpp:531
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition grouped_convolution_backward_data_kernel.hpp:515
remove_cvref_t< typename EpiloguePipeline::ODataType > OutDataType
Definition grouped_convolution_backward_data_kernel.hpp:534
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition grouped_convolution_backward_data_kernel.hpp:513
remove_cvref_t< typename GroupedConvTraitsType_::WeiLayout > WeiLayout
Definition grouped_convolution_backward_data_kernel.hpp:521
static constexpr index_t kBlockSize
Definition grouped_convolution_backward_data_kernel.hpp:528
static CK_TILE_HOST bool IsSupportedArgument(const GroupedConvBwdDataKernelArgsSpecialized &kargs)
Definition grouped_convolution_backward_data_kernel.hpp:591
remove_cvref_t< typename GemmPipeline::BLayout > GemmBLayout
Definition grouped_convolution_backward_data_kernel.hpp:517
remove_cvref_t< typename GroupedConvTraitsType_::DsLayout > DsLayout
Definition grouped_convolution_backward_data_kernel.hpp:523
static constexpr auto I2
Definition grouped_convolution_backward_data_kernel.hpp:546
static CK_TILE_DEVICE auto MakeGemmTensorViews(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, const GroupedConvBwdDataKernelArgsSpecialized &kargs, const index_t group_id)
Definition grouped_convolution_backward_data_kernel.hpp:720
static CK_TILE_HOST auto GridSize(const GroupedConvBwdDataKernelArgsSpecialized &kargs)
Definition grouped_convolution_backward_data_kernel.hpp:568
remove_cvref_t< typename GemmPipeline::ALayout > GemmALayout
Definition grouped_convolution_backward_data_kernel.hpp:516
remove_cvref_t< typename EpiloguePipeline::DsLayout > GemmDsLayout
Definition grouped_convolution_backward_data_kernel.hpp:525
CK_TILE_DEVICE index_t FindGroupId(const GroupedConvBwdDataKernelArgsSpecialized &kargs, index_t block_id) const
Definition grouped_convolution_backward_data_kernel.hpp:944
static CK_TILE_DEVICE void RunGemm(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *smem_ptr_0, const GroupedConvBwdDataKernelArgsSpecialized &kargs, const index_t block_idx_m, const index_t block_idx_n, const index_t group_id)
Runs single GEMM problem cooperatively by whole workgroup.
Definition grouped_convolution_backward_data_kernel.hpp:857
static CK_TILE_DEVICE void RunGemm2LDS(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const GroupedConvBwdDataKernelArgsSpecialized &kargs, const index_t block_idx_m, const index_t block_idx_n, const index_t group_id)
Runs single GEMM problem cooperatively by whole workgroup.
Definition grouped_convolution_backward_data_kernel.hpp:908
CK_TILE_DEVICE void operator()(GroupedConvBwdDataKernelArgsSpecialized kargs) const
Definition grouped_convolution_backward_data_kernel.hpp:969
static constexpr bool IsSplitKSupported
Definition grouped_convolution_backward_data_kernel.hpp:542
remove_cvref_t< typename GroupedConvTraitsType_::InLayout > InLayout
Definition grouped_convolution_backward_data_kernel.hpp:520
remove_cvref_t< typename GemmPipeline::CLayout > GemmCLayout
Definition grouped_convolution_backward_data_kernel.hpp:518
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition grouped_convolution_backward_data_kernel.hpp:532
static CK_TILE_HOST const std::string GetName()
Definition grouped_convolution_backward_data_kernel.hpp:556
static constexpr auto I0
Definition grouped_convolution_backward_data_kernel.hpp:544
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from 1D raw-indexes.
Definition gemm_tile_partitioner.hpp:192
Definition transform_conv_bwd_data_to_gemm.hpp:22
CK_TILE_HOST constexpr IndexType GetN() const
Definition transform_conv_bwd_data_to_gemm.hpp:119
CK_TILE_HOST constexpr IndexType GetOriginalN() const
Definition transform_conv_bwd_data_to_gemm.hpp:120
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
std::vector< ck_tile::long_index_t > input_spatial_lengths_
Definition tile/host/convolution_parameter.hpp:130
ck_tile::long_index_t K_
Definition tile/host/convolution_parameter.hpp:126
std::vector< ck_tile::long_index_t > output_spatial_lengths_
Definition tile/host/convolution_parameter.hpp:131
std::vector< ck_tile::long_index_t > input_right_pads_
Definition tile/host/convolution_parameter.hpp:137
ck_tile::long_index_t G_
Definition tile/host/convolution_parameter.hpp:124
std::vector< ck_tile::long_index_t > conv_filter_strides_
Definition tile/host/convolution_parameter.hpp:133
std::vector< ck_tile::long_index_t > filter_spatial_lengths_
Definition tile/host/convolution_parameter.hpp:129
ck_tile::long_index_t C_
Definition tile/host/convolution_parameter.hpp:127
ck_tile::long_index_t N_
Definition tile/host/convolution_parameter.hpp:125
std::vector< ck_tile::long_index_t > input_left_pads_
Definition tile/host/convolution_parameter.hpp:136
std::vector< ck_tile::long_index_t > conv_filter_dilations_
Definition tile/host/convolution_parameter.hpp:134
Definition type_traits.hpp:115
Definition tile/core/container/sequence.hpp:49
#define CK_TILE_ENV(name)
Definition tile/core/utility/env.hpp:145