17template <
typename Transforms,
18 typename LowerDimensionHiddenIdss,
19 typename UpperDimensionHiddenIdss,
20 typename BottomDimensionHiddenIds,
21 typename TopDimensionHiddenIds>
26 __host__ __device__
constexpr const auto&
GetTransforms()
const {
return transforms_; }
30 return LowerDimensionHiddenIdss{};
35 return UpperDimensionHiddenIdss{};
40 return TopDimensionHiddenIds{};
45 return BottomDimensionHiddenIds{};
58 static_assert(found ==
true,
59 "wrong! not found matching transformation and upper-dimension");
72 template <index_t IDim>
77 constexpr index_t idim_hidden = TopDimensionHiddenIds::At(idim_top);
84 constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran];
86 static_for<0, up_dim_ids.Size(), 1>{}([&](
auto idim_up) {
87 if constexpr(up_dim_ids[idim_up] == idim_hidden)
90 idim_up_found = idim_up;
96 return make_tuple(itran_found, idim_up_found, found);
101 return BottomDimensionHiddenIds::Size();
106 return TopDimensionHiddenIds::Size();
111 constexpr auto all_low_dim_ids =
113 LowerDimensionHiddenIdss{});
115 constexpr auto all_up_dim_ids =
117 UpperDimensionHiddenIdss{});
119 constexpr auto all_dim_ids =
merge_sequences(all_low_dim_ids, all_up_dim_ids);
125 return unique_sort_all_dim_ids::Size();
144 __host__ __device__
constexpr TensorAdaptor() : transforms_{}, element_size_{} {}
153 "wrong! inconsistent # of transformations");
158 __host__ __device__
constexpr auto GetElementSize()
const {
return element_size_; }
162 __host__ __device__
constexpr index_t GetTopDimensionLength(
Number<I> idim)
const
168 __host__ __device__
constexpr index_t GetBottomDimensionLength(
Number<I> idim)
const
174 template <
typename TopIdx>
177 static_assert(TopIdx::Size() == TopDimensionHiddenIds::Size(),
178 "wrong! # of dimension inconsistent");
189 static_for<ntransform, 0, -1>{}([&](
auto itran_p1) {
199 tran.CalculateLowerIndex(idx_low, idx_up);
209 bool is_known =
true;
211 static_for<0, Transforms::Size(), 1>{}([&](
auto i) {
218 __host__ __device__
void Print()
const
221 printf(
"TensorAdaptor, ");
223 printf(
"transforms: ");
224 transforms_[i].Print();
225 printf(
"LowerDimensionHiddenIds:");
226 LowerDimensionHiddenIdss{}.At(i).Print();
227 printf(
"UpperDimensionHiddenIds:");
228 UpperDimensionHiddenIdss{}.At(i).Print();
231 printf(
"BottomDimensionHiddenIds:");
232 BottomDimensionHiddenIds::Print();
233 printf(
"TopDimensionHiddenIds:");
234 TopDimensionHiddenIds::Print();
240 Transforms transforms_;
244template <
typename TensorAdaptor0,
typename TensorAdaptor1>
246 const TensorAdaptor1& adaptor1)
248 static_assert(TensorAdaptor0::GetNumOfTopDimension() ==
249 TensorAdaptor1::GetNumOfBottomDimension(),
253 const auto all_transforms =
257 constexpr index_t adaptor0_max_hidden_id = [&]() {
260 static_for<0, TensorAdaptor0::GetNumOfTransform(), 1>{}([&](
auto itran) {
262 TensorAdaptor0{}.GetTransforms()[itran].GetNumOfLowerDimension();
265 adaptor0_max_hidden_id_ =
267 TensorAdaptor0::GetLowerDimensionHiddenIdss()[itran][idim_low].
value);
271 TensorAdaptor0{}.GetTransforms()[itran].GetNumOfUpperDimension();
274 adaptor0_max_hidden_id_ =
276 TensorAdaptor0::GetUpperDimensionHiddenIdss()[itran][idim_up].
value);
280 return adaptor0_max_hidden_id_;
283 constexpr index_t adaptor1_min_hidden_id = [&]() {
286 static_for<0, TensorAdaptor1::GetNumOfTransform(), 1>{}([&](
auto itran) {
288 TensorAdaptor1{}.GetTransforms()[itran].GetNumOfLowerDimension();
293 constexpr index_t low_dim_hidden_id =
294 TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran][idim_low].value;
296 bool is_bottom_dim =
false;
297 static_for<0, TensorAdaptor1::GetNumOfBottomDimension(), 1>{}([&](
auto i) {
298 if constexpr(low_dim_hidden_id ==
299 TensorAdaptor1::GetBottomDimensionHiddenIds()[i])
301 is_bottom_dim =
true;
307 adaptor1_min_hidden_id_ =
math::min(adaptor1_min_hidden_id_, low_dim_hidden_id);
312 TensorAdaptor1{}.GetTransforms()[itran].GetNumOfUpperDimension();
316 adaptor1_min_hidden_id_ =
318 TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran][idim_up].
value);
322 return adaptor1_min_hidden_id_;
325 constexpr index_t adaptor1_hidden_id_shift =
326 adaptor0_max_hidden_id + 1 - adaptor1_min_hidden_id;
328 constexpr index_t ndim_bottom_1 = TensorAdaptor1::GetNumOfBottomDimension();
335 constexpr auto ndim_low_1 = TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran].Size();
337 constexpr auto low_dim_hidden_ids_1 =
338 TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran];
341 constexpr auto low_dim_hidden_ids_1_mod = [&]()
constexpr {
342 auto low_dim_hidden_ids_1_mod_ =
to_multi_index(low_dim_hidden_ids_1);
346 low_dim_hidden_ids_1_mod_(idim_low_1) += adaptor1_hidden_id_shift;
353 if constexpr(low_dim_hidden_ids_1[idim_low_1] ==
354 TensorAdaptor1::GetBottomDimensionHiddenIds()[idim_bottom_1])
356 low_dim_hidden_ids_1_mod_(idim_low_1) =
357 TensorAdaptor0::GetTopDimensionHiddenIds()[idim_bottom_1];
362 return low_dim_hidden_ids_1_mod_;
369 Number<TensorAdaptor1::GetNumOfTransform()>{});
371 constexpr auto all_low_dim_hidden_idss =
372 container_concat(TensorAdaptor0::GetLowerDimensionHiddenIdss(), low_dim_hidden_idss_1);
379 constexpr auto ndim_up_1 = TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran].Size();
381 constexpr auto up_dim_hidden_ids_1 =
382 TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran];
385 constexpr auto up_dim_hidden_ids_1_mod = [&]()
constexpr {
386 auto up_dim_hidden_ids_1_mod_ =
to_multi_index(up_dim_hidden_ids_1);
390 up_dim_hidden_ids_1_mod_(idim_up_1) += adaptor1_hidden_id_shift;
393 return up_dim_hidden_ids_1_mod_;
401 Number<TensorAdaptor1::GetNumOfTransform()>{});
403 constexpr auto all_up_dim_hidden_idss =
404 container_concat(TensorAdaptor0::GetUpperDimensionHiddenIdss(), up_dim_hidden_idss_1);
407 constexpr auto bottom_dim_hidden_ids = TensorAdaptor0::GetBottomDimensionHiddenIds();
410 constexpr auto top_dim_hidden_ids =
418 remove_cv_t<
decltype(top_dim_hidden_ids)>>{all_transforms};
424template <
typename Transforms,
typename LowerDimensionOldTopIdss,
typename UpperDimensionNewTopIdss>
426 LowerDimensionOldTopIdss,
427 UpperDimensionNewTopIdss)
429 constexpr index_t ntransform = Transforms::Size();
431 static_assert(LowerDimensionOldTopIdss::Size() == ntransform &&
432 UpperDimensionNewTopIdss::Size() == ntransform,
436 constexpr auto all_low_dim_old_top_ids =
unpack(
437 [](
auto&&... xs)
constexpr {
return merge_sequences(xs...); }, LowerDimensionOldTopIdss{});
439 constexpr auto all_up_dim_new_top_ids =
unpack(
440 [](
auto&&... xs)
constexpr {
return merge_sequences(xs...); }, UpperDimensionNewTopIdss{});
446 constexpr index_t ndim_old_top = all_low_dim_old_top_ids.Size();
447 constexpr index_t ndim_new_top = all_up_dim_new_top_ids.Size();
450 constexpr auto low_dim_hidden_idss = LowerDimensionOldTopIdss{};
458 constexpr auto bottom_dim_hidden_ids =
462 constexpr auto top_dim_hidden_ids =
469 remove_cv_t<
decltype(top_dim_hidden_ids)>>{transforms};
472template <
typename X,
typename... Xs,
typename enable_if<
sizeof...(Xs) >= 2,
bool>::type =
false>
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
__host__ __device__ constexpr auto container_concat(const X &x, const Ys &... ys)
Definition utility/container_helper.hpp:320
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0 &adaptor0, const TensorAdaptor1 &adaptor1)
Definition tensor_description/tensor_adaptor.hpp:245
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto unpack(F &&f, X &&x)
Definition functional4.hpp:46
__host__ __device__ constexpr auto generate_sequence_v2(F &&f, Number< N >)
Definition sequence_helper.hpp:25
std::enable_if< B, T > enable_if
Definition enable_if.hpp:24
__host__ __device__ constexpr auto to_multi_index(const T &x)
Definition array_multi_index.hpp:28
__host__ __device__ constexpr auto get_container_subset(const Array< T, N > &arr, Sequence< Is... >)
Definition utility/container_helper.hpp:346
__host__ __device__ constexpr auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition utility/container_helper.hpp:111
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
typename remove_cv< T >::type remove_cv_t
Definition type.hpp:295
__host__ __device__ constexpr void set_container_subset(Array< T, N > &y, Sequence< Is... > picks, const Array< T, sizeof...(Is)> &x)
Definition utility/container_helper.hpp:363
__host__ __device__ constexpr auto merge_sequences(Seqs...)
Definition utility/sequence.hpp:768
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
__host__ static __device__ constexpr T Max()
Definition numeric_limits.hpp:311
__host__ static __device__ constexpr T Min()
Definition numeric_limits.hpp:310
Definition tensor_description/tensor_adaptor.hpp:23
__host__ static __device__ constexpr index_t GetNumOfTransform()
Definition tensor_description/tensor_adaptor.hpp:24
__host__ static __device__ constexpr index_t GetNumOfBottomDimension()
Definition tensor_description/tensor_adaptor.hpp:99
static constexpr index_t ndim_hidden_
Definition tensor_description/tensor_adaptor.hpp:129
__host__ static __device__ constexpr auto GetLowerDimensionHiddenIdss()
Definition tensor_description/tensor_adaptor.hpp:28
__host__ static __device__ constexpr index_t GetNumOfHiddenDimension()
Definition tensor_description/tensor_adaptor.hpp:109
__host__ __device__ void Print() const
Definition tensor_description/tensor_adaptor.hpp:218
__host__ static __device__ constexpr auto GetUpperDimensionHiddenIdss()
Definition tensor_description/tensor_adaptor.hpp:33
__host__ static __device__ constexpr auto GetBottomDimensionHiddenIds()
Definition tensor_description/tensor_adaptor.hpp:43
__host__ static __device__ constexpr index_t GetNumOfTopDimension()
Definition tensor_description/tensor_adaptor.hpp:104
static constexpr index_t ntransform_
Definition tensor_description/tensor_adaptor.hpp:128
__host__ __device__ constexpr TensorAdaptor()
Definition tensor_description/tensor_adaptor.hpp:144
remove_cv_t< decltype(InitializeElementSize(Transforms{}))> ElementSize
Definition tensor_description/tensor_adaptor.hpp:138
__host__ __device__ constexpr auto GetElementSize() const
Definition tensor_description/tensor_adaptor.hpp:158
__host__ __device__ constexpr TensorAdaptor(const Transforms &transforms)
Definition tensor_description/tensor_adaptor.hpp:147
__host__ static __device__ constexpr auto GetTransformAndItsUpperDimension(Number< IDim >)
Definition tensor_description/tensor_adaptor.hpp:73
MultiIndex< ndim_hidden_ > HiddenIndex
Definition tensor_description/tensor_adaptor.hpp:133
static constexpr index_t ndim_bottom_
Definition tensor_description/tensor_adaptor.hpp:130
__host__ static __device__ constexpr auto InitializeElementSize(const Transforms &transforms)
Definition tensor_description/tensor_adaptor.hpp:48
__host__ static __device__ constexpr bool IsKnownAtCompileTime()
Definition tensor_description/tensor_adaptor.hpp:207
MultiIndex< ndim_bottom_ > BottomIndex
Definition tensor_description/tensor_adaptor.hpp:134
__host__ static __device__ constexpr auto GetTopDimensionHiddenIds()
Definition tensor_description/tensor_adaptor.hpp:38
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition tensor_description/tensor_adaptor.hpp:175
__host__ __device__ constexpr const auto & GetTransforms() const
Definition tensor_description/tensor_adaptor.hpp:26
MultiIndex< ndim_top_ > TopIndex
Definition tensor_description/tensor_adaptor.hpp:135
static constexpr index_t ndim_top_
Definition tensor_description/tensor_adaptor.hpp:131
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition is_known_at_compile_time.hpp:14
Definition utility/sequence.hpp:618
Definition utility/math.hpp:211
Definition utility/math.hpp:217
Definition utility/math.hpp:34
Definition utility/sequence.hpp:543
Definition functional2.hpp:33