blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp Source File

blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp Source File
blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Naive pipeline with lowest resource request per WGP
11// GlobalPrefetchStages: 2
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 1
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t ThreadBlockSize,
18 index_t ScaleBlockSize,
19 typename ADataType,
20 typename AScaleDataType,
21 typename BDataType,
22 typename BScaleDataType,
23 typename ATileDesc,
24 typename BTileDesc,
25 typename AMmaTileDesc,
26 typename BMmaTileDesc,
27 index_t ABlockTransferSrcScalarPerVector,
28 index_t BBlockTransferSrcScalarPerVector,
29 index_t MPerBlock,
30 index_t NPerBlock,
31 index_t KPerBlock,
32 index_t MPerXDL,
33 index_t NPerXDL,
34 index_t MRepeat, // MXdlPerWave
35 index_t NRepeat, // NXdlPerWave
36 index_t KPack>
40
41template <index_t ThreadBlockSize,
42 index_t ScaleBlockSize,
43 typename ADataType,
44 typename AScaleDataType,
45 typename BDataType,
46 typename BScaleDataType,
47 typename ATileDesc,
48 typename BTileDesc,
49 typename AMmaTileDesc,
50 typename BMmaTileDesc,
51 index_t ABlockTransferSrcScalarPerVector,
52 index_t BBlockTransferSrcScalarPerVector,
53 index_t MPerBlock,
54 index_t NPerBlock,
55 index_t KPerBlock,
56 index_t MPerXDL,
57 index_t NPerXDL,
58 index_t MRepeat, // MXdlPerWave
59 index_t NRepeat, // NXdlPerWave
60 index_t KPack>
63 ThreadBlockSize,
64 ScaleBlockSize,
65 ADataType,
66 AScaleDataType,
67 BDataType,
68 BScaleDataType,
69 ATileDesc,
70 BTileDesc,
71 AMmaTileDesc,
72 BMmaTileDesc,
73 ABlockTransferSrcScalarPerVector,
74 BBlockTransferSrcScalarPerVector,
75 MPerBlock,
76 NPerBlock,
77 KPerBlock,
78 MPerXDL,
79 NPerXDL,
80 MRepeat,
81 NRepeat,
82 KPack> : BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
83 ADataType,
84 BDataType,
85 ATileDesc,
86 BTileDesc,
87 AMmaTileDesc,
88 BMmaTileDesc,
89 ABlockTransferSrcScalarPerVector,
90 BBlockTransferSrcScalarPerVector,
91 MPerBlock,
92 NPerBlock,
93 KPerBlock,
94 MPerXDL,
95 NPerXDL,
96 MRepeat,
97 NRepeat,
98 KPack>
99
100{
101
103 ADataType,
104 BDataType,
105 ATileDesc,
106 BTileDesc,
107 AMmaTileDesc,
108 BMmaTileDesc,
109 ABlockTransferSrcScalarPerVector,
110 BBlockTransferSrcScalarPerVector,
111 MPerBlock,
112 NPerBlock,
113 KPerBlock,
114 MPerXDL,
115 NPerXDL,
116 MRepeat,
117 NRepeat,
118 KPack>;
119 using Base::A_K1;
120 using Base::I0;
121 using Base::I1;
122 using Base::KRepeat;
123 using Base::MWaves;
124 using Base::NWaves;
125 using Base::WaveSize;
126 using Base::xdlops_gemm;
127 using typename Base::HotLoopInstList;
128
137 using Base::GetWaveIdx;
140
143
144 using Base::AMmaKStride;
145 using Base::APackedSize;
146 using Base::BMmaKStride;
147 using Base::BPackedSize;
148 using Base::KThreadChunk;
149
150 using Base::KXdlPack;
151 using Base::MXdlPack;
152 using Base::NXdlPack;
153
154 using AccType = typename Base::AccType;
155 using Tuple5 = typename Base::Tuple5;
158
159 static constexpr index_t PrefetchStages = 2;
160 static constexpr index_t LocalPrefetchStages = 2;
161 static constexpr index_t PrefillStages = 1;
162 static constexpr index_t GlobalBufferNum = 1;
163 static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1;
164
165 static constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack;
166 static constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack * 2;
169 static constexpr auto async_vmcnt_encoding = 3952 + async_vmcnt % 16 + async_vmcnt / 16 * 16384;
170
171 static constexpr auto ScalesPerKBlockSize =
172 KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
173
174 //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
175 static constexpr auto ScalesPerXdlopsRun =
176 (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
177
178 //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
179 static constexpr auto ScalesPerXdlopsRunPerThread =
180 ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
181
183 static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
184 static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
185 static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
186 "A scale pack data type too large!");
187 static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
188 "B scale pack data type too large!");
191
192 __host__ static constexpr bool BlockHasHotloop(index_t num_loop)
193 {
194 return num_loop > PrefetchStages;
195 }
196
197 __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
198 {
199 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
200 }
201
202 __device__ static constexpr auto HotLoopScheduler()
203 {
204 // A/B split schedule
205 // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
206 constexpr auto num_ds_read_inst_a =
207 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
210
211 constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
212 constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * 2;
213 constexpr auto num_buffer_load_stage1 =
214 num_buffer_load_inst_b + num_buffer_load_a_scale + num_buffer_load_b_scale;
215
216 constexpr auto num_buffer_load_stage2 = num_buffer_load_inst_a;
217
218 constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * APackedSize * 2;
219 constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
220
221 constexpr auto ds_read_a_issue_cycle =
222 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
223 constexpr auto ds_read_a_mfma_rate =
224 math::integer_divide_ceil(mfma_cycle - 8, 2 * ds_read_a_issue_cycle);
225
226 // constexpr auto num_dsread_a_mfma =
227 // (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
228
229 constexpr auto num_total_stages = MRepeat;
230
231 // Group num_mfma_perstage num_ds_read_a_perstage
232 // since we want to reuse a local register buffer
233 constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
234 constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages;
235
236 constexpr auto num_ds_read_a_mfma_perstage =
237 math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
238
239 constexpr auto num_ds_read_a_prefetch_stages = 2;
240
241 constexpr auto buffer_load_perstage_more =
242 math::integer_divide_ceil((num_buffer_load_stage1), (num_total_stages - 2));
243 constexpr auto buffer_load_perstage_less =
244 math::integer_divide_floor((num_buffer_load_stage1), (num_total_stages - 2));
245 constexpr auto buffer_load_perstage_stage2 =
246 math::integer_divide_floor((num_buffer_load_stage2), 2);
247
248 constexpr auto buffer_load_stages_more =
249 num_buffer_load_stage1 -
250 math::integer_divide_floor(num_buffer_load_stage1, (num_total_stages - 2)) *
251 ((num_total_stages - 2));
252
253 constexpr auto buffer_load_issue_point_interval_more =
254 num_mfma_perstage / buffer_load_perstage_more;
255 constexpr auto buffer_load_issue_point_interval_less =
256 num_mfma_perstage / buffer_load_perstage_less;
257 constexpr auto buffer_load_issue_point_interval_stage2 =
258 num_mfma_perstage / buffer_load_perstage_stage2;
259
260 // Stage 1
261 // global read more
263 static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
264 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
265
266 if constexpr(imfma % buffer_load_issue_point_interval_more == 0)
267 {
268 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
269 }
270
271 if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
272 {
273 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
274 }
275 });
276 });
277
278 // global read less
279 static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) {
280 static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
281 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
282 if constexpr(imfma % buffer_load_issue_point_interval_less == 0)
283 {
284 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
285 }
286 if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
287 {
288 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
289 }
290 });
291 });
292
293 // Stage 2, Sync
294 // lds synchronization, prefetch next loop local A
296 static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
297 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
298 if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0)
299 {
300 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
301 }
302 if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
303 {
304 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
305 }
306 });
307 });
308 }
309
310 template <bool HasMainLoop,
311 TailNumber TailNum,
312 typename AGridDesc,
313 typename ABlockDesc,
314 typename ABlockTransfer,
315 typename AGridBuffer,
316 typename ABlockBuffer,
317 typename ABlockTransferStep,
318 typename BGridDesc,
319 typename BBlockDesc,
320 typename BBlockTransfer,
321 typename BGridBuffer,
322 typename BBlockBuffer,
323 typename BBlockTransferStep,
324 typename CThreadBuffer,
325 typename AScaleGridBuffer,
326 typename AScaleGridDesc,
327 typename AScaleThreadTransfer,
328 typename BScaleGridBuffer,
329 typename BScaleGridDesc,
330 typename BScaleThreadTransfer>
331 __device__ void Run(
332 // ABlockCopy
333 const AGridDesc& a_grid_desc,
334 const ABlockDesc& a_block_desc,
335 ABlockTransfer& a_blockwise_copy,
336 const AGridBuffer& a_grid_buf,
337 ABlockBuffer& a_block_bufs,
338 const ABlockTransferStep& a_block_copy_step,
339 // B Gate and Up
340 const BGridDesc& b_grid_desc,
341 const BBlockDesc& b_block_desc,
342 BBlockTransfer& b_blockwise_copy,
343 BBlockTransfer& b_blockwise_copy_up,
344 const BGridBuffer& b_grid_buf,
345 const BGridBuffer& b_grid_buf_up,
346 BBlockBuffer& b_block_bufs,
347 const BBlockTransferStep& b_block_copy_step,
348 // CThread
349 CThreadBuffer& c_thread_buf,
350 CThreadBuffer& c_thread_buf_up,
351 // A and B scales
352 const AScaleGridDesc& a_scale_grid_desc,
353 AScaleThreadTransfer& a_scale_thread_copy,
354 const AScaleGridBuffer& a_scale_grid_buf,
355 // Gate and Up scale
356 const BScaleGridDesc& b_scale_grid_desc,
357 BScaleThreadTransfer& b_scale_thread_copy,
358 BScaleThreadTransfer& b_scale_thread_copy_up,
359 const BScaleGridBuffer& b_scale_grid_buf,
360 const BScaleGridBuffer& b_scale_grid_buf_up,
361 index_t num_loop) const
362 {
363 ignore = b_block_bufs;
365 a_thread_desc_.GetElementSpaceSize());
367 b_thread_desc_.GetElementSpaceSize());
369 b_thread_desc_.GetElementSpaceSize());
370
371 StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
372 StaticallyIndexedArray<decltype(b_thread_buf_up), Number<2>{}> b_thread_bufs_up;
373 constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0);
374
376 a_scale_thread_desc.GetElementSpaceSize());
377
379 b_scale_thread_desc.GetElementSpaceSize());
381 b_scale_thread_desc.GetElementSpaceSize());
382
383 StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
384 StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
385 StaticallyIndexedArray<decltype(b_scale_thread_buf_up), Number<2>{}> b_scale_thread_bufs_up;
386
387 // Global prefetch 1
388 a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I0));
389 b_blockwise_copy.Run(
390 b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I0));
391 b_blockwise_copy_up.Run(
392 b_grid_desc, b_grid_buf_up, b_block_desc, b_block_origin_idx, b_thread_bufs_up(I0));
393
394 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
395 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
396 b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
397
398 // Prefetch a_scales
399 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
400 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
401 a_scale_thread_copy.Run(a_scale_grid_desc,
402 a_scale_grid_buf,
404 make_tuple(m0, k0, I0),
405 a_scale_thread_bufs(I0));
406
407 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
408 make_multi_index(0, I1, 0));
409 });
410 a_scale_thread_copy.MoveSrcSliceWindow(
411 a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
412 });
413
414 // restore row id and advance to the next set of scales
415 a_scale_thread_copy.MoveSrcSliceWindow(
416 a_scale_grid_desc,
417 make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
418
419 // Prefetch b_scales_gate
420 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
421 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
422 b_scale_thread_copy.Run(b_scale_grid_desc,
423 b_scale_grid_buf,
425 make_tuple(n0, k0, I0),
426 b_scale_thread_bufs(I0));
427
428 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
429 make_multi_index(0, I1, 0));
430 });
431 b_scale_thread_copy.MoveSrcSliceWindow(
432 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
433 });
434
435 // restore col id and advance to the next set of scales
436 // NWaves * NPerXDL * NRepeat == NPerBlock
437 b_scale_thread_copy.MoveSrcSliceWindow(
438 b_scale_grid_desc,
439 make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
440
441 // Prefetch b_scales_up
442 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
443 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
444 b_scale_thread_copy_up.Run(b_scale_grid_desc,
445 b_scale_grid_buf_up,
447 make_tuple(n0, k0, I0),
448 b_scale_thread_bufs_up(I0));
449
450 b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc,
451 make_multi_index(0, I1, 0));
452 });
453 b_scale_thread_copy_up.MoveSrcSliceWindow(
454 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
455 });
456
457 // restore col id and advance to the next set of scales
458 // NWaves * NPerXDL * NRepeat == NPerBlock
459 b_scale_thread_copy_up.MoveSrcSliceWindow(
460 b_scale_grid_desc,
461 make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
462
463 // Local prefetch 1, sync the async load
464 __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
467 static_for<0, KRepeat, 1>{}([&](auto k) {
468 constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
469 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
470 static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
471 [&](auto chunk) {
472 constexpr auto a_k_step_chunk =
473 k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
474 a_thread_copy_.Run(
478 a_block_bufs(I0),
482 a_thread_buf);
483 });
484 });
485 });
486
487 // Global prefetch 2
488 a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I1));
489 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
490
491 // Initialize C
492 c_thread_buf.Clear();
493 __builtin_amdgcn_sched_barrier(0);
494 constexpr index_t SwitchM = MRepeat - LocalPrefetchStages;
495 // main body
496 if constexpr(HasMainLoop)
497 {
498 // loop over k with the step KPerBlock
499 index_t i = 0;
500 do
501 {
502 auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) {
503 b_blockwise_copy.Run(b_grid_desc,
504 b_grid_buf,
505 b_block_desc,
506 b_block_origin_idx,
507 b_thread_bufs(scale_mem_buf));
508 b_blockwise_copy_up.Run(b_grid_desc,
509 b_grid_buf_up,
510 b_block_desc,
511 b_block_origin_idx,
512 b_thread_bufs_up(scale_mem_buf));
513
514 // Prefetch a_scales
515 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
516 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
517 a_scale_thread_copy.Run(a_scale_grid_desc,
518 a_scale_grid_buf,
520 make_tuple(m0, k0, I0),
521 a_scale_thread_bufs(scale_mem_buf));
522
523 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
524 make_multi_index(0, I1, 0));
525 });
526 a_scale_thread_copy.MoveSrcSliceWindow(
527 a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
528 });
529
530 // restore row id and advance to the next set of scales
531 a_scale_thread_copy.MoveSrcSliceWindow(
532 a_scale_grid_desc,
533 make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
534
535 // Prefetch b_scales_gate
536 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
537 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
538 b_scale_thread_copy.Run(b_scale_grid_desc,
539 b_scale_grid_buf,
541 make_tuple(n0, k0, I0),
542 b_scale_thread_bufs(scale_mem_buf));
543
544 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
545 make_multi_index(0, I1, 0));
546 });
547 b_scale_thread_copy.MoveSrcSliceWindow(
548 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
549 });
550
551 // restore col id and advance to the next set of scales
552 // NWaves * NPerXDL * NRepeat == NPerBlock
553 b_scale_thread_copy.MoveSrcSliceWindow(
554 b_scale_grid_desc,
555 make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
556
557 // Prefetch b_scales_up
558 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
559 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
560 b_scale_thread_copy_up.Run(b_scale_grid_desc,
561 b_scale_grid_buf_up,
563 make_tuple(n0, k0, I0),
564 b_scale_thread_bufs_up(scale_mem_buf));
565
566 b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc,
567 make_multi_index(0, I1, 0));
568 });
569 b_scale_thread_copy_up.MoveSrcSliceWindow(
570 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
571 });
572
573 // restore col id and advance to the next set of scales
574 // NWaves * NPerXDL * NRepeat == NPerBlock
575 b_scale_thread_copy_up.MoveSrcSliceWindow(
576 b_scale_grid_desc,
577 make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
578
579 // a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
580 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
581 b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
582
583 static_for<0, MRepeat, 1>{}([&](auto m0) {
584 constexpr auto im_major = m0 / MXdlPack;
585 constexpr auto im_minor = m0 % MXdlPack;
586 static_for<0, KRepeat, 1>{}([&](auto k0) {
587 constexpr auto ik_major = k0 / KXdlPack;
588 constexpr auto ik_minor = k0 % KXdlPack;
589 static_for<0, NRepeat, 1>{}([&](auto n0) {
590 constexpr auto in_major = n0 / NXdlPack;
591 constexpr auto in_minor = n0 % NXdlPack;
592
593 constexpr index_t a_scale_offset =
594 a_scale_thread_desc.CalculateOffset(
595 make_tuple(im_major, ik_major, I0));
596 constexpr index_t b_scale_offset =
597 b_scale_thread_desc.CalculateOffset(
598 make_tuple(in_major, ik_major, I0));
599
600 static_assert(0 < ScalesPerXdlopsRunPerThread,
601 "Must have at least one scale per Xdlops "
602 "per Thread.");
603
605 a_scale_thread_vec;
607 b_scale_thread_vec;
609 b_scale_thread_vec_up;
610
611 // Pack scale_thread_buf into scale_thread_vec
613 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
614 a_scale_thread_bufs(
615 scale_comp_buf)[Number<a_scale_offset + s>{}];
616 });
617 // B Gate scale
619 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
620 b_scale_thread_bufs(
621 scale_comp_buf)[Number<b_scale_offset + s>{}];
622 });
623 // B Up scale
625 b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
626 b_scale_thread_bufs_up(
627 scale_comp_buf)[Number<b_scale_offset + s>{}];
628 });
629
632 vector_type<ComputeTypeB, KPack> b_thread_vec_up;
633
634 static_for<0, KPack, 1>{}([&](auto ik) {
635 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
636 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
637 make_tuple(I0, I0, im_minor, k0, ik))>{}];
638 b_thread_vec.template AsType<ComputeTypeB>()(ik) = b_thread_bufs
639 [scale_comp_buf][Number<b_thread_desc_.CalculateOffset(
640 make_tuple(in_major, I0, in_minor, k0, ik))>{}];
641 b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
642 b_thread_bufs_up
643 [scale_comp_buf][Number<b_thread_desc_.CalculateOffset(
644 make_tuple(in_major, I0, in_minor, k0, ik))>{}];
645 });
646
647 using mfma_input_type_a =
648 typename vector_type<ComputeTypeA,
649 xdlops_gemm.K1PerXdlops /
650 APackedSize>::type;
651 using mfma_input_type_b =
652 typename vector_type<ComputeTypeB,
653 xdlops_gemm.K1PerXdlops /
654 BPackedSize>::type;
655
656 using mfma_scale_input_type_a =
657 typename vector_type<AScaleDataType,
659 using mfma_scale_input_type_b =
660 typename vector_type<BScaleDataType,
662
663 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
664 make_tuple(im_major, in_major, im_minor, in_minor, 0));
665
666 // MFMA accumulation A * Gate
667 xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
668 ik_minor * NXdlPack + in_minor>(
669 a_thread_vec.template AsType<mfma_input_type_a>(),
670 a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
671 b_thread_vec.template AsType<mfma_input_type_b>(),
672 b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
673 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
674
675 // MFMA accumulation A * Up
676 xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
677 ik_minor * NXdlPack + in_minor>(
678 a_thread_vec.template AsType<mfma_input_type_a>(),
679 a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
680 b_thread_vec_up.template AsType<mfma_input_type_b>(),
681 b_scale_thread_vec_up
682 .template AsType<mfma_scale_input_type_b>(),
683 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
684 });
685 });
686
687 if constexpr(m0.value == SwitchM)
688 {
689 __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
691 a_blockwise_copy.Run(a_grid_desc,
692 a_grid_buf,
693 a_block_desc,
694 a_block_bufs(scale_comp_buf));
695 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
696 }
697
698 constexpr auto lds_buf =
699 m0.value >= SwitchM ? scale_mem_buf : scale_comp_buf;
700
701 static_for<0, KRepeat, 1>{}([&](auto k) {
702 constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
703 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
704 static_for<0,
705 xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk),
706 1>{}([&](auto chunk) {
707 constexpr auto a_k_step_chunk =
708 k_step +
709 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
710 a_thread_copy_.Run(
713 (MRepeat / MXdlPack)>{},
714 I0,
716 I0,
718 a_block_bufs(Number<lds_buf>{}),
721 I0,
723 k,
725 a_thread_buf);
726 });
727 });
728 });
729
730 if constexpr(MPerBlock >= 64)
732 __builtin_amdgcn_sched_barrier(0);
733 };
734
735 LoopFunc(I0, I1);
736 LoopFunc(I1, I0);
737
738 i += 2;
739 } while(i < (num_loop - 2));
740 }
741
742 // tail
743 if constexpr(TailNum == TailNumber::Even)
744 {
745 b_blockwise_copy.Run(
746 b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I1));
747 b_blockwise_copy_up.Run(
748 b_grid_desc, b_grid_buf_up, b_block_desc, b_block_origin_idx, b_thread_bufs_up(I1));
749
750 // Prefetch a_scales_up
751 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
752 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
753 a_scale_thread_copy.Run(a_scale_grid_desc,
754 a_scale_grid_buf,
756 make_tuple(m0, k0, I0),
757 a_scale_thread_bufs(I1));
758
759 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
760 make_multi_index(0, I1, 0));
761 });
762 a_scale_thread_copy.MoveSrcSliceWindow(
763 a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
764 });
765
766 // Prefetch b_scales_gate
767 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
768 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
769 b_scale_thread_copy.Run(b_scale_grid_desc,
770 b_scale_grid_buf,
772 make_tuple(n0, k0, I0),
773 b_scale_thread_bufs(I1));
774
775 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
776 make_multi_index(0, I1, 0));
777 });
778 b_scale_thread_copy.MoveSrcSliceWindow(
779 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
780 });
781
782 // Prefetch b_scales_up
783 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
784 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
785 b_scale_thread_copy_up.Run(b_scale_grid_desc,
786 b_scale_grid_buf_up,
788 make_tuple(n0, k0, I0),
789 b_scale_thread_bufs_up(I1));
790
791 b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc,
792 make_multi_index(0, I1, 0));
793 });
794 b_scale_thread_copy_up.MoveSrcSliceWindow(
795 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
796 });
797
798 static_for<0, MRepeat, 1>{}([&](auto m0) {
799 constexpr auto im_major = m0 / MXdlPack;
800 constexpr auto im_minor = m0 % MXdlPack;
801 static_for<0, KRepeat, 1>{}([&](auto k0) {
802 constexpr auto ik_major = k0 / KXdlPack;
803 constexpr auto ik_minor = k0 % KXdlPack;
804 static_for<0, NRepeat, 1>{}([&](auto n0) {
805 constexpr auto in_major = n0 / NXdlPack;
806 constexpr auto in_minor = n0 % NXdlPack;
807
808 constexpr index_t a_scale_offset =
809 a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
810 constexpr index_t b_scale_offset =
811 b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
812
813 static_assert(0 < ScalesPerXdlopsRunPerThread,
814 "Must have at least one scale per Xdlops "
815 "per Thread.");
816
820
821 // Pack scale_thread_buf into scale_thread_vec
823 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
824 a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
825 });
826 // B Gate scale
828 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
829 b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
830 });
831 // B Up scale
833 b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
834 b_scale_thread_bufs_up(I0)[Number<b_scale_offset + s>{}];
835 });
836
839 vector_type<ComputeTypeB, KPack> b_thread_vec_up;
840
841 static_for<0, KPack, 1>{}([&](auto ik) {
842 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
843 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
844 make_tuple(I0, I0, im_minor, k0, ik))>{}];
845 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
846 b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
847 make_tuple(in_major, I0, in_minor, k0, ik))>{}];
848 b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
849 b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
850 make_tuple(in_major, I0, in_minor, k0, ik))>{}];
851 });
852
853 using mfma_input_type_a =
854 typename vector_type<ComputeTypeA,
855 xdlops_gemm.K1PerXdlops / APackedSize>::type;
856 using mfma_input_type_b =
857 typename vector_type<ComputeTypeB,
858 xdlops_gemm.K1PerXdlops / BPackedSize>::type;
859
860 using mfma_scale_input_type_a =
862 using mfma_scale_input_type_b =
864
865 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
866 make_tuple(im_major, in_major, im_minor, in_minor, 0));
867
868 // MFMA accumulation A * Gate
869 xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
870 ik_minor * NXdlPack + in_minor>(
871 a_thread_vec.template AsType<mfma_input_type_a>(),
872 a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
873 b_thread_vec.template AsType<mfma_input_type_b>(),
874 b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
875 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
876
877 // MFMA accumulation A * Gate
878 xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
879 ik_minor * NXdlPack + in_minor>(
880 a_thread_vec.template AsType<mfma_input_type_a>(),
881 a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
882 b_thread_vec_up.template AsType<mfma_input_type_b>(),
883 b_scale_thread_vec_up.template AsType<mfma_scale_input_type_b>(),
884 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
885 });
886 });
887 if constexpr(m0.value == SwitchM)
888 {
889 __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
891 }
892
893 constexpr auto lds_buf = m0.value >= SwitchM ? I1 : I0;
894
895 static_for<0, KRepeat, 1>{}([&](auto k) {
896 constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
897 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
898 static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
899 [&](auto chunk) {
900 constexpr auto a_k_step_chunk =
901 k_step +
902 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
903 a_thread_copy_.Run(
906 (MRepeat / MXdlPack)>{},
907 I0,
909 I0,
911 a_block_bufs(Number<lds_buf>{}),
915 a_thread_buf);
916 });
917 });
918 });
919
920 static_for<0, MRepeat, 1>{}([&](auto m0) {
921 constexpr auto im_major = m0 / MXdlPack;
922 constexpr auto im_minor = m0 % MXdlPack;
923 static_for<0, KRepeat, 1>{}([&](auto k0) {
924 constexpr auto ik_major = k0 / KXdlPack;
925 constexpr auto ik_minor = k0 % KXdlPack;
926 static_for<0, NRepeat, 1>{}([&](auto n0) {
927 constexpr auto in_major = n0 / NXdlPack;
928 constexpr auto in_minor = n0 % NXdlPack;
929
930 constexpr index_t a_scale_offset =
931 a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
932 constexpr index_t b_scale_offset =
933 b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
934
935 static_assert(0 < ScalesPerXdlopsRunPerThread,
936 "Must have at least one scale per Xdlops "
937 "per Thread.");
938
942
943 // Pack scale_thread_buf into scale_thread_vec
945 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
946 a_scale_thread_bufs(I1)[Number<a_scale_offset + s>{}];
947 });
949 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
950 b_scale_thread_bufs(I1)[Number<b_scale_offset + s>{}];
951 });
953 b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
954 b_scale_thread_bufs_up(I1)[Number<b_scale_offset + s>{}];
955 });
956
959 vector_type<ComputeTypeB, KPack> b_thread_vec_up;
960
961 static_for<0, KPack, 1>{}([&](auto ik) {
962 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
963 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
964 make_tuple(I0, I0, im_minor, k0, ik))>{}];
965 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
966 b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
967 make_tuple(in_major, I0, in_minor, k0, ik))>{}];
968 b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
969 b_thread_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(
970 make_tuple(in_major, I0, in_minor, k0, ik))>{}];
971 });
972
973 using mfma_input_type_a =
974 typename vector_type<ComputeTypeA,
975 xdlops_gemm.K1PerXdlops / APackedSize>::type;
976 using mfma_input_type_b =
977 typename vector_type<ComputeTypeB,
978 xdlops_gemm.K1PerXdlops / BPackedSize>::type;
979
980 using mfma_scale_input_type_a =
982 using mfma_scale_input_type_b =
984
985 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
986 make_tuple(im_major, in_major, im_minor, in_minor, 0));
987
988 // MFMA accumulation A * Gate
989 xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
990 ik_minor * NXdlPack + in_minor>(
991 a_thread_vec.template AsType<mfma_input_type_a>(),
992 a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
993 b_thread_vec.template AsType<mfma_input_type_b>(),
994 b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
995 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
996
997 // MFMA accumulation A * Up
998 xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
999 ik_minor * NXdlPack + in_minor>(
1000 a_thread_vec.template AsType<mfma_input_type_a>(),
1001 a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
1002 b_thread_vec_up.template AsType<mfma_input_type_b>(),
1003 b_scale_thread_vec_up.template AsType<mfma_scale_input_type_b>(),
1004 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
1005 });
1006 });
1007 if constexpr(m0.value < (MRepeat - LocalPrefetchStages))
1008 {
1009 static_for<0, KRepeat, 1>{}([&](auto k) {
1010 constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
1011 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
1012 static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
1013 [&](auto chunk) {
1014 constexpr auto a_k_step_chunk =
1015 k_step +
1016 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
1017 a_thread_copy_.Run(
1020 (MRepeat / MXdlPack)>{},
1021 I0,
1023 I0,
1025 a_block_bufs(I1),
1027 make_tuple(I0,
1028 I0,
1030 k,
1032 a_thread_buf);
1033 });
1034 });
1035 }
1036 });
1037 }
1038 else if constexpr(TailNum == TailNumber::Odd)
1039 {
1040 static_for<0, MRepeat, 1>{}([&](auto m0) {
1041 constexpr auto im_major = m0 / MXdlPack;
1042 constexpr auto im_minor = m0 % MXdlPack;
1043 static_for<0, KRepeat, 1>{}([&](auto k0) {
1044 constexpr auto ik_major = k0 / KXdlPack;
1045 constexpr auto ik_minor = k0 % KXdlPack;
1046 static_for<0, NRepeat, 1>{}([&](auto n0) {
1047 constexpr auto in_major = n0 / NXdlPack;
1048 constexpr auto in_minor = n0 % NXdlPack;
1049
1050 constexpr index_t a_scale_offset =
1051 a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
1052 constexpr index_t b_scale_offset =
1053 b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
1054
1055 static_assert(0 < ScalesPerXdlopsRunPerThread,
1056 "Must have at least one scale per Xdlops "
1057 "per Thread.");
1058
1062
1063 // Pack scale_thread_buf into scale_thread_vec
1065 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
1066 a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
1067 });
1068 // B Gate scale
1070 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
1071 b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
1072 });
1073 // B Up scale
1075 b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
1076 b_scale_thread_bufs_up(I0)[Number<b_scale_offset + s>{}];
1077 });
1078
1081 vector_type<ComputeTypeB, KPack> b_thread_vec_up;
1082
1083 static_for<0, KPack, 1>{}([&](auto ik) {
1084 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1085 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1086 make_tuple(I0, I0, im_minor, k0, ik))>{}];
1087 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1088 b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
1089 make_tuple(in_major, I0, in_minor, k0, ik))>{}];
1090 b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
1091 b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
1092 make_tuple(in_major, I0, in_minor, k0, ik))>{}];
1093 });
1094
1095 using mfma_input_type_a =
1096 typename vector_type<ComputeTypeA,
1097 xdlops_gemm.K1PerXdlops / APackedSize>::type;
1098 using mfma_input_type_b =
1099 typename vector_type<ComputeTypeB,
1100 xdlops_gemm.K1PerXdlops / BPackedSize>::type;
1101
1102 using mfma_scale_input_type_a =
1104 using mfma_scale_input_type_b =
1106
1107 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
1108 make_tuple(im_major, in_major, im_minor, in_minor, 0));
1109
1110 // MFMA accumulation A * Gate
1111 xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
1112 ik_minor * NXdlPack + in_minor>(
1113 a_thread_vec.template AsType<mfma_input_type_a>(),
1114 a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
1115 b_thread_vec.template AsType<mfma_input_type_b>(),
1116 b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
1117 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1118
1119 // MFMA accumulation A * up
1120 xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
1121 ik_minor * NXdlPack + in_minor>(
1122 a_thread_vec.template AsType<mfma_input_type_a>(),
1123 a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
1124 b_thread_vec_up.template AsType<mfma_input_type_b>(),
1125 b_scale_thread_vec_up.template AsType<mfma_scale_input_type_b>(),
1126 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
1127 });
1128 });
1129 if constexpr(m0.value < (MRepeat - LocalPrefetchStages))
1130 {
1131 static_for<0, KRepeat, 1>{}([&](auto k) {
1132 constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
1133 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
1134 static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
1135 [&](auto chunk) {
1136 constexpr auto a_k_step_chunk =
1137 k_step +
1138 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
1139 a_thread_copy_.Run(
1142 (MRepeat / MXdlPack)>{},
1143 I0,
1145 I0,
1147 a_block_bufs(I0),
1149 make_tuple(I0,
1150 I0,
1152 k,
1154 a_thread_buf);
1155 });
1156 });
1157 }
1158 });
1159 }
1160 }
1161
1162 // Length: A[ARegBuf, MWave, MXdlPack, KRepeat, KPack]
1163 // Order: 1 0 3 2 4
1164 static constexpr auto ARegBuf = 2;
1167
1171 decltype(a_thread_desc_),
1174 4,
1175 A_K1,
1176 A_K1>;
1178
1179 // TODO: make this field protected when a_scale_thread_copy_ is moved
1180 // here
1183 Number<KRepeat / KXdlPack>{},
1185
1186 // TODO: make this field protected when b_scale_thread_copy_ is moved
1187 // here
1190 Number<KRepeat / KXdlPack>{},
1192
1193 protected:
1194 // using Base::a_thread_copy_;
1195 // using Base::a_thread_desc_;
1196 using Base::b_thread_copy_;
1197 using Base::b_thread_desc_;
1198 using Base::c_thread_desc_;
1199};
1200
1201} // namespace ck
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
Definition utility/math.hpp:66
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void block_sync_lds()
Definition synchronization.hpp:16
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops,(packed_size_v< ComputeTypeA > > 1||packed_size_v< ComputeTypeB > > 1)> HotLoopInstList
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:88
__host__ __device__ BlockwiseGemmXdlops_mx_pipeline_base(Tuple5 a_origin=CalculateAThreadOriginDataIndex(), Tuple5 b_origin=CalculateBThreadOriginDataIndex())
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:204
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_bufs, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, BBlockTransfer &b_blockwise_copy_up, const BGridBuffer &b_grid_buf, const BGridBuffer &b_grid_buf_up, BBlockBuffer &b_block_bufs, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, CThreadBuffer &c_thread_buf_up, const AScaleGridDesc &a_scale_grid_desc, AScaleThreadTransfer &a_scale_thread_copy, const AScaleGridBuffer &a_scale_grid_buf, const BScaleGridDesc &b_scale_grid_desc, BScaleThreadTransfer &b_scale_thread_copy, BScaleThreadTransfer &b_scale_thread_copy_up, const BScaleGridBuffer &b_scale_grid_buf, const BScaleGridBuffer &b_scale_grid_buf_up, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp:331
BlockwiseGemmXdlops_mx_pipeline_base< ThreadBlockSize, ADataType, BDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp:102
ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeTypeA, decltype(a_block_desc_m0_m1_m2_m3_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, 1, KThreadChunk >, Sequence< 0, 1, 2, 3, 4 >, 4, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp:1168
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp:38
Definition utility/sequence.hpp:43
Definition threadwise_tensor_slice_transfer.hpp:1260
Unsigned representation of a conventional biased Float32 exponent.
Definition utility/e8m0.hpp:26
Definition functional2.hpp:33
Definition dtype_vector.hpp:10