Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxudong committed Jan 3, 2025
1 parent 74b1edc commit 8ffbc29
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions easy_rec/python/utils/io_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,29 +199,38 @@ def convert_tf_flags_to_argparse(flags):
import ast
parser = argparse.ArgumentParser()

args = set()
args = {}
for flag in flags._flags().values():
flag_name = flag.name
if flag_name in args:
args[flag_name][0] = True
continue
args.add(flag_name)
default = flag.value
flag_type = type(default)
help_str = flag.help or ''
args[flag_name] = [
False, flag_type, default, help_str,
flag.choices if hasattr(flag, 'choices') else None
]

for flag_name, (multi, flag_type, default, help_str, choices) in args.items():
if flag_type == bool:
parser.add_argument(
'--' + flag_name,
dest=flag_name,
action='store_true' if default else 'store_false',
help=help_str)
elif flag_type == str:
if hasattr(flag, 'choices') and flag.choices:
if choices:
parser.add_argument(
'--' + flag_name,
type=str,
choices=flag.choices,
choices=choices,
default=default,
help=help_str)
elif multi:
parser.add_argument(
'--' + flag_name, type=str, default=default, help=help_str)
else:
parser.add_argument(
'--' + flag_name, type=str, default=default, help=help_str)
Expand All @@ -231,9 +240,12 @@ def convert_tf_flags_to_argparse(flags):
type=lambda s: ast.literal_eval(s),
default=default,
help=help_str)
else:
elif flag_type in (int, float):
parser.add_argument(
'--' + flag_name, type=flag_type, default=default, help=help_str)
else:
parser.add_argument(
'--' + flag_name, type=str, default=default, help=help_str)
return parser


Expand All @@ -245,7 +257,7 @@ def filter_unknown_args(flags, args):
if len(unknown) > 1:
logging.info('undefined arguments: %s', ', '.join(unknown[1:]))
for key, value in vars(args).items():
if type(value) != bool and not value:
if type(value) in (list, dict) and not value:
continue
known_args.append('--' + key + '=' + str(value))
logging.info('defined arguments: %s', ', '.join(known_args[1:]))
Expand Down

0 comments on commit 8ffbc29

Please sign in to comment.