device_gemm_multiple_abd.hpp Source File

device_gemm_multiple_abd.hpp Source File#

Composable Kernel: device_gemm_multiple_abd.hpp Source File
device_gemm_multiple_abd.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <array>
7
9
10namespace ck {
11namespace tensor_operation {
12namespace device {
13
14// GEMM:
15// input : A0[M, K], B0[K, N],
16// input : D0[M, N], D1[M, N], ...
17// output : E[M, N]
18// C = a_op(A) * b_op(B)
19// E = cde_op(C, D0, D1, ...)
20// Assume:
21// D0, D1, ... and E have the same layout
22template <typename AsLayout,
23 typename BsLayout,
24 typename DsLayout,
25 typename ELayout,
26 typename AsDataType,
27 typename BsDataType,
28 typename DsDataType,
29 typename EDataType,
30 typename AElementwiseOperation,
31 typename BElementwiseOperation,
32 typename CDEElementwiseOperation>
34{
35 static constexpr index_t NumATensor = AsDataType::Size();
36 static constexpr index_t NumBTensor = BsDataType::Size();
37 static constexpr index_t NumDTensor = DsDataType::Size();
38
39 virtual std::unique_ptr<BaseArgument>
40 MakeArgumentPointer(std::array<const void*, NumATensor> p_as,
41 std::array<const void*, NumBTensor> p_bs,
42 std::array<const void*, NumDTensor> p_ds,
43 void* p_e,
47 std::array<ck::index_t, NumATensor> StrideAs,
48 std::array<ck::index_t, NumBTensor> StrideBs,
49 std::array<ck::index_t, NumDTensor> StrideDs,
50 ck::index_t StrideE,
51 AElementwiseOperation a_element_op,
52 BElementwiseOperation b_element_op,
53 CDEElementwiseOperation cde_element_op) = 0;
54
55 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
56};
57
58// GEMM:
59// input : A0[M, K], B0[K, N],
60// input : D0[M, N], D1[M, N], ...
61// output : E[M, N]
62// C = a_op(A) * b_op(B)
63// E = cde_op(C, D0, D1, ...)
64// Assume:
65// D0, D1, ... and E have the same layout
66template <typename AsLayout,
67 typename BsLayout,
68 typename DsLayout,
69 typename ELayout,
70 typename AsDataType,
71 typename BsDataType,
72 typename DsDataType,
73 typename EDataType,
74 typename AElementwiseOperation,
75 typename BElementwiseOperation,
76 typename CDEElementwiseOperation>
78{
79 static constexpr index_t NumATensor = AsDataType::Size();
80 static constexpr index_t NumBTensor = BsDataType::Size();
81 static constexpr index_t NumDTensor = DsDataType::Size();
82
83 virtual std::unique_ptr<BaseArgument>
84 MakeArgumentPointer(std::array<const void*, NumATensor> p_as,
85 std::array<const void*, NumBTensor> p_bs,
86 std::array<const void*, NumDTensor> p_ds,
87 void* p_e,
91 std::array<ck::index_t, NumATensor> StrideAs,
92 std::array<ck::index_t, NumBTensor> StrideBs,
93 std::array<ck::index_t, NumDTensor> StrideDs,
94 ck::index_t StrideE,
95 ck::index_t KBatch,
96 AElementwiseOperation a_element_op,
97 BElementwiseOperation b_element_op,
98 CDEElementwiseOperation cde_element_op) = 0;
99
100 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
101};
102
110template <typename AsLayout,
111 typename BsLayout,
112 typename DsLayout,
113 typename ELayout,
114 typename AsDataType,
115 typename BsDataType,
116 typename DsDataType,
117 typename EDataType,
118 typename AElementwiseOperation,
119 typename BElementwiseOperation,
120 typename CDEElementwiseOperation>
122 BsLayout,
123 DsLayout,
124 ELayout,
125 AsDataType,
126 BsDataType,
127 DsDataType,
128 EDataType,
129 AElementwiseOperation,
130 BElementwiseOperation,
131 CDEElementwiseOperation>
132{
133
135 BsLayout,
136 DsLayout,
137 ELayout,
138 AsDataType,
139 BsDataType,
140 DsDataType,
141 EDataType,
142 AElementwiseOperation,
143 BElementwiseOperation,
144 CDEElementwiseOperation>;
145
146 static constexpr index_t NumATensor = AsDataType::Size();
147 static constexpr index_t NumBTensor = BsDataType::Size();
148 static constexpr index_t NumDTensor = DsDataType::Size();
149
150#ifndef __HIPCC_RTC__
151
152 explicit DeviceGemmMultipleABDSplitKWrapper(std::unique_ptr<DeviceOp> p_op)
153 : p_op_(std::move(p_op))
154 {
155 }
156
157 bool IsSupportedArgument(const BaseArgument* p_arg) override
158 {
159 return p_op_->IsSupportedArgument(p_arg);
160 }
161 std::unique_ptr<BaseArgument>
162 MakeArgumentPointer(std::array<const void*, NumATensor> p_as,
163 std::array<const void*, NumBTensor> p_bs,
164 std::array<const void*, NumDTensor> p_ds,
165 void* p_e,
166 ck::index_t M,
167 ck::index_t N,
168 ck::index_t K,
169 std::array<ck::index_t, NumATensor> StrideAs,
170 std::array<ck::index_t, NumBTensor> StrideBs,
171 std::array<ck::index_t, NumDTensor> StrideDs,
172 ck::index_t StrideE,
173 AElementwiseOperation a_element_op,
174 BElementwiseOperation b_element_op,
175 CDEElementwiseOperation cde_element_op) override
176 {
177 return p_op_->MakeArgumentPointer(p_as,
178 p_bs,
179 p_ds,
180 p_e,
181 M,
182 N,
183 K,
184 StrideAs,
185 StrideBs,
186 StrideDs,
187 StrideE,
188 1, // KBatch
189 a_element_op,
190 b_element_op,
191 cde_element_op);
192 }
193
194 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
195 {
196 return p_op_->MakeInvokerPointer();
197 }
198
199 std::string GetTypeString() const override { return p_op_->GetTypeString(); }
200
201 private:
202 std::unique_ptr<DeviceOp> p_op_;
203
204#endif // __HIPCC_RTC__
205};
206
207} // namespace device
208} // namespace tensor_operation
209} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
STL namespace.
Definition device_base.hpp:197
Definition device_gemm_multiple_abd.hpp:34
static constexpr index_t NumATensor
Definition device_gemm_multiple_abd.hpp:35
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(std::array< const void *, NumATensor > p_as, std::array< const void *, NumBTensor > p_bs, std::array< const void *, NumDTensor > p_ds, void *p_e, ck::index_t M, ck::index_t N, ck::index_t K, std::array< ck::index_t, NumATensor > StrideAs, std::array< ck::index_t, NumBTensor > StrideBs, std::array< ck::index_t, NumDTensor > StrideDs, ck::index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
static constexpr index_t NumBTensor
Definition device_gemm_multiple_abd.hpp:36
static constexpr index_t NumDTensor
Definition device_gemm_multiple_abd.hpp:37
Definition device_gemm_multiple_abd.hpp:78
static constexpr index_t NumDTensor
Definition device_gemm_multiple_abd.hpp:81
static constexpr index_t NumATensor
Definition device_gemm_multiple_abd.hpp:79
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(std::array< const void *, NumATensor > p_as, std::array< const void *, NumBTensor > p_bs, std::array< const void *, NumDTensor > p_ds, void *p_e, ck::index_t M, ck::index_t N, ck::index_t K, std::array< ck::index_t, NumATensor > StrideAs, std::array< ck::index_t, NumBTensor > StrideBs, std::array< ck::index_t, NumDTensor > StrideDs, ck::index_t StrideE, ck::index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
static constexpr index_t NumBTensor
Definition device_gemm_multiple_abd.hpp:80
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_multiple_abd.hpp:194
DeviceGemmMultipleABDSplitK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation > DeviceOp
Definition device_gemm_multiple_abd.hpp:134
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_multiple_abd.hpp:157
static constexpr index_t NumBTensor
Definition device_gemm_multiple_abd.hpp:147
static constexpr index_t NumATensor
Definition device_gemm_multiple_abd.hpp:146
std::unique_ptr< BaseArgument > MakeArgumentPointer(std::array< const void *, NumATensor > p_as, std::array< const void *, NumBTensor > p_bs, std::array< const void *, NumDTensor > p_ds, void *p_e, ck::index_t M, ck::index_t N, ck::index_t K, std::array< ck::index_t, NumATensor > StrideAs, std::array< ck::index_t, NumBTensor > StrideBs, std::array< ck::index_t, NumDTensor > StrideDs, ck::index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition device_gemm_multiple_abd.hpp:162
static constexpr index_t NumDTensor
Definition device_gemm_multiple_abd.hpp:148
DeviceGemmMultipleABDSplitKWrapper(std::unique_ptr< DeviceOp > p_op)
Definition device_gemm_multiple_abd.hpp:152
std::string GetTypeString() const override
Definition device_gemm_multiple_abd.hpp:199