Skip to content

Commit

Permalink
Split hw (re)config between unpack/math threads (#34)
Browse files Browse the repository at this point in the history
* Add stall on unpack reconfig

* Split hw (re)config between unpack and math threads

* Stall on math and spfu for math hw reconfig

* add hw reconfig param float_only

* add static check for int8 compatibility
  • Loading branch information
rdjogoTT authored Oct 8, 2024
1 parent 1665150 commit bbe1d58
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 13 deletions.
10 changes: 3 additions & 7 deletions common/inc/cunpack_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,8 @@ namespace ckernel::unpacker
((uint)unpA_dst_format == (uint)DataFormat::Int32) ||
((uint)unpB_dst_format == (uint)DataFormat::Int32);

constexpr uint alu_format_mask = ALU_FORMAT_SPEC_REG0_SrcA_MASK | ALU_FORMAT_SPEC_REG1_SrcB_MASK |
ALU_FORMAT_SPEC_REG0_SrcAUnsigned_MASK | ALU_FORMAT_SPEC_REG0_SrcBUnsigned_MASK;
alu_payload.f.ALU_FORMAT_SPEC_REG0_SrcA = unpA_dst_format;
alu_payload.f.ALU_FORMAT_SPEC_REG1_SrcB = row_pool ? ((uint) DataFormat::Float16 | (exp_width<<2)) : unpB_dst_format;

constexpr uint alu_format_mask = ALU_FORMAT_SPEC_REG0_SrcAUnsigned_MASK | ALU_FORMAT_SPEC_REG0_SrcBUnsigned_MASK;

if ((uint)unpA_src_format == (uint)DataFormat::UInt8) {
alu_payload.f.ALU_FORMAT_SPEC_REG0_SrcAUnsigned = 1;
}
Expand All @@ -252,10 +249,9 @@ namespace ckernel::unpacker
// NOTE: This assumes these config fields are adjacent and in same register!!
static_assert(ALU_ACC_CTRL_Fp32_enabled_ADDR32 == ALU_FORMAT_SPEC_REG0_SrcA_ADDR32);
static_assert(ALU_ACC_CTRL_Fp32_enabled_ADDR32 == ALU_ACC_CTRL_SFPU_Fp32_enabled_ADDR32);
constexpr uint alu_dest_format_mask = ALU_ACC_CTRL_INT8_math_enabled_MASK | ALU_ACC_CTRL_SFPU_Fp32_enabled_MASK | ALU_ACC_CTRL_Fp32_enabled_MASK;
constexpr uint alu_dest_format_mask = ALU_ACC_CTRL_SFPU_Fp32_enabled_MASK | ALU_ACC_CTRL_Fp32_enabled_MASK;
alu_payload.f.ALU_ACC_CTRL_Fp32_enabled = fp32_dest_acc_en;
alu_payload.f.ALU_ACC_CTRL_SFPU_Fp32_enabled = fp32_dest_acc_en;
alu_payload.f.ALU_ACC_CTRL_INT8_math_enabled = int8_math_enabled;
constexpr uint alu_stoch_rnd_mask = ALU_ROUNDING_MODE_Fpu_srnd_en_MASK | ALU_ROUNDING_MODE_Gasket_srnd_en_MASK | ALU_ROUNDING_MODE_Packer_srnd_en_MASK;
alu_payload.f.ALU_ROUNDING_MODE_Fpu_srnd_en = fpu_srnd_en;
alu_payload.f.ALU_ROUNDING_MODE_Gasket_srnd_en = pack_srnd_en;
Expand Down
56 changes: 50 additions & 6 deletions llk_lib/llk_math_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,22 @@

using namespace ckernel::math;

template <bool untilize_en=false, bool row_pool=false>
inline void _llk_math_hw_configure_(const std::uint32_t srca_data_format, const std::uint32_t srcb_data_format) {
TTI_STALLWAIT(p_stall::STALL_CFG, p_stall::MATH | p_stall::WAIT_SFPU);
uint exp_width = ((uint)srca_data_format>>2)&0x1; //0=5-bit, 1=8-bit
uint int8_math_enabled = ((uint)(srca_data_format & 0xF) == (uint)DataFormat::Int8) ||
((uint)(srcb_data_format & 0xF) == (uint)DataFormat::Int8) ||
((uint)srca_data_format == (uint)DataFormat::Int32) ||
((uint)srcb_data_format == (uint)DataFormat::Int32);
uint srcb_format = (row_pool ? ((uint)DataFormat::Float16 | (exp_width<<2)) : srcb_data_format);
uint config_data = (srca_data_format << ALU_FORMAT_SPEC_REG0_SrcA_SHAMT) |
(srcb_format << ALU_FORMAT_SPEC_REG1_SrcB_SHAMT) |
(int8_math_enabled << ALU_ACC_CTRL_INT8_math_enabled_SHAMT);
constexpr uint config_mask = ALU_FORMAT_SPEC_REG0_SrcA_MASK | ALU_FORMAT_SPEC_REG1_SrcB_MASK | ALU_ACC_CTRL_INT8_math_enabled_MASK;
cfg_reg_rmw_tensix<ALU_FORMAT_SPEC_REG0_SrcA_ADDR32, 0, config_mask>(config_data);
}

template <DstSync Dst>
inline void _llk_math_wait_for_dest_available_() {
// These liteweight functions for sync with packer imply
Expand Down Expand Up @@ -112,19 +128,47 @@ inline void _llk_math_debug_dump_seek_(std::uint8_t offset) {
debug_dump_seek(offset);
}

template <bool to_from_int8=false, bool is_fp32_dest_acc_en=false>
inline void _llk_math_reconfig_data_format_srca_(const std::uint32_t srca_data_format) {
cfg_reg_rmw_tensix<ALU_FORMAT_SPEC_REG0_SrcA_RMW>(srca_data_format);
if constexpr (to_from_int8) {
static_assert(is_fp32_dest_acc_en, "Reconfiguring math to/from Int8 formats requires FP32 Dest mode enabled");
TTI_STALLWAIT(p_stall::STALL_CFG, p_stall::MATH | p_stall::WAIT_SFPU);
uint int8_math_enabled = ((uint)(srca_data_format & 0xF) == (uint)DataFormat::Int8) ||
((uint)srca_data_format == (uint)DataFormat::Int32);
uint config_data = (srca_data_format << ALU_FORMAT_SPEC_REG0_SrcA_SHAMT) | (int8_math_enabled << ALU_ACC_CTRL_INT8_math_enabled_SHAMT);
constexpr uint config_mask = ALU_FORMAT_SPEC_REG0_SrcA_MASK | ALU_ACC_CTRL_INT8_math_enabled_MASK;
cfg_reg_rmw_tensix<ALU_FORMAT_SPEC_REG0_SrcA_ADDR32, 0, config_mask>(config_data);
}
}

template <bool to_from_int8=false, bool is_fp32_dest_acc_en=false>
inline void _llk_math_reconfig_data_format_srcb_(const std::uint32_t srcb_data_format) {
cfg_reg_rmw_tensix<ALU_FORMAT_SPEC_REG1_SrcB_RMW>(srcb_data_format);
if constexpr (to_from_int8) {
static_assert(is_fp32_dest_acc_en, "Reconfiguring math to/from Int8 formats requires FP32 Dest mode enabled");
TTI_STALLWAIT(p_stall::STALL_CFG, p_stall::MATH | p_stall::WAIT_SFPU);
uint int8_math_enabled = ((uint)(srcb_data_format & 0xF) == (uint)DataFormat::Int8) ||
((uint)srcb_data_format == (uint)DataFormat::Int32);
uint config_data = (srcb_data_format << ALU_FORMAT_SPEC_REG1_SrcB_SHAMT) | (int8_math_enabled << ALU_ACC_CTRL_INT8_math_enabled_SHAMT);
constexpr uint config_mask = ALU_FORMAT_SPEC_REG1_SrcB_MASK | ALU_ACC_CTRL_INT8_math_enabled_MASK;
cfg_reg_rmw_tensix<ALU_FORMAT_SPEC_REG0_SrcA_ADDR32, 0, config_mask>(config_data);
}
}

template <bool to_from_int8=false, bool is_fp32_dest_acc_en=false>
inline void _llk_math_reconfig_data_format_(const std::uint32_t srca_data_format, const std::uint32_t srcb_data_format) {

uint config_data = (srca_data_format << ALU_FORMAT_SPEC_REG0_SrcA_SHAMT) | (srcb_data_format << ALU_FORMAT_SPEC_REG1_SrcB_SHAMT);
constexpr uint config_mask = ALU_FORMAT_SPEC_REG0_SrcA_MASK | ALU_FORMAT_SPEC_REG1_SrcB_MASK;
cfg_reg_rmw_tensix<ALU_FORMAT_SPEC_REG0_SrcA_ADDR32, 0, config_mask>(config_data);
if constexpr (to_from_int8) {
static_assert(is_fp32_dest_acc_en, "Reconfiguring math to/from Int8 formats requires FP32 Dest mode enabled");
TTI_STALLWAIT(p_stall::STALL_CFG, p_stall::MATH | p_stall::WAIT_SFPU);
uint int8_math_enabled = ((uint)(srca_data_format & 0xF) == (uint)DataFormat::Int8) ||
((uint)(srcb_data_format & 0xF) == (uint)DataFormat::Int8) ||
((uint)srca_data_format == (uint)DataFormat::Int32) ||
((uint)srcb_data_format == (uint)DataFormat::Int32);
uint config_data = (srca_data_format << ALU_FORMAT_SPEC_REG0_SrcA_SHAMT) |
(srcb_data_format << ALU_FORMAT_SPEC_REG1_SrcB_SHAMT) |
(int8_math_enabled << ALU_ACC_CTRL_INT8_math_enabled_SHAMT);
constexpr uint config_mask = ALU_FORMAT_SPEC_REG0_SrcA_MASK | ALU_FORMAT_SPEC_REG1_SrcB_MASK | ALU_ACC_CTRL_INT8_math_enabled_MASK;
cfg_reg_rmw_tensix<ALU_FORMAT_SPEC_REG0_SrcA_ADDR32, 0, config_mask>(config_data);
}
}

inline std::uint32_t _llk_math_get_compute_special_value_flags_() {
Expand Down
12 changes: 12 additions & 0 deletions llk_lib/llk_unpack_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,27 @@ inline void _llk_unpack_config_tile_dim_srcb_impl_(const std::uint32_t face_r_di
cfg_reg_rmw_tensix<THCON_SEC1_REG0_TileDescriptor_ADDR32+1, 16, 0xffff0000>(num_faces);
}

template <bool to_from_int8=false, bool is_fp32_dest_acc_en=false>
inline void _llk_unpack_reconfig_data_format_srca_impl_(const std::uint32_t unpack_src_format, const std::uint32_t unpack_dst_format, const std::uint32_t tile_size)
{
TTI_STALLWAIT(p_stall::STALL_CFG, p_stall::UNPACK0);
if constexpr (to_from_int8) {
static_assert(is_fp32_dest_acc_en, "Reconfiguring unpack to/from Int8 formats requires FP32 Dest mode enabled");
cfg_reg_rmw_tensix<ALU_FORMAT_SPEC_REG0_SrcAUnsigned_RMW>(((uint)unpack_src_format == (uint)DataFormat::UInt8) ? 1 : 0);
}
cfg_reg_rmw_tensix<THCON_SEC0_REG0_TileDescriptor_ADDR32, 0, 0x0f>(unpack_src_format);
cfg_reg_rmw_tensix<THCON_SEC0_REG2_Out_data_format_RMW>(unpack_dst_format);
TT_SETDMAREG(0, LOWER_HALFWORD(tile_size), 0, LO_16(p_gpr_unpack::TILE_SIZE_A)); // update gpr which holds tile size A
}

template <bool to_from_int8=false, bool is_fp32_dest_acc_en=false>
inline void _llk_unpack_reconfig_data_format_srcb_impl_(const std::uint32_t unpack_src_format, const std::uint32_t unpack_dst_format, const std::uint32_t tile_size)
{
TTI_STALLWAIT(p_stall::STALL_CFG, p_stall::UNPACK1);
if constexpr (to_from_int8) {
static_assert(is_fp32_dest_acc_en, "Reconfiguring unpack to/from Int8 formats requires FP32 Dest mode enabled");
cfg_reg_rmw_tensix<ALU_FORMAT_SPEC_REG0_SrcBUnsigned_RMW>(((uint)unpack_src_format == (uint)DataFormat::UInt8) ? 1 : 0);
}
cfg_reg_rmw_tensix<THCON_SEC1_REG0_TileDescriptor_ADDR32, 0, 0x0f>(unpack_src_format);
cfg_reg_rmw_tensix<THCON_SEC1_REG2_Out_data_format_RMW>(unpack_dst_format);
TT_SETDMAREG(0, LOWER_HALFWORD(tile_size), 0, LO_16(p_gpr_unpack::TILE_SIZE_B)); // update gpr which holds tile size B
Expand Down

0 comments on commit bbe1d58

Please sign in to comment.