diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index b1c9597..105a608 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -139,7 +139,7 @@ When you're ready to contribute code to address an open issue, please follow the We also strive to maintain high test coverage, so most contributions should include additions to [the unit tests](https://github.com/allenai/python-package-template/tree/main/tests). These tests are run with [`pytest`](https://docs.pytest.org/en/latest/), which you can use to locally run any test modules that you've added or changed. - For example, if you've fixed a bug in `my_package/a/b.py`, you can run the tests specific to that module with + For example, if you've fixed a bug in `tri_rmsnorm/a/b.py`, you can run the tests specific to that module with pytest -v tests/a/b_test.py diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index 547ae3d..6142eb6 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -17,7 +17,7 @@ body: ```python # All necessary imports at the beginning - import my_package + import tri_rmsnorm # A succinct reproducing example trimmed down to the essential parts: assert False is True, "Oh no!" diff --git a/.github/workflows/pr_checks.yml b/.github/workflows/pr_checks.yml index 614553c..799f4ae 100644 --- a/.github/workflows/pr_checks.yml +++ b/.github/workflows/pr_checks.yml @@ -9,7 +9,7 @@ on: branches: - main paths: - - 'my_package/**' + - 'tri_rmsnorm/**' jobs: changelog: diff --git a/.gitignore b/.gitignore index 8d3db15..cd93a26 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# package specific +benchmarking/* + # build artifacts .eggs/ diff --git a/Makefile b/Makefile index 214718f..4247489 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ .PHONY : docs docs : rm -rf docs/build/ - sphinx-autobuild -b html --watch my_package/ docs/source/ docs/build/ + sphinx-autobuild -b html --watch tri_rmsnorm/ docs/source/ docs/build/ .PHONY : run-checks run-checks : @@ -9,7 +9,7 @@ run-checks : black --check . ruff check . mypy . - CUDA_VISIBLE_DEVICES='' pytest -v --color=yes --doctest-modules tests/ my_package/ + CUDA_VISIBLE_DEVICES='' pytest -v --color=yes --doctest-modules tests/ tri_rmsnorm/ .PHONY : build build : diff --git a/README.md b/README.md index ff4556f..9185f88 100644 --- a/README.md +++ b/README.md @@ -1,111 +1,107 @@ -# python-package-template +# Tri-RMSNorm -This is a template repository for Python package projects. - -## In this README :point_down: - -- [Features](#features) -- [Usage](#usage) - - [Initial setup](#initial-setup) - - [Creating releases](#creating-releases) -- [Projects using this template](#projects-using-this-template) -- [FAQ](#faq) -- [Contributing](#contributing) +This small package provides an custom Triton kernel of RMS layer normalization with fused operations, leveraging the Triton compiler by OpenAI for high performance on GPUs. Implementation includes both forward and backward passes of RMS layer normalization, optimized for empowering deep learning training and inferencing. ## Features -This template repository comes with all of the boilerplate needed for: - -⚙️ Robust (and free) CI with [GitHub Actions](https://github.com/features/actions): - - Unit tests ran with [PyTest](https://docs.pytest.org) against multiple Python versions and operating systems. - - Type checking with [mypy](https://github.com/python/mypy). - - Linting with [ruff](https://astral.sh/ruff). - - Formatting with [isort](https://pycqa.github.io/isort/) and [black](https://black.readthedocs.io/en/stable/). - -🤖 [Dependabot](https://github.blog/2020-06-01-keep-all-your-packages-up-to-date-with-dependabot/) configuration to keep your dependencies up-to-date. - -📄 Great looking API documentation built using [Sphinx](https://www.sphinx-doc.org/en/master/) (run `make docs` to preview). +**Customized FW/BW RMS Normalization:** -🚀 Automatic GitHub and PyPI releases. Just follow the steps in [`RELEASE_PROCESS.md`](./RELEASE_PROCESS.md) to trigger a new release. +Implements the forward and backward passes of RMS normalization with fused operations for better performance. -## Usage - -### Initial setup - -1. [Create a new repository](https://github.com/allenai/python-package-template/generate) from this template with the desired name of your project. +**Triton and PyTorch Integration:** - *Your project name (i.e. the name of the repository) and the name of the corresponding Python package don't necessarily need to match, but you might want to check on [PyPI](https://pypi.org/) first to see if the package name you want is already taken.* +Utilizes Triton for GPU-accelerated computations and parallel computation, seamlessly integrated with PyTorch tensors. -2. Create a Python 3.8 or newer virtual environment. +**Customizable:** - *If you're not sure how to create a suitable Python environment, the easiest way is using [Miniconda](https://docs.conda.io/en/latest/miniconda.html). On a Mac, for example, you can install Miniconda using [Homebrew](https://brew.sh/):* +Compile-time constants for block sizes, accommodating different GPU architectures and memory layouts. - ``` - brew install miniconda - ``` +**Atomic Operations for Gradient Accumulation:** - *Then you can create and activate a new Python environment by running:* +Atomic operations to safely accumulate gradients across threads, preventing race conditions and ensuring correct gradient computation during the backward pass. - ``` - conda create -n my-package python=3.9 - conda activate my-package - ``` +**Lock-Free Mechanisms:** -3. Now that you have a suitable Python environment, you're ready to personalize this repository. Just run: +Advanced sync to minimize locking and blocking, improving the performance and scalability of gradient computation. - ``` - pip install -r setup-requirements.txt - python scripts/personalize.py - ``` +## Getting Started - And then follow the prompts. +## **Installation** - :pencil: *NOTE: This script will overwrite the README in your repository.* +**Requirements** -4. Commit and push your changes, then make sure all GitHub Actions jobs pass. +```bash +torch==2.1.0+cu121 +torchaudio==2.1.0+cu121 +torchvision==0.16.0+cu121 +triton==2.1.0 +``` -5. (Optional) If you plan on publishing your package to PyPI, add repository secrets for `PYPI_USERNAME` and `PYPI_PASSWORD`. To add these, go to "Settings" > "Secrets" > "Actions", and then click "New repository secret". +You can install the package using `pip3 install -e .`: - *If you don't have PyPI account yet, you can [create one for free](https://pypi.org/account/register/).* +```bash +git clone https://github.com/simudt/Tri-RMSNorm +cd Tri-RMSNorm +pip3 install -e . +``` -6. (Optional) If you want to deploy your API docs to [readthedocs.org](https://readthedocs.org), go to the [readthedocs dashboard](https://readthedocs.org/dashboard/import/?) and import your new project. - - Then click on the "Admin" button, navigate to "Automation Rules" in the sidebar, click "Add Rule", and then enter the following fields: +## Usage - - **Description:** Publish new versions from tags - - **Match:** Custom Match - - **Custom match:** v[vV] - - **Version:** Tag - - **Action:** Activate version +The package provides two main functions: - Then hit "Save". +- `_rms_norm_fwd_fused` for the forward pass of RMS normalization - *After your first release, the docs will automatically be published to [your-project-name.readthedocs.io](https://your-project-name.readthedocs.io/).* +- `_rms_norm_bwd_dx_fused` for the backward pass, computing gradients with respect to X, W, B -### Creating releases +```python +class RMSNormFunctionCustomKernel(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, bias, eps): + M, N = x.shape + y = torch.empty_like(x) + rstd = torch.empty(M, dtype=torch.float32, device=x.device) + _rms_norm_fwd_fused[(M,)](x, y, weight, bias, rstd, x.stride(0), N, eps, BLOCK_SIZE=1024) + ctx.save_for_backward(x, weight, bias, rstd) + ctx.eps = eps + ctx.N = N + return y -Creating new GitHub and PyPI releases is easy. The GitHub Actions workflow that comes with this repository will handle all of that for you. -All you need to do is follow the instructions in [RELEASE_PROCESS.md](./RELEASE_PROCESS.md). + @staticmethod + def backward(ctx, dy): + x, weight, bias, rstd = ctx.saved_tensors + eps = ctx.eps + N = ctx.N + M = x.shape[0] + dx = torch.empty_like(x) + _dw = torch.empty_like(weight) + _db = torch.empty_like(bias) + locks = torch.zeros(2 * 32, dtype=torch.int32, device=x.device) + _rms_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, weight, bias, rstd, locks, x.stride(0), N, eps, GROUP_SIZE_M=32, BLOCK_SIZE_N=1024) + return dx, _dw, _db, None -## Projects using this template +def test_rms_norm_custom_kernel(): + eps = 1e-5 + input = torch.tensor([[0.1, -0.2] * 10] * 10, device='cuda', requires_grad=True) + weights = torch.tensor([0.1] * 20, device='cuda', requires_grad=True) + biases = torch.tensor([0.01] * 20, device='cuda', requires_grad=True) -Here is an incomplete list of some projects that started off with this template: + output = RMSNormFunctionCustomKernel.apply(input, weights, biases, eps) + loss = output.mean() + loss.backward() -- [ai2-tango](https://github.com/allenai/tango) -- [cached-path](https://github.com/allenai/cached_path) -- [beaker-py](https://github.com/allenai/beaker-py) -- [gantry](https://github.com/allenai/beaker-gantry) -- [ip-bot](https://github.com/abe-101/ip-bot) + print("Gradients on Input: ", input.grad) + print("Gradients on Weights: ", weights.grad) + print("Gradients on Biases: ", biases.grad) -☝️ *Want your work featured here? Just open a pull request that adds the link.* +test_rms_norm_custom_kernel() +``` -## FAQ +Adjust grid, block, and other parameters as per your requirements and GPU specifications. -#### Should I use this template even if I don't want to publish my package? +## Benchmark -Absolutely! If you don't want to publish your package, just delete the `docs/` directory and the `release` job in [`.github/workflows/main.yml`](https://github.com/allenai/python-package-template/blob/main/.github/workflows/main.yml). +Tri-RMSNorm kernel demonstrates improved speedup in initial benchmarks when compared to the PyTorch-based custom RMSNorm implementation. Benchmarks will be included in the repository to ensure reproducibility. -## Contributing +## License -If you find a bug :bug:, please open a [bug report](https://github.com/allenai/python-package-template/issues/new?assignees=&labels=bug&template=bug_report.md&title=). -If you have an idea for an improvement or new feature :rocket:, please open a [feature request](https://github.com/allenai/python-package-template/issues/new?assignees=&labels=Feature+request&template=feature_request.md&title=). +This package is licensed under the Apache License - see the LICENSE file for details. \ No newline at end of file diff --git a/RELEASE_PROCESS.md b/RELEASE_PROCESS.md deleted file mode 100644 index f214ab4..0000000 --- a/RELEASE_PROCESS.md +++ /dev/null @@ -1,24 +0,0 @@ -# GitHub Release Process - -## Steps - -1. Update the version in `my_package/version.py`. - -3. Run the release script: - - ```bash - ./scripts/release.sh - ``` - - This will commit the changes to the CHANGELOG and `version.py` files and then create a new tag in git - which will trigger a workflow on GitHub Actions that handles the rest. - -## Fixing a failed release - -If for some reason the GitHub Actions release workflow failed with an error that needs to be fixed, you'll have to delete both the tag and corresponding release from GitHub. After you've pushed a fix, delete the tag from your local clone with - -```bash -git tag -l | xargs git tag -d && git fetch -t -``` - -Then repeat the steps above. diff --git a/docs/source/conf.py b/docs/source/conf.py index 2fab8ca..955a0f7 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -20,11 +20,11 @@ sys.path.insert(0, os.path.abspath("../../")) -from my_package import VERSION, VERSION_SHORT # noqa: E402 +from tri_rmsnorm import VERSION, VERSION_SHORT # noqa: E402 # -- Project information ----------------------------------------------------- -project = "my-package" +project = "tri_rmsnorm" copyright = f"{datetime.today().year}, Allen Institute for Artificial Intelligence" author = "Allen Institute for Artificial Intelligence" version = VERSION_SHORT diff --git a/my_package/__init__.py b/my_package/__init__.py deleted file mode 100644 index ac02de7..0000000 --- a/my_package/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from __future__ import annotations - -from .version import VERSION, VERSION_SHORT diff --git a/pyproject.toml b/pyproject.toml index 8bf2ebe..d2f512a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] # See https://setuptools.pypa.io/en/latest/userguide/quickstart.html for more project configuration options. -name = "my-package" +name = "tri_rmsnorm" dynamic = ["version"] readme = "README.md" classifiers = [ @@ -14,21 +14,12 @@ classifiers = [ "Programming Language :: Python :: 3", "Topic :: Scientific/Engineering :: Artificial Intelligence", ] -authors = [ - {name = "Allen Institute for Artificial Intelligence", email = "contact@allenai.org"} -] requires-python = ">=3.8" dependencies = [ # Add your own dependencies here ] license = {file = "LICENSE"} -[project.urls] -Homepage = "https://github.com/allenai/python-package-template" -Repository = "https://github.com/allenai/python-package-template" -Changelog = "https://github.com/allenai/python-package-template/blob/main/CHANGELOG.md" -# Documentation = "https://my-package.readthedocs.io/" - [project.optional-dependencies] dev = [ "ruff", @@ -65,10 +56,10 @@ exclude = [ include-package-data = true [tool.setuptools.package-data] -my_package = ["py.typed"] +tri_rmsnorm = ["py.typed"] [tool.setuptools.dynamic] -version = {attr = "my_package.version.VERSION"} +version = {attr = "tri_rmsnorm.version.VERSION"} [tool.black] line-length = 100 diff --git a/scripts/personalize.py b/scripts/personalize.py index 8d1e2c8..566c258 100644 --- a/scripts/personalize.py +++ b/scripts/personalize.py @@ -99,7 +99,7 @@ def main( (BASE_URL_TO_REPLACE, repo_url), (REPO_NAME_TO_REPLACE, github_repo), ("my-package", package_actual_name), - ("my_package", package_dir_name), + ("tri_rmsnorm", package_dir_name), ] if dry_run: for old, new in replacements: @@ -108,11 +108,11 @@ def main( if path.resolve() not in FILES_TO_REMOVE: personalize_file(path, dry_run, replacements) - # Rename 'my_package' directory to `package_dir_name`. + # Rename 'tri_rmsnorm' directory to `package_dir_name`. if not dry_run: - (REPO_BASE / "my_package").replace(REPO_BASE / package_dir_name) + (REPO_BASE / "tri_rmsnorm").replace(REPO_BASE / package_dir_name) else: - print(f"Renaming 'my_package' directory to '{package_dir_name}'") + print(f"Renaming 'tri_rmsnorm' directory to '{package_dir_name}'") # Start with a fresh README. readme_contents = f"""# {package_actual_name}\n""" diff --git a/scripts/prepare_changelog.py b/scripts/prepare_changelog.py index 64a3616..9aa07a9 100644 --- a/scripts/prepare_changelog.py +++ b/scripts/prepare_changelog.py @@ -3,7 +3,7 @@ from datetime import datetime from pathlib import Path -from my_package.version import VERSION +from tri_rmsnorm.version import VERSION def main(): diff --git a/scripts/release.sh b/scripts/release.sh index b9a695e..89ffe65 100755 --- a/scripts/release.sh +++ b/scripts/release.sh @@ -2,7 +2,7 @@ set -e -TAG=$(python -c 'from my_package.version import VERSION; print("v" + VERSION)') +TAG=$(python -c 'from tri_rmsnorm.version import VERSION; print("v" + VERSION)') read -p "Creating new release for $TAG. Do you want to continue? [Y/n] " prompt diff --git a/setup-requirements.txt b/setup-requirements.txt index 103b77f..a715ec0 100644 --- a/setup-requirements.txt +++ b/setup-requirements.txt @@ -1,3 +1,7 @@ click>=7.0,<9.0 click-help-colors>=0.9.1,<0.10 rich>=11.0,<14.0 +torch==2.1.0+cu121 +torchaudio==2.1.0+cu121 +torchvision==0.16.0+cu121 +triton==2.1.0 \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..66191da --- /dev/null +++ b/setup.py @@ -0,0 +1,26 @@ +from setuptools import setup, find_packages + +with open("README.md", "r") as fh: + long_description = fh.read() + +with open("setup-requirements.txt", "r") as req_file: + install_requires = req_file.read().splitlines() + +setup( + name="tri_rmsnorm", + version="0.1.0", + author="Simu", + author_email="simudtai@gmail.com", + description="Packaged version of Griffin for Jax + Flax.", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/simudt/RMSNorm-Triton", + packages=find_packages(), + install_requires=install_requires, + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], + python_requires=">=3.8", +) \ No newline at end of file diff --git a/tri_rmsnorm/__init__.py b/tri_rmsnorm/__init__.py new file mode 100644 index 0000000..7c8f8a7 --- /dev/null +++ b/tri_rmsnorm/__init__.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from .version import VERSION, VERSION_SHORT + +from tri_rmsnorm.kernel.rms_normalization_kernel import ( + _rms_norm_fwd_fused, + _rms_norm_bwd_dx_fused, + _rms_norm_bwd_dwdb, +) + +__all__ = [ + "VERSION", + "VERSION_SHORT", + "_rms_norm_fwd_fused", + "_rms_norm_bwd_dx_fused", + "_rms_norm_bwd_dwdb", +] diff --git a/tri_rmsnorm/kernel/rms_normalization_kernel.py b/tri_rmsnorm/kernel/rms_normalization_kernel.py new file mode 100644 index 0000000..c2cb7b3 --- /dev/null +++ b/tri_rmsnorm/kernel/rms_normalization_kernel.py @@ -0,0 +1,183 @@ +import torch +import triton +import triton.language as tl + +if hasattr(tl, "libdevice"): + tl_math = tl.libdevice +else: + tl_math = tl.math + + +@triton.jit +def _rms_norm_fwd_fused( + X, + Y, + W, + B, + Rstd, + stride, + N, + eps, + BLOCK_SIZE: tl.constexpr, +): + """ + Kernel invocation for forward pass of RMS normalization with fused operations + + Params: + - X (tensor): Input tensor + - Y (tensor): Output tensor where the normalized results will be written + - W (tensor): Scale tensor applied to the normalized input + - B (tensor): Bias tensor added to the scaled input + - Rstd (tensor): Reciprocal of the standard deviation used for normalization + - stride (int): Stride to be applied when accessing elements in the input and output tensors + - N (int): Number of elements in the input tensor + - eps (float): Small epsilon value added to the variance to prevent division by zero + - BLOCK_SIZE (constexpr): Size of the block for computation, provided as a compile-time constant + + Return: + - None + + Usage: + _rms_norm_fwd_fused[grid, block](X, Y, W, B, Rstd, stride, N, eps, BLOCK_SIZE) + """ + row = tl.program_id(0) + Y += row * stride + X += row * stride + + _rms = 0 + _rms = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + _rms += a * a + rms = tl.sqrt(tl.sum(_rms) / N + eps) + + tl.store(Rstd + row, rms) + + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(W + cols, mask=mask) + b = tl.load(B + cols, mask=mask) + x = tl.load(X + cols, mask=mask, other=0.0, eviction_policy="evict_first").to(tl.float32) + x_hat = x / rms + y = x_hat * w + b + tl.store(Y + cols, y, mask=mask) + + +@triton.jit +def _rms_norm_bwd_dx_fused( + DX, + DY, + DW, + DB, + X, + W, + B, + Rstd, + Lock, + stride, + N, + eps, + GROUP_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + """ + Kernel invocation for backward pass of RMS normalization, computing gradients w.r.t. input + + Params: + - DX (tensor): Gradient of the loss with respect to the inputs + - DY (tensor): Gradient of the loss with respect to the outputs + - DW (tensor): Gradient of the loss with respect to the scale tensor W + - DB (tensor): Gradient of the loss with respect to the bias tensor B + - X (tensor): Input tensor from the forward pass + - W (tensor): Scale tensor applied during the forward pass + - B (tensor): Bias tensor added during the forward pass + - Rstd (tensor): Reciprocal of the standard deviation used for normalization in the forward pass + - Lock (tensor): Lock tensor for atomic operations to prevent race conditions + - stride (int): Stride to be applied when accessing elements in the tensors + - N (int): Number of elements in each tensor + - eps (float): Small epsilon value used during the forward pass + - GROUP_SIZE_M (constexpr): Size of the group for M dimension, provided as a compile-time constant + - BLOCK_SIZE_N (constexpr): Size of the block for N dimension, provided as a compile-time constant + + Return: + - None + + Usage: + _rms_norm_bwd_dx_fused[grid, block](DX, DY, DW, DB, X, W, B, Rstd, Lock, stride, N, eps, GROUP_SIZE_M, BLOCK_SIZE_N) + """ + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_SIZE_N) + mask = cols < N + X += row * stride + DY += row * stride + DX += row * stride + lock_id = row % GROUP_SIZE_M + Lock += lock_id + Count = Lock + GROUP_SIZE_M + DW = DW + lock_id * N + cols + DB = DB + lock_id * N + cols + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + w = tl.load(W + cols, mask=mask).to(tl.float32) + rstd = tl.load(Rstd + row) + x_norm = x * rstd + wdy = w * dy + dx = wdy * rstd + tl.store(DX + cols, dx, mask=mask) + partial_dw = (dy * x_norm).to(w.dtype) + partial_db = dy.to(w.dtype) + + # Locking mechanism to prevent race conditions + while tl.atomic_cas(Lock, 0, 1) == 1: + pass + count = tl.load(Count) + if count == 0: + tl.atomic_xchg(Count, 1) + else: + partial_dw += tl.load(DW, mask=mask) + partial_db += tl.load(DB, mask=mask) + + tl.store(DW, partial_dw, mask=mask) + tl.store(DB, partial_db, mask=mask) + tl.atomic_xchg(Lock, 0) + + +@triton.jit +def _rms_norm_bwd_dwdb( + DW, DB, FINAL_DW, FINAL_DB, M, N, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr +): + """ + Kernel invocation for backward pass of RMS normalization, computing and aggregating gradients w.r.t. weights and biases + + Params: + - DW (tensor): Intermediate gradient tensor for the scale factors, W + - DB (tensor): Intermediate gradient tensor for the biases, B + - FINAL_DW (tensor): Aggregated gradient tensor for the scale factors, to be updated + - FINAL_DB (tensor): Aggregated gradient tensor for the biases, to be updated + - M (int): Number of groups or batch size dimension + - N (int): Dimensionality of the feature vectors or the number of features + - BLOCK_SIZE_M (constexpr): Compile-time constant defining the block size in the M dimension + - BLOCK_SIZE_N (constexpr): Compile-time constant defining the block size in the N dimension + + Return: + - None + + Usage: + _rms_norm_bwd_dwdb[grid, block](DW, DB, FINAL_DW, FINAL_DB, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N) + """ + pid = tl.program_id(0) + cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(0, M, BLOCK_SIZE_M): + rows = i + tl.arange(0, BLOCK_SIZE_M) + mask = (rows[:, None] < M) & (cols[None, :] < N) + offs = rows[:, None] * N + cols[None, :] + dw += tl.load(DW + offs, mask=mask, other=0.0) + db += tl.load(DB + offs, mask=mask, other=0.0) + sum_dw = tl.sum(dw, axis=0) + sum_db = tl.sum(db, axis=0) + tl.store(FINAL_DW + cols, sum_dw, mask=cols < N) + tl.store(FINAL_DB + cols, sum_db, mask=cols < N) diff --git a/my_package/py.typed b/tri_rmsnorm/py.typed similarity index 100% rename from my_package/py.typed rename to tri_rmsnorm/py.typed diff --git a/my_package/version.py b/tri_rmsnorm/version.py similarity index 100% rename from my_package/version.py rename to tri_rmsnorm/version.py diff --git a/usage.py b/usage.py new file mode 100644 index 0000000..0af7b0a --- /dev/null +++ b/usage.py @@ -0,0 +1,64 @@ +import torch +from tri_rmsnorm.kernel.rms_normalization_kernel import ( + _rms_norm_fwd_fused, + _rms_norm_bwd_dx_fused, +) + + +class RMSNormFunctionCustomKernel(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, bias, eps): + M, N = x.shape + y = torch.empty_like(x) + rstd = torch.empty(M, dtype=torch.float32, device=x.device) + _rms_norm_fwd_fused[(M,)](x, y, weight, bias, rstd, x.stride(0), N, eps, BLOCK_SIZE=1024) + ctx.save_for_backward(x, weight, bias, rstd) + ctx.eps = eps + ctx.N = N + return y + + @staticmethod + def backward(ctx, dy): + x, weight, bias, rstd = ctx.saved_tensors + eps = ctx.eps + N = ctx.N + M = x.shape[0] + dx = torch.empty_like(x) + _dw = torch.empty_like(weight) + _db = torch.empty_like(bias) + locks = torch.zeros(2 * 32, dtype=torch.int32, device=x.device) + _rms_norm_bwd_dx_fused[(M,)]( + dx, + dy, + _dw, + _db, + x, + weight, + bias, + rstd, + locks, + x.stride(0), + N, + eps, + GROUP_SIZE_M=32, + BLOCK_SIZE_N=1024, + ) + return dx, _dw, _db, None + + +def test_rms_norm_custom_kernel(): + eps = 1e-5 + input = torch.tensor([[0.1, -0.2] * 10] * 10, device="cuda", requires_grad=True) + weights = torch.tensor([0.1] * 20, device="cuda", requires_grad=True) + biases = torch.tensor([0.01] * 20, device="cuda", requires_grad=True) + + output = RMSNormFunctionCustomKernel.apply(input, weights, biases, eps) + loss = output.mean() + loss.backward() + + print("Grads X: ", input.grad) + print("Grads W: ", weights.grad) + print("Grads B: ", biases.grad) + + +test_rms_norm_custom_kernel()