Skip to content

Commit

Permalink
Cleanup code for finding default paths
Browse files Browse the repository at this point in the history
  • Loading branch information
KristofferSkare committed Mar 15, 2024
1 parent b1aa071 commit 5525d0e
Showing 1 changed file with 66 additions and 81 deletions.
147 changes: 66 additions & 81 deletions src/mlfmu/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class MlFmuBuilder:
fmu_output_folder: Optional[Path] = None
delete_build_folders: bool = False
temp_folder: Optional[tempfile.TemporaryDirectory[str]] = None
root_directory: Path

def __init__(
self,
Expand All @@ -76,6 +77,7 @@ def __init__(
fmu_output_folder: Optional[Path] = None,
build_folder: Optional[Path] = None,
delete_build_folders: bool = False,
root_directory: Optional[Path] = None,
):
self.fmu_name = fmu_name
self.interface_file = interface_file
Expand All @@ -84,27 +86,24 @@ def __init__(
self.fmu_output_folder = fmu_output_folder
self.build_folder = build_folder
self.delete_build_folders = delete_build_folders
self.root_directory = root_directory or Path(os.getcwd())

def build(self):
# TODO: Raise errors
if self.source_folder is None:
self.source_folder = self.default_build_source_folder()

self.source_folder = self.source_folder or self.default_build_source_folder()

self.ml_model_file = self.ml_model_file or self.default_model_file()
if self.ml_model_file is None:
self.ml_model_file = self.default_model_file()
if self.ml_model_file is None:
raise
raise

self.interface_file = self.interface_file or self.default_interface_file()
if self.interface_file is None:
self.interface_file = self.default_interface_file()
if self.interface_file is None:
raise
raise

if self.build_folder is None:
self.build_folder = self.default_build_folder()
self.build_folder = self.build_folder or self.default_build_folder()

if self.fmu_output_folder is None:
self.fmu_output_folder = self.default_fmu_output_folder()
self.fmu_output_folder = self.fmu_output_folder or self.default_fmu_output_folder()

try:
fmi_model = builder.generate_fmu_files(self.source_folder, self.ml_model_file, self.interface_file)
Expand All @@ -129,18 +128,17 @@ def build(self):

def generate(self):
# TODO: Raise errors
if self.source_folder is None:
self.source_folder = self.default_generate_source_folder()

self.source_folder = self.source_folder or self.default_generate_source_folder()

self.ml_model_file = self.ml_model_file or self.default_model_file()
if self.ml_model_file is None:
self.ml_model_file = self.default_model_file()
if self.ml_model_file is None:
raise
raise

self.interface_file = self.interface_file or self.default_interface_file()
if self.interface_file is None:
self.interface_file = self.default_interface_file()
if self.interface_file is None:
raise
raise

try:
fmi_model = builder.generate_fmu_files(self.source_folder, self.ml_model_file, self.interface_file)
self.fmu_name = fmi_model.name
Expand All @@ -151,22 +149,12 @@ def generate(self):
return

def compile(self):
if self.build_folder is None:
self.build_folder = self.default_build_folder()
self.build_folder = self.build_folder or self.default_build_folder()

if self.fmu_output_folder is None:
self.fmu_output_folder = self.default_fmu_output_folder()
self.fmu_output_folder = self.fmu_output_folder or self.default_fmu_output_folder()

if self.fmu_name is None or self.source_folder is None:
search_folders: List[Path] = []
if self.source_folder is not None:
search_folders.append(self.source_folder)

search_folders.append(Path(os.getcwd()))

source_child_folder = self.default_compile_source_folder(
search_folders=search_folders, folder_name=self.fmu_name
)
source_child_folder = self.default_compile_source_folder()
if source_child_folder is None:
raise
self.fmu_name = source_child_folder.stem
Expand Down Expand Up @@ -198,13 +186,52 @@ def delete_temp_folder(self):
if self.temp_folder is not None and Path(self.temp_folder.name).exists():
shutil.rmtree(Path(self.temp_folder.name))

@staticmethod
def default_interface_file():
return MlFmuBuilder._find_default_file(Path(os.getcwd()), "json", "interface")
def default_interface_file(self):
return MlFmuBuilder._find_default_file(self.root_directory, "json", "interface")

@staticmethod
def default_model_file():
return MlFmuBuilder._find_default_file(Path(os.getcwd()), "onnx", "model")
def default_model_file(self):
return MlFmuBuilder._find_default_file(self.root_directory, "onnx", "model")

def default_build_folder(self):
self.temp_folder = self.temp_folder or tempfile.TemporaryDirectory(prefix="mlfmu_", delete=True)
return Path(self.temp_folder.name) / "build"

def default_build_source_folder(self):
self.temp_folder = self.temp_folder or tempfile.TemporaryDirectory(prefix="mlfmu_", delete=True)
return Path(self.temp_folder.name) / "src"

def default_generate_source_folder(self):
return self.root_directory

def default_compile_source_folder(self):
search_folders: List[Path] = []
if self.source_folder is not None:
search_folders.append(self.source_folder)
search_folders.append(self.root_directory)
source_folder: Optional[Path] = None
# If source folder is not provide try to find one in current folder that is compatible with the tool
# I.e a folder that contains everything needed for compilation
for current_folder in search_folders:
for dir, _, _ in os.walk(current_folder):
try:
possible_source_folder = Path(dir)
# If a fmu name is given and the candidate folder name does not match. Skip it!
if self.fmu_name is not None and possible_source_folder.stem != self.fmu_name:
continue
builder.validate_fmu_source_files(possible_source_folder)
source_folder = possible_source_folder
# If a match was found stop searching
break
except Exception:
# Any folder that does not contain the correct folder structure and files needed for compilation will raise and exception
continue
# If a match was found stop searching
if source_folder is not None:
break
return source_folder

def default_fmu_output_folder(self):
return self.root_directory

@staticmethod
def _find_default_file(dir: Path, file_extension: str, default_name: Optional[str] = None):
Expand Down Expand Up @@ -241,48 +268,6 @@ def _find_default_file(dir: Path, file_extension: str, default_name: Optional[st
return name_matches[0]
return

def default_build_folder(self):
self.temp_folder = self.temp_folder or tempfile.TemporaryDirectory(prefix="mlfmu_", delete=True)
return Path(self.temp_folder.name) / "build"

def default_build_source_folder(self):
self.temp_folder = self.temp_folder or tempfile.TemporaryDirectory(prefix="mlfmu_", delete=True)
return Path(self.temp_folder.name) / "src"

@staticmethod
def default_generate_source_folder():
return Path(os.getcwd())

@staticmethod
def default_compile_source_folder(search_folders: Optional[List[Path]] = None, folder_name: Optional[str] = None):
if search_folders is None:
search_folders = [Path(os.getcwd())]
source_folder: Optional[Path] = None
# If source folder is not provide try to find one in current folder that is compatible with the tool
# I.e a folder that contains everything needed for compilation
for current_folder in search_folders:
for dir, _, _ in os.walk(current_folder):
try:
possible_source_folder = Path(dir)
# If a folder name is given and the candidate folder name does not match. Skip it!
if folder_name is not None and possible_source_folder.stem != folder_name:
continue
builder.validate_fmu_source_files(possible_source_folder)
source_folder = possible_source_folder
# If a match was found stop searching
break
except Exception:
# Any folder that does not contain the correct folder structure and files needed for compilation will raise and exception
continue
# If a match was found stop searching
if source_folder is not None:
break
return source_folder

@staticmethod
def default_fmu_output_folder():
return Path(os.getcwd())


class MlFmuProcess:
"""Top level class encapsulating the mlfmu process."""
Expand Down

0 comments on commit 5525d0e

Please sign in to comment.