Skip to content

Commit

Permalink
[train] optimizer on device (#928)
Browse files Browse the repository at this point in the history
With this change we can now run complete training loop on the device.
The last missing part was the optimizer.

For now, only forge optimizers are supported, since the torch optimizers
are not of `nn.Module` type (so we don't have a way to compile them
afaik).

To run the optimizer on the device, pass the forge optimizer into the
`forge.compile()` when compiling a model with trainable parameters.

The compile flow:
 - passed optimizer from the user is sent through to the autograd pass
- the autograd constructs the optimizer graph by calling
`generate_op_trace()` on the optimizer for each trainable parameter;
this function creates a subgraph which defines optimizer step for
particular parameter which is then merged into the main graph
- before lowering to mlir we split the graph into multiple graphs
(forward, backward, optimizer) as we did before this change
- finally, all of the optimizer parameters are stored in the
`CompiledModel` for the runtime, and the `CompiledModel` is linked to
the optimizer - this enables user to call `optimizer.step()` which will
in turn execute optimizer graphs for all linked models

Since we don't have a way to implement in-place updates yet, one major
workaround in this change is the introducing of aliased tensor. This is
done so that we can update the parameters' values after the execution of
the optimizer graph. E.g. `updated_weight = weight - lr * grad`, where
`updated_weight` output is aliased to the `weight` tensor, so that the
runtime can swap out the original weight tensor's data with the updated
ones.

Tests for compiling models with SGD, Adam, Adamw forge optimizers are
added as well as e2e test for running MNIST training (with SGD
optimizer) on the device is added.

Closes #176, closes #178
  • Loading branch information
pilkicTT authored Jan 13, 2025
1 parent a05e77c commit 7866dfe
Show file tree
Hide file tree
Showing 12 changed files with 501 additions and 171 deletions.
21 changes: 19 additions & 2 deletions forge/csrc/graph_lib/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,11 +437,11 @@ std::vector<Edge> Graph::user_data_edges_for_operand_port(const Node *node, Port
return result;
}

std::vector<Node *> Graph::users(const Node *node) const
std::vector<Node *> Graph::users(const Node *node, std::function<bool(Edge)> edge_filter) const
{
std::vector<Node *> user_nodes;

for (auto &user_edge : this->user_edges(node))
for (auto &user_edge : this->user_edges(node, edge_filter))
{
NodeId consumer_node_id = user_edge.consumer_node_id;
Node *consumer_node = node_by_id(consumer_node_id);
Expand Down Expand Up @@ -1096,6 +1096,23 @@ std::vector<std::string> Graph::get_constant_names() const
return constant_names;
}

std::vector<Node *> Graph::get_optimizer_parameter_nodes() const
{
std::vector<Node *> parameters;
for (Node *node : nodes_)
{
if (node->node_type() == NodeType::kInput)
{
InputNode *input_node = node->as<InputNode>();
if (input_node->is_optimizer_parameter())
{
parameters.push_back(node);
}
}
}
return parameters;
}

std::vector<Node *> Graph::get_parameter_nodes() const
{
std::vector<Node *> parameters;
Expand Down
4 changes: 3 additions & 1 deletion forge/csrc/graph_lib/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ class Graph
const std::unordered_set<Edge> &user_edges_set(const Node *node) const;
std::vector<Node *> operands(const Node *node) const;
std::vector<Node *> data_operands(const Node *node) const;
std::vector<Node *> users(const Node *node) const;
std::vector<Node *> users(
const Node *node, std::function<bool(Edge)> edge_filter = [](Edge) { return true; }) const;
std::vector<Node *> data_users(const Node *node) const;
std::unordered_set<NodeId> node_ids();

Expand Down Expand Up @@ -251,6 +252,7 @@ class Graph
std::vector<Node *> ordered_intermediates() const;
std::vector<Node *> get_constant_nodes(bool recurse = false) const;
std::vector<Node *> get_parameter_nodes() const;
std::vector<Node *> get_optimizer_parameter_nodes() const;
std::vector<std::string> get_constant_names() const;
std::vector<std::string> get_ordered_input_names() const;
std::vector<std::string> get_ordered_intermediate_names() const;
Expand Down
1 change: 1 addition & 0 deletions forge/csrc/graph_lib/node_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ std::unique_ptr<Node> InputNode::clone(std::string const &name) const
node->tile_broadcast_dims_ = tile_broadcast_dims_;
node->runtime_tensor_transform = runtime_tensor_transform;
node->add_tags(this->as<TaggedNode>()->get_tags());
node->requires_grad_ = requires_grad_;
return node;
}

Expand Down
20 changes: 20 additions & 0 deletions forge/csrc/graph_lib/node_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,8 @@ class OutputNode : public QueueNode
{
protected:
bool requires_grad_;
bool aliased_tensor_;
std::string alias_;
bool is_loss_output_;
bool is_intermediate_;
bool untilize_;
Expand All @@ -333,6 +335,7 @@ class OutputNode : public QueueNode
OutputNode(std::string name) :
QueueNode(name, QueueNodeType::Output, NodeType::kOutput),
requires_grad_(false),
aliased_tensor_(false),
is_loss_output_(false),
is_intermediate_(false),
untilize_(true),
Expand All @@ -349,6 +352,23 @@ class OutputNode : public QueueNode
void set_intermediate(bool intermediate) { is_intermediate_ = intermediate; }
void set_untilize(bool should_untilize) { untilize_ = should_untilize; }
void set_output_type(OutputType output_type) { output_type_ = output_type; }

void set_alias(const InputNode *node)
{
alias_ = node->name();
aliased_tensor_ = true;
}

// Indicates if this output node is actually an alias to an input node. This is used in optimizer graphs, where
// we want to update a parameter (e.g. `param = param - lr * grad`), but since the rest of the stack doesn't support
// this yet, we create a new output node that is an alias to the parameter (input) node. So we'll end up with
// something like this: `updated_param = param - lr * grad`, where `updated_param` is aliased to `param`. Then in
// the runtime we'll make sure to update the `param` tensor to point to the new data.
bool is_aliased_tensor() const { return aliased_tensor_; }

// Returns the name of the input node that this output node is aliased to.
std::string alias() const { return alias_; }

virtual std::unique_ptr<Node> clone(std::string const &name = "") const override;

void set_runtime_tensor_transform(RuntimeTensorTransform transform) { this->runtime_tensor_transform = transform; }
Expand Down
11 changes: 11 additions & 0 deletions forge/csrc/graph_lib/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ void GraphModule(py::module &m_graph)
.def("get_ordered_input_names", &Graph::get_ordered_input_names)
.def("get_ordered_intermediate_names", &Graph::get_ordered_intermediate_names)
.def("get_ordered_output_names", &Graph::get_ordered_output_names)
.def("get_ordered_output_nodes", &Graph::ordered_module_outputs)
.def("get_ordered_external_output_names", &Graph::get_ordered_external_output_names)
.def("get_ordered_target_names", &Graph::get_ordered_target_names)
.def("get_ordered_intermediate_names", &Graph::get_ordered_intermediate_names)
Expand All @@ -109,6 +110,7 @@ void GraphModule(py::module &m_graph)
py::arg("recurse") = false)
.def("get_subgraph_id_for_node", &Graph::get_subgraph_id_for_node)
.def("get_parameter_nodes", &Graph::get_parameter_nodes, py::return_value_policy::reference)
.def("get_optimizer_parameter_nodes", &Graph::get_optimizer_parameter_nodes, py::return_value_policy::reference)
.def(
"register_module_inputs",
&Graph::register_module_inputs,
Expand Down Expand Up @@ -251,6 +253,15 @@ void GraphModule(py::module &m_graph)
.def_property_readonly("shape", &Node::shape)
.def_property_readonly("output_df", &Node::output_df);

py::class_<tt::graphlib::OutputNode, tt::raw_ptr<tt::graphlib::OutputNode>>(m_graph, "OutputNode")
.def_property_readonly("id", &Node::id)
.def_property_readonly("name", &Node::name)
.def_property_readonly("node_type", &Node::node_type)
.def_property_readonly("shape", &Node::shape)
.def_property_readonly("output_df", &Node::output_df)
.def_property_readonly("is_aliased", &graphlib::OutputNode::is_aliased_tensor)
.def_property_readonly("alias", &graphlib::OutputNode::alias);

py::class_<graphlib::NodeContext>(m_graph, "NodeContext")
.def_readonly("id", &graphlib::NodeContext::id)
.def_readonly("name", &graphlib::NodeContext::name)
Expand Down
8 changes: 6 additions & 2 deletions forge/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ class MLIRGenerator
/// for the function.
mlir::func::FuncOp emit_mlir_function(tt::graphlib::Graph *graph, std::string fn_name = "forward")
{
log_info("Emmiting mlir for function {}", fn_name);
log_info("Emitting mlir for function {}", fn_name);
// Assemble the function arguments (inputs and parameters)
llvm::SmallVector<mlir::Type> argument_types;
llvm::SmallVector<graphlib::Node *> argument_nodes;
Expand All @@ -234,7 +234,11 @@ class MLIRGenerator
}

// Add the graph parameters to the argument list.
for (auto *parameter : graph->get_parameter_nodes())
// Both optimizer parameters and regular parameters are added.
auto opt_params = graph->get_optimizer_parameter_nodes();
auto params = graph->get_parameter_nodes();
params.insert(params.end(), opt_params.begin(), opt_params.end());
for (auto *parameter : params)
{
log_trace(LogMLIRCompiler, "Adding parameter {} to the argument list.", parameter->name());

Expand Down
Loading

0 comments on commit 7866dfe

Please sign in to comment.