rmsnorm2d_fwd_traits.hpp Source File

rmsnorm2d_fwd_traits.hpp Source File#

Composable Kernel: rmsnorm2d_fwd_traits.hpp Source File
rmsnorm2d_fwd_traits.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck_tile {
9
11{
12 NO_ADD = 0,
13 // fused add before RMSNorm and store result to global
15 // fused add before RMSNorm, but not store result
17};
18
19// clang-format off
20template<Rmsnorm2dFusedAddEnum> struct Rmsnorm2dFusedAddEnumName;
21template<> struct Rmsnorm2dFusedAddEnumName<Rmsnorm2dFusedAddEnum::NO_ADD> { static constexpr const char * name = "no"; };
22template<> struct Rmsnorm2dFusedAddEnumName<Rmsnorm2dFusedAddEnum::PRE_ADD_STORE> { static constexpr const char * name = "pras"; };
23template<> struct Rmsnorm2dFusedAddEnumName<Rmsnorm2dFusedAddEnum::PRE_ADD> { static constexpr const char * name = "pra"; };
24// clang-format on
25
27{
29 SMOOTH_DYNAMIC_QUANT = 1, // smooth oulier + rowwise quant, need input x-scale and store y_scale
30 DYNAMIC_QUANT = 2, // rowwise quant, store out a y-scale
31};
32
33// clang-format off
34template<Rmsnorm2dFusedQuantEnum> struct Rmsnorm2dFusedQuantEnumName;
35template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::NO_SWEEP> { static constexpr const char * name = "no"; };
36template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT> { static constexpr const char * name = "dqt"; };
37template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT> { static constexpr const char * name = "smdqt"; };
38// clang-format on
39
41{
43 // T5-like model for RMSNorm. The T5 model, developed by Google, is a transformer-based
44 // architecture designed for a variety of NLP tasks. This option mimics T5's approach to
45 // RMSNorm, aiming to ensure similar value distributions and enhance accuracy.
47};
48
49// clang-format off
50template<Rmsnorm2dSensitiveEnum> struct Rmsnorm2dSensitiveEnumName;
51template<> struct Rmsnorm2dSensitiveEnumName<Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL> { static constexpr const char * name = "nsm"; };
52template<> struct Rmsnorm2dSensitiveEnumName<Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE> { static constexpr const char * name = "t5ml"; };
53// clang-format on
54
55template <bool kPadN_,
56 bool kSaveInvRms_,
57 bool kSaveUnquant_,
58 bool kTwoPass_,
59 Rmsnorm2dFusedAddEnum kFusedAdd_,
60 Rmsnorm2dFusedQuantEnum kFusedQuant_,
61 Rmsnorm2dSensitiveEnum kUseModelSensitiveRMSNorm_>
63{
64 static constexpr bool kPadN = kPadN_;
65 static constexpr bool kSaveInvRms = kSaveInvRms_;
66 static constexpr bool kSaveUnquant = kSaveUnquant_;
67 static constexpr bool kTwoPass = kTwoPass_;
68 static constexpr Rmsnorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
69 static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
70 static constexpr Rmsnorm2dSensitiveEnum kUseModelSensitiveRMSNorm = kUseModelSensitiveRMSNorm_;
71};
72
73} // namespace ck_tile
Definition tile/core/algorithm/cluster_descriptor.hpp:13
Rmsnorm2dSensitiveEnum
Definition rmsnorm2d_fwd_traits.hpp:41
@ NO_SPECIFIC_MODEL
Definition rmsnorm2d_fwd_traits.hpp:42
@ T5_MODEL_LIKE
Definition rmsnorm2d_fwd_traits.hpp:46
Rmsnorm2dFusedQuantEnum
Definition rmsnorm2d_fwd_traits.hpp:27
@ NO_SWEEP
Definition layernorm2d_fwd_traits.hpp:41
@ SMOOTH_DYNAMIC_QUANT
Definition layernorm2d_fwd_traits.hpp:42
@ DYNAMIC_QUANT
Definition layernorm2d_fwd_traits.hpp:43
Rmsnorm2dFusedAddEnum
Definition rmsnorm2d_fwd_traits.hpp:11
@ PRE_ADD_STORE
Definition layernorm2d_fwd_traits.hpp:27
@ PRE_ADD
Definition layernorm2d_fwd_traits.hpp:29
@ NO_ADD
Definition layernorm2d_fwd_traits.hpp:25
static constexpr const char * name
Definition rmsnorm2d_fwd_traits.hpp:21
static constexpr const char * name
Definition rmsnorm2d_fwd_traits.hpp:23
static constexpr const char * name
Definition rmsnorm2d_fwd_traits.hpp:22
Definition rmsnorm2d_fwd_traits.hpp:20
static constexpr const char * name
Definition rmsnorm2d_fwd_traits.hpp:36
static constexpr const char * name
Definition rmsnorm2d_fwd_traits.hpp:35
static constexpr const char * name
Definition rmsnorm2d_fwd_traits.hpp:37
Definition rmsnorm2d_fwd_traits.hpp:34
Definition rmsnorm2d_fwd_traits.hpp:63
static constexpr bool kSaveUnquant
Definition rmsnorm2d_fwd_traits.hpp:66
static constexpr Rmsnorm2dFusedAddEnum kFusedAdd
Definition rmsnorm2d_fwd_traits.hpp:68
static constexpr Rmsnorm2dSensitiveEnum kUseModelSensitiveRMSNorm
Definition rmsnorm2d_fwd_traits.hpp:70
static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant
Definition rmsnorm2d_fwd_traits.hpp:69
static constexpr bool kTwoPass
Definition rmsnorm2d_fwd_traits.hpp:67
static constexpr bool kPadN
Definition rmsnorm2d_fwd_traits.hpp:64
static constexpr bool kSaveInvRms
Definition rmsnorm2d_fwd_traits.hpp:65
static constexpr const char * name
Definition rmsnorm2d_fwd_traits.hpp:51
static constexpr const char * name
Definition rmsnorm2d_fwd_traits.hpp:52
Definition rmsnorm2d_fwd_traits.hpp:50