Skip to content

Commit

Permalink
feat: enhance dataset management commands with optional dataset name …
Browse files Browse the repository at this point in the history
…and summary printing
  • Loading branch information
jjjermiah committed Dec 10, 2024
1 parent b2f99db commit d20afd7
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 15 deletions.
38 changes: 29 additions & 9 deletions src/orcestradownloader/cli/__main__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, Type
from typing import Dict, Type, List

import click
from click import Group, MultiCommand
Expand Down Expand Up @@ -82,25 +82,37 @@ def get_command(self, ctx, name):
@click.option('--no-pretty', is_flag=True, help='Disable pretty printing')
@click.pass_context
def _list(ctx, force: bool = False, no_pretty: bool = False, verbose: int = 1, quiet: bool = False):
"""List items for this dataset."""
"""List ALL datasets for this data type."""
manager = UnifiedDataManager(force=force)
manager.list_one(name, pretty=not no_pretty)

@ds_group.command(name='table')
@set_log_verbosity()
@click.argument('ds_name', nargs=1, type=str, required=False, metavar='[NAME OF DATASET]')
@click.option('--force', is_flag=True, help='Force fetch new data')
@click.pass_context
def _table(ctx, force: bool = False, verbose: int = 1, quiet: bool = False):
"""Print a table of items for this dataset."""
def _table(ctx, force: bool = False, verbose: int = 1, quiet: bool = False, ds_name: str | None = None):
"""Print a table summary items for this dataset.
If no dataset name is provided, prints a table of all datasets.
If a dataset name is provided, prints a table of the specified dataset.
"""
manager = UnifiedDataManager(force=force)
manager.print_one_table(name)
manager.fetch_one(name)
ds_manager = manager[name]
if ds_name:
ds_manager[ds_name].print_summary(title=f'{ds_name} Summary')
else:
manager.print_one_table(name)



@ds_group.command(name='download')
@click.option('--overwrite', '-o', is_flag=True, help='Overwrite existing file, if it exists.', default=False, show_default=True)
@click.option('--filename', '-f', help='Filename to save the file as. Defaults to the name of the dataset', default=None, type=str, required=False)
@click.option('--directory', '-d', help='Directory to save the file to', default=Path.cwd(), type=click.Path(exists=True, file_okay=False, dir_okay=True, writable=True, path_type=Path), required=True)
@click.argument(
'name',
'ds_name',
type=str,
required=True,
nargs=-1,
Expand All @@ -111,6 +123,7 @@ def _table(ctx, force: bool = False, verbose: int = 1, quiet: bool = False):
@click.pass_context
def _download(
ctx,
ds_name: List[str],
directory: Path,
force: bool = False,
verbose: int = 1,
Expand All @@ -120,6 +133,9 @@ def _download(
):
"""Download a file for this dataset."""
click.echo(f'Downloading {name} to {directory}')
manager = UnifiedDataManager(force=force)
file_path = manager.download_one(name, ds_name, directory, overwrite, force)
click.echo(f'Downloaded {file_path}')
return ds_group
return None

Expand All @@ -131,8 +147,8 @@ def format_usage(self, ctx, formatter):
)

@click.command(cls=DatasetMultiCommand, context_settings=CONTEXT_SETTINGS, invoke_without_command=True)
@click.help_option("-h", "--help", help="Show this message and exit.")
@click.option('-r', '--refresh', is_flag=True, help='Fetch all datasets and hydrate the cache.', default=False, show_default=True)
@click.help_option("-h", "--help", help="Show this message and exit.")
@set_log_verbosity()
@click.pass_context
def cli(ctx, refresh: bool = False, verbose: int = 0, quiet: bool = False):
Expand All @@ -158,14 +174,18 @@ def cli(ctx, refresh: bool = False, verbose: int = 0, quiet: bool = False):
"""
ctx.ensure_object(dict)

# if user wants to refresh all datasets in the cache
if refresh:
manager = UnifiedDataManager(force=True)
manager.hydrate_cache()
manager.list_all()
return
click.echo(ctx.get_help())


# if no subcommand is provided, print help
elif ctx.invoked_subcommand is None:
click.echo(ctx.get_help())
return

if __name__ == '__main__':
cli()
61 changes: 57 additions & 4 deletions src/orcestradownloader/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Type
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar

import aiohttp
from rich.console import Console
Expand All @@ -13,6 +13,10 @@

from orcestradownloader.cache import Cache
from orcestradownloader.logging_config import logger as log
from orcestradownloader.models.base import BaseModel

# Type variable for subclasses of BaseModel
T = TypeVar('T', bound=BaseModel)

CACHE_DIR = Path.home() / '.cache/orcestradownloader'

Expand Down Expand Up @@ -62,13 +66,13 @@ def print_table(self, items: List[Any], row_generator: Callable) -> None:


@dataclass
class DatasetManager:
class DatasetManager(Generic[T]):
"""Base class for managing datasets."""

url: str
cache_file: str
dataset_type: Type[Any]
datasets: List[Any] = field(default_factory=list)
dataset_type: Type[T]
datasets: List[T] = field(default_factory=list)

def __post_init__(self) -> None:
self.cache = Cache(CACHE_DIR, self.cache_file, CACHE_DAYS_TO_KEEP)
Expand Down Expand Up @@ -105,6 +109,15 @@ def names(self) -> List[str]:
"""List all datasets."""
return [ds.name for ds in self.datasets]

def __getitem__(self, name: str) -> T:
"""Get a dataset by name."""
try:
return next(ds for ds in self.datasets if ds.name == name)
except StopIteration as se:
msg = f'Dataset {name} not found in {self.__class__.__name__}.'
msg += f' Available datasets: {", ".join(self.names())}'
raise ValueError(msg) from se


class DatasetRegistry:
"""Registry to hold dataset manager instances."""
Expand Down Expand Up @@ -248,3 +261,43 @@ def list_all(self, pretty: bool = True, force: bool = False) -> None:

for ds_name in ds_names:
click.echo(f'{name},{ds_name}')

def download_one(self,
manager_name: str,
ds_name: List[str],
directory: Path,
overwrite: bool = False,
force: bool = False
) -> Path:
"""Download a single dataset."""
# Fetch data asynchronously
try:
self.fetch_one(manager_name)
except Exception as e:
log.exception('Error fetching %s: %s', manager_name, e)
errmsg = f'Error fetching {manager_name}: {e}'
raise ValueError(errmsg) from e

manager = self.registry.get_manager(manager_name)
dataset_list = [manager[ds_name] for ds_name in ds_name]

for ds in dataset_list:
if not ds.download_link:
msg = f'Dataset {ds.name} does not have a download link.'
raise ValueError(msg)

file_path = directory / f'{ds.name}.zip'
return file_path

def names(self) -> List[str]:
"""List all managers."""
return list(self.registry.get_all_managers().keys())

def __getitem__(self, name: str) -> DatasetManager:
"""Get a manager by name."""
try:
return self.registry.get_manager(name)
except StopIteration as se:
msg = f'Manager {name} not found in {self.__class__.__name__}.'
msg += f' Available managers: {", ".join(self.names())}'
raise ValueError(msg) from se
14 changes: 12 additions & 2 deletions src/orcestradownloader/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,15 @@ def datatypes(self) -> List[str]:
"""
return [datatype.name for datatype in self.available_datatypes]

def print_summary(self) -> None:
def print_summary(self, title: str | None = None) -> None:
"""
Print a summary of the dataset record.
This method uses Rich to display a well-formatted table of the record's attributes.
"""
table = Table(title=f'{self.__class__.__name__} Summary')
table = Table(
title=title if title else f'{self.__class__.__name__} Summary'
)

table.add_column('Field', style='bold cyan', no_wrap=True)
table.add_column('Value', style='magenta')
Expand All @@ -192,10 +194,18 @@ def print_summary(self) -> None:
table.add_row('Download Link', self.download_link)
table.add_row('Dataset Name', self.dataset.name)
table.add_row('Dataset Version', self.dataset.version_info.version)
table.add_row(
'Dataset Type',
self.dataset.version_info.dataset_type.name if self.dataset.version_info.dataset_type else 'N/A',
)
table.add_row(
'Available Datatypes',
', '.join(self.datatypes) if self.datatypes else 'N/A',
)
table.add_row(
'Publications',
', '.join([f"{pub.citation} ({pub.link})" for pub in self.dataset.version_info.publication])
)

console = Console()
console.print(table)

0 comments on commit d20afd7

Please sign in to comment.