diff --git a/easy_rec/python/utils/io_util.py b/easy_rec/python/utils/io_util.py index cfe20d4ac..d5a949630 100644 --- a/easy_rec/python/utils/io_util.py +++ b/easy_rec/python/utils/io_util.py @@ -199,15 +199,21 @@ def convert_tf_flags_to_argparse(flags): import ast parser = argparse.ArgumentParser() - args = set() + args = {} for flag in flags._flags().values(): flag_name = flag.name if flag_name in args: + args[flag_name][0] = True continue - args.add(flag_name) default = flag.value flag_type = type(default) help_str = flag.help or '' + args[flag_name] = [ + False, flag_type, default, help_str, + flag.choices if hasattr(flag, 'choices') else None + ] + + for flag_name, (multi, flag_type, default, help_str, choices) in args.items(): if flag_type == bool: parser.add_argument( '--' + flag_name, @@ -215,13 +221,16 @@ def convert_tf_flags_to_argparse(flags): action='store_true' if default else 'store_false', help=help_str) elif flag_type == str: - if hasattr(flag, 'choices') and flag.choices: + if choices: parser.add_argument( '--' + flag_name, type=str, - choices=flag.choices, + choices=choices, default=default, help=help_str) + elif multi: + parser.add_argument( + '--' + flag_name, type=str, default=default, help=help_str) else: parser.add_argument( '--' + flag_name, type=str, default=default, help=help_str) @@ -231,9 +240,12 @@ def convert_tf_flags_to_argparse(flags): type=lambda s: ast.literal_eval(s), default=default, help=help_str) - else: + elif flag_type in (int, float): parser.add_argument( '--' + flag_name, type=flag_type, default=default, help=help_str) + else: + parser.add_argument( + '--' + flag_name, type=str, default=default, help=help_str) return parser @@ -245,7 +257,7 @@ def filter_unknown_args(flags, args): if len(unknown) > 1: logging.info('undefined arguments: %s', ', '.join(unknown[1:])) for key, value in vars(args).items(): - if type(value) != bool and not value: + if type(value) in (list, dict) and not value: continue known_args.append('--' + key + '=' + str(value)) logging.info('defined arguments: %s', ', '.join(known_args[1:]))