-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
With this change we can now run complete training loop on the device. The last missing part was the optimizer. For now, only forge optimizers are supported, since the torch optimizers are not of `nn.Module` type (so we don't have a way to compile them afaik). To run the optimizer on the device, pass the forge optimizer into the `forge.compile()` when compiling a model with trainable parameters. The compile flow: - passed optimizer from the user is sent through to the autograd pass - the autograd constructs the optimizer graph by calling `generate_op_trace()` on the optimizer for each trainable parameter; this function creates a subgraph which defines optimizer step for particular parameter which is then merged into the main graph - before lowering to mlir we split the graph into multiple graphs (forward, backward, optimizer) as we did before this change - finally, all of the optimizer parameters are stored in the `CompiledModel` for the runtime, and the `CompiledModel` is linked to the optimizer - this enables user to call `optimizer.step()` which will in turn execute optimizer graphs for all linked models Since we don't have a way to implement in-place updates yet, one major workaround in this change is the introducing of aliased tensor. This is done so that we can update the parameters' values after the execution of the optimizer graph. E.g. `updated_weight = weight - lr * grad`, where `updated_weight` output is aliased to the `weight` tensor, so that the runtime can swap out the original weight tensor's data with the updated ones. Tests for compiling models with SGD, Adam, Adamw forge optimizers are added as well as e2e test for running MNIST training (with SGD optimizer) on the device is added. Closes #176, closes #178
- Loading branch information
Showing
12 changed files
with
501 additions
and
171 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.