-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathwhisper.py
34 lines (28 loc) · 1.3 KB
/
whisper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# This example requires torch nightly (see README.md for recommended version)
# It further requires `pip install datasets soundfile librosa`
# Please run `pip install -r requirements.txt`
from datasets import load_dataset
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from octoml_profile import remote_profile, accelerate
#
# load model and processor
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
# load dummy dataset and read audio files
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = ds[0]["audio"]
@accelerate(dynamic=True)
def predict(sample):
input_features = processor(sample["array"],
sampling_rate=sample["sampling_rate"],
return_tensors="pt").input_features
# generate token ids
predicted_ids = model.generate(input_features)
# decode token ids to text
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
return transcription
with remote_profile(backends=["g4dn.xlarge/onnxrt-cuda", "r6i.large/onnxrt-cpu"],
num_repeats=1):
for _ in range(3):
text = predict(sample)
print(text)