dynamic_quant_epilogue.hpp Source File

dynamic_quant_epilogue.hpp Source File#

Composable Kernel: dynamic_quant_epilogue.hpp Source File
dynamic_quant_epilogue.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11template <bool kPadM_,
12 bool kPadN_,
13 bool UseSmoothInputScale_,
14 bool UseRawStore_ = true,
15 bool UseMax3_ = false>
17{
18 static constexpr bool kPadM = kPadM_;
19 static constexpr bool kPadN = kPadN_;
20 static constexpr bool UseSmoothInputScale = UseSmoothInputScale_;
21 static constexpr bool UseRawStore = UseRawStore_;
22 static constexpr bool UseMax3 = UseMax3_;
23};
24
25// this epilogue just store out a M*N matrix, row major
26template <typename AccDataType_,
27 typename SmoothScaleDataType_,
28 typename YScaleDataType_,
29 typename ODataType_,
30 typename BlockShape_,
31 typename Traits_>
41
42// TODO: we should put descriptor creation function into policy
43template <typename Problem_, typename Policy_ = void>
45{
52 static constexpr bool kPadM = Problem::Traits::kPadM;
53 static constexpr bool kPadN = Problem::Traits::kPadN;
54 static constexpr bool UseRawStore = Problem::Traits::UseRawStore;
55 static constexpr bool UseMax3 = Problem::Traits::UseMax3;
56
62
68
74
101
103 {
104 auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync();
105 return reduce_crosswarp_sync.GetSmemSize();
106 }
107
108 template <typename ODramWindowTmp, typename YScaleWindow, typename OAccTile>
109 CK_TILE_DEVICE auto Impl(ODramWindowTmp& o_dram_window_tmp,
110 YScaleWindow& y_scale_window,
111 const OAccTile& o_acc_tile,
112 void* smem)
113 {
114 auto reduce = GetBlockReduce2d();
115 auto reduce_sync = GetBlockReduce2dSync();
116 auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync();
117
118 auto o_acc_tmp = o_acc_tile;
119
120 const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); };
121
122 auto row_absmax = [&]() {
123 constexpr auto y_size_per_row =
124 OAccTile{}.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(
125 number<1>{});
126 if constexpr(UseMax3 && std::is_same_v<AccDataType, float> && y_size_per_row % 2 == 0)
127 {
128 // fast max3+abs implementation
129 const auto f_max3 = [](auto acc_, auto v_0_, auto v_1_) {
130 float rtn;
131 asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
132 : "=v"(rtn)
133 : "v"(acc_), "v"(v_0_), "v"(v_1_));
134 return rtn;
135 };
136 return reduce(o_acc_tmp, type_convert<AccDataType>(0), f_max3, sequence<1, 2>{});
137 }
138 else
139 {
140 return reduce(o_acc_tmp, type_convert<AccDataType>(0), f_absmax);
141 }
142 }();
143 reduce_sync(row_absmax, f_absmax);
144 reduce_crosswarp_sync(row_absmax, smem, f_absmax);
145
146 // here y_scale is Acc TYpe, need convert to YScale type later
147 auto y_scale = tile_elementwise_in(
148 [&](const auto& v_) {
150 },
151 row_absmax);
152
153 store_tile(y_scale_window, cast_tile<YScaleDataType>(y_scale));
154
155 sweep_tile(o_acc_tmp, [&](auto idx) {
156 constexpr auto row_id = make_tuple(idx[number<0>{}]);
157 o_acc_tmp(idx) = o_acc_tmp[idx] / y_scale(row_id);
158 });
159
160 // TODO: this is ugly
161 if constexpr(UseRawStore && (kPadM || kPadN))
162 {
163 store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tmp));
165 }
166 else
167 {
168 store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tmp));
169 }
170 }
171
172 // TODO: this function assume store out vector size is the same as OAccTile last dimension size
173 // how do we fix this ?
174
175 // Smooth Dynamic Quant
176 template <typename ODramWindowTmp,
177 typename SmoothScaleWindow,
178 typename YScaleWindow,
179 typename OAccTile>
180 CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
181 const SmoothScaleWindow& sm_scale_window_,
182 YScaleWindow& y_scale_window,
183 const OAccTile& o_acc_tile,
184 void* smem)
185 {
186 const auto sm_scale_window =
188
189 auto sm_scale = load_tile(sm_scale_window);
190
191 auto o_acc_tmp = o_acc_tile;
192
193 sweep_tile(o_acc_tmp, [&](auto idx) {
194 constexpr auto j_idx = make_tuple(idx[number<1>{}]);
195 const auto xs_ = type_convert<AccDataType>(sm_scale[j_idx]);
196 o_acc_tmp(idx) = o_acc_tmp(idx) * xs_;
197 });
198
199 Impl(o_dram_window_tmp, y_scale_window, o_acc_tmp, smem);
200 }
201
202 // Dynamic Quant
203 template <typename ODramWindowTmp, typename YScaleWindow, typename OAccTile>
204 CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
205 YScaleWindow& y_scale_window,
206 const OAccTile& o_acc_tile,
207 void* smem)
208 {
209 Impl(o_dram_window_tmp, y_scale_window, o_acc_tile, smem);
210 }
211};
212} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_HOST_DEVICE constexpr void sweep_tile(const F &f, UnpacksPerXDim={})
Definition sweep_tile.hpp:231
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_DEVICE void buffer_store_fence(index_t cnt=0)
Definition tile/core/arch/amd_buffer_addressing.hpp:1063
CK_TILE_DEVICE void store_tile_raw(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:46
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition bfloat16.hpp:400
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_reduce2d.hpp:334
Definition block_reduce2d.hpp:46
Definition block_reduce2d_problem.hpp:15
Definition block_reduce2d.hpp:224
Definition dynamic_quant_epilogue.hpp:45
static constexpr bool UseMax3
Definition dynamic_quant_epilogue.hpp:55
remove_cvref_t< typename Problem::BlockShape > BlockShape
Definition dynamic_quant_epilogue.hpp:51
remove_cvref_t< typename Problem::YScaleDataType > YScaleDataType
Definition dynamic_quant_epilogue.hpp:49
remove_cvref_t< typename Problem::ODataType > ODataType
Definition dynamic_quant_epilogue.hpp:50
static constexpr bool kPadM
Definition dynamic_quant_epilogue.hpp:52
remove_cvref_t< typename Problem::SmoothScaleDataType > SmoothScaleDataType
Definition dynamic_quant_epilogue.hpp:48
static CK_TILE_HOST_DEVICE constexpr auto GetBlockReduce2dCrossWarpSync()
Definition dynamic_quant_epilogue.hpp:69
static CK_TILE_HOST_DEVICE constexpr auto GetBlockReduce2dSync()
Definition dynamic_quant_epilogue.hpp:63
CK_TILE_DEVICE auto operator()(ODramWindowTmp &o_dram_window_tmp, YScaleWindow &y_scale_window, const OAccTile &o_acc_tile, void *smem)
Definition dynamic_quant_epilogue.hpp:204
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition dynamic_quant_epilogue.hpp:47
static constexpr bool kPadN
Definition dynamic_quant_epilogue.hpp:53
static constexpr bool UseRawStore
Definition dynamic_quant_epilogue.hpp:54
CK_TILE_DEVICE auto operator()(ODramWindowTmp &o_dram_window_tmp, const SmoothScaleWindow &sm_scale_window_, YScaleWindow &y_scale_window, const OAccTile &o_acc_tile, void *smem)
Definition dynamic_quant_epilogue.hpp:180
static CK_TILE_HOST_DEVICE constexpr auto GetBlockReduce2d()
Definition dynamic_quant_epilogue.hpp:57
remove_cvref_t< Problem > Problem
Definition dynamic_quant_epilogue.hpp:46
static CK_TILE_DEVICE constexpr auto MakeSmoothInputScaleTileDistribution()
Definition dynamic_quant_epilogue.hpp:75
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition dynamic_quant_epilogue.hpp:102
CK_TILE_DEVICE auto Impl(ODramWindowTmp &o_dram_window_tmp, YScaleWindow &y_scale_window, const OAccTile &o_acc_tile, void *smem)
Definition dynamic_quant_epilogue.hpp:109
Definition dynamic_quant_epilogue.hpp:33
remove_cvref_t< YScaleDataType_ > YScaleDataType
Definition dynamic_quant_epilogue.hpp:36
remove_cvref_t< ODataType_ > ODataType
Definition dynamic_quant_epilogue.hpp:37
remove_cvref_t< Traits_ > Traits
Definition dynamic_quant_epilogue.hpp:39
remove_cvref_t< BlockShape_ > BlockShape
Definition dynamic_quant_epilogue.hpp:38
remove_cvref_t< SmoothScaleDataType_ > SmoothScaleDataType
Definition dynamic_quant_epilogue.hpp:35
remove_cvref_t< AccDataType_ > AccDataType
Definition dynamic_quant_epilogue.hpp:34
Definition dynamic_quant_epilogue.hpp:17
static constexpr bool UseSmoothInputScale
Definition dynamic_quant_epilogue.hpp:20
static CK_TILE_HOST_DEVICE constexpr T max()
Definition tile/core/numeric/numeric.hpp:26
Definition tile/core/container/sequence.hpp:49
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192