diff --git a/packages/subiquity_client/generator/Makefile b/packages/subiquity_client/generator/Makefile index 0704c5994..386d4f497 100644 --- a/packages/subiquity_client/generator/Makefile +++ b/packages/subiquity_client/generator/Makefile @@ -2,7 +2,7 @@ types_py = ../subiquity/subiquity/common/types.py types_dart = ../lib/src/types.dart generate: - python3 generator.py $(types_py) $(types_dart) + python3 generator.py $(types_py) --output $(types_dart) dart format $(types_dart) check: diff --git a/packages/subiquity_client/generator/generator.py b/packages/subiquity_client/generator/generator.py index 9d725653a..cab74d9a7 100755 --- a/packages/subiquity_client/generator/generator.py +++ b/packages/subiquity_client/generator/generator.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import argparse import ast import fileinput import os @@ -355,17 +356,43 @@ def __str__(self): }} """ +class PythonInputFile(argparse.FileType): + def __init__(self): + super().__init__("r", encoding="utf-8") + + def __call__(self, path: str): + if path != "-" and not path.endswith(".py"): + raise argparse.ArgumentTypeError(f"{path} does not have the .py extension") + + return super().__call__(path) + + +def DartOutputFile(path: str | None) -> str | None: + if path is None or path == "-": + return None + + if not path.endswith(".dart"): + raise argparse.ArgumentTypeError(f"{path} does not have a .dart extension") + return path + def main(): - input = sys.argv[1] if len(sys.argv) > 1 else "" - output = sys.argv[2] if len(sys.argv) == 3 else None + parser = argparse.ArgumentParser() + + # FileType is supposed to work with default="-" but it does not when used + # with positional arguments. + parser.add_argument("python-input-files", nargs="+", + type=PythonInputFile(), metavar="input.py") + parser.add_argument("--output", default="-", metavar="output.dart", + type=DartOutputFile) + + args = vars(parser.parse_args()) - if not input.endswith(".py") or (output is not None and not output.endswith(".dart")): - print("usage: generator ()") - return + input_files = args["python-input-files"] + output = args["output"] generator = Generator() - with open(sys.argv[1], "r") as file: + for file in input_files: generator.parse(file.read()) data = generator.to_dart()