device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp Source File

device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp Source File#

Composable Kernel: device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp Source File
device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
10#include "ck/utility/env.hpp"
19
20namespace ck {
21namespace tensor_operation {
22namespace device {
23
24// out[N, Ho, Wo, K] =
25// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K]) + residual[N, Ho, Wo, K]
26template <
27 typename InDataType,
28 typename WeiDataType,
29 typename OutDataType,
30 typename AccDataType,
31 typename InElementwiseOperation,
32 typename WeiElementwiseOperation,
33 typename OutElementwiseOperation,
34 ConvolutionForwardSpecialization ConvForwardSpecialization,
35 ck::index_t BlockSize,
36 ck::index_t MPerBlock,
37 ck::index_t NPerBlock,
38 ck::index_t K0PerBlock,
39 ck::index_t K1,
40 ck::index_t MPerXDL,
41 ck::index_t NPerXDL,
42 ck::index_t MXdlPerWave,
43 ck::index_t NXdlPerWave,
44 typename ABlockTransferThreadClusterLengths_K0_M_K1,
45 typename ABlockTransferThreadClusterArrangeOrder,
46 typename ABlockTransferSrcAccessOrder,
47 ck::index_t ABlockTransferSrcVectorDim,
48 ck::index_t ABlockTransferSrcScalarPerVector,
49 ck::index_t ABlockTransferDstScalarPerVector_K1,
50 bool ABlockLdsAddExtraM,
51 typename BBlockTransferThreadClusterLengths_K0_N_K1,
52 typename BBlockTransferThreadClusterArrangeOrder,
53 typename BBlockTransferSrcAccessOrder,
54 ck::index_t BBlockTransferSrcVectorDim,
55 ck::index_t BBlockTransferSrcScalarPerVector,
56 ck::index_t BBlockTransferDstScalarPerVector_K1,
57 bool BBlockLdsAddExtraN,
58 index_t CShuffleMXdlPerWavePerShuffle,
59 index_t CShuffleNXdlPerWavePerShuffle,
60 typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
61 index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
62struct
64 : public DeviceConvFwdBiasActivationAdd<InElementwiseOperation,
65 WeiElementwiseOperation,
66 OutElementwiseOperation>
67{
68 using DeviceOp =
70
72 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
73 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
74
75 using ADataType = InDataType;
76 using BDataType = WeiDataType;
77 using CDataType = OutDataType;
78
79 // TODO make A/B datatype different
80 using ABDataType = InDataType;
81
82 // TODO make it support any # of spatial dimensions
83 static constexpr index_t NDimSpatial = 2;
84
85 static constexpr auto I0 = Number<0>{};
86 static constexpr auto I1 = Number<1>{};
87 static constexpr auto I2 = Number<2>{};
88 static constexpr auto I3 = Number<3>{};
89 static constexpr auto I4 = Number<4>{};
90
91 static constexpr auto K1Number = Number<K1>{};
92 static constexpr auto GemmK1Number = K1Number;
93
94 static auto
98 std::vector<ck::index_t> input_spatial_lengths,
99 std::vector<ck::index_t> filter_spatial_lengths,
100 std::vector<ck::index_t> output_spatial_lengths,
101 std::vector<ck::index_t> conv_filter_strides,
102 std::vector<ck::index_t> conv_filter_dilations,
103 std::vector<ck::index_t> input_left_pads,
104 std::vector<ck::index_t> input_right_pads)
105 {
106 using namespace ck;
107
108 const index_t Hi = input_spatial_lengths[0];
109 const index_t Wi = input_spatial_lengths[1];
110
111 const index_t Ho = output_spatial_lengths[0];
112 const index_t Wo = output_spatial_lengths[1];
113
114 const index_t Y = filter_spatial_lengths[0];
115 const index_t X = filter_spatial_lengths[1];
116
117 const index_t ConvStrideH = conv_filter_strides[0];
118 const index_t ConvStrideW = conv_filter_strides[1];
119
120 const index_t ConvDilationH = conv_filter_dilations[0];
121 const index_t ConvDilationW = conv_filter_dilations[1];
122
123 const index_t InLeftPadH = input_left_pads[0];
124 const index_t InLeftPadW = input_left_pads[1];
125
126 const index_t InRightPadH = input_right_pads[0];
127 const index_t InRightPadW = input_right_pads[1];
128
129 const index_t GemmMRaw = N * Ho * Wo;
130 const index_t GemmN = K;
131
132 const auto GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock);
133 const auto GemmMPad = GemmM - GemmMRaw;
134
135 if constexpr(ConvForwardSpecialization ==
137 { // 1x1, stride=1, pad=0
138 const index_t GemmK = Y * X * C;
139 assert(GemmK % GemmK1Number == 0);
140
141 const index_t GemmK0 = GemmK / GemmK1Number;
142
143 // A: input tensor
144 const auto in_gemmmraw_gemmk_grid_desc =
146
147 const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
148 in_gemmmraw_gemmk_grid_desc,
150 make_right_pad_transform(GemmMRaw, GemmMPad)),
153
154 // B: weight tensor
155 const auto wei_gemmn_gemmk_grid_desc =
157
158 const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
159 wei_gemmn_gemmk_grid_desc,
164
165 // C: output tensor
166 const auto out_gemmmraw_gemmn_grid_desc =
168
169 const auto out_gemmm_gemmn_grid_desc =
170 transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
171 make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
175
176 // C0: bias tensor: assume a contiguous vector
177 const auto bias_grid_desc_gemmm_gemmn =
179
180 // C1: residual tensor: assume same layout as output tensor
181 const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc;
182
183 return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
184 wei_gemmk0_gemmn_gemmk1_grid_desc,
185 out_gemmm_gemmn_grid_desc,
186 bias_grid_desc_gemmm_gemmn,
187 resi_grid_desc_gemmm_gemmn);
188 }
189 else if constexpr(ConvForwardSpecialization ==
191 { // 1x1, pad=0
192 const index_t GemmK = Y * X * C;
193 assert(GemmK % GemmK1Number == 0);
194
195 const index_t GemmK0 = GemmK / GemmK1Number;
196
197 // A: input tensor
198 const auto in_n_hi_wi_c_grid_desc =
200
201 const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
202 in_n_hi_wi_c_grid_desc,
204 make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
205 make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
209
210 const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
211 in_n_ho_wo_c_grid_desc,
213 make_merge_transform(make_tuple(N, Ho, Wo))),
216
217 const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
218 in_gemmk0_gemmmraw_gemmk1_grid_desc,
220 make_right_pad_transform(GemmMRaw, GemmMPad),
224
225 // B: weight tensor
226 const auto wei_gemmn_gemmk_grid_desc =
228
229 const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
230 wei_gemmn_gemmk_grid_desc,
235
236 // C: output tensor
237 const auto out_gemmmraw_gemmn_grid_desc =
239
240 const auto out_gemmm_gemmn_grid_desc =
241 transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
242 make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
246
247 // C0: bias tensor: assume a contiguous vector
248 const auto bias_grid_desc_gemmm_gemmn =
250
251 // C1: residual tensor: assume same layout as output tensor
252 const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc;
253
254 return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
255 wei_gemmk0_gemmn_gemmk1_grid_desc,
256 out_gemmm_gemmn_grid_desc,
257 bias_grid_desc_gemmm_gemmn,
258 resi_grid_desc_gemmm_gemmn);
259 }
260 else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC)
261 { // C = odd value
262 const index_t GemmKRaw = Y * X * C;
263 const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number);
264 const index_t GemmKPad = GemmK - GemmKRaw;
265 const index_t GemmK0 = GemmK / GemmK1Number;
266
267 // A: input tensor
268 const auto in_n_hi_wi_c_grid_desc =
270
271 const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
272 in_n_hi_wi_c_grid_desc,
274 make_pad_transform(Hi, InLeftPadH, InRightPadH),
275 make_pad_transform(Wi, InLeftPadW, InRightPadW),
279
280 const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
281 in_n_hip_wip_c_grid_desc,
284 make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
285 make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
289
290 const auto in_gemmkraw_gemmmraw_grid_desc =
291 transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
293 make_merge_transform(make_tuple(N, Ho, Wo))),
296
297 const auto in_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
298 in_gemmkraw_gemmmraw_grid_desc,
299 make_tuple(make_right_pad_transform(GemmKRaw, GemmKPad),
300 make_right_pad_transform(GemmMRaw, GemmMPad)),
303
304 const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
305 in_gemmk_gemmm_grid_desc,
310
311 // B: weight tensor
312 const auto wei_k_yxc_grid_desc =
314
315 const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
316 wei_k_yxc_grid_desc,
318 make_right_pad_transform(GemmKRaw, GemmKPad)),
321
322 const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
323 wei_gemmk_gemmn_grid_desc,
328
329 // C: output tensor
330 const auto out_nhowo_k_grid_desc =
332
333 const auto out_gemmmraw_gemmn_grid_desc =
334 transform_tensor_descriptor(out_nhowo_k_grid_desc,
339
340 const auto out_gemmm_gemmn_grid_desc =
341 transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
342 make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
346
347 // C0: bias tensor: assume a contiguous vector
348 const auto bias_grid_desc_gemmm_gemmn =
350
351 // C1: residual tensor: assume same layout as output tensor
352 const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc;
353
354 return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
355 wei_gemmk0_gemmn_gemmk1_grid_desc,
356 out_gemmm_gemmn_grid_desc,
357 bias_grid_desc_gemmm_gemmn,
358 resi_grid_desc_gemmm_gemmn);
359 }
360 else
361 {
362 const index_t GemmK = Y * X * C;
363 assert(GemmK % GemmK1Number == 0);
364
365 const index_t GemmK0 = GemmK / GemmK1Number;
366
367 // A: input tensor
368 const auto in_n_hi_wi_c_grid_desc =
370
371 const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
372 in_n_hi_wi_c_grid_desc,
374 make_pad_transform(Hi, InLeftPadH, InRightPadH),
375 make_pad_transform(Wi, InLeftPadW, InRightPadW),
379
380 const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
381 in_n_hip_wip_c_grid_desc,
384 make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
385 make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
389
390 const auto in_gemmk_gemmmraw_grid_desc =
391 transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
393 make_merge_transform(make_tuple(N, Ho, Wo))),
396
397 const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
398 in_gemmk_gemmmraw_grid_desc,
403
404 const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
405 in_gemmk0_gemmmraw_gemmk1_grid_desc,
407 make_right_pad_transform(GemmMRaw, GemmMPad),
411
412 // B: weight tensor
413 const auto wei_k_yxc_grid_desc =
415
416 const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
417 wei_k_yxc_grid_desc,
421
422 const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
423 wei_gemmk_gemmn_grid_desc,
428
429 // C: output tensor
430 const auto out_nhowo_k_grid_desc =
432
433 const auto out_gemmmraw_gemmn_grid_desc =
434 transform_tensor_descriptor(out_nhowo_k_grid_desc,
439
440 const auto out_gemmm_gemmn_grid_desc =
441 transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
442 make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
446
447 // C0: bias tensor: assume a contiguous vector
448 const auto bias_grid_desc_gemmm_gemmn =
450
451 // C1: residual tensor: assume same layout as output tensor
452 const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc;
453
454 return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
455 wei_gemmk0_gemmn_gemmk1_grid_desc,
456 out_gemmm_gemmn_grid_desc,
457 bias_grid_desc_gemmm_gemmn,
458 resi_grid_desc_gemmm_gemmn);
459 }
460 }
461
463 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}));
464
467 using CGridDesc_M_N = remove_cvref_t<decltype(GridDescs{}[I2])>;
468 using C0GridDesc_M_N = remove_cvref_t<decltype(GridDescs{}[I3])>;
469 using C1GridDesc_M_N = remove_cvref_t<decltype(GridDescs{}[I4])>;
470
472
473 // GridwiseGemm
474 template <index_t NXdlPerWave_>
476 BlockSize,
477 ABDataType, // TODO: distinguish A/B datatype
478 AccDataType,
479 CDataType,
486 InElementwiseOperation,
487 WeiElementwiseOperation,
488 OutElementwiseOperation,
489 MPerBlock,
490 NPerBlock,
491 K0PerBlock,
492 MPerXDL,
493 NPerXDL,
494 K1,
495 MXdlPerWave,
496 NXdlPerWave_,
497 ABlockTransferThreadClusterLengths_K0_M_K1,
498 Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder,
499 Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder,
500 2, // ABlockTransferSrcVectorDim,
501 ABlockTransferSrcScalarPerVector,
502 ABlockTransferDstScalarPerVector_K1,
503 false, // AThreadTransferSrcResetCoordinateAfterRun,
504 ABlockLdsAddExtraM,
505 BBlockTransferThreadClusterLengths_K0_N_K1,
506 Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder,
507 Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder,
508 2, // BBlockTransferSrcVectorDim,
509 BBlockTransferSrcScalarPerVector,
510 BBlockTransferDstScalarPerVector_K1,
511 false, // BThreadTransferSrcResetCoordinateAfterRun,
512 BBlockLdsAddExtraN,
513 CShuffleMXdlPerWavePerShuffle,
514 CShuffleNXdlPerWavePerShuffle,
515 CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
516 CBlockTransferScalarPerVector_NWaveNPerXdl>;
519
520 // Argument
521 struct Argument : public BaseArgument
522 {
523 Argument(const InDataType* p_in_grid,
524 const WeiDataType* p_wei_grid,
525 OutDataType* p_out_grid,
526 const OutDataType* p_bias_grid,
527 const OutDataType* p_resi_grid,
528 ck::index_t N,
529 ck::index_t K,
530 ck::index_t C,
531 std::vector<ck::index_t> input_spatial_lengths,
532 std::vector<ck::index_t> filter_spatial_lengths,
533 std::vector<ck::index_t> output_spatial_lengths,
534 std::vector<ck::index_t> conv_filter_strides,
535 std::vector<ck::index_t> conv_filter_dilations,
536 std::vector<ck::index_t> input_left_pads,
537 std::vector<ck::index_t> input_right_pads,
538 InElementwiseOperation in_element_op,
539 WeiElementwiseOperation wei_element_op,
540 OutElementwiseOperation out_element_op)
541 : p_a_grid_{p_in_grid},
542 p_b_grid_{p_wei_grid},
543 p_c_grid_{p_out_grid},
544 p_c0_grid_{p_bias_grid},
545 p_c1_grid_{p_resi_grid},
552 in_element_op_{in_element_op},
553 wei_element_op_{wei_element_op},
554 out_element_op_{out_element_op},
555 Conv_N_{N},
556 Conv_K_{K},
557 Conv_C_{C},
558 input_spatial_lengths_{input_spatial_lengths},
559 filter_spatial_lengths_{filter_spatial_lengths},
560 output_spatial_lengths_{output_spatial_lengths},
561 conv_filter_strides_{conv_filter_strides},
562 conv_filter_dilations_{conv_filter_dilations},
563 input_left_pads_{input_left_pads},
564 input_right_pads_{input_right_pads}
565 {
566 const auto descs =
568 K,
569 C,
570 input_spatial_lengths,
571 filter_spatial_lengths,
572 output_spatial_lengths,
573 conv_filter_strides,
574 conv_filter_dilations,
575 input_left_pads,
576 input_right_pads);
577
578 a_grid_desc_k0_m_k1_ = descs[I0];
579 b_grid_desc_k0_n_k1_ = descs[I1];
580 c_grid_desc_m_n_ = descs[I2];
581 c0_grid_desc_m_n_ = descs[I3];
582 c1_grid_desc_m_n_ = descs[I4];
583
585 }
586
587 // private:
598
600 InElementwiseOperation in_element_op_;
601 WeiElementwiseOperation wei_element_op_;
602 OutElementwiseOperation out_element_op_;
603 // for checking IsSupportedArgument()
607 std::vector<index_t> input_spatial_lengths_;
608 std::vector<index_t> filter_spatial_lengths_;
609 std::vector<index_t> output_spatial_lengths_;
610 std::vector<index_t> conv_filter_strides_;
611 std::vector<index_t> conv_filter_dilations_;
612 std::vector<index_t> input_left_pads_;
613 std::vector<index_t> input_right_pads_;
614 };
615
616 // Invoker
617 struct Invoker : public BaseInvoker
618 {
620
621 template <typename GridwiseGemm>
622 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
623 {
624 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
625 {
626 std::cout << DeviceOp{}.GetTypeString() << std::endl;
627 std::cout << "N " << arg.Conv_N_ << ", " << "K " << arg.Conv_K_ << ", " << "C "
628 << arg.Conv_C_ << ", " << std::endl;
629 std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", "
630 << arg.filter_spatial_lengths_[1] << ", " << std::endl;
631 std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", "
632 << arg.input_spatial_lengths_[1] << ", " << std::endl;
633 std::cout << "Ho Wo " << arg.output_spatial_lengths_[0] << ", "
634 << arg.output_spatial_lengths_[1] << ", " << std::endl;
635 std::cout << "Strides " << arg.conv_filter_strides_[0] << ", "
636 << arg.conv_filter_strides_[1] << ", " << std::endl;
637 std::cout << "Dilations " << arg.conv_filter_dilations_[0] << ", "
638 << arg.conv_filter_dilations_[1] << ", " << std::endl;
639 std::cout << "InLeftPads " << arg.input_left_pads_[0] << ", "
640 << arg.input_left_pads_[1] << ", " << std::endl;
641 std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
642 << arg.input_right_pads_[1] << ", " << std::endl;
643
644 std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
645 << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
646 << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
647
648 std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
649 << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
650 << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
651
652 std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
653 << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
654
655 std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0)
656 << ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
657
658 std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0)
659 << ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
660 }
661
662 if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
666 {
667 throw std::runtime_error(
668 "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r3 has invalid setting");
669 }
670
671 const index_t grid_size =
673
674 const auto K =
675 arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
676
677 float ave_time = 0;
678
679 auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
680 GridwiseGemm::
681 MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
682 arg.c_grid_desc_m_n_);
683
684 auto c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
685 GridwiseGemm::
686 MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
688
689 auto c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
690 GridwiseGemm::
691 MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
693 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
694 {
695 const auto kernel = kernel_gemm_xdlops_v3r3<
696 GridwiseGemm,
697 ADataType, // TODO: distiguish A/B datatype
698 CDataType,
702 typename GridwiseGemm::
703 CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
705 typename GridwiseGemm::
706 C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
708 typename GridwiseGemm::
709 C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
710 InElementwiseOperation,
711 WeiElementwiseOperation,
712 OutElementwiseOperation,
713 Block2CTileMap,
714 true>;
715
716 ave_time = launch_and_time_kernel(
717 stream_config,
718 kernel,
719 dim3(grid_size),
720 dim3(BlockSize),
721 0,
722 arg.p_a_grid_,
723 arg.p_b_grid_,
724 arg.p_c_grid_,
725 arg.p_c0_grid_,
726 arg.p_c1_grid_,
729 c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
730 c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
731 c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
732 arg.in_element_op_,
733 arg.wei_element_op_,
734 arg.out_element_op_,
736 }
737 else
738 {
739 const auto kernel = kernel_gemm_xdlops_v3r3<
740 GridwiseGemm,
741 ADataType, // TODO: distiguish A/B datatype
742 CDataType,
746 typename GridwiseGemm::
747 CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
749 typename GridwiseGemm::
750 C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
752 typename GridwiseGemm::
753 C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
754 InElementwiseOperation,
755 WeiElementwiseOperation,
756 OutElementwiseOperation,
757 Block2CTileMap,
758 false>;
759
760 ave_time = launch_and_time_kernel(
761 stream_config,
762 kernel,
763 dim3(grid_size),
764 dim3(BlockSize),
765 0,
766 arg.p_a_grid_,
767 arg.p_b_grid_,
768 arg.p_c_grid_,
769 arg.p_c0_grid_,
770 arg.p_c1_grid_,
773 c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
774 c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
775 c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
776 arg.in_element_op_,
777 arg.wei_element_op_,
778 arg.out_element_op_,
780 }
781
782 return ave_time;
783 }
784
786
787 float Run(const BaseArgument* p_arg,
788 const StreamConfig& stream_config = StreamConfig{}) override
789 {
790 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
791 }
792 };
793
794 static constexpr bool IsValidCompilationParameter()
795 {
796 // TODO: properly implement this check
797 return true;
798 }
799
800 static bool IsSupportedArgument(const Argument& arg)
801 {
803 {
804 return false;
805 }
806 if constexpr(ConvForwardSpecialization ==
808 {
809 // check if it's 1x1, stride=1 conv
810 if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
811 arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 &&
812 arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
813 arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
814 {
815 return false;
816 }
817 }
818 else if constexpr(ConvForwardSpecialization ==
820 {
821 // check if it's 1x1 conv
822 if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
823 arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
824 arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
825 {
826 return false;
827 }
828 }
829
830 // vector load A/B matrix from global memory
831 if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 &&
832 arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 &&
833 arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
834 {
835 return false;
836 }
837
838 // vector store C matrix into global memory
839 if(!(arg.Conv_K_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0))
840 {
841 return false;
842 }
843
844 // Gridwise GEMM size
845 if(get_warp_size() == 64)
846 {
847 if constexpr(NXdlPerWave64 > 0)
848 {
853 }
854 }
855 else
856 {
857 if constexpr(NXdlPerWave32 > 0)
858 {
863 }
864 }
865 return false;
866 }
867
868 bool IsSupportedArgument(const BaseArgument* p_arg) override
869 {
870 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
871 }
872
873 static auto MakeArgument(const InDataType* p_in_grid,
874 const WeiDataType* p_wei_grid,
875 OutDataType* p_out_grid,
876 const OutDataType* p_bias_grid,
877 const OutDataType* p_resi_grid,
878 ck::index_t N,
879 ck::index_t K,
880 ck::index_t C,
881 std::vector<ck::index_t> input_spatial_lengths,
882 std::vector<ck::index_t> filter_spatial_lengths,
883 std::vector<ck::index_t> output_spatial_lengths,
884 std::vector<ck::index_t> conv_filter_strides,
885 std::vector<ck::index_t> conv_filter_dilations,
886 std::vector<ck::index_t> input_left_pads,
887 std::vector<ck::index_t> input_right_pads,
888 InElementwiseOperation in_element_op,
889 WeiElementwiseOperation wei_element_op,
890 OutElementwiseOperation out_element_op)
891 {
892 return Argument{p_in_grid,
893 p_wei_grid,
894 p_out_grid,
895 p_bias_grid,
896 p_resi_grid,
897 N,
898 K,
899 C,
900 input_spatial_lengths,
901 filter_spatial_lengths,
902 output_spatial_lengths,
903 conv_filter_strides,
904 conv_filter_dilations,
905 input_left_pads,
906 input_right_pads,
907 in_element_op,
908 wei_element_op,
909 out_element_op};
910 }
911
912 static auto MakeInvoker() { return Invoker{}; }
913
914 std::unique_ptr<BaseArgument>
915 MakeArgumentPointer(const void* p_in_grid,
916 const void* p_wei_grid,
917 void* p_out_grid,
918 const void* p_bias_grid,
919 const void* p_resi_grid,
920 ck::index_t N,
921 ck::index_t K,
922 ck::index_t C,
923 std::vector<ck::index_t> input_spatial_lengths,
924 std::vector<ck::index_t> filter_spatial_lengths,
925 std::vector<ck::index_t> output_spatial_lengths,
926 std::vector<ck::index_t> conv_filter_strides,
927 std::vector<ck::index_t> conv_filter_dilations,
928 std::vector<ck::index_t> input_left_pads,
929 std::vector<ck::index_t> input_right_pads,
930 InElementwiseOperation in_element_op,
931 WeiElementwiseOperation wei_element_op,
932 OutElementwiseOperation out_element_op) override
933 {
934 return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
935 static_cast<const WeiDataType*>(p_wei_grid),
936 static_cast<OutDataType*>(p_out_grid),
937 static_cast<const OutDataType*>(p_bias_grid),
938 static_cast<const OutDataType*>(p_resi_grid),
939 N,
940 K,
941 C,
942 input_spatial_lengths,
943 filter_spatial_lengths,
944 output_spatial_lengths,
945 conv_filter_strides,
946 conv_filter_dilations,
947 input_left_pads,
948 input_right_pads,
949 in_element_op,
950 wei_element_op,
951 out_element_op);
952 }
953
954 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
955 {
956 return std::make_unique<Invoker>(Invoker{});
957 }
958
959 std::string GetTypeString() const override
960 {
961 auto str = std::stringstream();
962
963 // clang-format off
964 str << "DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
965 << "<"
966 << BlockSize << ", "
967 << MPerBlock << ", "
968 << NPerBlock << ", "
969 << K0PerBlock << ", "
970 << K1 << ", "
971 << MPerXDL << ", "
972 << NPerXDL << ", "
973 << MXdlPerWave << ", "
974 << NXdlPerWave << ", "
975 << ABlockTransferSrcScalarPerVector << ", "
976 << ABlockTransferDstScalarPerVector_K1 << ", "
977 << BBlockTransferSrcScalarPerVector << ", "
978 << BBlockTransferDstScalarPerVector_K1 << ", "
979 << CShuffleMXdlPerWavePerShuffle << ", "
980 << CShuffleNXdlPerWavePerShuffle << ", "
981 << CBlockTransferScalarPerVector_NWaveNPerXdl
982 << ">";
983 // clang-format on
984
985 return str.str();
986 }
987};
988} // namespace device
989} // namespace tensor_operation
990} // namespace ck
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
ConvolutionForwardSpecialization
Definition convolution_forward_specialization.hpp:15
@ OddC
Definition convolution_forward_specialization.hpp:19
@ Filter1x1Stride1Pad0
Definition convolution_forward_specialization.hpp:18
@ Filter1x1Pad0
Definition convolution_forward_specialization.hpp:17
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:19
__host__ __device__ constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition multi_index_transform_helper.hpp:48
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__global__ void kernel_gemm_xdlops_v3r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const FloatC *__restrict__ p_c0_grid, const FloatC *__restrict__ p_c1_grid, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map)
Definition gridwise_gemm_xdlops_v3r3.hpp:37
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition ck/stream_config.hpp:10
Definition block_to_ctile_map.hpp:24
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition block_to_ctile_map.hpp:38
Definition gridwise_gemm_xdlops_v3r3.hpp:142
Definition utility/sequence.hpp:43
Definition device_base.hpp:197
Block2CTileMap block_2_ctile_map_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:599
InElementwiseOperation in_element_op_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:600
std::vector< index_t > filter_spatial_lengths_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:608
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:593
Argument(const InDataType *p_in_grid, const WeiDataType *p_wei_grid, OutDataType *p_out_grid, const OutDataType *p_bias_grid, const OutDataType *p_resi_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op)
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:523
CDataType * p_c_grid_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:590
WeiElementwiseOperation wei_element_op_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:601
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:594
std::vector< index_t > conv_filter_strides_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:610
const CDataType * p_c0_grid_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:591
index_t Conv_C_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:606
std::vector< index_t > input_right_pads_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:613
OutElementwiseOperation out_element_op_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:602
const BDataType * p_b_grid_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:589
index_t Conv_K_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:605
std::vector< index_t > input_spatial_lengths_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:607
C1GridDesc_M_N c1_grid_desc_m_n_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:597
const ADataType * p_a_grid_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:588
std::vector< index_t > output_spatial_lengths_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:609
std::vector< index_t > conv_filter_dilations_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:611
CGridDesc_M_N c_grid_desc_m_n_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:595
const CDataType * p_c1_grid_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:592
std::vector< index_t > input_left_pads_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:612
C0GridDesc_M_N c0_grid_desc_m_n_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:596
index_t Conv_N_
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:604
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:67
static constexpr auto NXdlPerWave32
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:73
static constexpr bool IsValidCompilationParameter()
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:794
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:868
static constexpr auto I4
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:89
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads)
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:95
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:954
InDataType ADataType
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:75
static constexpr index_t NDimSpatial
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:83
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3< BlockSize, ABDataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, Sequence< 1, 0, 2 >, Sequence< 1, 0, 2 >, 2, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, Sequence< 1, 0, 2 >, Sequence< 1, 0, 2 >, 2, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, CBlockTransferScalarPerVector_NWaveNPerXdl > GridwiseGemmBase
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:475
static auto MakeInvoker()
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:912
remove_cvref_t< decltype(GridDescs{}[I3])> C0GridDesc_M_N
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:468
remove_cvref_t< decltype(GridDescs{}[I1])> BGridDesc_K0_N_K1
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:466
BlockToCTileMap_M00_N0_M01< MPerBlock, NPerBlock, CGridDesc_M_N > Block2CTileMap
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:471
decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})) GridDescs
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:462
static constexpr auto I2
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:87
static constexpr auto I0
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:85
static constexpr auto GemmK1Number
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:92
static constexpr auto I3
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:88
InDataType ABDataType
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:80
remove_cvref_t< decltype(GridDescs{}[I0])> AGridDesc_K0_M_K1
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:465
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:518
WeiDataType BDataType
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:76
static auto MakeArgument(const InDataType *p_in_grid, const WeiDataType *p_wei_grid, OutDataType *p_out_grid, const OutDataType *p_bias_grid, const OutDataType *p_resi_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op)
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:873
static bool IsSupportedArgument(const Argument &arg)
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:800
remove_cvref_t< decltype(GridDescs{}[I2])> CGridDesc_M_N
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:467
static constexpr auto I1
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:86
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:72
std::string GetTypeString() const override
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:959
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K DeviceOp
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:68
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:517
OutDataType CDataType
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:77
static constexpr auto K1Number
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:91
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in_grid, const void *p_wei_grid, void *p_out_grid, const void *p_bias_grid, const void *p_resi_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) override
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:915
remove_cvref_t< decltype(GridDescs{}[I4])> C1GridDesc_M_N
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:469
DeviceOp::Argument Argument
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:619
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:787
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp:622
Definition device_conv_fwd_bias_activation_add.hpp:19
#define CK_ENV(name)
Definition utility/env.hpp:129