reference_fused_moe.hpp Source File

reference_fused_moe.hpp Source File#

Composable Kernel: reference_fused_moe.hpp Source File
reference_fused_moe.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10// [indexing implementation-1]
11// using M_a as constexpr block_size to partition all tokens into different slices
12// each slice map to one expert, and one expert can have multiple slices
13// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
14// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
15// tok-0 tok-1 tok-2 tok-3 tok-4
16// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float
17// number)
18//
19// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
20// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
21// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
22//
23// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
24// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
25// * this could be larger than actual, since actual tokens are on GPU
26//
27// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6,
28// 0, 1, 2, 5]
29// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4
30// -|- exp-5 -|
31// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *,
32// c, f, i, o]
33//
34// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
35//
36// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
37// * length is (max_num_tokens_padded + block_size - 1) / block_size
39// num_tokens_post_padded_ptr : [28]
40// num_sorted_tiles_ptr : [7]
41
42template <typename AccDataType, // you only need to explcitly set this one
43 typename Activation, // ck_tile::element_wise::Gelu
44 typename ADataType,
45 typename GDataType,
46 typename DDataType,
47 typename ODataType,
48 typename AScaleDataType,
49 typename GScaleDataType,
50 typename DScaleDataType,
51 typename YSmoothScaleDataType,
52 typename TopkWeightDataType,
53 typename IndexDataType>
55 const ck_tile::HostTensor<ADataType>& a_host, // [tokens, hidden_size]
56 const ck_tile::HostTensor<GDataType>& g_host, // [experts, interme_size_0, hidden_size]
57 const ck_tile::HostTensor<DDataType>& d_host, // [experts, hidden_size, interme_size_1]
58 const ck_tile::HostTensor<AScaleDataType>& sa_host, // [tokens, 1],
59 const ck_tile::HostTensor<GScaleDataType>& sg_host, // [experts, 1, interme_size_0]
60 const ck_tile::HostTensor<DScaleDataType>& sd_host, // [experts, 1, hidden_size],
61 const ck_tile::HostTensor<YSmoothScaleDataType>& sy_host, // [experts, 1, interme_size_0]
62 ck_tile::HostTensor<ODataType>& o_host, // [tokens, hidden_size]
63 const ck_tile::HostTensor<IndexDataType>& sorted_token_ids_host, // [max_num_tokens_padded]
64 const ck_tile::HostTensor<TopkWeightDataType>& sorted_weight_host, // [max_num_tokens_padded]
66 sorted_expert_ids_host, // [(max_num_tokens_padded + block_size - 1) / block_size]
67 const ck_tile::HostTensor<IndexDataType>& num_sorted_tiles_host, // [1]
68
70 token_ids_host, // [tokens, topk] --> ugly!!! remove in the future
71
72 ck_tile::index_t block_m,
73 ck_tile::index_t tokens,
74 ck_tile::index_t experts,
75 ck_tile::index_t hidden_size,
76 ck_tile::index_t intermediate_size, // this size is for gate/up/down
78 ck_tile::index_t gate_only)
79{
80 assert(sorted_token_ids_host.get_num_of_dimension() == 1);
81 assert(sorted_weight_host.get_num_of_dimension() == 1);
82 assert(sorted_expert_ids_host.get_num_of_dimension() == 1);
83 assert(num_sorted_tiles_host.get_element_size() == 1);
84 ck_tile::index_t num_sorted_tiles = num_sorted_tiles_host.mData[0] / block_m;
85 ck_tile::index_t intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2);
86 ck_tile::index_t intermediate_size_1 = intermediate_size;
87
88 ck_tile::HostTensor<AccDataType> out_topk_tokens({tokens, topk, hidden_size});
89
90 int max_num_tokens_padded = topk * tokens + experts * block_m - topk;
91 // assert();
92 auto f = [&](auto i_flatten) {
93 ck_tile::index_t i_tile = i_flatten / block_m;
94 if(i_tile >= num_sorted_tiles)
95 return;
96 ck_tile::index_t i_expert = sorted_expert_ids_host.mData[i_tile];
97
98#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
99 ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
100 ck_tile::index_t i_topk = i_token >> 24;
101 i_token &= 0xffffff;
102 if(i_token >= tokens)
103 return;
104 (void)token_ids_host;
105#else
106 // TODO: better remove this in the future, or modify the token_id value
107 auto get_topk_id = [&](ck_tile::index_t token_id_, ck_tile::index_t expert_id_) {
108 for(ck_tile::index_t i_ = 0; i_ < topk; i_++)
109 {
110 if(token_ids_host(token_id_, i_) == expert_id_)
111 return i_;
112 }
113 throw std::runtime_error("not correct token/expert pair\n");
114 return -1; // TODO: not correct!!
115 };
116 ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
117 if(i_token >= tokens)
118 return;
119 ck_tile::index_t i_topk = get_topk_id(i_token, i_expert); // TODO: ugly
120#endif
121 auto weight = sorted_weight_host.mData[i_flatten];
122
123 ck_tile::HostTensor<AccDataType> acc_0({1, intermediate_size_0});
124 // first gemm
125 for(ck_tile::index_t i_n = 0; i_n < intermediate_size_0; i_n++)
126 {
127 AccDataType acc = static_cast<AccDataType>(0);
128 for(ck_tile::index_t i_k = 0; i_k < hidden_size; i_k++)
129 {
130 acc += type_convert<AccDataType>(a_host(i_token, i_k)) *
131 type_convert<AccDataType>(g_host(i_expert, i_n, i_k));
132 }
133 acc_0(0, i_n) = acc;
134 // printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, acc);
135 }
136
137 ck_tile::HostTensor<AccDataType> y({1, intermediate_size_1});
138 if(gate_only)
139 {
140 if(intermediate_size_1 != intermediate_size_0)
141 throw std::runtime_error(
142 "intermediate_size not correct, 0:" + std::to_string(intermediate_size_0) +
143 ", 1:" + std::to_string(intermediate_size_1));
144 for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++)
145 {
146 Activation{}(y(0, i_n), acc_0(0, i_n));
147 // printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, y(0, i_n));
148 }
149 }
150 else
151 {
152 if(intermediate_size_1 * 2 != intermediate_size_0)
153 throw std::runtime_error(
154 "intermediate_size not correct, 0:" + std::to_string(intermediate_size_0) +
155 ", 1:" + std::to_string(intermediate_size_1));
156 for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++)
157 {
158 AccDataType tmp;
159 Activation{}(tmp, acc_0(0, i_n));
160 y(0, i_n) = tmp * acc_0(0, i_n + intermediate_size_1); // TODO: elementwise mul
161 }
162 }
163
164 // second gemm, loop along gemm-n
165 ck_tile::HostTensor<AccDataType> acc_1({1, hidden_size});
166 for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
167 {
168 AccDataType acc = static_cast<AccDataType>(0);
169 for(ck_tile::index_t i_k = 0; i_k < intermediate_size_1; i_k++)
170 {
171 acc += y(0, i_k) * type_convert<AccDataType>(d_host(i_expert, i_n, i_k));
172 }
173 acc_1(0, i_n) = acc * weight; // multiple weight here
174 }
175
176 for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
177 {
178 out_topk_tokens(i_token, i_topk, i_n) = acc_1(0, i_n);
179 }
180 };
181
182 // make_ParallelTensorFunctor(f, max_num_tokens_padded)(std::thread::hardware_concurrency());
183 make_ParallelTensorFunctor(f, max_num_tokens_padded)(1);
184
185 // reduce
186 auto r = [&](auto i_token) {
187 for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
188 {
189 AccDataType acc = type_convert<AccDataType>(0);
190 for(ck_tile::index_t i_topk = 0; i_topk < topk; i_topk++)
191 {
192 acc += out_topk_tokens(i_token, i_topk, i_n);
193 }
194 o_host(i_token, i_n) = type_convert<ODataType>(acc);
195 }
196 };
197 make_ParallelTensorFunctor(r, tokens)(std::thread::hardware_concurrency());
198
199 (void)num_sorted_tiles_host;
200 (void)sa_host;
201 (void)sg_host;
202 (void)sd_host;
203 (void)sy_host;
204}
205} // namespace ck_tile
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition tile/host/host_tensor.hpp:329
void reference_fused_moe(const ck_tile::HostTensor< ADataType > &a_host, const ck_tile::HostTensor< GDataType > &g_host, const ck_tile::HostTensor< DDataType > &d_host, const ck_tile::HostTensor< AScaleDataType > &sa_host, const ck_tile::HostTensor< GScaleDataType > &sg_host, const ck_tile::HostTensor< DScaleDataType > &sd_host, const ck_tile::HostTensor< YSmoothScaleDataType > &sy_host, ck_tile::HostTensor< ODataType > &o_host, const ck_tile::HostTensor< IndexDataType > &sorted_token_ids_host, const ck_tile::HostTensor< TopkWeightDataType > &sorted_weight_host, const ck_tile::HostTensor< IndexDataType > &sorted_expert_ids_host, const ck_tile::HostTensor< IndexDataType > &num_sorted_tiles_host, const ck_tile::HostTensor< IndexDataType > &token_ids_host, ck_tile::index_t block_m, ck_tile::index_t tokens, ck_tile::index_t experts, ck_tile::index_t hidden_size, ck_tile::index_t intermediate_size, ck_tile::index_t topk, ck_tile::index_t gate_only)
Definition reference_fused_moe.hpp:54
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
Activation
Definition gridwise_moe_gemm.hpp:31
Definition tile/host/host_tensor.hpp:336
std::size_t get_num_of_dimension() const
Definition tile/host/host_tensor.hpp:396
std::size_t get_element_size() const
Definition tile/host/host_tensor.hpp:398
Data mData
Definition tile/host/host_tensor.hpp:801