From 9b859ef1b60d23a2c20950b398dd279d2bb5e258 Mon Sep 17 00:00:00 2001 From: Jason Hoelscher-Obermaier Date: Mon, 7 Aug 2023 07:58:54 +0200 Subject: [PATCH] Add scripts for type-checking rst doc files and Jupyter notebooks --- ci/typecheck_docs.py | 167 ++++++++++++++++++++++++++++++++++++++ ci/typecheck_notebooks.py | 57 +++++++++++++ setup.py | 1 + 3 files changed, 225 insertions(+) create mode 100644 ci/typecheck_docs.py create mode 100644 ci/typecheck_notebooks.py diff --git a/ci/typecheck_docs.py b/ci/typecheck_docs.py new file mode 100644 index 000000000..d70b96ada --- /dev/null +++ b/ci/typecheck_docs.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python +"""Type-check all code-blocks inside rst documentation files.""" +import argparse +import sys +import tempfile +from functools import partial +from io import StringIO +from pathlib import Path +from typing import Dict, List + +from docutils.core import publish_doctree # type: ignore +from docutils.nodes import literal_block +from mypy import api + +_TMP_DIR = Path(tempfile.gettempdir()) / "imitation" / "typecheck" +_REPO_DIR = Path(__file__).parent.parent + + +_info = partial(print, file=sys.stderr) + + +def get_files(input_paths: List) -> List[Path]: + """Build list of files to scan from list of paths and files.""" + files = [] + for file in input_paths: + if file.is_dir(): + files.extend(file.glob("**/*.rst")) + else: + if file.suffix == ".rst": + files.append(file) + else: + _info(f"Skipping {file} (not a documentation file)") + if not files: + _info("No documentation files found") + sys.exit(1) + return files + + +def get_code_blocks(file: Path) -> Dict[int, str]: + """Find all Python code-blocks inside an rst documentation file. + + Args: + file: The rst documentation file to scan. + + Returns: + Mapping from line number to Python code block. + """ + rst_content = file.read_text() + doc_parse_f = StringIO() + document = publish_doctree( + rst_content, + settings_overrides={"warning_stream": doc_parse_f}, + ) + + python_blocks = {} + for node in document.traverse(literal_block): + if "code" in node.get("classes") and "python" in node.get("classes"): + src_text = node.astext() + end_line = node.line # node.line = line number of the end of the block + start_line = end_line - len(src_text.split("\n")) + python_blocks[start_line] = src_text + + return python_blocks + + +def typecheck_doc_file(file: Path) -> List[str]: + """Type-check Python code-blocks inside an rst documentation file using pytype/mypy. + + Args: + file: The rst documentation file to type-check. + + Returns: + List of type errors (str) in the documentation code-blocks. + """ + code_blocks = get_code_blocks(file) + file = file.relative_to(_REPO_DIR) + tmp = _TMP_DIR / file + tmp.parent.mkdir(parents=True, exist_ok=True) + + errors = [] + for line, code_block in code_blocks.items(): + temp_file = tmp.with_suffix(f".{line}.py") + temp_file.write_text(code_block) + file_errors = mypy_codeblock(temp_file) + + def post_process_error_msg(error_msg: str) -> str: + """Change the error message to use the original file path and line number. + + Replaces temp_file path with original path in error_msg and + recalculates the line number. + + Args: + error_msg: The error message to post-process. + + Returns: + The post-processed error message in the standard mypy format. + """ + try: + path, line_no, *rest = error_msg.split(":") + return ":".join([str(file), str(line + int(line_no) - 1), *rest]) + except ValueError: + # error_msg is not a std mypy error message + return error_msg + + errors += [post_process_error_msg(msg) for msg in file_errors] + return errors + + +def mypy_codeblock(codeblock: Path) -> List: + stdout, stderr, exit_status = api.run([str(codeblock)]) + if exit == 0 or not stdout or "no issues found" in stdout: + return [] + # format of stdout output: + # /:6: error: Name "policy" is not defined [name-defined] + # /:8: error: Too many positional arguments for "register" ... + # Found 2 errors in 1 file (checked 1 source file) + return stdout.strip().split("\n")[:-1] # last line is redundant + + +def parse_args(): + """Parse command-line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "files", + nargs="+", + type=Path, + help="List of files or paths to check", + ) + args = parser.parse_args() + return parser, args + + +def main(): + """Type-check all code-blocks inside rst documentation files.""" + parser, args = parse_args() + input_paths = args.files + + if len(input_paths) == 0: + parser.print_help() + sys.exit(1) + + files = get_files(input_paths) + + errors = [] + affected_files = 0 + for file in files: + if file_errors := typecheck_doc_file(file): + errors += file_errors + affected_files += 1 + _info(f"{file}: {len(file_errors)} error{'s'[:len(file_errors)^1]}") + else: + _info(f"{file}: OK") + + f = len(files) + e = len(errors) + a = affected_files + print("\n".join(errors)) + _info( + f"Found {e} error{'s'[:e^1]} in {a} file{'s'[:a^1]}" + f" (checked {f} source file{'s'[:f^1]}).", + ) + if errors: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/ci/typecheck_notebooks.py b/ci/typecheck_notebooks.py new file mode 100644 index 000000000..21ab0cd06 --- /dev/null +++ b/ci/typecheck_notebooks.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python +"""Type-check all code-blocks inside Jupyter notebook files. + +This relies on nbQA which can be installed without new extra dependencies. +""" +import argparse +import subprocess +import sys +from functools import partial +from pathlib import Path +from typing import List + +_info = partial(print, file=sys.stderr) + + +def get_files(input_paths: List) -> List[Path]: + """Build list of files to scan from list of paths and files.""" + files = [] + for file in input_paths: + if file.is_dir(): + files.extend(file.glob("**/*.ipynb")) + else: + if file.suffix == ".ipynb": + files.append(file) + else: + _info(f"Skipping {file} (not a Jupyter notebook file)") + if not files: + _info("No Jupyter notebooks found") + sys.exit(1) + return files + + +def parse_args(): + """Parse command-line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "files", + nargs="+", + type=Path, + help="List of files or paths to check", + ) + args = parser.parse_args() + return parser, args + + +def main(): + """Type-check all code inside Jupyter notebook files.""" + parser, args = parse_args() + input_paths = get_files(args.files) + try: + subprocess.run(["nbqa", "mypy", *input_paths], check=True) + except subprocess.CalledProcessError as e: + sys.exit(e.returncode) + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index fa1d03f31..db7213e39 100644 --- a/setup.py +++ b/setup.py @@ -50,6 +50,7 @@ # https://github.com/jupyter/jupyter_client/issues/637 is fixed "jupyter-client~=6.1.12", "mypy~=0.990", + "nbqa~=1.7.0", "pandas~=1.4.3", "pytest~=7.1.2", "pytest-cov~=3.0.0",