flatmm_pipeline_agmem_bgmem_creg_v1.hpp Source File

flatmm_pipeline_agmem_bgmem_creg_v1.hpp Source File#

Composable Kernel: flatmm_pipeline_agmem_bgmem_creg_v1.hpp Source File
flatmm_pipeline_agmem_bgmem_creg_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9
10namespace ck_tile {
11
12template <typename Problem>
14{
15 static constexpr index_t PrefetchStages = 2;
16
17 CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
18 {
19 return num_loop > PrefetchStages;
20 }
21
23 {
24 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
25 }
26 template <typename RunFunction>
27 CK_TILE_HOST_DEVICE static auto
28 TailHandler(const RunFunction& run_func, bool, TailNumber tail_num)
29 {
30 if(TailNumber::Even == tail_num)
31 {
32 return run_func(bool_constant<true>{},
34 }
35 else if(TailNumber::Odd == tail_num)
36 {
37 return run_func(bool_constant<true>{},
39 }
40 // return run_func(bool_constant<true>{}, integral_constant<TailNumber,
41 // TailNumber::Empty>{});
42 }
43};
44
45template <typename Problem, typename PipelinePolicy = UniversalFlatmmPipelineAgBgCrPolicy>
47{
52
56
59
60 static constexpr auto config =
61 BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
62
63 using WG = remove_cvref_t<decltype(config.template at<0>())>;
64
65 static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2
66 static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
67
68 static constexpr index_t BlockSize = Problem::kBlockSize;
69 static constexpr index_t WaveSize = get_warp_size();
70
71 static constexpr index_t kMPerBlock = BlockGemmShape::kM;
72 static constexpr index_t kNPerBlock = BlockGemmShape::kN;
73 static constexpr index_t kKPerBlock = BlockGemmShape::kK;
74
75 static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp;
76 static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
77
78 static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; }
79 static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; }
80 static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
81
82 static constexpr bool kPadM = Problem::kPadM;
83 static constexpr bool kPadN = Problem::kPadN;
84 static constexpr bool kPadK = Problem::kPadK;
85
86 static constexpr index_t kLdsAlignmentInBytes = 16;
87 static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
88 static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
89
90 static constexpr auto I0 = number<0>();
91 static constexpr auto I1 = number<1>();
92 static constexpr auto I2 = number<2>();
93 static constexpr auto idxM = I0;
94 static constexpr auto idxN = I1;
95 static constexpr auto idxK = I2;
99
100 static constexpr index_t MWarp = config.template at<1>();
101 static constexpr index_t NWarp = config.template at<2>();
102
103 static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
104 static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
105 static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
106
109
112
113 static constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
117
118 static constexpr bool HasHotLoop = Problem::HasHotLoop;
119 static constexpr auto TailNum = Problem::TailNum;
120
121/*
122defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) // mi300 fp8 16c 0.5*K1
123defined(USING_MFMA_32x32x16) && defined(ENABLE_FP8) // mi300 fp8 32c 0.5*K1
124defined(USING_MFMA_16x16x16) && defined(ENABLE_FP16) // mi300 fp16 16c 0.5*K1
125defined(USING_MFMA_32x32x8) && defined(ENABLE_FP16) // mi300 fp16 32c 0.5*K1
126
127defined(USING_MFMA_16x16x128) && defined(ENABLE_FP8) // mi350 fp8 32c 2*K1
128defined(USING_MFMA_32x32x64) && defined(ENABLE_FP8) // mi350 fp8 64c 2*K1
129defined(USING_MFMA_16x16x32) && defined(ENABLE_FP16) // mi350 fp16 16c 1*K1
130defined(USING_MFMA_32x32x16) && defined(ENABLE_FP16) // mi350 fp16 32c 1*K1
131
132defined(USING_MFMA_16x16x128) && defined(ENABLE_FP4) // mi350 fp4 16c 1*K1
133defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
134*/
135
136// #if (defined(USING_MFMA_16x16x32_F8) || \
137// defined(USING_MFMA_32x32x16_F8) || \
138// defined(USING_MFMA_16x16x16_F16) || \
139// defined(USING_MFMA_32x32x8_F16)) // K1 per Mfma = 0.5
140// static constexpr auto mfma_per_wg = 2;
141// static constexpr auto dsread_per_wg = 1;
142// #elif (defined(USING_MFMA_16x16x32_F16) || \
143// defined(USING_MFMA_32x32x16_F16) || \
144// defined(USING_MFMA_16x16x128_F4) || \
145// defined(USING_MFMA_32x32x64_F4)) // K1 per Mfma = 1
146// static constexpr auto mfma_per_wg = 1;
147// static constexpr auto dsread_per_wg = 1;
148// #elif (defined(USING_MFMA_16x16x128_F8) || \
149// defined(USING_MFMA_32x32x64_F8)) // K1 per Mfma = 2
150// static constexpr auto mfma_per_wg = 1;
151// static constexpr auto dsread_per_wg = 2;
152// #endif
153#ifdef __gfx942__
154 static constexpr index_t mfma_per_wg = 2;
155#else
156 static constexpr index_t mfma_per_wg = 1;
157#endif
158 static constexpr index_t dsread_per_wg =
159 WG::kM * WG::kK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize;
160 static_assert((WG::kM * WG::kK * sizeof(ADataType) / WaveSize) % Problem::VectorLoadSize == 0);
161
166 static constexpr index_t Aload_rep = dswrite_rep;
167 static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / K1 / WaveSize;
168 static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
169 static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
170
174
175 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
176 {
177 // clang-format off
178 return concat('_', "pipeline_AGmemBGmemCRegV1",
180 concat('x', WG::kM, WG::kN, WG::kK),
182 concat('x', kPadM, kPadN, kPadK));
183 // clang-format on
184 }
185
186 // For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
187 static constexpr bool DoubleSmemBuffer = false;
188
189 CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
190
192 {
193 return PipelinePolicy::template GetSmemSize<Problem>();
194 }
195
196 CK_TILE_HOST_DEVICE static constexpr auto
197 SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
198 {
199 // Init inst order
200 index_t max_data_inst = dsread_perM > load_perM
201 ? (dsread_perM > dswrite_perM ? dsread_perM : dswrite_perM)
202 : (load_perM > dswrite_perM ? load_perM : dswrite_perM);
203 index_t sum_data_inst = dsread_perM + load_perM + dswrite_perM;
204 index_t round_data_inst = (sum_data_inst + mfma_perM_perK - 1) / mfma_perM_perK;
205
206 index_t inst_order[NIterPerWarp * 10];
207#pragma unroll
208 for(int idx = 0; idx < NIterPerWarp * 10; idx++)
209 {
210 inst_order[idx] = 0;
211 }
212
213 index_t index = 0;
214#pragma unroll
215 for(int j = 0; j < max_data_inst; j++)
216 {
217 if(dswrite_perM > j)
218 {
219 inst_order[index] = 1;
220 index++;
221 }
222 if(load_perM > j)
223 {
224 inst_order[index] = 2;
225 index++;
226 }
227 if(dsread_perM > j)
228 {
229 inst_order[index] = 3;
230 index++;
231 }
232 }
233
234// Schedule IGLP
235#pragma unroll
236 for(int j = 0; j < mfma_perM_perK; j++)
237 {
238 index_t inst_idx = 0;
239 if(j == 0)
240 ;
241 else if(j == 1)
242 inst_idx = mfma_perM_perK == 2 ? 1 : mfma_perM_perK - 2;
243 else if(j == 2)
244 inst_idx = mfma_perM_perK - 1;
245 else
246 inst_idx = mfma_perM_perK - j;
247
248 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
249
250#pragma unroll
251 for(int r = 0; r < round_data_inst; r++)
252 {
253 if(r % 2 == 0)
254 {
255 if(inst_order[inst_idx + r * mfma_perM_perK] == 1)
256 {
257 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
258 }
259 if(inst_order[inst_idx + r * mfma_perM_perK] == 2)
260 {
261 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
262 }
263 if(inst_order[inst_idx + r * mfma_perM_perK] == 3)
264 {
265 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
266 }
267 }
268 else
269 {
270 if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 1)
271 {
272 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
273 }
274 if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 2)
275 {
276 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
277 }
278 if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 3)
279 {
280 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
281 }
282 }
283 }
284 }
285 }
286 CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
287 {
288 // Keypoint of pipeline optimize is workload balance in time
289 // instruction schedule example(128X256X256, 1X4, 16X16X128):
290 // Iter MNK MFMA ds_read ds_write A_load b_load
291 // -1 M6N0: 57 - 8 - -
292 // -1 M6N1: 58 1 - - -
293 // -1 M6N2: 59 - - 7 -
294 // -1 M6N3: 60 2 - - -
295 // -1 M7N0: 61 - - - -
296 // -1 M7N1: 62 3 - - -
297 // -1 M7N2: 63 - - 8 -
298 // -1 M7N3: 64 4 - - -
299 // 0 M0N0K0: 1 - - - 1
300 // 0 M0N1: 2 5 - - -
301 // 0 M0N2: 3 - - - 2
302 // 0 M0N3: 4 6 - - -
303 // 0 M1N0: 5 - - - 3
304 // 0 M1N1: 6 7 - - -
305 // 0 M1N2: 7 - - - 4
306 // 0 M1N3: 8 8 - - -
307 // 0 M2N0: 9 - - - 5
308 // 0 M2N1: 10 9 - - -
309 // 0 M2N2: 11 - - - 6
310 // 0 M2N3: 12 10 - - -
311 // 0 M3N0: 13 - 1 - 7
312 // 0 M3N1: 14 11 - - -
313 // 0 M3N2: 15 - - - 8
314 // 0 M3N3: 16 12 - - -
315 // 0 M4N0: 17 - 2 - -
316 // 0 M4N1: 18 13 - - -
317 // 0 M4N2: 19 - - 1 -
318 // 0 M4N3: 20 14 - - -
319 // 0 M5N0: 21 - 3 - -
320 // 0 M5N1: 22 15 - - -
321 // 0 M5N2: 23 - - 2 -
322 // 0 M5N3: 24 16 - - -
323 // 0 M6N0: 25 - 4 - -
324 // 0 M6N1: 26 17 - - -
325 // 0 M6N2: 27 - - 3 -
326 // 0 M6N3: 28 18 - - -
327 // 0 M7N0: 29 - - - -
328 // 0 M7N1: 30 19 - - -
329 // 0 M7N2: 31 - - 4 -
330 // 0 M7N3: 32 20 - - -
331 // 0 M0N0K1: 33 - - - 9
332 // 0 M0N1: 34 21 - - -
333 // 0 M0N2: 35 - - - 10
334 // 0 M0N3: 36 22 - - -
335 // 0 M1N0: 37 - - - 11
336 // 0 M1N1: 38 23 - - -
337 // 0 M1N2: 39 - - - 12
338 // 0 M1N3: 40 24 - - -
339 // 0 M2N0: 41 - - - 13
340 // 0 M2N1: 42 25 - - -
341 // 0 M2N2: 43 - - - 14
342 // 0 M2N3: 44 26 - - -
343 // 0 M3N0: 45 - 5 - 15
344 // 0 M3N1: 46 27 - - -
345 // 0 M3N2: 47 - - - 16
346 // 0 M3N3: 48 28 - - -
347 // 0 M4N0: 49 - 6 - -
348 // 0 M4N1: 50 29 - - -
349 // 0 M4N2: 51 - - 5 -
350 // 0 M4N3: 52 30 - - -
351 // 0 M5N0: 53 - 7 - -
352 // 0 M5N1: 54 31 - - -
353 // 0 M5N2: 55 - - 6 -
354 // 0 M5N3: 56 32 - - -
355 // 0 M6N0: 57 - 8 - -
356 // 0 M6N1: 58 1 - - -
357 // 0 M6N2: 59 - - 7 -
358 // 0 M6N3: 60 2 - - -
359 // 0 M7N0: 61 - - - -
360 // 0 M7N1: 62 3 - - -
361 // 0 M7N2: 63 - - 8 -
362 // 0 M7N3: 64 4 - - -
363
364#pragma unroll
365 for(int kIter = 0; kIter < KIterPerWarp; kIter++)
366 {
367#pragma unroll
368 for(int mIter = 0; mIter < MIterPerWarp; mIter++)
369 {
370 index_t dsread_perM = 0;
371 index_t dswrite_perM = 0;
372 index_t load_perM = 0;
373
374 // Calculate ds_read number per M
375 dsread_perM = dsread_per_wg;
376
377 // Calculate ds_write number per M
378 if(mIter == 0)
379 {
380 dswrite_perM =
383 : 0;
384 }
385 else if(mIter >= MIterPerWarp - DsWritePreIssue + 1)
386 {
387 dswrite_perM = 0;
388 }
389 else
390 {
391 dswrite_perM = (dswrite_num_perK -
392 (MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0
394 : 0;
395 }
396 // Add ds write when ds write data > needed
397 if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter))
398 {
399 if(mIter == MIterPerWarp - 1 - dswrite_mIter)
400 dswrite_perM = 1;
401 }
402
403 // Calculate buffer_load number per M
404 if(mIter < HalfMIter)
405 {
406 load_perM =
407 ((Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0 ? Aload_rep
408 : 0) +
409 ((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep
410 : 0);
411 }
412 else
413 {
414 load_perM = (Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0
415 ? Aload_rep
416 : 0;
417 }
418 SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
419 }
420 }
421 // Add Aload when Aload data > needed
422 if(Aload_num_perK == 0)
423 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
424 __builtin_amdgcn_sched_barrier(0);
425 }
426
428 {
429#pragma unroll
430 for(int kIter = 0; kIter < KIterPerWarp; kIter++)
431 {
432#pragma unroll
433 for(int mIter = 0; mIter < MIterPerWarp; mIter++)
434 {
435 index_t dsread_perM = 0;
436 index_t dswrite_perM = 0;
437 index_t load_perM = 0;
438
439 // Calculate ds_read number per M
440 dsread_perM = dsread_per_wg;
441
442 // Calculate ds_write number per M
443 if(mIter == 0)
444 {
445 dswrite_perM =
448 : 0;
449 }
450 else if(mIter >= MIterPerWarp - DsWritePreIssue + 1)
451 {
452 dswrite_perM = 0;
453 }
454 else
455 {
456 dswrite_perM = (dswrite_num_perK -
457 (MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0
459 : 0;
460 }
461 // Add ds write when ds write data > needed
462 if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter))
463 {
464 if(mIter == MIterPerWarp - 1 - dswrite_mIter)
465 dswrite_perM = 1;
466 }
467
468 // Calculate buffer_load number per M
469 if(mIter < HalfMIter)
470 {
471 load_perM =
472 ((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep
473 : 0);
474 }
475 SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
476 }
477 }
478 __builtin_amdgcn_sched_barrier(0);
479 }
480
482 {
483#pragma unroll
484 for(int kIter = 0; kIter < KIterPerWarp; kIter++)
485 {
486#pragma unroll
487 for(int mIter = 0; mIter < MIterPerWarp; mIter++)
488 {
489 index_t dsread_perM = 0;
490 index_t dswrite_perM = 0;
491 index_t load_perM = 0;
492
493 // Calculate ds_read number per M
494 if((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
495 dsread_perM = dsread_per_wg;
496
497 SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
498 }
499 }
500 // __builtin_amdgcn_sched_barrier(0);
501 }
502
503 template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, typename AElementFunction>
504 CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
505 const AElementFunction& a_element_func,
506 const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
507 index_t num_loop,
508 void* p_smem_ping,
509 void* p_smem_pong) const
510 {
511 static_assert(
512 std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>>,
513 "wrong!");
514
515 static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
516 "wrong!");
517 static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
518 "wrong!");
519
520 constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
521 const index_t iMWarp = get_warp_id() / NWarp;
522
523 using CWarpDstr = typename WG::CWarpDstr;
524 using CWarpTensor = typename WG::CWarpTensor;
525
526 constexpr auto c_warp_y_lengths =
527 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
528 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
529
530 __builtin_amdgcn_sched_barrier(0);
531
532 // A tile in LDS
533 ADataType* p_a_lds_ping = static_cast<ADataType*>(p_smem_ping);
534 ADataType* p_a_lds_pong = static_cast<ADataType*>(p_smem_pong);
535
536 constexpr auto a_lds_block_desc =
537 PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
538
539 auto a_lds_block_ping =
540 make_tensor_view<address_space_enum::lds>(p_a_lds_ping, a_lds_block_desc);
541 auto a_lds_block_pong =
542 make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
543
544 // A DRAM tile window for load
545 auto a_copy_dram_window =
546 make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
548 a_dram_block_window_tmp.get_window_origin(),
549 PipelinePolicy::template MakeADramTileDistribution<Problem>());
550
551 auto a_copy_lds_window_ping =
552 make_tile_window(a_lds_block_ping,
554 {0, 0},
555 PipelinePolicy::template MakeADramTileDistribution<Problem>());
556
557 auto a_copy_lds_window_pong =
558 make_tile_window(a_lds_block_pong,
560 {0, 0},
561 PipelinePolicy::template MakeADramTileDistribution<Problem>());
562
563 // ping-pong window for A LDS
564 auto a_warp_window_ping_tmp =
565 make_tile_window(a_lds_block_ping,
567 {iMWarp * WG::kM, 0},
568 PipelinePolicy::template MakeALDS_WarpTileDistribution<Problem>());
569
570 auto a_warp_window_pong_tmp =
571 make_tile_window(a_lds_block_pong,
573 {iMWarp * WG::kM, 0},
574 PipelinePolicy::template MakeALDS_WarpTileDistribution<Problem>());
575
577 statically_indexed_array<decltype(a_warp_window_ping_tmp), KIterPerWarp>,
579 a_warp_windows_ping;
580
582 statically_indexed_array<decltype(a_warp_window_pong_tmp), KIterPerWarp>,
584 a_warp_windows_pong;
585
586 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
587 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
588 a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
589 a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
590
591 move_tile_window(a_warp_windows_ping(mIter)(kIter),
592 {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
593 move_tile_window(a_warp_windows_pong(mIter)(kIter),
594 {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
595 });
596 });
597
598 // Block GEMM
599 auto block_flatmm = BlockFlatmm();
600 // Acc register tile
601 auto c_block_tile = block_flatmm.MakeCBlockTile();
602
603 // B flat DRAM window for load
604 auto b_flat_distribution =
605 PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>();
606 auto b_flat_dram_window = // tile_window_with_static_distribution
608 b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
610 b_flat_dram_block_window_tmp.get_window_origin(),
611 b_flat_distribution);
612
613 // pingpong buffer for B
615 statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
617 b_flat_dram_windows;
618
620 statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
622 b_warp_tensor_ping;
623
625 statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
627 b_warp_tensor_pong;
628
629 // HEAD
630 // Prefetch A0
631 auto a_block_tile = load_tile(a_copy_dram_window);
632 // move A window to next k
633 move_tile_window(a_copy_dram_window, {0, kKPerBlock});
634
635 // prefetch B
636 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
637 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
638 b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
639
640 move_tile_window(b_flat_dram_windows(nIter)(kIter),
641 {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
642
643 b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
644 });
645 });
646 // move B window to next flat K
647 move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
648
649 auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
650 store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
651 __builtin_amdgcn_sched_barrier(0);
652
653 // Prefetch A1
654 a_block_tile = load_tile(a_copy_dram_window);
655 // move A window to next k
656 move_tile_window(a_copy_dram_window, {0, kKPerBlock});
657
658 // initialize C
659 tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
660
662
663 // preload A00,A10... from lds
664 statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))),
665 m_preload>
666 a_warp_tensor;
667
668 static_for<0, m_preload, 1>{}([&](auto loadIter) {
669 constexpr auto mIter = loadIter % MIterPerWarp;
670 constexpr auto kIter = loadIter / MIterPerWarp;
671 a_warp_tensor(loadIter) =
672 load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
673 });
674 __builtin_amdgcn_sched_barrier(0);
675
676 // MAIN LOOP
677 index_t iCounter = (num_loop - 1) / 2;
678 while(iCounter > 0)
679 {
680 // prefetch B(2i+1)
681 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
682 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
683 b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
684
685 move_tile_window(b_flat_dram_windows(nIter)(kIter),
686 {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
687
688 b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
689 });
690 });
691
692 // Prefill A(2i+1)
693 a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
694 store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
695
696 // Prefetch A(2i+2)
697 a_block_tile = load_tile(a_copy_dram_window);
698 // move A window to next k
699 move_tile_window(a_copy_dram_window, {0, kKPerBlock});
700
701 // GEMM 2i
702 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
703 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
704 constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
705 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
706 // read C warp tensor from C block tensor
707 CWarpTensor c_warp_tensor;
708
709 c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
710 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
711 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
712
713 // warp GEMM
714 WG{}(c_warp_tensor,
715 a_warp_tensor(number<AwarpIter>{}),
716 b_warp_tensor_ping(nIter)(kIter));
717
718 // write C warp tensor into C block tensor
719 c_block_tile.set_y_sliced_thread_data(
720 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
721 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
722 c_warp_tensor.get_thread_buffer());
723 });
724 // preload next A from lds
725 if constexpr((kIter * MIterPerWarp + mIter) <
727 {
728 constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
729 constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
730 a_warp_tensor(number<AwarpIter>{}) =
731 load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
732 }
733
734 // barrier
735 if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
736 {
738 }
739 });
740 });
741
742 // move B window to next flat K
743 move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
744
745 static_for<0, m_preload, 1>{}([&](auto loadIter) {
746 constexpr auto mIter = loadIter % MIterPerWarp;
747 constexpr auto kIter = loadIter / MIterPerWarp;
748 a_warp_tensor(loadIter) =
749 load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
750 });
752
753 // Next K
754
755 // prefetch B(2i+2)
756 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
757 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
758 b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
759
760 move_tile_window(b_flat_dram_windows(nIter)(kIter),
761 {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
762
763 b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
764 });
765 });
766
767 // Prefill A(2i+2)
768 a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
769 store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
770
771 // Prefetch A(2i+3)
772 a_block_tile = load_tile(a_copy_dram_window);
773 // move A window to next k
774 move_tile_window(a_copy_dram_window, {0, kKPerBlock});
775
776 // GEMM 2i+1
777 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
778 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
779 constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
780 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
781 // read C warp tensor from C block tensor
782 CWarpTensor c_warp_tensor;
783 c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
784 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
785 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
786
787 // warp GEMM
788 WG{}(c_warp_tensor,
789 a_warp_tensor(number<AwarpIter>{}),
790 b_warp_tensor_pong(nIter)(kIter));
791
792 // write C warp tensor into C block tensor
793 c_block_tile.set_y_sliced_thread_data(
794 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
795 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
796 c_warp_tensor.get_thread_buffer());
797 });
798 // preload next A from lds
799 if constexpr((kIter * MIterPerWarp + mIter) <
801 {
802 constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
803 constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
804 a_warp_tensor(number<AwarpIter>{}) =
805 load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
806 }
807
808 // barrier
809 if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
810 {
812 }
813 });
814 });
815
816 // move B window to next flat K
817 move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
818
819 static_for<0, m_preload, 1>{}([&](auto loadIter) {
820 constexpr auto mIter = loadIter % MIterPerWarp;
821 constexpr auto kIter = loadIter / MIterPerWarp;
822 a_warp_tensor(loadIter) =
823 load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
824 });
826
827 iCounter--;
828 }
829
830 // TAIL
831 if constexpr(TailNum == TailNumber::Even)
832 {
833 // prefetch B(loopK)
834 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
835 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
836 b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
837
838 move_tile_window(b_flat_dram_windows(nIter)(kIter),
839 {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
840
841 b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
842 });
843 });
844
845 // Prefill A(loopK)
846 a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
847 store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
848
849 // GEMM loopK-1
850 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
851 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
852 constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
853 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
854 // read C warp tensor from C block tensor
855 CWarpTensor c_warp_tensor;
856
857 c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
858 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
859 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
860
861 // warp GEMM
862 WG{}(c_warp_tensor,
863 a_warp_tensor(number<AwarpIter>{}),
864 b_warp_tensor_ping(nIter)(kIter));
865
866 // write C warp tensor into C block tensor
867 c_block_tile.set_y_sliced_thread_data(
868 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
869 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
870 c_warp_tensor.get_thread_buffer());
871 });
872 // preload next A from lds
873 if constexpr((kIter * MIterPerWarp + mIter) <
875 {
876 constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
877 constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
878 a_warp_tensor(number<AwarpIter>{}) =
879 load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
880 }
881
882 // barrier
883 if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
884 {
886 }
887 });
888 });
889
890 static_for<0, m_preload, 1>{}([&](auto loadIter) {
891 constexpr auto mIter = loadIter % MIterPerWarp;
892 constexpr auto kIter = loadIter / MIterPerWarp;
893 a_warp_tensor(loadIter) =
894 load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
895 });
896
898
899 // GEMM loopK
900 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
901 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
902 constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
903 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
904 // read C warp tensor from C block tensor
905 CWarpTensor c_warp_tensor;
906
907 c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
908 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
909 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
910
911 // warp GEMM
912 WG{}(c_warp_tensor,
913 a_warp_tensor(number<AwarpIter>{}),
914 b_warp_tensor_pong(nIter)(kIter));
915
916 // write C warp tensor into C block tensor
917 c_block_tile.set_y_sliced_thread_data(
918 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
919 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
920 c_warp_tensor.get_thread_buffer());
921 });
922 if constexpr((kIter * MIterPerWarp + mIter) <
924 {
925 constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
926 constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
927 a_warp_tensor(number<AwarpIter>{}) =
928 load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
929 }
930 // barrier
931 if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
932 {
934 }
935 });
936 });
938 }
939 else if constexpr(TailNum == TailNumber::Odd)
940 {
941 // GEMM loopK
942 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
943 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
944 constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
945 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
946 // read C warp tensor from C block tensor
947 CWarpTensor c_warp_tensor;
948
949 c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
950 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
951 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
952
953 // warp GEMM
954 WG{}(c_warp_tensor,
955 a_warp_tensor(number<AwarpIter>{}),
956 b_warp_tensor_ping(nIter)(kIter));
957
958 // write C warp tensor into C block tensor
959 c_block_tile.set_y_sliced_thread_data(
960 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
961 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
962 c_warp_tensor.get_thread_buffer());
963 });
964 // preload next A from lds
965 if constexpr((kIter * MIterPerWarp + mIter) <
967 {
968 constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
969 constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
970 a_warp_tensor(number<AwarpIter>{}) =
971 load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
972 }
973
974 // barrier
975 if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
976 {
978 }
979 });
980 });
982 }
983
984 return c_block_tile;
985 }
986
987 template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp>
988 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
989 const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
990 index_t num_loop,
991 void* p_smem_ping,
992 void* p_smem_pong) const
993 {
994 return operator()(
995 a_dram_block_window_tmp,
996 [](const ADataType & a) { return a; },
997 b_flat_dram_block_window_tmp,
998 num_loop,
999 p_smem_ping,
1000 p_smem_pong);
1001 }
1002};
1003
1004} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
TailNumber
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:21
@ Even
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:24
@ Odd
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:23
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
tuple_array< T, N > statically_indexed_array
Definition tile/core/container/statically_indexed_array.hpp:16
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:14
static CK_TILE_HOST constexpr bool BlockHasHotloop(index_t num_loop)
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:17
static constexpr index_t PrefetchStages
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:15
static CK_TILE_HOST_DEVICE auto TailHandler(const RunFunction &run_func, bool, TailNumber tail_num)
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:28
static CK_TILE_HOST constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:22
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:47
remove_cvref_t< decltype(config.template at< 0 >())> WG
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:63
static constexpr index_t dsread_num_perK
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:162
static constexpr index_t Bload_num_perK
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:167
static CK_TILE_HOST_DEVICE constexpr auto Last2ndHotLoopScheduler()
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:427
remove_cvref_t< typename Problem::CLayout > CLayout
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:55
static constexpr auto idxK
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:95
static constexpr index_t GetVectorSizeA()
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:78
remove_cvref_t< typename Problem::CDataType > CDataType
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:50
remove_cvref_t< typename Problem::ADataType > ADataType
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:48
static CK_TILE_HOST_DEVICE constexpr auto LastHotLoopScheduler()
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:481
static constexpr bool DoubleSmemBuffer
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:187
static constexpr index_t flatNPerWarp
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:76
static constexpr index_t KPerBlockPerIter
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:111
static constexpr index_t kNPerBlock
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:72
remove_cvref_t< typename BlockGemmShape::BlockTile > BlockTile
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:96
static constexpr index_t mfma_per_wg
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:156
static constexpr index_t MIterPerWarp
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:103
static constexpr index_t kMPerBlock
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:71
static constexpr index_t NWarp
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:101
static constexpr index_t KFlatPerBlockPerIter
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:107
static constexpr bool UsePersistentKernel
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:88
static constexpr index_t kKPerBlock
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:73
remove_cvref_t< typename BlockGemmShape::WarpTile > WarpTile
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:98
static constexpr auto I0
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:90
static constexpr index_t GetVectorSizeB()
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:79
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const BFlatBlockWindowTmp &b_flat_dram_block_window_tmp, index_t num_loop, void *p_smem_ping, void *p_smem_pong) const
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:988
static constexpr index_t DsWritePreIssue
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:65
static constexpr bool kPadM
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:82
static constexpr index_t Bload_rep
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:169
static constexpr auto I1
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:91
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:191
remove_cvref_t< typename Problem::ALayout > ALayout
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:53
static constexpr index_t flatKPerWarp
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:75
static constexpr index_t mfma_perM_perK
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:171
static constexpr bool kPadK
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:84
static constexpr index_t MWarp
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:100
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:51
static constexpr index_t kLdsAlignmentInBytes
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:86
static constexpr index_t NFlatPerBlockPerIter
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:108
static constexpr index_t KIterPerWarp
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:105
remove_cvref_t< typename BlockGemmShape::BlockWarps > BlockWarps
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:97
static constexpr index_t NIterPerWarp
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:104
remove_cvref_t< typename Problem::BLayout > BLayout
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:54
static CK_TILE_HOST_DEVICE constexpr auto TransposeC()
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:189
static constexpr auto idxM
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:93
static constexpr index_t dswrite_rep
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:164
static constexpr index_t dswrite_mIter
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:172
static CK_TILE_HOST_DEVICE constexpr auto SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:197
static constexpr index_t NumWaveGroups
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:87
static constexpr auto idxN
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:94
static CK_TILE_HOST const std::string GetName()
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:175
static constexpr auto I2
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:92
static constexpr index_t dswrite_num_perK
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:163
static constexpr auto config
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:60
static constexpr index_t Aload_num_perK
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:165
static constexpr index_t K1
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:113
remove_cvref_t< typename Problem::BDataType > BDataType
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:49
static constexpr auto TailNum
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:119
static constexpr bool HasHotLoop
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:118
static constexpr index_t HalfMIter
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:168
static constexpr index_t DsReadPreload
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:66
static constexpr index_t WaveSize
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:69
static constexpr index_t dsread_per_wg
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:158
static constexpr index_t dswrite_kIter
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:173
static constexpr index_t GetVectorSizeC()
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:80
static constexpr bool kPadN
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:83
remove_cvref_t< decltype(PipelinePolicy::template GetBlockFlatmm< Problem >())> BlockFlatmm
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:57
static CK_TILE_HOST_DEVICE constexpr auto HotLoopScheduler()
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:286
static constexpr index_t BlockSize
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:68
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const AElementFunction &a_element_func, const BFlatBlockWindowTmp &b_flat_dram_block_window_tmp, index_t num_loop, void *p_smem_ping, void *p_smem_pong) const
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:504
static constexpr index_t MPerBlockPerIter
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:110
static constexpr index_t m_preload
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:114
static constexpr index_t Aload_rep
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:166
Definition tile/core/numeric/integral_constant.hpp:30
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43