Skip to content

Commit

Permalink
feat(cli): add option to choose which columns to display (#257)
Browse files Browse the repository at this point in the history
  • Loading branch information
ludwiktrammer authored Dec 23, 2024
1 parent 30796aa commit 64a56c0
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 43 deletions.
112 changes: 80 additions & 32 deletions packages/ragbits-cli/src/ragbits/cli/state.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import json
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from enum import Enum
from typing import TypeVar

import typer
from pydantic import BaseModel
from rich.console import Console
from rich.table import Table
from rich.table import Column, Table


class OutputType(Enum):
Expand All @@ -24,41 +26,87 @@ class CliState:

cli_state = CliState()

ModelT = TypeVar("ModelT", bound=BaseModel)

def print_output(data: Sequence[BaseModel] | BaseModel) -> None:

def print_output_table(
data: Sequence[ModelT], columns: Mapping[str, Column] | Sequence[str] | str | None = None
) -> None:
"""
Process and display output based on the current state's output type.
Display data from Pydantic models in a table format.
Args:
data: a list of pydantic models representing output of CLI function
columns: a list of columns to display in the output table: either as a list, string with comma separated names,
or for grater control over how the data is displayed a mapping of column names to Column objects.
If not provided, the columns will be inferred from the model schema.
"""
console = Console()
if isinstance(data, BaseModel):
data = [data]
if len(data) == 0:
_print_empty_list()

if not data:
console.print("No results")
return
first_el_instance = type(data[0])
if any(not isinstance(datapoint, first_el_instance) for datapoint in data):
raise ValueError("All the rows need to be of the same type")
data_dicts: list[dict] = [output.model_dump(mode="python") for output in data]
output_type = cli_state.output_type
if output_type == OutputType.json:
console.print(json.dumps(data_dicts, indent=4))
elif output_type == OutputType.text:
table = Table(show_header=True, header_style="bold magenta")
properties = data[0].model_json_schema()["properties"]
for key in properties:
table.add_column(properties[key]["title"])
for row in data_dicts:
table.add_row(*[str(value) for value in row.values()])
console.print(table)
else:
raise ValueError(f"Output type: {output_type} not supported")


def _print_empty_list() -> None:
if cli_state.output_type == OutputType.text:
print("Empty data list")
elif cli_state.output_type == OutputType.json:
print(json.dumps([]))

fields = data[0].model_fields

# Human readable titles for columns
titles = {key: value.get("title", key) for key, value in data[0].model_json_schema()["properties"].items()}

# Normalize the list of columns
if columns is None:
columns = {key: Column() for key in fields}
elif isinstance(columns, str):
columns = {key: Column() for key in columns.split(",")}
elif isinstance(columns, Sequence):
columns = {key: Column() for key in columns}

# Add headers to columns if not provided
for key in columns:
if key not in fields:
Console(stderr=True).print(f"Unknown column: {key}")
raise typer.Exit(1)

column = columns[key]
if column.header == "":
column.header = titles.get(key, key)

# Create and print the table
table = Table(*columns.values(), show_header=True, header_style="bold magenta")
for row in data:
table.add_row(*[str(getattr(row, key)) for key in columns])
console.print(table)


def print_output_json(data: Sequence[ModelT]) -> None:
"""
Display data from Pydantic models in a JSON format.
Args:
data: a list of pydantic models representing output of CLI function
"""
console = Console()
console.print(json.dumps([output.model_dump(mode="json") for output in data], indent=4))


def print_output(
data: Sequence[ModelT] | ModelT, columns: Mapping[str, Column] | Sequence[str] | str | None = None
) -> None:
"""
Process and display output based on the current state's output type.
Args:
data: a list of pydantic models representing output of CLI function
columns: a list of columns to display in the output table: either as a list, string with comma separated names,
or for grater control over how the data is displayed a mapping of column names to Column objects.
If not provided, the columns will be inferred from the model schema.
"""
if not isinstance(data, Sequence):
data = [data]

match cli_state.output_type:
case OutputType.text:
print_output_table(data, columns)
case OutputType.json:
print_output_json(data)
case _:
raise ValueError(f"Unsupported output type: {cli_state.output_type}")
53 changes: 42 additions & 11 deletions packages/ragbits-core/src/ragbits/core/vector_stores/_cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from dataclasses import dataclass
from pathlib import Path
from typing import Annotated

import typer
from pydantic import BaseModel
Expand All @@ -21,11 +22,22 @@ class CLIState:

state: CLIState = CLIState()

# Default columns for commands that list entries
_default_columns = "id,key,metadata"


@vector_stores_app.callback()
def common_args(
factory_path: str | None = None,
yaml_path: Path | None = None,
factory_path: Annotated[
str | None,
typer.Option(
help="Python path to a function that creates a vector store, ina format 'module.submodule:function'"
),
] = None,
yaml_path: Annotated[
Path | None,
typer.Option(help="Path to a YAML configuration file for the vector store", exists=True, resolve_path=True),
] = None,
) -> None:
state.vector_store = get_instance_or_exit(
VectorStore,
Expand All @@ -35,7 +47,13 @@ def common_args(


@vector_stores_app.command(name="list")
def list_entries(limit: int = 10, offset: int = 0) -> None:
def list_entries(
limit: Annotated[int, typer.Option(help="Maximum number of entries to list")] = 10,
offset: Annotated[int, typer.Option(help="How many entries to skip")] = 0,
columns: Annotated[
str, typer.Option(help="Comma-separated list of columns to display, aviailable: id, key, vector, metadata")
] = _default_columns,
) -> None:
"""
List all objects in the chosen vector store.
"""
Expand All @@ -45,7 +63,7 @@ async def run() -> None:
raise ValueError("Vector store not initialized")

entries = await state.vector_store.list(limit=limit, offset=offset)
print_output(entries)
print_output(entries, columns=columns)

asyncio.run(run())

Expand All @@ -55,7 +73,9 @@ class RemovedItem(BaseModel):


@vector_stores_app.command()
def remove(ids: list[str]) -> None:
def remove(
ids: Annotated[list[str], typer.Argument(help="IDs of the entries to remove from the vector store")],
) -> None:
"""
Remove objects from the chosen vector store.
"""
Expand All @@ -75,11 +95,22 @@ async def run() -> None:

@vector_stores_app.command()
def query(
text: str,
k: int = 5,
max_distance: float | None = None,
embedder_factory_path: str | None = None,
embedder_yaml_path: Path | None = None,
text: Annotated[str, typer.Argument(help="Text to query the vector store with")],
k: Annotated[int, typer.Option(help="Number of entries to retrieve")] = 5,
max_distance: Annotated[float | None, typer.Option(help="Maximum distance to the query vector")] = None,
embedder_factory_path: Annotated[
str | None,
typer.Option(
help="Python path to a function that creates an embedder, in a format 'module.submodule:function'"
),
] = None,
embedder_yaml_path: Annotated[
Path | None,
typer.Option(help="Path to a YAML configuration file for the embedder", exists=True, resolve_path=True),
] = None,
columns: Annotated[
str, typer.Option(help="Comma-separated list of columns to display, aviailable: id, key, vector, metadata")
] = _default_columns,
) -> None:
"""
Query the chosen vector store.
Expand All @@ -104,6 +135,6 @@ async def run() -> None:
vector=search_vector[0],
options=options,
)
print_output(entries)
print_output(entries, columns=columns)

asyncio.run(run())
41 changes: 41 additions & 0 deletions packages/ragbits-core/tests/cli/test_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,47 @@ def test_vector_store_list_limit_offset():
assert "entry 3" not in result.stdout


def test_vector_store_list_columns():
runner = CliRunner(mix_stderr=False)
result = runner.invoke(
vector_stores_app,
["--factory-path", "cli.test_vector_store:vector_store_factory", "list", "--columns", "id,key,metadata"],
)
assert result.exit_code == 0
assert "entry 1" in result.stdout
assert "entry 2" in result.stdout
assert "entry 3" in result.stdout
assert "Vector" not in result.stdout
assert "Id" in result.stdout
assert "Key" in result.stdout
assert "Metadata" in result.stdout
assert "another_key" in result.stdout

result = runner.invoke(
vector_stores_app,
["--factory-path", "cli.test_vector_store:vector_store_factory", "list", "--columns", "id,key"],
)
assert result.exit_code == 0
assert "entry 1" in result.stdout
assert "entry 2" in result.stdout
assert "entry 3" in result.stdout
assert "Vector" not in result.stdout
assert "Id" in result.stdout
assert "Key" in result.stdout
assert "Metadata" not in result.stdout
assert "another_key" not in result.stdout


def test_vector_store_list_columns_non_existent():
runner = CliRunner(mix_stderr=False)
result = runner.invoke(
vector_stores_app,
["--factory-path", "cli.test_vector_store:vector_store_factory", "list", "--columns", "id,key,non_existent"],
)
assert result.exit_code == 1
assert "Unknown column: non_existent" in result.stderr


def test_vector_store_remove():
runner = CliRunner(mix_stderr=False)
result = runner.invoke(
Expand Down

0 comments on commit 64a56c0

Please sign in to comment.