diff --git a/common/inc/cunpack_common.h b/common/inc/cunpack_common.h index 96c55f1..3a6a5d9 100644 --- a/common/inc/cunpack_common.h +++ b/common/inc/cunpack_common.h @@ -236,11 +236,8 @@ namespace ckernel::unpacker ((uint)unpA_dst_format == (uint)DataFormat::Int32) || ((uint)unpB_dst_format == (uint)DataFormat::Int32); - constexpr uint alu_format_mask = ALU_FORMAT_SPEC_REG0_SrcA_MASK | ALU_FORMAT_SPEC_REG1_SrcB_MASK | - ALU_FORMAT_SPEC_REG0_SrcAUnsigned_MASK | ALU_FORMAT_SPEC_REG0_SrcBUnsigned_MASK; - alu_payload.f.ALU_FORMAT_SPEC_REG0_SrcA = unpA_dst_format; - alu_payload.f.ALU_FORMAT_SPEC_REG1_SrcB = row_pool ? ((uint) DataFormat::Float16 | (exp_width<<2)) : unpB_dst_format; - + constexpr uint alu_format_mask = ALU_FORMAT_SPEC_REG0_SrcAUnsigned_MASK | ALU_FORMAT_SPEC_REG0_SrcBUnsigned_MASK; + if ((uint)unpA_src_format == (uint)DataFormat::UInt8) { alu_payload.f.ALU_FORMAT_SPEC_REG0_SrcAUnsigned = 1; } @@ -252,10 +249,9 @@ namespace ckernel::unpacker // NOTE: This assumes these config fields are adjacent and in same register!! static_assert(ALU_ACC_CTRL_Fp32_enabled_ADDR32 == ALU_FORMAT_SPEC_REG0_SrcA_ADDR32); static_assert(ALU_ACC_CTRL_Fp32_enabled_ADDR32 == ALU_ACC_CTRL_SFPU_Fp32_enabled_ADDR32); - constexpr uint alu_dest_format_mask = ALU_ACC_CTRL_INT8_math_enabled_MASK | ALU_ACC_CTRL_SFPU_Fp32_enabled_MASK | ALU_ACC_CTRL_Fp32_enabled_MASK; + constexpr uint alu_dest_format_mask = ALU_ACC_CTRL_SFPU_Fp32_enabled_MASK | ALU_ACC_CTRL_Fp32_enabled_MASK; alu_payload.f.ALU_ACC_CTRL_Fp32_enabled = fp32_dest_acc_en; alu_payload.f.ALU_ACC_CTRL_SFPU_Fp32_enabled = fp32_dest_acc_en; - alu_payload.f.ALU_ACC_CTRL_INT8_math_enabled = int8_math_enabled; constexpr uint alu_stoch_rnd_mask = ALU_ROUNDING_MODE_Fpu_srnd_en_MASK | ALU_ROUNDING_MODE_Gasket_srnd_en_MASK | ALU_ROUNDING_MODE_Packer_srnd_en_MASK; alu_payload.f.ALU_ROUNDING_MODE_Fpu_srnd_en = fpu_srnd_en; alu_payload.f.ALU_ROUNDING_MODE_Gasket_srnd_en = pack_srnd_en; diff --git a/llk_lib/llk_math_common.h b/llk_lib/llk_math_common.h index eda8802..4f46931 100644 --- a/llk_lib/llk_math_common.h +++ b/llk_lib/llk_math_common.h @@ -14,6 +14,22 @@ using namespace ckernel::math; +template +inline void _llk_math_hw_configure_(const std::uint32_t srca_data_format, const std::uint32_t srcb_data_format) { + TTI_STALLWAIT(p_stall::STALL_CFG, p_stall::MATH | p_stall::WAIT_SFPU); + uint exp_width = ((uint)srca_data_format>>2)&0x1; //0=5-bit, 1=8-bit + uint int8_math_enabled = ((uint)(srca_data_format & 0xF) == (uint)DataFormat::Int8) || + ((uint)(srcb_data_format & 0xF) == (uint)DataFormat::Int8) || + ((uint)srca_data_format == (uint)DataFormat::Int32) || + ((uint)srcb_data_format == (uint)DataFormat::Int32); + uint srcb_format = (row_pool ? ((uint)DataFormat::Float16 | (exp_width<<2)) : srcb_data_format); + uint config_data = (srca_data_format << ALU_FORMAT_SPEC_REG0_SrcA_SHAMT) | + (srcb_format << ALU_FORMAT_SPEC_REG1_SrcB_SHAMT) | + (int8_math_enabled << ALU_ACC_CTRL_INT8_math_enabled_SHAMT); + constexpr uint config_mask = ALU_FORMAT_SPEC_REG0_SrcA_MASK | ALU_FORMAT_SPEC_REG1_SrcB_MASK | ALU_ACC_CTRL_INT8_math_enabled_MASK; + cfg_reg_rmw_tensix(config_data); +} + template inline void _llk_math_wait_for_dest_available_() { // These liteweight functions for sync with packer imply @@ -112,19 +128,47 @@ inline void _llk_math_debug_dump_seek_(std::uint8_t offset) { debug_dump_seek(offset); } +template inline void _llk_math_reconfig_data_format_srca_(const std::uint32_t srca_data_format) { - cfg_reg_rmw_tensix(srca_data_format); + if constexpr (to_from_int8) { + static_assert(is_fp32_dest_acc_en, "Reconfiguring math to/from Int8 formats requires FP32 Dest mode enabled"); + TTI_STALLWAIT(p_stall::STALL_CFG, p_stall::MATH | p_stall::WAIT_SFPU); + uint int8_math_enabled = ((uint)(srca_data_format & 0xF) == (uint)DataFormat::Int8) || + ((uint)srca_data_format == (uint)DataFormat::Int32); + uint config_data = (srca_data_format << ALU_FORMAT_SPEC_REG0_SrcA_SHAMT) | (int8_math_enabled << ALU_ACC_CTRL_INT8_math_enabled_SHAMT); + constexpr uint config_mask = ALU_FORMAT_SPEC_REG0_SrcA_MASK | ALU_ACC_CTRL_INT8_math_enabled_MASK; + cfg_reg_rmw_tensix(config_data); + } } +template inline void _llk_math_reconfig_data_format_srcb_(const std::uint32_t srcb_data_format) { - cfg_reg_rmw_tensix(srcb_data_format); + if constexpr (to_from_int8) { + static_assert(is_fp32_dest_acc_en, "Reconfiguring math to/from Int8 formats requires FP32 Dest mode enabled"); + TTI_STALLWAIT(p_stall::STALL_CFG, p_stall::MATH | p_stall::WAIT_SFPU); + uint int8_math_enabled = ((uint)(srcb_data_format & 0xF) == (uint)DataFormat::Int8) || + ((uint)srcb_data_format == (uint)DataFormat::Int32); + uint config_data = (srcb_data_format << ALU_FORMAT_SPEC_REG1_SrcB_SHAMT) | (int8_math_enabled << ALU_ACC_CTRL_INT8_math_enabled_SHAMT); + constexpr uint config_mask = ALU_FORMAT_SPEC_REG1_SrcB_MASK | ALU_ACC_CTRL_INT8_math_enabled_MASK; + cfg_reg_rmw_tensix(config_data); + } } +template inline void _llk_math_reconfig_data_format_(const std::uint32_t srca_data_format, const std::uint32_t srcb_data_format) { - - uint config_data = (srca_data_format << ALU_FORMAT_SPEC_REG0_SrcA_SHAMT) | (srcb_data_format << ALU_FORMAT_SPEC_REG1_SrcB_SHAMT); - constexpr uint config_mask = ALU_FORMAT_SPEC_REG0_SrcA_MASK | ALU_FORMAT_SPEC_REG1_SrcB_MASK; - cfg_reg_rmw_tensix(config_data); + if constexpr (to_from_int8) { + static_assert(is_fp32_dest_acc_en, "Reconfiguring math to/from Int8 formats requires FP32 Dest mode enabled"); + TTI_STALLWAIT(p_stall::STALL_CFG, p_stall::MATH | p_stall::WAIT_SFPU); + uint int8_math_enabled = ((uint)(srca_data_format & 0xF) == (uint)DataFormat::Int8) || + ((uint)(srcb_data_format & 0xF) == (uint)DataFormat::Int8) || + ((uint)srca_data_format == (uint)DataFormat::Int32) || + ((uint)srcb_data_format == (uint)DataFormat::Int32); + uint config_data = (srca_data_format << ALU_FORMAT_SPEC_REG0_SrcA_SHAMT) | + (srcb_data_format << ALU_FORMAT_SPEC_REG1_SrcB_SHAMT) | + (int8_math_enabled << ALU_ACC_CTRL_INT8_math_enabled_SHAMT); + constexpr uint config_mask = ALU_FORMAT_SPEC_REG0_SrcA_MASK | ALU_FORMAT_SPEC_REG1_SrcB_MASK | ALU_ACC_CTRL_INT8_math_enabled_MASK; + cfg_reg_rmw_tensix(config_data); + } } inline std::uint32_t _llk_math_get_compute_special_value_flags_() { diff --git a/llk_lib/llk_unpack_common.h b/llk_lib/llk_unpack_common.h index 15a384f..aa0891d 100644 --- a/llk_lib/llk_unpack_common.h +++ b/llk_lib/llk_unpack_common.h @@ -80,15 +80,27 @@ inline void _llk_unpack_config_tile_dim_srcb_impl_(const std::uint32_t face_r_di cfg_reg_rmw_tensix(num_faces); } +template inline void _llk_unpack_reconfig_data_format_srca_impl_(const std::uint32_t unpack_src_format, const std::uint32_t unpack_dst_format, const std::uint32_t tile_size) { + TTI_STALLWAIT(p_stall::STALL_CFG, p_stall::UNPACK0); + if constexpr (to_from_int8) { + static_assert(is_fp32_dest_acc_en, "Reconfiguring unpack to/from Int8 formats requires FP32 Dest mode enabled"); + cfg_reg_rmw_tensix(((uint)unpack_src_format == (uint)DataFormat::UInt8) ? 1 : 0); + } cfg_reg_rmw_tensix(unpack_src_format); cfg_reg_rmw_tensix(unpack_dst_format); TT_SETDMAREG(0, LOWER_HALFWORD(tile_size), 0, LO_16(p_gpr_unpack::TILE_SIZE_A)); // update gpr which holds tile size A } +template inline void _llk_unpack_reconfig_data_format_srcb_impl_(const std::uint32_t unpack_src_format, const std::uint32_t unpack_dst_format, const std::uint32_t tile_size) { + TTI_STALLWAIT(p_stall::STALL_CFG, p_stall::UNPACK1); + if constexpr (to_from_int8) { + static_assert(is_fp32_dest_acc_en, "Reconfiguring unpack to/from Int8 formats requires FP32 Dest mode enabled"); + cfg_reg_rmw_tensix(((uint)unpack_src_format == (uint)DataFormat::UInt8) ? 1 : 0); + } cfg_reg_rmw_tensix(unpack_src_format); cfg_reg_rmw_tensix(unpack_dst_format); TT_SETDMAREG(0, LOWER_HALFWORD(tile_size), 0, LO_16(p_gpr_unpack::TILE_SIZE_B)); // update gpr which holds tile size B