Skip to content

Commit

Permalink
Restructured folders and files
Browse files Browse the repository at this point in the history
  • Loading branch information
yaniv-golan committed Oct 10, 2024
1 parent ae48a22 commit 7dc5fd1
Show file tree
Hide file tree
Showing 20 changed files with 2,899 additions and 19 deletions.
2 changes: 2 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
OPENAI_KEY=your-openai-api-key
PYANNOTE_TOKEN=your-pyannote-api-token
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.11.4
28 changes: 28 additions & 0 deletions config/default_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# API Costs
api_costs:
whisper:
cost_per_minute: 0.006 # USD per minute for Whisper
pyannote:
cost_per_hour: 0.18 # USD per hour for diarization

# Logging Configuration
logging:
log_directory: "logs"
max_log_size: 10485760 # 10 MB in bytes
backup_count: 5

# Model Configuration
model:
default_model_id: "openai/whisper-large-v3"

# Supported Services
supported_upload_services:
- "0x0.st"
- "file.io"

# Timeout Settings (in seconds)
timeouts:
download_timeout: 120
upload_timeout: 120
diarization_timeout: 3600
job_status_timeout: 60
1,530 changes: 1,530 additions & 0 deletions poetry.lock

Large diffs are not rendered by default.

32 changes: 32 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
[tool.poetry]
name = "yawt"
version = "0.1.0"
description = "YAWT (Yet Another Whisper-based Transcriber) is a transcription tool that leverages OpenAI's Whisper model to deliver accurate and efficient audio-to-text conversion."
authors = ["Yaniv Golan <[email protected]>"]
packages = [{ include = "yawt", from = "src" }]

[tool.poetry.dependencies]
python = "^3.11.4"
torch = "^2.1.0"
transformers = ">=4.35.0"
tqdm = ">=4.66.0"
python-dotenv = ">=1.0.0"
ffmpeg-python = ">=0.2.0"
numpy = ">=1.21.0"
requests = ">=2.31.0"
requests-toolbelt = ">=0.10.1"
accelerate = ">=0.26.0"
srt = ">=3.0.0"
pyyaml = ">=6.0"

[tool.poetry.dev-dependencies]
pytest = "^7.0"
pytest-mock = "^3.10.0"
pytest-cov = "^4.0.0"

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry.scripts]
yawt = "yawt.main:main"
744 changes: 733 additions & 11 deletions requirements.txt

Large diffs are not rendered by default.

File renamed without changes.
6 changes: 3 additions & 3 deletions config.py → src/yawt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from pathlib import Path
from dotenv import load_dotenv

# Load environment variables from .env file
# **Load environment variables from .env file**
load_dotenv()

# Define the path to the config file
CONFIG_FILE = Path(__file__).parent / "config.yaml"
# Define the path to the default config file
CONFIG_FILE = Path(__file__).parent / "../../config/default_config.yaml"

def load_config(config_path=CONFIG_FILE):
if not config_path.exists():
Expand Down
File renamed without changes.
File renamed without changes.
10 changes: 5 additions & 5 deletions trans-ctx.py → src/yawt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,12 @@ def main():

logging.info("Script started.")

# Initialize environment variables
initialize_environment()
# Initialize and load configuration
config = load_and_prepare_config(args.config) # Load the specified config

# Retrieve API tokens
pyannote_token = args.pyannote_token or os.getenv("PYANNOTE_TOKEN")
openai_key = args.openai_key or os.getenv("OPENAI_KEY")
# **Retrieve API tokens with correct precedence**
pyannote_token = args.pyannote_token or config.get('pyannote_token') or os.getenv("PYANNOTE_TOKEN")
openai_key = args.openai_key or config.get('openai_key') or os.getenv("OPENAI_KEY")

# Check if API tokens are set
check_api_tokens(pyannote_token, openai_key)
Expand Down
File renamed without changes.
File renamed without changes.
76 changes: 76 additions & 0 deletions tests/test_audio_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import pytest
from unittest.mock import patch, mock_open, MagicMock
from yawt.audio_handler import load_audio, upload_file, download_audio
import os

def test_load_audio_success(mocker):
mock_ffmpeg = mocker.patch('yawt.audio_handler.ffmpeg')
mock_ffmpeg.input.return_value.output.return_value.run.return_value = (b'\x00\x01', None)
audio = load_audio("path/to/audio.wav")
assert isinstance(audio, list)
assert len(audio) == 2 # Based on mock data

def test_load_audio_ffmpeg_error(mocker):
mock_ffmpeg = mocker.patch('yawt.audio_handler.ffmpeg')
mock_ffmpeg.input.return_value.output.return_value.run.side_effect = ffmpeg.Error('ffmpeg', 'Error', 'Error message')

with pytest.raises(SystemExit):
load_audio("path/to/invalid_audio.wav")

@patch('yawt.audio_handler.requests.post')
def test_upload_file_0x0_st_success(mock_post, mocker):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.text = "https://0x0.st/abc123"
mock_post.return_value = mock_response

with patch('builtins.open', mock_open(read_data=b"data")):
file_url = upload_file("path/to/file.wav", service='0x0.st')
assert file_url == "https://0x0.st/abc123"

@patch('yawt.audio_handler.requests.post')
def test_upload_file_0x0_st_failure(mock_post, mocker):
mock_response = MagicMock()
mock_response.status_code = 400
mock_response.text = "Bad Request"
mock_post.return_value = mock_response

with patch('builtins.open', mock_open(read_data=b"data")), pytest.raises(Exception):
upload_file("path/to/file.wav", service='0x0.st')

@patch('yawt.audio_handler.requests.post')
def test_upload_file_file_io_success(mock_post, mocker):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"success": True, "link": "https://file.io/xyz789"}
mock_post.return_value = mock_response

with patch('builtins.open', mock_open(read_data=b"data")):
file_url = upload_file("path/to/file.wav", service='file.io')
assert file_url == "https://file.io/xyz789"

@patch('yawt.audio_handler.requests.get')
def test_download_audio_success(mock_get, mocker, tmp_path):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.iter_content.return_value = [b'data']
mock_response.__enter__.return_value = mock_response
mock_get.return_value = mock_response

with patch('builtins.open', mock_open()) as mocked_file:
file_path = download_audio("https://example.com/audio.wav")
mocked_file.assert_called_once_with(os.path.join(tmp_path, 'temp_audio.wav'), 'wb')
assert file_path.endswith('.wav')

def test_download_audio_empty_file(mocker, tmp_path):
mock_get = mocker.patch('yawt.audio_handler.requests.get')
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.iter_content.return_value = []
mock_response.__enter__.return_value = mock_response
mock_get.return_value = mock_response

with patch('builtins.open', mock_open()), \
patch('yawt.audio_handler.os.path.getsize', return_value=0), \
pytest.raises(Exception):
download_audio("https://example.com/empty_audio.wav")
58 changes: 58 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest
from yawt.config import load_config, validate_config
import os
import yaml

def test_load_config():
config = load_config("config/default_config.yaml")
assert config is not None, "Configuration should not be None"
assert "api_costs" in config, "Configuration should contain 'api_costs'"
assert "logging" in config, "Configuration should contain 'logging'"
assert "model" in config, "Configuration should contain 'model'"
assert "supported_upload_services" in config, "Configuration should contain 'supported_upload_services'"
assert "timeouts" in config, "Configuration should contain 'timeouts'"

def test_load_config_file_not_found():
with pytest.raises(FileNotFoundError):
load_config("config/nonexistent_config.yaml")

def test_validate_config_valid():
config = load_config("config/default_config.yaml")
# If load_config already validates, this should pass without exceptions
validate_config(config) # Should not raise

def test_validate_config_invalid_whisper_cost(tmp_path):
# Create an invalid config with negative whisper cost
invalid_config = {
"api_costs": {
"whisper": {
"cost_per_minute": -0.01 # Invalid negative value
},
"pyannote": {
"cost_per_hour": 0.18
}
},
"logging": {
"log_directory": "logs",
"max_log_size": 10485760,
"backup_count": 5
},
"model": {
"default_model_id": "openai/whisper-large-v3"
},
"supported_upload_services": ["0x0.st", "file.io"],
"timeouts": {
"download_timeout": 120,
"upload_timeout": 120,
"diarization_timeout": 3600,
"job_status_timeout": 60
}
}
config_path = tmp_path / "invalid_config.yaml"
with open(config_path, 'w') as f:
yaml.dump(invalid_config, f)

with pytest.raises(ValueError) as exc_info:
config = load_config(str(config_path))
validate_config(config)
assert "Whisper cost per minute must be non-negative." in str(exc_info.value)
87 changes: 87 additions & 0 deletions tests/test_diarization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import pytest
from unittest.mock import patch, MagicMock
from yawt.diarization import submit_diarization_job, wait_for_diarization, get_job_status

@patch('yawt.diarization.requests.post')
def test_submit_diarization_job_success(mock_post):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"jobId": "job123"}
mock_post.return_value = mock_response

job_id = submit_diarization_job("fake_token", "https://example.com/audio.wav", num_speakers=2)
assert job_id == "job123"

@patch('yawt.diarization.requests.post')
def test_submit_diarization_job_rate_limit(mock_post, mocker):
mock_response = MagicMock()
mock_response.status_code = 429
mock_response.headers = {'Retry-After': '1'}
mock_post.return_value = mock_response

with patch('time.sleep') as mock_sleep:
mock_post.side_effect = [mock_response, MagicMock(status_code=200, json=lambda: {"jobId": "job123"})]
job_id = submit_diarization_job("fake_token", "https://example.com/audio.wav")
assert job_id == "job123"
assert mock_sleep.called_once_with(1)

@patch('yawt.diarization.requests.post')
def test_submit_diarization_job_failure(mock_post):
mock_response = MagicMock()
mock_response.status_code = 400
mock_response.text = "Bad Request"
mock_post.return_value = mock_response

with pytest.raises(Exception) as exc_info:
submit_diarization_job("fake_token", "https://example.com/audio.wav")
assert "Diarization submission failed" in str(exc_info.value)

@patch('yawt.diarization.requests.get')
def test_get_job_status_success(mock_get):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"status": "succeeded", "output": {"diarization": []}}
mock_get.return_value = mock_response

job_info = get_job_status("fake_token", "job123")
assert job_info['status'] == "succeeded"

@patch('yawt.diarization.requests.get')
def test_get_job_status_rate_limit(mock_get, mocker):
mock_response = MagicMock()
mock_response.status_code = 429
mock_response.headers = {'Retry-After': '1'}
mock_get.return_value = mock_response

with patch('time.sleep') as mock_sleep:
mock_get.side_effect = [mock_response, MagicMock(status_code=200, json=lambda: {"status": "succeeded"})]
job_info = get_job_status("fake_token", "job123")
assert job_info['status'] == "succeeded"
assert mock_sleep.called_once_with(1)

@patch('yawt.diarization.requests.get')
def test_get_job_status_failure(mock_get):
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.text = "Internal Server Error"
mock_get.return_value = mock_response

with pytest.raises(Exception) as exc_info:
get_job_status("fake_token", "job123")
assert "Failed to get job status" in str(exc_info.value)

@patch('yawt.diarization.get_job_status')
def test_wait_for_diarization_success(mock_get_job_status, mocker):
mock_get_job_status.return_value = {"status": "succeeded", "output": {"diarization": []}}
job_info = wait_for_diarization("fake_token", "job123", "https://example.com/audio.wav")
assert job_info['status'] == "succeeded"

@patch('yawt.diarization.get_job_status')
def test_wait_for_diarization_timeout(mock_get_job_status, mocker):
mock_get_job_status.return_value = {"status": "in_progress"}

with patch('time.time', side_effect=[0, 4000]):
with patch('time.sleep'):
with pytest.raises(Exception) as exc_info:
wait_for_diarization("fake_token", "job123", "https://example.com/audio.wav", timeout=3600)
assert "Diarization job timed out." in str(exc_info.value)
39 changes: 39 additions & 0 deletions tests/test_logging_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
from unittest.mock import patch, MagicMock
from yawt.logging_setup import setup_logging
import logging

def test_setup_logging_default(mocker):
mock_makedirs = mocker.patch('yawt.logging_setup.os.makedirs')
mock_RotatingFileHandler = mocker.patch('yawt.logging_setup.RotatingFileHandler')
mock_StreamHandler = mocker.patch('yawt.logging_setup.logging.StreamHandler')

setup_logging()

mock_makedirs.assert_called_once_with("logs", exist_ok=True)
assert logging.getLogger().level == logging.WARNING
mock_StreamHandler.assert_called_once_with()

def test_setup_logging_debug(mocker):
mock_makedirs = mocker.patch('yawt.logging_setup.os.makedirs')
mock_RotatingFileHandler = mocker.patch('yawt.logging_setup.RotatingFileHandler')
mock_StreamHandler = mocker.patch('yawt.logging_setup.logging.StreamHandler')
mock_basicConfig = mocker.patch('yawt.logging_setup.logging.basicConfig')

setup_logging(debug=True)

mock_basicConfig.assert_called_once()
args, kwargs = mock_basicConfig.call_args
assert kwargs['level'] == logging.DEBUG

def test_setup_logging_verbose(mocker):
mock_makedirs = mocker.patch('yawt.logging_setup.os.makedirs')
mock_RotatingFileHandler = mocker.patch('yawt.logging_setup.RotatingFileHandler')
mock_StreamHandler = mocker.patch('yawt.logging_setup.logging.StreamHandler')
mock_basicConfig = mocker.patch('yawt.logging_setup.logging.basicConfig')

setup_logging(verbose=True)

mock_basicConfig.assert_called_once()
args, kwargs = mock_basicConfig.call_args
assert kwargs['level'] == logging.INFO
Loading

0 comments on commit 7dc5fd1

Please sign in to comment.