gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp Source File

gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp Source File#

Composable Kernel: gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp Source File
gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
11
12namespace ck {
13
14template <typename GridwiseReduceSecondHalfBatchNormBackwardFinal_,
15 typename XDataType,
16 typename DyDataType,
17 typename DxDataType,
18 typename ScaleDataType,
19 typename DscaleDbiasDataType,
20 typename MeanVarDataType,
21 typename DyElementwiseOp,
22 typename XYGridDesc_M_K,
23 typename DscaleDbiasGridDesc_M_K,
24 typename MeanVarGridDesc_M,
25 typename ScaleBiasGridDesc_M>
27 const XYGridDesc_M_K x_grid_desc_m_k,
28 const XYGridDesc_M_K dy_grid_desc_m_k,
29 const XYGridDesc_M_K dx_grid_desc_m_k,
30 const DscaleDbiasGridDesc_M_K dscale_dbias_grid_desc_m_k,
31 const MeanVarGridDesc_M mean_var_grid_desc_m,
32 const ScaleBiasGridDesc_M scale_grid_desc_m,
33 const ScaleBiasGridDesc_M bias_grid_desc_m,
34 index_t blkgroup_size,
35 long_index_t reduce_size,
36 index_t num_xy_k_block_tile_iteration,
37 index_t num_dscale_dbias_k_block_tile_iteration,
38 const DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
39 const DscaleDbiasDataType* const __restrict__ p_reduce_dbias,
40 const MeanVarDataType* const __restrict__ p_mean,
41 const MeanVarDataType* const __restrict__ p_inv_var,
42 const XDataType* const __restrict__ p_x,
43 const DyDataType* const __restrict__ p_dy,
44 const ScaleDataType* const __restrict__ p_scale,
45 const DyElementwiseOp dy_elementwise_op,
46 DxDataType* const __restrict__ p_dx,
47 DscaleDbiasDataType* const __restrict__ p_dscale,
48 DscaleDbiasDataType* const __restrict__ p_dbias)
49{
50 GridwiseReduceSecondHalfBatchNormBackwardFinal_::Run(x_grid_desc_m_k,
51 dy_grid_desc_m_k,
52 dx_grid_desc_m_k,
53 dscale_dbias_grid_desc_m_k,
54 mean_var_grid_desc_m,
55 scale_grid_desc_m,
56 bias_grid_desc_m,
57 blkgroup_size,
58 reduce_size,
59 num_xy_k_block_tile_iteration,
60 num_dscale_dbias_k_block_tile_iteration,
61 p_reduce_dscale,
62 p_reduce_dbias,
63 p_mean,
64 p_inv_var,
65 p_x,
66 p_dy,
67 p_scale,
68 dy_elementwise_op,
69 p_dx,
70 p_dscale,
71 p_dbias);
72};
73
74template <typename XDataType,
75 typename DyDataType,
76 typename DxDataType,
77 typename AccDataType,
78 typename ScaleDataType,
79 typename DscaleDbiasDataType,
80 typename MeanVarDataType,
81 typename DyElementwiseOp,
82 typename XYGridDesc_M_K,
83 typename DscaleDbiasGridDesc_M_K,
84 typename MeanVarGridDesc_M,
85 typename ScaleBiasGridDesc_M,
86 index_t BlockSize,
87 index_t MThreadClusterSize,
88 index_t KThreadClusterSize,
89 index_t MThreadSliceSize,
90 index_t KThreadSliceSize,
91 index_t XDyDxVectorDim,
92 index_t XSrcVectorSize,
93 index_t DySrcVectorSize,
94 index_t DxDstVectorSize,
95 index_t ScaleSrcVectorSize,
96 index_t DscaleDbiasDstVectorSize,
97 index_t MeanVarSrcVectorSize>
99{
100 static_assert((XDyDxVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0 &&
101 MThreadSliceSize % DySrcVectorSize == 0 &&
102 MThreadSliceSize % DxDstVectorSize == 0) ||
103 (XDyDxVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0 &&
104 KThreadSliceSize % DySrcVectorSize == 0 &&
105 KThreadSliceSize % DxDstVectorSize == 0),
106 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
107
108 static constexpr bool reorder_thread_cluster = (XDyDxVectorDim == 0);
109
111
114
117
118 static constexpr auto thread_cluster_desc =
120
125
127 BlockSize,
131 false>;
132
137 false>;
138
140
141 static constexpr auto I0 = Number<0>{};
142 static constexpr auto I1 = Number<1>{};
143
144 static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
145 static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
146
147 // clang-format off
148 // Two of the steps of Multiblock BatchNorm Backward
149 // Step 1: Second half of Reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
150 // Step 2: calculating dx = 1/reduce_size * inv-variance * scale * (reduce_size * dy - dbias - dscale * (x - mean) * inv-variance)) elementwise-ly
151 // clang-format on
152 __device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
153 const XYGridDesc_M_K& dy_grid_desc_m_k,
154 const XYGridDesc_M_K& dx_grid_desc_m_k,
155 const DscaleDbiasGridDesc_M_K& dscale_dbias_grid_desc_m_k,
156 const MeanVarGridDesc_M& mean_var_grid_desc_m,
157 const ScaleBiasGridDesc_M& scale_grid_desc_m,
158 const ScaleBiasGridDesc_M& dscale_dbias_grid_desc_m,
159 index_t blkgroup_size,
160 long_index_t reduce_size,
161 index_t num_xy_k_block_tile_iteration,
162 index_t num_dscale_dbias_k_block_tile_iteration,
163 const DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
164 const DscaleDbiasDataType* const __restrict__ p_reduce_dbias,
165 const MeanVarDataType* const __restrict__ p_mean,
166 const MeanVarDataType* const __restrict__ p_inv_var,
167 const XDataType* const __restrict__ p_x,
168 const DyDataType* const __restrict__ p_dy,
169 const ScaleDataType* const __restrict__ p_scale,
170 const DyElementwiseOp dy_elementwise_op,
171 DxDataType* const __restrict__ p_dx,
172 DscaleDbiasDataType* const __restrict__ p_dscale,
173 DscaleDbiasDataType* const __restrict__ p_dbias)
174 {
175 __shared__ AccDataType p_reduce_work_buffer[BlockSize];
176
177 auto reduce_work_buf =
178 make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
179
181 reduce_dscale_thread_buf;
183 reduce_dbias_thread_buf;
184
187
189 x_thread_buf;
191 dy_thread_buf;
193 dx_thread_buf;
194
197 inv_var_thread_buf;
199
200 const index_t thread_local_id = get_thread_local_1d_id();
201 const index_t block_global_id = get_block_1d_id();
202 const index_t blkgroup_id = block_global_id / blkgroup_size;
203 const index_t block_local_id = block_global_id % blkgroup_size;
204
205 const auto thread_cluster_idx =
206 thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
207
208 const auto thread_m_cluster_id = thread_cluster_idx[I0];
209 const auto thread_k_cluster_id = thread_cluster_idx[I1];
210
211 using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
212 using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
213 using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
214 constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
216 constexpr auto thread_buffer_desc_m =
218 constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
220
221 // clang-format off
222 // Step 1: do final reduction of dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
223 // clang-format on
224
225 auto threadwise_dscale_dbias_load_m_k =
226 ThreadwiseTensorSliceTransfer_v2<DscaleDbiasDataType,
227 AccDataType,
228 DscaleDbiasGridDesc_M_K,
229 decltype(thread_buffer_desc_m_1),
230 ThreadBufferLengths_M_1,
232 1,
233 1,
234 1,
235 true>(
236 dscale_dbias_grid_desc_m_k,
237 make_multi_index(blkgroup_id * M_BlockTileSize +
238 thread_m_cluster_id * MThreadSliceSize,
239 thread_k_cluster_id * 1));
240
241 auto threadwise_dscale_dbias_store_m =
243 DscaleDbiasDataType,
244 decltype(thread_buffer_desc_m),
245 ScaleBiasGridDesc_M,
247 ThreadBufferLengths_M,
249 0,
250 DscaleDbiasDstVectorSize,
252 1,
253 true>(
254 dscale_dbias_grid_desc_m,
255 make_multi_index(blkgroup_id * M_BlockTileSize +
256 thread_m_cluster_id * MThreadSliceSize),
257 PassThroughOp{});
258
259 const auto reduce_dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
260 p_reduce_dscale, dscale_dbias_grid_desc_m_k.GetElementSpaceSize());
261
262 const auto reduce_dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
263 p_reduce_dbias, dscale_dbias_grid_desc_m_k.GetElementSpaceSize());
264
265 auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
266 p_dscale, dscale_dbias_grid_desc_m.GetElementSpaceSize());
267
268 auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
269 p_dbias, dscale_dbias_grid_desc_m.GetElementSpaceSize());
270
271 constexpr auto dscale_dbias_thread_copy_step_m_k =
272 make_multi_index(0, KThreadClusterSize * 1);
273
275 dscale_thread_buf(I) = type_convert<AccDataType>(0.0f);
276 dbias_thread_buf(I) = type_convert<AccDataType>(0.0f);
277 });
278
279 for(index_t reducedTiles = 0; reducedTiles < num_dscale_dbias_k_block_tile_iteration;
280 ++reducedTiles)
281 {
282 threadwise_dscale_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k,
283 reduce_dscale_global_buf,
284 thread_buffer_desc_m_1,
285 make_tuple(I0, I0),
286 reduce_dscale_thread_buf);
287
288 threadwise_dscale_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k,
289 reduce_dbias_global_buf,
290 thread_buffer_desc_m_1,
291 make_tuple(I0, I0),
292 reduce_dbias_thread_buf);
293
294 ThreadwiseReduce::Reduce(reduce_dscale_thread_buf, dscale_thread_buf);
295 ThreadwiseReduce::Reduce(reduce_dbias_thread_buf, dbias_thread_buf);
296
297 threadwise_dscale_dbias_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k,
298 dscale_dbias_thread_copy_step_m_k);
299 }
300
302 if constexpr(I > 0)
304
305 BlockwiseReduce::Reduce(reduce_work_buf, dscale_thread_buf(I));
307 BlockwiseReduce::Reduce(reduce_work_buf, dbias_thread_buf(I));
308 });
309
310 threadwise_dscale_dbias_store_m.Run(thread_buffer_desc_m,
311 make_tuple(I0),
312 dscale_thread_buf,
313 dscale_dbias_grid_desc_m,
314 dscale_global_buf);
315
316 threadwise_dscale_dbias_store_m.Run(thread_buffer_desc_m,
317 make_tuple(I0),
318 dbias_thread_buf,
319 dscale_dbias_grid_desc_m,
320 dbias_global_buf);
321
322 // clang-format off
323 // Step 2: calculate dx = 1/N * inv-variance * scale * (N * dy - dbias - dscale * (x - mean) * inv-variance)
324 // clang-format on
325
326 const index_t workSizePerBlock = K_BlockTileSize * num_xy_k_block_tile_iteration;
327
328 auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
329 AccDataType,
330 XYGridDesc_M_K,
331 decltype(thread_buffer_desc_m_k),
332 ThreadBufferLengths_M_K,
334 XDyDxVectorDim,
335 XSrcVectorSize,
336 1,
337 true>(
338 x_grid_desc_m_k,
339 make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
340 workSizePerBlock * block_local_id +
341 thread_k_cluster_id * KThreadSliceSize));
342
343 auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2<DyDataType,
344 AccDataType,
345 XYGridDesc_M_K,
346 decltype(thread_buffer_desc_m_k),
347 ThreadBufferLengths_M_K,
349 XDyDxVectorDim,
350 DySrcVectorSize,
351 1,
352 true>(
353 dy_grid_desc_m_k,
354 make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
355 workSizePerBlock * block_local_id +
356 thread_k_cluster_id * KThreadSliceSize));
357
358 auto threadwise_dx_store =
360 DxDataType,
361 decltype(thread_buffer_desc_m_k),
362 XYGridDesc_M_K,
364 ThreadBufferLengths_M_K,
366 XDyDxVectorDim,
367 DxDstVectorSize,
369 1,
370 true>(
371 dx_grid_desc_m_k,
373 blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
374 workSizePerBlock * block_local_id + thread_k_cluster_id * KThreadSliceSize),
375 PassThroughOp{});
376
377 auto threadwise_scale_load =
379 AccDataType,
380 ScaleBiasGridDesc_M,
381 decltype(thread_buffer_desc_m),
382 ThreadBufferLengths_M,
384 0,
385 ScaleSrcVectorSize,
386 1,
387 true>(
388 scale_grid_desc_m,
389 make_multi_index(blkgroup_id * M_BlockTileSize +
390 thread_m_cluster_id * MThreadSliceSize));
391
392 auto threadwise_mean_var_load =
393 ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
394 AccDataType,
395 MeanVarGridDesc_M,
396 decltype(thread_buffer_desc_m),
397 ThreadBufferLengths_M,
399 0,
400 MeanVarSrcVectorSize,
401 1,
402 true>(
403 mean_var_grid_desc_m,
404 make_multi_index(blkgroup_id * M_BlockTileSize +
405 thread_m_cluster_id * MThreadSliceSize));
406
407 const auto x_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
408 p_x, x_grid_desc_m_k.GetElementSpaceSize());
409
410 const auto dy_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
411 p_dy, dy_grid_desc_m_k.GetElementSpaceSize());
412
414 p_dx, dx_grid_desc_m_k.GetElementSpaceSize());
415
416 const auto scale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
417 p_scale, scale_grid_desc_m.GetElementSpaceSize());
418
419 const auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
420 p_mean, mean_var_grid_desc_m.GetElementSpaceSize());
421
422 const auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
423 p_inv_var, mean_var_grid_desc_m.GetElementSpaceSize());
424
425 threadwise_scale_load.Run(scale_grid_desc_m,
426 scale_global_buf,
427 thread_buffer_desc_m,
428 make_tuple(I0),
429 scale_thread_buf);
430
431 threadwise_mean_var_load.Run(mean_var_grid_desc_m,
432 mean_global_buf,
433 thread_buffer_desc_m,
434 make_tuple(I0),
435 mean_thread_buf);
436
437 threadwise_mean_var_load.Run(mean_var_grid_desc_m,
438 inv_var_global_buf,
439 thread_buffer_desc_m,
440 make_tuple(I0),
441 inv_var_thread_buf);
442
443 constexpr auto xy_thread_copy_step_m_k = make_multi_index(0, K_BlockTileSize);
444
445 AccDataType inv_reduce_size =
447
448 for(index_t reducedTiles = 0; reducedTiles < num_xy_k_block_tile_iteration; ++reducedTiles)
449 {
450 threadwise_x_load.Run(x_grid_desc_m_k,
451 x_global_buf,
452 thread_buffer_desc_m_k,
453 make_tuple(I0, I0),
454 x_thread_buf);
455
456 threadwise_dy_load.Run(dy_grid_desc_m_k,
457 dy_global_buf,
458 thread_buffer_desc_m_k,
459 make_tuple(I0, I0),
460 dy_thread_buf);
461
463 AccDataType multiplier =
464 inv_reduce_size * inv_var_thread_buf[iM] * scale_thread_buf[iM];
465
467 constexpr auto offset =
468 thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
469
470 dy_elementwise_op(dy_thread_buf(Number<offset>{}),
471 dy_thread_buf[Number<offset>{}]);
472
473 AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
474 inv_var_thread_buf[iM];
475
476 AccDataType tmpVal = norm_x * dscale_thread_buf[iM];
477
478 dx_thread_buf(Number<offset>{}) =
479 multiplier *
480 (type_convert<AccDataType>(reduce_size) * dy_thread_buf[Number<offset>{}] -
481 dbias_thread_buf[iM] - tmpVal);
482 });
483 });
484
485 threadwise_dx_store.Run(thread_buffer_desc_m_k,
486 make_tuple(I0, I0),
487 dx_thread_buf,
488 dx_grid_desc_m_k,
489 dx_global_buf);
490
491 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_thread_copy_step_m_k);
492 threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, xy_thread_copy_step_m_k);
493 threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, xy_thread_copy_step_m_k);
494 }
495 };
496};
497
498} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__global__ void kernel_reduce_second_half_batchnorm_backward_final(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k, const XYGridDesc_M_K dx_grid_desc_m_k, const DscaleDbiasGridDesc_M_K dscale_dbias_grid_desc_m_k, const MeanVarGridDesc_M mean_var_grid_desc_m, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M bias_grid_desc_m, index_t blkgroup_size, long_index_t reduce_size, index_t num_xy_k_block_tile_iteration, index_t num_dscale_dbias_k_block_tile_iteration, const DscaleDbiasDataType *const __restrict__ p_reduce_dscale, const DscaleDbiasDataType *const __restrict__ p_reduce_dbias, const MeanVarDataType *const __restrict__ p_mean, const MeanVarDataType *const __restrict__ p_inv_var, const XDataType *const __restrict__ p_x, const DyDataType *const __restrict__ p_dy, const ScaleDataType *const __restrict__ p_scale, const DyElementwiseOp dy_elementwise_op, DxDataType *const __restrict__ p_dx, DscaleDbiasDataType *const __restrict__ p_dscale, DscaleDbiasDataType *const __restrict__ p_dbias)
Definition gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
int64_t long_index_t
Definition ck.hpp:300
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
Definition gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp:99
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp:110
static constexpr index_t K_BlockTileSize
Definition gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp:145
decltype(make_naive_tensor_descriptor_packed( make_tuple(Number< MThreadSliceSize >{}, Number< 1 >{}))) ThreadReduceSrcDesc_M_1
Definition gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp:121
static constexpr index_t M_BlockTileSize
Definition gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp:144
static constexpr bool reorder_thread_cluster
Definition gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp:108
static __device__ void Run(const XYGridDesc_M_K &x_grid_desc_m_k, const XYGridDesc_M_K &dy_grid_desc_m_k, const XYGridDesc_M_K &dx_grid_desc_m_k, const DscaleDbiasGridDesc_M_K &dscale_dbias_grid_desc_m_k, const MeanVarGridDesc_M &mean_var_grid_desc_m, const ScaleBiasGridDesc_M &scale_grid_desc_m, const ScaleBiasGridDesc_M &dscale_dbias_grid_desc_m, index_t blkgroup_size, long_index_t reduce_size, index_t num_xy_k_block_tile_iteration, index_t num_dscale_dbias_k_block_tile_iteration, const DscaleDbiasDataType *const __restrict__ p_reduce_dscale, const DscaleDbiasDataType *const __restrict__ p_reduce_dbias, const MeanVarDataType *const __restrict__ p_mean, const MeanVarDataType *const __restrict__ p_inv_var, const XDataType *const __restrict__ p_x, const DyDataType *const __restrict__ p_dy, const ScaleDataType *const __restrict__ p_scale, const DyElementwiseOp dy_elementwise_op, DxDataType *const __restrict__ p_dx, DscaleDbiasDataType *const __restrict__ p_dscale, DscaleDbiasDataType *const __restrict__ p_dbias)
Definition gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp:152
static constexpr auto I1
Definition gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp:142
tensor_operation::element_wise::PassThrough PassThroughOp
Definition gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp:139
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp:112
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp:115
PartitionedBlockwiseReduction< AccDataType, BlockSize, ThreadClusterLengths_M_K, ThreadClusterArrangeOrder, ck::reduce::Add, false > BlockwiseReduce
Definition gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp:126
ThreadwiseReduction< AccDataType, ThreadReduceSrcDesc_M_1, ThreadReduceDstDesc_M, ck::reduce::Add, false > ThreadwiseReduce
Definition gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp:133
static constexpr auto I0
Definition gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp:141
static constexpr auto thread_cluster_desc
Definition gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp:118
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp:123
Definition reduction_functions_blockwise.hpp:28
static __device__ void Reduce(BufferType &work_buffer, AccDataType &in_out_value)
Definition reduction_functions_blockwise.hpp:44
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition reduction_functions_threadwise.hpp:23
static __device__ void Reduce(const SrcBufferType &src_buf, DstBufferType &dst_buf)
Definition reduction_functions_threadwise.hpp:36
Definition threadwise_tensor_slice_transfer.hpp:39
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition threadwise_tensor_slice_transfer.hpp:66
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/functional.hpp:100
Definition reduction_operator.hpp:37
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340