-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
370 additions
and
368 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
""" FormDict tools. | ||
FormDict is not a real class, just a normal dict. But we need to put somewhere functions related to it. | ||
""" | ||
import logging | ||
from argparse import Action, ArgumentParser | ||
from typing import Callable, Optional, TypeVar, Union, get_type_hints | ||
from unittest.mock import patch | ||
|
||
from tyro import cli | ||
from tyro._argparse_formatter import TyroArgumentParser | ||
|
||
from .FormField import FormField | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
ConfigInstance = TypeVar("ConfigInstance") | ||
ConfigClass = Callable[..., ConfigInstance] | ||
FormDict = dict[str, Union[FormField, 'FormDict']] | ||
""" Nested form that can have descriptions (through FormField) instead of plain values. """ | ||
|
||
|
||
def formdict_repr(d: FormDict) -> dict: | ||
""" For the testing purposes, returns a new dict when all FormFields are replaced with their values. """ | ||
out = {} | ||
for k, v in d.items(): | ||
if isinstance(v, FormField): | ||
v = v.val | ||
out[k] = formdict_repr(v) if isinstance(v, dict) else v | ||
return out | ||
|
||
|
||
def dict_to_formdict(data: dict) -> FormDict: | ||
fd = {} | ||
for key, val in data.items(): | ||
if isinstance(val, dict): # nested config hierarchy | ||
fd[key] = dict_to_formdict(val) | ||
else: # scalar value | ||
# NOTE name=param is not set (yet?) in `config_to_formdict`, neither `src` | ||
fd[key] = FormField(val, "", name=key, src=(data, key)) | ||
return fd | ||
|
||
|
||
def config_to_formdict(args: ConfigInstance, descr: dict, _path="") -> FormDict: | ||
""" Convert the dataclass produced by tyro into dict of dicts. """ | ||
main = "" | ||
params = {main: {}} if not _path else {} | ||
for param, val in vars(args).items(): | ||
annotation = None | ||
if val is None: | ||
wanted_type = get_type_hints(args.__class__).get(param) | ||
if wanted_type in (Optional[int], Optional[str]): | ||
# Since tkinter_form does not handle None yet, we have help it. | ||
# We need it to be able to write a number and if empty, return None. | ||
# This would fail: `severity: int | None = None` | ||
# Here, we convert None to str(""), in normalize_types we convert it back. | ||
annotation = wanted_type | ||
val = "" | ||
else: | ||
# An unknown type annotation encountered- | ||
# Since tkinter_form does not handle None yet, this will display as checkbox. | ||
# Which is not probably wanted. | ||
val = False | ||
logger.warn(f"Annotation {wanted_type} of `{param}` not supported by Mininterface." | ||
"None converted to False.") | ||
if hasattr(val, "__dict__"): # nested config hierarchy | ||
params[param] = config_to_formdict(val, descr, _path=f"{_path}{param}.") | ||
elif not _path: # scalar value in root | ||
params[main][param] = FormField(val, descr.get(param), annotation, param, src2=(args, param)) | ||
else: # scalar value in nested | ||
params[param] = FormField(val, descr.get(f"{_path}{param}"), annotation, param, src2=(args, param)) | ||
return params | ||
|
||
|
||
def get_args_allow_missing(config: ConfigClass, kwargs: dict, parser: ArgumentParser): | ||
""" Fetch missing required options in GUI. """ | ||
# On missing argument, tyro fail. We cannot determine which one was missing, except by intercepting | ||
# the error message function. Then, we reconstruct the missing options. | ||
# NOTE But we should rather invoke a GUI with the missing options only. | ||
original_error = TyroArgumentParser.error | ||
eavesdrop = "" | ||
|
||
def custom_error(self, message: str): | ||
nonlocal eavesdrop | ||
if not message.startswith("the following arguments are required:"): | ||
return original_error(self, message) | ||
eavesdrop = message | ||
raise SystemExit(2) # will be catched | ||
try: | ||
with patch.object(TyroArgumentParser, 'error', custom_error): | ||
return cli(config, **kwargs) | ||
except BaseException as e: | ||
if hasattr(e, "code") and e.code == 2 and eavesdrop: # Some arguments are missing. Determine which. | ||
for arg in eavesdrop.partition(":")[2].strip().split(", "): | ||
argument: Action = next(iter(p for p in parser._actions if arg in p.option_strings)) | ||
argument.default = "DEFAULT" # NOTE I do not know whether used | ||
if "." in argument.dest: # missing nested required argument handler not implemented, we make tyro fail in CLI | ||
pass | ||
else: | ||
match argument.metavar: | ||
case "INT": | ||
setattr(kwargs["default"], argument.dest, 0) | ||
case "STR": | ||
setattr(kwargs["default"], argument.dest, "") | ||
case _: | ||
pass # missing handler not implemented, we make tyro fail in CLI | ||
return cli(config, **kwargs) # second attempt | ||
raise |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
from dataclasses import dataclass | ||
from typing import TYPE_CHECKING, Any, Iterable, Optional, TypeVar, get_args | ||
|
||
from .auxiliary import flatten | ||
|
||
if TYPE_CHECKING: | ||
from .FormDict import FormDict | ||
|
||
try: | ||
from tkinter_form import Value | ||
except ImportError: | ||
@dataclass | ||
class Value: | ||
""" This class helps to enrich the field with a description. """ | ||
val: Any | ||
description: str | ||
|
||
|
||
FFValue = TypeVar("FFValue") | ||
TD = TypeVar("TD") | ||
""" dict """ | ||
TK = TypeVar("TK") | ||
""" dict key """ | ||
|
||
|
||
@dataclass | ||
class FormField(Value): | ||
""" Bridge between the input values and a UI widget. | ||
Helps to creates a widget from the input value (includes description etc.), | ||
then transforms the value back (str to int conversion etc). | ||
(Ex: Merge the dict of dicts from the GUI back into the object holding the configuration.) | ||
""" | ||
|
||
annotation: Any | None = None | ||
""" Used for validation. To convert an empty '' to None. """ | ||
name: str | None = None # NOTE: Only TextualInterface uses this by now. | ||
|
||
src: tuple[TD, TK] | None = None | ||
""" The original dict to be updated when UI ends. """ | ||
src2: tuple[TD, TK] | None = None | ||
""" The original object to be updated when UI ends. | ||
NOTE should be merged to `src` | ||
""" | ||
|
||
def __post_init__(self): | ||
self._original_desc = self.description | ||
|
||
def set_error_text(self, s): | ||
self.description = f"{s} {self._original_desc}" | ||
|
||
def update(self, ui_value): | ||
""" UI value → FormField value → original value. (With type conversion and checks.) | ||
The value has been updated in a UI. | ||
Update accordingly the value in the original linked dict | ||
the mininterface was invoked with. | ||
Validates the type and do the transformation. | ||
(Ex: Some values might be nulled from "".) | ||
""" | ||
fixed_value = ui_value | ||
if self.annotation: | ||
if ui_value == "" and type(None) in get_args(self.annotation): | ||
# The user is not able to set the value to None, they left it empty. | ||
# Cast back to None as None is one of the allowed types. | ||
# Ex: `severity: int | None = None` | ||
fixed_value = None | ||
elif self.annotation == Optional[int]: | ||
try: | ||
fixed_value = int(ui_value) | ||
except ValueError: | ||
pass | ||
|
||
if not isinstance(fixed_value, self.annotation): | ||
self.set_error_text(f"Type must be `{self.annotation}`!") | ||
return False # revision needed | ||
|
||
# keep values if revision needed | ||
# We merge new data to the origin. If form is re-submitted, the values will stay there. | ||
self.val = ui_value | ||
|
||
# Store to the source user data | ||
if self.src: | ||
d, k = self.src | ||
d[k] = fixed_value | ||
elif self.src2: | ||
d, k = self.src2 | ||
setattr(d, k, fixed_value) | ||
else: | ||
# This might be user-created object. The user reads directly from this. There is no need to update anything. | ||
pass | ||
return True | ||
# Fixing types: | ||
# This code would support tuple[int, int]: | ||
# | ||
# self.types = get_args(self.annotation) \ | ||
# if isinstance(self.annotation, UnionType) else (self.annotation, ) | ||
# "All possible types in a tuple. Ex 'int | str' -> (int, str)" | ||
# | ||
# | ||
# def convert(self): | ||
# """ Convert the self.value to the given self.type. | ||
# The value might be in str due to CLI or TUI whereas the programs wants bool. | ||
# """ | ||
# # if self.value == "True": | ||
# # return True | ||
# # if self.value == "False": | ||
# # return False | ||
# if type(self.val) is str and str not in self.types: | ||
# try: | ||
# return literal_eval(self.val) # ex: int, tuple[int, int] | ||
# except: | ||
# raise ValueError(f"{self.name}: Cannot convert value {self.val}") | ||
# return self.val | ||
|
||
@staticmethod | ||
def submit_values(updater: Iterable[tuple["FormField", FFValue]]) -> bool: | ||
""" Returns whether the form is alright or whether we should revise it. | ||
Input is tuple of the FormFields and their new values from the UI. | ||
""" | ||
# Why list? We need all the FormField values be updates from the UI. | ||
# If the revision is needed, the UI fetches the values from the FormField. | ||
# We need the keep the values so that the user does not have to re-write them. | ||
return all(list(ff.update(ui_value) for ff, ui_value in updater)) | ||
|
||
@staticmethod | ||
def submit(fd: "FormDict", ui: dict): | ||
""" Returns whether the form is alright or whether we should revise it. | ||
Input is the FormDict and the UI dict in the very same form. | ||
""" | ||
return FormField.submit_values(zip(flatten(fd), flatten(ui))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.