diff --git a/src/dolphin/_decorators.py b/src/dolphin/_decorators.py index 4820af1d..5dda97c0 100644 --- a/src/dolphin/_decorators.py +++ b/src/dolphin/_decorators.py @@ -1,13 +1,12 @@ +from __future__ import annotations + import functools -import inspect import shutil import tempfile -from inspect import Parameter from pathlib import Path -from typing import Any, Callable, Optional +from typing import Any, Callable from dolphin._log import get_log -from dolphin._types import Filename logger = get_log(__name__) @@ -17,15 +16,14 @@ def atomic_output( - function: Optional[Callable] = None, output_arg: str = "output_file", is_dir: bool = False, - scratch_dir: Optional[Filename] = None, + use_tmp: bool = False, ) -> Callable: """Use a temporary file/directory for the `output_arg` until the function finishes. Decorator is used on a function which writes to an output file/directory in blocks. - If the function were interrupted, the file/directory would be partiall complete. + If the function were interrupted, the file/directory would be partially complete. This decorator replaces the final output name with a temp file/dir, and then renames the temp file/dir to the final name after the function finishes. @@ -36,65 +34,51 @@ def atomic_output( Parameters ---------- - function : Optional[Callable] - Used if the decorator is called without any arguments (i.e. as - `@atomic_output` instead of `@atomic_output(output_arg=...)`) output_arg : str, optional The name of the argument to replace, by default 'output_file' is_dir : bool, default = False - If True, the output argument is a directory, not a file - scratch_dir : Optional[Filename] - The directory to use for the temporary file, by default None - If None, uses the same directory as the final requested output. + If `True`, the output argument is a directory, not a file + use_tmp : bool, default = False + If `False`, uses the parent directory of the desired output, with + a random suffix added to the name to distinguish from actual output. + If `True`, uses the `/tmp` directory (or wherever the default is + for the `tempfile` module). Returns ------- Callable The decorated function + + Raises + ------ + FileExistsError + if the file for `output_arg` already exists (if out_dir=False), or + if the directory at `output_arg` exists and is non-empty. + + Notes + ----- + The output at `output_arg` *must not* exist already, or the decorator will error + (though if `is_dir=True`, it is allowed to be an empty directory). + The function being decorated *must* be called with keyword args for `output_arg`. """ - def actual_decorator(func: Callable) -> Callable: - # Want to be able to use this decorator with or without arguments: - # https://stackoverflow.com/a/19017908/4174466 - # Code adapted from the `@login_required` decorator in Django: - # https://github.com/django/django/blob/d254a54e7f65e83d8971bd817031bc6af32a7a46/django/contrib/auth/decorators.py#L43 # noqa + def decorator(func: Callable) -> Callable: @functools.wraps(func) def wrapper(*args, **kwargs) -> Any: # Extract the output file path if output_arg in kwargs: final_out_name = kwargs[output_arg] - # Track where we will slot in the new tempfile name (kwargs here) - replace_tuple = (kwargs, output_arg) else: - # Check that it was provided as positional, in `args` - sig = inspect.signature(func) - for idx, param in enumerate(sig.parameters.values()): - if output_arg == param.name: - try: - # If the gave it in the args, use that - final_out_name = args[idx] - # Track where we will slot in the new tempfile name (args here) - # Need to make `args` into a list to we can mutate - replace_tuple = (list(args), idx) - except IndexError: - # Otherwise, nothing was given, so use the default - final_out_name = param.default - if param.kind == Parameter.POSITIONAL_ONLY: - # Insert as a positional arg if it needs to be - replace_tuple = (list(args), idx) - else: - replace_tuple = (kwargs, output_arg) - break - else: - raise ValueError( - f"Argument {output_arg} not found in function {func.__name__}" - ) + raise FileExistsError( + f"Argument {output_arg} not found in function {func.__name__}:" + f" {kwargs}" + ) final_path = Path(final_out_name) - if scratch_dir is None: - tmp_dir = final_path.parent - else: - tmp_dir = None + # Make sure the desired final output doesn't already exist + _raise_if_exists(final_path, is_dir=is_dir) + # None means that tempfile will use /tmp + tmp_dir = final_path.parent if not use_tmp else None # Make the tempfile start the same as the desired output prefix = final_path.name @@ -109,11 +93,11 @@ def wrapper(*args, **kwargs) -> Any: try: # Replace the output file path with the temp file # It would be like this if we only allows keyword: - # kwargs[output_arg] = temp_path - replace_tuple[0][replace_tuple[1]] = temp_path + kwargs[output_arg] = temp_path # Execute the original function result = func(*args, **kwargs) # Move the temp file to the final location + logger.debug("Moving %s to %s", temp_path, final_path) shutil.move(temp_path, final_path) return result @@ -127,8 +111,19 @@ def wrapper(*args, **kwargs) -> Any: return wrapper - if function is not None: - # Decorator used without arguments - return actual_decorator(function) - # Decorator used with arguments - return actual_decorator + return decorator + + +def _raise_if_exists(final_path: Path, is_dir: bool): + if final_path.exists(): + err_msg = f"{final_path} already exists" + if is_dir and final_path.is_dir(): + try: + final_path.rmdir() + except OSError as e: + if "Directory not empty" not in e.args[0]: + raise e + else: + raise FileExistsError(err_msg) + else: + raise FileExistsError(err_msg) diff --git a/tests/test_decorators.py b/tests/test_decorators.py index e47aa6ac..97625c42 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -14,62 +14,63 @@ def _long_write(filename, pause: float = 0.2): f.close() -@atomic_output -def default_write(output_file="out.txt"): - _long_write(output_file) +@atomic_output(output_arg="outname") +def default_write_newname(outname="out.txt"): + _long_write(outname) -@atomic_output(output_arg="outname") -def default_write_newname(outname="out3.txt"): +@atomic_output(output_arg="outname", use_tmp=True) +def default_write_newname_tmp(outname="out.txt"): + _long_write(outname) + + +@atomic_output(output_arg="output_dir", is_dir=True) +def default_write_dir(output_dir="some_dir", filename="testfile.txt"): + p = Path(output_dir) + p.mkdir(exist_ok=True, parents=True) + outname = p / filename _long_write(outname) -@atomic_output(output_arg="output_dir") -def default_write_dir(output_dir="some_dir"): +@atomic_output(output_arg="output_dir", is_dir=True, use_tmp=True) +def default_write_dir_tmp(output_dir="some_dir", filename="testfile.txt"): p = Path(output_dir) - p.mkdir(exist_ok=True) - outname = p / "testfile.txt" + p.mkdir(exist_ok=True, parents=True) + outname = p / filename _long_write(outname) def test_atomic_output(tmpdir): with tmpdir.as_cwd(): - default_write() - default_write(output_file="out2.txt") - default_write_newname() - default_write_newname(outname="out4.txt") - for fn in ["out.txt", "out2.txt", "out3.txt", "out4.txt"]: + default_write_newname(outname="out1.txt") + default_write_newname(outname="out2.txt") + for fn in ["out1.txt", "out2.txt"]: assert Path(fn).exists() -def test_atomic_output_name_swap(tmpdir): - # Kick off the writing function in the background - # so we see if a different file was created +def test_atomic_output_tmp(tmpdir): with tmpdir.as_cwd(): - # Check it works providing the "args" - t = threading.Thread(target=default_write) - t.start() - # It should NOT exist, yet - assert not Path("out.txt").exists() - time.sleep(0.5) - assert Path("out.txt").exists() - Path("out.txt").unlink() + default_write_newname_tmp(outname="out1.txt") + assert Path("out1.txt").exists() -def test_atomic_output_name_swap_with_args(tmpdir): - with tmpdir.as_cwd(): - outname2 = "out2.txt" - t = threading.Thread(target=default_write, args=(outname2,)) - t.start() - # It should NOT exist, yet - assert not Path(outname2).exists() - time.sleep(0.5) - t.join() - assert Path(outname2).exists() - Path(outname2).unlink() +def test_atomic_output_dir(tmp_path): + out_dir = tmp_path / "out" + filename = "testfile.txt" + out_dir.mkdir() + default_write_dir(output_dir=out_dir, filename=filename) + assert Path(out_dir / filename).exists() -def test_atomic_output_name_swap_with_kwargs(tmpdir): +def test_atomic_output_dir_tmp(tmp_path): + out_dir = tmp_path / "out" + filename = "testfile.txt" + out_dir.mkdir() + default_write_dir(output_dir=out_dir, filename=filename) + assert Path(out_dir / filename).exists() + + +def test_atomic_output_name_swap_file(tmpdir): with tmpdir.as_cwd(): outname2 = "out3.txt" t = threading.Thread(target=default_write_newname, kwargs={"outname": outname2}) @@ -79,18 +80,18 @@ def test_atomic_output_name_swap_with_kwargs(tmpdir): time.sleep(0.5) t.join() assert Path(outname2).exists() - Path(outname2).unlink() -def test_atomic_output_dir_name_swap(tmpdir): +def test_atomic_output_dir_swap(tmp_path): # Kick off the writing function in the background # so we see if a different file was created - with tmpdir.as_cwd(): - # Check it works providing the "args" - t = threading.Thread(target=default_write) - t.start() - # It should NOT exist, yet - assert not Path("out.txt").exists() - time.sleep(0.5) - assert Path("out.txt").exists() - Path("out.txt").unlink() + # Check it works providing the "args" + out_dir = tmp_path / "out" + out_dir.mkdir() + t = threading.Thread(target=default_write_dir, kwargs={"output_dir": out_dir}) + t.start() + # It should NOT exist, yet + assert not Path(out_dir / "testfile.txt").exists() + time.sleep(0.5) + t.join() + assert Path(out_dir / "testfile.txt").exists()