Skip to content

Commit

Permalink
Flag the t5 example
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Aug 26, 2024
1 parent 8f13371 commit 7444f87
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions examples/inference/pippy/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,21 @@
import time

import torch
from packaging import version
from transformers import AutoModelForSeq2SeqLM

from accelerate import PartialState, prepare_pippy
from accelerate import __version__ as accelerate_version
from accelerate.utils import set_seed


if version.parse(accelerate_version) > version.parse("0.33.0"):
raise RuntimeError(
"Using encoder/decoder models is not supported with the `torch.pipelining` integration or accelerate>=0.34.0. "
"Please use a lower accelerate version and `torchpippy`, which this example uses."
)


# Set the random seed to have reproducable outputs
set_seed(42)

Expand All @@ -32,7 +41,7 @@
input = torch.randint(
low=0,
high=model.config.vocab_size,
size=(1, 1024), # bs x seq_len
size=(2, 1024), # bs x seq_len
device="cpu",
dtype=torch.int64,
requires_grad=False,
Expand All @@ -59,18 +68,9 @@
# gather_outputs=True
# )

# Create new inputs of the expected size (n_processes)
input = torch.randint(
low=0,
high=model.config.vocab_size,
size=(2, 512), # bs x seq_len
device="cpu",
dtype=torch.int64,
requires_grad=False,
)
# The model expects a tuple during real inference
# with the data on the first device
args = (input.to("cuda:0"), input.to("cuda:0"))
args = (example_inputs["input_ids"].to("cuda:0"), example_inputs["decoder_input_ids"].to("cuda:0"))

# Take an average of 5 times
# Measure first batch
Expand Down

0 comments on commit 7444f87

Please sign in to comment.