enable_if_t< WaveSize==32||WaveSize==64 > > Struct Template Reference#
ck::wmma_type< WmmaInstr::wmma_i32_16x16x16_iu8_gfx12, WaveSize, typename std::enable_if_t< WaveSize==32||WaveSize==64 > > Struct Template Reference
#include <wmma_gemm.hpp>
Public Member Functions | |
| template<index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC, bool neg_a = true, bool neg_b = true, bool clamp = false> | |
| __device__ void | run (const FloatA &a, const FloatB &b, FloatC ®_c) const |
Static Public Attributes | |
| static constexpr index_t | m_per_wmma = 16 |
| static constexpr index_t | n_per_wmma = 16 |
| static constexpr index_t | k_per_wmma = 16 |
| static constexpr index_t | acc_data_size = 4 |
| static constexpr index_t | acc_pack_number = 1 |
| static constexpr index_t | num_thread_per_subgroups = n_per_wmma |
| static constexpr index_t | wave_size = Number<WaveSize>{} |
| static constexpr index_t | num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size |
| static constexpr index_t | num_subgroups = wave_size / num_thread_per_subgroups |
Member Function Documentation
◆ run()
template<index_t WaveSize>
template<index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC, bool neg_a = true, bool neg_b = true, bool clamp = false>
|
inline |
Member Data Documentation
◆ acc_data_size
template<index_t WaveSize>
|
staticconstexpr |
◆ acc_pack_number
template<index_t WaveSize>
|
staticconstexpr |
◆ k_per_wmma
template<index_t WaveSize>
|
staticconstexpr |
◆ m_per_wmma
template<index_t WaveSize>
|
staticconstexpr |
◆ n_per_wmma
template<index_t WaveSize>
|
staticconstexpr |
◆ num_acc_vgprs_per_wave
template<index_t WaveSize>
|
staticconstexpr |
◆ num_subgroups
template<index_t WaveSize>
|
staticconstexpr |
◆ num_thread_per_subgroups
template<index_t WaveSize>
|
staticconstexpr |
◆ wave_size
template<index_t WaveSize>
|
staticconstexpr |
The documentation for this struct was generated from the following file: