diff --git a/src/dolphin/_decorators.py b/src/dolphin/_decorators.py index 33b7d22a..d9a5ca12 100644 --- a/src/dolphin/_decorators.py +++ b/src/dolphin/_decorators.py @@ -19,6 +19,7 @@ def atomic_output( output_arg: str = "output_file", is_dir: bool = False, use_tmp: bool = False, + overwrite: bool = False, ) -> Callable: """Use a temporary file/directory for the `output_arg` until the function finishes. @@ -43,6 +44,9 @@ def atomic_output( a random suffix added to the name to distinguish from actual output. If `True`, uses the `/tmp` directory (or wherever the default is for the `tempfile` module). + overwrite : bool, default = False + Overwrite an existing file. + If `False` raises `FileExistsError` if the file already exists. Returns ------- @@ -66,17 +70,17 @@ def decorator(func: Callable) -> Callable: @functools.wraps(func) def wrapper(*args, **kwargs) -> Any: # Extract the output file path - if output_arg in kwargs: + if kwargs.get(output_arg): final_out_name = kwargs[output_arg] else: raise FileExistsError( - f"Argument {output_arg} not found in function {func.__name__}:" + f"Argument {output_arg} not passed to function {func.__name__}:" f" {kwargs}" ) final_path = Path(final_out_name) # Make sure the desired final output doesn't already exist - _raise_if_exists(final_path, is_dir=is_dir) + _raise_if_exists(final_path, is_dir=is_dir, overwrite=overwrite) # None means that tempfile will use /tmp tmp_dir = final_path.parent if not use_tmp else None @@ -117,17 +121,27 @@ def wrapper(*args, **kwargs) -> Any: return decorator -def _raise_if_exists(final_path: Path, is_dir: bool): +def _raise_if_exists(final_path: Path, is_dir: bool, overwrite: bool): + msg = f"{final_path} already exists" if final_path.exists(): - err_msg = f"{final_path} already exists" + logger.debug(f"{final_path} already exists") + if overwrite: + if final_path.is_dir(): + shutil.rmtree(final_path) + else: + final_path.unlink() + return + if is_dir and final_path.is_dir(): + # We can work with an empty directory try: final_path.rmdir() except OSError as e: err_msg = str(e) - if "Directory not empty" not in err_msg: - raise e + if "Directory not empty" in err_msg: + raise FileExistsError(msg) else: - raise FileExistsError(err_msg) + # Some other error we don't know + raise e else: - raise FileExistsError(err_msg) + raise FileExistsError(msg) diff --git a/src/dolphin/unwrap.py b/src/dolphin/unwrap.py index 6cde4318..de6af998 100644 --- a/src/dolphin/unwrap.py +++ b/src/dolphin/unwrap.py @@ -579,7 +579,7 @@ def unwrap_snaphu_py( zero_where_masked: bool = True, nodata: str | float | None = None, init_method: str = "mst", -): +) -> tuple[Path, Path]: """Unwrap an interferogram using at multiple scales using `tophu`. Parameters @@ -668,3 +668,5 @@ def unwrap_snaphu_py( corr.close() if mask is not None: mask.close() + + return Path(unw_filename), Path(cc_filename)