Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support adding context instructions to basic block graphs. #289

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions gematria/basic_block/basic_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -377,11 +377,17 @@ std::ostream& operator<<(std::ostream& os, const Instruction& instruction) {
return os;
}

BasicBlock::BasicBlock(std::vector<Instruction> instructions)
: instructions(std::move(instructions)) {}
BasicBlock::BasicBlock(std::vector<Instruction> instructions,
std::vector<Instruction> back_context,
std::vector<Instruction> front_context)
: instructions(std::move(instructions)),
back_context(std::move(back_context)),
front_context(std::move(front_context)) {}

bool BasicBlock::operator==(const BasicBlock& other) const {
return instructions == other.instructions;
return instructions == other.instructions &&
back_context == other.back_context &&
front_context == other.front_context;
}

std::string BasicBlock::ToString() const {
Expand All @@ -395,6 +401,24 @@ std::string BasicBlock::ToString() const {
if (buffer.back() == ' ') buffer.pop_back();
buffer += "))";
}
if (!back_context.empty()) {
buffer += "back_context=InstructionList((";
for (const Instruction& instruction : back_context) {
buffer += instruction.ToString();
buffer += ", ";
}
if (buffer.back() == ' ') buffer.pop_back();
buffer += "))";
}
if (!front_context.empty()) {
buffer += "front_context=InstructionList((";
for (const Instruction& instruction : front_context) {
buffer += instruction.ToString();
buffer += ", ";
}
if (buffer.back() == ' ') buffer.pop_back();
buffer += "))";
}
buffer.push_back(')');
return buffer;
}
Expand Down
16 changes: 12 additions & 4 deletions gematria/basic_block/basic_block.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ std::ostream& operator<<(std::ostream& os, const InstructionOperand& operand);
// Represents an annotation holding a value such as some measure/statistic
// paired with the instruction.
struct Annotation {
Annotation() : value(-1){};
Annotation() : value(-1) {};

// Initializes all fields of the annotation.
Annotation(std::string name, double value);
Expand Down Expand Up @@ -324,9 +324,12 @@ std::ostream& operator<<(std::ostream& os, const Instruction& instruction);
struct BasicBlock {
BasicBlock() {}

// Initializes the basic block from a list of instructions. Needed for
// compatibility with the Python code.
explicit BasicBlock(std::vector<Instruction> instructions);
// Initializes the basic block from a list of instructions and optional
// context. Needed for compatibility with the Python code.
explicit BasicBlock(
std::vector<Instruction> instructions,
std::vector<Instruction> back_context = std::vector<Instruction>(),
std::vector<Instruction> front_context = std::vector<Instruction>());

BasicBlock(const BasicBlock&) = default;
BasicBlock(BasicBlock&&) = default;
Expand All @@ -346,6 +349,11 @@ struct BasicBlock {

// The list of instructions in the basic block.
std::vector<Instruction> instructions;

// The back and front context instructions, i.e. those preceeding and
// following the instructions in the basic block.
std::vector<Instruction> back_context;
std::vector<Instruction> front_context;
};

std::ostream& operator<<(std::ostream& os, const BasicBlock& block);
Expand Down
11 changes: 9 additions & 2 deletions gematria/basic_block/basic_block_protos.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,15 @@ CanonicalizedInstructionProto ProtoFromInstruction(

BasicBlock BasicBlockFromProto(const BasicBlockProto& proto) {
return BasicBlock(
/* instructions = */ ToVector<Instruction>(
proto.canonicalized_instructions(), InstructionFromProto));
/* instructions = */
ToVector<Instruction>(proto.canonicalized_instructions(),
InstructionFromProto),
/* back_context = */
ToVector<Instruction>(proto.canonicalized_back_context(),
InstructionFromProto),
/* front_context = */
ToVector<Instruction>(proto.canonicalized_front_context(),
InstructionFromProto));
}

} // namespace gematria
5 changes: 2 additions & 3 deletions gematria/basic_block/basic_block_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ TEST(InstructionOperandTest, Equality) {

TEST(InstructionOperandTest, ToString) {
const struct {
InstructionOperand opernad;
InstructionOperand operand;
const char* expected_string;
} kTestCases[] = {
{InstructionOperand::Register("RAX"),
Expand All @@ -292,7 +292,7 @@ TEST(InstructionOperandTest, ToString) {
"InstructionOperand.from_memory(32)"}};

for (const auto& test_case : kTestCases) {
EXPECT_EQ(test_case.opernad.ToString(), test_case.expected_string);
EXPECT_EQ(test_case.operand.ToString(), test_case.expected_string);
}
}

Expand All @@ -318,7 +318,6 @@ TEST(InstructionOperandTest, AsTokenList) {
}
}

// TODO(virajbshah): Add tests for Annotation.
TEST(AnnotationTest, Constructor) {
constexpr char kName[] = "cache_miss_freq";
constexpr double kValue = 0.875;
Expand Down
10 changes: 8 additions & 2 deletions gematria/basic_block/python/basic_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,15 @@ PYBIND11_MODULE(basic_block, m) {

py::class_<BasicBlock> basic_block(m, "BasicBlock");
basic_block
.def(py::init<std::vector<Instruction> /* instructions */>(),
py::arg("instructions") = std::vector<Instruction>())
.def(py::init<std::vector<Instruction> /* instructions */,
std::vector<Instruction> /* back_context */,
std::vector<Instruction> /* front_context */>(),
py::arg("instructions") = std::vector<Instruction>(),
py::arg("back_context") = std::vector<Instruction>(),
py::arg("front_context") = std::vector<Instruction>())
.def_readwrite("instructions", &BasicBlock::instructions)
.def_readwrite("back_context", &BasicBlock::back_context)
.def_readwrite("front_context", &BasicBlock::front_context)
.def("__repr__", &BasicBlock::ToString)
.def("__str__", &BasicBlock::ToString)
.def("__eq__", &BasicBlock::operator==)
Expand Down
110 changes: 61 additions & 49 deletions gematria/granite/graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ BasicBlockGraphBuilder::BasicBlockGraphBuilder(
}

bool BasicBlockGraphBuilder::AddBasicBlockFromInstructions(
const std::vector<Instruction>& instructions) {
const std::vector<Instruction>& instructions,
const std::vector<Instruction>& back_context,
const std::vector<Instruction>& front_context) {
if (instructions.empty()) return false;
AddBasicBlockTransaction transaction(this);

Expand All @@ -202,58 +204,65 @@ bool BasicBlockGraphBuilder::AddBasicBlockFromInstructions(
const int prev_num_edges = num_edges();

NodeIndex previous_instruction_node = kInvalidNode;
for (const Instruction& instruction : instructions) {
// Add the instruction node.
const NodeIndex instruction_node =
AddNode(NodeType::kInstruction, instruction.mnemonic);
if (instruction_node == kInvalidNode) {
return false;
}
const struct {
const std::vector<Instruction>& instruction_group;
bool is_context;
} instruction_groups[] = {
{back_context, true}, {instructions, false}, {front_context, true}};
for (const auto [instruction_group, is_context] : instruction_groups) {
for (const Instruction& instruction : instruction_group) {
// Add the instruction node.
const NodeIndex instruction_node =
AddNode(NodeType::kInstruction, instruction.mnemonic, is_context);
if (instruction_node == kInvalidNode) {
return false;
}

// Store the annotations for later use (inclusion in embeddings), using -1
// as a default value wherever annotations are missing.
std::vector<float> row = std::vector<float>(annotation_names_.size(), -1);
for (const auto& [name, value] : instruction.instruction_annotations) {
const auto annotation_index = annotation_name_to_idx_.find(name);
if (annotation_index == annotation_name_to_idx_.end()) continue;
row[annotation_index->second] = value;
}
instruction_annotations_.push_back(row);
// Store the annotations for later use (inclusion in embeddings), using -1
// as a default value wherever annotations are missing.
std::vector<float> row = std::vector<float>(annotation_names_.size(), -1);
for (const auto& [name, value] : instruction.instruction_annotations) {
const auto annotation_index = annotation_name_to_idx_.find(name);
if (annotation_index == annotation_name_to_idx_.end()) continue;
row[annotation_index->second] = value;
}
instruction_annotations_.push_back(row);

// Add nodes for prefixes of the instruction.
for (const std::string& prefix : instruction.prefixes) {
const NodeIndex prefix_node = AddNode(NodeType::kPrefix, prefix);
if (prefix_node == kInvalidNode) {
return false;
// Add nodes for prefixes of the instruction.
for (const std::string& prefix : instruction.prefixes) {
const NodeIndex prefix_node = AddNode(NodeType::kPrefix, prefix);
if (prefix_node == kInvalidNode) {
return false;
}
AddEdge(EdgeType::kInstructionPrefix, prefix_node, instruction_node);
}
AddEdge(EdgeType::kInstructionPrefix, prefix_node, instruction_node);
}

// Add a structural dependency edge from the previous instruction.
if (previous_instruction_node >= 0) {
AddEdge(EdgeType::kStructuralDependency, previous_instruction_node,
instruction_node);
}
// Add a structural dependency edge from the previous instruction.
if (previous_instruction_node >= 0) {
AddEdge(EdgeType::kStructuralDependency, previous_instruction_node,
instruction_node);
}

// Add edges for input operands. And nodes too, if necessary.
for (const InstructionOperand& operand : instruction.input_operands) {
if (!AddInputOperand(instruction_node, operand)) return false;
}
for (const InstructionOperand& operand :
instruction.implicit_input_operands) {
if (!AddInputOperand(instruction_node, operand)) return false;
}
// Add edges for input operands. And nodes too, if necessary.
for (const InstructionOperand& operand : instruction.input_operands) {
if (!AddInputOperand(instruction_node, operand)) return false;
}
for (const InstructionOperand& operand :
instruction.implicit_input_operands) {
if (!AddInputOperand(instruction_node, operand)) return false;
}

// Add edges and nodes for output operands.
for (const InstructionOperand& operand : instruction.output_operands) {
if (!AddOutputOperand(instruction_node, operand)) return false;
}
for (const InstructionOperand& operand :
instruction.implicit_output_operands) {
if (!AddOutputOperand(instruction_node, operand)) return false;
}
// Add edges and nodes for output operands.
for (const InstructionOperand& operand : instruction.output_operands) {
if (!AddOutputOperand(instruction_node, operand)) return false;
}
for (const InstructionOperand& operand :
instruction.implicit_output_operands) {
if (!AddOutputOperand(instruction_node, operand)) return false;
}

previous_instruction_node = instruction_node;
previous_instruction_node = instruction_node;
}
}

global_features_.emplace_back(num_node_tokens(), 0);
Expand All @@ -276,6 +285,7 @@ void BasicBlockGraphBuilder::Reset() {

node_types_.clear();
node_features_.clear();
context_node_mask_.clear();

edge_senders_.clear();
edge_receivers_.clear();
Expand Down Expand Up @@ -404,15 +414,16 @@ bool BasicBlockGraphBuilder::AddDependencyOnRegister(
}

BasicBlockGraphBuilder::NodeIndex BasicBlockGraphBuilder::AddNode(
NodeType node_type, TokenIndex token_index) {
NodeType node_type, TokenIndex token_index, bool is_context) {
const NodeIndex new_node_index = num_nodes();
node_types_.push_back(node_type);
node_features_.push_back(token_index);
context_node_mask_.push_back(is_context);
return new_node_index;
}

BasicBlockGraphBuilder::NodeIndex BasicBlockGraphBuilder::AddNode(
NodeType node_type, const std::string& token) {
NodeType node_type, const std::string& token, bool is_context) {
const auto it = node_tokens_.find(token);
TokenIndex token_index = kInvalidTokenIndex;
if (it != node_tokens_.end()) {
Expand All @@ -427,7 +438,7 @@ BasicBlockGraphBuilder::NodeIndex BasicBlockGraphBuilder::AddNode(
token_index = replacement_token_;
}
}
return AddNode(node_type, token_index);
return AddNode(node_type, token_index, is_context);
}

void BasicBlockGraphBuilder::AddEdge(EdgeType edge_type, NodeIndex sender,
Expand Down Expand Up @@ -505,6 +516,7 @@ std::string BasicBlockGraphBuilder::DebugString() const {
StrAppendList(buffer, "num_nodes_per_block", num_nodes_per_block());
StrAppendList(buffer, "num_edges_per_block", num_edges_per_block());
StrAppendList(buffer, "node_types", node_types());
StrAppendList(buffer, "context_node_mask", context_node_mask());
StrAppendList(buffer, "edge_senders", edge_senders());
StrAppendList(buffer, "edge_receivers", edge_receivers());
StrAppendList(buffer, "edge_types", edge_types());
Expand Down
30 changes: 24 additions & 6 deletions gematria/granite/graph_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,14 +187,23 @@ class BasicBlockGraphBuilder {
// method encountered an unknown token and the unknown token behavior is not
// kReplaceToken or when the basic block does not contain any instructions.
// When this happens, the graph builder is left in the previous state, i.e. no
// basic block is added to it.
bool AddBasicBlock(const BasicBlock& block) {
// basic block is added to it. The basic block context is added to the graph
// if and only if `add_context` is true.
bool AddBasicBlock(const BasicBlock& block, bool add_context = false) {
if (add_context) {
return AddBasicBlockFromInstructions(
block.instructions, block.back_context, block.front_context);
}
return AddBasicBlockFromInstructions(block.instructions);
}
// A version of AddBasicBlock that takes the list of instructions in the basic
// block instead of the basic block object itself.
// block and optionally its back and front contexts instead of the basic block
// object itself.
bool AddBasicBlockFromInstructions(
const std::vector<Instruction>& instructions);
const std::vector<Instruction>& instructions,
const std::vector<Instruction>& back_context = std::vector<Instruction>(),
const std::vector<Instruction>& front_context =
std::vector<Instruction>());

// Resets the graph builder so that it can be used to create a new graph from
// scratch.
Expand Down Expand Up @@ -242,6 +251,12 @@ class BasicBlockGraphBuilder {
// Feature value of the nodes in the batch (i.e. the indices of the tokens
// corresponding to the nodes). Corresponds to `GraphsTuple.nodes`.
const std::vector<int>& node_features() const { return node_features_; }
// Whether or not the corresponding node belongs to either the back or front
// context of the basic block, and not the basic block itself. Used by the
// models to exclude context nodes from predictions.
const std::vector<bool>& context_node_mask() const {
return context_node_mask_;
}

// Names of types of instruction annotations stored.
const std::vector<std::string>& annotation_names() const {
Expand Down Expand Up @@ -375,11 +390,13 @@ class BasicBlockGraphBuilder {

// Adds a new node to the batch; the feature of the node is given directly by
// the caller.
NodeIndex AddNode(NodeType node_type, TokenIndex token_index);
NodeIndex AddNode(NodeType node_type, TokenIndex token_index,
bool is_context = false);
// Adds a new edge to the batch; the feature of the node is determined from
// the token associated with the node. Returns kInvalidNode when the node was
// not added.
NodeIndex AddNode(NodeType node_type, const std::string& token);
NodeIndex AddNode(NodeType node_type, const std::string& token,
bool is_context = false);
// Adds a new edge to the batch.
void AddEdge(EdgeType edge_type, NodeIndex sender, NodeIndex receiver);

Expand All @@ -406,6 +423,7 @@ class BasicBlockGraphBuilder {

std::vector<NodeType> node_types_;
std::vector<TokenIndex> node_features_;
std::vector<bool> context_node_mask_;

// Mapping from annotation type names to corresponding row index in the
// `instruction_annotations_` matrix.
Expand Down
Loading
Loading