device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp Source File

device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp Source File#

Composable Kernel: device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp Source File
device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#ifndef DEVICE_CONV3D_FWD_XDL_HPP
5#define DEVICE_CONV3D_FWD_XDL_HPP
6
7#include <iostream>
8#include <memory>
9#include <sstream>
10#include "device.hpp"
11#include "device_conv_fwd.hpp"
12#include "common_header.hpp"
13#include "ck/utility/env.hpp"
14#include "tensor_layout.hpp"
16#include "tensor_descriptor.hpp"
20
21namespace ck {
22namespace tensor_operation {
23namespace device {
24
25/*
26 * \see \link impl/device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3() \endlink.
27 */
28template <typename GridwiseGemm,
29 typename FloatAB,
30 typename FloatC,
31 typename AGridDesc_K0_M_K1,
32 typename BGridDesc_K0_N_K1,
33 typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
34 typename AElementwiseOperation,
35 typename BElementwiseOperation,
36 typename CElementwiseOperation,
37 typename Block2CTileMap,
38 bool HasMainKBlockLoop>
39__global__ void
40#if CK_USE_LAUNCH_BOUNDS
42#endif
44 const FloatAB* __restrict__ p_a_grid,
45 const FloatAB* __restrict__ p_b_grid,
46 FloatC* __restrict__ p_c_grid,
47 const index_t num_batches,
48 const index_t a_batch_stride,
49 const index_t b_batch_stride,
50 const index_t c_batch_stride,
51 const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
52 const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
53 const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
54 const AElementwiseOperation a_element_op,
55 const BElementwiseOperation b_element_op,
56 const CElementwiseOperation c_element_op,
57 const Block2CTileMap block_2_ctile_map)
58{
59#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
60 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
61 {
62 const index_t num_blocks_per_batch =
63 __builtin_amdgcn_readfirstlane(get_grid_size() / num_batches);
64 const index_t g_idx =
65 __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
66
67 const long_index_t a_batch_offset =
68 __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(a_batch_stride) * g_idx);
69 const long_index_t b_batch_offset =
70 __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(b_batch_stride) * g_idx);
71 const long_index_t c_batch_offset =
72 __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(c_batch_stride) * g_idx);
73
74 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
75
76 GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
77 p_b_grid + b_batch_offset,
78 p_c_grid + c_batch_offset,
79 p_shared,
80 a_grid_desc_k0_m_k1,
81 b_grid_desc_k0_n_k1,
82 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
83 a_element_op,
84 b_element_op,
85 c_element_op,
86 block_2_ctile_map);
87 }
88#else
89 ignore = p_a_grid;
90 ignore = p_b_grid;
91 ignore = p_c_grid;
92 ignore = num_batches;
93 ignore = a_batch_stride;
94 ignore = b_batch_stride;
95 ignore = c_batch_stride;
96 ignore = a_grid_desc_k0_m_k1;
97 ignore = b_grid_desc_k0_n_k1;
98 ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2;
99 ignore = a_element_op;
100 ignore = b_element_op;
101 ignore = c_element_op;
102 ignore = block_2_ctile_map;
103#endif // end of if (defined(__gfx9__))
104}
105
106// specialization for #D conv: in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k]
107template <typename InDataType,
108 typename WeiDataType, // WeiDataType must be the same as InDataType
109 typename OutDataType,
110 typename AccDataType,
111 typename InElementwiseOperation,
112 typename WeiElementwiseOperation,
113 typename OutElementwiseOperation,
114 ConvolutionForwardSpecialization ConvForwardSpecialization,
115 ck::index_t BlockSize,
116 ck::index_t MPerBlock,
117 ck::index_t NPerBlock,
118 ck::index_t K0PerBlock,
119 ck::index_t K1,
120 ck::index_t MPerXDL,
121 ck::index_t NPerXDL,
122 ck::index_t MXdlPerWave,
123 ck::index_t NXdlPerWave,
124 typename ABlockTransferThreadClusterLengths_K0_M_K1,
125 typename ABlockTransferThreadClusterArrangeOrder,
126 typename ABlockTransferSrcAccessOrder,
127 ck::index_t ABlockTransferSrcVectorDim,
128 ck::index_t ABlockTransferSrcScalarPerVector,
129 ck::index_t ABlockTransferDstScalarPerVector_K1,
130 bool ABlockLdsAddExtraM,
131 typename BBlockTransferThreadClusterLengths_K0_N_K1,
132 typename BBlockTransferThreadClusterArrangeOrder,
133 typename BBlockTransferSrcAccessOrder,
134 ck::index_t BBlockTransferSrcVectorDim,
135 ck::index_t BBlockTransferSrcScalarPerVector,
136 ck::index_t BBlockTransferDstScalarPerVector_K1,
137 bool BBlockLdsAddExtraN,
138 ck::index_t CThreadTransferSrcDstVectorDim,
139 ck::index_t CThreadTransferDstScalarPerVector>
141 : public DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
142
143{
145
147 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
148 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
149
150 using ADataType = InDataType;
151 using BDataType = WeiDataType;
152 using CDataType = OutDataType;
153 // TODO make A/B datatype different
154 using ABDataType = InDataType;
155
156 static constexpr auto I0 = Number<0>{};
157 static constexpr auto I1 = Number<1>{};
158 static constexpr auto I2 = Number<2>{};
159 static constexpr auto I3 = Number<3>{};
160
161 /*
162 * \brief Split the number of batches, \p N, into N = B * N1, such that the memory
163 * space of input and output tensors stays with the value range of index_t, and each subbatch
164 * can be dealed with GridwiseGemm.
165 */
167 const index_t K,
168 const index_t C,
169 std::vector<ck::index_t> input_spatial_lengths,
170 std::vector<ck::index_t> output_spatial_lengths)
171 {
172 const index_t Di = input_spatial_lengths[0];
173 const index_t Hi = input_spatial_lengths[1];
174 const index_t Wi = input_spatial_lengths[2];
175
176 const index_t Do = output_spatial_lengths[0];
177 const index_t Ho = output_spatial_lengths[1];
178 const index_t Wo = output_spatial_lengths[2];
179
180 // N1 should satisfy that
181 // 1) N % N1 = 0;
182 // 2) N1 * (Do * Ho * Wo * K) < (2^31 - 1)
183 // 3) N1 * (Di * Hi * Wi * C) < (2^31 - 1)
184 //
185 // Do NOT confuse (B, N1) in this function with (B, N1) in gridewise GEMM.
186 auto N1 = N + 1;
187
188 const auto stride =
189 math::max(long_index_t(Do) * Ho * Wo * K, long_index_t(Di) * Hi * Wi * C);
190 const index_t max_stride = NumericLimits<index_t>::Max();
191
192 for(index_t n0 = 1; n0 <= N; ++n0)
193 {
194 index_t n1 = N / n0;
195 if(n0 * n1 == N && long_index_t(n1) * long_index_t(stride) < max_stride)
196 {
197 N1 = n1;
198 break;
199 }
200 }
201
202 const auto B = N / N1;
203 if(B * N1 != N)
204 {
205 throw std::runtime_error(__func__ +
206 std::string(": failed to find num_subbatches for conv3d.\n"));
207 }
208
209 return N1;
210 }
211
212 static auto
214 const index_t K,
215 const index_t C,
216 std::vector<ck::index_t> input_spatial_lengths,
217 std::vector<ck::index_t> filter_spatial_lengths,
218 std::vector<ck::index_t> output_spatial_lengths,
219 std::vector<ck::index_t> conv_filter_strides,
220 std::vector<ck::index_t> conv_filter_dilations,
221 std::vector<ck::index_t> input_left_pads,
222 std::vector<ck::index_t> input_right_pads)
223 {
224 assert(input_spatial_lengths.size() > 2);
225 assert(filter_spatial_lengths.size() > 2);
226 assert(conv_filter_strides.size() > 2);
227 assert(conv_filter_dilations.size() > 2);
228 assert(input_left_pads.size() > 2);
229 assert(input_right_pads.size() > 2);
230
231 const index_t Di = input_spatial_lengths[0];
232 const index_t Hi = input_spatial_lengths[1];
233 const index_t Wi = input_spatial_lengths[2];
234 const index_t Z = filter_spatial_lengths[0];
235 const index_t Y = filter_spatial_lengths[1];
236 const index_t X = filter_spatial_lengths[2];
237
238 const index_t Do = output_spatial_lengths[0];
239 const index_t Ho = output_spatial_lengths[1];
240 const index_t Wo = output_spatial_lengths[2];
241
242 static_assert(ConvForwardSpecialization == ConvolutionForwardSpecialization::Default,
243 "Wrong! This specialization not implemented!");
244
245 const auto in_desc_n_di_hi_wi_c =
247 const auto wei_desc_k_z_y_x_c =
249 const auto out_desc_n_do_ho_wo_k =
251
253 in_desc_n_di_hi_wi_c,
254 wei_desc_k_z_y_x_c,
255 out_desc_n_do_ho_wo_k,
256 make_tuple(conv_filter_strides[0], conv_filter_strides[1], conv_filter_strides[2]),
258 conv_filter_dilations[0], conv_filter_dilations[1], conv_filter_dilations[2]),
259 make_tuple(input_left_pads[0], input_left_pads[1], input_left_pads[2]),
260 make_tuple(input_right_pads[0], input_right_pads[1], input_right_pads[2]),
261 Number<K1>{});
262
263 return descs;
264 }
265
267 1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}))>;
268
272
273 template <index_t NXdlPerWave_>
275 BlockSize,
276 InDataType,
277 AccDataType,
278 OutDataType,
283 InElementwiseOperation,
284 WeiElementwiseOperation,
285 OutElementwiseOperation,
286 MPerBlock,
287 NPerBlock,
288 K0PerBlock,
289 MPerXDL,
290 NPerXDL,
291 K1,
292 MXdlPerWave,
293 NXdlPerWave_,
294 ABlockTransferThreadClusterLengths_K0_M_K1,
295 Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder,
296 Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder,
297 2,
298 ABlockTransferSrcScalarPerVector,
299 ABlockTransferDstScalarPerVector_K1,
300 false, // AThreadTransferSrcResetCoordinateAfterRun,
301 ABlockLdsAddExtraM,
302 BBlockTransferThreadClusterLengths_K0_N_K1,
303 Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder,
304 Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder,
305 2,
306 BBlockTransferSrcScalarPerVector,
307 BBlockTransferDstScalarPerVector_K1,
308 false, // BThreadTransferSrcResetCoordinateAfterRun,
309 BBlockLdsAddExtraN,
311 7,
312 CThreadTransferDstScalarPerVector>;
315
317 decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
318 using Block2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap;
319
320 // Argument
321 struct Argument : public BaseArgument
322 {
323 Argument(const InDataType* p_in,
324 const WeiDataType* p_wei,
325 OutDataType* p_out,
326 const index_t N,
327 const index_t K,
328 const index_t C,
329 std::vector<ck::index_t> input_spatial_lengths,
330 std::vector<ck::index_t> filter_spatial_lengths,
331 std::vector<ck::index_t> output_spatial_lengths,
332 std::vector<ck::index_t> conv_filter_strides,
333 std::vector<ck::index_t> conv_filter_dilations,
334 std::vector<ck::index_t> input_left_pads,
335 std::vector<ck::index_t> input_right_pads,
336 index_t M01,
337 index_t N01,
338 InElementwiseOperation in_element_op,
339 WeiElementwiseOperation wei_element_op,
340 OutElementwiseOperation out_element_op)
341 : p_a_grid_{p_in},
342 p_b_grid_{p_wei},
343 p_c_grid_{p_out},
344 M01_{M01},
345 N01_{N01},
346 in_element_op_{in_element_op},
347 wei_element_op_{wei_element_op},
348 out_element_op_{out_element_op}
349 {
350 const index_t subbatch_size =
351 GetMaxAllowableSubBatchSize(N, K, C, input_spatial_lengths, output_spatial_lengths);
352 num_subbatches_ = N / subbatch_size;
353
354 const auto descs =
356 K,
357 C,
358 input_spatial_lengths,
359 filter_spatial_lengths,
360 output_spatial_lengths,
361 conv_filter_strides,
362 conv_filter_dilations,
363 input_left_pads,
364 input_right_pads);
365
366 a_grid_desc_k0_m_k1_ = descs[I0];
367 b_grid_desc_k0_n_k1_ = descs[I1];
368 c_grid_desc_m_n_ = descs[I2];
369
371 GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
372
373 a_batch_stride_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize();
374 b_batch_stride_ = 0;
375 c_batch_stride_ = c_grid_desc_m_n_.GetElementSpaceSize();
376
377 if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
381 {
383 GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
384 }
385 }
386
387 // private:
388 const InDataType* p_a_grid_;
389 const WeiDataType* p_b_grid_;
390 OutDataType* p_c_grid_;
402 InElementwiseOperation in_element_op_;
403 WeiElementwiseOperation wei_element_op_;
404 OutElementwiseOperation out_element_op_;
405 };
406
407 // Invoker
408 struct Invoker : public BaseInvoker
409 {
411
412 template <typename GridwiseGemm>
413 float RunImp(const typename GridwiseGemm::Argument& arg,
414 const StreamConfig& stream_config = StreamConfig{})
415 {
416 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
417 {
418 std::cout << "num_batches_of_GEMM = " << arg.num_subbatches_ << std::endl;
419 std::cout << "a_grid_desc_k0_m_k1{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
420 << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
421 << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
422
423 std::cout << "b_grid_desc_k0_n_k1{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
424 << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
425 << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
426
427 std::cout << "c_grid_desc_m_n{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
428 << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
429 }
430
431 if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
432 arg.b_grid_desc_k0_n_k1_,
433 arg.c_grid_desc_m_n_,
434 arg.block_2_ctile_map_))
435 {
436 throw std::runtime_error(
437 "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
438 }
439
440 const index_t grid_size =
441 arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) *
442 arg.num_subbatches_;
443
444 const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
445
446 const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
447
448 float ave_time = 0;
449 if(has_main_k0_block_loop)
450 {
451 const auto kernel = kernel_gemm_xdlops_v2r3_for_conv3d<
452 GridwiseGemm,
453 InDataType,
454 OutDataType,
458 InElementwiseOperation,
459 WeiElementwiseOperation,
460 OutElementwiseOperation,
462 true>;
463 ave_time = launch_and_time_kernel(stream_config,
464 kernel,
465 dim3(grid_size),
466 dim3(BlockSize),
467 0,
468 arg.p_a_grid_,
469 arg.p_b_grid_,
470 arg.p_c_grid_,
471 arg.num_subbatches_,
472 arg.a_batch_stride_,
473 arg.b_batch_stride_,
474 arg.c_batch_stride_,
475 arg.a_grid_desc_k0_m_k1_,
476 arg.b_grid_desc_k0_n_k1_,
477 arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
478 arg.in_element_op_,
479 arg.wei_element_op_,
480 arg.out_element_op_,
481 arg.block_2_ctile_map_);
482 }
483 else
484 {
485 const auto kernel = kernel_gemm_xdlops_v2r3_for_conv3d<
486 GridwiseGemm,
487 InDataType,
488 OutDataType,
492 InElementwiseOperation,
493 WeiElementwiseOperation,
494 OutElementwiseOperation,
496 false>;
497
498 ave_time = launch_and_time_kernel(stream_config,
499 kernel,
500 dim3(grid_size),
501 dim3(BlockSize),
502 0,
503 arg.p_a_grid_,
504 arg.p_b_grid_,
505 arg.p_c_grid_,
506 arg.num_subbatches_,
507 arg.a_batch_stride_,
508 arg.b_batch_stride_,
509 arg.c_batch_stride_,
510 arg.a_grid_desc_k0_m_k1_,
511 arg.b_grid_desc_k0_n_k1_,
512 arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
513 arg.in_element_op_,
514 arg.wei_element_op_,
515 arg.out_element_op_,
516 arg.block_2_ctile_map_);
517 }
518
519 return ave_time;
520 }
521
523
524 // polymorphic
525 float Run(const BaseArgument* p_arg,
526 const StreamConfig& stream_config = StreamConfig{}) override
527 {
528 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
529 }
530 };
531
532 static constexpr bool IsValidCompilationParameter()
533 {
534 // TODO: properly implement this check
535 return true;
536 }
537
538 static bool IsSupportedArgument(const Argument& arg)
539 {
541 {
542 return false;
543 }
544 return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
548 }
549
550 // polymorphic
551 bool IsSupportedArgument(const BaseArgument* p_arg) override
552 {
553 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
554 }
555
556 static auto MakeArgument(const InDataType* p_in,
557 const WeiDataType* p_wei,
558 OutDataType* p_out,
559 const index_t N,
560 const index_t K,
561 const index_t C,
562 std::vector<ck::index_t> input_spatial_lengths,
563 std::vector<ck::index_t> filter_spatial_lengths,
564 std::vector<ck::index_t> output_spatial_lengths,
565 std::vector<ck::index_t> conv_filter_strides,
566 std::vector<ck::index_t> conv_filter_dilations,
567 std::vector<ck::index_t> input_left_pads,
568 std::vector<ck::index_t> input_right_pads,
569 InElementwiseOperation in_element_op,
570 WeiElementwiseOperation wei_element_op,
571 OutElementwiseOperation out_element_op)
572 {
573 return Argument{p_in,
574 p_wei,
575 p_out,
576 N,
577 K,
578 C,
579 input_spatial_lengths,
580 filter_spatial_lengths,
581 output_spatial_lengths,
582 conv_filter_strides,
583 conv_filter_dilations,
584 input_left_pads,
585 input_right_pads,
586 1,
587 1,
588 in_element_op,
589 wei_element_op,
590 out_element_op};
591 }
592
593 static auto MakeInvoker() { return Invoker{}; }
594
595 // polymorphic
596 std::unique_ptr<BaseArgument>
597 MakeArgumentPointer(const void* p_in,
598 const void* p_wei,
599 void* p_out,
600 const index_t N,
601 const index_t K,
602 const index_t C,
603 std::vector<ck::index_t> input_spatial_lengths,
604 std::vector<ck::index_t> filter_spatial_lengths,
605 std::vector<ck::index_t> output_spatial_lengths,
606 std::vector<ck::index_t> conv_filter_strides,
607 std::vector<ck::index_t> conv_filter_dilations,
608 std::vector<ck::index_t> input_left_pads,
609 std::vector<ck::index_t> input_right_pads,
610 InElementwiseOperation in_element_op,
611 WeiElementwiseOperation wei_element_op,
612 OutElementwiseOperation out_element_op) override
613
614 {
615 return std::make_unique<Argument>(static_cast<const InDataType*>(p_in),
616 static_cast<const WeiDataType*>(p_wei),
617 static_cast<OutDataType*>(p_out),
618 N,
619 K,
620 C,
621 input_spatial_lengths,
622 filter_spatial_lengths,
623 output_spatial_lengths,
624 conv_filter_strides,
625 conv_filter_dilations,
626 input_left_pads,
627 input_right_pads,
628 1,
629 1,
630 in_element_op,
631 wei_element_op,
632 out_element_op);
633 }
634
635 // polymorphic
636 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
637 {
638 return std::make_unique<Invoker>(Invoker{});
639 }
640
641 std::string GetTypeString() const override
642 {
643 auto str = std::stringstream();
644
645 // clang-format off
646 str << "DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K"
647 << "<"
648 << BlockSize << ", "
649 << MPerBlock << ", "
650 << NPerBlock << ", "
651 << K0PerBlock << ", "
652 << K1 << ", "
653 << MPerXDL << ", "
654 << NPerXDL << ", "
655 << MXdlPerWave << ", "
656 << NXdlPerWave << ", "
657 << ABlockTransferSrcScalarPerVector << ", "
658 << ABlockTransferDstScalarPerVector_K1 << ", "
659 << BBlockTransferSrcScalarPerVector << ", "
660 << BBlockTransferDstScalarPerVector_K1 << ", "
661 << CShuffleMXdlPerWavePerShuffle << ", "
662 << CShuffleNXdlPerWavePerShuffle << ", "
663 << CBlockTransferScalarPerVector_NWaveNPerXdl
664 << ">";
665 // clang-format on
666
667 return str.str();
668 }
669};
670
671} // namespace device
672} // namespace tensor_operation
673} // namespace ck
674#endif
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
ConvolutionForwardSpecialization
Definition convolution_forward_specialization.hpp:15
@ Default
Definition convolution_forward_specialization.hpp:16
__global__ void kernel_gemm_xdlops_v2r3_for_conv3d(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const index_t num_batches, const index_t a_batch_stride, const index_t b_batch_stride, const index_t c_batch_stride, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map)
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:43
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk_pad(const TensorDescriptor< In... > &in_grid_desc_n_di_hi_wi_c, const TensorDescriptor< Wei... > &wei_k_z_y_x_c_grid_desc, const TensorDescriptor< Out... > &out_n_do_ho_wo_k_grid_desc, const ConvStrides &conv_strides, const ConvDilations &conv_dilations, const InLeftPads &in_left_pads, const InRightPads &in_right_pads, Number< GemmK1Value >)
Definition transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp:28
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
int64_t long_index_t
Definition ck.hpp:300
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdlops_v2r3.hpp:142
__host__ static __device__ constexpr T Max()
Definition numeric_limits.hpp:311
Definition utility/sequence.hpp:43
Definition device_base.hpp:197
InElementwiseOperation in_element_op_
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:402
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:398
Argument(const InDataType *p_in, const WeiDataType *p_wei, OutDataType *p_out, const index_t N, const index_t K, const index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, index_t M01, index_t N01, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op)
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:323
CGridDesc_M_N c_grid_desc_m_n_
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:397
WeiElementwiseOperation wei_element_op_
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:403
OutElementwiseOperation out_element_op_
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:404
const InDataType * p_a_grid_
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:388
const WeiDataType * p_b_grid_
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:389
Block2CTileMap block_2_ctile_map_
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:399
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:396
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:395
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:413
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:525
DeviceOp::Argument Argument
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:410
typename GridwiseGemm::DefaultBlock2CTileMap Block2CTileMap
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:318
DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K DeviceOp
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:144
remove_cvref_t< decltype(ABCGridDescs{}[I0])> AGridDesc_K0_M_K1
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:269
remove_cvref_t< decltype(ABCGridDescs{}[I1])> BGridDesc_K0_N_K1
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:270
static auto MakeArgument(const InDataType *p_in, const WeiDataType *p_wei, OutDataType *p_out, const index_t N, const index_t K, const index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op)
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:556
static bool IsSupportedArgument(const Argument &arg)
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:538
static constexpr auto NXdlPerWave32
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:148
static constexpr auto I3
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:159
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:314
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(const index_t N, const index_t K, const index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads)
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:213
InDataType ABDataType
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:154
remove_cvref_t< decltype(ABCGridDescs{}[I2])> CGridDesc_M_N
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:271
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, InDataType, AccDataType, OutDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, Sequence< 1, 0, 2 >, Sequence< 1, 0, 2 >, 2, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, Sequence< 1, 0, 2 >, Sequence< 1, 0, 2 >, 2, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, Sequence< 2, 3, 0, 1, 7, 5, 4, 6 >, 7, CThreadTransferDstScalarPerVector > GridwiseGemmBase
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:274
InDataType ADataType
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:150
static constexpr auto I2
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:158
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:636
std::string GetTypeString() const override
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:641
static constexpr auto I1
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:157
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:313
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in, const void *p_wei, void *p_out, const index_t N, const index_t K, const index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) override
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:597
remove_cvref_t< decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( 1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}))> ABCGridDescs
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:266
static auto MakeInvoker()
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:593
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})) CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:316
WeiDataType BDataType
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:151
OutDataType CDataType
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:152
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:551
static constexpr auto I0
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:156
static index_t GetMaxAllowableSubBatchSize(const index_t N, const index_t K, const index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths)
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:166
static constexpr bool IsValidCompilationParameter()
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:532
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:147
Definition device_conv_fwd.hpp:25
#define CK_ENV(name)
Definition utility/env.hpp:129