From 6dffa3daffd98b5e7f67ee10c825da4f6c8f7de2 Mon Sep 17 00:00:00 2001 From: Frost Ming Date: Fri, 3 Jan 2025 10:32:16 +0800 Subject: [PATCH] fix: missing git urls when freezing requirements (#22) Signed-off-by: Frost Ming --- nodes/api.py | 37 +++++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/nodes/api.py b/nodes/api.py index 17755f5..f3d7391 100644 --- a/nodes/api.py +++ b/nodes/api.py @@ -11,6 +11,7 @@ import time import uuid import zipfile +from importlib.metadata import Distribution, distributions from pathlib import Path from typing import Any, Union @@ -25,25 +26,33 @@ ZPath = Union[Path, zipfile.Path] TEMP_FOLDER = Path(__file__).parent.parent / "temp" COMFY_PACK_DIR = Path(__file__).parent.parent / "src" / "comfy_pack" -EXCLUDE_PACKAGES = ["bentoml", "onnxruntime"] # TODO: standardize this +EXCLUDE_PACKAGES = ["bentoml", "onnxruntime", "conda"] # TODO: standardize this + + +def _get_requirement_string(dist: Distribution) -> str: + direct_url_text = dist.read_text("direct_url.json") + pinned_str = f'{dist.metadata["Name"]}=={dist.version}' + if not direct_url_text: + return pinned_str + direct_url = json.loads(direct_url_text) + if url := direct_url.get("url"): + if url.startswith("file://"): + # we are not able to share local files + return pinned_str + if vcs_info := direct_url.get("vcs_info"): + url = f"{vcs_info['vcs']}+{url}@{vcs_info['commit_id']}" + if subdirectory := direct_url.get("subdirectory"): + url += f"#subdirectory={subdirectory}" + return f"{dist.metadata['Name']} @ {url}" + else: + return pinned_str async def _write_requirements(path: ZPath, extras: list[str] | None = None) -> None: print("Package => Writing requirements.txt") with path.joinpath("requirements.txt").open("w") as f: - proc = await asyncio.subprocess.create_subprocess_exec( - sys.executable, - "-m", - "pip", - "list", - "--format", - "freeze", - "--exclude-editable", - *[f"--exclude={p}" for p in EXCLUDE_PACKAGES], - stdout=subprocess.PIPE, - ) - stdout, _ = await proc.communicate() - f.write(stdout.decode().rstrip("\n") + "\n") + for dist in distributions(): + f.write(_get_requirement_string(dist) + "\n") if extras: f.write("\n".join(extras) + "\n")