15template <
typename Problem>
24 float scale_reg_f = 0.f;
25 if constexpr(std::is_same_v<BQDataType, ck_tile::fp8_t>)
30 else if constexpr(std::is_same_v<BQDataType, ck_tile::bf8_t>)
35 else if constexpr(std::is_same_v<BQDataType, float>)
41 static_assert(
false,
"BQDataType must be float, fp8_t or bf8_t.");
52template <
typename Problem_,
53 typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy,
58 template <
typename PipelineProblem_,
typename GemmPolicy_>
71 static constexpr index_t kBlockSize = Problem::kBlockSize;
72 static constexpr auto Scheduler = Problem::Scheduler;
75 static constexpr index_t MPerBlock = BlockGemmShape::kM;
76 static constexpr index_t NPerBlock = BlockGemmShape::kN;
77 static constexpr index_t KPerBlock = BlockGemmShape::kK;
79 static constexpr index_t NQPerBlock = NPerBlock / QuantGroupSize::kN;
80 static constexpr index_t KQPerBlock = KPerBlock / QuantGroupSize::kK;
82 static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
86 static constexpr index_t MWarp = config.template at<1>();
87 static constexpr index_t NWarp = config.template at<2>();
92 static_assert(
MWarp == BlockGemmShape::BlockWarps::at(
I0{}),
93 "Error! WarpGemm's MWarp is not consistent with BlockGemmShape!");
94 static_assert(
NWarp == BlockGemmShape::BlockWarps::at(
I1{}),
95 "Error! WarpGemm's NWarp is not consistent with BlockGemmShape!");
96 static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(
I0{}),
97 "Error! WarpGemm's M is not consistent with BlockGemmShape!");
98 static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(
I1{}),
99 "Error! WarpGemm's N is not consistent with BlockGemmShape!");
105 static constexpr index_t QScalesPerBlockRow =
107 static constexpr index_t QScalesPerWarpGemmRow =
112 static_assert(QuantGroupSize::kK % WarpGemm::kK == 0,
113 "Error! WarpGemm::kK should be a multiple of QuantGroupSize");
114 static_assert(QScalesPerWarpGemmRow == 1,
115 "Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK");
117 "Error! KItersPerWarp should be a multiple of QscalesPerBlockRow");
119 static_assert(KPerBlock / QuantGroupSize::kK > 0,
120 "Error! Each row of blockgemm should have a separate scale");
123 "Error! Warps should cover all Block tile!");
125 "Error! Warps should cover all Block tile!");
132 static_assert((std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>) &&
133 (std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
134 std::is_same_v<BDataType, ck_tile::pk_int4_t>) &&
135 (std::is_same_v<BQDataType, float> ||
136 std::is_same_v<BQDataType, ck_tile::fp8_t> ||
137 std::is_same_v<BQDataType, ck_tile::bf8_t>) &&
138 (std::is_same_v<ComputeDataType, fp8_t> ||
139 std::is_same_v<ComputeDataType, bf8_t>) &&
140 std::is_same_v<CDataType, fp32_t>);
142 static constexpr index_t InterWaveSchedulingMacClusters = 1;
144 static constexpr index_t KPack = WarpGemm::kKPerThread;
149 using Traits = GemmTraits_<Problem_, Policy_>;
179 static_assert(std::is_same_v<typename WarpGemm::CDataType, float>);
202 constexpr index_t KPerThread = Traits::KPerThread;
203 constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
205 constexpr index_t KPerInnerLoop =
206 ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
208 constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
214 constexpr auto a_block_outer_dstr_encoding =
222 a_block_outer_dstr_encoding,
typename WarpGemm::AWarpDstrEncoding{});
224 return a_block_dstr_encode;
229 constexpr index_t KPerThread = Traits::KPerThread;
230 constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
231 constexpr index_t KPerInnerLoop =
232 ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
233 constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
239 constexpr auto b_block_outer_dstr_encoding =
248 b_block_outer_dstr_encoding,
typename WarpGemm::BWarpDstrEncoding{});
250 return b_block_dstr_encode;
254 template <GemmPipelineScheduler Scheduler,
typename GemmTraits>
259 template <
typename GemmTraits>
262 static constexpr auto ALdsTileDistr =
264 static constexpr auto BLdsTileDistr =
270 ALdsTile a_warp_tile_;
271 BLdsTile b_warp_tile_;
273 template <
typename ASmemBlockWindow,
typename BSmemBlockWindow>
275 const BSmemBlockWindow& b_block_window)
277 if constexpr(std::is_same_v<ADataType, pk_int4_t>)
279 static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
280 std::is_same_v<ComputeDataType, bf8_t>);
281 Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
287 if constexpr(std::is_same_v<BDataType, pk_int4_t>)
289 static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
290 std::is_same_v<ComputeDataType, bf8_t>);
291 Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
300 template <
typename CBlockTensor,
301 typename BQBlockTensor,
302 typename ASmemBlockWindow,
303 typename BSmemBlockWindow>
305 BQBlockTensor& bq_block_tensor,
306 [[maybe_unused]] ASmemBlockWindow& a_block_window,
307 [[maybe_unused]] BSmemBlockWindow& b_block_window)
309 static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
310 "The CDataType as defined in traits should be the same as corresponding "
311 "C block tensor data type!");
315 static_for<0, MIterPerWarp, 1>{}([&](
auto mIter) {
316 static_for<0, NIterPerWarp, 1>{}([&](
auto nIter) {
319 static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](
auto kQScale) {
320 static_for<0, Traits::KIterPerQScale, 1>{}([&](
auto kIterInQScale) {
321 constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
324 a_warp_tensor.get_thread_buffer() =
325 a_warp_tile_.get_y_sliced_thread_data(
330 b_warp_tensor.get_thread_buffer() =
331 b_warp_tile_.get_y_sliced_thread_data(
335 if constexpr(kIterInQScale == 0)
337 c_warp_tensor =
WarpGemm{}(a_warp_tensor, b_warp_tensor);
341 WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
346 constexpr index_t reg_offset = [&]() {
347 if constexpr(GemmTraits::QuantGroupSize::kN >= (
NWarp * WarpGemm::kN))
348 return (nIter *
NWarp * WarpGemm::kN) /
349 GemmTraits::QuantGroupSize::kN * Traits::KQPerBlock +
353 return nIter * Traits::KQPerBlock + kQScale;
357 constexpr auto tbuf_offset =
358 number<
typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
361 CBlockTensor::PackedSize>{};
363 auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
365 static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
367 c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
368 (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
388 c_block_outer_dstr_encoding,
typename WarpGemm::CWarpDstrEncoding{});
392 return c_block_tensor;
395 template <
typename ASmemBlockWindow,
typename BSmemBlockWindow>
397 const BSmemBlockWindow& b_block_window)
399 block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window);
403 template <
typename CBlockTensor,
404 typename BQBlockTensor,
405 typename ASmemBlockWindow,
406 typename BSmemBlockWindow>
408 BQBlockTensor& bq_block_tensor,
409 const ASmemBlockWindow& a_block_window,
410 const BSmemBlockWindow& b_block_window)
412 block_gemm_impl_(c_block_tensor, bq_block_tensor, a_block_window, b_block_window);
416 BlockGemmImpl<Scheduler, Traits> block_gemm_impl_{};
#define CK_TILE_DEVICE
Definition config.hpp:41
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
CK_TILE_DEVICE float amd_assembly_fp8_to_fp32(uint32_t src)
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:258
CK_TILE_DEVICE float amd_assembly_bf8_to_fp32(uint32_t src)
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:265
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
GemmPipelineScheduler
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:14
@ Intrawave
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:16
@ Interwave
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:17
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
unsigned int uint32_t
Definition stdint.h:126
Definition block_universal_gemm_as_bs_bquant_cr.hpp:56
GemmTraits_< Problem_, Policy_ > Traits
Definition block_universal_gemm_as_bs_bquant_cr.hpp:149
static constexpr auto a_warp_y_lengths
Definition block_universal_gemm_as_bs_bquant_cr.hpp:181
BlockGemmBQuantBase< Problem_ > Base
Definition block_universal_gemm_as_bs_bquant_cr.hpp:157
static constexpr auto c_warp_y_index_zeros
Definition block_universal_gemm_as_bs_bquant_cr.hpp:190
remove_cvref_t< InterleavedPKTypeLoader< ComputeDataType, UnaryOpSize_ > > Loader
Definition block_universal_gemm_as_bs_bquant_cr.hpp:159
typename WarpGemm::CWarpTensor CWarpTensor
Definition block_universal_gemm_as_bs_bquant_cr.hpp:177
remove_cvref_t< typename Traits::ComputeDataType > ComputeDataType
Definition block_universal_gemm_as_bs_bquant_cr.hpp:154
typename WarpGemm::BWarpTensor BWarpTensor
Definition block_universal_gemm_as_bs_bquant_cr.hpp:176
static constexpr index_t KIterPerWarp
Definition block_universal_gemm_as_bs_bquant_cr.hpp:162
CK_TILE_DEVICE void operator()(CBlockTensor &c_block_tensor, BQBlockTensor &bq_block_tensor, const ASmemBlockWindow &a_block_window, const BSmemBlockWindow &b_block_window)
Definition block_universal_gemm_as_bs_bquant_cr.hpp:407
static CK_TILE_DEVICE constexpr auto MakeABlockDistributionEncode()
Definition block_universal_gemm_as_bs_bquant_cr.hpp:200
static constexpr auto a_warp_y_index_zeros
Definition block_universal_gemm_as_bs_bquant_cr.hpp:188
number< 0 > I0
Definition block_universal_gemm_as_bs_bquant_cr.hpp:197
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
Definition block_universal_gemm_as_bs_bquant_cr.hpp:377
static constexpr index_t MIterPerWarp
Definition block_universal_gemm_as_bs_bquant_cr.hpp:163
remove_cvref_t< typename Traits::BDataType > BDataType
Definition block_universal_gemm_as_bs_bquant_cr.hpp:152
static CK_TILE_DEVICE constexpr auto MakeBBlockDistributionEncode()
Definition block_universal_gemm_as_bs_bquant_cr.hpp:227
typename WarpGemm::CWarpDstr CWarpDstr
Definition block_universal_gemm_as_bs_bquant_cr.hpp:173
typename WarpGemm::AWarpDstr AWarpDstr
Definition block_universal_gemm_as_bs_bquant_cr.hpp:171
static constexpr index_t NWarp
Definition block_universal_gemm_as_bs_bquant_cr.hpp:167
static constexpr auto Scheduler
Definition block_universal_gemm_as_bs_bquant_cr.hpp:169
static constexpr index_t APackedSize
Definition block_universal_gemm_as_bs_bquant_cr.hpp:192
static constexpr index_t MWarp
Definition block_universal_gemm_as_bs_bquant_cr.hpp:166
static constexpr index_t BPackedSize
Definition block_universal_gemm_as_bs_bquant_cr.hpp:194
remove_cvref_t< typename Traits::CDataType > CDataType
Definition block_universal_gemm_as_bs_bquant_cr.hpp:155
static constexpr auto c_warp_y_lengths
Definition block_universal_gemm_as_bs_bquant_cr.hpp:185
remove_cvref_t< typename Traits::ADataType > ADataType
Definition block_universal_gemm_as_bs_bquant_cr.hpp:151
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow &a_block_window, const BSmemBlockWindow &b_block_window)
Definition block_universal_gemm_as_bs_bquant_cr.hpp:396
typename WarpGemm::AWarpTensor AWarpTensor
Definition block_universal_gemm_as_bs_bquant_cr.hpp:175
static constexpr index_t NIterPerWarp
Definition block_universal_gemm_as_bs_bquant_cr.hpp:164
number< 1 > I1
Definition block_universal_gemm_as_bs_bquant_cr.hpp:198
remove_cvref_t< typename Traits::BQDataType > BQDataType
Definition block_universal_gemm_as_bs_bquant_cr.hpp:153
static constexpr auto b_warp_y_index_zeros
Definition block_universal_gemm_as_bs_bquant_cr.hpp:189
static constexpr auto b_warp_y_lengths
Definition block_universal_gemm_as_bs_bquant_cr.hpp:183
remove_cvref_t< typename Traits::WarpGemm > WarpGemm
Definition block_universal_gemm_as_bs_bquant_cr.hpp:160
typename WarpGemm::BWarpDstr BWarpDstr
Definition block_universal_gemm_as_bs_bquant_cr.hpp:172
Definition block_universal_gemm_as_bs_bquant_cr.hpp:17
remove_cvref_t< typename Problem::BQDataType > BQDataType
Definition block_universal_gemm_as_bs_bquant_cr.hpp:18
static CK_TILE_DEVICE float cvt_scale_to_fp32(T scale)
Definition block_universal_gemm_as_bs_bquant_cr.hpp:22
remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition block_universal_gemm_as_bs_bquant_cr.hpp:19
Definition tile/core/numeric/numeric.hpp:81
Definition tile/core/container/sequence.hpp:49
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192