From 48ae456ce42b9066e67d0df4a9ce696ab255fb49 Mon Sep 17 00:00:00 2001 From: Jermiah Joseph Date: Tue, 10 Dec 2024 14:08:26 -0500 Subject: [PATCH] feat: enhance dataset download functionality with support for multiple datasets and default directory --- .gitignore | 4 +- src/orcestradownloader/cli/__main__.py | 26 +++++++++---- src/orcestradownloader/managers.py | 52 +++++++++++++++++++++----- 3 files changed, 64 insertions(+), 18 deletions(-) diff --git a/.gitignore b/.gitignore index 1c8a509..c0b3032 100644 --- a/.gitignore +++ b/.gitignore @@ -160,4 +160,6 @@ recipe.yaml .ruff_cache **/__pycache__/* *./**/.pyc -**.pyc \ No newline at end of file +**.pyc + +rawdata \ No newline at end of file diff --git a/src/orcestradownloader/cli/__main__.py b/src/orcestradownloader/cli/__main__.py index 3a46401..a2d2d42 100644 --- a/src/orcestradownloader/cli/__main__.py +++ b/src/orcestradownloader/cli/__main__.py @@ -11,6 +11,8 @@ from orcestradownloader.managers import REGISTRY, DatasetManager, UnifiedDataManager from orcestradownloader.models import ICBSet, PharmacoSet, RadioSet, ToxicoSet, XevaSet, RadiomicSet +DEFAULT_DATA_DIR = Path.cwd() / 'rawdata' / 'orcestradata' + @dataclass class DatasetConfig: url: str @@ -108,8 +110,19 @@ def _table(ctx, force: bool = False, verbose: int = 1, quiet: bool = False, ds_n @ds_group.command(name='download') @click.option('--overwrite', '-o', is_flag=True, help='Overwrite existing file, if it exists.', default=False, show_default=True) - @click.option('--filename', '-f', help='Filename to save the file as. Defaults to the name of the dataset', default=None, type=str, required=False) - @click.option('--directory', '-d', help='Directory to save the file to', default=Path.cwd(), type=click.Path(exists=True, file_okay=False, dir_okay=True, writable=True, path_type=Path), required=True) + @click.option( + '--directory', + '-d', + help=f'Directory to save the file to. Defaults to ./{DEFAULT_DATA_DIR.relative_to(Path.cwd())}', + default=DEFAULT_DATA_DIR, + type=click.Path( + exists=False, + file_okay=False, + dir_okay=True, + writable=True, + path_type=Path + ), + ) @click.argument( 'ds_name', type=str, @@ -123,18 +136,17 @@ def _table(ctx, force: bool = False, verbose: int = 1, quiet: bool = False, ds_n def _download( ctx, ds_name: List[str], - directory: Path, + directory: Path, force: bool = False, verbose: int = 1, quiet: bool = False, - filename: str | None = None, overwrite: bool = False ): """Download a file for this dataset.""" - click.echo(f'Downloading {name} to {directory}') manager = UnifiedDataManager(force=force) - file_path = manager.download_one(name, ds_name, directory, overwrite, force) - click.echo(f'Downloaded {file_path}') + file_paths = manager.download_by_name(name, ds_name, directory, overwrite, force) + for file_path in file_paths: + click.echo(f'Downloaded {file_path}') return ds_group return None diff --git a/src/orcestradownloader/managers.py b/src/orcestradownloader/managers.py index cbaf295..fa2cf8c 100644 --- a/src/orcestradownloader/managers.py +++ b/src/orcestradownloader/managers.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import sys from collections import defaultdict from dataclasses import dataclass, field from pathlib import Path @@ -262,13 +263,14 @@ def list_all(self, pretty: bool = True, force: bool = False) -> None: for ds_name in ds_names: click.echo(f'{name},{ds_name}') - def download_one(self, - manager_name: str, - ds_name: List[str], - directory: Path, - overwrite: bool = False, - force: bool = False - ) -> Path: + def download_by_name( + self, + manager_name: str, + ds_name: List[str], + directory: Path, + overwrite: bool = False, + force: bool = False + ) -> List[Path]: """Download a single dataset.""" # Fetch data asynchronously try: @@ -281,13 +283,26 @@ def download_one(self, manager = self.registry.get_manager(manager_name) dataset_list = [manager[ds_name] for ds_name in ds_name] + file_paths = {} for ds in dataset_list: if not ds.download_link: msg = f'Dataset {ds.name} does not have a download link.' raise ValueError(msg) - file_path = directory / f'{ds.name}.zip' - return file_path + if file_path.exists() and not overwrite: + Console().print(f'[bold red]File {file_path} already exists. Use --overwrite to overwrite.[/]') + sys.exit(1) + file_path.parent.mkdir(parents=True, exist_ok=True) + file_paths[ds.name] = file_path + + async def download_all(progress: Progress) -> List[Path]: + return await asyncio.gather(*[ + download_dataset(ds.download_link, file_paths[ds.name], progress) + for ds in dataset_list + ]) + + with Progress() as progress: + return asyncio.run(download_all(progress)) def names(self) -> List[str]: """List all managers.""" @@ -300,4 +315,21 @@ def __getitem__(self, name: str) -> DatasetManager: except StopIteration as se: msg = f'Manager {name} not found in {self.__class__.__name__}.' msg += f' Available managers: {", ".join(self.names())}' - raise ValueError(msg) from se \ No newline at end of file + raise ValueError(msg) from se + +async def download_dataset(download_link: str, file_path: Path, progress: Progress) -> Path: + """Download a single dataset. + + Called by the UnifiedDataManager.download_by_name method + """ + async with aiohttp.ClientSession() as session: # noqa: SIM117 + async with session.get(download_link) as response: + total = int(response.headers.get('content-length', 0)) + task = progress.add_task(f"[cyan]Downloading {file_path.name}...", total=total) + with file_path.open('wb') as f: + async for chunk in response.content.iter_chunked(8192): + f.write(chunk) + progress.update(task, advance=len(chunk)) + return file_path + +