Skip to content

Commit

Permalink
chore: format
Browse files Browse the repository at this point in the history
  • Loading branch information
jjjermiah committed Dec 11, 2024
1 parent e5eb513 commit 27ada75
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 55 deletions.
2 changes: 1 addition & 1 deletion src/orcestradownloader/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .cli import DatasetMultiCommand, cli

__all__ = ['DatasetMultiCommand', 'cli']
__all__ = ['DatasetMultiCommand', 'cli']
13 changes: 6 additions & 7 deletions src/orcestradownloader/dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,31 @@ class DatasetConfig:
'pharmacosets': DatasetConfig(
url='https://orcestra.ca/api/pset/available',
cache_file='pharmacosets.json',
dataset_type=PharmacoSet
dataset_type=PharmacoSet,
),
'icbsets': DatasetConfig(
url='https://orcestra.ca/api/clinical_icb/available',
cache_file='icbsets.json',
dataset_type=ICBSet
dataset_type=ICBSet,
),
'radiosets': DatasetConfig(
url='https://orcestra.ca/api/radioset/available',
cache_file='radiosets.json',
dataset_type=RadioSet
dataset_type=RadioSet,
),
'xevasets': DatasetConfig(
url='https://orcestra.ca/api/xevaset/available',
cache_file='xevasets.json',
dataset_type=XevaSet
dataset_type=XevaSet,
),
'toxicosets': DatasetConfig(
url='https://orcestra.ca/api/toxicoset/available',
cache_file='toxicosets.json',
dataset_type=ToxicoSet
dataset_type=ToxicoSet,
),
'radiomicsets': DatasetConfig(
url='https://orcestra.ca/api/radiomicset/available',
cache_file='radiomicsets.json',
dataset_type=RadiomicSet
dataset_type=RadiomicSet,
),
}

95 changes: 54 additions & 41 deletions src/orcestradownloader/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def get_manager(self, name: str) -> DatasetManager:

def get_all_managers(self) -> Dict[str, DatasetManager]:
return self.registry

def __repr__(self) -> str:
return f'DatasetRegistry(registry={self.registry})'

Expand Down Expand Up @@ -273,13 +273,13 @@ def list_all(self, pretty: bool = True, force: bool = False) -> None:
click.echo(f'{name},{ds_name}')

def download_by_name(
self,
manager_name: str,
self,
manager_name: str,
ds_name: List[str],
directory: Path,
overwrite: bool = False,
directory: Path,
overwrite: bool = False,
force: bool = False,
timeout_seconds: int = 3600
timeout_seconds: int = 3600,
) -> List[Path]:
"""Download a single dataset."""
# Fetch data asynchronously
Expand All @@ -289,7 +289,7 @@ def download_by_name(
log.exception('Error fetching %s: %s', manager_name, e)
errmsg = f'Error fetching {manager_name}: {e}'
raise ValueError(errmsg) from e

manager = self.registry.get_manager(manager_name)
dataset_list = [manager[ds_name] for ds_name in ds_name]

Expand All @@ -300,32 +300,36 @@ def download_by_name(
raise ValueError(msg)
file_path = directory / manager_name / f'{ds.name}.RDS'
if file_path.exists() and not overwrite:
Console().print(f'[bold red]File {file_path} already exists. Use --overwrite to overwrite.[/]')
Console().print(
f'[bold red]File {file_path} already exists. Use --overwrite to overwrite.[/]'
)
sys.exit(1)
file_path.parent.mkdir(parents=True, exist_ok=True)
file_paths[ds.name] = file_path

async def download_all(progress: Progress) -> List[Path]:
return await asyncio.gather(*[
download_dataset(
ds.download_link,
file_paths[ds.name],
progress,
timeout_seconds=timeout_seconds
)
for ds in dataset_list
])

return await asyncio.gather(
*[
download_dataset(
ds.download_link,
file_paths[ds.name],
progress,
timeout_seconds=timeout_seconds,
)
for ds in dataset_list
]
)

with Progress() as progress:
return asyncio.run(download_all(progress))

def download_all(
self,
self,
manager_name: str,
directory: Path,
overwrite: bool = False,
directory: Path,
overwrite: bool = False,
force: bool = False,
timeout_seconds: int = 3600
timeout_seconds: int = 3600,
) -> List[Path]:
"""Download all datasets for a specific manager."""
file_paths = []
Expand All @@ -335,7 +339,7 @@ def download_all(
log.exception('Error fetching %s: %s', manager_name, e)
errmsg = f'Error fetching {manager_name}: {e}'
raise ValueError(errmsg) from e

manager = self.registry.get_manager(manager_name)
for ds in manager.datasets:
if not ds.download_link:
Expand All @@ -344,22 +348,26 @@ def download_all(
continue
file_path = directory / manager_name / f'{ds.name}.RDS'
if file_path.exists() and not overwrite:
Console().print(f'[bold red]File {file_path} already exists. Use --overwrite to overwrite.[/]')
Console().print(
f'[bold red]File {file_path} already exists. Use --overwrite to overwrite.[/]'
)
sys.exit(1)
file_path.parent.mkdir(parents=True, exist_ok=True)
file_paths.append(file_path)

async def download_all_datasets(progress: Progress) -> List[Path]:
return await asyncio.gather(*[
download_dataset(
ds.download_link,
file_path,
progress,
timeout_seconds=timeout_seconds
)
for ds, file_path in zip(manager.datasets, file_paths)
])

return await asyncio.gather(
*[
download_dataset(
ds.download_link,
file_path,
progress,
timeout_seconds=timeout_seconds,
)
for ds, file_path in zip(manager.datasets, file_paths)
]
)

with Progress(
'[progress.description]{task.description}',
BarColumn(),
Expand All @@ -385,7 +393,10 @@ def __getitem__(self, name: str) -> DatasetManager:
msg += f' Available managers: {", ".join(self.names())}'
raise ValueError(msg) from se

async def download_dataset(download_link: str, file_path: Path, progress: Progress, timeout_seconds: int = 3600) -> Path:

async def download_dataset(
download_link: str, file_path: Path, progress: Progress, timeout_seconds: int = 3600
) -> Path:
"""Download a single dataset.
Called by the UnifiedDataManager.download_by_name method
Expand All @@ -395,14 +406,16 @@ async def download_dataset(download_link: str, file_path: Path, progress: Progre
try:
async with session.get(download_link) as response:
total = int(response.headers.get('content-length', 0))
task = progress.add_task(f"[cyan]Downloading {file_path.name}...", total=total)
task = progress.add_task(
f'[cyan]Downloading {file_path.name}...', total=total
)
with file_path.open('wb') as f:
async for chunk in response.content.iter_chunked(8192):
f.write(chunk)
progress.update(task, advance=len(chunk))
except asyncio.TimeoutError:
Console().print(f'[bold red]Timeout while downloading {file_path.name}. Please try again later.[/]')
Console().print(
f'[bold red]Timeout while downloading {file_path.name}. Please try again later.[/]'
)
raise
return file_path


17 changes: 11 additions & 6 deletions src/orcestradownloader/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,7 @@ def print_summary(self, title: str | None = None) -> None:
This method uses Rich to display a well-formatted table of the record's attributes.
"""
table = Table(
title=title if title else f'{self.__class__.__name__} Summary'
)
table = Table(title=title if title else f'{self.__class__.__name__} Summary')

table.add_column('Field', style='bold cyan', no_wrap=True)
table.add_column('Value', style='magenta')
Expand All @@ -195,16 +193,23 @@ def print_summary(self, title: str | None = None) -> None:
table.add_row('Dataset Name', self.dataset.name)
table.add_row('Dataset Version', self.dataset.version_info.version)
table.add_row(
'Dataset Type',
self.dataset.version_info.dataset_type.name if self.dataset.version_info.dataset_type else 'N/A',
'Dataset Type',
self.dataset.version_info.dataset_type.name
if self.dataset.version_info.dataset_type
else 'N/A',
)
table.add_row(
'Available Datatypes',
', '.join(self.datatypes) if self.datatypes else 'N/A',
)
table.add_row(
'Publications',
', '.join([f"{pub.citation} ({pub.link})" for pub in self.dataset.version_info.publication])
', '.join(
[
f'{pub.citation} ({pub.link})'
for pub in self.dataset.version_info.publication
]
),
)

console = Console()
Expand Down

0 comments on commit 27ada75

Please sign in to comment.