Skip to content

Commit

Permalink
make @atomic_output work with args or kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
scottstanie committed Nov 6, 2023
1 parent a9b8049 commit 5d76ce5
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 3 deletions.
29 changes: 26 additions & 3 deletions src/dolphin/_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
import shutil
import tempfile
from inspect import Parameter
from pathlib import Path
from typing import Any, Callable, Optional

Expand Down Expand Up @@ -62,10 +63,28 @@ def wrapper(*args, **kwargs) -> Any:
# Extract the output file path
if output_arg in kwargs:
final_out_name = kwargs[output_arg]
# Track where we will slot in the new tempfile name (kwargs here)
replace_tuple = (kwargs, output_arg)
else:
# Check that it was provided as positional, in `args`
sig = inspect.signature(func)
if output_arg in sig.parameters:
final_out_name = sig.parameters[output_arg].default
for idx, param in enumerate(sig.parameters.values()):
if output_arg == param.name:
try:
# If the gave it in the args, use that
final_out_name = args[idx]
# Track where we will slot in the new tempfile name (args here)
# Need to make `args` into a list to we can mutate
replace_tuple = (list(args), idx)
except IndexError:
# Otherwise, nothing was given, so use the default
final_out_name = param.default
if param.kind == Parameter.POSITIONAL_ONLY:
# Insert as a positional arg if it needs to be
replace_tuple = (list(args), idx)
else:
replace_tuple = (kwargs, output_arg)
break
else:
raise ValueError(
f"Argument {output_arg} not found in function {func.__name__}"
Expand All @@ -77,6 +96,7 @@ def wrapper(*args, **kwargs) -> Any:
else:
tmp_dir = None

# Make the tempfile start the same as the desired output
prefix = final_path.name
if is_dir:
# Create a temporary directory
Expand All @@ -88,7 +108,9 @@ def wrapper(*args, **kwargs) -> Any:

try:
# Replace the output file path with the temp file
kwargs[output_arg] = temp_path
# It would be like this if we only allows keyword:
# kwargs[output_arg] = temp_path
replace_tuple[0][replace_tuple[1]] = temp_path
# Execute the original function
result = func(*args, **kwargs)
# Move the temp file to the final location
Expand All @@ -97,6 +119,7 @@ def wrapper(*args, **kwargs) -> Any:
return result
finally:
logger.debug("Cleaning up temp file %s", temp_path)
# Different cleanup is needed
if is_dir:
shutil.rmtree(temp_path, ignore_errors=True)
else:
Expand Down
96 changes: 96 additions & 0 deletions tests/test_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import threading
import time
from pathlib import Path

from dolphin._decorators import atomic_output


def _long_write(filename, pause: float = 0.2):
"""Simulate a long writing process"""
f = open(filename, "w")
f.write("aaa\n")
time.sleep(pause)
f.write("bbb\n")
f.close()


@atomic_output
def default_write(output_file="out.txt"):
_long_write(output_file)


@atomic_output(output_arg="outname")
def default_write_newname(outname="out3.txt"):
_long_write(outname)


@atomic_output(output_arg="output_dir")
def default_write_dir(output_dir="some_dir"):
p = Path(output_dir)
p.mkdir(exist_ok=True)
outname = p / "testfile.txt"
_long_write(outname)


def test_atomic_output(tmpdir):
with tmpdir.as_cwd():
default_write()
default_write(output_file="out2.txt")
default_write_newname()
default_write_newname(outname="out4.txt")
for fn in ["out.txt", "out2.txt", "out3.txt", "out4.txt"]:
assert Path(fn).exists()


def test_atomic_output_name_swap(tmpdir):
# Kick off the writing function in the background
# so we see if a different file was created
with tmpdir.as_cwd():
# Check it works providing the "args"
t = threading.Thread(target=default_write)
t.start()
# It should NOT exist, yet
assert not Path("out.txt").exists()
time.sleep(0.5)
assert Path("out.txt").exists()
Path("out.txt").unlink()


def test_atomic_output_name_swap_with_args(tmpdir):
with tmpdir.as_cwd():
outname2 = "out2.txt"
t = threading.Thread(target=default_write, args=(outname2,))
t.start()
# It should NOT exist, yet
assert not Path(outname2).exists()
time.sleep(0.5)
t.join()
assert Path(outname2).exists()
Path(outname2).unlink()


def test_atomic_output_name_swap_with_kwargs(tmpdir):
with tmpdir.as_cwd():
outname2 = "out3.txt"
t = threading.Thread(target=default_write_newname, kwargs={"outname": outname2})
t.start()
# It should NOT exist, yet
assert not Path(outname2).exists()
time.sleep(0.5)
t.join()
assert Path(outname2).exists()
Path(outname2).unlink()


def test_atomic_output_dir_name_swap(tmpdir):
# Kick off the writing function in the background
# so we see if a different file was created
with tmpdir.as_cwd():
# Check it works providing the "args"
t = threading.Thread(target=default_write)
t.start()
# It should NOT exist, yet
assert not Path("out.txt").exists()
time.sleep(0.5)
assert Path("out.txt").exists()
Path("out.txt").unlink()

0 comments on commit 5d76ce5

Please sign in to comment.