From fe1230db224f53fccac86e247a2fb38566e58bba Mon Sep 17 00:00:00 2001 From: Sasa Vuckovic Date: Wed, 8 May 2024 12:12:48 +0000 Subject: [PATCH] [sparse] Add ability to calculate per-core sparse matmul metadata (used for execution cycles estimation) (cherry picked from commit 8a1e05ebddaef0b0010949c1df27fd122ad7a1af) --- README.debug.md | 3 +- pybuda/csrc/balancer/balancer_utils.cpp | 194 +++++++++--------- pybuda/csrc/balancer/balancer_utils.hpp | 2 +- pybuda/csrc/balancer/legalizer/legalizer.cpp | 13 +- .../csrc/balancer/policies/policy_utils.hpp | 1 + pybuda/csrc/balancer/python_bindings.cpp | 15 +- pybuda/csrc/balancer/types.cpp | 9 +- pybuda/csrc/balancer/types.hpp | 64 +++++- .../csrc/shared_utils/sparse_matmul_utils.hpp | 143 ++++++++++++- pybuda/pybuda/op/eval/buda/matmul.py | 47 ++++- pybuda/pybuda/op/eval/common.py | 73 +++++-- pybuda/pybuda/op/eval/pybuda/matmul.py | 4 +- .../test/benchmark/benchmark/models/other.py | 13 +- 13 files changed, 426 insertions(+), 155 deletions(-) diff --git a/README.debug.md b/README.debug.md index babc9f9f..a28a6d51 100644 --- a/README.debug.md +++ b/README.debug.md @@ -127,10 +127,11 @@ ## Temp overrides * PYBUDA\_TEMP\_ENABLE\_NEW\_SPARSE\_ESTIMATES: Apply new formula to estimate the cycle count of sparse matmul ops (currently only support LoFi and HiFi2 fidelities) * PYBUDA\_TEMP\_SCALE\_SPARSE\_ESTIMATE\_ARGS: Scale counts of non-zero tiles, ublocks and strips to reflect the numbers that would end up on a single core, since BBE estimates always assume grid_size [1,1]. +* PYBUDA\_TEMP\_SPARSE\_ESTIMATE\_ARGS\_PER\_CORE: Instead of uniformly scaling sparse args (as happens in PYBUDA_TEMP_SCALE_SPARSE_ESTIMATE_ARGS), calculate them per core. To use, need set PYBUDA_TEMP_SCALE_SPARSE_ESTIMATE_ARGS to 1 as well. * PYBUDA\_TEMP\_ELT\_UNARY\_ESTIMATES\_LEGACY: Force legacy path of calculating execution cycles for eltwise unary ops - instead of calling into BBE, use hand-crafted FE-side logic * PYBUDA\_TEMP\_ENABLE\_NEW\_FUSED\_ESTIMATES: Apply new formula to estimate the cycle count of fused ops. The formula calls BBE to estimate each subop and sums up the results. * PYBUDA\_LEGACY\_KERNEL\_BROADCAST: Use legacy kernel broadcast detection path. Will detect fewer kernel broadcasts, and will oftentimes use more tiles (longer KBs). -* PYBUDA\_TEMP\_BALANCER\_MODEL\_PCIE\_BW: Estimate PCIe bandwidth in limiter cycles. (default: 1/True)) +* PYBUDA\_TEMP\_BALANCER\_MODEL\_PCIE\_BW: Estimate PCIe bandwidth in limiter cycles. (default: 1/True) * PYBUDA\_TEMP\_BALANCER\_DISABLE\_TARGET\_PROXIMITY: Disable target proximity in balancer. (default: 0/False) * PYBUDA\_TEMP\_DISABLE\_FJ\_NOP\_SCHEDULE\_FIX: This flag disables a fix that forces FJ buffering nops to be scheduled last. * PYBUDA\_TEMP\_FIX\_2351: Controls the fix for bug #2351 - fork-join can end up adding buffering nops and queues on same path, this control flag fixes it. diff --git a/pybuda/csrc/balancer/balancer_utils.cpp b/pybuda/csrc/balancer/balancer_utils.cpp index b58196a7..98602f3a 100644 --- a/pybuda/csrc/balancer/balancer_utils.cpp +++ b/pybuda/csrc/balancer/balancer_utils.cpp @@ -4,9 +4,13 @@ #include "balancer_utils.hpp" #include +#include #include +#include "balancer/types.hpp" #include "passes/t_stream.hpp" +#include "shared_utils/sparse_matmul_utils.hpp" +#include "utils/assert.hpp" #include "utils/hash_combine.hpp" #include "utils/logger.hpp" #include "utils/profile.hpp" @@ -583,124 +587,116 @@ ResourceUsage get_edge_resource_usage_simple( return usage; } -std::tuple get_sparse_matmul_metadata(balancer::OpModel const& op_model) +// Calculates sparse matmul metadata that is used for kernel execution cycle estimatation. These are: +// - number of non-zero tiles per core +// - number of non-zero ublocks per core +// - number of non-zero strips per core +// +std::shared_ptr get_sparse_matmul_metadata(balancer::OpModel const& op_model) { - int grid_r = op_model.grid_shape.r; - int u_rt = op_model.output_buffers[0].block_shape.ublock.rt; - int u_kt = op_model.input_buffers[1].block_shape.ublock.rt; - int t_factor_c = op_model.t_stream_factor.c; - int t_factor_r = op_model.t_stream_factor.r; + OpModel::SparseMetadata sparse_metadata(op_model.grid_shape.r); + const sparse::SparseBUDA& sparse_buda = *(op_model.sparse_buda); + const int grid_r = op_model.grid_shape.r; + const int u_rt = op_model.output_buffers[0].block_shape.ublock.rt; + const int u_kt = op_model.input_buffers[1].block_shape.ublock.rt; + const int m_k = sparse_buda.sparse_shape[1] / sparse::TILE_DIM / u_kt; + const int t_factor_c = op_model.t_stream_factor.c; + const int t_factor_r = op_model.t_stream_factor.r; + const int bcast_factor = sparse_buda.bcast_factor; // broadcast factor + const std::vector& sparse_indices = sparse_buda.sparse_indices; auto layout = sparse::SparseBUDA::create_layout( op_model.has_sparse_buffer() or env_as("PYBUDA_FORCE_SPARSE_BUFFER_LAYOUT"), op_model.t_stream_factor.dir.z_major(), op_model.fracture_factor); - int bcast_factor = sparse_buda.bcast_factor; - int zdim = sparse_buda.sparse_zs.size(); - // Initialize tiles/ublocks/strips counter - int sum_nz_tiles = 0; - int sum_nz_ublocks = 0; - int sum_nz_strips = 0; - constexpr int TILE_DIM = tt::sparse::TILE_DIM; + const int sparse_rt = + sparse_buda.sparse_shape[0] / sparse::TILE_DIM; // number of tiles in row dim of sparse tensor + const int bcast_slice_size = sparse_rt / bcast_factor; // size of each slice after bcast slicing + const int r_tiles_in_core = sparse_rt / grid_r / t_factor_r; // rows of tiles in a single core + const int dflow_factor = + (layout == sparse::SparseBUDA::Layout::ZMajorDataflow) ? sparse_rt / grid_r / t_factor_r / bcast_factor : 1; - struct CounterEntry + // Helper struct - accounting for metadata per core + // + struct MetadataPerCore { - std::unordered_set rt_ct_cmb; - std::unordered_set ubc_ubr_cmb; - std::unordered_set ubc_idxs; - int smallest_rt; - CounterEntry() : smallest_rt(INT_MAX){}; + unsigned int nz_tiles = 0; + std::unordered_set nz_ublocks; + std::unordered_set nz_strips; }; - std::vector counters; - int slice_count = grid_r * t_factor_r; + std::vector metadata_per_core(grid_r); - // Iterate throufh all sparse tensors - for (int z = 0; z < zdim; z++) + // Calculate metadata below + // Go through all the sparse indices (coordinates of non-zero tiles in the canonical sparse tensor), and map them to + // coords (core, t-dim, ublock, ...) based on the layout provided + // + // The math below, which is different for each layout, is not obvious. Need to add some documentation to explain. + // Tracking issue: + // # tenstorrent/pybuda#2601 + // + int r, t, target_rt; + for (const sparse::SparseIndex& index : sparse_indices) { - auto sparse = sparse_buda.sparse_zs[z]; - - // Take stat of the sparseCOO - int dflow_factor = (layout == sparse::SparseBUDA::Layout::ZMajorDataflow) - ? (sparse.rt() / grid_r / t_factor_r / bcast_factor) - : 1; - int num_slices = (layout == tt::sparse::SparseBUDA::Layout::Default) - ? grid_r * t_factor_r - : grid_r * t_factor_r * bcast_factor * dflow_factor; - std::int64_t slice_height = sparse.shape[0] / num_slices; - - std::vector ret(num_slices); - for (size_t idx = 0; idx < sparse.rows.size(); idx++) + if (layout == sparse::SparseBUDA::Layout::Default) { - // Count nonzero tiles/ublocks/strips in the SparseCOO - int ret_slice_idx = -1, rt = -1; - if (layout == tt::sparse::SparseBUDA::Layout::Default) - { - ret_slice_idx = sparse.rows[idx] / slice_height; - rt = (sparse.rows[idx] % slice_height) / TILE_DIM; - } - else if (layout == tt::sparse::SparseBUDA::Layout::ZMajor) - { - int slice_idx = sparse.rows[idx] / slice_height; - int inner_idx = (slice_idx / slice_count) * grid_r + (slice_idx % grid_r); - int slice_inner_idx = inner_idx % bcast_factor; - ret_slice_idx = (slice_idx % (grid_r * t_factor_r)) / grid_r * grid_r + (inner_idx / bcast_factor); - int new_rows = (sparse.rows[idx] % slice_height) + slice_height * slice_inner_idx; - rt = new_rows / TILE_DIM; - } - else - { - TT_ASSERT( - layout == sparse::SparseBUDA::Layout::BufferOp or - layout == sparse::SparseBUDA::Layout::ZMajorDataflow); - if (layout == sparse::SparseBUDA::Layout::ZMajorDataflow and - ((sparse.rt() / grid_r / t_factor_r) % bcast_factor != 0)) - continue; - - int slice_idx = sparse.rows[idx] / slice_height; - int inner_idx = (slice_idx % (dflow_factor * grid_r)) * bcast_factor + - (slice_idx / (dflow_factor * grid_r * t_factor_r)); - int slice_inner_idx = inner_idx % (bcast_factor * dflow_factor); - ret_slice_idx = (slice_idx / (dflow_factor * grid_r)) % t_factor_r * grid_r + - (inner_idx / (bcast_factor * dflow_factor)); - int new_rows = (sparse.rows[idx] % slice_height) + slice_height * slice_inner_idx; - rt = new_rows / TILE_DIM; - } - int ct = sparse.cols[idx] / TILE_DIM; - int ubr_idx = rt / u_rt; - int ubc_idx = ct / u_kt; - uint64_t rt_ct_key = (uint64_t(rt) << 32) | (ct & 0x0FFFF); - uint64_t ubc_ubr_key = (uint64_t(ubc_idx) << 32) | (ubr_idx & 0x0FFFF); - - // Add the metadata to counting struct - CounterEntry& e = ret[ret_slice_idx]; - e.rt_ct_cmb.insert(rt_ct_key); - e.ubc_ubr_cmb.insert(ubc_ubr_key); - e.ubc_idxs.insert(ubc_idx); - if (rt < ret[ret_slice_idx].smallest_rt) - ret[ret_slice_idx].smallest_rt = rt; + r = (index.rt % (sparse_rt / t_factor_r)) / r_tiles_in_core; + t = index.rt / (sparse_rt / t_factor_r); + target_rt = index.rt % r_tiles_in_core; } - - // Count tiles, ublocks, strips - for (int idx = 0; idx < slice_count; idx++) + else if (layout == sparse::SparseBUDA::Layout::ZMajor) { - const CounterEntry& e = ret[idx]; - sum_nz_tiles += e.rt_ct_cmb.size(); - sum_nz_ublocks += e.ubc_ubr_cmb.size(); - sum_nz_strips += e.ubc_idxs.size(); - if (e.smallest_rt >= 1 and e.smallest_rt < INT_MAX) - { - sum_nz_tiles++; - sum_nz_ublocks++; - } + int core_factor = std::max(1, grid_r % bcast_factor); + r = index.rt / bcast_slice_size * (grid_r / bcast_factor) + (index.rt / r_tiles_in_core) % core_factor; + t = (index.rt % (sparse_rt / bcast_factor)) / (sparse_rt / t_factor_r / bcast_factor); + target_rt = index.rt % r_tiles_in_core; } + else + { + TT_ASSERT( + layout == sparse::SparseBUDA::Layout::ZMajorDataflow || layout == sparse::SparseBUDA::Layout::BufferOp); + const int t_slice_size = bcast_slice_size / t_factor_r; // size of each slice after vslice(t_factor_r) + r = (index.rt / dflow_factor) % grid_r; + t = index.rt / t_slice_size % t_factor_r; + target_rt = (index.rt % bcast_slice_size * bcast_factor + index.rt / bcast_slice_size) % + (bcast_factor * dflow_factor); + } + + const int ublock_r_idx = target_rt / u_rt; + const int ublock_c_idx = index.ct / u_kt; + const int strip_idx = t * m_k + ublock_c_idx; + + // Insert metadata per core + // + // - non-zero tiles + // - non-zero ublocks + // - non-zero strips + // + metadata_per_core.at(r).nz_tiles++; + std::uint64_t ublock_key = + static_cast(t) << 48 | // 16 bits for t + ((static_cast(ublock_r_idx) & 0xFFFFFF) << 24) | // 24 bits for ublock_r + (static_cast(ublock_c_idx) & 0xFFFFFF); // 24 bits for ublock_c + metadata_per_core.at(r).nz_ublocks.insert(ublock_key); + metadata_per_core.at(r).nz_strips.insert(strip_idx); + } + + // Copy over the results to return object + // + for (int r = 0; r < grid_r; ++r) + { + + // Previous solution multiplied with t_factor_c - copying that behaviour here... However, it's not obvious + // whether this is correct + // # tenstorrent/pybuda#2598 + // + sparse_metadata.nz_tiles.at(r) = metadata_per_core.at(r).nz_tiles * t_factor_c; + sparse_metadata.nz_ublocks.at(r) = metadata_per_core.at(r).nz_ublocks.size() * t_factor_c; + sparse_metadata.nz_strips.at(r) = metadata_per_core.at(r).nz_strips.size() * t_factor_c; } - sum_nz_tiles *= t_factor_c; - sum_nz_ublocks *= t_factor_c; - sum_nz_strips *= t_factor_c; - return std::make_tuple<>(sum_nz_tiles, sum_nz_ublocks, sum_nz_strips); + return std::make_shared(sparse_metadata); } } // namespace tt::balancer diff --git a/pybuda/csrc/balancer/balancer_utils.hpp b/pybuda/csrc/balancer/balancer_utils.hpp index b3db4b4c..e239f7ab 100644 --- a/pybuda/csrc/balancer/balancer_utils.hpp +++ b/pybuda/csrc/balancer/balancer_utils.hpp @@ -134,7 +134,7 @@ ResourceUsage get_edge_resource_usage_simple( OpModel const &consumer_op_model, bool is_queue = false); -std::tuple get_sparse_matmul_metadata(balancer::OpModel const &grid); +std::shared_ptr get_sparse_matmul_metadata(balancer::OpModel const &grid); } // namespace tt::balancer diff --git a/pybuda/csrc/balancer/legalizer/legalizer.cpp b/pybuda/csrc/balancer/legalizer/legalizer.cpp index 18d4fcc6..395b6ef5 100644 --- a/pybuda/csrc/balancer/legalizer/legalizer.cpp +++ b/pybuda/csrc/balancer/legalizer/legalizer.cpp @@ -326,7 +326,10 @@ static std::pair calculate_streaming_pars( graph->data_operands(op_node)[0]->as()->get_sparse_buda(); std::vector& sparse_zs = sparse_buda.sparse_zs; auto layout = sparse::SparseBUDA::create_layout(sparse_buffer_enable, dir.z_major(), fracture_factor); - int bcast_factor = (layout == sparse::SparseBUDA::Layout::ZMajor) ? sparse_buda.bcast_factor : 1; + int bcast_factor = + (layout == sparse::SparseBUDA::Layout::ZMajor || layout == sparse::SparseBUDA::Layout::ZMajorDataflow) + ? sparse_buda.bcast_factor + : 1; // Each potential t needs to evenly divide output's r-dim but also in1's r-dim (in0's c-dim) std::vector operands = graph->operand_data_edges(op_node); @@ -940,9 +943,8 @@ static std::vector calculate_output_buffer_models_for_grid( static bool is_input_node_parameter_or_constant(const graphlib::Node* node) { return node->node_type() == graphlib::NodeType::kInput and - (node->as()->is_parameter() - or node->as()->is_constant() - or node->as()->is_optimizer_parameter()); + (node->as()->is_parameter() or node->as()->is_constant() or + node->as()->is_optimizer_parameter()); } static std::vector calculate_parameter_buffer_models_for_grid( @@ -955,8 +957,7 @@ static std::vector calculate_parameter_buffer_models_for_grid( parameter_buffers.resize(operands.size()); for (int input_idx = 0; input_idx < (int)operands.size(); ++input_idx) { - if (is_input_node_parameter_or_constant(operands[input_idx]) and - not force_dram_parameters) + if (is_input_node_parameter_or_constant(operands[input_idx]) and not force_dram_parameters) { TensorShape const& parameter_shape = op_shape.producer_shapes[input_idx]; int grid_r = FactorizedInt(parameter_shape.rt).get_nearest_factor_le(selected_grid.r); diff --git a/pybuda/csrc/balancer/policies/policy_utils.hpp b/pybuda/csrc/balancer/policies/policy_utils.hpp index bb831014..7f1006ff 100644 --- a/pybuda/csrc/balancer/policies/policy_utils.hpp +++ b/pybuda/csrc/balancer/policies/policy_utils.hpp @@ -392,6 +392,7 @@ const OpModel* pick_preferred_op_model( { auto op_models = current_graph_solver.at(op); const OpModel* prefered_op_model = nullptr; + for (const auto& op_model : op_models) { log_trace( diff --git a/pybuda/csrc/balancer/python_bindings.cpp b/pybuda/csrc/balancer/python_bindings.cpp index f0a053e2..59cdb4d6 100644 --- a/pybuda/csrc/balancer/python_bindings.cpp +++ b/pybuda/csrc/balancer/python_bindings.cpp @@ -10,6 +10,7 @@ #include "balancer/policies/policy_utils.hpp" #include "balancer/python_interface.hpp" #include "balancer/balancer_utils.hpp" +#include "balancer/types.hpp" #include "graph_lib/utils.hpp" #include "placer/placer.hpp" #include "passes/fuse_ops.hpp" @@ -360,9 +361,7 @@ void BalancerModule(py::module &m_balancer) { .def_readonly("output_buffers", &OpModel::output_buffers) .def_readonly("parameter_buffers", &OpModel::parameter_buffers) .def_readonly("is_sparse_matmul", &OpModel::is_sparse_matmul) - .def_readonly("nz_tiles", &OpModel::nz_tiles) - .def_readonly("nz_ublocks", &OpModel::nz_ublocks) - .def_readonly("nz_strips", &OpModel::nz_strips) + .def("get_sparse_metadata", &OpModel::get_sparse_metadata) .def("block_shape", &OpModel::block_shape) .def("__repr__", [](OpModel const& a) { std::stringstream ss; @@ -370,6 +369,16 @@ void BalancerModule(py::module &m_balancer) { return ss.str(); }); + py::class_(m_balancer, "SparseMetadata") + .def_readonly("nz_tiles", &OpModel::SparseMetadata::nz_tiles) + .def_readonly("nz_ublocks", &OpModel::SparseMetadata::nz_ublocks) + .def_readonly("nz_strips", &OpModel::SparseMetadata::nz_strips) + .def("__repr__", [](OpModel::SparseMetadata const& a) { + std::stringstream ss; + ss << a; + return ss.str(); + }); + py::class_(m_balancer, "FusedSubOpModel") .def_readonly("type", &FusedSubOpModel::type) .def_readonly("mblock_m", &FusedSubOpModel::mblock_m) diff --git a/pybuda/csrc/balancer/types.cpp b/pybuda/csrc/balancer/types.cpp index 7bab5e5a..0296bf0a 100644 --- a/pybuda/csrc/balancer/types.cpp +++ b/pybuda/csrc/balancer/types.cpp @@ -229,18 +229,15 @@ int OpModel::get_execution_cycles_uncached(std::string const &arch_name, bool th { std::shared_ptr fused_op = this->fused_op(); - // Calculate sparse-matmul metadata and cache the result + // Calculate sparse matmul metadata and write into OpModel's SparseMetadata struct if (env_as("PYBUDA_TEMP_ENABLE_NEW_SPARSE_ESTIMATES", false) and this->is_sparse_matmul and - this->nz_ublocks == -1) + this->sparse_metadata == nullptr) { auto mf = this->math_fidelity(); if (mf == tt::MathFidelity::HiFi2 or mf == tt::MathFidelity::LoFi) { - auto [nz_tiles, nz_ublocks, nz_strips] = get_sparse_matmul_metadata(*this); auto *p_this = const_cast(this); - p_this->nz_tiles = nz_tiles; - p_this->nz_ublocks = nz_ublocks; - p_this->nz_strips = nz_strips; + p_this->sparse_metadata = get_sparse_matmul_metadata(*this); } } diff --git a/pybuda/csrc/balancer/types.hpp b/pybuda/csrc/balancer/types.hpp index 7b3424e4..b2618657 100644 --- a/pybuda/csrc/balancer/types.hpp +++ b/pybuda/csrc/balancer/types.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -263,6 +264,24 @@ struct Padding // struct OpModel { + struct SparseMetadata { + std::vector nz_tiles; + std::vector nz_ublocks; + std::vector nz_strips; + + bool operator==(SparseMetadata const &other) const + { + return nz_tiles == other.nz_tiles and nz_ublocks == other.nz_ublocks and nz_strips == other.nz_strips; + } + + SparseMetadata(int grid_r) + { + nz_tiles.resize(grid_r, 0); + nz_ublocks.resize(grid_r, 0); + nz_strips.resize(grid_r, 0); + } + }; + UniqueId id; GridShape grid_shape; OpShape op_shape; @@ -272,10 +291,9 @@ struct OpModel bool sparse_buffer = false; bool is_sparse_matmul = false; bool consumes_rz_major = false; - int nz_tiles = 0; // sparse-matmul specific - int nz_ublocks = -1; // sparse-matmul specific - int nz_strips = -1; // sparse-matmul specific - const sparse::SparseBUDA *sparse_buda = nullptr; // sparse-matmul specific + const sparse::SparseBUDA *sparse_buda = nullptr; // sparse-matmul specific + std::shared_ptr sparse_metadata = nullptr; // sparse-matmul specific + // ^ using shared_ptr (vs unique_ptr) to allow for implicit copy construction of OpModel TStreamFactor t_stream_factor; int fracture_factor; int sparse_indices; @@ -372,8 +390,14 @@ struct OpModel return static_cast(get_output_bytes()) / get_execution_cycles(arch_name); } + // PyBind is acting strange with std::shared_ptr, and there seem to be some bugs reported on this, doing this for + // now... + const SparseMetadata get_sparse_metadata() { return *sparse_metadata.get(); } + bool operator==(OpModel const &other) const { return id == other.id; } + // This function is used to compare two OpModels for similarity. It is used for caching mechanisms. + // bool is_similar(OpModel const &other) const { return buda_op_node == other.buda_op_node @@ -382,12 +406,10 @@ struct OpModel and fracture_factor == other.fracture_factor and input_prologue == other.input_prologue and sparse_buffer == other.sparse_buffer - and nz_tiles == other.nz_tiles - and nz_ublocks == other.nz_ublocks - and nz_strips == other.nz_strips and padding == other.padding and input_buffers == other.input_buffers - and output_buffers == other.output_buffers; + and output_buffers == other.output_buffers + and is_similar_sparse_metadata(other); } TensorShape get_out_shape(bool post_t_stream = true) const @@ -409,6 +431,20 @@ struct OpModel private: int get_execution_cycles_uncached(std::string const &arch_name, bool theoretical = false) const; + bool is_similar_sparse_metadata(OpModel const &other) const + { + if (sparse_metadata == nullptr and other.sparse_metadata == nullptr) + { + return true; + } + + if (sparse_metadata == nullptr or other.sparse_metadata == nullptr) + { + return false; + } + + return *sparse_metadata == *other.sparse_metadata; + } }; using LegalOpModels = std::unordered_map>; @@ -861,6 +897,18 @@ inline std::ostream &ostream_with_indent(std::ostream &os, OpModel const &op_mod inline std::ostream &operator<<(std::ostream &os, OpModel const &op_model) { return ostream_with_indent(os, op_model); } +inline std::ostream &operator<<(std::ostream &os, OpModel::SparseMetadata const &sparse_metadata) +{ + os << "SparseMetadata{.nz_tiles = {"; + for (int nz_tile : sparse_metadata.nz_tiles) os << nz_tile << ", "; + os << "}, .nz_ublocks = {"; + for (int nz_ublock : sparse_metadata.nz_ublocks) os << nz_ublock << ", "; + os << "}, .nz_strips = {"; + for (int nz_strip : sparse_metadata.nz_strips) os << nz_strip << ", "; + os << "}}"; + return os; +} + inline std::ostream &operator<<(std::ostream &os, FusedSubOpModel const &sub_op_model) { os << "FusedSubOpModel{" << std::endl; diff --git a/pybuda/csrc/shared_utils/sparse_matmul_utils.hpp b/pybuda/csrc/shared_utils/sparse_matmul_utils.hpp index ffec94aa..b71499ed 100644 --- a/pybuda/csrc/shared_utils/sparse_matmul_utils.hpp +++ b/pybuda/csrc/shared_utils/sparse_matmul_utils.hpp @@ -347,12 +347,147 @@ struct SparseBUDA public: enum class Layout { - Default, - ZMajor, - ZMajorDataflow, - BufferOp, + Default, // Default layout + ZMajor, // Z-major layout, e.g. RZ streaming - go thru all Zs first, for a given R slice + ZMajorDataflow, // Z-major layout, special cased for sparse->dense dataflow (same as ZMajor, but slice + // vertically down to single tile) + BufferOp, // Used for sparse buffer matmuls }; + // A little more on layouts... + // + // Layout dictates what the sparse tensor of the sparse matmul will look like. Sparse tensors get divided into + // chunks, to accomodate for parallelization across cores and t-streaming. Depending on what the chunks look like, + // and which core receives which chunk, the performance profile of the sparse matmul can change - additionally, + // whether the parallelization is legal can also change. + // + // Sparse matmul ops are used for various scenarios, but most often they're there as building blocks of + // convolutions. Sparse tensors of such ops have a specific pattern that looks something like this (e.g. for a 2x2 + // convolution): + // [ + // 1 0 0 0 + // 0 1 0 0 + // 0 0 1 0 + // 0 0 0 1 + // 1 0 0 0 + // 0 1 0 0 + // 0 0 1 0 + // 0 0 0 1 + // 1 0 0 0 + // 0 1 0 0 + // 0 0 1 0 + // 0 0 0 1 + // 1 0 0 0 + // 0 1 0 0 + // 0 0 1 0 + // 0 0 0 1 + // ] + // So they will look like a set of diagonal matrices, each one representing a single kernel point of the + // convolution. + // + // Currently, there are 3 variants of the layout: Default, ZMajor, and ZMajorDataflow. + // + // Default: + // In this layout, the sparse tensor is in its original shape. + // + // ZMajor: + // In this layout, the goal is to eliminate serialization in execution between cores. It is easier to explain the + // transformation first, and then show why it eliminates serialization. + // + // The transformation is as follows: + // - Let's say we have 2x1 cores for this sparse op and t=2. The top core will handle the first two kernel + // points while the second core will handle the last two kernel points - this is by design of the layout. In t=0 + // the top core will get the top parts of the sparse tensor's first two kernel points, which is two 2x4x4 + // pieces. The bottom core will get the top parts of the sparse tensor's last two kernel points, which is also + // two 2x4x4 pieces. So for t=0, the first core's tensor will look like this: + // [ + // 1 0 0 0 <- first row of the first kernel point + // 0 1 0 0 <- second row of the first kernel point + // 1 0 0 0 <- first row of the second kernel point + // 0 1 0 0 <- second row of the second kernel point + // ] + // - And the second core, for t=0, it's tensor will look like this: + // [ + // 1 0 0 0 <- first row of the third kernel point + // 0 1 0 0 <- second row of the third kernel point + // 1 0 0 0 <- first row of the fourth kernel point + // 0 1 0 0 <- second row of the fourth kernel point + // ] + // - If we were to continue this for t=1, the first core's tensor would look like this: + // [ + // 0 0 1 0 <- third row of the first kernel point + // 0 0 0 1 <- fourth row of the first kernel point + // 0 0 1 0 <- third row of the second kernel point + // 0 0 0 1 <- fourth row of the second kernel point + // ] + // - And the second core, for t=1, it's tensor will look like this: + // [ + // 0 0 1 0 <- third row of the third kernel point + // 0 0 0 1 <- fourth row of the third kernel point + // 0 0 1 0 <- third row of the fourth kernel point + // 0 0 0 1 <- fourth row of the fourth kernel point + // ] + // + // From the perspective of the sparse matmul, where these sparse tensors are left operands, the ZMajor layout + // makes it so that all cores, in a given point in time, are working on the same inner dimension of the matmul. If + // we think about how the right operand is buffered (horizontal strips), and keep in mind that all cores receive + // the full right operand tensor, this layout makes it so that all cores are working on the same strip of the + // right operand tensor, which removes any serialization in execution. + // + // Naming comes from the fact that the sparse tensor is first vertically sliced by a factor of kernel points, + // which is in this case a vslice(4). Then, a core will work on a set of Zs for a given R slice, so in a way, if + // we fix the height that a core is reading (in this case 2-tile high), we go through the tensor in Z-major order, + // hence the name ZMajor for the layout. + // + // ZMajorDataflow: + // This layout is a special case of ZMajor, where the sparse tensor is vertically sliced all the way down to a + // single tile. The goal of this layout is to improve the dataflow between the sparse and dense matmuls - if we + // place a sparse&dense pair of ops next to each other, with the same height of cores, as we usually do, this + // layout will enable dataflow thru NOC to use direct pipes, which means that each sparse core will send its + // output directly and only to its corresponding core of dense matmul, which is right next to the sparse core. + // Hence the "Dataflow" in the name. + // + // Using the sparse tensor from the Default layout, the transformation to ZMajorDataflow is as follows: + // - Similar to ZMajor, we can first imagine that the sparse tensor is vertically sliced by a factor of kernel + // points, which is in this case a vslice(4), so the tensor goes from 16x4 to a 4x4x4. Then, imagine that the + // tensor is read completely in Z-major order, tile by tile. Let's say we have 2x1 cores for this sparse op and + // t=2. First core, for t=0, will get these tiles: + // [ + // 1 0 0 0 <- first tile of the first kernel point + // 1 0 0 0 <- first tile of the second kernel point + // 1 0 0 0 <- first tile of the third kernel point + // 1 0 0 0 <- first tile of the fourth kernel point + // ] + // - And the second core, for t=0, it's tensor will look like this: + // [ + // 0 1 0 0 <- second tile of the first kernel point + // 0 1 0 0 <- second tile of the second kernel point + // 0 1 0 0 <- second tile of the third kernel point + // 0 1 0 0 <- second tile of the fourth kernel point + // ] + // - If we were to continue this for t=1, the first core's tensor would look like this: + // [ + // 0 0 1 0 <- third tile of the first kernel point + // 0 0 1 0 <- third tile of the second kernel point + // 0 0 1 0 <- third tile of the third kernel point + // 0 0 1 0 <- third tile of the fourth kernel point + // ] + // - And the second core, for t=1, it's tensor will look like this: + // [ + // 0 0 0 1 <- fourth tile of the first kernel point + // 0 0 0 1 <- fourth tile of the second kernel point + // 0 0 0 1 <- fourth tile of the third kernel point + // 0 0 0 1 <- fourth tile of the fourth kernel point + // ] + // + // The main difference between ZMajor and ZMajorDataflow is that ZMajorDataflow is vertically sliced all the way + // down to a single tile, which enables the dataflow between the sparse and dense matmuls to use direct pipes. + // However, by doing this transformation, there is a side-effect on the output, the rows of the output are all + // mixed up, i.e. they're not in the order that the dense matmul "expects" them to be in. To correct this, a set + // of TM ops is applied to the output of the sparse matmul, which will reorder the rows back to the correct order. + // This set of TM ops might look complicated, but when worked out, it just makes it so that each sparse matmul + // core sends data only to its corresponding dense matmul core, which is right next to it. + static Layout create_layout(bool buffer_op, bool z_major, int fracture_factor); std::vector sparse_zs; diff --git a/pybuda/pybuda/op/eval/buda/matmul.py b/pybuda/pybuda/op/eval/buda/matmul.py index 3c9a4041..3e38b077 100644 --- a/pybuda/pybuda/op/eval/buda/matmul.py +++ b/pybuda/pybuda/op/eval/buda/matmul.py @@ -388,16 +388,55 @@ def input_ublock_order(type, attr, num_operands): def execution_cycles(type, arch_name, op_model, theoretical) -> int: + # Special handling for sparse matmul as the backend API assumes 1x1 grid, but each sparse matmul core can do + # different amount of work, depending on what the encodings (sparse tensor) look like. Call for each core to find + # the slowest one. + # + if op_model.is_sparse_matmul: + # Calculate cycles per core + # + if ( + os.environ.get("PYBUDA_TEMP_SCALE_SPARSE_ESTIMATE_ARGS", False) + and os.environ.get("PYBUDA_TEMP_SPARSE_ESTIMATE_ARGS_PER_CORE", False) + ): + cycles_to_return = 0 + for r in range(op_model.grid_shape.r): + # Generate op model desc for current core + # + op_model_desc = op_model_to_desc(type, arch_name, op_model, sparse_r=r) + + # Get execution cycles, try from cache first, if miss, then calculate + # + curr_cycles = 0 + compiler_cache_cycles = get_compiler_cached_cycles(op_model_desc) + if compiler_cache_cycles is not None: + curr_cycles = compiler_cache_cycles + else: + curr_cycles = get_op_model_execution_cycles(op_model_desc) + + # Save max cycles + # + cycles_to_return = max(cycles_to_return, curr_cycles) + else: + # Otherwise fallback to default behavior (create single op_model_desc, and let it decide whether to average + # parameters, or to sum and pretend everything is on a single core) + # + op_model_desc = op_model_to_desc(type, arch_name, op_model) + compiler_cache_cycles = get_compiler_cached_cycles(op_model_desc) + if compiler_cache_cycles is not None: + cycles_to_return = compiler_cache_cycles + else: + cycles_to_return = get_op_model_execution_cycles(op_model_desc) + + return cycles_to_return + # End sparse matmul exec cycles calculation + op_model_desc = op_model_to_desc(type, arch_name, op_model) compiler_cache_cycles = get_compiler_cached_cycles(op_model_desc) if compiler_cache_cycles is not None: return compiler_cache_cycles - is_sparse = op_model.is_sparse_matmul - if is_sparse: - return get_op_model_execution_cycles(op_model_desc) - # Math fidelity and data format are just estimated guesses for now math_fid = math_fidelity_to_multiplier(op_model.math_fidelity()) u_kt = op_model.input_buffers[0].block_shape.ublock.ct diff --git a/pybuda/pybuda/op/eval/common.py b/pybuda/pybuda/op/eval/common.py index 6c3eddc4..65d1227d 100644 --- a/pybuda/pybuda/op/eval/common.py +++ b/pybuda/pybuda/op/eval/common.py @@ -331,7 +331,14 @@ def data_format_to_int(df: DataFormat) -> int: return 11 raise RuntimeError(f"Unknown data format {df}") -def op_model_to_desc(type: str, arch_name: str, op_model: OpModel, sub_op_model: FusedSubOpModel = None) -> OpModelDesc: +def op_model_to_desc( + type: str, + arch_name: str, + op_model: OpModel, + sub_op_model: FusedSubOpModel = None, + sparse_grid_row=-1, +) -> OpModelDesc: + desc = OpModelDesc() desc.arch = arch_name desc.data_format = op_model.data_format @@ -339,6 +346,13 @@ def op_model_to_desc(type: str, arch_name: str, op_model: OpModel, sub_op_model: desc.t = op_model.output_buffers[0].block_shape.t desc.approx_mode = False + # tenstorrent/pybuda#2565 + # Example overrides (DataFormat or MathFidelity) to target sparse estimates v2 when op is originally Bfp8_b+HiFi2 + # if type == "matmul" and op_model.is_sparse_matmul: + # if desc.data_format == DataFormat.Bfp8_b and desc.math_fidelity == MathFidelity.HiFi2: + # # desc.data_format = DataFormat.Float16_b # Override DataFormat + # desc.math_fidelity = MathFidelity.LoFi # Override MathFidelity + if op_model.op_type() == "fused_op": desc.type = sub_op_model.type desc.mblock_m = sub_op_model.mblock_m @@ -366,22 +380,49 @@ def op_model_to_desc(type: str, arch_name: str, op_model: OpModel, sub_op_model: desc.mblock_k = op_model.op_shape.inputs[1].rt // desc.ublock_kt desc.sparse_indices = op_model.sparse_indices if os.environ.get("PYBUDA_TEMP_ENABLE_NEW_SPARSE_ESTIMATES", False): - desc.sparse_nz_ublocks = op_model.nz_ublocks - desc.sparse_nz_strips = op_model.nz_strips - - # op model descriptor assumes grid_size [1, 1], so we need to scale down the number of - # sparse tiles, ublocks and strips to what is expected to end up on a single core - if os.environ.get("PYBUDA_TEMP_SCALE_SPARSE_ESTIMATE_ARGS", False): - if op_model.nz_tiles > 1: - desc.sparse_indices = max(op_model.nz_tiles // op_model.grid_shape.r, 1) + sparse_metadata = op_model.get_sparse_metadata() + desc.sparse_indices = sum(sparse_metadata.nz_tiles) + desc.sparse_nz_ublocks = sum(sparse_metadata.nz_ublocks) + desc.sparse_nz_strips = sum(sparse_metadata.nz_strips) + + # Op model descriptor assumes grid_size [1, 1], so we need to scale down the parameters to what is + # expected to end up on a single core. Initially, we did this by averaging the parameters with the + # number of cores. However, not all the cores perform the same amount of work, so we need to + # calculate parameters per core. We keep both of these modes in this transition period. + # + # PYBUDA_TEMP_SCALE_SPARSE_ESTIMATE_ARGS must be set to true to enable any of the mentioned modes. + # + # Mode 1: + # Average the parameters (by default) + # Mode 2: + # Scale the parameters by the number of cores (needs the env var + # "PYBUDA_TEMP_SPARSE_ESTIMATE_ARGS_PER_CORE" to be set to true) + # + per_core_mode = os.environ.get("PYBUDA_TEMP_SPARSE_ESTIMATE_ARGS_PER_CORE", False) + if not per_core_mode: + # Average mode + # + nz_tiles = sum(sparse_metadata.nz_tiles) + nz_ublocks = sum(sparse_metadata.nz_ublocks) + nz_strips = sum(sparse_metadata.nz_strips) + + if nz_tiles > 1: + desc.sparse_indices = max(nz_tiles // op_model.grid_shape.r, 1) else: - desc.sparse_indices = op_model.nz_tiles - - if op_model.nz_ublocks > 1: - desc.sparse_nz_ublocks = max(op_model.nz_ublocks // op_model.grid_shape.r, 1) - - if op_model.nz_strips > 1: - desc.sparse_nz_strips = max(op_model.nz_strips // op_model.grid_shape.r, 1) + desc.sparse_indices = nz_tiles + + if nz_ublocks > 1: + desc.sparse_nz_ublocks = max(nz_ublocks // op_model.grid_shape.r, 1) + + if nz_strips > 1: + desc.sparse_nz_strips = max(nz_strips // op_model.grid_shape.r, 1) + else: + # Per core mode + # + assert sparse_grid_row != -1 # Must provide which row of cores we're fetching the estimates for + desc.sparse_indices = sparse_metadata.nz_tiles[sparse_grid_row] + desc.sparse_nz_ublocks = sparse_metadata.nz_ublocks[sparse_grid_row] + desc.sparse_nz_strips = sparse_metadata.nz_strips[sparse_grid_row] else: # old sparse estimates if os.environ.get("PYBUDA_TEMP_SCALE_SPARSE_ESTIMATE_ARGS", False): diff --git a/pybuda/pybuda/op/eval/pybuda/matmul.py b/pybuda/pybuda/op/eval/pybuda/matmul.py index de7a433c..ae57acf8 100644 --- a/pybuda/pybuda/op/eval/pybuda/matmul.py +++ b/pybuda/pybuda/op/eval/pybuda/matmul.py @@ -179,7 +179,7 @@ def lower(type, attr, buda_attr, lc, ops, outputs): picker = lc.get_pytorch_tensor(in0) zdim = 1 if len(picker.shape) < 3 else picker.shape[-3] - z_bcast_factor = 1 if len(attr) < 2 else attr[1] + z_bcast_factor = 1 if len(attr) < 2 else attr[1] # set in sparse matmul's decompose # We can fully fracture kH * kW max_fracture_factor = z_bcast_factor if is_kernel_fracturing_candidate(ops, z_bcast_factor) else 1 @@ -255,6 +255,8 @@ def decompose(type, attr, dc, inputs): accumulate = (len(attr) >= 1) and bool(attr[0]) z_bcast_factor = zdim if (zdim > 1 and in1.shape[-3] == 1) else 1 + # In case of convolutions, z_bcast_factor is the volume of the conv's kernel (kernel_height * kernel_width) + if z_bcast_factor > 1: picker = torch.cat([picker[0][z] for z in range(z_bcast_factor)]) sparse = dc.tensor(picker) diff --git a/pybuda/test/benchmark/benchmark/models/other.py b/pybuda/test/benchmark/benchmark/models/other.py index ac6a257f..43bc45ca 100644 --- a/pybuda/test/benchmark/benchmark/models/other.py +++ b/pybuda/test/benchmark/benchmark/models/other.py @@ -5,7 +5,7 @@ Catch-all for random perf testing """ -import numpy as np +import os import pybuda import torch @@ -84,21 +84,21 @@ def forward(self, x): @benchmark_model(configs=["224"]) -def big_conv(training: bool, config: str, microbatch: int, devtype: str, arch: str): +def big_conv(training: bool, config: str, microbatch: int, devtype: str, arch: str, data_type: str, math_fidelity: str): if config == "224": input_size = (224, 224) cin = 3 cout = 64 kH = 7 kW = 7 - stride = 2 - padding = 3 + stride = 1 + padding = "same" dilation = 1 else: raise RuntimeError(f"Invalid config: {config}") if microbatch == 0: - microbatch = 1 + microbatch = 64 mod = ConvTModule( name="big_conv_benchmark", @@ -113,7 +113,8 @@ def big_conv(training: bool, config: str, microbatch: int, devtype: str, arch: s bias=False) compiler_cfg = _get_global_compiler_config() - compiler_cfg.balancer_policy = "CNN" + compiler_cfg.balancer_policy = "Ribbon" + os.environ["PYBUDA_RIBBON2"] = "1" models = {"tt": mod} inputs = [torch.rand(microbatch, cin, input_size[0], input_size[1])]