Skip to content

Commit

Permalink
simplify logic by requiring kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
scottstanie committed Nov 7, 2023
1 parent 5d76ce5 commit 278ac23
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 103 deletions.
105 changes: 50 additions & 55 deletions src/dolphin/_decorators.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from __future__ import annotations

import functools
import inspect
import shutil
import tempfile
from inspect import Parameter
from pathlib import Path
from typing import Any, Callable, Optional
from typing import Any, Callable

from dolphin._log import get_log
from dolphin._types import Filename

logger = get_log(__name__)

Expand All @@ -17,15 +16,14 @@


def atomic_output(
function: Optional[Callable] = None,
output_arg: str = "output_file",
is_dir: bool = False,
scratch_dir: Optional[Filename] = None,
use_tmp: bool = False,
) -> Callable:
"""Use a temporary file/directory for the `output_arg` until the function finishes.
Decorator is used on a function which writes to an output file/directory in blocks.
If the function were interrupted, the file/directory would be partiall complete.
If the function were interrupted, the file/directory would be partially complete.
This decorator replaces the final output name with a temp file/dir, and then
renames the temp file/dir to the final name after the function finishes.
Expand All @@ -36,65 +34,51 @@ def atomic_output(
Parameters
----------
function : Optional[Callable]
Used if the decorator is called without any arguments (i.e. as
`@atomic_output` instead of `@atomic_output(output_arg=...)`)
output_arg : str, optional
The name of the argument to replace, by default 'output_file'
is_dir : bool, default = False
If True, the output argument is a directory, not a file
scratch_dir : Optional[Filename]
The directory to use for the temporary file, by default None
If None, uses the same directory as the final requested output.
If `True`, the output argument is a directory, not a file
use_tmp : bool, default = False
If `False`, uses the parent directory of the desired output, with
a random suffix added to the name to distinguish from actual output.
If `True`, uses the `/tmp` directory (or wherever the default is
for the `tempfile` module).
Returns
-------
Callable
The decorated function
Raises
------
FileExistsError
if the file for `output_arg` already exists (if out_dir=False), or
if the directory at `output_arg` exists and is non-empty.
Notes
-----
The output at `output_arg` *must not* exist already, or the decorator will error
(though if `is_dir=True`, it is allowed to be an empty directory).
The function being decorated *must* be called with keyword args for `output_arg`.
"""

def actual_decorator(func: Callable) -> Callable:
# Want to be able to use this decorator with or without arguments:
# https://stackoverflow.com/a/19017908/4174466
# Code adapted from the `@login_required` decorator in Django:
# https://github.com/django/django/blob/d254a54e7f65e83d8971bd817031bc6af32a7a46/django/contrib/auth/decorators.py#L43 # noqa
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any:
# Extract the output file path
if output_arg in kwargs:
final_out_name = kwargs[output_arg]
# Track where we will slot in the new tempfile name (kwargs here)
replace_tuple = (kwargs, output_arg)
else:
# Check that it was provided as positional, in `args`
sig = inspect.signature(func)
for idx, param in enumerate(sig.parameters.values()):
if output_arg == param.name:
try:
# If the gave it in the args, use that
final_out_name = args[idx]
# Track where we will slot in the new tempfile name (args here)
# Need to make `args` into a list to we can mutate
replace_tuple = (list(args), idx)
except IndexError:
# Otherwise, nothing was given, so use the default
final_out_name = param.default
if param.kind == Parameter.POSITIONAL_ONLY:
# Insert as a positional arg if it needs to be
replace_tuple = (list(args), idx)
else:
replace_tuple = (kwargs, output_arg)
break
else:
raise ValueError(
f"Argument {output_arg} not found in function {func.__name__}"
)
raise FileExistsError(
f"Argument {output_arg} not found in function {func.__name__}:"
f" {kwargs}"
)

final_path = Path(final_out_name)
if scratch_dir is None:
tmp_dir = final_path.parent
else:
tmp_dir = None
# Make sure the desired final output doesn't already exist
_raise_if_exists(final_path, is_dir=is_dir)
# None means that tempfile will use /tmp
tmp_dir = final_path.parent if not use_tmp else None

# Make the tempfile start the same as the desired output
prefix = final_path.name
Expand All @@ -109,11 +93,11 @@ def wrapper(*args, **kwargs) -> Any:
try:
# Replace the output file path with the temp file
# It would be like this if we only allows keyword:
# kwargs[output_arg] = temp_path
replace_tuple[0][replace_tuple[1]] = temp_path
kwargs[output_arg] = temp_path
# Execute the original function
result = func(*args, **kwargs)
# Move the temp file to the final location
logger.debug("Moving %s to %s", temp_path, final_path)
shutil.move(temp_path, final_path)

return result
Expand All @@ -127,8 +111,19 @@ def wrapper(*args, **kwargs) -> Any:

return wrapper

if function is not None:
# Decorator used without arguments
return actual_decorator(function)
# Decorator used with arguments
return actual_decorator
return decorator


def _raise_if_exists(final_path: Path, is_dir: bool):
if final_path.exists():
err_msg = f"{final_path} already exists"
if is_dir and final_path.is_dir():
try:
final_path.rmdir()
except OSError as e:
if "Directory not empty" not in e.args[0]:
raise e
else:
raise FileExistsError(err_msg)
else:
raise FileExistsError(err_msg)
97 changes: 49 additions & 48 deletions tests/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,62 +14,63 @@ def _long_write(filename, pause: float = 0.2):
f.close()


@atomic_output
def default_write(output_file="out.txt"):
_long_write(output_file)
@atomic_output(output_arg="outname")
def default_write_newname(outname="out.txt"):
_long_write(outname)


@atomic_output(output_arg="outname")
def default_write_newname(outname="out3.txt"):
@atomic_output(output_arg="outname", use_tmp=True)
def default_write_newname_tmp(outname="out.txt"):
_long_write(outname)


@atomic_output(output_arg="output_dir", is_dir=True)
def default_write_dir(output_dir="some_dir", filename="testfile.txt"):
p = Path(output_dir)
p.mkdir(exist_ok=True, parents=True)
outname = p / filename
_long_write(outname)


@atomic_output(output_arg="output_dir")
def default_write_dir(output_dir="some_dir"):
@atomic_output(output_arg="output_dir", is_dir=True, use_tmp=True)
def default_write_dir_tmp(output_dir="some_dir", filename="testfile.txt"):
p = Path(output_dir)
p.mkdir(exist_ok=True)
outname = p / "testfile.txt"
p.mkdir(exist_ok=True, parents=True)
outname = p / filename
_long_write(outname)


def test_atomic_output(tmpdir):
with tmpdir.as_cwd():
default_write()
default_write(output_file="out2.txt")
default_write_newname()
default_write_newname(outname="out4.txt")
for fn in ["out.txt", "out2.txt", "out3.txt", "out4.txt"]:
default_write_newname(outname="out1.txt")
default_write_newname(outname="out2.txt")
for fn in ["out1.txt", "out2.txt"]:
assert Path(fn).exists()


def test_atomic_output_name_swap(tmpdir):
# Kick off the writing function in the background
# so we see if a different file was created
def test_atomic_output_tmp(tmpdir):
with tmpdir.as_cwd():
# Check it works providing the "args"
t = threading.Thread(target=default_write)
t.start()
# It should NOT exist, yet
assert not Path("out.txt").exists()
time.sleep(0.5)
assert Path("out.txt").exists()
Path("out.txt").unlink()
default_write_newname_tmp(outname="out1.txt")
assert Path("out1.txt").exists()


def test_atomic_output_name_swap_with_args(tmpdir):
with tmpdir.as_cwd():
outname2 = "out2.txt"
t = threading.Thread(target=default_write, args=(outname2,))
t.start()
# It should NOT exist, yet
assert not Path(outname2).exists()
time.sleep(0.5)
t.join()
assert Path(outname2).exists()
Path(outname2).unlink()
def test_atomic_output_dir(tmp_path):
out_dir = tmp_path / "out"
filename = "testfile.txt"
out_dir.mkdir()
default_write_dir(output_dir=out_dir, filename=filename)
assert Path(out_dir / filename).exists()


def test_atomic_output_name_swap_with_kwargs(tmpdir):
def test_atomic_output_dir_tmp(tmp_path):
out_dir = tmp_path / "out"
filename = "testfile.txt"
out_dir.mkdir()
default_write_dir(output_dir=out_dir, filename=filename)
assert Path(out_dir / filename).exists()


def test_atomic_output_name_swap_file(tmpdir):
with tmpdir.as_cwd():
outname2 = "out3.txt"
t = threading.Thread(target=default_write_newname, kwargs={"outname": outname2})
Expand All @@ -79,18 +80,18 @@ def test_atomic_output_name_swap_with_kwargs(tmpdir):
time.sleep(0.5)
t.join()
assert Path(outname2).exists()
Path(outname2).unlink()


def test_atomic_output_dir_name_swap(tmpdir):
def test_atomic_output_dir_swap(tmp_path):
# Kick off the writing function in the background
# so we see if a different file was created
with tmpdir.as_cwd():
# Check it works providing the "args"
t = threading.Thread(target=default_write)
t.start()
# It should NOT exist, yet
assert not Path("out.txt").exists()
time.sleep(0.5)
assert Path("out.txt").exists()
Path("out.txt").unlink()
# Check it works providing the "args"
out_dir = tmp_path / "out"
out_dir.mkdir()
t = threading.Thread(target=default_write_dir, kwargs={"output_dir": out_dir})
t.start()
# It should NOT exist, yet
assert not Path(out_dir / "testfile.txt").exists()
time.sleep(0.5)
t.join()
assert Path(out_dir / "testfile.txt").exists()

0 comments on commit 278ac23

Please sign in to comment.