Skip to content

Commit

Permalink
chore: use batching in ReadOnlyProject.fetch_read_only_experiments (#151
Browse files Browse the repository at this point in the history
)

* chore: use batching in ReadOnlyRun.fetch_read_only_experiments

* chore: use batching in ReadOnlyRun.fetch_read_only_experiments

* chore: use batching in ReadOnlyRun.fetch_read_only_experiments

* chore: use batching in ReadOnlyRun.fetch_read_only_experiments

* chore: use batching in ReadOnlyRun.fetch_read_only_experiments

* chore: use batching in ReadOnlyRun.fetch_read_only_experiments

* chore: use batching in ReadOnlyRun.fetch_read_only_experiments

* chore: MR changes
  • Loading branch information
PatrykGala authored Jan 17, 2025
1 parent 59669a0 commit 9532fc8
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 51 deletions.
4 changes: 3 additions & 1 deletion src/neptune_fetcher/nql.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
Union,
)

from neptune_fetcher.util import escape_nql_criterion


@dataclass
class NQLQuery:
Expand Down Expand Up @@ -185,7 +187,7 @@ def prepare_nql_query(
name="sys/id",
type=NQLAttributeType.STRING,
operator=NQLAttributeOperator.EQUALS,
value=api_id,
value=escape_nql_criterion(api_id),
)
for api_id in ids
],
Expand Down
112 changes: 95 additions & 17 deletions src/neptune_fetcher/read_only_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = [
"ReadOnlyProject",
]
__all__ = ["ReadOnlyProject"]

import collections
import concurrent.futures
Expand Down Expand Up @@ -69,7 +67,8 @@
logger = logging.getLogger(__name__)

PROJECT_ENV_NAME = "NEPTUNE_PROJECT"
SYS_COLUMNS = ["sys/id", "sys/name", "sys/custom_run_id"]
SYS_ID = "sys/id"
SYS_COLUMNS = [SYS_ID, "sys/name", "sys/custom_run_id"]

MAX_CUMULATIVE_LENGTH = 100000
MAX_QUERY_LENGTH = 250000
Expand Down Expand Up @@ -166,9 +165,7 @@ def list_experiments(self) -> Generator[Dict[str, Optional[str]], None, None]:
yield {column: exp.attributes.get(column, None) for column in SYS_COLUMNS}

def fetch_read_only_runs(
self,
with_ids: Optional[List[str]] = None,
custom_ids: Optional[List[str]] = None,
self, with_ids: Optional[List[str]] = None, custom_ids: Optional[List[str]] = None, eager_load_fields=True
) -> Iterator[ReadOnlyRun]:
"""Lists runs of the project in the form of read-only runs.
Expand All @@ -177,23 +174,42 @@ def fetch_read_only_runs(
Args:
with_ids: List of run ids to fetch.
custom_ids: List of custom run ids to fetch.
eager_load_fields: Whether to eagerly load the run fields definitions.
If `False`, individual fields are loaded only when accessed. Default is `True`.
"""
for run_id in with_ids or []:
yield ReadOnlyRun(read_only_project=self, with_id=run_id)

for custom_id in custom_ids or []:
yield ReadOnlyRun(read_only_project=self, custom_id=custom_id)
queries = []
if with_ids:
queries.append(_make_leaderboard_nql(is_run=True, with_ids=with_ids))

if custom_ids:
queries.append(_make_leaderboard_nql(is_run=True, custom_ids=custom_ids))

for query in queries:
runs = list_objects_from_project(
backend=self._backend,
project_id=self._project_id,
query=str(query),
object_type="run",
columns=[SYS_ID],
)

for run in runs:
yield ReadOnlyRun._create(
read_only_project=self, sys_id=run.attributes[SYS_ID], eager_load_fields=eager_load_fields
)

def fetch_read_only_experiments(
self,
names: Optional[List[str]] = None,
self, names: Optional[List[str]] = None, eager_load_fields=True
) -> Iterator[ReadOnlyRun]:
"""Lists experiments of the project in the form of read-only runs.
Returns a generator of `ReadOnlyRun` instances.
Args:
names: List of experiment names to fetch.
eager_load_fields: Whether to eagerly load the run fields definitions.
If `False`, individual fields are loaded only when accessed. Default is `True`.
Example:
```
Expand All @@ -202,8 +218,20 @@ def fetch_read_only_experiments(
...
```
"""
for name in names or []:
yield ReadOnlyRun(read_only_project=self, experiment_name=name)
if names is None or names == []:
return
query = _make_leaderboard_nql(is_run=False, names=names)
experiments = list_objects_from_project(
backend=self._backend,
project_id=self._project_id,
query=str(query),
object_type="experiment",
columns=[SYS_ID],
)
for exp in experiments:
yield ReadOnlyRun._create(
read_only_project=self, sys_id=exp.attributes[SYS_ID], eager_load_fields=eager_load_fields
)

def fetch_runs(self) -> "DataFrame":
"""Fetches a table containing identifiers and names of runs in the project.
Expand Down Expand Up @@ -595,6 +623,35 @@ def _stream_attributes(
if remaining <= 0 or not next_page_token:
break

def _fetch_sys_id(
self, sys_id: Optional[str] = None, custom_id: Optional[str] = None, experiment_name: Optional[str] = None
) -> Optional[str]:
if sys_id is not None:
query = _make_leaderboard_nql(with_ids=[sys_id], trashed=False)
object_type = "run"

elif custom_id is not None:
query = _make_leaderboard_nql(custom_ids=[custom_id], trashed=False)
object_type = "run"

elif experiment_name is not None:
query = _make_leaderboard_nql(names=[experiment_name], trashed=False)
object_type = "experiment"

container = list(
list_objects_from_project(
backend=self._backend,
project_id=self._project_id,
limit=1,
columns=[SYS_ID],
query=str(query),
object_type=object_type,
)
)
if len(container) == 0:
return None
return container[0].attributes[SYS_ID]


def _extract_value(attr: ProtoAttributeDTO) -> Any:
if attr.type == "string":
Expand Down Expand Up @@ -742,7 +799,7 @@ def _stream_runs(


def _find_sort_type(backend, project_id, sort_by):
if sort_by == "sys/id":
if sort_by == SYS_ID:
return "string"
elif sort_by == "sys/creation_time":
return "datetime"
Expand Down Expand Up @@ -789,6 +846,7 @@ def _make_leaderboard_nql(
tags: Optional[Iterable[str]] = None,
trashed: Optional[bool] = False,
names_regex: Optional[str] = None,
names: Optional[List[str]] = None,
names_exclude_regex: Optional[Union[str, Iterable[str]]] = None,
custom_id_regex: Optional[Union[str, Iterable[str]]] = None,
is_run: bool = True,
Expand Down Expand Up @@ -834,6 +892,26 @@ def _make_leaderboard_nql(
aggregator=NQLAggregator.AND,
)

if names is not None:
query = NQLQueryAggregate(
items=[
query,
NQLQueryAggregate(
items=[
NQLQueryAttribute(
name="sys/name",
type=NQLAttributeType.STRING,
operator=NQLAttributeOperator.EQUALS,
value=escape_nql_criterion(name),
)
for name in names
],
aggregator=NQLAggregator.OR,
),
],
aggregator=NQLAggregator.AND,
)

if isinstance(names_exclude_regex, str):
names_exclude_regex = [names_exclude_regex]

Expand Down Expand Up @@ -903,7 +981,7 @@ def list_objects_from_project(
columns: Iterable[str] = None,
query: str = "(`sys/trashed`:bool = false)",
limit: Optional[int] = None,
sort_by: Tuple[str, str, Literal["ascending", "descending"]] = ("sys/id", "string", "descending"),
sort_by: Tuple[str, str, Literal["ascending", "descending"]] = (SYS_ID, "string", "descending"),
) -> Generator[_AttributeContainer, None, None]:
offset = 0
batch_size = 10_000
Expand Down
68 changes: 35 additions & 33 deletions src/neptune_fetcher/read_only_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
FieldDefinition,
FieldType,
)
from neptune_fetcher.util import escape_nql_criterion

if TYPE_CHECKING:
from neptune_fetcher.read_only_project import ReadOnlyProject
Expand All @@ -50,40 +49,23 @@ def __init__(
experiment_name: Optional[str] = None,
eager_load_fields: bool = True,
) -> None:
self.project = read_only_project

if with_id is None and custom_id is None and experiment_name is None:
raise ValueError("You must provide one of: `with_id`, `custom_id`, and `experiment_name`.")
sys_id = ReadOnlyRun._select_and_validate_id(
read_only_project, sys_id=with_id, custom_id=custom_id, experiment_name=experiment_name
)
self._initialize_instance(read_only_project, sys_id, eager_load_fields)

if sum([with_id is not None, custom_id is not None, experiment_name is not None]) != 1:
raise ValueError("You must provide exactly one of: `with_id`, `custom_id`, and `experiment_name`.")
@classmethod
def _create(cls, read_only_project, sys_id, eager_load_fields) -> "ReadOnlyRun":
instance = cls.__new__(cls)
instance._initialize_instance(read_only_project, sys_id, eager_load_fields)
return instance

if custom_id is not None:
run = read_only_project.fetch_runs_df(
query=f'`sys/custom_run_id`:string = "{escape_nql_criterion(custom_id)}"', limit=1, columns=["sys/id"]
)

if len(run) == 0:
raise ValueError(f"No experiment found with custom id '{custom_id}'")
self.with_id = run.iloc[0]["sys/id"]
elif experiment_name is not None:
experiment = read_only_project.fetch_experiments_df(
query=f'`sys/name`:string = "{escape_nql_criterion(experiment_name)}"', limit=1, columns=["sys/id"]
)
if len(experiment) == 0:
raise ValueError(f"No experiment found with name '{experiment_name}'")
self.with_id = experiment.iloc[0]["sys/id"]
else:
run = read_only_project.fetch_runs_df(
query=f'`sys/id`:string = "{escape_nql_criterion(with_id)}"', limit=1, columns=["sys/id"]
)
if len(run) == 0:
raise ValueError(f"No experiment found with Neptune ID '{with_id}'")
self.with_id = with_id

self._container_id = f"{self.project.project_identifier}/{self.with_id}"
def _initialize_instance(self, read_only_project, sys_id, eager_load_fields):
self.project = read_only_project
self.with_id = sys_id
self._container_id = f"{read_only_project.project_identifier}/{sys_id}"
self._cache = FieldsCache(
backend=self.project._backend,
backend=read_only_project._backend,
container_id=self._container_id,
)
self._loaded_structure = False
Expand All @@ -92,6 +74,23 @@ def __init__(
else:
self._structure = {}

@staticmethod
def _select_and_validate_id(read_only_project, sys_id, custom_id, experiment_name):
if sum([sys_id is not None, custom_id is not None, experiment_name is not None]) != 1:
raise ValueError("You must provide exactly one of: `with_id`, `custom_id`, and `experiment_name`.")

sys_id = read_only_project._fetch_sys_id(sys_id=sys_id, custom_id=custom_id, experiment_name=experiment_name)

if sys_id is not None:
return sys_id

if custom_id is not None:
raise ValueError(f"No experiment found with custom id '{custom_id}'")
elif experiment_name is not None:
raise ValueError(f"No experiment found with name '{experiment_name}'")
else:
raise ValueError(f"No experiment found with Neptune ID '{sys_id}'")

def __getitem__(self, item: str) -> Union[Fetchable, FetchableSeries]:
try:
return self._structure[item]
Expand All @@ -102,7 +101,10 @@ def __getitem__(self, item: str) -> Union[Fetchable, FetchableSeries]:
# raise KeyError if the field does not indeed exist backend-side.
field = self._cache[item]
self._structure[item] = which_fetchable(
FieldDefinition(path=item, type=field.type), self.project._backend, self._container_id, self._cache
FieldDefinition(path=item, type=field.type),
self.project._backend,
self._container_id,
self._cache,
)

return self._structure[item]
Expand Down
45 changes: 45 additions & 0 deletions tests/e2e/test_read_only_project.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os
import random

import pytest

NEPTUNE_PROJECT = os.getenv("NEPTUNE_E2E_FIXED_PROJECT")


@pytest.mark.parametrize("eager_load_fields", [True, False])
def test_fetch_read_only_experiments(project, all_experiment_ids, eager_load_fields):
experiments = list(project.fetch_read_only_experiments(names=None, eager_load_fields=eager_load_fields))
experiments_empty_names = list(project.fetch_read_only_experiments(names=[], eager_load_fields=eager_load_fields))
assert len(experiments) == len(experiments_empty_names) == 0


def test_fetch_read_only_with_experiment_names(project, all_experiment_names):
exp_names = random.sample(all_experiment_names, 3)
experiments = list(project.fetch_read_only_experiments(names=exp_names))

assert len(experiments) == len(exp_names)
assert set(exp_names) == {exp["sys/name"].fetch() for exp in experiments}


@pytest.mark.parametrize("eager_load_fields", [True, False])
def test_fetch_read_only_runs(project, all_experiment_ids, eager_load_fields):
assert [] == list(project.fetch_read_only_runs(eager_load_fields=eager_load_fields))
assert [] == list(project.fetch_read_only_runs(with_ids=[], custom_ids=[], eager_load_fields=eager_load_fields))

runs_ids = random.sample(all_experiment_ids, 3)
filtered = list(project.fetch_read_only_runs(custom_ids=runs_ids, eager_load_fields=eager_load_fields))

assert len(filtered) == len(runs_ids)
assert set(runs_ids) == {run["sys/custom_run_id"].fetch() for run in filtered}

sys_ids = [run["sys/id"].fetch() for run in filtered]

filtered_by_sys_id = list(project.fetch_read_only_runs(with_ids=sys_ids, eager_load_fields=eager_load_fields))
assert set(runs_ids) == {run["sys/custom_run_id"].fetch() for run in filtered_by_sys_id}

duplicated = list(
project.fetch_read_only_runs(custom_ids=runs_ids, with_ids=sys_ids, eager_load_fields=eager_load_fields)
)
# custom_ids=runs_ids, with_ids=sys_ids duplicates the results
assert len(duplicated) == (len(runs_ids) * 2)
assert set(runs_ids) == {run["sys/custom_run_id"].fetch() for run in duplicated}

0 comments on commit 9532fc8

Please sign in to comment.