Skip to content

Commit

Permalink
TrivialPipeline accespts function names as arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
estshorter committed Aug 6, 2021
1 parent 62ea786 commit 317e9e1
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 23 deletions.
65 changes: 42 additions & 23 deletions lwpipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
logger = logging.getLogger(__name__)


__version__ = "4.0.4"
__version__ = "4.0.1"


class DumpType(IntEnum):
Expand Down Expand Up @@ -368,17 +368,21 @@ def __init__(self, funcs: list[Callable], names: list[str] = None) -> None:
names: 関数の名前
"""
self.funcs = funcs
# node.name -> index in self.nodes
self.name_to_idx = dict()
if names is not None:
_assert_same_length(funcs, names, "funcs", "names")
if len(names) != len(set(names)):
raise ValueError(f"names is not unique: {names}")
self.names = names
for idx, name in enumerate(names):
self.name_to_idx[name] = idx
return

self.names = []
name_duplicate_counter = {}

for func in self.funcs:
for idx, func in enumerate(self.funcs):
if hasattr(func, "__name__"):
name = func.__name__
else:
Expand All @@ -388,39 +392,54 @@ def __init__(self, funcs: list[Callable], names: list[str] = None) -> None:
if dup_counter == 0:
name_duplicate_counter[name] = 1
else:
if name in name_duplicate_counter:
name_duplicate_counter[name] = dup_counter + 1
name += f"__{dup_counter + 1}__"
dup_counter = name_duplicate_counter.get(name, 0)
if dup_counter == 0:
name_duplicate_counter[name] = 1
else:
raise ValueError(
f"name: {name} is duplicated. Consider change name"
)
name_duplicate_counter[name] = dup_counter + 1
name += f"__{dup_counter + 1}__"
dup_counter = name_duplicate_counter.get(name, 0)
if dup_counter == 0:
name_duplicate_counter[name] = 1
else:
raise ValueError(
f"name: {name} is duplicated. Consider change name"
)
self.name_to_idx[name] = idx
self.names.append(name)

def run(self, from_=0, to_=None):
def _get_start_or_end_index(self, start_or_end: int | str, start_or_end_str: str):
if isinstance(start_or_end, int):
idx = start_or_end
elif isinstance(start_or_end, str):
try:
idx = self.name_to_idx[start_or_end]
except KeyError:
raise ValueError(
f"specified {start_or_end_str} node ({start_or_end}) is not found"
)

if idx < 0 or idx >= len(self.funcs):
raise ValueError(
f"0 <= {start_or_end_str} ({idx}) <= {len(self.funcs)-1} must be satisfied"
)
return idx

def run(self, from_: str | int = 0, to_: str | int | None = None):
if to_ is None:
to_ = len(self.funcs) - 1

if from_ < 0 or from_ >= len(self.funcs):
idx_from = self._get_start_or_end_index(from_, "start")
self.idx_from = idx_from
idx_to = self._get_start_or_end_index(to_, "end")
if idx_from > idx_to:
raise ValueError(
f"0 <= from({from_}) <= {len(self.funcs)-1} must be satisfied"
f"idx_from must satisfy idx_from ({idx_from}) <= idx_to ({idx_to})"
)
if from_ > to_:
raise ValueError("start <= to must be satisfied")
if to_ < 0 or to_ >= len(self.funcs):
raise ValueError(f"0 <= to({to_}) <= {len(self.funcs)-1} must be satisfied")

logger.info(
f"Scheduled {len(self.funcs[from_:to_+1])} tasks, {len(self.funcs)} tasks in total"
f"Scheduled {len(self.funcs[idx_from:idx_to+1])} tasks, {len(self.funcs)} tasks in total"
)
for idx, (func, name) in enumerate(
zip(self.funcs[from_ : to_ + 1], self.names[from_ : to_ + 1])
zip(self.funcs[idx_from : idx_to + 1], self.names[idx_from : idx_to + 1])
):
logger.info(
f"Running {idx+1}/{len(self.funcs[from_:to_+1])} tasks ({name})"
f"Running {idx+1}/{len(self.funcs[idx_from:idx_to+1])} tasks ({name})"
)
func()
logger.info("Completed all tasks!")
Expand Down
13 changes: 13 additions & 0 deletions tests/test_trivial_pipieline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,16 @@ def test_name_uniqueness():
funcs = [no_op, no_op]
with pytest.raises(ValueError):
TrivialPipeline(funcs, names=["a1", "a1"])


def test_string_from_to():
funcs = [no_op, no_op, no_op]
pipe = TrivialPipeline(funcs, names=["func1", "func2", "func3"])
pipe.run("func1", "func2")
pipe.run("func2", "func3")
pipe.run("func1", "func3")
pipe.run("func3", "func3")
with pytest.raises(ValueError):
pipe.run("func3", "func1")
with pytest.raises(ValueError):
pipe.run("func2", "func1")

0 comments on commit 317e9e1

Please sign in to comment.