blockwise_softmax.hpp Source File

blockwise_softmax.hpp Source File#

Composable Kernel: blockwise_softmax.hpp Source File
blockwise_softmax.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
12
13namespace ck {
14
25template <index_t BlockSize,
26 typename AccDataType,
27 typename ThreadMap_M_K, // thread_id to m_k
28 typename ThreadClusterDesc_M_K,
29 typename ThreadSliceDesc_M_K,
30 bool IgnoreNaN = false>
32{
33 static constexpr auto I0 = Number<0>{};
34 static constexpr auto I1 = Number<1>{};
35 static constexpr index_t MRepeat = ThreadSliceDesc_M_K{}.GetLength(I0);
36 static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1);
37
39 make_tuple(ThreadSliceDesc_M_K{}.GetLength(I0))));
40
42 IgnoreNaN,
43 ThreadwiseReduction<AccDataType,
44 ThreadSliceDesc_M_K,
47 false,
49 ThreadwiseReduction<AccDataType,
50 ThreadSliceDesc_M_K,
53 false>>::type;
54
56 IgnoreNaN,
57 ThreadwiseReduction<AccDataType,
58 ThreadSliceDesc_M_K,
61 false,
63 ThreadwiseReduction<AccDataType,
64 ThreadSliceDesc_M_K,
67 false>>::type;
68
69 using ThreadClusterLengths_M_K = decltype(ThreadClusterDesc_M_K{}.GetLengths());
70
72 BlockSize,
74 ThreadMap_M_K,
76 false>;
77
79 BlockSize,
81 ThreadMap_M_K,
83 false>;
84
86
87 template <typename CThreadBuffer, typename WorkspaceBuffer>
88 __host__ __device__ void Run(CThreadBuffer& in_thread_buf, WorkspaceBuffer& reduce_work_buf)
89 {
90 // find max value
91 static_for<0, MRepeat, 1>{}([&](auto I) {
92 max_value_buf(I) = reduce::Max::template GetIdentityValue<AccDataType>();
93 });
94 ThreadwiseMaxReduce::Reduce(in_thread_buf, max_value_buf);
95 static_for<0, MRepeat, 1>{}([&](auto I) {
96 BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I));
98 });
99
100 // calculate exp for elements, P=exp(s-max)
101 static_for<0, MRepeat, 1>{}([&](auto iM) {
102 static_for<0, KRepeat, 1>{}([&](auto iK) {
103 auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
104 in_thread_buf(offset) = IgnoreNaN && ck::math::isnan(in_thread_buf[offset])
105 ? 0
106 : math::exp(in_thread_buf[offset] - max_value_buf(iM));
107 });
108 });
109
110 // sum data
111 static_for<0, MRepeat, 1>{}([&](auto I) {
112 sum_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
113 });
114 ThreadwiseSumReduce::Reduce(in_thread_buf, sum_value_buf);
115 static_for<0, MRepeat, 1>{}([&](auto I) {
116 BlockwiseSumReduce::Reduce(reduce_work_buf, sum_value_buf(I));
118 });
119 }
120
123};
124
125} // namespace ck
__host__ T exp(T x)
Definition math_v2.hpp:391
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void block_sync_lds()
Definition synchronization.hpp:16
Blockwise softmax.
Definition blockwise_softmax.hpp:32
decltype(make_naive_tensor_descriptor_packed( make_tuple(ThreadSliceDesc_M_K{}.GetLength(I0)))) ThreadSliceDesc_M
Definition blockwise_softmax.hpp:38
PartitionedBlockwiseReduction_v2< AccDataType, BlockSize, ThreadClusterLengths_M_K, ThreadMap_M_K, reduce::Add, false > BlockwiseSumReduce
Definition blockwise_softmax.hpp:78
decltype(ThreadClusterDesc_M_K{}.GetLengths()) ThreadClusterLengths_M_K
Definition blockwise_softmax.hpp:69
__host__ __device__ void Run(CThreadBuffer &in_thread_buf, WorkspaceBuffer &reduce_work_buf)
Definition blockwise_softmax.hpp:88
StaticBuffer< AddressSpaceEnum::Vgpr, AccDataType, MRepeat, true > BufferType
Definition blockwise_softmax.hpp:85
static constexpr index_t MRepeat
Definition blockwise_softmax.hpp:35
PartitionedBlockwiseReduction_v2< AccDataType, BlockSize, ThreadClusterLengths_M_K, ThreadMap_M_K, reduce::Max, false > BlockwiseMaxReduce
Definition blockwise_softmax.hpp:71
static constexpr index_t KRepeat
Definition blockwise_softmax.hpp:36
static constexpr auto I0
Definition blockwise_softmax.hpp:33
typename conditional< IgnoreNaN, ThreadwiseReduction< AccDataType, ThreadSliceDesc_M_K, ThreadSliceDesc_M, reduce::Add, false, detail::AccumulateWithNanIgnore< reduce::Add, AccDataType > >, ThreadwiseReduction< AccDataType, ThreadSliceDesc_M_K, ThreadSliceDesc_M, reduce::Add, false > >::type ThreadwiseSumReduce
Definition blockwise_softmax.hpp:55
typename conditional< IgnoreNaN, ThreadwiseReduction< AccDataType, ThreadSliceDesc_M_K, ThreadSliceDesc_M, reduce::Max, false, detail::AccumulateWithNanIgnore< reduce::Max, AccDataType > >, ThreadwiseReduction< AccDataType, ThreadSliceDesc_M_K, ThreadSliceDesc_M, reduce::Max, false > >::type ThreadwiseMaxReduce
Definition blockwise_softmax.hpp:41
static constexpr auto I1
Definition blockwise_softmax.hpp:34
BufferType sum_value_buf
Definition blockwise_softmax.hpp:122
BufferType max_value_buf
Definition blockwise_softmax.hpp:121
Definition reduction_functions_blockwise.hpp:101
static __device__ void Reduce(BufferType &work_buffer, AccDataType &in_out_value)
Definition reduction_functions_blockwise.hpp:116
Definition static_buffer.hpp:16
Definition reduction_functions_threadwise.hpp:23
Definition utility/functional.hpp:100
Definition reduction_functions_accumulate.hpp:17
Definition reduction_operator.hpp:37
Definition reduction_operator.hpp:163
Definition functional2.hpp:33