-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ae48a22
commit 7dc5fd1
Showing
20 changed files
with
2,899 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
OPENAI_KEY=your-openai-api-key | ||
PYANNOTE_TOKEN=your-pyannote-api-token |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
3.11.4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# API Costs | ||
api_costs: | ||
whisper: | ||
cost_per_minute: 0.006 # USD per minute for Whisper | ||
pyannote: | ||
cost_per_hour: 0.18 # USD per hour for diarization | ||
|
||
# Logging Configuration | ||
logging: | ||
log_directory: "logs" | ||
max_log_size: 10485760 # 10 MB in bytes | ||
backup_count: 5 | ||
|
||
# Model Configuration | ||
model: | ||
default_model_id: "openai/whisper-large-v3" | ||
|
||
# Supported Services | ||
supported_upload_services: | ||
- "0x0.st" | ||
- "file.io" | ||
|
||
# Timeout Settings (in seconds) | ||
timeouts: | ||
download_timeout: 120 | ||
upload_timeout: 120 | ||
diarization_timeout: 3600 | ||
job_status_timeout: 60 |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
[tool.poetry] | ||
name = "yawt" | ||
version = "0.1.0" | ||
description = "YAWT (Yet Another Whisper-based Transcriber) is a transcription tool that leverages OpenAI's Whisper model to deliver accurate and efficient audio-to-text conversion." | ||
authors = ["Yaniv Golan <[email protected]>"] | ||
packages = [{ include = "yawt", from = "src" }] | ||
|
||
[tool.poetry.dependencies] | ||
python = "^3.11.4" | ||
torch = "^2.1.0" | ||
transformers = ">=4.35.0" | ||
tqdm = ">=4.66.0" | ||
python-dotenv = ">=1.0.0" | ||
ffmpeg-python = ">=0.2.0" | ||
numpy = ">=1.21.0" | ||
requests = ">=2.31.0" | ||
requests-toolbelt = ">=0.10.1" | ||
accelerate = ">=0.26.0" | ||
srt = ">=3.0.0" | ||
pyyaml = ">=6.0" | ||
|
||
[tool.poetry.dev-dependencies] | ||
pytest = "^7.0" | ||
pytest-mock = "^3.10.0" | ||
pytest-cov = "^4.0.0" | ||
|
||
[build-system] | ||
requires = ["poetry-core>=1.0.0"] | ||
build-backend = "poetry.core.masonry.api" | ||
|
||
[tool.poetry.scripts] | ||
yawt = "yawt.main:main" |
Large diffs are not rendered by default.
Oops, something went wrong.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import pytest | ||
from unittest.mock import patch, mock_open, MagicMock | ||
from yawt.audio_handler import load_audio, upload_file, download_audio | ||
import os | ||
|
||
def test_load_audio_success(mocker): | ||
mock_ffmpeg = mocker.patch('yawt.audio_handler.ffmpeg') | ||
mock_ffmpeg.input.return_value.output.return_value.run.return_value = (b'\x00\x01', None) | ||
audio = load_audio("path/to/audio.wav") | ||
assert isinstance(audio, list) | ||
assert len(audio) == 2 # Based on mock data | ||
|
||
def test_load_audio_ffmpeg_error(mocker): | ||
mock_ffmpeg = mocker.patch('yawt.audio_handler.ffmpeg') | ||
mock_ffmpeg.input.return_value.output.return_value.run.side_effect = ffmpeg.Error('ffmpeg', 'Error', 'Error message') | ||
|
||
with pytest.raises(SystemExit): | ||
load_audio("path/to/invalid_audio.wav") | ||
|
||
@patch('yawt.audio_handler.requests.post') | ||
def test_upload_file_0x0_st_success(mock_post, mocker): | ||
mock_response = MagicMock() | ||
mock_response.status_code = 200 | ||
mock_response.text = "https://0x0.st/abc123" | ||
mock_post.return_value = mock_response | ||
|
||
with patch('builtins.open', mock_open(read_data=b"data")): | ||
file_url = upload_file("path/to/file.wav", service='0x0.st') | ||
assert file_url == "https://0x0.st/abc123" | ||
|
||
@patch('yawt.audio_handler.requests.post') | ||
def test_upload_file_0x0_st_failure(mock_post, mocker): | ||
mock_response = MagicMock() | ||
mock_response.status_code = 400 | ||
mock_response.text = "Bad Request" | ||
mock_post.return_value = mock_response | ||
|
||
with patch('builtins.open', mock_open(read_data=b"data")), pytest.raises(Exception): | ||
upload_file("path/to/file.wav", service='0x0.st') | ||
|
||
@patch('yawt.audio_handler.requests.post') | ||
def test_upload_file_file_io_success(mock_post, mocker): | ||
mock_response = MagicMock() | ||
mock_response.status_code = 200 | ||
mock_response.json.return_value = {"success": True, "link": "https://file.io/xyz789"} | ||
mock_post.return_value = mock_response | ||
|
||
with patch('builtins.open', mock_open(read_data=b"data")): | ||
file_url = upload_file("path/to/file.wav", service='file.io') | ||
assert file_url == "https://file.io/xyz789" | ||
|
||
@patch('yawt.audio_handler.requests.get') | ||
def test_download_audio_success(mock_get, mocker, tmp_path): | ||
mock_response = MagicMock() | ||
mock_response.status_code = 200 | ||
mock_response.iter_content.return_value = [b'data'] | ||
mock_response.__enter__.return_value = mock_response | ||
mock_get.return_value = mock_response | ||
|
||
with patch('builtins.open', mock_open()) as mocked_file: | ||
file_path = download_audio("https://example.com/audio.wav") | ||
mocked_file.assert_called_once_with(os.path.join(tmp_path, 'temp_audio.wav'), 'wb') | ||
assert file_path.endswith('.wav') | ||
|
||
def test_download_audio_empty_file(mocker, tmp_path): | ||
mock_get = mocker.patch('yawt.audio_handler.requests.get') | ||
mock_response = MagicMock() | ||
mock_response.status_code = 200 | ||
mock_response.iter_content.return_value = [] | ||
mock_response.__enter__.return_value = mock_response | ||
mock_get.return_value = mock_response | ||
|
||
with patch('builtins.open', mock_open()), \ | ||
patch('yawt.audio_handler.os.path.getsize', return_value=0), \ | ||
pytest.raises(Exception): | ||
download_audio("https://example.com/empty_audio.wav") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import pytest | ||
from yawt.config import load_config, validate_config | ||
import os | ||
import yaml | ||
|
||
def test_load_config(): | ||
config = load_config("config/default_config.yaml") | ||
assert config is not None, "Configuration should not be None" | ||
assert "api_costs" in config, "Configuration should contain 'api_costs'" | ||
assert "logging" in config, "Configuration should contain 'logging'" | ||
assert "model" in config, "Configuration should contain 'model'" | ||
assert "supported_upload_services" in config, "Configuration should contain 'supported_upload_services'" | ||
assert "timeouts" in config, "Configuration should contain 'timeouts'" | ||
|
||
def test_load_config_file_not_found(): | ||
with pytest.raises(FileNotFoundError): | ||
load_config("config/nonexistent_config.yaml") | ||
|
||
def test_validate_config_valid(): | ||
config = load_config("config/default_config.yaml") | ||
# If load_config already validates, this should pass without exceptions | ||
validate_config(config) # Should not raise | ||
|
||
def test_validate_config_invalid_whisper_cost(tmp_path): | ||
# Create an invalid config with negative whisper cost | ||
invalid_config = { | ||
"api_costs": { | ||
"whisper": { | ||
"cost_per_minute": -0.01 # Invalid negative value | ||
}, | ||
"pyannote": { | ||
"cost_per_hour": 0.18 | ||
} | ||
}, | ||
"logging": { | ||
"log_directory": "logs", | ||
"max_log_size": 10485760, | ||
"backup_count": 5 | ||
}, | ||
"model": { | ||
"default_model_id": "openai/whisper-large-v3" | ||
}, | ||
"supported_upload_services": ["0x0.st", "file.io"], | ||
"timeouts": { | ||
"download_timeout": 120, | ||
"upload_timeout": 120, | ||
"diarization_timeout": 3600, | ||
"job_status_timeout": 60 | ||
} | ||
} | ||
config_path = tmp_path / "invalid_config.yaml" | ||
with open(config_path, 'w') as f: | ||
yaml.dump(invalid_config, f) | ||
|
||
with pytest.raises(ValueError) as exc_info: | ||
config = load_config(str(config_path)) | ||
validate_config(config) | ||
assert "Whisper cost per minute must be non-negative." in str(exc_info.value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import pytest | ||
from unittest.mock import patch, MagicMock | ||
from yawt.diarization import submit_diarization_job, wait_for_diarization, get_job_status | ||
|
||
@patch('yawt.diarization.requests.post') | ||
def test_submit_diarization_job_success(mock_post): | ||
mock_response = MagicMock() | ||
mock_response.status_code = 200 | ||
mock_response.json.return_value = {"jobId": "job123"} | ||
mock_post.return_value = mock_response | ||
|
||
job_id = submit_diarization_job("fake_token", "https://example.com/audio.wav", num_speakers=2) | ||
assert job_id == "job123" | ||
|
||
@patch('yawt.diarization.requests.post') | ||
def test_submit_diarization_job_rate_limit(mock_post, mocker): | ||
mock_response = MagicMock() | ||
mock_response.status_code = 429 | ||
mock_response.headers = {'Retry-After': '1'} | ||
mock_post.return_value = mock_response | ||
|
||
with patch('time.sleep') as mock_sleep: | ||
mock_post.side_effect = [mock_response, MagicMock(status_code=200, json=lambda: {"jobId": "job123"})] | ||
job_id = submit_diarization_job("fake_token", "https://example.com/audio.wav") | ||
assert job_id == "job123" | ||
assert mock_sleep.called_once_with(1) | ||
|
||
@patch('yawt.diarization.requests.post') | ||
def test_submit_diarization_job_failure(mock_post): | ||
mock_response = MagicMock() | ||
mock_response.status_code = 400 | ||
mock_response.text = "Bad Request" | ||
mock_post.return_value = mock_response | ||
|
||
with pytest.raises(Exception) as exc_info: | ||
submit_diarization_job("fake_token", "https://example.com/audio.wav") | ||
assert "Diarization submission failed" in str(exc_info.value) | ||
|
||
@patch('yawt.diarization.requests.get') | ||
def test_get_job_status_success(mock_get): | ||
mock_response = MagicMock() | ||
mock_response.status_code = 200 | ||
mock_response.json.return_value = {"status": "succeeded", "output": {"diarization": []}} | ||
mock_get.return_value = mock_response | ||
|
||
job_info = get_job_status("fake_token", "job123") | ||
assert job_info['status'] == "succeeded" | ||
|
||
@patch('yawt.diarization.requests.get') | ||
def test_get_job_status_rate_limit(mock_get, mocker): | ||
mock_response = MagicMock() | ||
mock_response.status_code = 429 | ||
mock_response.headers = {'Retry-After': '1'} | ||
mock_get.return_value = mock_response | ||
|
||
with patch('time.sleep') as mock_sleep: | ||
mock_get.side_effect = [mock_response, MagicMock(status_code=200, json=lambda: {"status": "succeeded"})] | ||
job_info = get_job_status("fake_token", "job123") | ||
assert job_info['status'] == "succeeded" | ||
assert mock_sleep.called_once_with(1) | ||
|
||
@patch('yawt.diarization.requests.get') | ||
def test_get_job_status_failure(mock_get): | ||
mock_response = MagicMock() | ||
mock_response.status_code = 500 | ||
mock_response.text = "Internal Server Error" | ||
mock_get.return_value = mock_response | ||
|
||
with pytest.raises(Exception) as exc_info: | ||
get_job_status("fake_token", "job123") | ||
assert "Failed to get job status" in str(exc_info.value) | ||
|
||
@patch('yawt.diarization.get_job_status') | ||
def test_wait_for_diarization_success(mock_get_job_status, mocker): | ||
mock_get_job_status.return_value = {"status": "succeeded", "output": {"diarization": []}} | ||
job_info = wait_for_diarization("fake_token", "job123", "https://example.com/audio.wav") | ||
assert job_info['status'] == "succeeded" | ||
|
||
@patch('yawt.diarization.get_job_status') | ||
def test_wait_for_diarization_timeout(mock_get_job_status, mocker): | ||
mock_get_job_status.return_value = {"status": "in_progress"} | ||
|
||
with patch('time.time', side_effect=[0, 4000]): | ||
with patch('time.sleep'): | ||
with pytest.raises(Exception) as exc_info: | ||
wait_for_diarization("fake_token", "job123", "https://example.com/audio.wav", timeout=3600) | ||
assert "Diarization job timed out." in str(exc_info.value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import pytest | ||
from unittest.mock import patch, MagicMock | ||
from yawt.logging_setup import setup_logging | ||
import logging | ||
|
||
def test_setup_logging_default(mocker): | ||
mock_makedirs = mocker.patch('yawt.logging_setup.os.makedirs') | ||
mock_RotatingFileHandler = mocker.patch('yawt.logging_setup.RotatingFileHandler') | ||
mock_StreamHandler = mocker.patch('yawt.logging_setup.logging.StreamHandler') | ||
|
||
setup_logging() | ||
|
||
mock_makedirs.assert_called_once_with("logs", exist_ok=True) | ||
assert logging.getLogger().level == logging.WARNING | ||
mock_StreamHandler.assert_called_once_with() | ||
|
||
def test_setup_logging_debug(mocker): | ||
mock_makedirs = mocker.patch('yawt.logging_setup.os.makedirs') | ||
mock_RotatingFileHandler = mocker.patch('yawt.logging_setup.RotatingFileHandler') | ||
mock_StreamHandler = mocker.patch('yawt.logging_setup.logging.StreamHandler') | ||
mock_basicConfig = mocker.patch('yawt.logging_setup.logging.basicConfig') | ||
|
||
setup_logging(debug=True) | ||
|
||
mock_basicConfig.assert_called_once() | ||
args, kwargs = mock_basicConfig.call_args | ||
assert kwargs['level'] == logging.DEBUG | ||
|
||
def test_setup_logging_verbose(mocker): | ||
mock_makedirs = mocker.patch('yawt.logging_setup.os.makedirs') | ||
mock_RotatingFileHandler = mocker.patch('yawt.logging_setup.RotatingFileHandler') | ||
mock_StreamHandler = mocker.patch('yawt.logging_setup.logging.StreamHandler') | ||
mock_basicConfig = mocker.patch('yawt.logging_setup.logging.basicConfig') | ||
|
||
setup_logging(verbose=True) | ||
|
||
mock_basicConfig.assert_called_once() | ||
args, kwargs = mock_basicConfig.call_args | ||
assert kwargs['level'] == logging.INFO |
Oops, something went wrong.