device_batched_gemm_reduce_xdl_cshuffle.hpp File Reference#
device_batched_gemm_reduce_xdl_cshuffle.hpp File Reference
#include <iostream>#include <sstream>#include "ck/utility/common_header.hpp"#include "ck/utility/env.hpp"#include "ck/tensor_description/tensor_descriptor.hpp"#include "ck/tensor_description/tensor_descriptor_helper.hpp"#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"#include "ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp"#include "ck/host_utility/device_prop.hpp"#include "ck/host_utility/kernel_launch.hpp"Go to the source code of this file.
Namespaces | |
| namespace | ck |
| namespace | ck::tensor_operation |
| namespace | ck::tensor_operation::device |
Functions | |
| template<typename GridwiseGemm, typename FloatAB, typename FloatC, typename ReducePtrsGlobal, typename AElementwiseOperation, typename BElementwiseOperation, typename CElementwiseOperation, typename ReduceInElementwiseOperations, typename ReduceAccElementwiseOperations, typename AGridDesc_AK0_M_AK1, typename BGridDesc_BK0_N_BK1, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename ReduceGridDescriptor_MBlock_MPerBlock, typename ComputeBasePrtOfBatch, typename Block2CTileMap, bool HasMainK0BlockLoop> | |
| __global__ void | ck::tensor_operation::device::kernel_batched_gemm_reduce_xdl_cshuffle_v1 (const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, ReducePtrsGlobal p_reduces_grid, const index_t batch_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const ReduceInElementwiseOperations reduce_in_element_ops, const ReduceAccElementwiseOperations reduce_out_element_ops, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, const Block2CTileMap block_2_ctile_map) |