From d00a01052402c6b0c6157be96c865db74048689d Mon Sep 17 00:00:00 2001 From: Sasa Vuckovic Date: Wed, 8 May 2024 15:24:39 +0000 Subject: [PATCH] Revert "[sparse] Add ability to calculate per-core sparse matmul metadata (used for execution cycles estimation)" (cherry picked from commit a1bae21af4c89bcbdb7e900a8c3f8b6095c7925d) --- README.debug.md | 3 +- pybuda/csrc/balancer/balancer_utils.cpp | 194 +++++++++--------- pybuda/csrc/balancer/balancer_utils.hpp | 2 +- pybuda/csrc/balancer/legalizer/legalizer.cpp | 13 +- .../csrc/balancer/policies/policy_utils.hpp | 1 - pybuda/csrc/balancer/python_bindings.cpp | 15 +- pybuda/csrc/balancer/types.cpp | 9 +- pybuda/csrc/balancer/types.hpp | 64 +----- .../csrc/shared_utils/sparse_matmul_utils.hpp | 143 +------------ pybuda/pybuda/op/eval/buda/matmul.py | 47 +---- pybuda/pybuda/op/eval/common.py | 73 ++----- pybuda/pybuda/op/eval/pybuda/matmul.py | 4 +- .../test/benchmark/benchmark/models/other.py | 13 +- 13 files changed, 155 insertions(+), 426 deletions(-) diff --git a/README.debug.md b/README.debug.md index a28a6d51..babc9f9f 100644 --- a/README.debug.md +++ b/README.debug.md @@ -127,11 +127,10 @@ ## Temp overrides * PYBUDA\_TEMP\_ENABLE\_NEW\_SPARSE\_ESTIMATES: Apply new formula to estimate the cycle count of sparse matmul ops (currently only support LoFi and HiFi2 fidelities) * PYBUDA\_TEMP\_SCALE\_SPARSE\_ESTIMATE\_ARGS: Scale counts of non-zero tiles, ublocks and strips to reflect the numbers that would end up on a single core, since BBE estimates always assume grid_size [1,1]. -* PYBUDA\_TEMP\_SPARSE\_ESTIMATE\_ARGS\_PER\_CORE: Instead of uniformly scaling sparse args (as happens in PYBUDA_TEMP_SCALE_SPARSE_ESTIMATE_ARGS), calculate them per core. To use, need set PYBUDA_TEMP_SCALE_SPARSE_ESTIMATE_ARGS to 1 as well. * PYBUDA\_TEMP\_ELT\_UNARY\_ESTIMATES\_LEGACY: Force legacy path of calculating execution cycles for eltwise unary ops - instead of calling into BBE, use hand-crafted FE-side logic * PYBUDA\_TEMP\_ENABLE\_NEW\_FUSED\_ESTIMATES: Apply new formula to estimate the cycle count of fused ops. The formula calls BBE to estimate each subop and sums up the results. * PYBUDA\_LEGACY\_KERNEL\_BROADCAST: Use legacy kernel broadcast detection path. Will detect fewer kernel broadcasts, and will oftentimes use more tiles (longer KBs). -* PYBUDA\_TEMP\_BALANCER\_MODEL\_PCIE\_BW: Estimate PCIe bandwidth in limiter cycles. (default: 1/True) +* PYBUDA\_TEMP\_BALANCER\_MODEL\_PCIE\_BW: Estimate PCIe bandwidth in limiter cycles. (default: 1/True)) * PYBUDA\_TEMP\_BALANCER\_DISABLE\_TARGET\_PROXIMITY: Disable target proximity in balancer. (default: 0/False) * PYBUDA\_TEMP\_DISABLE\_FJ\_NOP\_SCHEDULE\_FIX: This flag disables a fix that forces FJ buffering nops to be scheduled last. * PYBUDA\_TEMP\_FIX\_2351: Controls the fix for bug #2351 - fork-join can end up adding buffering nops and queues on same path, this control flag fixes it. diff --git a/pybuda/csrc/balancer/balancer_utils.cpp b/pybuda/csrc/balancer/balancer_utils.cpp index 98602f3a..b58196a7 100644 --- a/pybuda/csrc/balancer/balancer_utils.cpp +++ b/pybuda/csrc/balancer/balancer_utils.cpp @@ -4,13 +4,9 @@ #include "balancer_utils.hpp" #include -#include #include -#include "balancer/types.hpp" #include "passes/t_stream.hpp" -#include "shared_utils/sparse_matmul_utils.hpp" -#include "utils/assert.hpp" #include "utils/hash_combine.hpp" #include "utils/logger.hpp" #include "utils/profile.hpp" @@ -587,116 +583,124 @@ ResourceUsage get_edge_resource_usage_simple( return usage; } -// Calculates sparse matmul metadata that is used for kernel execution cycle estimatation. These are: -// - number of non-zero tiles per core -// - number of non-zero ublocks per core -// - number of non-zero strips per core -// -std::shared_ptr get_sparse_matmul_metadata(balancer::OpModel const& op_model) +std::tuple get_sparse_matmul_metadata(balancer::OpModel const& op_model) { - OpModel::SparseMetadata sparse_metadata(op_model.grid_shape.r); - + int grid_r = op_model.grid_shape.r; + int u_rt = op_model.output_buffers[0].block_shape.ublock.rt; + int u_kt = op_model.input_buffers[1].block_shape.ublock.rt; + int t_factor_c = op_model.t_stream_factor.c; + int t_factor_r = op_model.t_stream_factor.r; const sparse::SparseBUDA& sparse_buda = *(op_model.sparse_buda); - const int grid_r = op_model.grid_shape.r; - const int u_rt = op_model.output_buffers[0].block_shape.ublock.rt; - const int u_kt = op_model.input_buffers[1].block_shape.ublock.rt; - const int m_k = sparse_buda.sparse_shape[1] / sparse::TILE_DIM / u_kt; - const int t_factor_c = op_model.t_stream_factor.c; - const int t_factor_r = op_model.t_stream_factor.r; - const int bcast_factor = sparse_buda.bcast_factor; // broadcast factor - const std::vector& sparse_indices = sparse_buda.sparse_indices; auto layout = sparse::SparseBUDA::create_layout( op_model.has_sparse_buffer() or env_as("PYBUDA_FORCE_SPARSE_BUFFER_LAYOUT"), op_model.t_stream_factor.dir.z_major(), op_model.fracture_factor); + int bcast_factor = sparse_buda.bcast_factor; + int zdim = sparse_buda.sparse_zs.size(); - const int sparse_rt = - sparse_buda.sparse_shape[0] / sparse::TILE_DIM; // number of tiles in row dim of sparse tensor - const int bcast_slice_size = sparse_rt / bcast_factor; // size of each slice after bcast slicing - const int r_tiles_in_core = sparse_rt / grid_r / t_factor_r; // rows of tiles in a single core - const int dflow_factor = - (layout == sparse::SparseBUDA::Layout::ZMajorDataflow) ? sparse_rt / grid_r / t_factor_r / bcast_factor : 1; + // Initialize tiles/ublocks/strips counter + int sum_nz_tiles = 0; + int sum_nz_ublocks = 0; + int sum_nz_strips = 0; + constexpr int TILE_DIM = tt::sparse::TILE_DIM; - // Helper struct - accounting for metadata per core - // - struct MetadataPerCore + struct CounterEntry { - unsigned int nz_tiles = 0; - std::unordered_set nz_ublocks; - std::unordered_set nz_strips; + std::unordered_set rt_ct_cmb; + std::unordered_set ubc_ubr_cmb; + std::unordered_set ubc_idxs; + int smallest_rt; + CounterEntry() : smallest_rt(INT_MAX){}; }; - std::vector metadata_per_core(grid_r); + std::vector counters; + int slice_count = grid_r * t_factor_r; - // Calculate metadata below - // Go through all the sparse indices (coordinates of non-zero tiles in the canonical sparse tensor), and map them to - // coords (core, t-dim, ublock, ...) based on the layout provided - // - // The math below, which is different for each layout, is not obvious. Need to add some documentation to explain. - // Tracking issue: - // # tenstorrent/pybuda#2601 - // - int r, t, target_rt; - for (const sparse::SparseIndex& index : sparse_indices) + // Iterate throufh all sparse tensors + for (int z = 0; z < zdim; z++) { - if (layout == sparse::SparseBUDA::Layout::Default) - { - r = (index.rt % (sparse_rt / t_factor_r)) / r_tiles_in_core; - t = index.rt / (sparse_rt / t_factor_r); - target_rt = index.rt % r_tiles_in_core; - } - else if (layout == sparse::SparseBUDA::Layout::ZMajor) + auto sparse = sparse_buda.sparse_zs[z]; + + // Take stat of the sparseCOO + int dflow_factor = (layout == sparse::SparseBUDA::Layout::ZMajorDataflow) + ? (sparse.rt() / grid_r / t_factor_r / bcast_factor) + : 1; + int num_slices = (layout == tt::sparse::SparseBUDA::Layout::Default) + ? grid_r * t_factor_r + : grid_r * t_factor_r * bcast_factor * dflow_factor; + std::int64_t slice_height = sparse.shape[0] / num_slices; + + std::vector ret(num_slices); + for (size_t idx = 0; idx < sparse.rows.size(); idx++) { - int core_factor = std::max(1, grid_r % bcast_factor); - r = index.rt / bcast_slice_size * (grid_r / bcast_factor) + (index.rt / r_tiles_in_core) % core_factor; - t = (index.rt % (sparse_rt / bcast_factor)) / (sparse_rt / t_factor_r / bcast_factor); - target_rt = index.rt % r_tiles_in_core; + // Count nonzero tiles/ublocks/strips in the SparseCOO + int ret_slice_idx = -1, rt = -1; + if (layout == tt::sparse::SparseBUDA::Layout::Default) + { + ret_slice_idx = sparse.rows[idx] / slice_height; + rt = (sparse.rows[idx] % slice_height) / TILE_DIM; + } + else if (layout == tt::sparse::SparseBUDA::Layout::ZMajor) + { + int slice_idx = sparse.rows[idx] / slice_height; + int inner_idx = (slice_idx / slice_count) * grid_r + (slice_idx % grid_r); + int slice_inner_idx = inner_idx % bcast_factor; + ret_slice_idx = (slice_idx % (grid_r * t_factor_r)) / grid_r * grid_r + (inner_idx / bcast_factor); + int new_rows = (sparse.rows[idx] % slice_height) + slice_height * slice_inner_idx; + rt = new_rows / TILE_DIM; + } + else + { + TT_ASSERT( + layout == sparse::SparseBUDA::Layout::BufferOp or + layout == sparse::SparseBUDA::Layout::ZMajorDataflow); + if (layout == sparse::SparseBUDA::Layout::ZMajorDataflow and + ((sparse.rt() / grid_r / t_factor_r) % bcast_factor != 0)) + continue; + + int slice_idx = sparse.rows[idx] / slice_height; + int inner_idx = (slice_idx % (dflow_factor * grid_r)) * bcast_factor + + (slice_idx / (dflow_factor * grid_r * t_factor_r)); + int slice_inner_idx = inner_idx % (bcast_factor * dflow_factor); + ret_slice_idx = (slice_idx / (dflow_factor * grid_r)) % t_factor_r * grid_r + + (inner_idx / (bcast_factor * dflow_factor)); + int new_rows = (sparse.rows[idx] % slice_height) + slice_height * slice_inner_idx; + rt = new_rows / TILE_DIM; + } + int ct = sparse.cols[idx] / TILE_DIM; + int ubr_idx = rt / u_rt; + int ubc_idx = ct / u_kt; + uint64_t rt_ct_key = (uint64_t(rt) << 32) | (ct & 0x0FFFF); + uint64_t ubc_ubr_key = (uint64_t(ubc_idx) << 32) | (ubr_idx & 0x0FFFF); + + // Add the metadata to counting struct + CounterEntry& e = ret[ret_slice_idx]; + e.rt_ct_cmb.insert(rt_ct_key); + e.ubc_ubr_cmb.insert(ubc_ubr_key); + e.ubc_idxs.insert(ubc_idx); + if (rt < ret[ret_slice_idx].smallest_rt) + ret[ret_slice_idx].smallest_rt = rt; } - else + + // Count tiles, ublocks, strips + for (int idx = 0; idx < slice_count; idx++) { - TT_ASSERT( - layout == sparse::SparseBUDA::Layout::ZMajorDataflow || layout == sparse::SparseBUDA::Layout::BufferOp); - const int t_slice_size = bcast_slice_size / t_factor_r; // size of each slice after vslice(t_factor_r) - r = (index.rt / dflow_factor) % grid_r; - t = index.rt / t_slice_size % t_factor_r; - target_rt = (index.rt % bcast_slice_size * bcast_factor + index.rt / bcast_slice_size) % - (bcast_factor * dflow_factor); + const CounterEntry& e = ret[idx]; + sum_nz_tiles += e.rt_ct_cmb.size(); + sum_nz_ublocks += e.ubc_ubr_cmb.size(); + sum_nz_strips += e.ubc_idxs.size(); + if (e.smallest_rt >= 1 and e.smallest_rt < INT_MAX) + { + sum_nz_tiles++; + sum_nz_ublocks++; + } } - - const int ublock_r_idx = target_rt / u_rt; - const int ublock_c_idx = index.ct / u_kt; - const int strip_idx = t * m_k + ublock_c_idx; - - // Insert metadata per core - // - // - non-zero tiles - // - non-zero ublocks - // - non-zero strips - // - metadata_per_core.at(r).nz_tiles++; - std::uint64_t ublock_key = - static_cast(t) << 48 | // 16 bits for t - ((static_cast(ublock_r_idx) & 0xFFFFFF) << 24) | // 24 bits for ublock_r - (static_cast(ublock_c_idx) & 0xFFFFFF); // 24 bits for ublock_c - metadata_per_core.at(r).nz_ublocks.insert(ublock_key); - metadata_per_core.at(r).nz_strips.insert(strip_idx); - } - - // Copy over the results to return object - // - for (int r = 0; r < grid_r; ++r) - { - - // Previous solution multiplied with t_factor_c - copying that behaviour here... However, it's not obvious - // whether this is correct - // # tenstorrent/pybuda#2598 - // - sparse_metadata.nz_tiles.at(r) = metadata_per_core.at(r).nz_tiles * t_factor_c; - sparse_metadata.nz_ublocks.at(r) = metadata_per_core.at(r).nz_ublocks.size() * t_factor_c; - sparse_metadata.nz_strips.at(r) = metadata_per_core.at(r).nz_strips.size() * t_factor_c; } - return std::make_shared(sparse_metadata); + sum_nz_tiles *= t_factor_c; + sum_nz_ublocks *= t_factor_c; + sum_nz_strips *= t_factor_c; + return std::make_tuple<>(sum_nz_tiles, sum_nz_ublocks, sum_nz_strips); } } // namespace tt::balancer diff --git a/pybuda/csrc/balancer/balancer_utils.hpp b/pybuda/csrc/balancer/balancer_utils.hpp index e239f7ab..b3db4b4c 100644 --- a/pybuda/csrc/balancer/balancer_utils.hpp +++ b/pybuda/csrc/balancer/balancer_utils.hpp @@ -134,7 +134,7 @@ ResourceUsage get_edge_resource_usage_simple( OpModel const &consumer_op_model, bool is_queue = false); -std::shared_ptr get_sparse_matmul_metadata(balancer::OpModel const &grid); +std::tuple get_sparse_matmul_metadata(balancer::OpModel const &grid); } // namespace tt::balancer diff --git a/pybuda/csrc/balancer/legalizer/legalizer.cpp b/pybuda/csrc/balancer/legalizer/legalizer.cpp index 395b6ef5..18d4fcc6 100644 --- a/pybuda/csrc/balancer/legalizer/legalizer.cpp +++ b/pybuda/csrc/balancer/legalizer/legalizer.cpp @@ -326,10 +326,7 @@ static std::pair calculate_streaming_pars( graph->data_operands(op_node)[0]->as()->get_sparse_buda(); std::vector& sparse_zs = sparse_buda.sparse_zs; auto layout = sparse::SparseBUDA::create_layout(sparse_buffer_enable, dir.z_major(), fracture_factor); - int bcast_factor = - (layout == sparse::SparseBUDA::Layout::ZMajor || layout == sparse::SparseBUDA::Layout::ZMajorDataflow) - ? sparse_buda.bcast_factor - : 1; + int bcast_factor = (layout == sparse::SparseBUDA::Layout::ZMajor) ? sparse_buda.bcast_factor : 1; // Each potential t needs to evenly divide output's r-dim but also in1's r-dim (in0's c-dim) std::vector operands = graph->operand_data_edges(op_node); @@ -943,8 +940,9 @@ static std::vector calculate_output_buffer_models_for_grid( static bool is_input_node_parameter_or_constant(const graphlib::Node* node) { return node->node_type() == graphlib::NodeType::kInput and - (node->as()->is_parameter() or node->as()->is_constant() or - node->as()->is_optimizer_parameter()); + (node->as()->is_parameter() + or node->as()->is_constant() + or node->as()->is_optimizer_parameter()); } static std::vector calculate_parameter_buffer_models_for_grid( @@ -957,7 +955,8 @@ static std::vector calculate_parameter_buffer_models_for_grid( parameter_buffers.resize(operands.size()); for (int input_idx = 0; input_idx < (int)operands.size(); ++input_idx) { - if (is_input_node_parameter_or_constant(operands[input_idx]) and not force_dram_parameters) + if (is_input_node_parameter_or_constant(operands[input_idx]) and + not force_dram_parameters) { TensorShape const& parameter_shape = op_shape.producer_shapes[input_idx]; int grid_r = FactorizedInt(parameter_shape.rt).get_nearest_factor_le(selected_grid.r); diff --git a/pybuda/csrc/balancer/policies/policy_utils.hpp b/pybuda/csrc/balancer/policies/policy_utils.hpp index 7f1006ff..bb831014 100644 --- a/pybuda/csrc/balancer/policies/policy_utils.hpp +++ b/pybuda/csrc/balancer/policies/policy_utils.hpp @@ -392,7 +392,6 @@ const OpModel* pick_preferred_op_model( { auto op_models = current_graph_solver.at(op); const OpModel* prefered_op_model = nullptr; - for (const auto& op_model : op_models) { log_trace( diff --git a/pybuda/csrc/balancer/python_bindings.cpp b/pybuda/csrc/balancer/python_bindings.cpp index 59cdb4d6..f0a053e2 100644 --- a/pybuda/csrc/balancer/python_bindings.cpp +++ b/pybuda/csrc/balancer/python_bindings.cpp @@ -10,7 +10,6 @@ #include "balancer/policies/policy_utils.hpp" #include "balancer/python_interface.hpp" #include "balancer/balancer_utils.hpp" -#include "balancer/types.hpp" #include "graph_lib/utils.hpp" #include "placer/placer.hpp" #include "passes/fuse_ops.hpp" @@ -361,7 +360,9 @@ void BalancerModule(py::module &m_balancer) { .def_readonly("output_buffers", &OpModel::output_buffers) .def_readonly("parameter_buffers", &OpModel::parameter_buffers) .def_readonly("is_sparse_matmul", &OpModel::is_sparse_matmul) - .def("get_sparse_metadata", &OpModel::get_sparse_metadata) + .def_readonly("nz_tiles", &OpModel::nz_tiles) + .def_readonly("nz_ublocks", &OpModel::nz_ublocks) + .def_readonly("nz_strips", &OpModel::nz_strips) .def("block_shape", &OpModel::block_shape) .def("__repr__", [](OpModel const& a) { std::stringstream ss; @@ -369,16 +370,6 @@ void BalancerModule(py::module &m_balancer) { return ss.str(); }); - py::class_(m_balancer, "SparseMetadata") - .def_readonly("nz_tiles", &OpModel::SparseMetadata::nz_tiles) - .def_readonly("nz_ublocks", &OpModel::SparseMetadata::nz_ublocks) - .def_readonly("nz_strips", &OpModel::SparseMetadata::nz_strips) - .def("__repr__", [](OpModel::SparseMetadata const& a) { - std::stringstream ss; - ss << a; - return ss.str(); - }); - py::class_(m_balancer, "FusedSubOpModel") .def_readonly("type", &FusedSubOpModel::type) .def_readonly("mblock_m", &FusedSubOpModel::mblock_m) diff --git a/pybuda/csrc/balancer/types.cpp b/pybuda/csrc/balancer/types.cpp index 0296bf0a..7bab5e5a 100644 --- a/pybuda/csrc/balancer/types.cpp +++ b/pybuda/csrc/balancer/types.cpp @@ -229,15 +229,18 @@ int OpModel::get_execution_cycles_uncached(std::string const &arch_name, bool th { std::shared_ptr fused_op = this->fused_op(); - // Calculate sparse matmul metadata and write into OpModel's SparseMetadata struct + // Calculate sparse-matmul metadata and cache the result if (env_as("PYBUDA_TEMP_ENABLE_NEW_SPARSE_ESTIMATES", false) and this->is_sparse_matmul and - this->sparse_metadata == nullptr) + this->nz_ublocks == -1) { auto mf = this->math_fidelity(); if (mf == tt::MathFidelity::HiFi2 or mf == tt::MathFidelity::LoFi) { + auto [nz_tiles, nz_ublocks, nz_strips] = get_sparse_matmul_metadata(*this); auto *p_this = const_cast(this); - p_this->sparse_metadata = get_sparse_matmul_metadata(*this); + p_this->nz_tiles = nz_tiles; + p_this->nz_ublocks = nz_ublocks; + p_this->nz_strips = nz_strips; } } diff --git a/pybuda/csrc/balancer/types.hpp b/pybuda/csrc/balancer/types.hpp index b2618657..7b3424e4 100644 --- a/pybuda/csrc/balancer/types.hpp +++ b/pybuda/csrc/balancer/types.hpp @@ -6,7 +6,6 @@ #include #include #include -#include #include #include #include @@ -264,24 +263,6 @@ struct Padding // struct OpModel { - struct SparseMetadata { - std::vector nz_tiles; - std::vector nz_ublocks; - std::vector nz_strips; - - bool operator==(SparseMetadata const &other) const - { - return nz_tiles == other.nz_tiles and nz_ublocks == other.nz_ublocks and nz_strips == other.nz_strips; - } - - SparseMetadata(int grid_r) - { - nz_tiles.resize(grid_r, 0); - nz_ublocks.resize(grid_r, 0); - nz_strips.resize(grid_r, 0); - } - }; - UniqueId id; GridShape grid_shape; OpShape op_shape; @@ -291,9 +272,10 @@ struct OpModel bool sparse_buffer = false; bool is_sparse_matmul = false; bool consumes_rz_major = false; - const sparse::SparseBUDA *sparse_buda = nullptr; // sparse-matmul specific - std::shared_ptr sparse_metadata = nullptr; // sparse-matmul specific - // ^ using shared_ptr (vs unique_ptr) to allow for implicit copy construction of OpModel + int nz_tiles = 0; // sparse-matmul specific + int nz_ublocks = -1; // sparse-matmul specific + int nz_strips = -1; // sparse-matmul specific + const sparse::SparseBUDA *sparse_buda = nullptr; // sparse-matmul specific TStreamFactor t_stream_factor; int fracture_factor; int sparse_indices; @@ -390,14 +372,8 @@ struct OpModel return static_cast(get_output_bytes()) / get_execution_cycles(arch_name); } - // PyBind is acting strange with std::shared_ptr, and there seem to be some bugs reported on this, doing this for - // now... - const SparseMetadata get_sparse_metadata() { return *sparse_metadata.get(); } - bool operator==(OpModel const &other) const { return id == other.id; } - // This function is used to compare two OpModels for similarity. It is used for caching mechanisms. - // bool is_similar(OpModel const &other) const { return buda_op_node == other.buda_op_node @@ -406,10 +382,12 @@ struct OpModel and fracture_factor == other.fracture_factor and input_prologue == other.input_prologue and sparse_buffer == other.sparse_buffer + and nz_tiles == other.nz_tiles + and nz_ublocks == other.nz_ublocks + and nz_strips == other.nz_strips and padding == other.padding and input_buffers == other.input_buffers - and output_buffers == other.output_buffers - and is_similar_sparse_metadata(other); + and output_buffers == other.output_buffers; } TensorShape get_out_shape(bool post_t_stream = true) const @@ -431,20 +409,6 @@ struct OpModel private: int get_execution_cycles_uncached(std::string const &arch_name, bool theoretical = false) const; - bool is_similar_sparse_metadata(OpModel const &other) const - { - if (sparse_metadata == nullptr and other.sparse_metadata == nullptr) - { - return true; - } - - if (sparse_metadata == nullptr or other.sparse_metadata == nullptr) - { - return false; - } - - return *sparse_metadata == *other.sparse_metadata; - } }; using LegalOpModels = std::unordered_map>; @@ -897,18 +861,6 @@ inline std::ostream &ostream_with_indent(std::ostream &os, OpModel const &op_mod inline std::ostream &operator<<(std::ostream &os, OpModel const &op_model) { return ostream_with_indent(os, op_model); } -inline std::ostream &operator<<(std::ostream &os, OpModel::SparseMetadata const &sparse_metadata) -{ - os << "SparseMetadata{.nz_tiles = {"; - for (int nz_tile : sparse_metadata.nz_tiles) os << nz_tile << ", "; - os << "}, .nz_ublocks = {"; - for (int nz_ublock : sparse_metadata.nz_ublocks) os << nz_ublock << ", "; - os << "}, .nz_strips = {"; - for (int nz_strip : sparse_metadata.nz_strips) os << nz_strip << ", "; - os << "}}"; - return os; -} - inline std::ostream &operator<<(std::ostream &os, FusedSubOpModel const &sub_op_model) { os << "FusedSubOpModel{" << std::endl; diff --git a/pybuda/csrc/shared_utils/sparse_matmul_utils.hpp b/pybuda/csrc/shared_utils/sparse_matmul_utils.hpp index b71499ed..ffec94aa 100644 --- a/pybuda/csrc/shared_utils/sparse_matmul_utils.hpp +++ b/pybuda/csrc/shared_utils/sparse_matmul_utils.hpp @@ -347,147 +347,12 @@ struct SparseBUDA public: enum class Layout { - Default, // Default layout - ZMajor, // Z-major layout, e.g. RZ streaming - go thru all Zs first, for a given R slice - ZMajorDataflow, // Z-major layout, special cased for sparse->dense dataflow (same as ZMajor, but slice - // vertically down to single tile) - BufferOp, // Used for sparse buffer matmuls + Default, + ZMajor, + ZMajorDataflow, + BufferOp, }; - // A little more on layouts... - // - // Layout dictates what the sparse tensor of the sparse matmul will look like. Sparse tensors get divided into - // chunks, to accomodate for parallelization across cores and t-streaming. Depending on what the chunks look like, - // and which core receives which chunk, the performance profile of the sparse matmul can change - additionally, - // whether the parallelization is legal can also change. - // - // Sparse matmul ops are used for various scenarios, but most often they're there as building blocks of - // convolutions. Sparse tensors of such ops have a specific pattern that looks something like this (e.g. for a 2x2 - // convolution): - // [ - // 1 0 0 0 - // 0 1 0 0 - // 0 0 1 0 - // 0 0 0 1 - // 1 0 0 0 - // 0 1 0 0 - // 0 0 1 0 - // 0 0 0 1 - // 1 0 0 0 - // 0 1 0 0 - // 0 0 1 0 - // 0 0 0 1 - // 1 0 0 0 - // 0 1 0 0 - // 0 0 1 0 - // 0 0 0 1 - // ] - // So they will look like a set of diagonal matrices, each one representing a single kernel point of the - // convolution. - // - // Currently, there are 3 variants of the layout: Default, ZMajor, and ZMajorDataflow. - // - // Default: - // In this layout, the sparse tensor is in its original shape. - // - // ZMajor: - // In this layout, the goal is to eliminate serialization in execution between cores. It is easier to explain the - // transformation first, and then show why it eliminates serialization. - // - // The transformation is as follows: - // - Let's say we have 2x1 cores for this sparse op and t=2. The top core will handle the first two kernel - // points while the second core will handle the last two kernel points - this is by design of the layout. In t=0 - // the top core will get the top parts of the sparse tensor's first two kernel points, which is two 2x4x4 - // pieces. The bottom core will get the top parts of the sparse tensor's last two kernel points, which is also - // two 2x4x4 pieces. So for t=0, the first core's tensor will look like this: - // [ - // 1 0 0 0 <- first row of the first kernel point - // 0 1 0 0 <- second row of the first kernel point - // 1 0 0 0 <- first row of the second kernel point - // 0 1 0 0 <- second row of the second kernel point - // ] - // - And the second core, for t=0, it's tensor will look like this: - // [ - // 1 0 0 0 <- first row of the third kernel point - // 0 1 0 0 <- second row of the third kernel point - // 1 0 0 0 <- first row of the fourth kernel point - // 0 1 0 0 <- second row of the fourth kernel point - // ] - // - If we were to continue this for t=1, the first core's tensor would look like this: - // [ - // 0 0 1 0 <- third row of the first kernel point - // 0 0 0 1 <- fourth row of the first kernel point - // 0 0 1 0 <- third row of the second kernel point - // 0 0 0 1 <- fourth row of the second kernel point - // ] - // - And the second core, for t=1, it's tensor will look like this: - // [ - // 0 0 1 0 <- third row of the third kernel point - // 0 0 0 1 <- fourth row of the third kernel point - // 0 0 1 0 <- third row of the fourth kernel point - // 0 0 0 1 <- fourth row of the fourth kernel point - // ] - // - // From the perspective of the sparse matmul, where these sparse tensors are left operands, the ZMajor layout - // makes it so that all cores, in a given point in time, are working on the same inner dimension of the matmul. If - // we think about how the right operand is buffered (horizontal strips), and keep in mind that all cores receive - // the full right operand tensor, this layout makes it so that all cores are working on the same strip of the - // right operand tensor, which removes any serialization in execution. - // - // Naming comes from the fact that the sparse tensor is first vertically sliced by a factor of kernel points, - // which is in this case a vslice(4). Then, a core will work on a set of Zs for a given R slice, so in a way, if - // we fix the height that a core is reading (in this case 2-tile high), we go through the tensor in Z-major order, - // hence the name ZMajor for the layout. - // - // ZMajorDataflow: - // This layout is a special case of ZMajor, where the sparse tensor is vertically sliced all the way down to a - // single tile. The goal of this layout is to improve the dataflow between the sparse and dense matmuls - if we - // place a sparse&dense pair of ops next to each other, with the same height of cores, as we usually do, this - // layout will enable dataflow thru NOC to use direct pipes, which means that each sparse core will send its - // output directly and only to its corresponding core of dense matmul, which is right next to the sparse core. - // Hence the "Dataflow" in the name. - // - // Using the sparse tensor from the Default layout, the transformation to ZMajorDataflow is as follows: - // - Similar to ZMajor, we can first imagine that the sparse tensor is vertically sliced by a factor of kernel - // points, which is in this case a vslice(4), so the tensor goes from 16x4 to a 4x4x4. Then, imagine that the - // tensor is read completely in Z-major order, tile by tile. Let's say we have 2x1 cores for this sparse op and - // t=2. First core, for t=0, will get these tiles: - // [ - // 1 0 0 0 <- first tile of the first kernel point - // 1 0 0 0 <- first tile of the second kernel point - // 1 0 0 0 <- first tile of the third kernel point - // 1 0 0 0 <- first tile of the fourth kernel point - // ] - // - And the second core, for t=0, it's tensor will look like this: - // [ - // 0 1 0 0 <- second tile of the first kernel point - // 0 1 0 0 <- second tile of the second kernel point - // 0 1 0 0 <- second tile of the third kernel point - // 0 1 0 0 <- second tile of the fourth kernel point - // ] - // - If we were to continue this for t=1, the first core's tensor would look like this: - // [ - // 0 0 1 0 <- third tile of the first kernel point - // 0 0 1 0 <- third tile of the second kernel point - // 0 0 1 0 <- third tile of the third kernel point - // 0 0 1 0 <- third tile of the fourth kernel point - // ] - // - And the second core, for t=1, it's tensor will look like this: - // [ - // 0 0 0 1 <- fourth tile of the first kernel point - // 0 0 0 1 <- fourth tile of the second kernel point - // 0 0 0 1 <- fourth tile of the third kernel point - // 0 0 0 1 <- fourth tile of the fourth kernel point - // ] - // - // The main difference between ZMajor and ZMajorDataflow is that ZMajorDataflow is vertically sliced all the way - // down to a single tile, which enables the dataflow between the sparse and dense matmuls to use direct pipes. - // However, by doing this transformation, there is a side-effect on the output, the rows of the output are all - // mixed up, i.e. they're not in the order that the dense matmul "expects" them to be in. To correct this, a set - // of TM ops is applied to the output of the sparse matmul, which will reorder the rows back to the correct order. - // This set of TM ops might look complicated, but when worked out, it just makes it so that each sparse matmul - // core sends data only to its corresponding dense matmul core, which is right next to it. - static Layout create_layout(bool buffer_op, bool z_major, int fracture_factor); std::vector sparse_zs; diff --git a/pybuda/pybuda/op/eval/buda/matmul.py b/pybuda/pybuda/op/eval/buda/matmul.py index 3e38b077..3c9a4041 100644 --- a/pybuda/pybuda/op/eval/buda/matmul.py +++ b/pybuda/pybuda/op/eval/buda/matmul.py @@ -388,55 +388,16 @@ def input_ublock_order(type, attr, num_operands): def execution_cycles(type, arch_name, op_model, theoretical) -> int: - # Special handling for sparse matmul as the backend API assumes 1x1 grid, but each sparse matmul core can do - # different amount of work, depending on what the encodings (sparse tensor) look like. Call for each core to find - # the slowest one. - # - if op_model.is_sparse_matmul: - # Calculate cycles per core - # - if ( - os.environ.get("PYBUDA_TEMP_SCALE_SPARSE_ESTIMATE_ARGS", False) - and os.environ.get("PYBUDA_TEMP_SPARSE_ESTIMATE_ARGS_PER_CORE", False) - ): - cycles_to_return = 0 - for r in range(op_model.grid_shape.r): - # Generate op model desc for current core - # - op_model_desc = op_model_to_desc(type, arch_name, op_model, sparse_r=r) - - # Get execution cycles, try from cache first, if miss, then calculate - # - curr_cycles = 0 - compiler_cache_cycles = get_compiler_cached_cycles(op_model_desc) - if compiler_cache_cycles is not None: - curr_cycles = compiler_cache_cycles - else: - curr_cycles = get_op_model_execution_cycles(op_model_desc) - - # Save max cycles - # - cycles_to_return = max(cycles_to_return, curr_cycles) - else: - # Otherwise fallback to default behavior (create single op_model_desc, and let it decide whether to average - # parameters, or to sum and pretend everything is on a single core) - # - op_model_desc = op_model_to_desc(type, arch_name, op_model) - compiler_cache_cycles = get_compiler_cached_cycles(op_model_desc) - if compiler_cache_cycles is not None: - cycles_to_return = compiler_cache_cycles - else: - cycles_to_return = get_op_model_execution_cycles(op_model_desc) - - return cycles_to_return - # End sparse matmul exec cycles calculation - op_model_desc = op_model_to_desc(type, arch_name, op_model) compiler_cache_cycles = get_compiler_cached_cycles(op_model_desc) if compiler_cache_cycles is not None: return compiler_cache_cycles + is_sparse = op_model.is_sparse_matmul + if is_sparse: + return get_op_model_execution_cycles(op_model_desc) + # Math fidelity and data format are just estimated guesses for now math_fid = math_fidelity_to_multiplier(op_model.math_fidelity()) u_kt = op_model.input_buffers[0].block_shape.ublock.ct diff --git a/pybuda/pybuda/op/eval/common.py b/pybuda/pybuda/op/eval/common.py index 65d1227d..6c3eddc4 100644 --- a/pybuda/pybuda/op/eval/common.py +++ b/pybuda/pybuda/op/eval/common.py @@ -331,14 +331,7 @@ def data_format_to_int(df: DataFormat) -> int: return 11 raise RuntimeError(f"Unknown data format {df}") -def op_model_to_desc( - type: str, - arch_name: str, - op_model: OpModel, - sub_op_model: FusedSubOpModel = None, - sparse_grid_row=-1, -) -> OpModelDesc: - +def op_model_to_desc(type: str, arch_name: str, op_model: OpModel, sub_op_model: FusedSubOpModel = None) -> OpModelDesc: desc = OpModelDesc() desc.arch = arch_name desc.data_format = op_model.data_format @@ -346,13 +339,6 @@ def op_model_to_desc( desc.t = op_model.output_buffers[0].block_shape.t desc.approx_mode = False - # tenstorrent/pybuda#2565 - # Example overrides (DataFormat or MathFidelity) to target sparse estimates v2 when op is originally Bfp8_b+HiFi2 - # if type == "matmul" and op_model.is_sparse_matmul: - # if desc.data_format == DataFormat.Bfp8_b and desc.math_fidelity == MathFidelity.HiFi2: - # # desc.data_format = DataFormat.Float16_b # Override DataFormat - # desc.math_fidelity = MathFidelity.LoFi # Override MathFidelity - if op_model.op_type() == "fused_op": desc.type = sub_op_model.type desc.mblock_m = sub_op_model.mblock_m @@ -380,49 +366,22 @@ def op_model_to_desc( desc.mblock_k = op_model.op_shape.inputs[1].rt // desc.ublock_kt desc.sparse_indices = op_model.sparse_indices if os.environ.get("PYBUDA_TEMP_ENABLE_NEW_SPARSE_ESTIMATES", False): - sparse_metadata = op_model.get_sparse_metadata() - desc.sparse_indices = sum(sparse_metadata.nz_tiles) - desc.sparse_nz_ublocks = sum(sparse_metadata.nz_ublocks) - desc.sparse_nz_strips = sum(sparse_metadata.nz_strips) - - # Op model descriptor assumes grid_size [1, 1], so we need to scale down the parameters to what is - # expected to end up on a single core. Initially, we did this by averaging the parameters with the - # number of cores. However, not all the cores perform the same amount of work, so we need to - # calculate parameters per core. We keep both of these modes in this transition period. - # - # PYBUDA_TEMP_SCALE_SPARSE_ESTIMATE_ARGS must be set to true to enable any of the mentioned modes. - # - # Mode 1: - # Average the parameters (by default) - # Mode 2: - # Scale the parameters by the number of cores (needs the env var - # "PYBUDA_TEMP_SPARSE_ESTIMATE_ARGS_PER_CORE" to be set to true) - # - per_core_mode = os.environ.get("PYBUDA_TEMP_SPARSE_ESTIMATE_ARGS_PER_CORE", False) - if not per_core_mode: - # Average mode - # - nz_tiles = sum(sparse_metadata.nz_tiles) - nz_ublocks = sum(sparse_metadata.nz_ublocks) - nz_strips = sum(sparse_metadata.nz_strips) - - if nz_tiles > 1: - desc.sparse_indices = max(nz_tiles // op_model.grid_shape.r, 1) + desc.sparse_nz_ublocks = op_model.nz_ublocks + desc.sparse_nz_strips = op_model.nz_strips + + # op model descriptor assumes grid_size [1, 1], so we need to scale down the number of + # sparse tiles, ublocks and strips to what is expected to end up on a single core + if os.environ.get("PYBUDA_TEMP_SCALE_SPARSE_ESTIMATE_ARGS", False): + if op_model.nz_tiles > 1: + desc.sparse_indices = max(op_model.nz_tiles // op_model.grid_shape.r, 1) else: - desc.sparse_indices = nz_tiles - - if nz_ublocks > 1: - desc.sparse_nz_ublocks = max(nz_ublocks // op_model.grid_shape.r, 1) - - if nz_strips > 1: - desc.sparse_nz_strips = max(nz_strips // op_model.grid_shape.r, 1) - else: - # Per core mode - # - assert sparse_grid_row != -1 # Must provide which row of cores we're fetching the estimates for - desc.sparse_indices = sparse_metadata.nz_tiles[sparse_grid_row] - desc.sparse_nz_ublocks = sparse_metadata.nz_ublocks[sparse_grid_row] - desc.sparse_nz_strips = sparse_metadata.nz_strips[sparse_grid_row] + desc.sparse_indices = op_model.nz_tiles + + if op_model.nz_ublocks > 1: + desc.sparse_nz_ublocks = max(op_model.nz_ublocks // op_model.grid_shape.r, 1) + + if op_model.nz_strips > 1: + desc.sparse_nz_strips = max(op_model.nz_strips // op_model.grid_shape.r, 1) else: # old sparse estimates if os.environ.get("PYBUDA_TEMP_SCALE_SPARSE_ESTIMATE_ARGS", False): diff --git a/pybuda/pybuda/op/eval/pybuda/matmul.py b/pybuda/pybuda/op/eval/pybuda/matmul.py index ae57acf8..de7a433c 100644 --- a/pybuda/pybuda/op/eval/pybuda/matmul.py +++ b/pybuda/pybuda/op/eval/pybuda/matmul.py @@ -179,7 +179,7 @@ def lower(type, attr, buda_attr, lc, ops, outputs): picker = lc.get_pytorch_tensor(in0) zdim = 1 if len(picker.shape) < 3 else picker.shape[-3] - z_bcast_factor = 1 if len(attr) < 2 else attr[1] # set in sparse matmul's decompose + z_bcast_factor = 1 if len(attr) < 2 else attr[1] # We can fully fracture kH * kW max_fracture_factor = z_bcast_factor if is_kernel_fracturing_candidate(ops, z_bcast_factor) else 1 @@ -255,8 +255,6 @@ def decompose(type, attr, dc, inputs): accumulate = (len(attr) >= 1) and bool(attr[0]) z_bcast_factor = zdim if (zdim > 1 and in1.shape[-3] == 1) else 1 - # In case of convolutions, z_bcast_factor is the volume of the conv's kernel (kernel_height * kernel_width) - if z_bcast_factor > 1: picker = torch.cat([picker[0][z] for z in range(z_bcast_factor)]) sparse = dc.tensor(picker) diff --git a/pybuda/test/benchmark/benchmark/models/other.py b/pybuda/test/benchmark/benchmark/models/other.py index 43bc45ca..ac6a257f 100644 --- a/pybuda/test/benchmark/benchmark/models/other.py +++ b/pybuda/test/benchmark/benchmark/models/other.py @@ -5,7 +5,7 @@ Catch-all for random perf testing """ -import os +import numpy as np import pybuda import torch @@ -84,21 +84,21 @@ def forward(self, x): @benchmark_model(configs=["224"]) -def big_conv(training: bool, config: str, microbatch: int, devtype: str, arch: str, data_type: str, math_fidelity: str): +def big_conv(training: bool, config: str, microbatch: int, devtype: str, arch: str): if config == "224": input_size = (224, 224) cin = 3 cout = 64 kH = 7 kW = 7 - stride = 1 - padding = "same" + stride = 2 + padding = 3 dilation = 1 else: raise RuntimeError(f"Invalid config: {config}") if microbatch == 0: - microbatch = 64 + microbatch = 1 mod = ConvTModule( name="big_conv_benchmark", @@ -113,8 +113,7 @@ def big_conv(training: bool, config: str, microbatch: int, devtype: str, arch: s bias=False) compiler_cfg = _get_global_compiler_config() - compiler_cfg.balancer_policy = "Ribbon" - os.environ["PYBUDA_RIBBON2"] = "1" + compiler_cfg.balancer_policy = "CNN" models = {"tt": mod} inputs = [torch.rand(microbatch, cin, input_size[0], input_size[1])]