diff --git a/.github/workflows/type_check.yaml b/.github/workflows/type_check.yaml index 1e54124..c7161c7 100644 --- a/.github/workflows/type_check.yaml +++ b/.github/workflows/type_check.yaml @@ -15,4 +15,4 @@ jobs: - name: Install dependencies run: pip install ".[all]" - name: Type Check (mypy) - run: mypy src/lean_dojo/interaction + run: mypy src/lean_dojo diff --git a/.gitignore b/.gitignore index cd92677..8d89e22 100644 --- a/.gitignore +++ b/.gitignore @@ -130,3 +130,6 @@ dmypy.json # Pyre type checker .pyre/ + +# vscode debug config +.vscode/ \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index 5edd158..e12d4d8 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -13,7 +13,7 @@ project = "LeanDojo" copyright = "2023, LeanDojo Team" author = "Kaiyu Yang" -release = "2.0.3" +release = "2.1.0" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/mypy.ini b/mypy.ini index 9ca677c..c48bfef 100644 --- a/mypy.ini +++ b/mypy.ini @@ -8,4 +8,13 @@ disallow_untyped_calls = False follow_imports = skip [mypy-pexpect.*] +ignore_missing_imports = True + +[mypy-lxml.*] +ignore_missing_imports = True + +[mypy-tqdm.*] +ignore_missing_imports = True + +[mypy-networkx.*] ignore_missing_imports = True \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 4625fc2..fec8ce5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ exclude = [ [project] name = "lean-dojo" -version = "2.0.3" +version = "2.1.0" authors = [ { name="Kaiyu Yang", email="kaiyuy@meta.com" }, ] @@ -31,6 +31,7 @@ dependencies = [ "python-dotenv", "loguru", "filelock", + "gitpython", "psutil", "pexpect", "types-psutil", diff --git a/src/lean_dojo/constants.py b/src/lean_dojo/constants.py index aa50261..590c23a 100644 --- a/src/lean_dojo/constants.py +++ b/src/lean_dojo/constants.py @@ -14,7 +14,7 @@ load_dotenv() -__version__ = "2.0.3" +__version__ = "2.1.0" logger.remove() if "VERBOSE" in os.environ or "DEBUG" in os.environ: @@ -71,15 +71,16 @@ assert re.fullmatch(r"\d+g", TACTIC_MEMORY_LIMIT) -def check_git_version(min_version: Tuple[int, int, int]) -> Tuple[int, int, int]: +def check_git_version(min_version: Tuple[int, int, int]) -> None: """Check the version of Git installed on the system.""" res = subprocess.run("git --version", shell=True, capture_output=True, check=True) - output = res.stdout.decode() + output = res.stdout.decode().strip() error = res.stderr.decode() assert error == "", error - m = re.match(r"git version (?P[0-9.]+)", output) - version = tuple(int(_) for _ in m["version"].split(".")) - + m = re.search(r"git version (\d+\.\d+\.\d+)", output) + assert m, f"Could not parse Git version from: {output}" + # Convert version number string to tuple of integers + version = tuple(int(_) for _ in m.group(1).split(".")) version_str = ".".join(str(_) for _ in version) min_version_str = ".".join(str(_) for _ in min_version) assert ( diff --git a/src/lean_dojo/data_extraction/ast.py b/src/lean_dojo/data_extraction/ast.py index f858c82..2b88640 100644 --- a/src/lean_dojo/data_extraction/ast.py +++ b/src/lean_dojo/data_extraction/ast.py @@ -27,7 +27,7 @@ def from_data(cls, node_data: Dict[str, Any], lean_file: LeanFile) -> "Node": return subcls.from_data(node_data, lean_file) @classmethod - def _kind_to_node_type(cls, kind: str) -> type: + def _kind_to_node_type(cls, kind: str) -> type["Node"]: prefix = "Lean.Parser." if kind.startswith(prefix): kind = kind[len(prefix) :] @@ -83,7 +83,7 @@ def from_xml(cls, tree: etree.Element, lean_file: LeanFile) -> "Node": start = Pos.from_str(tree.attrib["start"]) if "start" in tree.attrib else None end = Pos.from_str(tree.attrib["end"]) if "end" in tree.attrib else None children = [Node.from_xml(subtree, lean_file) for subtree in tree] - kwargs = {} + kwargs: Dict[str, Any] = {} for field in subcls.__dataclass_fields__.values(): if field.name in ("lean_file", "start", "end", "children"): @@ -113,11 +113,13 @@ def from_xml(cls, tree: etree.Element, lean_file: LeanFile) -> "Node": return subcls(lean_file, start, end, children, **kwargs) # type: ignore - def get_closure(self) -> Tuple[Pos, Pos]: + def get_closure(self) -> Tuple[Optional[Pos], Optional[Pos]]: return self.start, self.end -def _parse_pos(info: Dict[str, Any], lean_file: LeanFile) -> Pos: +def _parse_pos( + info: Dict[str, Any], lean_file: LeanFile +) -> Optional[Tuple[Optional[Pos], Optional[Pos]]]: if "synthetic" in info and not info["synthetic"]["canonical"]: return None diff --git a/src/lean_dojo/data_extraction/cache.py b/src/lean_dojo/data_extraction/cache.py index acfe1c9..3bf239c 100644 --- a/src/lean_dojo/data_extraction/cache.py +++ b/src/lean_dojo/data_extraction/cache.py @@ -7,13 +7,12 @@ from pathlib import Path from loguru import logger from filelock import FileLock +from typing import Optional, Generator from dataclasses import dataclass, field -from typing import Optional, Tuple, Generator from ..utils import ( execute, url_exists, - get_repo_info, report_critical_failure, ) from ..constants import ( @@ -23,22 +22,6 @@ ) -def _split_git_url(url: str) -> Tuple[str, str]: - """Split a Git URL into user name and repo name.""" - if url.endswith("/"): - url = url[:-1] - assert not url.endswith("/"), f"Unexpected URL: {url}" - fields = url.split("/") - user_name = fields[-2] - repo_name = fields[-1] - return user_name, repo_name - - -def _format_dirname(url: str, commit: str) -> str: - user_name, repo_name = _split_git_url(url) - return f"{user_name}-{repo_name}-{commit}" - - _CACHE_CORRPUTION_MSG = "The cache may have been corrputed!" @@ -59,16 +42,20 @@ def __post_init__(self): lock_path = self.cache_dir.with_suffix(".lock") object.__setattr__(self, "lock", FileLock(lock_path)) - def get(self, url: str, commit: str) -> Optional[Path]: - """Get the path of a traced repo with URL ``url`` and commit hash ``commit``. Return None if no such repo can be found.""" - _, repo_name = _split_git_url(url) - dirname = _format_dirname(url, commit) + def get(self, rel_cache_dir: Path) -> Optional[Path]: + """Get the cache repo at ``CACHE_DIR / rel_cache_dir`` from the cache. + + Args: + rel_cache_dir (Path): The relative path of the stored repo in the cache. + """ + dirname = rel_cache_dir.parent dirpath = self.cache_dir / dirname + cache_path = self.cache_dir / rel_cache_dir with self.lock: if dirpath.exists(): - assert (dirpath / repo_name).exists() - return dirpath / repo_name + assert cache_path.exists() + return cache_path elif not DISABLE_REMOTE_CACHE: url = os.path.join(REMOTE_CACHE_URL, f"{dirname}.tar.gz") @@ -83,23 +70,27 @@ def get(self, url: str, commit: str) -> Optional[Path]: with tarfile.open(f"{dirpath}.tar.gz") as tar: tar.extractall(self.cache_dir) os.remove(f"{dirpath}.tar.gz") - assert (dirpath / repo_name).exists() + assert (cache_path).exists() - return dirpath / repo_name + return cache_path else: return None - def store(self, src: Path) -> Path: - """Store a traced repo at path ``src``. Return its path in the cache.""" - url, commit = get_repo_info(src) - dirpath = self.cache_dir / _format_dirname(url, commit) - _, repo_name = _split_git_url(url) + def store(self, src: Path, rel_cache_dir: Path) -> Path: + """Store a repo at path ``src``. Return its path in the cache. + + Args: + src (Path): Path to the repo. + rel_cache_dir (Path): The relative path of the stored repo in the cache. + """ + dirpath = self.cache_dir / rel_cache_dir.parent + cache_path = self.cache_dir / rel_cache_dir if not dirpath.exists(): with self.lock: with report_critical_failure(_CACHE_CORRPUTION_MSG): - shutil.copytree(src, dirpath / repo_name) - return dirpath / repo_name + shutil.copytree(src, cache_path) + return cache_path cache = Cache(CACHE_DIR) diff --git a/src/lean_dojo/data_extraction/lean.py b/src/lean_dojo/data_extraction/lean.py index 5d7b140..e3aabc9 100644 --- a/src/lean_dojo/data_extraction/lean.py +++ b/src/lean_dojo/data_extraction/lean.py @@ -8,10 +8,14 @@ import toml import time import urllib +import shutil +import tempfile import webbrowser +from enum import Enum from pathlib import Path from loguru import logger from functools import cache +from git import Repo, BadName from github import Github, Auth from dataclasses import dataclass, field from github.Repository import Repository @@ -19,13 +23,13 @@ from typing import List, Dict, Any, Generator, Union, Optional, Tuple, Iterator from ..utils import ( - execute, read_url, url_exists, - get_repo_info, working_directory, + is_git_repo, ) -from ..constants import LEAN4_URL +from .cache import cache as repo_cache +from ..constants import TMP_DIR, LEAN4_URL GITHUB_ACCESS_TOKEN = os.getenv("GITHUB_ACCESS_TOKEN", None) @@ -43,24 +47,115 @@ ) GITHUB = Github() -LEAN4_REPO = GITHUB.get_repo("leanprover/lean4") +LEAN4_REPO = None """The GitHub Repo for Lean 4 itself.""" _URL_REGEX = re.compile(r"(?P.*?)/*") +_SSH_TO_HTTPS_REGEX = re.compile(r"^git@github\.com:(.+)/(.+)(?:\.git)?$") -def normalize_url(url: str) -> str: - return _URL_REGEX.fullmatch(url)["url"] # Remove trailing `/`. +REPO_CACHE_PREFIX = "repos" + + +class RepoType(Enum): + GITHUB = 0 + REMOTE = 1 # Remote but not GitHub. + LOCAL = 2 + + +def normalize_url(url: str, repo_type: RepoType = RepoType.GITHUB) -> str: + if repo_type == RepoType.LOCAL: # Convert to absolute path if local. + return os.path.abspath(url) + # Remove trailing `/`. + return _URL_REGEX.fullmatch(url)["url"] # type: ignore + + +def get_repo_type(url: str) -> Optional[RepoType]: + """Get the type of the repository. + + Args: + url (str): The URL of the repository. + Returns: + Optional[str]: The type of the repository (None if the repo cannot be found). + """ + m = _SSH_TO_HTTPS_REGEX.match(url) + url = f"https://github.com/{m.group(1)}/{m.group(2)}" if m else url + parsed_url = urllib.parse.urlparse(url) # type: ignore + if parsed_url.scheme in ["http", "https"]: + # Case 1 - GitHub URL. + if "github.com" in url: + if not url.startswith("https://"): + logger.warning(f"{url} should start with https://") + return None + else: + return RepoType.GITHUB + # Case 2 - remote URL. + elif url_exists(url): # Not check whether it is a git URL + return RepoType.REMOTE + # Case 3 - local path + elif is_git_repo(Path(parsed_url.path)): + return RepoType.LOCAL + logger.warning(f"{url} is not a valid URL") + return None + + +def _split_git_url(url: str) -> Tuple[str, str]: + """Split a Git URL into user name and repo name.""" + if url.endswith("/"): + url = url[:-1] + assert not url.endswith("/"), f"Unexpected URL: {url}" + fields = url.split("/") + user_name = fields[-2] + repo_name = fields[-1] + return user_name, repo_name + + +def _format_cache_dirname(url: str, commit: str) -> str: + user_name, repo_name = _split_git_url(url) + repo_type = get_repo_type(url) + assert repo_type is not None, f"Invalid url {url}" + if repo_type == RepoType.GITHUB: + return f"{user_name}-{repo_name}-{commit}" + else: # git repo + return f"gitpython-{repo_name}-{commit}" @cache -def url_to_repo(url: str, num_retries: int = 2) -> Repository: +def url_to_repo( + url: str, + num_retries: int = 2, + repo_type: Optional[RepoType] = None, + tmp_dir: Optional[Path] = None, +) -> Union[Repo, Repository]: + """Convert a URL to a Repo object. + + Args: + url (str): The URL of the repository. + num_retries (int): Number of retries in case of failure. + repo_type (Optional[RepoType]): The type of the repository. Defaults to None. + tmp_dir (Optional[Path]): The temporary directory to clone the repo to. Defaults to None. + + Returns: + Repo: A Git Repo object. + """ url = normalize_url(url) backoff = 1 - + if tmp_dir is None: + tmp_dir = (TMP_DIR or Path("/tmp")) / next(tempfile._get_candidate_names()) # type: ignore + repo_type = repo_type or get_repo_type(url) + assert repo_type is not None, f"Invalid url {url}" while True: try: - return GITHUB.get_repo("/".join(url.split("/")[-2:])) + if repo_type == RepoType.GITHUB: + return GITHUB.get_repo("/".join(url.split("/")[-2:])) + with working_directory(tmp_dir): + repo_name = os.path.basename(url) + if repo_type == RepoType.LOCAL: + assert is_git_repo(url), f"Local path {url} is not a git repo" + shutil.copytree(url, repo_name) + return Repo(repo_name) + else: + return Repo.clone_from(url, repo_name) except Exception as ex: if num_retries <= 0: raise ex @@ -74,7 +169,10 @@ def url_to_repo(url: str, num_retries: int = 2) -> Repository: def get_latest_commit(url: str) -> str: """Get the hash of the latest commit of the Git repo at ``url``.""" repo = url_to_repo(url) - return repo.get_branch(repo.default_branch).commit.sha + if isinstance(repo, Repository): + return repo.get_branch(repo.default_branch).commit.sha + else: + return repo.head.commit.hexsha def cleanse_string(s: Union[str, Path]) -> str: @@ -83,20 +181,24 @@ def cleanse_string(s: Union[str, Path]) -> str: @cache -def _to_commit_hash(repo: Repository, label: str) -> str: +def _to_commit_hash(repo: Union[Repository, Repo], label: str) -> str: """Convert a tag or branch to a commit hash.""" - logger.debug(f"Querying the commit hash for {repo.name} {label}") - - try: - return repo.get_branch(label).commit.sha - except GithubException: - pass - - for tag in repo.get_tags(): - if tag.name == label: - return tag.commit.sha - - raise ValueError(f"Invalid tag or branch: `{label}` for {repo}") + if isinstance(repo, Repository): # GitHub repository + logger.debug(f"Querying the commit hash for {repo.name} {label}") + try: + return repo.get_commit(label).sha + except GithubException as ex: + raise ValueError(f"Invalid tag or branch: `{label}` for {repo.name}") + else: # Local or remote Git repository + assert isinstance(repo, Repo) + logger.debug( + f"Querying the commit hash for {repo.working_dir} repository {label}" + ) + try: + # Resolve the label to a commit hash + return repo.commit(label).hexsha + except Exception as ex: + raise ValueError(f"Error converting ref to commit hash: {ex}") @dataclass(eq=True, unsafe_hash=True) @@ -318,6 +420,11 @@ def __getitem__(self, key) -> str: _LEAN4_VERSION_REGEX = re.compile(r"leanprover/lean4:(?P.+?)") +def is_commit_hash(s: str): + """Check if a string is a valid commit hash.""" + return len(s) == 40 and _COMMIT_REGEX.fullmatch(s) + + def get_lean4_version_from_config(toolchain: str) -> str: """Return the required Lean version given a ``lean-toolchain`` config.""" m = _LEAN4_VERSION_REGEX.fullmatch(toolchain.strip()) @@ -327,6 +434,9 @@ def get_lean4_version_from_config(toolchain: str) -> str: def get_lean4_commit_from_config(config_dict: Dict[str, Any]) -> str: """Return the required Lean commit given a ``lean-toolchain`` config.""" + global LEAN4_REPO + if LEAN4_REPO is None: + LEAN4_REPO = GITHUB.get_repo("leanprover/lean4") assert "content" in config_dict, "config_dict must have a 'content' field" config = config_dict["content"].strip() prefix = "leanprover/lean4:" @@ -335,7 +445,9 @@ def get_lean4_commit_from_config(config_dict: Dict[str, Any]) -> str: return _to_commit_hash(LEAN4_REPO, version) -URL = TAG = COMMIT = str +URL = str +TAG = str +COMMIT = str @dataclass(frozen=True) @@ -386,9 +498,9 @@ class LeanGitRepo: """Git repo of a Lean project.""" url: str - """The repo's Github URL. + """The repo's URL. - Note that we only support Github as of now. + It can be a GitHub URL that starts with https:// or git@github.com, a local path, or any other valid Git URL. """ commit: str @@ -397,42 +509,62 @@ class LeanGitRepo: You can also use tags such as ``v3.5.0``. They will be converted to commit hashes. """ - repo: Repository = field(init=False, repr=False) - """A :class:`github.Repository` object. + repo: Union[Repository, Repo] = field(init=False, repr=False) + """A :class:`github.Repository` object for GitHub repos or + a :class:`git.Repo` object for local or remote Git repos. """ lean_version: str = field(init=False, repr=False) """Required Lean version. """ + repo_type: RepoType = field(init=False, repr=False) + """Type of the repo. It can be ``GITHUB``, ``LOCAL`` or ``REMOTE``. + """ + def __post_init__(self) -> None: - if "github.com" not in self.url: - raise ValueError(f"{self.url} is not a Github URL") - if not self.url.startswith("https://"): + repo_type = get_repo_type(self.url) + if repo_type is None: raise ValueError(f"{self.url} is not a valid URL") - object.__setattr__(self, "url", normalize_url(self.url)) - object.__setattr__(self, "repo", url_to_repo(self.url)) - + object.__setattr__(self, "repo_type", repo_type) + object.__setattr__(self, "url", normalize_url(self.url, repo_type=repo_type)) + # set repo and commit + if repo_type == RepoType.GITHUB: + repo = url_to_repo(self.url, repo_type=repo_type) + else: + # get repo from cache + rel_cache_dir = lambda url, commit: Path( + f"{REPO_CACHE_PREFIX}/{_format_cache_dirname(url, commit)}/{self.name}" + ) + cache_repo_dir = repo_cache.get(rel_cache_dir(self.url, self.commit)) + if cache_repo_dir is None: + with working_directory() as tmp_dir: + repo = url_to_repo(self.url, repo_type=repo_type, tmp_dir=tmp_dir) + commit = _to_commit_hash(repo, self.commit) + cache_repo_dir = repo_cache.store( + repo.working_dir, rel_cache_dir(self.url, commit) + ) + repo = Repo(cache_repo_dir) # Convert tags or branches to commit hashes - if not (len(self.commit) == 40 and _COMMIT_REGEX.fullmatch(self.commit)): + if not is_commit_hash(self.commit): if (self.url, self.commit) in info_cache.tag2commit: commit = info_cache.tag2commit[(self.url, self.commit)] else: - commit = _to_commit_hash(self.repo, self.commit) - assert _COMMIT_REGEX.fullmatch(commit), f"Invalid commit hash: {commit}" - info_cache.tag2commit[(self.url, self.commit)] = commit + commit = _to_commit_hash(repo, self.commit) + assert is_commit_hash(commit), f"Invalid commit hash: {commit}" + info_cache.tag2commit[(self.url, commit)] = commit object.__setattr__(self, "commit", commit) + object.__setattr__(self, "repo", repo) # Determine the required Lean version. if (self.url, self.commit) in info_cache.lean_version: lean_version = info_cache.lean_version[(self.url, self.commit)] - elif self.is_lean4: - lean_version = self.commit + if self.is_lean4: + lean_version = "latest" # lean4 itself else: config = self.get_config("lean-toolchain") - lean_version = get_lean4_commit_from_config(config) - v = get_lean4_version_from_config(config["content"]) - if not is_supported_version(v): + lean_version = get_lean4_version_from_config(config["content"]) + if not is_supported_version(lean_version): logger.warning( f"{self} relies on an unsupported Lean version: {lean_version}" ) @@ -440,14 +572,14 @@ def __post_init__(self) -> None: object.__setattr__(self, "lean_version", lean_version) @classmethod - def from_path(cls, path: Path) -> "LeanGitRepo": + def from_path(cls, path: Union[Path, str]) -> "LeanGitRepo": """Construct a :class:`LeanGitRepo` object from the path to a local Git repo.""" - url, commit = get_repo_info(path) - return cls(url, commit) + commit = Repo(path).head.commit.hexsha + return cls(str(path), commit) @property def name(self) -> str: - return self.repo.name + return os.path.basename(self.url) @property def is_lean4(self) -> bool: @@ -455,24 +587,37 @@ def is_lean4(self) -> bool: @property def commit_url(self) -> str: - return os.path.join(self.url, f"tree/{self.commit}") + return f"{self.url}/tree/{self.commit}" + + def get_cache_dirname(self) -> Path: + """Return the formatted cache directory name""" + assert is_commit_hash(self.commit), f"Invalid commit hash: {self.commit}" + return Path(_format_cache_dirname(self.url, self.commit)) def show(self) -> None: """Show the repo in the default browser.""" webbrowser.open(self.commit_url) def exists(self) -> bool: - return url_exists(self.commit_url) + if self.repo_type != RepoType.GITHUB: + repo = self.repo # git repo + try: + repo.commit(self.commit) + return repo.head.commit.hexsha == self.commit + except BadName: + logger.warning( + f"Commit {self.commit} does not exist in this repository." + ) + return False + else: + return url_exists(self.commit_url) def clone_and_checkout(self) -> None: """Clone the repo to the current working directory and checkout a specific commit.""" logger.debug(f"Cloning {self}") - execute(f"git clone -n --recursive {self.url}", capture_output=True) - with working_directory(self.name): - execute( - f"git checkout {self.commit} && git submodule update --recursive", - capture_output=True, - ) + repo = Repo.clone_from(self.url, Path(self.name), no_checkout=True) + repo.git.checkout(self.commit) + repo.submodule_update(init=True, recursive=True) def get_dependencies( self, path: Union[str, Path, None] = None @@ -541,13 +686,13 @@ def _parse_deps( deps = [] for m in matches: - url = m["url"] + url = m["url"] # type: ignore if url.endswith(".git"): url = url[:-4] if url.startswith("git@"): url = "https://" + url[4:].replace(":", "/") - rev = m["rev"] + rev = m["rev"] # type: ignore if rev is None: commit = get_latest_commit(url) elif len(rev) == 40 and _COMMIT_REGEX.fullmatch(rev): @@ -559,7 +704,7 @@ def _parse_deps( commit = get_latest_commit(url) assert _COMMIT_REGEX.fullmatch(commit) - deps.append((m["name"], LeanGitRepo(url, commit))) + deps.append((m["name"], LeanGitRepo(url, commit))) # type: ignore return deps @@ -573,8 +718,8 @@ def _parse_lakefile_toml_dependencies( ) matches = dict() - for requirement in _LAKEFILE_TOML_REQUIREMENT_REGEX.finditer(lakefile): - for line in requirement.strip().splitlines(): + for req in _LAKEFILE_TOML_REQUIREMENT_REGEX.finditer(lakefile): + for line in req.group().strip().splitlines(): key, value = line.split("=") key = key.strip() value = value.strip() @@ -587,7 +732,7 @@ def _parse_lakefile_toml_dependencies( if key == "name": matches["name"] = value - return self._parse_deps(lakefile, matches) + return self._parse_deps(matches) def get_license(self) -> Optional[str]: """Return the content of the ``LICENSE`` file.""" @@ -596,7 +741,7 @@ def get_license(self) -> Optional[str]: license_url = f"{url}/{self.commit}/LICENSE" try: return read_url(license_url) - except urllib.error.HTTPError: + except urllib.error.HTTPError: # type: ignore return None def _get_config_url(self, filename: str) -> str: @@ -606,8 +751,13 @@ def _get_config_url(self, filename: str) -> str: def get_config(self, filename: str, num_retries: int = 2) -> Dict[str, Any]: """Return the repo's files.""" - config_url = self._get_config_url(filename) - content = read_url(config_url, num_retries) + if self.repo_type == RepoType.GITHUB: + config_url = self._get_config_url(filename) + content = read_url(config_url, num_retries) + else: + working_dir = self.repo.working_dir + with open(os.path.join(working_dir, filename), "r") as f: + content = f.read() if filename.endswith(".toml"): return toml.loads(content) elif filename.endswith(".json"): diff --git a/src/lean_dojo/data_extraction/trace.py b/src/lean_dojo/data_extraction/trace.py index e5b2492..072abf4 100644 --- a/src/lean_dojo/data_extraction/trace.py +++ b/src/lean_dojo/data_extraction/trace.py @@ -54,7 +54,7 @@ def _monitor(paths: List[Path], num_total: int) -> None: @contextmanager -def launch_progressbar(paths: List[Union[str, Path]]) -> Generator[None, None, None]: +def launch_progressbar(paths: List[Path]) -> Generator[None, None, None]: """Launch an async progressbar to monitor the progress of tracing the repo.""" paths = [Path(p) for p in paths] olean_files = list( @@ -71,7 +71,7 @@ def get_lean_version() -> str: """Get the version of Lean.""" output = execute("lean --version", capture_output=True)[0].strip() m = re.match(r"Lean \(version (?P\S+?),", output) - return m["version"] + return m["version"] # type: ignore def is_new_version(v: str) -> bool: @@ -93,7 +93,7 @@ def is_new_version(v: str) -> bool: return True -def check_files(packages_path: str, no_deps: bool) -> None: +def check_files(packages_path: Path, no_deps: bool) -> None: """Check if all \*.lean files have been processed to produce \*.ast.json and \*.dep_paths files.""" cwd = Path.cwd() packages_path = cwd / packages_path @@ -142,12 +142,12 @@ def _trace(repo: LeanGitRepo, build_deps: bool) -> None: # Copy the Lean 4 stdlib into the path of packages. lean_prefix = execute(f"lean --print-prefix", capture_output=True)[0].strip() if is_new_version(get_lean_version()): - packages_path = ".lake/packages" - build_path = ".lake/build" + packages_path = Path(".lake/packages") + build_path = Path(".lake/build") else: - packages_path = "lake-packages" - build_path = "build" - shutil.copytree(lean_prefix, f"{packages_path}/lean4") + packages_path = Path("lake-packages") + build_path = Path("build") + shutil.copytree(lean_prefix, str(packages_path / "lean4")) # Run ExtractData.lean to extract ASTs, tactic states, and premise information. shutil.copyfile(LEAN4_DATA_EXTRACTOR_PATH, LEAN4_DATA_EXTRACTOR_PATH.name) @@ -204,15 +204,17 @@ def get_traced_repo_path(repo: LeanGitRepo, build_deps: bool = True) -> Path: Returns: Path: The path of the traced repo in the cache, e.g. :file:`/home/kaiyu/.cache/lean_dojo/leanprover-community-mathlib-2196ab363eb097c008d4497125e0dde23fb36db2` """ - path = cache.get(repo.url, repo.commit) + rel_cache_dir = repo.get_cache_dirname() / repo.name + path = cache.get(rel_cache_dir) if path is None: logger.info(f"Tracing {repo}") with working_directory() as tmp_dir: logger.debug(f"Working in the temporary directory {tmp_dir}") _trace(repo, build_deps) - traced_repo = TracedRepo.from_traced_files(tmp_dir / repo.name, build_deps) + src_dir = tmp_dir / repo.name + traced_repo = TracedRepo.from_traced_files(src_dir, build_deps) traced_repo.save_to_disk() - path = cache.store(tmp_dir / repo.name) + path = cache.store(src_dir, rel_cache_dir) else: logger.debug("The traced repo is available in the cache.") return path diff --git a/src/lean_dojo/data_extraction/traced_data.py b/src/lean_dojo/data_extraction/traced_data.py index 82317b2..5e56050 100644 --- a/src/lean_dojo/data_extraction/traced_data.py +++ b/src/lean_dojo/data_extraction/traced_data.py @@ -25,7 +25,29 @@ to_json_path, to_xml_path, ) -from .ast import * +from .ast import ( + Node, + FileNode, + OtherNode, + LemmaNode, + IdentNode, + CommandEndNode, + ModuleImportNode, + ModulePreludeNode, + CommandSectionNode, + CommandTheoremNode, + CommandModuledocNode, + CommandNamespaceNode, + CommandDoccommentNode, + CommandDeclarationNode, + MathlibTacticLemmaNode, + TacticTacticseqbracketedNode, + TacticTacticseq1IndentedNode, + CommandNoncomputablesectionNode, + is_leaf, + is_mutual_lean4, + is_potential_premise_lean4, +) from .lean import LeanFile, LeanGitRepo, Theorem, Pos from ..constants import NUM_WORKERS, LOAD_USED_PACKAGES_ONLY, LEAN4_PACKAGES_DIR @@ -184,7 +206,10 @@ def get_annotated_tactic(self) -> Tuple[str, List[Dict[str, Any]]]: Returns: Tuple[str, List[Dict[str, Any]]]: The first return value is the tactic string marked by `` ... ``. The second return value is a list of provenances. """ - assert self.traced_theorem != None + assert ( + self.traced_theorem is not None + and self.traced_theorem.traced_file is not None + ) lean_file = self.traced_theorem.traced_file.lean_file annot_tac = [] provenances = [] @@ -243,7 +268,9 @@ class TracedTheorem: def __post_init__(self) -> None: assert ( - self.root_dir.is_absolute() and self.root_dir == self.traced_file.root_dir + self.root_dir.is_absolute() + and self.traced_file is not None + and self.root_dir == self.traced_file.root_dir ) def __getstate__(self) -> Dict[str, Any]: @@ -272,7 +299,7 @@ def file_path(self) -> Path: return self.theorem.file_path @property - def traced_repo(self) -> "TracedRepo": + def traced_repo(self) -> Optional["TracedRepo"]: """The traced repo this theorem belongs to.""" if self.traced_file is None: return None @@ -325,25 +352,11 @@ def get_tactic_proof(self) -> Optional[str]: def get_theorem_statement(self) -> str: """Return the theorem statement.""" proof_start, _ = self.locate_proof() + assert self.traced_file is not None return get_code_without_comments( self.traced_file.lean_file, self.ast.start, proof_start, self.comments ) - def get_single_tactic_proof(self) -> Optional[str]: - """Wrap the proof into a single (potentially very long) tactic.""" - if not self.has_tactic_proof(): - return None - node = self.get_proof_node() - start, end = node.get_closure() - proof = get_code_without_comments(node.lean_file, start, end, self.comments) - - raise NotImplementedError - assert isinstance(node.children[0], AtomNode) and node.children[0].val == "by" - assert proof.startswith("by") - proof = proof[len("by") :].strip() - - return proof - def get_premise_full_names(self) -> List[str]: """Return the fully qualified names of all premises used in the proof.""" names = [] @@ -526,7 +539,7 @@ def from_traced_file( def _from_lean4_traced_file( cls, root_dir: Path, json_path: Path, repo: LeanGitRepo ) -> "TracedFile": - lean_path = to_lean_path(root_dir, json_path, repo) + lean_path = to_lean_path(root_dir, json_path) lean_file = LeanFile(root_dir, lean_path) data = json.load(json_path.open()) @@ -712,6 +725,7 @@ def traverse_preorder(self, callback, node_cls: Optional[type] = None): def _get_repo_and_relative_path(self) -> Tuple[LeanGitRepo, Path]: """Return the repo this file belongs to, as well as the file's path relative to it.""" + assert self.traced_repo is not None if self.path.is_relative_to(LEAN4_PACKAGES_DIR): # The theorem belongs to one of the dependencies. p = self.path.relative_to(LEAN4_PACKAGES_DIR) @@ -737,24 +751,26 @@ def get_traced_theorem( def _callback( node: Union[CommandTheoremNode, LemmaNode, MathlibTacticLemmaNode], _ - ) -> None: + ) -> bool: nonlocal result, private_result - if not isinstance( - node, - ( - CommandTheoremNode, - LemmaNode, - MathlibTacticLemmaNode, - ), + if ( + isinstance( + node, + ( + CommandTheoremNode, + LemmaNode, + MathlibTacticLemmaNode, + ), + ) + and node.full_name == thm.full_name ): - return False - if node.full_name == thm.full_name: comments = self._filter_comments(node.start, node.end) t = TracedTheorem(self.root_dir, thm, node, comments, self) if t.is_private: private_result = t else: result = t + return False self.ast.traverse_preorder(_callback, node_cls=None) @@ -769,7 +785,7 @@ def get_traced_theorems(self) -> List[TracedTheorem]: def _callback( node: Union[CommandTheoremNode, LemmaNode, MathlibTacticLemmaNode], _ - ) -> None: + ) -> bool: if not isinstance( node, ( @@ -906,7 +922,7 @@ def from_xml( root_dir = Path(root_dir) path = Path(path) assert path.suffixes == [".trace", ".xml"] - lean_path = to_lean_path(root_dir, path, repo) + lean_path = to_lean_path(root_dir, path) lean_file = LeanFile(root_dir, lean_path) tree = etree.parse(path).getroot() @@ -1128,6 +1144,7 @@ def from_traced_files( def get_traced_file(self, path: Union[str, Path]) -> TracedFile: """Return a traced file by its path.""" + assert self.traced_files_graph is not None return self.traced_files_graph.nodes[str(path)]["traced_file"] def _update_traced_files(self) -> None: diff --git a/src/lean_dojo/utils.py b/src/lean_dojo/utils.py index 0386949..57f6240 100644 --- a/src/lean_dojo/utils.py +++ b/src/lean_dojo/utils.py @@ -71,7 +71,7 @@ def ray_actor_pool( """ assert not ray.is_initialized() ray.init() - pool = ActorPool([actor_cls.remote(*args, **kwargs) for _ in range(NUM_WORKERS)]) + pool = ActorPool([actor_cls.remote(*args, **kwargs) for _ in range(NUM_WORKERS)]) # type: ignore try: yield pool finally: @@ -144,33 +144,6 @@ def camel_case(s: str) -> str: return _CAMEL_CASE_REGEX.sub(" ", s).title().replace(" ", "") -@cache -def get_repo_info(path: Path) -> Tuple[str, str]: - """Get the URL and commit hash of the Git repo at ``path``. - - Args: - path (Path): Path to the Git repo. - - Returns: - Tuple[str, str]: URL and (most recent) hash commit - """ - with working_directory(path): - # Get the URL. - url_msg, _ = execute(f"git remote get-url origin", capture_output=True) - url = url_msg.strip() - # Get the commit. - commit_msg, _ = execute(f"git log -n 1", capture_output=True) - m = re.search(r"(?<=^commit )[a-z0-9]+", commit_msg) - assert m is not None - commit = m.group() - - if url.startswith("git@"): - assert url.endswith(".git") - url = url[: -len(".git")].replace(":", "/").replace("git@", "https://") - - return url, commit - - def is_optional_type(tp: type) -> bool: """Test if ``tp`` is Optional[X].""" if typing.get_origin(tp) != Union: @@ -181,8 +154,7 @@ def is_optional_type(tp: type) -> bool: def remove_optional_type(tp: type) -> type: """Given Optional[X], return X.""" - if typing.get_origin(tp) != Union: - return False + assert typing.get_origin(tp) == Union args = typing.get_args(tp) if len(args) == 2 and args[1] == type(None): return args[0] @@ -196,7 +168,11 @@ def read_url(url: str, num_retries: int = 2) -> str: backoff = 1 while True: try: - with urllib.request.urlopen(url) as f: + request = urllib.request.Request(url) # type: ignore + gh_token = os.getenv("GITHUB_ACCESS_TOKEN") + if gh_token is not None: + request.add_header("Authorization", f"token {gh_token}") + with urllib.request.urlopen(request) as f: # type: ignore return f.read().decode() except Exception as ex: if num_retries <= 0: @@ -209,11 +185,15 @@ def read_url(url: str, num_retries: int = 2) -> str: @cache def url_exists(url: str) -> bool: - """Return True if the URL ``url`` exists.""" + """Return True if the URL ``url`` exists, using the GITHUB_ACCESS_TOKEN for authentication if provided.""" try: - with urllib.request.urlopen(url) as _: + request = urllib.request.Request(url) # type: ignore + gh_token = os.getenv("GITHUB_ACCESS_TOKEN") + if gh_token is not None: + request.add_header("Authorization", f"token {gh_token}") + with urllib.request.urlopen(request) as _: # type: ignore return True - except urllib.error.HTTPError: + except urllib.error.HTTPError: # type: ignore return False @@ -279,7 +259,7 @@ def to_json_path(root_dir: Path, path: Path, repo) -> Path: return _from_lean_path(root_dir, path, repo, ext=".ast.json") -def to_lean_path(root_dir: Path, path: Path, repo) -> bool: +def to_lean_path(root_dir: Path, path: Path) -> Path: if path.is_absolute(): path = path.relative_to(root_dir) diff --git a/tests/conftest.py b/tests/conftest.py index 1e14c96..d581c62 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,8 @@ AESOP_URL = "https://github.com/leanprover-community/aesop" MATHLIB4_URL = "https://github.com/leanprover-community/mathlib4" LEAN4_EXAMPLE_URL = "https://github.com/yangky11/lean4-example" +EXAMPLE_COMMIT_HASH = "3f8c5eb303a225cdef609498b8d87262e5ef344b" +REMOTE_EXAMPLE_URL = "https://gitee.com/rexzong/lean4-example" URLS = [ BATTERIES_URL, AESOP_URL, @@ -15,6 +17,21 @@ ] +@pytest.fixture(scope="session") +def remote_example_url(): + return REMOTE_EXAMPLE_URL + + +@pytest.fixture(scope="session") +def example_commit_hash(): + return EXAMPLE_COMMIT_HASH + + +@pytest.fixture(scope="session") +def lean4_example_url(): + return LEAN4_EXAMPLE_URL + + @pytest.fixture(scope="session") def monkeysession(): with pytest.MonkeyPatch.context() as mp: diff --git a/tests/data_extraction/test_cache.py b/tests/data_extraction/test_cache.py new file mode 100644 index 0000000..e80d0f8 --- /dev/null +++ b/tests/data_extraction/test_cache.py @@ -0,0 +1,40 @@ +# test for cache manager +from git import Repo +from pathlib import Path +from lean_dojo.utils import working_directory +from lean_dojo.data_extraction.cache import cache + + +def test_local_repo_cache(lean4_example_url, example_commit_hash): + # Note: The `git.Repo` requires the local repo to be cloned in a directory + # all cached repos are stored in CACHE_DIR/repos + prefix = "repos" + repo_name = "lean4-example" + with working_directory() as tmp_dir: + repo = Repo.clone_from(lean4_example_url, repo_name) + repo.git.checkout(example_commit_hash) + local_dir = tmp_dir / repo_name + rel_cache_dir = ( + prefix / Path(f"gitpython-{repo_name}-{example_commit_hash}") / repo_name + ) + cache.store(local_dir, rel_cache_dir) + # get the cache + repo_cache_dir = cache.get(rel_cache_dir) + assert repo_cache_dir is not None + + +def test_remote_repo_cache(remote_example_url): + prefix = "repos" + repo_name = "lean4-example" + with working_directory() as tmp_dir: + repo = Repo.clone_from(remote_example_url, repo_name) + tmp_remote_dir = tmp_dir / repo_name + rel_cache_dir = ( + prefix + / Path(f"gitpython-{repo_name}-{repo.head.commit.hexsha}") + / repo_name + ) + cache.store(tmp_remote_dir, rel_cache_dir) + # get the cache + repo_cache = cache.get(rel_cache_dir) + assert repo_cache is not None diff --git a/tests/data_extraction/test_lean_repo.py b/tests/data_extraction/test_lean_repo.py new file mode 100644 index 0000000..42fc5bd --- /dev/null +++ b/tests/data_extraction/test_lean_repo.py @@ -0,0 +1,141 @@ +# Tests for the class `LeanGitRepo` +from git import Repo +from lean_dojo import LeanGitRepo +from github.Repository import Repository +from lean_dojo.data_extraction.lean import ( + _to_commit_hash, + get_repo_type, + url_to_repo, + get_latest_commit, + is_commit_hash, + GITHUB, + RepoType, +) +from lean_dojo.utils import working_directory + + +def test_github_type(lean4_example_url, example_commit_hash): + repo_name = "lean4-example" + + ## get_latest_commit + gh_cm_hash = get_latest_commit(lean4_example_url) + assert is_commit_hash(gh_cm_hash) + + ## url_to_repo & get_repo_type + github_repo = url_to_repo(lean4_example_url) + assert get_repo_type(lean4_example_url) == RepoType.GITHUB + assert get_repo_type("git@github.com:yangky11/lean4-example.git") == RepoType.GITHUB + assert get_repo_type("git@github.com:yangky11/lean4-example") == RepoType.GITHUB + assert isinstance(github_repo, Repository) + assert github_repo.name == repo_name + + ## commit hash + assert _to_commit_hash(github_repo, example_commit_hash) == example_commit_hash + ### test branch, assume this branch is not changing + assert ( + _to_commit_hash(github_repo, "paper") + == "8bf74cf67d1acf652a0c74baaa9dc3b9b9e4098c" + ) + ### test git tag + assert ( + _to_commit_hash(GITHUB.get_repo("leanprover/lean4"), "v4.9.1") + == "1b78cb4836cf626007bd38872956a6fab8910993" + ) + + ## LeanGitRepo + LeanGitRepo(lean4_example_url, "main") # init with branch + repo = LeanGitRepo(lean4_example_url, example_commit_hash) + assert repo.url == lean4_example_url + assert repo.repo_type == RepoType.GITHUB + assert repo.commit == example_commit_hash + assert repo.exists() + assert repo.name == repo_name + assert repo.lean_version == "v4.7.0" + assert repo.commit_url == f"{lean4_example_url}/tree/{example_commit_hash}" + # cache name + assert isinstance(repo.repo, Repository) + assert ( + str(repo.get_cache_dirname()) == f"yangky11-{repo_name}-{example_commit_hash}" + ) + + +def test_remote_type(remote_example_url, example_commit_hash): + repo_name = "lean4-example" + + remote_repo = url_to_repo(remote_example_url) + assert get_repo_type(remote_example_url) == RepoType.REMOTE + assert isinstance(remote_repo, Repo) + re_cm_hash = get_latest_commit(remote_example_url) + assert re_cm_hash == get_latest_commit(str(remote_repo.working_dir)) + assert _to_commit_hash(remote_repo, example_commit_hash) == example_commit_hash + assert ( + _to_commit_hash(remote_repo, "origin/paper") + == "8bf74cf67d1acf652a0c74baaa9dc3b9b9e4098c" + ) + + ## LeanGitRepo + LeanGitRepo(remote_example_url, "main") + repo = LeanGitRepo(remote_example_url, example_commit_hash) + assert repo.url == remote_example_url + assert repo.repo_type == RepoType.REMOTE + assert repo.commit == example_commit_hash + assert repo.exists() + assert repo.name == repo_name + assert repo.lean_version == "v4.7.0" + assert repo.commit_url == f"{remote_example_url}/tree/{example_commit_hash}" + # cache name + assert isinstance(repo.repo, Repo) + assert ( + str(repo.get_cache_dirname()) == f"gitpython-{repo_name}-{example_commit_hash}" + ) + + +def test_local_type(lean4_example_url, example_commit_hash): + repo_name = "lean4-example" + gh_cm_hash = get_latest_commit(lean4_example_url) + + with working_directory() as tmp_dir: + # git repo placed in `tmp_dir / repo_name` + Repo.clone_from(lean4_example_url, repo_name) + + ## get_latest_commit + local_url = str((tmp_dir / repo_name).absolute()) + assert get_latest_commit(local_url) == gh_cm_hash + + ## url_to_repo & get_repo_type + local_repo = url_to_repo(local_url, repo_type=RepoType.LOCAL) + assert get_repo_type(local_url) == RepoType.LOCAL + assert isinstance(local_repo, Repo) + assert ( + local_repo.working_dir != local_url + ), "The working directory should not be the same as the original repo" + + ## commit hash + repo = Repo(local_url) + repo.git.checkout(example_commit_hash) + repo.create_tag("v0.1.0") # create a tag for the example commit hash + repo.git.checkout("main") # switch back to main branch + assert _to_commit_hash(repo, example_commit_hash) == example_commit_hash + assert ( + _to_commit_hash(repo, "origin/paper") + == "8bf74cf67d1acf652a0c74baaa9dc3b9b9e4098c" + ) + assert _to_commit_hash(repo, "v0.1.0") == example_commit_hash + + ## LeanGitRepo + LeanGitRepo(local_url, "main") + repo = LeanGitRepo(local_url, example_commit_hash) + repo2 = LeanGitRepo.from_path(local_url) # test from_path + assert repo.url == local_url == repo2.url + assert repo.repo_type == RepoType.LOCAL == repo2.repo_type + assert repo.commit == example_commit_hash and repo2.commit == gh_cm_hash + assert repo.exists() and repo2.exists() + assert repo.name == repo_name == repo2.name + assert repo.lean_version == "v4.7.0" + # cache name + assert isinstance(repo.repo, Repo) and isinstance(repo2.repo, Repo) + assert ( + str(repo.get_cache_dirname()) + == f"gitpython-{repo_name}-{example_commit_hash}" + ) + assert str(repo2.get_cache_dirname()) == f"gitpython-{repo_name}-{gh_cm_hash}" diff --git a/tests/data_extraction/test_trace.py b/tests/data_extraction/test_trace.py index 0064933..0c9710a 100644 --- a/tests/data_extraction/test_trace.py +++ b/tests/data_extraction/test_trace.py @@ -1,5 +1,42 @@ from pathlib import Path from lean_dojo import * +from lean_dojo.data_extraction.cache import cache +from lean_dojo.utils import working_directory +from lean_dojo.data_extraction.lean import RepoType +from git import Repo + + +def test_github_trace(lean4_example_url): + # github + github_repo = LeanGitRepo(lean4_example_url, "main") + assert github_repo.repo_type == RepoType.GITHUB + trace_repo = trace(github_repo) + path = cache.get(github_repo.get_cache_dirname() / github_repo.name) + assert path is not None + + +def test_remote_trace(remote_example_url): + # remote + remote_repo = LeanGitRepo(remote_example_url, "main") + assert remote_repo.repo_type == RepoType.REMOTE + trace_repo = trace(remote_repo) + path = cache.get(remote_repo.get_cache_dirname() / remote_repo.name) + assert path is not None + + +def test_local_trace(lean4_example_url): + # local + with working_directory() as tmp_dir: + # git repo placed in `tmp_dir / repo_name` + Repo.clone_from(lean4_example_url, "lean4-example") + local_dir = str((tmp_dir / "lean4-example")) + local_url = str((tmp_dir / "lean4-example").absolute()) + local_repo = LeanGitRepo(local_dir, "main") + assert local_repo.url == local_url + assert local_repo.repo_type == RepoType.LOCAL + trace_repo = trace(local_repo) + path = cache.get(local_repo.get_cache_dirname() / local_repo.name) + assert path is not None def test_trace(traced_repo): diff --git a/tests/interaction/test_interaction.py b/tests/interaction/test_interaction.py new file mode 100644 index 0000000..83b1c45 --- /dev/null +++ b/tests/interaction/test_interaction.py @@ -0,0 +1,66 @@ +import os +from git import Repo +from lean_dojo.utils import working_directory +from lean_dojo.data_extraction.lean import RepoType +from lean_dojo import LeanGitRepo, Dojo, ProofFinished, ProofGivenUp, Theorem + + +# Avoid using remote cache +os.environ["DISABLE_REMOTE_CACHE"] = "1" + + +def test_github_interact(lean4_example_url): + repo = LeanGitRepo(url=lean4_example_url, commit="main") + assert repo.repo_type == RepoType.GITHUB + theorem = Theorem(repo, "Lean4Example.lean", "hello_world") + # initial state + dojo, state_0 = Dojo(theorem).__enter__() + assert state_0.pp == "a b c : Nat\n⊢ a + b + c = a + c + b" + # state after running a tactic + state_1 = dojo.run_tac(state_0, "rw [add_assoc]") + assert state_1.pp == "a b c : Nat\n⊢ a + (b + c) = a + c + b" + # state after running another a sorry tactic + assert dojo.run_tac(state_1, "sorry") == ProofGivenUp() + # finish proof + final_state = dojo.run_tac(state_1, "rw [add_comm b, ←add_assoc]") + assert isinstance(final_state, ProofFinished) + + +def test_remote_interact(remote_example_url): + repo = LeanGitRepo(url=remote_example_url, commit="main") + assert repo.repo_type == RepoType.REMOTE + theorem = Theorem(repo, "Lean4Example.lean", "hello_world") + # initial state + dojo, state_0 = Dojo(theorem).__enter__() + assert state_0.pp == "a b c : Nat\n⊢ a + b + c = a + c + b" + # state after running a tactic + state_1 = dojo.run_tac(state_0, "rw [add_assoc]") + assert state_1.pp == "a b c : Nat\n⊢ a + (b + c) = a + c + b" + # state after running another a sorry tactic + assert dojo.run_tac(state_1, "sorry") == ProofGivenUp() + # finish proof + final_state = dojo.run_tac(state_1, "rw [add_comm b, ←add_assoc]") + assert isinstance(final_state, ProofFinished) + + +def test_local_interact(lean4_example_url): + # Clone the GitHub repository to the local path + with working_directory() as tmp_dir: + # git repo placed in `tmp_dir / repo_name` + Repo.clone_from(lean4_example_url, "lean4-example") + + local_dir = str((tmp_dir / "lean4-example")) + repo = LeanGitRepo(local_dir, commit="main") + assert repo.repo_type == RepoType.LOCAL + theorem = Theorem(repo, "Lean4Example.lean", "hello_world") + # initial state + dojo, state_0 = Dojo(theorem).__enter__() + assert state_0.pp == "a b c : Nat\n⊢ a + b + c = a + c + b" + # state after running a tactic + state_1 = dojo.run_tac(state_0, "rw [add_assoc]") + assert state_1.pp == "a b c : Nat\n⊢ a + (b + c) = a + c + b" + # state after running another a sorry tactic + assert dojo.run_tac(state_1, "sorry") == ProofGivenUp() + # finish proof + final_state = dojo.run_tac(state_1, "rw [add_comm b, ←add_assoc]") + assert isinstance(final_state, ProofFinished) diff --git a/tests/interaction/test_tactics.py b/tests/interaction/test_tactics.py index c0a2ee8..1aba648 100644 --- a/tests/interaction/test_tactics.py +++ b/tests/interaction/test_tactics.py @@ -37,10 +37,10 @@ def test_example_append_subset(batteries_repo: LeanGitRepo) -> None: thm = Theorem( batteries_repo, "Batteries/Data/List/Lemmas.lean", - "List.append_subset", + "List.disjoint_append_left", ) with Dojo(thm) as (dojo, s0): - s1 = dojo.run_tac(s0, "simp [subset_def, or_imp, forall_and]") + s1 = dojo.run_tac(s0, "simp [Disjoint, or_imp, forall_and]") assert isinstance(s1, ProofFinished) assert dojo.is_successful @@ -130,10 +130,10 @@ def test_example_length_le(batteries_repo: LeanGitRepo) -> None: thm = Theorem( batteries_repo, "Batteries/Data/List/Lemmas.lean", - "List.IsSuffix.length_le", + "List.disjoint_of_disjoint_append_right_right", ) with Dojo(thm) as (dojo, s0): - s1 = dojo.run_tac(s0, "exact h.sublist.length_le") + s1 = dojo.run_tac(s0, "exact (disjoint_append_right.1 d).2") assert isinstance(s1, ProofFinished) assert dojo.is_successful