diff --git a/self_hosting_machinery/configs/__init__.py b/self_hosting_machinery/configs/__init__.py new file mode 100644 index 00000000..7d2f03e0 --- /dev/null +++ b/self_hosting_machinery/configs/__init__.py @@ -0,0 +1,8 @@ +from self_hosting_machinery.configs.configs_in_files import ( + rcfg_save, + rcfg_load, + rcfg_mtime, + rcfg_list_uploads +) + + diff --git a/self_hosting_machinery/configs/configs_in_files.py b/self_hosting_machinery/configs/configs_in_files.py new file mode 100644 index 00000000..a123c59a --- /dev/null +++ b/self_hosting_machinery/configs/configs_in_files.py @@ -0,0 +1,34 @@ +import os +import json +import time +from self_hosting_machinery import env +from typing import Dict, Any, Optional, List + + +def rcfg_save(path: str, adict: Dict[str, Any]) -> None: + with open(path + ".tmp", "w") as f: + json.dump(adict, f, indent=4) + os.rename(path + ".tmp", path) + + +def rcfg_load(path: str, default_if_not_found: Optional[Dict[str, Any]]=None) -> Dict[str, Any]: + if default_if_not_found is not None and not os.path.exists(path): + return default_if_not_found + with open(path, "r") as f: + return json.load(f) + + +def rcfg_load_not_too_old(path: str, seconds: int, default_if_not_found_or_too_old: Optional[Dict[str, Any]]=None) -> Dict[str, Any]: + if rcfg_mtime(path) + seconds < time.time(): + return default_if_not_found_or_too_old + return rcfg_load(path, default_if_not_found_or_too_old) + + +def rcfg_mtime(path) -> int: + if not os.path.exists(path): + return 0 + return os.path.getmtime(path) + + +def rcfg_list_uploads(uploaded_path: str) -> List[str]: + return list(sorted(os.listdir(uploaded_path))) diff --git a/self_hosting_machinery/webgui/tab_upload.py b/self_hosting_machinery/webgui/tab_upload.py index 080c7b09..5571fe27 100644 --- a/self_hosting_machinery/webgui/tab_upload.py +++ b/self_hosting_machinery/webgui/tab_upload.py @@ -15,6 +15,8 @@ from refact_data_pipeline.finetune.finetune_utils import get_prog_and_status_for_ui from self_hosting_machinery.webgui.selfhost_webutils import log from self_hosting_machinery import env +from self_hosting_machinery import configs + __all__ = ["TabUploadRouter", "download_file_from_url", "UploadViaURL"] @@ -107,46 +109,39 @@ async def _tab_files_get(self): result = { "uploaded_files": {} } - uploaded_path = env.DIR_UPLOADS - if os.path.isfile(env.CONFIG_HOW_TO_UNZIP): - how_to_process = json.load(open(env.CONFIG_HOW_TO_UNZIP, "r")) - else: - how_to_process = {'uploaded_files': {}} - - scan_stats = {"uploaded_files": {}} - stats_uploaded_files = {} - if os.path.isfile(env.CONFIG_PROCESSING_STATS): - scan_stats = json.load(open(env.CONFIG_PROCESSING_STATS, "r")) - mtime = os.path.getmtime(env.CONFIG_PROCESSING_STATS) - stats_uploaded_files = scan_stats.get("uploaded_files", {}) - for fstat in stats_uploaded_files.values(): - if fstat["status"] in ["working", "starting"]: - if mtime + 600 < time.time(): - fstat["status"] = "failed" - - if os.path.isfile(env.CONFIG_HOW_TO_FILETYPES): - result["filetypes"] = json.load(open(env.CONFIG_HOW_TO_FILETYPES, "r")) - else: - result["filetypes"] = { + how_to_process = configs.rcfg_load(env.CONFIG_HOW_TO_UNZIP, {'uploaded_files': {}}) + + scan_stats = configs.rcfg_load(env.CONFIG_PROCESSING_STATS, {"uploaded_files": {}}) + scan_stats_mtime = configs.rcfg_mtime(env.CONFIG_PROCESSING_STATS) + stats_uploaded_files = scan_stats["uploaded_files"] + for fstat in stats_uploaded_files.values(): + if fstat["status"] in ["working", "starting"]: + if scan_stats_mtime + 600 < time.time(): + fstat["status"] = "failed" + + filetypes = configs.rcfg_load(env.CONFIG_HOW_TO_FILETYPES, + { "filetypes_finetune": {}, "filetypes_db": {}, "force_include": "", "force_exclude": "", } + ) + result["filetypes"] = filetypes default = { "which_set": "train", "to_db": True, } - for fn in sorted(os.listdir(uploaded_path)): + for fn in configs.rcfg_list_uploads(env.DIR_UPLOADS): result["uploaded_files"][fn] = { "which_set": how_to_process["uploaded_files"].get(fn, default)["which_set"], "to_db": how_to_process["uploaded_files"].get(fn, default)["to_db"], "is_git": False, **stats_uploaded_files.get(fn, {}) } - if os.path.exists(os.path.join(uploaded_path, fn, env.GIT_CONFIG_FILENAME)): - with open(os.path.join(uploaded_path, fn, env.GIT_CONFIG_FILENAME)) as f: + if os.path.exists(os.path.join(env.DIR_UPLOADS, fn, env.GIT_CONFIG_FILENAME)): + with open(os.path.join(env.DIR_UPLOADS, fn, env.GIT_CONFIG_FILENAME)) as f: config = json.load(f) result["uploaded_files"][fn].update({ "is_git": True, @@ -166,9 +161,7 @@ async def _tab_files_get(self): return Response(json.dumps(result, indent=4) + "\n") async def _tab_files_save_config(self, config: TabFilesConfig): - with open(env.CONFIG_HOW_TO_UNZIP + ".tmp", "w") as f: - json.dump(config.dict(), f, indent=4) - os.rename(env.CONFIG_HOW_TO_UNZIP + ".tmp", env.CONFIG_HOW_TO_UNZIP) + configs.rcfg_save(env.CONFIG_HOW_TO_UNZIP, config.dict()) # _reset_process_stats() -- this requires process script restart, but it flashes too much in GUI return JSONResponse("OK") @@ -215,6 +208,7 @@ async def _tab_files_repo_upload(self, repo: CloneRepo): class IncorrectUrl(Exception): def __init__(self): super().__init__() + def cleanup_url(url: str): for sym in ["\t", " "]: splited = list(filter(lambda x: len(x) > 0, url.split(sym))) @@ -222,6 +216,7 @@ def cleanup_url(url: str): raise IncorrectUrl() url = splited[0] return url + def check_url(url: str): from giturlparse import parse if not parse(url).valid: @@ -235,11 +230,10 @@ def get_repo_name_from_url(url: str) -> str: last_suffix_index = url.rfind(".git") if last_suffix_index < 0: last_suffix_index = len(url) - if last_slash_index < 0 or last_suffix_index <= last_slash_index: raise Exception("Badly formatted url {}".format(url)) - return url[last_slash_index + 1:last_suffix_index] + try: url = cleanup_url(repo.url) url = check_url(url)