-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #59 from instadeepai/58-instanovo-100-code-tests
InstaNovo 1.0.0 code tests
- Loading branch information
Showing
28 changed files
with
2,582 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -171,7 +171,6 @@ docs/reference | |
coverage/ | ||
checkpoints/ | ||
data/ | ||
docs/ | ||
docs_public/ | ||
logs/ | ||
mlruns/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
--8<-- "LICENSE.md" |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
"""Generate the code reference pages and navigation.""" | ||
|
||
# This script is used by the mkdocs-gen-files plugin (https://oprypin.github.io/mkdocs-gen-files/) | ||
# for MkDocs (https://www.mkdocs.org/). It creates for each module in the code a stub page | ||
# and it creates a "docs/reference/SUMMARY.md" page which contains a Table of Contents with links to | ||
# all the stub pages. When MkDocs runs, it will populate the stub pages with the documentation | ||
# pulled from the docstrings | ||
from __future__ import annotations | ||
|
||
from pathlib import Path | ||
|
||
import mkdocs_gen_files | ||
|
||
# Folders for which we don't want to create code documentation but which can contain *.py files | ||
IGNORE_DIRS = ("build", "docs_public", "docs", "tests", "scripts", "utils", ".venv") | ||
|
||
|
||
def is_ignored_directory(module_path: Path) -> bool: | ||
"""Check if the module path is within any ignored directory.""" | ||
return any(part in IGNORE_DIRS for part in module_path.parts) | ||
|
||
|
||
def is_ignored_file(module_path: Path) -> bool: | ||
"""Check if the file is a test file or ignored file.""" | ||
return module_path.parts[-1].endswith("_test") or module_path.parts[-1] in ( | ||
"mlflow_auth", | ||
"types", | ||
"constants", | ||
) | ||
|
||
|
||
def process_python_files(source_directory: str, module_name: str) -> None: | ||
"""Generate documentation paths for Python files in the source directory.""" | ||
nav = mkdocs_gen_files.Nav() | ||
|
||
for python_file in sorted(Path(source_directory).rglob("*.py")): | ||
relative_module_path = python_file.relative_to(source_directory).with_suffix("") | ||
|
||
if not is_ignored_directory(relative_module_path) and not is_ignored_file( | ||
relative_module_path | ||
): | ||
doc_path = python_file.relative_to( | ||
source_directory, module_name | ||
).with_suffix(".md") | ||
full_doc_path = Path("reference", doc_path) | ||
|
||
parts = tuple(relative_module_path.parts) | ||
|
||
if parts[-1] == "__init__": | ||
parts = parts[:-1] | ||
doc_path = doc_path.with_name("index.md") | ||
full_doc_path = full_doc_path.with_name("index.md") | ||
elif parts[-1] == "__main__": | ||
continue | ||
|
||
nav[parts] = doc_path.as_posix() | ||
|
||
with mkdocs_gen_files.open(full_doc_path, "w") as fd: | ||
ident = ".".join(parts) | ||
fd.write(f"::: {ident}") | ||
|
||
mkdocs_gen_files.set_edit_path(full_doc_path, ".." / python_file) | ||
|
||
with mkdocs_gen_files.open("reference/SUMMARY.md", "w") as nav_file: | ||
nav_file.writelines(nav.build_literate_nav()) | ||
|
||
|
||
process_python_files(source_directory=".", module_name="instanovo") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
--8<-- "README.md" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
defaults: | ||
- instanovo_base | ||
- _self_ | ||
|
||
# Model parameters | ||
dim_model: 320 | ||
n_head: 20 | ||
dim_feedforward: 1280 | ||
n_layers: 6 | ||
dropout: 0. |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
from __future__ import annotations | ||
|
||
import os | ||
import random | ||
import sys | ||
from typing import Any | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
import pytorch_lightning as ptl | ||
import torch | ||
from hydra import compose, initialize | ||
from omegaconf import DictConfig, open_dict | ||
|
||
from instanovo.inference.knapsack_beam_search import KnapsackBeamSearchDecoder | ||
from instanovo.transformer.model import InstaNovo | ||
from instanovo.transformer.predict import _setup_knapsack | ||
|
||
|
||
# Add the root directory to the PYTHONPATH | ||
# This allows pytest to find the modules for testing | ||
|
||
root_dir = os.path.dirname(os.path.dirname(__file__)) | ||
sys.path.append(root_dir) | ||
|
||
|
||
def reset_seed(seed: int = 42) -> None: | ||
"""Function to reset seeds.""" | ||
torch.manual_seed(seed) | ||
if torch.cuda.is_available(): | ||
torch.cuda.manual_seed_all(seed) | ||
np.random.seed(seed) | ||
random.seed(seed) | ||
ptl.seed_everything(seed) | ||
|
||
|
||
@pytest.fixture() | ||
def _reset_seed() -> None: | ||
"""A pytest fixture to reset the seeds at the start of relevant tests.""" | ||
reset_seed() | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def checkpoints_dir() -> str: | ||
"""A pytest fixture to create and provide the absolute path of a 'checkpoints' directory. | ||
Ensures the directory exists for storing checkpoint files during the test session. | ||
""" | ||
checkpoints_dir = "checkpoints" | ||
os.makedirs(checkpoints_dir, exist_ok=True) | ||
return os.path.abspath(checkpoints_dir) | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def instanovo_config() -> DictConfig: | ||
"""A pytest fixture to read in a Hydra config for the Instanovo model unit and integration tests.""" | ||
with initialize(version_base=None, config_path="../instanovo/configs"): | ||
cfg = compose(config_name="instanovo_unit_test") | ||
|
||
sub_configs_list = ["model", "dataset", "residues"] | ||
for sub_name in sub_configs_list: | ||
if sub_name in cfg: | ||
with open_dict(cfg): | ||
temp = cfg[sub_name] | ||
del cfg[sub_name] | ||
cfg.update(temp) | ||
|
||
return cfg | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def instanovo_inference_config() -> DictConfig: | ||
"""A pytest fixture to read in a Hydra config for inference of the Instanovo model unit and integration tests.""" | ||
with initialize(version_base=None, config_path="../instanovo/configs/inference"): | ||
cfg = compose(config_name="unit_test") | ||
|
||
return cfg | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def dir_paths() -> tuple[str, str]: | ||
"""A pytest fixture that returns the root and data directories for the unit and integration tests.""" | ||
root_dir = "./tests/instanovo_test_resources" | ||
data_dir = os.path.join(root_dir, "example_data") | ||
return root_dir, data_dir | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def instanovo_checkpoint(dir_paths: tuple[str, str]) -> str: | ||
"""A pytest fixture that returns the InstaNovo model checkpoint used for unit and integration tests.""" | ||
root_dir, _ = dir_paths | ||
return os.path.join(root_dir, "model.ckpt") | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def instanovo_model( | ||
instanovo_checkpoint: str, | ||
) -> tuple[Any, Any]: | ||
"""A pytest fixture that returns the InstaNovo model and config used for unit and integration tests.""" | ||
model, config = InstaNovo.load(path=instanovo_checkpoint) | ||
return model, config | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def residue_set(instanovo_model: tuple[Any, Any]) -> Any: | ||
"""A pytest fixture to return the model's residue set used for unit and integration tests.""" | ||
model, _ = instanovo_model | ||
return model.residue_set | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def instanovo_output(dir_paths: tuple[str, str]) -> pd.DataFrame: | ||
"""A pytest fixture to load the pre-computed InstaNovo model predictions for unit and integration tests.""" | ||
root_dir, _ = dir_paths | ||
return pd.read_csv(os.path.join(root_dir, "predictions.csv")) | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def knapsack_dir(dir_paths: tuple[str, str]) -> str: | ||
"""A pytest fixture to create and provide the absolute path of a 'knapsack' directory within the checkpoints directory for storing test artifacts.""" | ||
root_dir, _ = dir_paths | ||
knapsack_dir = os.path.join(root_dir, "example_knapsack") | ||
return os.path.abspath(knapsack_dir) | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def setup_knapsack_decoder( | ||
instanovo_model: tuple[Any, Any], knapsack_dir: str | ||
) -> KnapsackBeamSearchDecoder: | ||
"""A pytest fixture to create a Knapsack object.""" | ||
model, _ = instanovo_model | ||
|
||
if os.path.exists(knapsack_dir): | ||
decoder = KnapsackBeamSearchDecoder.from_file(model=model, path=knapsack_dir) | ||
print("Loaded knapsack decoder.") | ||
|
||
else: | ||
knapsack = _setup_knapsack(model) | ||
|
||
knapsack.save(path=knapsack_dir) | ||
print("Created and saved knapsack.") | ||
|
||
decoder = KnapsackBeamSearchDecoder(model, knapsack) | ||
print("Loaded knapsack decoder.") | ||
|
||
return decoder |
Empty file.
Oops, something went wrong.