Skip to content

Commit

Permalink
feat: enhance dataset download functionality with support for multipl…
Browse files Browse the repository at this point in the history
…e datasets and default directory
  • Loading branch information
jjjermiah committed Dec 10, 2024
1 parent 83e68be commit 48ae456
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 18 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,6 @@ recipe.yaml
.ruff_cache
**/__pycache__/*
*./**/.pyc
**.pyc
**.pyc

rawdata
26 changes: 19 additions & 7 deletions src/orcestradownloader/cli/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from orcestradownloader.managers import REGISTRY, DatasetManager, UnifiedDataManager
from orcestradownloader.models import ICBSet, PharmacoSet, RadioSet, ToxicoSet, XevaSet, RadiomicSet

DEFAULT_DATA_DIR = Path.cwd() / 'rawdata' / 'orcestradata'

@dataclass
class DatasetConfig:
url: str
Expand Down Expand Up @@ -108,8 +110,19 @@ def _table(ctx, force: bool = False, verbose: int = 1, quiet: bool = False, ds_n

@ds_group.command(name='download')
@click.option('--overwrite', '-o', is_flag=True, help='Overwrite existing file, if it exists.', default=False, show_default=True)
@click.option('--filename', '-f', help='Filename to save the file as. Defaults to the name of the dataset', default=None, type=str, required=False)
@click.option('--directory', '-d', help='Directory to save the file to', default=Path.cwd(), type=click.Path(exists=True, file_okay=False, dir_okay=True, writable=True, path_type=Path), required=True)
@click.option(
'--directory',
'-d',
help=f'Directory to save the file to. Defaults to ./{DEFAULT_DATA_DIR.relative_to(Path.cwd())}',
default=DEFAULT_DATA_DIR,
type=click.Path(
exists=False,
file_okay=False,
dir_okay=True,
writable=True,
path_type=Path
),
)
@click.argument(
'ds_name',
type=str,
Expand All @@ -123,18 +136,17 @@ def _table(ctx, force: bool = False, verbose: int = 1, quiet: bool = False, ds_n
def _download(
ctx,
ds_name: List[str],
directory: Path,
directory: Path,
force: bool = False,
verbose: int = 1,
quiet: bool = False,
filename: str | None = None,
overwrite: bool = False
):
"""Download a file for this dataset."""
click.echo(f'Downloading {name} to {directory}')
manager = UnifiedDataManager(force=force)
file_path = manager.download_one(name, ds_name, directory, overwrite, force)
click.echo(f'Downloaded {file_path}')
file_paths = manager.download_by_name(name, ds_name, directory, overwrite, force)
for file_path in file_paths:
click.echo(f'Downloaded {file_path}')
return ds_group
return None

Expand Down
52 changes: 42 additions & 10 deletions src/orcestradownloader/managers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import sys
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
Expand Down Expand Up @@ -262,13 +263,14 @@ def list_all(self, pretty: bool = True, force: bool = False) -> None:
for ds_name in ds_names:
click.echo(f'{name},{ds_name}')

def download_one(self,
manager_name: str,
ds_name: List[str],
directory: Path,
overwrite: bool = False,
force: bool = False
) -> Path:
def download_by_name(
self,
manager_name: str,
ds_name: List[str],
directory: Path,
overwrite: bool = False,
force: bool = False
) -> List[Path]:
"""Download a single dataset."""
# Fetch data asynchronously
try:
Expand All @@ -281,13 +283,26 @@ def download_one(self,
manager = self.registry.get_manager(manager_name)
dataset_list = [manager[ds_name] for ds_name in ds_name]

file_paths = {}
for ds in dataset_list:
if not ds.download_link:
msg = f'Dataset {ds.name} does not have a download link.'
raise ValueError(msg)

file_path = directory / f'{ds.name}.zip'
return file_path
if file_path.exists() and not overwrite:
Console().print(f'[bold red]File {file_path} already exists. Use --overwrite to overwrite.[/]')
sys.exit(1)
file_path.parent.mkdir(parents=True, exist_ok=True)
file_paths[ds.name] = file_path

async def download_all(progress: Progress) -> List[Path]:
return await asyncio.gather(*[
download_dataset(ds.download_link, file_paths[ds.name], progress)
for ds in dataset_list
])

with Progress() as progress:
return asyncio.run(download_all(progress))

def names(self) -> List[str]:
"""List all managers."""
Expand All @@ -300,4 +315,21 @@ def __getitem__(self, name: str) -> DatasetManager:
except StopIteration as se:
msg = f'Manager {name} not found in {self.__class__.__name__}.'
msg += f' Available managers: {", ".join(self.names())}'
raise ValueError(msg) from se
raise ValueError(msg) from se

async def download_dataset(download_link: str, file_path: Path, progress: Progress) -> Path:
"""Download a single dataset.
Called by the UnifiedDataManager.download_by_name method
"""
async with aiohttp.ClientSession() as session: # noqa: SIM117
async with session.get(download_link) as response:
total = int(response.headers.get('content-length', 0))
task = progress.add_task(f"[cyan]Downloading {file_path.name}...", total=total)
with file_path.open('wb') as f:
async for chunk in response.content.iter_chunked(8192):
f.write(chunk)
progress.update(task, advance=len(chunk))
return file_path


0 comments on commit 48ae456

Please sign in to comment.