Skip to content

Commit

Permalink
Revert "[sparse] Add ability to calculate per-core sparse matmul meta…
Browse files Browse the repository at this point in the history
…data (used for execution cycles estimation)"

(cherry picked from commit a1bae21af4c89bcbdb7e900a8c3f8b6095c7925d)
  • Loading branch information
svuckovicTT authored and vmilosevic committed May 9, 2024
1 parent fe1230d commit d00a010
Show file tree
Hide file tree
Showing 13 changed files with 155 additions and 426 deletions.
3 changes: 1 addition & 2 deletions README.debug.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,10 @@
## Temp overrides
* PYBUDA\_TEMP\_ENABLE\_NEW\_SPARSE\_ESTIMATES: Apply new formula to estimate the cycle count of sparse matmul ops (currently only support LoFi and HiFi2 fidelities)
* PYBUDA\_TEMP\_SCALE\_SPARSE\_ESTIMATE\_ARGS: Scale counts of non-zero tiles, ublocks and strips to reflect the numbers that would end up on a single core, since BBE estimates always assume grid_size [1,1].
* PYBUDA\_TEMP\_SPARSE\_ESTIMATE\_ARGS\_PER\_CORE: Instead of uniformly scaling sparse args (as happens in PYBUDA_TEMP_SCALE_SPARSE_ESTIMATE_ARGS), calculate them per core. To use, need set PYBUDA_TEMP_SCALE_SPARSE_ESTIMATE_ARGS to 1 as well.
* PYBUDA\_TEMP\_ELT\_UNARY\_ESTIMATES\_LEGACY: Force legacy path of calculating execution cycles for eltwise unary ops - instead of calling into BBE, use hand-crafted FE-side logic
* PYBUDA\_TEMP\_ENABLE\_NEW\_FUSED\_ESTIMATES: Apply new formula to estimate the cycle count of fused ops. The formula calls BBE to estimate each subop and sums up the results.
* PYBUDA\_LEGACY\_KERNEL\_BROADCAST: Use legacy kernel broadcast detection path. Will detect fewer kernel broadcasts, and will oftentimes use more tiles (longer KBs).
* PYBUDA\_TEMP\_BALANCER\_MODEL\_PCIE\_BW: Estimate PCIe bandwidth in limiter cycles. (default: 1/True)
* PYBUDA\_TEMP\_BALANCER\_MODEL\_PCIE\_BW: Estimate PCIe bandwidth in limiter cycles. (default: 1/True))
* PYBUDA\_TEMP\_BALANCER\_DISABLE\_TARGET\_PROXIMITY: Disable target proximity in balancer. (default: 0/False)
* PYBUDA\_TEMP\_DISABLE\_FJ\_NOP\_SCHEDULE\_FIX: This flag disables a fix that forces FJ buffering nops to be scheduled last.
* PYBUDA\_TEMP\_FIX\_2351: Controls the fix for bug #2351 - fork-join can end up adding buffering nops and queues on same path, this control flag fixes it.
Expand Down
194 changes: 99 additions & 95 deletions pybuda/csrc/balancer/balancer_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,9 @@
#include "balancer_utils.hpp"

#include <cstdint>
#include <memory>
#include <unordered_map>

#include "balancer/types.hpp"
#include "passes/t_stream.hpp"
#include "shared_utils/sparse_matmul_utils.hpp"
#include "utils/assert.hpp"
#include "utils/hash_combine.hpp"
#include "utils/logger.hpp"
#include "utils/profile.hpp"
Expand Down Expand Up @@ -587,116 +583,124 @@ ResourceUsage get_edge_resource_usage_simple(
return usage;
}

// Calculates sparse matmul metadata that is used for kernel execution cycle estimatation. These are:
// - number of non-zero tiles per core
// - number of non-zero ublocks per core
// - number of non-zero strips per core
//
std::shared_ptr<const OpModel::SparseMetadata> get_sparse_matmul_metadata(balancer::OpModel const& op_model)
std::tuple<uint32_t, uint32_t, uint32_t> get_sparse_matmul_metadata(balancer::OpModel const& op_model)
{
OpModel::SparseMetadata sparse_metadata(op_model.grid_shape.r);

int grid_r = op_model.grid_shape.r;
int u_rt = op_model.output_buffers[0].block_shape.ublock.rt;
int u_kt = op_model.input_buffers[1].block_shape.ublock.rt;
int t_factor_c = op_model.t_stream_factor.c;
int t_factor_r = op_model.t_stream_factor.r;
const sparse::SparseBUDA& sparse_buda = *(op_model.sparse_buda);
const int grid_r = op_model.grid_shape.r;
const int u_rt = op_model.output_buffers[0].block_shape.ublock.rt;
const int u_kt = op_model.input_buffers[1].block_shape.ublock.rt;
const int m_k = sparse_buda.sparse_shape[1] / sparse::TILE_DIM / u_kt;
const int t_factor_c = op_model.t_stream_factor.c;
const int t_factor_r = op_model.t_stream_factor.r;
const int bcast_factor = sparse_buda.bcast_factor; // broadcast factor
const std::vector<sparse::SparseIndex>& sparse_indices = sparse_buda.sparse_indices;
auto layout = sparse::SparseBUDA::create_layout(
op_model.has_sparse_buffer() or env_as<bool>("PYBUDA_FORCE_SPARSE_BUFFER_LAYOUT"),
op_model.t_stream_factor.dir.z_major(),
op_model.fracture_factor);
int bcast_factor = sparse_buda.bcast_factor;
int zdim = sparse_buda.sparse_zs.size();

const int sparse_rt =
sparse_buda.sparse_shape[0] / sparse::TILE_DIM; // number of tiles in row dim of sparse tensor
const int bcast_slice_size = sparse_rt / bcast_factor; // size of each slice after bcast slicing
const int r_tiles_in_core = sparse_rt / grid_r / t_factor_r; // rows of tiles in a single core
const int dflow_factor =
(layout == sparse::SparseBUDA::Layout::ZMajorDataflow) ? sparse_rt / grid_r / t_factor_r / bcast_factor : 1;
// Initialize tiles/ublocks/strips counter
int sum_nz_tiles = 0;
int sum_nz_ublocks = 0;
int sum_nz_strips = 0;
constexpr int TILE_DIM = tt::sparse::TILE_DIM;

// Helper struct - accounting for metadata per core
//
struct MetadataPerCore
struct CounterEntry
{
unsigned int nz_tiles = 0;
std::unordered_set<uint64_t> nz_ublocks;
std::unordered_set<uint64_t> nz_strips;
std::unordered_set<uint64_t> rt_ct_cmb;
std::unordered_set<uint64_t> ubc_ubr_cmb;
std::unordered_set<int> ubc_idxs;
int smallest_rt;
CounterEntry() : smallest_rt(INT_MAX){};
};

std::vector<MetadataPerCore> metadata_per_core(grid_r);
std::vector<CounterEntry> counters;
int slice_count = grid_r * t_factor_r;

// Calculate metadata below
// Go through all the sparse indices (coordinates of non-zero tiles in the canonical sparse tensor), and map them to
// coords (core, t-dim, ublock, ...) based on the layout provided
//
// The math below, which is different for each layout, is not obvious. Need to add some documentation to explain.
// Tracking issue:
// # tenstorrent/pybuda#2601
//
int r, t, target_rt;
for (const sparse::SparseIndex& index : sparse_indices)
// Iterate throufh all sparse tensors
for (int z = 0; z < zdim; z++)
{
if (layout == sparse::SparseBUDA::Layout::Default)
{
r = (index.rt % (sparse_rt / t_factor_r)) / r_tiles_in_core;
t = index.rt / (sparse_rt / t_factor_r);
target_rt = index.rt % r_tiles_in_core;
}
else if (layout == sparse::SparseBUDA::Layout::ZMajor)
auto sparse = sparse_buda.sparse_zs[z];

// Take stat of the sparseCOO
int dflow_factor = (layout == sparse::SparseBUDA::Layout::ZMajorDataflow)
? (sparse.rt() / grid_r / t_factor_r / bcast_factor)
: 1;
int num_slices = (layout == tt::sparse::SparseBUDA::Layout::Default)
? grid_r * t_factor_r
: grid_r * t_factor_r * bcast_factor * dflow_factor;
std::int64_t slice_height = sparse.shape[0] / num_slices;

std::vector<CounterEntry> ret(num_slices);
for (size_t idx = 0; idx < sparse.rows.size(); idx++)
{
int core_factor = std::max(1, grid_r % bcast_factor);
r = index.rt / bcast_slice_size * (grid_r / bcast_factor) + (index.rt / r_tiles_in_core) % core_factor;
t = (index.rt % (sparse_rt / bcast_factor)) / (sparse_rt / t_factor_r / bcast_factor);
target_rt = index.rt % r_tiles_in_core;
// Count nonzero tiles/ublocks/strips in the SparseCOO
int ret_slice_idx = -1, rt = -1;
if (layout == tt::sparse::SparseBUDA::Layout::Default)
{
ret_slice_idx = sparse.rows[idx] / slice_height;
rt = (sparse.rows[idx] % slice_height) / TILE_DIM;
}
else if (layout == tt::sparse::SparseBUDA::Layout::ZMajor)
{
int slice_idx = sparse.rows[idx] / slice_height;
int inner_idx = (slice_idx / slice_count) * grid_r + (slice_idx % grid_r);
int slice_inner_idx = inner_idx % bcast_factor;
ret_slice_idx = (slice_idx % (grid_r * t_factor_r)) / grid_r * grid_r + (inner_idx / bcast_factor);
int new_rows = (sparse.rows[idx] % slice_height) + slice_height * slice_inner_idx;
rt = new_rows / TILE_DIM;
}
else
{
TT_ASSERT(
layout == sparse::SparseBUDA::Layout::BufferOp or
layout == sparse::SparseBUDA::Layout::ZMajorDataflow);
if (layout == sparse::SparseBUDA::Layout::ZMajorDataflow and
((sparse.rt() / grid_r / t_factor_r) % bcast_factor != 0))
continue;

int slice_idx = sparse.rows[idx] / slice_height;
int inner_idx = (slice_idx % (dflow_factor * grid_r)) * bcast_factor +
(slice_idx / (dflow_factor * grid_r * t_factor_r));
int slice_inner_idx = inner_idx % (bcast_factor * dflow_factor);
ret_slice_idx = (slice_idx / (dflow_factor * grid_r)) % t_factor_r * grid_r +
(inner_idx / (bcast_factor * dflow_factor));
int new_rows = (sparse.rows[idx] % slice_height) + slice_height * slice_inner_idx;
rt = new_rows / TILE_DIM;
}
int ct = sparse.cols[idx] / TILE_DIM;
int ubr_idx = rt / u_rt;
int ubc_idx = ct / u_kt;
uint64_t rt_ct_key = (uint64_t(rt) << 32) | (ct & 0x0FFFF);
uint64_t ubc_ubr_key = (uint64_t(ubc_idx) << 32) | (ubr_idx & 0x0FFFF);

// Add the metadata to counting struct
CounterEntry& e = ret[ret_slice_idx];
e.rt_ct_cmb.insert(rt_ct_key);
e.ubc_ubr_cmb.insert(ubc_ubr_key);
e.ubc_idxs.insert(ubc_idx);
if (rt < ret[ret_slice_idx].smallest_rt)
ret[ret_slice_idx].smallest_rt = rt;
}
else

// Count tiles, ublocks, strips
for (int idx = 0; idx < slice_count; idx++)
{
TT_ASSERT(
layout == sparse::SparseBUDA::Layout::ZMajorDataflow || layout == sparse::SparseBUDA::Layout::BufferOp);
const int t_slice_size = bcast_slice_size / t_factor_r; // size of each slice after vslice(t_factor_r)
r = (index.rt / dflow_factor) % grid_r;
t = index.rt / t_slice_size % t_factor_r;
target_rt = (index.rt % bcast_slice_size * bcast_factor + index.rt / bcast_slice_size) %
(bcast_factor * dflow_factor);
const CounterEntry& e = ret[idx];
sum_nz_tiles += e.rt_ct_cmb.size();
sum_nz_ublocks += e.ubc_ubr_cmb.size();
sum_nz_strips += e.ubc_idxs.size();
if (e.smallest_rt >= 1 and e.smallest_rt < INT_MAX)
{
sum_nz_tiles++;
sum_nz_ublocks++;
}
}

const int ublock_r_idx = target_rt / u_rt;
const int ublock_c_idx = index.ct / u_kt;
const int strip_idx = t * m_k + ublock_c_idx;

// Insert metadata per core
//
// - non-zero tiles
// - non-zero ublocks
// - non-zero strips
//
metadata_per_core.at(r).nz_tiles++;
std::uint64_t ublock_key =
static_cast<std::uint64_t>(t) << 48 | // 16 bits for t
((static_cast<std::uint64_t>(ublock_r_idx) & 0xFFFFFF) << 24) | // 24 bits for ublock_r
(static_cast<std::uint64_t>(ublock_c_idx) & 0xFFFFFF); // 24 bits for ublock_c
metadata_per_core.at(r).nz_ublocks.insert(ublock_key);
metadata_per_core.at(r).nz_strips.insert(strip_idx);
}

// Copy over the results to return object
//
for (int r = 0; r < grid_r; ++r)
{

// Previous solution multiplied with t_factor_c - copying that behaviour here... However, it's not obvious
// whether this is correct
// # tenstorrent/pybuda#2598
//
sparse_metadata.nz_tiles.at(r) = metadata_per_core.at(r).nz_tiles * t_factor_c;
sparse_metadata.nz_ublocks.at(r) = metadata_per_core.at(r).nz_ublocks.size() * t_factor_c;
sparse_metadata.nz_strips.at(r) = metadata_per_core.at(r).nz_strips.size() * t_factor_c;
}

return std::make_shared<const OpModel::SparseMetadata>(sparse_metadata);
sum_nz_tiles *= t_factor_c;
sum_nz_ublocks *= t_factor_c;
sum_nz_strips *= t_factor_c;
return std::make_tuple<>(sum_nz_tiles, sum_nz_ublocks, sum_nz_strips);
}

} // namespace tt::balancer
2 changes: 1 addition & 1 deletion pybuda/csrc/balancer/balancer_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ ResourceUsage get_edge_resource_usage_simple(
OpModel const &consumer_op_model,
bool is_queue = false);

std::shared_ptr<const OpModel::SparseMetadata> get_sparse_matmul_metadata(balancer::OpModel const &grid);
std::tuple<uint32_t, uint32_t, uint32_t> get_sparse_matmul_metadata(balancer::OpModel const &grid);

} // namespace tt::balancer

Expand Down
13 changes: 6 additions & 7 deletions pybuda/csrc/balancer/legalizer/legalizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,10 +326,7 @@ static std::pair<FactorizedShape, LegalSparseUKts> calculate_streaming_pars(
graph->data_operands(op_node)[0]->as<graphlib::ConstantInputNode>()->get_sparse_buda();
std::vector<tt::sparse::SparseCOO>& sparse_zs = sparse_buda.sparse_zs;
auto layout = sparse::SparseBUDA::create_layout(sparse_buffer_enable, dir.z_major(), fracture_factor);
int bcast_factor =
(layout == sparse::SparseBUDA::Layout::ZMajor || layout == sparse::SparseBUDA::Layout::ZMajorDataflow)
? sparse_buda.bcast_factor
: 1;
int bcast_factor = (layout == sparse::SparseBUDA::Layout::ZMajor) ? sparse_buda.bcast_factor : 1;

// Each potential t needs to evenly divide output's r-dim but also in1's r-dim (in0's c-dim)
std::vector<graphlib::Edge> operands = graph->operand_data_edges(op_node);
Expand Down Expand Up @@ -943,8 +940,9 @@ static std::vector<BufferModel> calculate_output_buffer_models_for_grid(
static bool is_input_node_parameter_or_constant(const graphlib::Node* node)
{
return node->node_type() == graphlib::NodeType::kInput and
(node->as<graphlib::InputNode>()->is_parameter() or node->as<graphlib::InputNode>()->is_constant() or
node->as<graphlib::InputNode>()->is_optimizer_parameter());
(node->as<graphlib::InputNode>()->is_parameter()
or node->as<graphlib::InputNode>()->is_constant()
or node->as<graphlib::InputNode>()->is_optimizer_parameter());
}

static std::vector<BufferModel> calculate_parameter_buffer_models_for_grid(
Expand All @@ -957,7 +955,8 @@ static std::vector<BufferModel> calculate_parameter_buffer_models_for_grid(
parameter_buffers.resize(operands.size());
for (int input_idx = 0; input_idx < (int)operands.size(); ++input_idx)
{
if (is_input_node_parameter_or_constant(operands[input_idx]) and not force_dram_parameters)
if (is_input_node_parameter_or_constant(operands[input_idx]) and
not force_dram_parameters)
{
TensorShape const& parameter_shape = op_shape.producer_shapes[input_idx];
int grid_r = FactorizedInt(parameter_shape.rt).get_nearest_factor_le(selected_grid.r);
Expand Down
1 change: 0 additions & 1 deletion pybuda/csrc/balancer/policies/policy_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,6 @@ const OpModel* pick_preferred_op_model(
{
auto op_models = current_graph_solver.at(op);
const OpModel* prefered_op_model = nullptr;

for (const auto& op_model : op_models)
{
log_trace(
Expand Down
15 changes: 3 additions & 12 deletions pybuda/csrc/balancer/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include "balancer/policies/policy_utils.hpp"
#include "balancer/python_interface.hpp"
#include "balancer/balancer_utils.hpp"
#include "balancer/types.hpp"
#include "graph_lib/utils.hpp"
#include "placer/placer.hpp"
#include "passes/fuse_ops.hpp"
Expand Down Expand Up @@ -361,24 +360,16 @@ void BalancerModule(py::module &m_balancer) {
.def_readonly("output_buffers", &OpModel::output_buffers)
.def_readonly("parameter_buffers", &OpModel::parameter_buffers)
.def_readonly("is_sparse_matmul", &OpModel::is_sparse_matmul)
.def("get_sparse_metadata", &OpModel::get_sparse_metadata)
.def_readonly("nz_tiles", &OpModel::nz_tiles)
.def_readonly("nz_ublocks", &OpModel::nz_ublocks)
.def_readonly("nz_strips", &OpModel::nz_strips)
.def("block_shape", &OpModel::block_shape)
.def("__repr__", [](OpModel const& a) {
std::stringstream ss;
ss << a;
return ss.str();
});

py::class_<OpModel::SparseMetadata>(m_balancer, "SparseMetadata")
.def_readonly("nz_tiles", &OpModel::SparseMetadata::nz_tiles)
.def_readonly("nz_ublocks", &OpModel::SparseMetadata::nz_ublocks)
.def_readonly("nz_strips", &OpModel::SparseMetadata::nz_strips)
.def("__repr__", [](OpModel::SparseMetadata const& a) {
std::stringstream ss;
ss << a;
return ss.str();
});

py::class_<FusedSubOpModel>(m_balancer, "FusedSubOpModel")
.def_readonly("type", &FusedSubOpModel::type)
.def_readonly("mblock_m", &FusedSubOpModel::mblock_m)
Expand Down
9 changes: 6 additions & 3 deletions pybuda/csrc/balancer/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,15 +229,18 @@ int OpModel::get_execution_cycles_uncached(std::string const &arch_name, bool th
{
std::shared_ptr<FusedOp> fused_op = this->fused_op();

// Calculate sparse matmul metadata and write into OpModel's SparseMetadata struct
// Calculate sparse-matmul metadata and cache the result
if (env_as<bool>("PYBUDA_TEMP_ENABLE_NEW_SPARSE_ESTIMATES", false) and this->is_sparse_matmul and
this->sparse_metadata == nullptr)
this->nz_ublocks == -1)
{
auto mf = this->math_fidelity();
if (mf == tt::MathFidelity::HiFi2 or mf == tt::MathFidelity::LoFi)
{
auto [nz_tiles, nz_ublocks, nz_strips] = get_sparse_matmul_metadata(*this);
auto *p_this = const_cast<OpModel *>(this);
p_this->sparse_metadata = get_sparse_matmul_metadata(*this);
p_this->nz_tiles = nz_tiles;
p_this->nz_ublocks = nz_ublocks;
p_this->nz_strips = nz_strips;
}
}

Expand Down
Loading

0 comments on commit d00a010

Please sign in to comment.