From 5525d0e6037e92408395a888cf82afd7a349a300 Mon Sep 17 00:00:00 2001 From: Kristoffer Skare Date: Fri, 15 Mar 2024 17:45:20 +0100 Subject: [PATCH] Cleanup code for finding default paths --- src/mlfmu/api.py | 147 +++++++++++++++++++++-------------------------- 1 file changed, 66 insertions(+), 81 deletions(-) diff --git a/src/mlfmu/api.py b/src/mlfmu/api.py index a880c29..2ec4ac5 100644 --- a/src/mlfmu/api.py +++ b/src/mlfmu/api.py @@ -66,6 +66,7 @@ class MlFmuBuilder: fmu_output_folder: Optional[Path] = None delete_build_folders: bool = False temp_folder: Optional[tempfile.TemporaryDirectory[str]] = None + root_directory: Path def __init__( self, @@ -76,6 +77,7 @@ def __init__( fmu_output_folder: Optional[Path] = None, build_folder: Optional[Path] = None, delete_build_folders: bool = False, + root_directory: Optional[Path] = None, ): self.fmu_name = fmu_name self.interface_file = interface_file @@ -84,27 +86,24 @@ def __init__( self.fmu_output_folder = fmu_output_folder self.build_folder = build_folder self.delete_build_folders = delete_build_folders + self.root_directory = root_directory or Path(os.getcwd()) def build(self): # TODO: Raise errors - if self.source_folder is None: - self.source_folder = self.default_build_source_folder() + self.source_folder = self.source_folder or self.default_build_source_folder() + + self.ml_model_file = self.ml_model_file or self.default_model_file() if self.ml_model_file is None: - self.ml_model_file = self.default_model_file() - if self.ml_model_file is None: - raise + raise + self.interface_file = self.interface_file or self.default_interface_file() if self.interface_file is None: - self.interface_file = self.default_interface_file() - if self.interface_file is None: - raise + raise - if self.build_folder is None: - self.build_folder = self.default_build_folder() + self.build_folder = self.build_folder or self.default_build_folder() - if self.fmu_output_folder is None: - self.fmu_output_folder = self.default_fmu_output_folder() + self.fmu_output_folder = self.fmu_output_folder or self.default_fmu_output_folder() try: fmi_model = builder.generate_fmu_files(self.source_folder, self.ml_model_file, self.interface_file) @@ -129,18 +128,17 @@ def build(self): def generate(self): # TODO: Raise errors - if self.source_folder is None: - self.source_folder = self.default_generate_source_folder() + self.source_folder = self.source_folder or self.default_generate_source_folder() + + self.ml_model_file = self.ml_model_file or self.default_model_file() if self.ml_model_file is None: - self.ml_model_file = self.default_model_file() - if self.ml_model_file is None: - raise + raise + self.interface_file = self.interface_file or self.default_interface_file() if self.interface_file is None: - self.interface_file = self.default_interface_file() - if self.interface_file is None: - raise + raise + try: fmi_model = builder.generate_fmu_files(self.source_folder, self.ml_model_file, self.interface_file) self.fmu_name = fmi_model.name @@ -151,22 +149,12 @@ def generate(self): return def compile(self): - if self.build_folder is None: - self.build_folder = self.default_build_folder() + self.build_folder = self.build_folder or self.default_build_folder() - if self.fmu_output_folder is None: - self.fmu_output_folder = self.default_fmu_output_folder() + self.fmu_output_folder = self.fmu_output_folder or self.default_fmu_output_folder() if self.fmu_name is None or self.source_folder is None: - search_folders: List[Path] = [] - if self.source_folder is not None: - search_folders.append(self.source_folder) - - search_folders.append(Path(os.getcwd())) - - source_child_folder = self.default_compile_source_folder( - search_folders=search_folders, folder_name=self.fmu_name - ) + source_child_folder = self.default_compile_source_folder() if source_child_folder is None: raise self.fmu_name = source_child_folder.stem @@ -198,13 +186,52 @@ def delete_temp_folder(self): if self.temp_folder is not None and Path(self.temp_folder.name).exists(): shutil.rmtree(Path(self.temp_folder.name)) - @staticmethod - def default_interface_file(): - return MlFmuBuilder._find_default_file(Path(os.getcwd()), "json", "interface") + def default_interface_file(self): + return MlFmuBuilder._find_default_file(self.root_directory, "json", "interface") - @staticmethod - def default_model_file(): - return MlFmuBuilder._find_default_file(Path(os.getcwd()), "onnx", "model") + def default_model_file(self): + return MlFmuBuilder._find_default_file(self.root_directory, "onnx", "model") + + def default_build_folder(self): + self.temp_folder = self.temp_folder or tempfile.TemporaryDirectory(prefix="mlfmu_", delete=True) + return Path(self.temp_folder.name) / "build" + + def default_build_source_folder(self): + self.temp_folder = self.temp_folder or tempfile.TemporaryDirectory(prefix="mlfmu_", delete=True) + return Path(self.temp_folder.name) / "src" + + def default_generate_source_folder(self): + return self.root_directory + + def default_compile_source_folder(self): + search_folders: List[Path] = [] + if self.source_folder is not None: + search_folders.append(self.source_folder) + search_folders.append(self.root_directory) + source_folder: Optional[Path] = None + # If source folder is not provide try to find one in current folder that is compatible with the tool + # I.e a folder that contains everything needed for compilation + for current_folder in search_folders: + for dir, _, _ in os.walk(current_folder): + try: + possible_source_folder = Path(dir) + # If a fmu name is given and the candidate folder name does not match. Skip it! + if self.fmu_name is not None and possible_source_folder.stem != self.fmu_name: + continue + builder.validate_fmu_source_files(possible_source_folder) + source_folder = possible_source_folder + # If a match was found stop searching + break + except Exception: + # Any folder that does not contain the correct folder structure and files needed for compilation will raise and exception + continue + # If a match was found stop searching + if source_folder is not None: + break + return source_folder + + def default_fmu_output_folder(self): + return self.root_directory @staticmethod def _find_default_file(dir: Path, file_extension: str, default_name: Optional[str] = None): @@ -241,48 +268,6 @@ def _find_default_file(dir: Path, file_extension: str, default_name: Optional[st return name_matches[0] return - def default_build_folder(self): - self.temp_folder = self.temp_folder or tempfile.TemporaryDirectory(prefix="mlfmu_", delete=True) - return Path(self.temp_folder.name) / "build" - - def default_build_source_folder(self): - self.temp_folder = self.temp_folder or tempfile.TemporaryDirectory(prefix="mlfmu_", delete=True) - return Path(self.temp_folder.name) / "src" - - @staticmethod - def default_generate_source_folder(): - return Path(os.getcwd()) - - @staticmethod - def default_compile_source_folder(search_folders: Optional[List[Path]] = None, folder_name: Optional[str] = None): - if search_folders is None: - search_folders = [Path(os.getcwd())] - source_folder: Optional[Path] = None - # If source folder is not provide try to find one in current folder that is compatible with the tool - # I.e a folder that contains everything needed for compilation - for current_folder in search_folders: - for dir, _, _ in os.walk(current_folder): - try: - possible_source_folder = Path(dir) - # If a folder name is given and the candidate folder name does not match. Skip it! - if folder_name is not None and possible_source_folder.stem != folder_name: - continue - builder.validate_fmu_source_files(possible_source_folder) - source_folder = possible_source_folder - # If a match was found stop searching - break - except Exception: - # Any folder that does not contain the correct folder structure and files needed for compilation will raise and exception - continue - # If a match was found stop searching - if source_folder is not None: - break - return source_folder - - @staticmethod - def default_fmu_output_folder(): - return Path(os.getcwd()) - class MlFmuProcess: """Top level class encapsulating the mlfmu process."""