thread_group_tensor_slice_transfer_v4r1.hpp Source File

thread_group_tensor_slice_transfer_v4r1.hpp Source File#

Composable Kernel: thread_group_tensor_slice_transfer_v4r1.hpp Source File
thread_group_tensor_slice_transfer_v4r1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
11
12namespace ck {
13
23template <typename ThreadGroup,
24 typename SrcElementwiseOperation,
25 typename DstElementwiseOperation,
27 typename BlockSliceLengths,
28 typename ThreadClusterLengths,
29 typename ThreadClusterArrangeOrder,
30 typename SrcData,
31 typename DstData,
32 typename SrcDesc,
33 typename DstDesc,
34 typename SrcDimAccessOrder,
35 typename DstDimAccessOrder,
36 index_t SrcVectorDim,
37 index_t DstVectorDim,
38 index_t SrcScalarPerVector,
39 index_t DstScalarPerVector,
40 index_t SrcScalarStrideInVector,
41 index_t DstScalarStrideInVector,
42 bool ThreadTransferSrcResetCoordinateAfterRun,
43 bool ThreadTransferDstResetCoordinateAfterRun,
44 index_t NumThreadScratch = 1>
46{
48
49 static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
50
52
54 const SrcDesc& src_desc,
55 const Index& src_block_slice_origin,
56 const SrcElementwiseOperation& src_element_op,
57 const DstDesc& dst_desc,
58 const Index& dst_block_slice_origin,
59 const DstElementwiseOperation& dst_element_op)
60 : threadwise_transfer_(src_desc,
62 src_element_op,
63 dst_desc,
65 dst_element_op)
66
67 {
70 nDim == ThreadClusterLengths::Size() &&
71 nDim == ThreadClusterArrangeOrder::Size() &&
72 nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
73 "wrong! nDim not consistent");
74
75 static_assert(
76 is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
77 "wrong! threads should be mapped to cover entire slicing window");
78
79 static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
80 "wrong! ThreadGroup::GetNumOfThread() too small");
81
82 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
83 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
84 {
85 const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
86 make_multi_index(ThreadGroup::GetThreadId()));
87
88 const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
89
90 threadwise_transfer_.SetSrcSliceOrigin(src_desc,
91 src_block_slice_origin + thread_data_idx_begin);
92 threadwise_transfer_.SetDstSliceOrigin(dst_desc,
93 dst_block_slice_origin + thread_data_idx_begin);
94 }
95 }
96
97 __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_block_slice_origin)
98 {
99 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
100 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
101 {
102 const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
103 make_multi_index(ThreadGroup::GetThreadId()));
104
105 const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
106
107 threadwise_transfer_.SetSrcSliceOrigin(src_desc,
108 src_block_slice_origin + thread_data_idx_begin);
109 }
110 }
111
112 template <typename SeqIdx, index_t ThreadScratchId = 0>
113 __device__ constexpr auto GetSrcThreadScratchIdx()
114 {
115 return threadwise_transfer_.template GetSrcThreadScratchIdx<SeqIdx, ThreadScratchId>();
116 }
117
118 template <typename SrcBuffer, index_t ThreadScratchId = 0>
119 __device__ void RunRead(const SrcDesc& src_desc,
120 const SrcBuffer& src_buf,
122 {
123 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
124 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
125 {
126 threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id);
127 }
128 }
129
130 template <typename DstBuffer, index_t ThreadScratchId = 0>
131 __device__ void RunWrite(const DstDesc& dst_desc,
132 DstBuffer& dst_buf,
134 {
135 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
136 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
137 {
138 threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id);
139 }
140 }
141
142 template <typename SrcBuffer, typename DstBuffer, index_t ThreadScratchId>
143 __device__ void Run(const SrcDesc& src_desc,
144 const SrcBuffer& src_buf,
145 const DstDesc& dst_desc,
146 DstBuffer& dst_buf,
147 Number<ThreadScratchId> thread_scratch_id)
148 {
149 RunRead(src_desc, src_buf, thread_scratch_id);
150 RunWrite(dst_desc, dst_buf, thread_scratch_id);
151 }
152
153 __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
154 {
155 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
156 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
157 {
158 threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
159 }
160 }
161
162 __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
163 {
164 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
165 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
166 {
167 threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
168 }
169 }
170
171 private:
172 static constexpr auto thread_cluster_desc_ =
173 make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
174
175 using ThreadwiseTransfer =
176 ThreadwiseTensorSliceTransfer_v3r1<decltype(thread_slice_lengths),
177 SrcElementwiseOperation,
178 DstElementwiseOperation,
179 DstInMemOp,
180 SrcData,
181 DstData,
182 SrcDesc,
183 DstDesc,
184 SrcDimAccessOrder,
185 DstDimAccessOrder,
186 SrcVectorDim,
187 DstVectorDim,
188 SrcScalarPerVector,
189 DstScalarPerVector,
190 SrcScalarStrideInVector,
191 DstScalarStrideInVector,
192 ThreadTransferSrcResetCoordinateAfterRun,
193 ThreadTransferDstResetCoordinateAfterRun,
194 NumThreadScratch>;
195
196 ThreadwiseTransfer threadwise_transfer_;
197};
198
199} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_zero_multi_index()
Definition array_multi_index.hpp:21
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id)
Definition thread_group_tensor_slice_transfer_v4r1.hpp:143
MultiIndex< nDim > Index
Definition thread_group_tensor_slice_transfer_v4r1.hpp:51
static constexpr index_t nDim
Definition thread_group_tensor_slice_transfer_v4r1.hpp:47
static constexpr auto thread_slice_lengths
Definition thread_group_tensor_slice_transfer_v4r1.hpp:49
__device__ void SetSrcSliceOrigin(const SrcDesc &src_desc, const Index &src_block_slice_origin)
Definition thread_group_tensor_slice_transfer_v4r1.hpp:97
__device__ void RunRead(const SrcDesc &src_desc, const SrcBuffer &src_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition thread_group_tensor_slice_transfer_v4r1.hpp:119
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v4r1.hpp:153
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v4r1.hpp:162
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r1(const SrcDesc &src_desc, const Index &src_block_slice_origin, const SrcElementwiseOperation &src_element_op, const DstDesc &dst_desc, const Index &dst_block_slice_origin, const DstElementwiseOperation &dst_element_op)
Definition thread_group_tensor_slice_transfer_v4r1.hpp:53
__device__ constexpr auto GetSrcThreadScratchIdx()
Definition thread_group_tensor_slice_transfer_v4r1.hpp:113
__device__ void RunWrite(const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition thread_group_tensor_slice_transfer_v4r1.hpp:131
__device__ void RunRead(const SrcDesc &src_desc, const SrcBuffer &src_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v3r1.hpp:118
__device__ void RunWrite(const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v3r1.hpp:521
Definition type.hpp:177