-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[training] bootstrapping training #99
Conversation
"Handling of intermediate results and storing them as outputs of the forward pass to be used later in the backward pass is NOT implemented."
|
autograd_engine = pyautograd.AutogradEngine(graph, autograd_config) | ||
|
||
graph = autograd_engine.run() | ||
dump_graph(graph, graph_name, "post_autograd") | ||
|
||
context.losses = calculate_grads(outputs, dev, intermediate_tensors, False, context.losses) | ||
# GOLDEN: | ||
# context.losses = calculate_grads(outputs, dev, intermediate_tensors, False, context.losses) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this one commented out? We don't track Golden for the initial implementation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Commented out because it uses the deleted TTDevice
object. I plan to file issues for golden verification refactor - currently, the calculations needed for golden verifications are sprinkled throughout compile.py
, so they need to be centralized and have dependency on TTDevice
removed.
|
||
if self.compiled_graph_state.graph.training(): | ||
# For executing loss and its backward graph on CPU, we need to tell torch to compute gradients | ||
for output in outputs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think this needs to be part of compiler logic? E.g. can we set the PyTorch model into the .train() mode, and just have here a check if the PyTorch module is in a valid state?
My reasoning is that if user executes something on CPU, they should handle it properly. For ease of use, compiler should have robust checks to inform user it something is out of order.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think this needs to be part of compiler logic?
I am not sure - but what i know for certain is that the APIs are still messy, and we'll need to iterate on this. 😄
we can sync offline further, but the reason i put this logic here is that we are creating/returning the output tensors... so, if we are creating them and know that they will further be used for backward pass, why not set the flag...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My 2 cents - APIs are messy by design, as described in #15 :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Initial changes for bootstrapping training. - expanding `pybuda.compile` to accept loss and optimizer module. - running autograd - forward and backward graphs are lowered into tt-mlir as two separate functions - resulting binary has two programs (fwd and bwd) - expanding `CompiledModel` with `backward()` For initial training support we are going with executing loss and optimizer modules on CPU. Currently, the most problematic part is isolating and lowering fwd and bwd parts of the graph, since the autograd is creating a single graph. For now, I've hacked it up using `GraphTraversalContext`. In changes to follow, we will separate these two into different graphs. That should make things much cleaner - e.g., each graph will have precisely defined inputs and outputs unlike now, where we need to deduce what those are for a particular subgraph. Handling of intermediate results and storing them as outputs of the forward pass to be used later in the backward pass is NOT implemented. This will follow after the autograd changes are done and we have two separate graphs. For separating backward and forward executions, I've experimented with two approaches: 1. creating two binaries 2. creating one binary with two programs The second option seems cleaner and will simplify handling of the binary, but in future we might support both approaches. Basic test for training single parametrized multiply op is added. Issues #17 #18 #19 #20
312cefa
to
737fcb7
Compare
Initial changes for bootstrapping training.
pybuda.compile
to accept loss and optimizer module.CompiledModel
withbackward()
For initial training support we are going with executing loss and optimizer modules on CPU.
Currently, the most problematic part is isolating and lowering fwd and bwd parts of the graph, since the autograd is creating a single graph. For now, I've hacked it up using
GraphTraversalContext
. In changes to follow, we will separate these two into different graphs. That should make things much cleaner - e.g., each graph will have precisely defined inputs and outputs unlike now, where we need to deduce what those are for a particular subgraph.Handling of intermediate results and storing them as outputs of the forward pass to be used later in the backward pass is NOT implemented. This will follow after the autograd changes are done and we have two separate graphs.
For separating backward and forward executions, I've experimented with two approaches:
The second option seems cleaner and will simplify handling of the binary, but in future we might support both approaches.
Basic test for training single parametrized multiply op is added.
Issues #17 #18 #19 #20