Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] main from NVIDIA:main #26

Merged
merged 63 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
df69965
[JAX] Fix unit tests to work around cuDNN 9.4 regression of 0 length…
mgoldfarb-nvidia Sep 16, 2024
af5daa0
Add dtensor support for TE optimizers (#1171)
blahBlahhhJ Sep 16, 2024
d2d4cf9
Update CI users (#1181)
timmoon10 Sep 16, 2024
9101a78
[JAX] Context Parallel Attention with All-Gather (#1106)
mgoldfarb-nvidia Sep 17, 2024
44fd316
[Common] Default CUDA_HOME to /usr/local/cuda when dynamically loadin…
denera Sep 17, 2024
528d44b
Changed VERSION to 1.12.0.dev
ptrendx Sep 17, 2024
28f95bd
Allow specifying cmake setup directory (#1186)
ryxli Sep 17, 2024
eb60b1a
Add docs for installing from PyPI (#1184)
ksivaman Sep 17, 2024
7e1068b
[PyTorch] Port fused optimizer tests to pytest (#1185)
timmoon10 Sep 18, 2024
841634c
[PyTorch] Check network interface name when initializing Userbuffers …
denera Sep 18, 2024
c0caadb
Expose `rotary_base` as an arg instead of hardcoding (#944)
sudhakarsingh27 Sep 18, 2024
0ee5ccd
[PyTorch] Relax the contiguous check for flash attention (#1176)
yaox12 Sep 19, 2024
195d703
Allow downloading of model weights automatically (#1172)
sudhakarsingh27 Sep 20, 2024
0c74535
Restore compatibility with Python 3.8 (#1189)
ptrendx Sep 20, 2024
a68acd7
Update list of CI users (#1198)
timmoon10 Sep 23, 2024
99af5c0
Allow to pass architectures like 90a, without being overriden (#1178)
aurianer Sep 24, 2024
a44cb72
Update list of CI users (#1203)
ksivaman Sep 24, 2024
209b8e5
fix NVTE_UB_WITH_MPI read (#1194)
erhoo82 Sep 25, 2024
c4a5cb8
[PyTorch] Add GroupedLinear to the docs and fix typos (#1206)
pggPL Sep 27, 2024
8a1b7ee
[PyTorch] Fix detection of 3 in 3hd/h3d layouts (#1187)
cyanguwa Sep 27, 2024
7b152a8
Fix CP unit test on A100 and L40s (#1211)
xrennvidia Sep 27, 2024
728c558
[PyTorch] Add pool argument to make_graphed_callable (#1218)
ksivaman Oct 1, 2024
46075b9
[PyTorch] Fix distributed testing (#1219)
ksivaman Oct 1, 2024
fb74961
Removed the unused options from GroupedLinear docs and fixed the bug …
ptrendx Oct 1, 2024
10cceae
[PyTorch] Move `block_table` argument to FA varlen function (#1222)
cyanguwa Oct 3, 2024
9d976bc
[PyTorch] Minor optimizations to reduce CPU overheads in modules (#1191)
timmoon10 Oct 4, 2024
f8eb799
[PyTorch] remove duplicate code (#1215)
emmanuel-ferdman Oct 6, 2024
60f738f
Tests for distributed (#1196)
pggPL Oct 7, 2024
c24a4c4
Hierarchical CP implementation (Ulysses + Ring) (#1209)
xrennvidia Oct 7, 2024
c3b3cd2
Fix cuDNN sliding window size (#1212)
cyanguwa Oct 7, 2024
e762592
[PyTorch] Miscellaneous fixes for FA3 attention (#1174)
cyanguwa Oct 8, 2024
5b89f1a
[PyTorch] Debug dtype casting in operation-based API (#1202)
timmoon10 Oct 9, 2024
2d87552
[PyTorch] Add documentation for FP8 attention checkpointing (#1223)
cyanguwa Oct 9, 2024
5b6546c
[PyTorch] Improve `get_qkv_layout` (#1214)
cyanguwa Oct 9, 2024
85e60e6
[JAX] Expose sliding window attn to TE-JAX API (#1205)
huanghua1994 Oct 10, 2024
3b89c36
Small fixes to Float8Tensor (#1225)
ptrendx Oct 10, 2024
9ee2dbd
Fix bug in torch compile and seqdim is integer (#1217)
wplf Oct 11, 2024
b36bd0a
Add FlashAttention3 to CP implementations (#1232)
xrennvidia Oct 11, 2024
55dcbb4
[PyTorch] Let Fused RoPE support CP with THD format (#1238)
yaox12 Oct 12, 2024
86f07be
Do not link against CUDA driver when building (#1240)
timmoon10 Oct 14, 2024
20c55e4
Check for backend support in Jax context parallel fused attention tes…
mgoldfarb-nvidia Oct 15, 2024
54aa12a
Create README.md for examples/ (#1221)
sbhavani Oct 15, 2024
f6b766b
[PyTorch] Build custom ORT ops before running ONNX export tests (#1252)
timmoon10 Oct 16, 2024
43b9e1e
fix assertion bug for SWA API in TE-JAX (#1242)
kocchop Oct 16, 2024
161b1d9
[PyTorch] Drop FA as an installation requirement (#1226)
cyanguwa Oct 16, 2024
6e90fcb
Upgrade pylint to 3.3.1 (#1257)
ksivaman Oct 16, 2024
a518151
[PyTorch] Fix FP8 activation recompute (#1254)
ksivaman Oct 16, 2024
9001081
Changed VERSION to 1.13.0.dev
ptrendx Oct 16, 2024
2d7020e
[PyTorch] Fix wgrads for GroupedLinear when weights don't require gra…
yaox12 Oct 17, 2024
8e97c8d
[Bugfix] Fix bias for 0-dim tensors in gemm (#1246)
yaox12 Oct 17, 2024
12f30ea
[TE/JAX] Enabling CudaGraph for custom calls with FFI (#1228)
phu0ngng Oct 17, 2024
a488b8b
Fix seq_dim in CP implementation (#1264)
xrennvidia Oct 17, 2024
41fe1e5
[PyTorch] Reorganize L1 tests (#1255)
timmoon10 Oct 18, 2024
927bca7
[Paddle] Debug wheel test (#1265)
timmoon10 Oct 18, 2024
3ea7dd3
[PyTorch] Remove PyTorch L0 distributed test (#1273)
timmoon10 Oct 18, 2024
29e3a09
[PyTorch] Reduce the number of FA versions in L3 tests (#1280)
cyanguwa Oct 21, 2024
7b18f23
Fused Attention Support 64-bit Ragged Offsets for Large THD Tensors (…
mgoldfarb-nvidia Oct 22, 2024
35f7d26
[JAX] Skip V100 encoder tests (#1262)
zlsh80826 Oct 22, 2024
d9b4bfb
Add THD + GQA supports (#1260)
zlsh80826 Oct 22, 2024
20c7529
[JAX] Fix correctness of JAX fused attention with CP and improve nume…
mgoldfarb-nvidia Oct 24, 2024
18c2234
[JAX] XLA Custom Calls with FFI for FusedAttnFwd, Quantize, Transpose…
huanghua1994 Oct 24, 2024
7a5fd0c
[Paddle] Update type names for Paddle 3.0 (#1286)
timmoon10 Oct 24, 2024
7b284fe
[Pytorch] Check gradient in test numerics (#1229)
pggPL Oct 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/trigger-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ jobs:
|| github.actor == 'xrennvidia'
|| github.actor == 'yaox12'
|| github.actor == 'huanghua1994'
|| github.actor == 'mgoldfarb-nvidia'
|| github.actor == 'pggPL'
|| github.actor == 'vasunvidia'
|| github.actor == 'erhoo82'
)
steps:
- name: Check if comment is issued by authorized person
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ develop-eggs/
dist/
downloads/
.pytest_cache/
compile_commands.json
10 changes: 9 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,15 @@ To install the latest stable version of Transformer Engine,
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch).
This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch,paddle).

Alternatively, the package can be directly installed from `Transformer Engine's PyPI <https://pypi.org/project/transformer-engine/>`_, e.g.

.. code-block:: bash
pip install transformer_engine[pytorch]
To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch,paddle]). Transformer Engine ships wheels for the core library as well as the PaddlePaddle extensions. Source distributions are shipped for the JAX and PyTorch extensions.

From source
^^^^^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion build_tools/VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.11.0.dev0
1.13.0.dev0
8 changes: 6 additions & 2 deletions build_tools/build_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,12 @@ def run(self) -> None:
if isinstance(ext, CMakeExtension):
print(f"Building CMake extension {ext.name}")
# Set up incremental builds for CMake extensions
setup_dir = Path(__file__).resolve().parent.parent
build_dir = setup_dir / "build" / "cmake"
build_dir = os.getenv("NVTE_CMAKE_BUILD_DIR")
if build_dir:
build_dir = Path(build_dir).resolve()
else:
root_dir = Path(__file__).resolve().parent.parent
build_dir = root_dir / "build" / "cmake"

# Ensure the directory exists
build_dir.mkdir(parents=True, exist_ok=True)
Expand Down
10 changes: 5 additions & 5 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,15 @@ def setup_pytorch_extension(
)
)

if "80" in cuda_architectures:
nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"])
if "90" in cuda_architectures:
nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"])
for arch in cuda_architectures.split(";"):
if arch == "70":
continue # Already handled
nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"])

# Libraries
library_dirs = []
libraries = []
if os.getenv("NVTE_UB_WITH_MPI"):
if bool(int(os.getenv("NVTE_UB_WITH_MPI", 0))):
assert (
os.getenv("MPI_HOME") is not None
), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
Expand Down
3 changes: 3 additions & 0 deletions docs/api/pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ pyTorch
.. autoapiclass:: transformer_engine.pytorch.Linear(in_features, out_features, bias=True, **kwargs)
:members: forward, set_tensor_parallel_group

.. autoapiclass:: transformer_engine.pytorch.GroupedLinear(in_features, out_features, bias=True, **kwargs)
:members: forward, set_tensor_parallel_group

.. autoapiclass:: transformer_engine.pytorch.LayerNorm(hidden_size, eps=1e-5, **kwargs)

.. autoapiclass:: transformer_engine.pytorch.RMSNorm(hidden_size, eps=1e-5, **kwargs)
Expand Down
9 changes: 6 additions & 3 deletions docs/examples/te_llama/te_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,11 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k
Custom method adapted from `from_pretrained` method in HuggingFace
Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579
"""
vanilla_model = cls(config).to(kwargs["torch_dtype"])
is_local = os.path.isdir(pretrained_model_name_or_path)
# Before loading the model, set the default dtype for torch
torch.set_default_dtype(kwargs["torch_dtype"])

# Load the vanilla model weights
vanilla_model = cls(config)
subfolder = ""
variant = None
if os.path.isfile(
Expand Down Expand Up @@ -133,7 +136,7 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k
else:
raise AssertionError("Only sharded PyTorch ckpt format supported at the moment")

resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
resolved_archive_file, _ = get_checkpoint_shard_files(
pretrained_model_name_or_path,
archive_file,
)
Expand Down
65 changes: 46 additions & 19 deletions docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -247,15 +247,24 @@
"restart_jupyter_notebook()\n",
"\n",
"\n",
"# Import necessary packages and methods\n",
"# Import necessary packages, methods and variables\n",
"from utils import *\n",
"\n",
"\n",
"# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n",
"## !!! `model_name` attr must point to the location of the model weights !!!\n",
"# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n",
"# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n",
"hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n",
"# Provide Huggingface Access Token\n",
"hyperparams.hf_access_token = \"\"\n",
"assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"hyperparams.weights_cache_dir = \"\"\n",
"\n",
"# For Llama 2, uncomment this line (also set by default)\n",
"hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n",
"\n",
"# For Llama 3, uncomment this line\n",
"# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
"\n",
"hyperparams.mixed_precision = \"bf16\"\n",
"\n",
"\n",
Expand Down Expand Up @@ -554,7 +563,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"id": "bdb34b91",
"metadata": {},
"outputs": [
Expand All @@ -573,15 +582,24 @@
"restart_jupyter_notebook()\n",
"\n",
"\n",
"# Import necessary packages and methods\n",
"# Import necessary packages, methods and variables\n",
"from utils import *\n",
"\n",
"\n",
"# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n",
"## !!! `model_name` attr must point to the location of the model weights !!!\n",
"# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n",
"# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n",
"hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n",
"# Provide Huggingface Access Token\n",
"hyperparams.hf_access_token = \"\"\n",
"assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"hyperparams.weights_cache_dir = \"\"\n",
"\n",
"# For Llama 2, uncomment this line (also set by default)\n",
"hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n",
"\n",
"# For Llama 3, uncomment this line\n",
"# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
"\n",
"hyperparams.mixed_precision = \"bf16\"\n",
"\n",
"\n",
Expand Down Expand Up @@ -653,15 +671,24 @@
"restart_jupyter_notebook()\n",
"\n",
"\n",
"# Import necessary packages and methods\n",
"# Import necessary packages, methods and variables\n",
"from utils import *\n",
"\n",
"\n",
"# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n",
"## !!! `model_name` attr must point to the location of the model weights !!!\n",
"# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n",
"# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n",
"hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n",
"# Provide Huggingface Access Token\n",
"hyperparams.hf_access_token = \"\"\n",
"assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"hyperparams.weights_cache_dir = \"\"\n",
"\n",
"# For Llama 2, uncomment this line (also set by default)\n",
"hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n",
"\n",
"# For Llama 3, uncomment this line\n",
"# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
"\n",
"hyperparams.mixed_precision = \"fp8\"\n",
"\n",
"\n",
Expand Down
56 changes: 51 additions & 5 deletions docs/examples/te_llama/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
class HyperParameters:
def __init__(self):
self.mixed_precision = "bf16"
# self.model_name = "" # <== Add model weight location here

# Set to Meta Llama 2 by default.
self.model_name = "meta-llama/Llama-2-7b-hf"

self.dataset_name = "timdettmers/openassistant-guanaco"
self.dataset_text_field = "text"
self.learning_rate = 1.41e-5
Expand All @@ -35,6 +38,10 @@ def __init__(self):
self.num_warmup_steps = 5
self.num_training_steps = 10

# This is either provided by the user or it will be set when the
# model weights are downloaded.
self.weights_cache_dir = ""


hyperparams = HyperParameters()

Expand Down Expand Up @@ -76,13 +83,49 @@ def tokenize(element):
return train_dataloader


def ensure_model_is_downloaded(hyperparams):
assert hyperparams.model_name in [
"meta-llama/Meta-Llama-3-8B",
"meta-llama/Llama-2-7b-hf",
], "Only Meta Llama 2 7B and Meta Llama 3 8B models are supported!"

# Login using Huggingface Hub API
from huggingface_hub import login

try:
login(hyperparams.hf_access_token)
except Exception as e:
if "Invalid token passed!" in str(e):
print(
"Please pass a valid HF Access Token! More info at"
" https://huggingface.co/docs/hub/en/security-tokens."
)
else:
print(f"Exception is {e}")

# Download the model if it doesn't exist
from huggingface_hub import snapshot_download

supplied_cache_dir = (
hyperparams.weights_cache_dir if hyperparams.weights_cache_dir != "" else None
)
hyperparams.weights_cache_dir = snapshot_download(
repo_id=hyperparams.model_name, cache_dir=supplied_cache_dir
)

print(f"Model cache directory : {hyperparams.weights_cache_dir}")


def init_baseline_model(hyperparams):
# Download and cache the weights
ensure_model_is_downloaded(hyperparams)

# Init the model
config = AutoConfig.from_pretrained(hyperparams.model_name)
config = AutoConfig.from_pretrained(hyperparams.weights_cache_dir)
# make sure to use flash_attention to do iso comparison with TELlamaModel
config._attn_implementation = "flash_attention_2"
model = AutoModelForCausalLM.from_pretrained(
hyperparams.model_name,
hyperparams.weights_cache_dir,
config=config,
torch_dtype=torch.bfloat16,
)
Expand All @@ -94,13 +137,16 @@ def init_baseline_model(hyperparams):


def init_te_llama_model(hyperparams):
# Download and cache the weights
ensure_model_is_downloaded(hyperparams)

# Init the model
from te_llama import TELlamaForCausalLM

config = AutoConfig.from_pretrained(hyperparams.model_name)
config = AutoConfig.from_pretrained(hyperparams.weights_cache_dir)
config._attn_implementation = "flash_attention_2"
model = TELlamaForCausalLM.from_pretrained_local(
hyperparams.model_name,
hyperparams.weights_cache_dir,
config=config,
torch_dtype=torch.bfloat16,
)
Expand Down
75 changes: 75 additions & 0 deletions docs/faq.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
..
Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

See LICENSE for license information.

Frequently Asked Questions (FAQ)
================================

FP8 checkpoint compatibility
----------------------------

Transformer Engine starts to support FP8 attention in 1.6. It stores the FP8 metadata, i.e. scaling factors and amax histories, under a `._extra_state` key in the checkpoint. As the FP8 attention support expands from one backend to multiple backends, the location of the `._extra_state` key has also shifted.

Here, we take the `MultiheadAttention` module as an example. Its FP8 attention metadata in Transformer Engine 1.11 is stored as `core_attention._extra_state` as shown below.

.. code-block:: python

>>> from transformer_engine.pytorch import MultiheadAttention, fp8_model_init
>>> with fp8_model_init(enabled=True):
... mha = MultiheadAttention(
... hidden_size=1024,
... num_attention_heads=16,
... bias=True,
... params_dtype=torch.bfloat16,
... input_layernorm=False,
... fuse_qkv_params=True,
... attention_type="self",
... qkv_weight_interleaved=True,
... ).to(dtype=torch.bfloat16, device="cuda")
...
>>> state_dict = mha.state_dict()
>>> print(state_dict.keys())
odict_keys(['qkv.weight', 'qkv.bias', 'qkv._extra_state', 'core_attention._extra_state', 'proj.weight', 'proj.bias', 'proj._extra_state'])

Here is a full list of the checkpoint save/load behaviors from all Transformer Engine versions.

.. list-table::

* - **Version: <= 1.5**

- Saves no FP8 metadata since FP8 attention is not supported
- Loading behavior for checkpoints created by the following versions:

:<= 1.5: Loads no FP8 metadata
:> 1.5: Error: unexpected key
* - **Version: 1.6, 1.7**

- Saves FP8 metadata to `core_attention.fused_attention._extra_state`
- Loading behavior for checkpoints created by the following versions:

:<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes
:1.6, 1.7: Loads FP8 metadata from checkpoint
:>= 1.8: Error: unexpected key
* - **Version: >=1.8, <= 1.11**

- Saves FP8 metadata to `core_attention._extra_state`
- Loading behavior for checkpoints created by the following versions:

:<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes
:1.6, 1.7: This save/load combination relies on users to map the 1.6/1.7 key to the 1.8-1.11 key. Otherwise, it initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes. The mapping can be done, in this `MultiheadAttention` example, by

.. code-block:: python

>>> state_dict["core_attention._extra_state"] = \
state_dict["core_attention.fused_attention._extra_state"]
>>> del state_dict["core_attention.fused_attention._extra_state"]

:>= 1.8: Loads FP8 metadata from checkpoint
* - **Version: >=1.12**

- Saves FP8 metadata to `core_attention._extra_state`
- Loading behavior for checkpoints created by the following versions:

:<= 1.5: Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes
:>= 1.6: Loads FP8 metadata from checkpoint
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Transformer Engine documentation

installation
examples/quickstart.ipynb
faq

.. toctree::
:hidden:
Expand Down
Loading
Loading