generated from dtreai/Python-Template
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
20 changed files
with
386 additions
and
129 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,7 @@ on: | |
branches: | ||
- main | ||
paths: | ||
- 'my_package/**' | ||
- 'tri_rmsnorm/**' | ||
|
||
jobs: | ||
changelog: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
# package specific | ||
benchmarking/* | ||
|
||
# build artifacts | ||
|
||
.eggs/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,111 +1,107 @@ | ||
# python-package-template | ||
# Tri-RMSNorm | ||
|
||
This is a template repository for Python package projects. | ||
|
||
## In this README :point_down: | ||
|
||
- [Features](#features) | ||
- [Usage](#usage) | ||
- [Initial setup](#initial-setup) | ||
- [Creating releases](#creating-releases) | ||
- [Projects using this template](#projects-using-this-template) | ||
- [FAQ](#faq) | ||
- [Contributing](#contributing) | ||
This small package provides an custom Triton kernel of RMS layer normalization with fused operations, leveraging the Triton compiler by OpenAI for high performance on GPUs. Implementation includes both forward and backward passes of RMS layer normalization, optimized for empowering deep learning training and inferencing. | ||
|
||
## Features | ||
|
||
This template repository comes with all of the boilerplate needed for: | ||
|
||
⚙️ Robust (and free) CI with [GitHub Actions](https://github.com/features/actions): | ||
- Unit tests ran with [PyTest](https://docs.pytest.org) against multiple Python versions and operating systems. | ||
- Type checking with [mypy](https://github.com/python/mypy). | ||
- Linting with [ruff](https://astral.sh/ruff). | ||
- Formatting with [isort](https://pycqa.github.io/isort/) and [black](https://black.readthedocs.io/en/stable/). | ||
|
||
🤖 [Dependabot](https://github.blog/2020-06-01-keep-all-your-packages-up-to-date-with-dependabot/) configuration to keep your dependencies up-to-date. | ||
|
||
📄 Great looking API documentation built using [Sphinx](https://www.sphinx-doc.org/en/master/) (run `make docs` to preview). | ||
**Customized FW/BW RMS Normalization:** | ||
|
||
🚀 Automatic GitHub and PyPI releases. Just follow the steps in [`RELEASE_PROCESS.md`](./RELEASE_PROCESS.md) to trigger a new release. | ||
Implements the forward and backward passes of RMS normalization with fused operations for better performance. | ||
|
||
## Usage | ||
|
||
### Initial setup | ||
|
||
1. [Create a new repository](https://github.com/allenai/python-package-template/generate) from this template with the desired name of your project. | ||
**Triton and PyTorch Integration:** | ||
|
||
*Your project name (i.e. the name of the repository) and the name of the corresponding Python package don't necessarily need to match, but you might want to check on [PyPI](https://pypi.org/) first to see if the package name you want is already taken.* | ||
Utilizes Triton for GPU-accelerated computations and parallel computation, seamlessly integrated with PyTorch tensors. | ||
|
||
2. Create a Python 3.8 or newer virtual environment. | ||
**Customizable:** | ||
|
||
*If you're not sure how to create a suitable Python environment, the easiest way is using [Miniconda](https://docs.conda.io/en/latest/miniconda.html). On a Mac, for example, you can install Miniconda using [Homebrew](https://brew.sh/):* | ||
Compile-time constants for block sizes, accommodating different GPU architectures and memory layouts. | ||
|
||
``` | ||
brew install miniconda | ||
``` | ||
**Atomic Operations for Gradient Accumulation:** | ||
|
||
*Then you can create and activate a new Python environment by running:* | ||
Atomic operations to safely accumulate gradients across threads, preventing race conditions and ensuring correct gradient computation during the backward pass. | ||
|
||
``` | ||
conda create -n my-package python=3.9 | ||
conda activate my-package | ||
``` | ||
**Lock-Free Mechanisms:** | ||
|
||
3. Now that you have a suitable Python environment, you're ready to personalize this repository. Just run: | ||
Advanced sync to minimize locking and blocking, improving the performance and scalability of gradient computation. | ||
|
||
``` | ||
pip install -r setup-requirements.txt | ||
python scripts/personalize.py | ||
``` | ||
## Getting Started | ||
|
||
And then follow the prompts. | ||
## **Installation** | ||
|
||
:pencil: *NOTE: This script will overwrite the README in your repository.* | ||
**Requirements** | ||
|
||
4. Commit and push your changes, then make sure all GitHub Actions jobs pass. | ||
```bash | ||
torch==2.1.0+cu121 | ||
torchaudio==2.1.0+cu121 | ||
torchvision==0.16.0+cu121 | ||
triton==2.1.0 | ||
``` | ||
|
||
5. (Optional) If you plan on publishing your package to PyPI, add repository secrets for `PYPI_USERNAME` and `PYPI_PASSWORD`. To add these, go to "Settings" > "Secrets" > "Actions", and then click "New repository secret". | ||
You can install the package using `pip3 install -e .`: | ||
|
||
*If you don't have PyPI account yet, you can [create one for free](https://pypi.org/account/register/).* | ||
```bash | ||
git clone https://github.com/simudt/Tri-RMSNorm | ||
cd Tri-RMSNorm | ||
pip3 install -e . | ||
``` | ||
|
||
6. (Optional) If you want to deploy your API docs to [readthedocs.org](https://readthedocs.org), go to the [readthedocs dashboard](https://readthedocs.org/dashboard/import/?) and import your new project. | ||
Then click on the "Admin" button, navigate to "Automation Rules" in the sidebar, click "Add Rule", and then enter the following fields: | ||
## Usage | ||
|
||
- **Description:** Publish new versions from tags | ||
- **Match:** Custom Match | ||
- **Custom match:** v[vV] | ||
- **Version:** Tag | ||
- **Action:** Activate version | ||
The package provides two main functions: | ||
|
||
Then hit "Save". | ||
- `_rms_norm_fwd_fused` for the forward pass of RMS normalization | ||
|
||
*After your first release, the docs will automatically be published to [your-project-name.readthedocs.io](https://your-project-name.readthedocs.io/).* | ||
- `_rms_norm_bwd_dx_fused` for the backward pass, computing gradients with respect to X, W, B | ||
|
||
### Creating releases | ||
```python | ||
class RMSNormFunctionCustomKernel(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, x, weight, bias, eps): | ||
M, N = x.shape | ||
y = torch.empty_like(x) | ||
rstd = torch.empty(M, dtype=torch.float32, device=x.device) | ||
_rms_norm_fwd_fused[(M,)](x, y, weight, bias, rstd, x.stride(0), N, eps, BLOCK_SIZE=1024) | ||
ctx.save_for_backward(x, weight, bias, rstd) | ||
ctx.eps = eps | ||
ctx.N = N | ||
return y | ||
|
||
Creating new GitHub and PyPI releases is easy. The GitHub Actions workflow that comes with this repository will handle all of that for you. | ||
All you need to do is follow the instructions in [RELEASE_PROCESS.md](./RELEASE_PROCESS.md). | ||
@staticmethod | ||
def backward(ctx, dy): | ||
x, weight, bias, rstd = ctx.saved_tensors | ||
eps = ctx.eps | ||
N = ctx.N | ||
M = x.shape[0] | ||
dx = torch.empty_like(x) | ||
_dw = torch.empty_like(weight) | ||
_db = torch.empty_like(bias) | ||
locks = torch.zeros(2 * 32, dtype=torch.int32, device=x.device) | ||
_rms_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, weight, bias, rstd, locks, x.stride(0), N, eps, GROUP_SIZE_M=32, BLOCK_SIZE_N=1024) | ||
return dx, _dw, _db, None | ||
|
||
## Projects using this template | ||
def test_rms_norm_custom_kernel(): | ||
eps = 1e-5 | ||
input = torch.tensor([[0.1, -0.2] * 10] * 10, device='cuda', requires_grad=True) | ||
weights = torch.tensor([0.1] * 20, device='cuda', requires_grad=True) | ||
biases = torch.tensor([0.01] * 20, device='cuda', requires_grad=True) | ||
|
||
Here is an incomplete list of some projects that started off with this template: | ||
output = RMSNormFunctionCustomKernel.apply(input, weights, biases, eps) | ||
loss = output.mean() | ||
loss.backward() | ||
|
||
- [ai2-tango](https://github.com/allenai/tango) | ||
- [cached-path](https://github.com/allenai/cached_path) | ||
- [beaker-py](https://github.com/allenai/beaker-py) | ||
- [gantry](https://github.com/allenai/beaker-gantry) | ||
- [ip-bot](https://github.com/abe-101/ip-bot) | ||
print("Gradients on Input: ", input.grad) | ||
print("Gradients on Weights: ", weights.grad) | ||
print("Gradients on Biases: ", biases.grad) | ||
|
||
☝️ *Want your work featured here? Just open a pull request that adds the link.* | ||
test_rms_norm_custom_kernel() | ||
``` | ||
|
||
## FAQ | ||
Adjust grid, block, and other parameters as per your requirements and GPU specifications. | ||
|
||
#### Should I use this template even if I don't want to publish my package? | ||
## Benchmark | ||
|
||
Absolutely! If you don't want to publish your package, just delete the `docs/` directory and the `release` job in [`.github/workflows/main.yml`](https://github.com/allenai/python-package-template/blob/main/.github/workflows/main.yml). | ||
Tri-RMSNorm kernel demonstrates improved speedup in initial benchmarks when compared to the PyTorch-based custom RMSNorm implementation. Benchmarks will be included in the repository to ensure reproducibility. | ||
|
||
## Contributing | ||
## License | ||
|
||
If you find a bug :bug:, please open a [bug report](https://github.com/allenai/python-package-template/issues/new?assignees=&labels=bug&template=bug_report.md&title=). | ||
If you have an idea for an improvement or new feature :rocket:, please open a [feature request](https://github.com/allenai/python-package-template/issues/new?assignees=&labels=Feature+request&template=feature_request.md&title=). | ||
This package is licensed under the Apache License - see the LICENSE file for details. |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" | |
|
||
[project] | ||
# See https://setuptools.pypa.io/en/latest/userguide/quickstart.html for more project configuration options. | ||
name = "my-package" | ||
name = "tri_rmsnorm" | ||
dynamic = ["version"] | ||
readme = "README.md" | ||
classifiers = [ | ||
|
@@ -14,21 +14,12 @@ classifiers = [ | |
"Programming Language :: Python :: 3", | ||
"Topic :: Scientific/Engineering :: Artificial Intelligence", | ||
] | ||
authors = [ | ||
{name = "Allen Institute for Artificial Intelligence", email = "[email protected]"} | ||
] | ||
requires-python = ">=3.8" | ||
dependencies = [ | ||
# Add your own dependencies here | ||
] | ||
license = {file = "LICENSE"} | ||
|
||
[project.urls] | ||
Homepage = "https://github.com/allenai/python-package-template" | ||
Repository = "https://github.com/allenai/python-package-template" | ||
Changelog = "https://github.com/allenai/python-package-template/blob/main/CHANGELOG.md" | ||
# Documentation = "https://my-package.readthedocs.io/" | ||
|
||
[project.optional-dependencies] | ||
dev = [ | ||
"ruff", | ||
|
@@ -65,10 +56,10 @@ exclude = [ | |
include-package-data = true | ||
|
||
[tool.setuptools.package-data] | ||
my_package = ["py.typed"] | ||
tri_rmsnorm = ["py.typed"] | ||
|
||
[tool.setuptools.dynamic] | ||
version = {attr = "my_package.version.VERSION"} | ||
version = {attr = "tri_rmsnorm.version.VERSION"} | ||
|
||
[tool.black] | ||
line-length = 100 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
click>=7.0,<9.0 | ||
click-help-colors>=0.9.1,<0.10 | ||
rich>=11.0,<14.0 | ||
torch==2.1.0+cu121 | ||
torchaudio==2.1.0+cu121 | ||
torchvision==0.16.0+cu121 | ||
triton==2.1.0 |
Oops, something went wrong.