Skip to content

Commit

Permalink
Refactor whisper_wrapper.cpp: Add trim function and handle exceptions…
Browse files Browse the repository at this point in the history
… in transcription and result callback
  • Loading branch information
royshil committed Oct 25, 2024
1 parent fddebc4 commit 56d27b4
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 16 deletions.
18 changes: 13 additions & 5 deletions simpler_whisper/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ def __del__(self):

class ThreadedWhisperModel:
def __init__(
self, model_path: str, use_gpu=False, max_duration_sec=10.0, sample_rate=16000
self,
model_path: str,
callback: Callable[[int, str, bool], None],
use_gpu=False,
max_duration_sec=10.0,
sample_rate=16000,
):
"""
Initialize a threaded Whisper model for continuous audio processing.
Expand All @@ -39,10 +44,13 @@ def __init__(
model_path, use_gpu, max_duration_sec, sample_rate
)
self._is_running = False
self.callback = callback

def start(
self, callback: Callable[[int, str, bool], None], result_check_interval_ms=100
):
def handle_result(self, chunk_id: int, text: str, is_partial: bool):
if self.callback is not None:
self.callback(chunk_id, text, is_partial)

def start(self, result_check_interval_ms=100):
"""
Start the processing threads with a callback for results.
Expand All @@ -56,7 +64,7 @@ def start(
if self._is_running:
return

self.model.start(callback, result_check_interval_ms)
self.model.start(self.handle_result, result_check_interval_ms)
self._is_running = True

def stop(self):
Expand Down
8 changes: 6 additions & 2 deletions src/whisper_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class WhisperModel
{
throw std::runtime_error("Failed to initialize whisper context");
}
params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
}

~WhisperModel()
Expand Down Expand Up @@ -85,7 +86,6 @@ class WhisperModel

std::vector<std::string> transcribe_raw_audio(const float *audio_data, int n_samples)
{
whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
if (whisper_full(ctx, params, audio_data, n_samples) != 0)
{
throw std::runtime_error("Whisper inference failed");
Expand All @@ -104,6 +104,7 @@ class WhisperModel

private:
whisper_context *ctx;
whisper_full_params params;
};

struct AudioChunk
Expand Down Expand Up @@ -237,7 +238,10 @@ class ThreadedWhisperModel
std::cerr << "Unknown exception during transcription" << std::endl;
}

std::cout << "Transcription: " << segments[0] << std::endl;
if (segments.empty())
{
return;
}

TranscriptionResult result;
result.chunk_id = current_id;
Expand Down
28 changes: 19 additions & 9 deletions test_simpler_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,36 +71,46 @@ def test_simpler_whisper():

def test_threaded_whisper():
def handle_result(chunk_id: int, text: str, is_partial: bool):
print(f"Chunk {chunk_id} results ({'partial' if is_partial else 'final'}):")
print(f" {text}")
print(
f"Chunk {chunk_id} results ({'partial' if is_partial else 'final'}): {text}"
)

# Create model with 10-second max duration
model = ThreadedWhisperModel(
model_path=model_path, use_gpu=True, max_duration_sec=10.0
model_path=model_path,
callback=handle_result,
use_gpu=True,
max_duration_sec=10.0,
)

# load audio from file with av
import av
container = av.open(R"C:\Users\roysh\Downloads\1847363777395929088.mp4")

container = av.open(
R"local_path_to_audio_file"
)
audio = container.streams.audio[0]
print(audio)
frame_generator = container.decode(audio)

# Start processing with callback
print("Starting threaded Whisper model...")
model.start(callback=handle_result)
model.start()

for i, frame in enumerate(frame_generator):
# print(f"Queueing audio chunk {i + 1}")
# Read audio chunk
incoming_audio = frame.to_ndarray().mean(axis=0)
incoming_audio = incoming_audio / 32768.0 # normalize to [-1, 1]
# resample to 16kHz
samples = resampy.resample(frame.to_ndarray().mean(axis=0), frame.rate, 16000)
samples = resampy.resample(incoming_audio, frame.rate, 16000)

# Queue some audio (will get partial results until 10 seconds accumulate)
chunk_id = model.queue_audio(samples)
# print(f" Queued chunk {i + 1} with ID {chunk_id} size {len(samples)}")
# sleep for the size of the audio chunk
time.sleep(len(samples) / 16000)
try:
time.sleep(len(samples) / 16000)
except:
break

# close the container
container.close()
Expand Down

0 comments on commit 56d27b4

Please sign in to comment.