diff --git a/easy_rec/python/model/easy_rec_model.py b/easy_rec/python/model/easy_rec_model.py index 66de98922..e45010553 100644 --- a/easy_rec/python/model/easy_rec_model.py +++ b/easy_rec/python/model/easy_rec_model.py @@ -11,6 +11,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile from easy_rec.python.compat import regularizers from easy_rec.python.layers import input_layer @@ -379,13 +380,13 @@ def _get_restore_vars(self, ckpt_var_map_path): name2var[var_name] = [one_var] if is_part else one_var if ckpt_var_map_path != '': - if not tf.gfile.Exists(ckpt_var_map_path): + if not gfile.Exists(ckpt_var_map_path): logging.warning('%s not exist' % ckpt_var_map_path) return name2var # load var map name_map = {} - with open(ckpt_var_map_path, 'r') as fin: + with gfile.GFile(ckpt_var_map_path, 'r') as fin: for one_line in fin: one_line = one_line.strip() line_tok = [x for x in one_line.split() if x != ''] @@ -393,14 +394,16 @@ def _get_restore_vars(self, ckpt_var_map_path): logging.warning('Failed to process: %s' % one_line) continue name_map[line_tok[0]] = line_tok[1] - var_map = {} + update_map = {} + old_keys = [] for var_name in name2var: if var_name in name_map: in_ckpt_name = name_map[var_name] - var_map[in_ckpt_name] = name2var[var_name] - else: - logging.warning('Failed to find in var_map_file(%s): %s' % - (ckpt_var_map_path, var_name)) + update_map[in_ckpt_name] = name2var[var_name] + old_keys.append(var_name) + for tmp_key in old_keys: + del name2var[tmp_key] + name2var.update(update_map) return name2var else: var_filter, scope_update = self.get_restore_filter()