threadwise_tensor_slice_transfer_v3r1.hpp Source File

threadwise_tensor_slice_transfer_v3r1.hpp Source File#

Composable Kernel: threadwise_tensor_slice_transfer_v3r1.hpp Source File
threadwise_tensor_slice_transfer_v3r1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
12
14
15namespace ck {
16
17// Assume:
18// 1. src_desc and dst_desc are not known at compile-time
19// 2. SrcBuffer and DstBuffer are DynamicBuffer
20// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
21// 4. Use thread buffer
22template <typename SliceLengths,
23 typename SrcElementwiseOperation,
24 typename DstElementwiseOperation,
26 typename SrcData,
27 typename DstData,
28 typename SrcDesc,
29 typename DstDesc,
30 typename SrcDimAccessOrder,
31 typename DstDimAccessOrder,
32 index_t SrcVectorDim,
33 index_t DstVectorDim,
34 index_t SrcScalarPerVector_,
35 index_t DstScalarPerVector_,
36 index_t SrcScalarStrideInVector,
37 index_t DstScalarStrideInVector,
38 bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
39 // RunRead(), will be fused with MoveSrcSliceWindow to
40 // save addr computation
41 bool DstResetCoordinateAfterRun, // control whether to move back dst coordinate after each
42 // RunWrite(), will be fused with MoveDstSliceWindow to
43 // save addr computation
44 index_t NumThreadScratch = 1>
46{
47 static constexpr index_t nDim = SliceLengths::Size();
49
50 using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
51 using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
52
53 using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
54 using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
55
56 static constexpr auto I0 = Number<0>{};
57 static constexpr auto I1 = Number<1>{};
58 static constexpr auto I2 = Number<2>{};
59 static constexpr auto I3 = Number<3>{};
60 static constexpr auto I4 = Number<4>{};
61 static constexpr auto I5 = Number<5>{};
62 static constexpr auto I6 = Number<6>{};
63 static constexpr auto I7 = Number<7>{};
64 static constexpr auto I8 = Number<8>{};
65 static constexpr auto I10 = Number<10>{};
66 static constexpr auto I12 = Number<12>{};
67 static constexpr auto I13 = Number<13>{};
68 static constexpr auto I14 = Number<14>{};
69 static constexpr auto I16 = Number<16>{};
70
71 static constexpr index_t PackedSize = []() {
73 return 2;
74 else
75 return 1;
76 }();
77
78 static constexpr auto SrcScalarPerVector = Number<SrcScalarPerVector_ / PackedSize>{};
79 static constexpr auto DstScalarPerVector = Number<DstScalarPerVector_ / PackedSize>{};
80
82 const SrcDesc& src_desc,
83 const Index& src_slice_origin,
84 const SrcElementwiseOperation& src_element_op,
85 const DstDesc& dst_desc,
86 const Index& dst_slice_origin,
87 const DstElementwiseOperation& dst_element_op)
88 : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
89 dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
90 src_element_op_(src_element_op),
91 dst_element_op_(dst_element_op)
92 {
93 if constexpr((packed_size_v<SrcData>) > 1)
94 {
96 "SrcData != DstData");
97
98 static_assert(
99 SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0,
100 "SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type");
101
102 static_assert(SrcVectorDim == DstVectorDim,
103 "Packed data type does not support transpose");
104 }
105 }
106
107 __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
108 {
109 src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
110 }
111
112 __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
113 {
114 dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
115 }
116
117 template <typename SrcBuffer, index_t ThreadScratchId = 0>
118 __device__ void RunRead(const SrcDesc& src_desc,
119 const SrcBuffer& src_buf,
121 {
122 static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
123 SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
124 "wrong!");
125
126 static_assert(
128 "wrong! SrcBuffer and SrcData data type are inconsistent");
129
130 // scalar per access on each dim
131 // TODO: don't use lambda_scalar_per_access
132 constexpr auto src_scalar_per_access = generate_sequence(
134
135 constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
136
137 static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector_) == 0,
138 "SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector");
139
140 constexpr auto src_dim_access_order = SrcDimAccessOrder{};
141
142 constexpr auto ordered_src_access_lengths =
143 container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
144
145 // make forward steps
146 const auto src_forward_steps = generate_tuple(
147 [&](auto i) {
148 Index forward_step_idx;
149
150 static_for<0, nDim, 1>{}([&](auto j) {
151 forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
152 });
153
154 return make_tensor_coordinate_step(src_desc, forward_step_idx);
155 },
156 Number<nDim>{});
157
158 // make backward steps
159 const auto src_backward_steps = generate_tuple(
160 [&](auto i) {
161 Index backward_step_idx;
162
163 static_for<0, nDim, 1>{}([&](auto j) {
164 backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
165 });
166
167 return make_tensor_coordinate_step(src_desc, backward_step_idx);
168 },
169 Number<nDim>{});
170
171 // loop over tensor and copy
172 static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
173 // judge move forward or move backward
174 constexpr auto forward_sweep = [&]() {
176
177 forward_sweep_(I0) = true;
178
179 static_for<1, nDim, 1>{}([&](auto i) {
180 index_t tmp = ordered_src_access_idx[I0];
181
182 static_for<1, i, 1>{}([&](auto j) {
183 tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
184 });
185
186 forward_sweep_(i) = tmp % 2 == 0;
187 });
188
189 return forward_sweep_;
190 }();
191
192 // calculate src data index
193 constexpr auto src_data_idx = [&]() {
194 Index ordered_idx;
195
196 static_for<0, nDim, 1>{}([&](auto i) {
197 ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
198 : ordered_src_access_lengths[i] - 1 -
199 ordered_src_access_idx[i];
200 });
201
202 return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
203 src_scalar_per_access;
204 }();
205
206 constexpr auto src_data_idx_seq = generate_sequence_v2(
207 [&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
208
209 // maintain a container record is_src_valid, waiting for RunWrite use.
210 const bool is_src_valid =
212 src_oob_thread_scratch_tuple_(thread_scratch_id)
213 .template SetAsType<bool>(src_data_idx_seq, is_src_valid);
214
216 using dst_vector_t = typename dst_vector_type::type;
217 dst_vector_type op_r_v;
218
219 constexpr auto get_elem_op_vec_len = []() {
220 if constexpr(is_detected<is_pack8_invocable_t, decltype(src_element_op_)>::value)
221 {
222 if constexpr(decltype(src_element_op_)::is_pack8_invocable)
223 return math::min(8, SrcScalarPerVector);
224 }
225 else if constexpr(is_detected<is_pack4_invocable_t,
226 decltype(src_element_op_)>::value)
227 {
228 if constexpr(decltype(src_element_op_)::is_pack4_invocable)
229 return math::min(4, SrcScalarPerVector);
230 }
231 else if constexpr(is_detected<is_pack2_invocable_t,
232 decltype(src_element_op_)>::value)
233 {
234 if constexpr(decltype(src_element_op_)::is_pack2_invocable)
235 return math::min(2, SrcScalarPerVector);
236 }
237 else
238 {
239 return 1;
240 }
241 };
242
243 constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
244
245 using src_elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
246 using dst_elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
247
248 using VectorSizeLookupTable = Tuple<Sequence<>,
249 Sequence<I1>,
250 Sequence<I2>,
251 Sequence<I2, I1>,
252 Sequence<I4>,
253 Sequence<I4, I1>,
254 Sequence<I4, I2>,
255 Sequence<I4, I2, I1>,
256 Sequence<I8>,
257 Sequence<I8, I1>,
258 Sequence<I8, I2>,
259 Sequence<I8, I2, I1>,
260 Sequence<I8, I4>,
261 Sequence<I8, I4, I1>,
262 Sequence<I8, I4, I2>,
263 Sequence<I8, I4, I2, I1>,
264 Sequence<I16>>;
265 using VectorOffsetsLookupTable = Tuple<Sequence<>,
266 Sequence<I0>,
267 Sequence<I0>,
268 Sequence<I0, I2>,
269 Sequence<I0>,
270 Sequence<I0, I4>,
271 Sequence<I0, I4>,
272 Sequence<I0, I4, I6>,
273 Sequence<I0>,
274 Sequence<I0, I8>,
275 Sequence<I0, I8>,
276 Sequence<I0, I8, I10>,
277 Sequence<I0, I8>,
278 Sequence<I0, I8, I12>,
279 Sequence<I0, I8, I12>,
280 Sequence<I0, I8, I12, I14>,
281 Sequence<I0>>;
282
283 static_for<0, tuple_element_t<SrcScalarPerVector, VectorSizeLookupTable>::Size(), 1>{}(
284 [&](auto v_idx) {
285 constexpr auto VectorLoadSize =
286 tuple_element_t<SrcScalarPerVector, VectorSizeLookupTable>::At(v_idx);
287 constexpr auto LoadOffset =
288 tuple_element_t<SrcScalarPerVector, VectorOffsetsLookupTable>::At(v_idx);
289
290 using src_vector_container = vector_type_maker_t<SrcData, VectorLoadSize>;
291 using src_vector_container_t = typename src_vector_container::type;
292
293 src_vector_container src_vector =
294 src_vector_container{src_buf.template Get<src_vector_container_t>(
295 src_coord_.GetOffset() / PackedSize + LoadOffset, true)};
296
297 static_for<0, VectorLoadSize / elem_op_vec_len, 1>{}([&](auto idx) {
298 // apply the src elementwise op and convert to DstData under the hood if
299 // needed
300 src_element_op_(
301 op_r_v.template AsType<dst_elem_op_vec_t>()(idx + LoadOffset),
302 src_vector.template AsType<src_elem_op_vec_t>()[idx]);
303 });
304 });
305
306 // copy data from src_vector_container into src_thread_scratch_
307 src_thread_scratch_tuple_(thread_scratch_id)
308 .template SetAsType<dst_vector_t>(src_data_idx_seq,
309 op_r_v.template AsType<dst_vector_t>()[I0]);
310
311 constexpr auto move_on_dim = [&]() constexpr {
313
314 static_for<0, nDim, 1>{}([&](auto i) {
315 move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
316
317 static_for<i + 1, nDim, 1>{}([&](auto j) {
318 move_on_dim_(i) &=
319 ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
320 });
321 });
322
323 return move_on_dim_;
324 }();
325
326 // move src coord
327 static_for<0, nDim, 1>{}([&](auto i) {
328 if constexpr(move_on_dim[i])
329 {
330 if constexpr(forward_sweep[i])
331 {
333 src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
334 }
335 else
336 {
338 src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
339 }
340 }
341 });
342 });
343
344 // move src coordinate back to slice origin (or not)
345 if constexpr(SrcResetCoordinateAfterRun)
346 {
347 const auto src_reset_step =
349
350 move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
351 }
352 }
353
354 template <typename SeqIdx, index_t ThreadScratchId = 0>
355 __device__ constexpr auto
357 {
359 return src_thread_scratch_tuple_(thread_scratch_id).template GetAsType<vector_t>(SeqIdx{});
360 }
361
362 template <index_t ThreadScratchId>
363 __device__ void
365 {
366#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
367 static_ford<SliceLengths>{}([&](auto idx) {
368 dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
369 });
370#else
371 // OOB Check
372 constexpr auto src_scalar_per_access = generate_sequence(
374
375 constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
376
377 constexpr auto src_dim_access_order = SrcDimAccessOrder{};
378
379 constexpr auto ordered_src_access_lengths =
380 container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
381
382 // loop over tensor and copy
383 static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
384 // judge move forward or move backward
385 constexpr auto forward_sweep = [&]() {
387
388 forward_sweep_(I0) = true;
389
390 static_for<1, nDim, 1>{}([&](auto i) {
391 index_t tmp = ordered_src_access_idx[I0];
392
393 static_for<1, i, 1>{}([&](auto j) {
394 tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
395 });
396
397 forward_sweep_(i) = tmp % 2 == 0;
398 });
399
400 return forward_sweep_;
401 }();
402
403 // calculate src data index
404 constexpr auto src_data_idx = [&]() {
405 Index ordered_idx;
406
407 static_for<0, nDim, 1>{}([&](auto i) {
408 ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
409 : ordered_src_access_lengths[i] - 1 -
410 ordered_src_access_idx[i];
411 });
412
413 return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
414 src_scalar_per_access;
415 }();
416
417 constexpr auto src_data_idx_seq = generate_sequence_v2(
418 [&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
419
421
422 auto op_r = src_thread_scratch_tuple_(thread_scratch_id)
423 .template GetAsType<vector_t>(src_data_idx_seq);
424
425 const bool is_src_valid = src_oob_thread_scratch_tuple_(thread_scratch_id)
426 .template GetAsType<bool>(src_data_idx_seq);
427
428 auto op_r_v = is_src_valid ? op_r : vector_t(0);
429
430 src_thread_scratch_tuple_(thread_scratch_id)
431 .template SetAsType<vector_t>(src_data_idx_seq, op_r_v);
432 });
433
434 // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
435 // TODO make this logic more generic for more sub-dword datatype
436 if constexpr(SrcVectorDim != DstVectorDim &&
438 SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
440 SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) ||
442 SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
443 {
445 "in-register transpose is not supported for pk_i4_t");
447 "in-register transpose is not supported for f4x2_pk_t");
448 // each transpose does
449 // DstScalarPerVector # of src vectors in src_thread_scratch_
450 // SrcScalarPerVector # of dst vectors in dst_thread_scratch_
451 constexpr index_t num_src_vector = Number<DstScalarPerVector>{};
452 constexpr index_t num_dst_vector = Number<SrcScalarPerVector>{};
453
454 // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose
455 // TODO: make this logic generic for all scenario
456 static_assert(SrcVectorDim != DstVectorDim, "wrong");
457
458 constexpr auto src_scalar_step_in_vector = generate_sequence(
460
461 constexpr auto dst_scalar_step_in_vector = generate_sequence(
463
464 constexpr auto scalar_per_access = generate_sequence(
467 DstVectorDim,
469 Number<nDim>{});
470
471 constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
472
473 static_ford<decltype(access_lengths)>{}([&](auto access_idx) {
474 constexpr auto data_idx = access_idx * scalar_per_access;
475
476 constexpr auto data_idx_seq = generate_sequence_v2(
477 [&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
478
481
482 // get DstScalarPerVector # of read-only references to src vectors from
483 // src_thread_scratch_
484 const auto src_vector_refs = generate_tie(
485 [&](auto i) -> const src_vector_t& {
486 // i increment corresponds to movement in DstVectorDim
487 return src_thread_scratch_tuple_[thread_scratch_id].GetVectorTypeReference(
488 data_idx_seq + i * dst_scalar_step_in_vector);
489 },
491
492 // get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_
493 auto dst_vector_refs = generate_tie(
494 [&](auto i) -> dst_vector_t& {
495 // i increment corresponds to movement in SrcVectorDim
496 return dst_thread_scratch_.GetVectorTypeReference(
497 data_idx_seq + i * src_scalar_step_in_vector);
498 },
500
501 // do data transpose
503 src_vector_refs, dst_vector_refs);
504 });
505 }
506 else
507 {
508 constexpr auto packed_per_access = generate_sequence(
510
511 constexpr auto packed_access_lengths = SliceLengths{} / packed_per_access;
512
513 static_ford<decltype(packed_access_lengths)>{}([&](auto idx) {
514 dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
515 });
516 }
517#endif
518 }
519
520 template <typename DstBuffer, index_t ThreadScratchId = 0>
521 __device__ void RunWrite(const DstDesc& dst_desc,
522 DstBuffer& dst_buf,
524 {
525 // if there is transpose, it's done here
526 // if there is oob check, it's done here
527 // TODO move this elsewhere
529
530 static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
531 DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
532 "wrong!");
533
534 static_assert(
536 "wrong! SrcBuffer or DstBuffer data type is wrong");
537
538 // src scalar per access on each dim
539 // TODO: don't use this
540 constexpr auto dst_scalar_per_access = generate_sequence(
542
543 constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
544
545 constexpr auto dst_dim_access_order = DstDimAccessOrder{};
546
547 constexpr auto ordered_dst_access_lengths =
548 container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
549
550 // make forward steps
551 const auto dst_forward_steps = generate_tuple(
552 [&](auto i) {
553 Index forward_step_idx;
554
555 static_for<0, nDim, 1>{}([&](auto j) {
556 forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
557 });
558
559 return make_tensor_coordinate_step(dst_desc, forward_step_idx);
560 },
561 Number<nDim>{});
562
563 // make backward steps
564 const auto dst_backward_steps = generate_tuple(
565 [&](auto i) {
566 Index backward_step_idx;
567
568 static_for<0, nDim, 1>{}([&](auto j) {
569 backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
570 });
571
572 return make_tensor_coordinate_step(dst_desc, backward_step_idx);
573 },
574 Number<nDim>{});
575
576 // loop over tensor and copy
577 static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
578 // judge move forward or move backward
579 constexpr auto forward_sweep = [&]() {
581
582 forward_sweep_(I0) = true;
583
584 static_for<1, nDim, 1>{}([&](auto i) {
585 index_t tmp = ordered_dst_access_idx[I0];
586
587 static_for<1, i, 1>{}([&](auto j) {
588 tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
589 });
590
591 forward_sweep_(i) = tmp % 2 == 0;
592 });
593
594 return forward_sweep_;
595 }();
596
597 // calculate dst data index
598 constexpr auto dst_data_idx = [&]() {
599 Index ordered_idx;
600
601 static_for<0, nDim, 1>{}([&](auto i) {
602 ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i]
603 : ordered_dst_access_lengths[i] - 1 -
604 ordered_dst_access_idx[i];
605 });
606
607 return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
608 dst_scalar_per_access;
609 }();
610
611 constexpr auto dst_data_idx_seq = generate_sequence_v2(
612 [&](auto i) { return Number<dst_data_idx[i]>{}; }, Number<dst_data_idx.Size()>{});
613
614 const bool is_dst_valid =
616
618 using dst_vector_t = typename dst_vector_type::type;
619
620 // copy data from dst_thread_scratch_ into dst_vector_container
621 auto dst_vector_container = dst_vector_type{
622 dst_thread_scratch_.template GetAsType<dst_vector_t>(dst_data_idx_seq)};
623
624 static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
625 DstData dst_v;
626
627 // apply DstElementwiseOperation
628 dst_element_op_(dst_v, dst_vector_container.template AsType<DstData>()[i]);
629 });
630
631 // copy data from dst_vector_container to dst_buf
632 dst_buf.template Set<dst_vector_t>(
633 dst_coord_.GetOffset() / PackedSize,
634 is_dst_valid,
635 dst_vector_container.template AsType<dst_vector_t>()[I0]);
636
637 constexpr auto move_on_dim = [&]() constexpr {
639
640 static_for<0, nDim, 1>{}([&](auto i) {
641 move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
642
643 static_for<i + 1, nDim, 1>{}([&](auto j) {
644 move_on_dim_(i) &=
645 ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
646 });
647 });
648
649 return move_on_dim_;
650 }();
651
652 // move dst coord
653 static_for<0, nDim, 1>{}([&](auto i) {
654 if constexpr(move_on_dim[i])
655 {
656 if constexpr(forward_sweep[i])
657 {
659 dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
660 }
661 else
662 {
664 dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
665 }
666 }
667 });
668 });
669
670 // move dst coordinate back to slice origin (or not)
671 if constexpr(DstResetCoordinateAfterRun)
672 {
673 const auto dst_reset_step =
675
676 move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
677 }
678 }
679
680 __device__ static constexpr auto GetSrcCoordinateResetStep()
681 {
682 // scalar per access on each dim
683 // TODO: don't use lambda_scalar_per_access
684 constexpr auto src_scalar_per_access = generate_sequence(
686
687 constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
688
689 constexpr auto src_dim_access_order = SrcDimAccessOrder{};
690
691 constexpr auto ordered_src_access_lengths =
692 container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
693
694 // judge move forward or move backward during the last iteration
695 constexpr auto forward_sweep = [&]() {
697
698 forward_sweep_(I0) = true;
699
700 static_for<1, nDim, 1>{}([&](auto i) {
701 index_t tmp = ordered_src_access_lengths[I0] - 1;
702
703 static_for<1, i, 1>{}([&](auto j) {
704 tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
705 });
706
707 forward_sweep_(i) = tmp % 2 == 0;
708 });
709
710 return forward_sweep_;
711 }();
712
713 // calculate src data index after last iteration in RunRead(), if it has not being reset by
714 // RunRead()
715 constexpr auto src_data_idx = [&]() {
716 Index ordered_idx;
717
718 static_for<0, nDim, 1>{}([&](auto i) {
719 ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
720 });
721
722 return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
723 src_scalar_per_access;
724 }();
725
726 //
727 constexpr auto reset_src_data_step = [&]() {
728 Index reset_src_data_step_;
729
730 static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
731
732 return reset_src_data_step_;
733 }();
734
735 return reset_src_data_step;
736 }
737
738 __device__ static constexpr auto GetDstCoordinateResetStep()
739 {
740 // scalar per access on each dim
741 // TODO: don't use lambda_scalar_per_access
742 constexpr auto dst_scalar_per_access = generate_sequence(
744
745 constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
746
747 constexpr auto dst_dim_access_order = DstDimAccessOrder{};
748
749 constexpr auto ordered_dst_access_lengths =
750 container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
751
752 // judge move forward or move backward during the last iteration
753 constexpr auto forward_sweep = [&]() {
755
756 forward_sweep_(I0) = true;
757
758 static_for<1, nDim, 1>{}([&](auto i) {
759 index_t tmp = ordered_dst_access_lengths[I0] - 1;
760
761 static_for<1, i, 1>{}([&](auto j) {
762 tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
763 });
764
765 forward_sweep_(i) = tmp % 2 == 0;
766 });
767
768 return forward_sweep_;
769 }();
770
771 // calculate dst data index after last iteration in RunWrite(), if it has not being reset by
772 // RunWrite()
773 constexpr auto dst_data_idx = [&]() {
774 Index ordered_idx;
775
776 static_for<0, nDim, 1>{}([&](auto i) {
777 ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
778 });
779
780 return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
781 dst_scalar_per_access;
782 }();
783
784 //
785 constexpr auto reset_dst_data_step = [&]() {
786 Index reset_dst_data_step_;
787
788 static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
789
790 return reset_dst_data_step_;
791 }();
792
793 return reset_dst_data_step;
794 }
795
796 // src_slice_origin_step_idx need to be known at compile-time, for performance reason
797 __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
798 const Index& src_slice_origin_step_idx)
799 {
800 // if src coord was not reset by RunRead(), then need to adjust the step here
801 const auto adjusted_step_idx =
802 SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
803 : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
804
805 // is it OK to construct a new step every time?
806 const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
807
808 move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
809 }
810
811 // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
812 __device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
813 const Index& dst_slice_origin_step_idx)
814 {
815 // if dst coord was not reset by RunWrite(), then need to adjust the step here
816 const auto adjusted_step_idx =
817 DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
818 : dst_slice_origin_step_idx + GetDstCoordinateResetStep();
819
820 // is it OK to construct a new step every time?
821 const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
822
823 move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
824 }
825
826 __device__ static constexpr auto GetSrcThreadScratchDescriptor()
827 {
828 constexpr auto src_scalar_per_access = generate_sequence(
830
831 constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
832
833 constexpr auto src_access_lengths_and_vector_length = container_push_back(
835
836 // 1st stage of transforms
837 constexpr auto desc0 =
838 make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
839
840 // 2nd stage of transforms
841 constexpr auto transforms = generate_tuple(
842 [&](auto i) {
843 if constexpr(i == SrcVectorDim)
844 {
846 make_tuple(src_access_lengths_and_vector_length[i],
847 src_access_lengths_and_vector_length[Number<nDim>{}]));
848 }
849 else
850 {
851 return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
852 }
853 },
854 Number<nDim>{});
855
856 constexpr auto low_dim_idss = generate_tuple(
857 [&](auto i) {
858 if constexpr(i == SrcVectorDim)
859 {
860 return Sequence<i.value, nDim>{};
861 }
862 else
863 {
864 return Sequence<i.value>{};
865 }
866 },
867 Number<nDim>{});
868
869 constexpr auto up_dim_idss =
870 generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
871
872 return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
873 }
874
875 __device__ static constexpr auto GetSrcOOBThreadScratchDescriptor()
876 {
877 constexpr auto src_scalar_per_access = generate_sequence(
879
880 constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
881
882 return make_naive_tensor_descriptor_packed(src_access_lengths);
883 }
884
885 __device__ static constexpr auto GetDstThreadScratchDescriptor()
886 {
887 // 1st stage of transforms
888 constexpr auto dst_scalar_per_access = generate_sequence(
890
891 constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
892
893 constexpr auto dst_access_lengths_and_vector_length = container_push_back(
895
896 constexpr auto desc0 =
897 make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
898
899 // 2nd stage of transforms
900 constexpr auto transforms = generate_tuple(
901 [&](auto i) {
902 if constexpr(i == DstVectorDim)
903 {
905 make_tuple(dst_access_lengths_and_vector_length[i],
906 dst_access_lengths_and_vector_length[Number<nDim>{}]));
907 }
908 else
909 {
910 return make_pass_through_transform(dst_access_lengths_and_vector_length[i]);
911 }
912 },
913 Number<nDim>{});
914
915 constexpr auto low_dim_idss = generate_tuple(
916 [&](auto i) {
917 if constexpr(i == DstVectorDim)
918 {
919 return Sequence<i.value, nDim>{};
920 }
921 else
922 {
923 return Sequence<i.value>{};
924 }
925 },
926 Number<nDim>{});
927
928 constexpr auto up_dim_idss =
929 generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
930
931 return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
932 }
933
934 private:
935 static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
936 static constexpr auto src_oob_thread_scratch_desc_ =
937 decltype(GetSrcThreadScratchDescriptor()){};
938 static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
939
940 using SrcThreadScratch =
941 StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
942 DstData, // apply data_convert with SrcThreadScratch
944 decltype(src_thread_scratch_desc_),
945 true>;
946
947 using SrcOOBThreadScratch =
948 StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
949 bool, // apply data_convert with SrcThreadScratch
950 1,
951 decltype(src_oob_thread_scratch_desc_),
952 true>;
953
954 using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
955 DstData,
957 decltype(dst_thread_scratch_desc_),
958 true>;
959
962
963 DstThreadScratch dst_thread_scratch_;
964
965 SrcCoord src_coord_;
966 DstCoord dst_coord_;
967 const SrcElementwiseOperation src_element_op_;
968 const DstElementwiseOperation dst_element_op_;
969};
970
971} // namespace ck
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
decltype(ck::declval< T & >().is_pack8_invocable) is_pack8_invocable_t
Definition is_detected.hpp:43
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc &, const VisibleIndex &idx_diff_visible, UpdateLowerIndexHack)
Definition tensor_description/tensor_descriptor.hpp:444
__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const TensorCoordStep &coord_step)
Definition tensor_description/tensor_descriptor.hpp:508
__host__ __device__ constexpr auto container_push_back(const Array< TData, NSize > &a, const TData &x)
Definition utility/container_helper.hpp:18
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
decltype(ck::declval< T & >().is_pack4_invocable) is_pack4_invocable_t
Definition is_detected.hpp:40
__host__ __device__ constexpr bool coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc &tensor_desc, const TensorCoord &coord)
Definition tensor_description/tensor_descriptor.hpp:560
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition is_detected.hpp:34
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Lds
Definition amd_address_space.hpp:18
@ Global
Definition amd_address_space.hpp:17
@ Vgpr
Definition amd_address_space.hpp:20
__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence< Is... >)
Definition utility/container_helper.hpp:380
__host__ __device__ constexpr auto generate_sequence_v2(F &&f, Number< N >)
Definition sequence_helper.hpp:25
__host__ __device__ constexpr auto container_reorder_given_old2new(const Array< TData, NSize > &old_array, Sequence< IRs... > old2new)
Definition utility/container_helper.hpp:54
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
decltype(ck::declval< T & >().is_pack2_invocable) is_pack2_invocable_t
Definition is_detected.hpp:37
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto generate_sequence(F, Number< N >)
Definition sequence_helper.hpp:18
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
constexpr index_t packed_size_v
Definition data_type.hpp:411
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc &tensor_desc, const VisibleIndex &idx_visible)
Definition tensor_description/tensor_descriptor.hpp:407
__host__ __device__ constexpr auto container_reorder_given_new2old(const Array< TData, NSize > &old_array, Sequence< IRs... >)
Definition utility/container_helper.hpp:43
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
typename vector_type_maker< T, N >::type vector_type_maker_t
Definition dtype_vector.hpp:54
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
Definition utility/sequence.hpp:43
static __device__ constexpr auto GetSrcOOBThreadScratchDescriptor()
Definition threadwise_tensor_slice_transfer_v3r1.hpp:875
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &dst_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer_v3r1.hpp:812
__device__ constexpr auto GetSrcThreadScratchIdx(Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v3r1.hpp:356
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer_v3r1.hpp:797
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(const SrcDesc &src_desc, const Index &src_slice_origin, const SrcElementwiseOperation &src_element_op, const DstDesc &dst_desc, const Index &dst_slice_origin, const DstElementwiseOperation &dst_element_op)
Definition threadwise_tensor_slice_transfer_v3r1.hpp:81
static __device__ constexpr auto GetSrcCoordinateResetStep()
Definition threadwise_tensor_slice_transfer_v3r1.hpp:680
static __device__ constexpr auto GetSrcThreadScratchDescriptor()
Definition threadwise_tensor_slice_transfer_v3r1.hpp:826
__device__ void TransferDataFromSrcThreadScratchToDstThreadScratch(Number< ThreadScratchId > thread_scratch_id)
Definition threadwise_tensor_slice_transfer_v3r1.hpp:364
__device__ void RunRead(const SrcDesc &src_desc, const SrcBuffer &src_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v3r1.hpp:118
static __device__ constexpr auto GetDstCoordinateResetStep()
Definition threadwise_tensor_slice_transfer_v3r1.hpp:738
static __device__ constexpr auto GetDstThreadScratchDescriptor()
Definition threadwise_tensor_slice_transfer_v3r1.hpp:885
__device__ void RunWrite(const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v3r1.hpp:521
__device__ void SetSrcSliceOrigin(const SrcDesc &src_desc, const Index &src_slice_origin_idx)
Definition threadwise_tensor_slice_transfer_v3r1.hpp:107
__device__ void SetDstSliceOrigin(const DstDesc &dst_desc, const Index &dst_slice_origin_idx)
Definition threadwise_tensor_slice_transfer_v3r1.hpp:112
Definition threadwise_tensor_slice_transfer_util.hpp:43
Definition threadwise_tensor_slice_transfer_util.hpp:20
Definition threadwise_tensor_slice_transfer_util.hpp:29
Definition data_type.hpp:42
Definition type.hpp:177
Definition data_type.hpp:187
Definition functional2.hpp:33
Definition functional3.hpp:97
Definition utility/transpose_vectors.hpp:16
Definition dtype_vector.hpp:30