convolution_parameter.hpp Source File

convolution_parameter.hpp Source File#

Composable Kernel: convolution_parameter.hpp Source File
tile/host/convolution_parameter.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <cstdlib>
7#include <numeric>
8#include <iterator>
9#include <vector>
10
11namespace ck_tile {
12namespace conv {
13
15{
17 ck_tile::index_t group_count,
18 ck_tile::index_t n_batch,
19 ck_tile::index_t n_out_channels,
20 ck_tile::index_t n_in_channels,
21 const std::vector<ck_tile::index_t>& filters_len,
22 const std::vector<ck_tile::index_t>& input_len,
23 const std::vector<ck_tile::index_t>& strides,
24 const std::vector<ck_tile::index_t>& dilations,
25 const std::vector<ck_tile::index_t>& left_pads,
26 const std::vector<ck_tile::index_t>& right_pads)
27 : num_dim_spatial_(static_cast<ck_tile::long_index_t>(n_dim)),
28 G_(static_cast<ck_tile::long_index_t>(group_count)),
29 N_(static_cast<ck_tile::long_index_t>(n_batch)),
30 K_(static_cast<ck_tile::long_index_t>(n_out_channels)),
31 C_(static_cast<ck_tile::long_index_t>(n_in_channels)),
39 {
44 static_cast<ck_tile::index_t>(input_left_pads_.size()) != num_dim_spatial_ ||
46 {
47 throw(std::runtime_error(
48 "ConvParam::ConvParam: "
49 "parameter size is different from number of declared dimensions!"));
50 }
51
52 for(ck_tile::index_t i = 0; i < num_dim_spatial_; ++i)
53 {
54 filter_spatial_lengths_[i] = static_cast<ck_tile::long_index_t>(filters_len[i]);
55 input_spatial_lengths_[i] = static_cast<ck_tile::long_index_t>(input_len[i]);
56 conv_filter_strides_[i] = static_cast<ck_tile::long_index_t>(strides[i]);
57 conv_filter_dilations_[i] = static_cast<ck_tile::long_index_t>(dilations[i]);
58 input_left_pads_[i] = static_cast<ck_tile::long_index_t>(left_pads[i]);
59 input_right_pads_[i] = static_cast<ck_tile::long_index_t>(right_pads[i]);
60
61 // XEff = (X - 1) * conv_dilation_w + 1;
62 // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
63 const ck_tile::long_index_t x_eff =
64 (filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
65
66 output_spatial_lengths_[i] =
67 (input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) /
68 conv_filter_strides_[i] +
69 1;
70 }
71 }
72
74 ck_tile::long_index_t group_count,
76 ck_tile::long_index_t n_out_channels,
77 ck_tile::long_index_t n_in_channels,
78 const std::vector<ck_tile::long_index_t>& filters_len,
79 const std::vector<ck_tile::long_index_t>& input_len,
80 const std::vector<ck_tile::long_index_t>& strides,
81 const std::vector<ck_tile::long_index_t>& dilations,
82 const std::vector<ck_tile::long_index_t>& left_pads,
83 const std::vector<ck_tile::long_index_t>& right_pads)
84 : num_dim_spatial_(n_dim),
85 G_(group_count),
86 N_(n_batch),
87 K_(n_out_channels),
88 C_(n_in_channels),
89 filter_spatial_lengths_(filters_len),
90 input_spatial_lengths_(input_len),
92 conv_filter_strides_(strides),
93 conv_filter_dilations_(dilations),
94 input_left_pads_(left_pads),
95 input_right_pads_(right_pads)
96 {
101 static_cast<ck_tile::index_t>(input_left_pads_.size()) != num_dim_spatial_ ||
103 {
104 throw(std::runtime_error(
105 "ConvParam::ConvParam: "
106 "parameter size is different from number of declared dimensions!"));
107 }
108
109 for(ck_tile::index_t i = 0; i < num_dim_spatial_; ++i)
110 {
111 // XEff = (X - 1) * conv_dilation_w + 1;
112 // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
113 const ck_tile::long_index_t x_eff =
114 (filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
115
116 output_spatial_lengths_[i] =
117 (input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) /
118 conv_filter_strides_[i] +
119 1;
120 }
121 }
122
128
129 std::vector<ck_tile::long_index_t> filter_spatial_lengths_;
130 std::vector<ck_tile::long_index_t> input_spatial_lengths_;
131 std::vector<ck_tile::long_index_t> output_spatial_lengths_;
132
133 std::vector<ck_tile::long_index_t> conv_filter_strides_;
134 std::vector<ck_tile::long_index_t> conv_filter_dilations_;
135
136 std::vector<ck_tile::long_index_t> input_left_pads_;
137 std::vector<ck_tile::long_index_t> input_right_pads_;
138
139 std::vector<ck_tile::long_index_t> GetOutputSpatialLengths() const
140 {
142 }
143
144 std::size_t GetFlops() const
145 {
146 // 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
147 return static_cast<std::size_t>(2) * G_ * N_ * K_ * C_ *
148 std::accumulate(std::begin(output_spatial_lengths_),
149 std::next(std::begin(output_spatial_lengths_), num_dim_spatial_),
150 1,
151 std::multiplies<>()) *
152 std::accumulate(std::begin(filter_spatial_lengths_),
153 std::next(std::begin(filter_spatial_lengths_), num_dim_spatial_),
154 1,
155 std::multiplies<>());
156 }
157
158 template <typename InDataType>
159 std::size_t GetInputByte() const
160 {
161 // sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
162 return sizeof(InDataType) *
163 (G_ * N_ * C_ *
164 std::accumulate(std::begin(input_spatial_lengths_),
165 std::next(std::begin(input_spatial_lengths_), num_dim_spatial_),
166 1,
167 std::multiplies<>()));
168 }
169
170 template <typename WeiDataType>
171 std::size_t GetWeightByte() const
172 {
173 // sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
174 return sizeof(WeiDataType) *
175 (G_ * K_ * C_ *
176 std::accumulate(std::begin(filter_spatial_lengths_),
177 std::next(std::begin(filter_spatial_lengths_), num_dim_spatial_),
178 1,
179 std::multiplies<>()));
180 }
181
182 template <typename OutDataType>
183 std::size_t GetOutputByte() const
184 {
185 // sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
186 return sizeof(OutDataType) * (G_ * N_ * K_ *
187 std::accumulate(std::begin(output_spatial_lengths_),
188 std::end(output_spatial_lengths_),
189 static_cast<std::size_t>(1),
190 std::multiplies<std::size_t>()));
191 }
192
193 template <typename InDataType, typename WeiDataType, typename OutDataType>
194 std::size_t GetByte() const
195 {
198 }
199};
200
202{
203 std::string msg;
204
205 msg += "Following arguments (depending on number of spatial dims):\n"
206 " Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d)\n"
207 " G, N, K, C, \n"
208 " <filter spatial dimensions>, (ie Y, X for 2D)\n"
209 " <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
210 " <strides>, (ie Sy, Sx for 2D)\n"
211 " <dilations>, (ie Dy, Dx for 2D)\n"
212 " <left padding>, (ie LeftPy, LeftPx for 2D)\n"
213 " <right padding>, (ie RightPy, RightPx for 2D)\n";
214
215 return msg;
216}
217
219parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[])
220{
221 const ck_tile::long_index_t G = std::stol(argv[arg_idx++]);
222 const ck_tile::long_index_t N = std::stol(argv[arg_idx++]);
223 const ck_tile::long_index_t K = std::stol(argv[arg_idx++]);
224 const ck_tile::long_index_t C = std::stol(argv[arg_idx++]);
225
226 std::vector<ck_tile::long_index_t> filter_spatial_lengths(num_dim_spatial);
227 std::vector<ck_tile::long_index_t> input_spatial_lengths(num_dim_spatial);
228 std::vector<ck_tile::long_index_t> conv_filter_strides(num_dim_spatial);
229 std::vector<ck_tile::long_index_t> conv_filter_dilations(num_dim_spatial);
230 std::vector<ck_tile::long_index_t> input_left_pads(num_dim_spatial);
231 std::vector<ck_tile::long_index_t> input_right_pads(num_dim_spatial);
232
233 for(int i = 0; i < num_dim_spatial; ++i)
234 {
235 filter_spatial_lengths[i] = std::stol(argv[arg_idx++]);
236 }
237
238 for(int i = 0; i < num_dim_spatial; ++i)
239 {
240 input_spatial_lengths[i] = std::stol(argv[arg_idx++]);
241 }
242
243 for(int i = 0; i < num_dim_spatial; ++i)
244 {
245 conv_filter_strides[i] = std::stol(argv[arg_idx++]);
246 }
247
248 for(int i = 0; i < num_dim_spatial; ++i)
249 {
250 conv_filter_dilations[i] = std::stol(argv[arg_idx++]);
251 }
252
253 for(int i = 0; i < num_dim_spatial; ++i)
254 {
255 input_left_pads[i] = std::stol(argv[arg_idx++]);
256 }
257
258 for(int i = 0; i < num_dim_spatial; ++i)
259 {
260 input_right_pads[i] = std::stol(argv[arg_idx++]);
261 }
262
263 return ck_tile::conv::ConvParam{num_dim_spatial,
264 G,
265 N,
266 K,
267 C,
268 filter_spatial_lengths,
269 input_spatial_lengths,
270 conv_filter_strides,
271 conv_filter_dilations,
272 input_left_pads,
273 input_right_pads};
274}
275
276} // namespace conv
277} // namespace ck_tile
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/host/convolution_host_tensor_descriptor_helper.hpp:11
CK_TILE_HOST ck_tile::conv::ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char *const argv[])
Definition tile/host/convolution_parameter.hpp:219
CK_TILE_HOST std::string get_conv_param_parser_helper_msg()
Definition tile/host/convolution_parameter.hpp:201
Definition tile/core/algorithm/cluster_descriptor.hpp:13
int64_t long_index_t
Definition integer.hpp:11
int32_t index_t
Definition integer.hpp:9
Definition tile/host/convolution_parameter.hpp:15
std::vector< ck_tile::long_index_t > input_spatial_lengths_
Definition tile/host/convolution_parameter.hpp:130
std::size_t GetWeightByte() const
Definition tile/host/convolution_parameter.hpp:171
ck_tile::long_index_t K_
Definition tile/host/convolution_parameter.hpp:126
ck_tile::long_index_t num_dim_spatial_
Definition tile/host/convolution_parameter.hpp:123
std::vector< ck_tile::long_index_t > output_spatial_lengths_
Definition tile/host/convolution_parameter.hpp:131
std::vector< ck_tile::long_index_t > GetOutputSpatialLengths() const
Definition tile/host/convolution_parameter.hpp:139
std::vector< ck_tile::long_index_t > input_right_pads_
Definition tile/host/convolution_parameter.hpp:137
ck_tile::long_index_t G_
Definition tile/host/convolution_parameter.hpp:124
std::size_t GetInputByte() const
Definition tile/host/convolution_parameter.hpp:159
std::vector< ck_tile::long_index_t > conv_filter_strides_
Definition tile/host/convolution_parameter.hpp:133
std::size_t GetFlops() const
Definition tile/host/convolution_parameter.hpp:144
std::vector< ck_tile::long_index_t > filter_spatial_lengths_
Definition tile/host/convolution_parameter.hpp:129
ck_tile::long_index_t C_
Definition tile/host/convolution_parameter.hpp:127
std::size_t GetByte() const
Definition tile/host/convolution_parameter.hpp:194
ck_tile::long_index_t N_
Definition tile/host/convolution_parameter.hpp:125
std::vector< ck_tile::long_index_t > input_left_pads_
Definition tile/host/convolution_parameter.hpp:136
ConvParam(ck_tile::long_index_t n_dim, ck_tile::long_index_t group_count, ck_tile::long_index_t n_batch, ck_tile::long_index_t n_out_channels, ck_tile::long_index_t n_in_channels, const std::vector< ck_tile::long_index_t > &filters_len, const std::vector< ck_tile::long_index_t > &input_len, const std::vector< ck_tile::long_index_t > &strides, const std::vector< ck_tile::long_index_t > &dilations, const std::vector< ck_tile::long_index_t > &left_pads, const std::vector< ck_tile::long_index_t > &right_pads)
Definition tile/host/convolution_parameter.hpp:73
std::vector< ck_tile::long_index_t > conv_filter_dilations_
Definition tile/host/convolution_parameter.hpp:134
ConvParam(ck_tile::index_t n_dim, ck_tile::index_t group_count, ck_tile::index_t n_batch, ck_tile::index_t n_out_channels, ck_tile::index_t n_in_channels, const std::vector< ck_tile::index_t > &filters_len, const std::vector< ck_tile::index_t > &input_len, const std::vector< ck_tile::index_t > &strides, const std::vector< ck_tile::index_t > &dilations, const std::vector< ck_tile::index_t > &left_pads, const std::vector< ck_tile::index_t > &right_pads)
Definition tile/host/convolution_parameter.hpp:16
std::size_t GetOutputByte() const
Definition tile/host/convolution_parameter.hpp:183