From 5d76ce5abfae376a0d32dd000f3677a658a15a1d Mon Sep 17 00:00:00 2001 From: Scott Staniewicz Date: Mon, 6 Nov 2023 12:50:01 -0500 Subject: [PATCH] make `@atomic_output` work with args or kwargs --- src/dolphin/_decorators.py | 29 ++++++++++-- tests/test_decorators.py | 96 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+), 3 deletions(-) create mode 100644 tests/test_decorators.py diff --git a/src/dolphin/_decorators.py b/src/dolphin/_decorators.py index f68d97c1..4820af1d 100644 --- a/src/dolphin/_decorators.py +++ b/src/dolphin/_decorators.py @@ -2,6 +2,7 @@ import inspect import shutil import tempfile +from inspect import Parameter from pathlib import Path from typing import Any, Callable, Optional @@ -62,10 +63,28 @@ def wrapper(*args, **kwargs) -> Any: # Extract the output file path if output_arg in kwargs: final_out_name = kwargs[output_arg] + # Track where we will slot in the new tempfile name (kwargs here) + replace_tuple = (kwargs, output_arg) else: + # Check that it was provided as positional, in `args` sig = inspect.signature(func) - if output_arg in sig.parameters: - final_out_name = sig.parameters[output_arg].default + for idx, param in enumerate(sig.parameters.values()): + if output_arg == param.name: + try: + # If the gave it in the args, use that + final_out_name = args[idx] + # Track where we will slot in the new tempfile name (args here) + # Need to make `args` into a list to we can mutate + replace_tuple = (list(args), idx) + except IndexError: + # Otherwise, nothing was given, so use the default + final_out_name = param.default + if param.kind == Parameter.POSITIONAL_ONLY: + # Insert as a positional arg if it needs to be + replace_tuple = (list(args), idx) + else: + replace_tuple = (kwargs, output_arg) + break else: raise ValueError( f"Argument {output_arg} not found in function {func.__name__}" @@ -77,6 +96,7 @@ def wrapper(*args, **kwargs) -> Any: else: tmp_dir = None + # Make the tempfile start the same as the desired output prefix = final_path.name if is_dir: # Create a temporary directory @@ -88,7 +108,9 @@ def wrapper(*args, **kwargs) -> Any: try: # Replace the output file path with the temp file - kwargs[output_arg] = temp_path + # It would be like this if we only allows keyword: + # kwargs[output_arg] = temp_path + replace_tuple[0][replace_tuple[1]] = temp_path # Execute the original function result = func(*args, **kwargs) # Move the temp file to the final location @@ -97,6 +119,7 @@ def wrapper(*args, **kwargs) -> Any: return result finally: logger.debug("Cleaning up temp file %s", temp_path) + # Different cleanup is needed if is_dir: shutil.rmtree(temp_path, ignore_errors=True) else: diff --git a/tests/test_decorators.py b/tests/test_decorators.py new file mode 100644 index 00000000..e47aa6ac --- /dev/null +++ b/tests/test_decorators.py @@ -0,0 +1,96 @@ +import threading +import time +from pathlib import Path + +from dolphin._decorators import atomic_output + + +def _long_write(filename, pause: float = 0.2): + """Simulate a long writing process""" + f = open(filename, "w") + f.write("aaa\n") + time.sleep(pause) + f.write("bbb\n") + f.close() + + +@atomic_output +def default_write(output_file="out.txt"): + _long_write(output_file) + + +@atomic_output(output_arg="outname") +def default_write_newname(outname="out3.txt"): + _long_write(outname) + + +@atomic_output(output_arg="output_dir") +def default_write_dir(output_dir="some_dir"): + p = Path(output_dir) + p.mkdir(exist_ok=True) + outname = p / "testfile.txt" + _long_write(outname) + + +def test_atomic_output(tmpdir): + with tmpdir.as_cwd(): + default_write() + default_write(output_file="out2.txt") + default_write_newname() + default_write_newname(outname="out4.txt") + for fn in ["out.txt", "out2.txt", "out3.txt", "out4.txt"]: + assert Path(fn).exists() + + +def test_atomic_output_name_swap(tmpdir): + # Kick off the writing function in the background + # so we see if a different file was created + with tmpdir.as_cwd(): + # Check it works providing the "args" + t = threading.Thread(target=default_write) + t.start() + # It should NOT exist, yet + assert not Path("out.txt").exists() + time.sleep(0.5) + assert Path("out.txt").exists() + Path("out.txt").unlink() + + +def test_atomic_output_name_swap_with_args(tmpdir): + with tmpdir.as_cwd(): + outname2 = "out2.txt" + t = threading.Thread(target=default_write, args=(outname2,)) + t.start() + # It should NOT exist, yet + assert not Path(outname2).exists() + time.sleep(0.5) + t.join() + assert Path(outname2).exists() + Path(outname2).unlink() + + +def test_atomic_output_name_swap_with_kwargs(tmpdir): + with tmpdir.as_cwd(): + outname2 = "out3.txt" + t = threading.Thread(target=default_write_newname, kwargs={"outname": outname2}) + t.start() + # It should NOT exist, yet + assert not Path(outname2).exists() + time.sleep(0.5) + t.join() + assert Path(outname2).exists() + Path(outname2).unlink() + + +def test_atomic_output_dir_name_swap(tmpdir): + # Kick off the writing function in the background + # so we see if a different file was created + with tmpdir.as_cwd(): + # Check it works providing the "args" + t = threading.Thread(target=default_write) + t.start() + # It should NOT exist, yet + assert not Path("out.txt").exists() + time.sleep(0.5) + assert Path("out.txt").exists() + Path("out.txt").unlink()