diff --git a/easy_rec/python/layers/keras/mask_net.py b/easy_rec/python/layers/keras/mask_net.py index de5624944..92077b82f 100644 --- a/easy_rec/python/layers/keras/mask_net.py +++ b/easy_rec/python/layers/keras/mask_net.py @@ -31,6 +31,7 @@ def __init__(self, params, name='mask_block', reuse=None, **kwargs): self._projection_dim = params.get_or_default('projection_dim', None) self.reuse = reuse self.final_relu = Activation('relu', name='relu') + self._name = name def build(self, input_shape): assert len(input_shape) >= 2, 'MaskBlock must has at least two inputs' @@ -49,30 +50,30 @@ def build(self, input_shape): activation='relu', kernel_initializer='he_uniform', kernel_regularizer=self.l2_reg, - name='aggregation') + name='%s/aggregation' % self._name) self.weight_layer = Dense(input_dim, name='weights') if self._projection_dim is not None: self.project_layer = Dense( self._projection_dim, kernel_regularizer=self.l2_reg, use_bias=False, - name='project') + name='%s/project' % self._name) if self.config.input_layer_norm: # 推荐在调用MaskBlock之前做好 layer norm,否则每一次调用都需要对input做ln if tf.__version__ >= '2.0': self.input_layer_norm = tf.keras.layers.LayerNormalization( - name='input_ln') + name='%s/input_ln' % self._name) else: - self.input_layer_norm = LayerNormalization(name='input_ln') + self.input_layer_norm = LayerNormalization(name='%s/input_ln' % self._name) if self.config.HasField('output_size'): self.output_layer = Dense( - self.config.output_size, use_bias=False, name='output') + self.config.output_size, use_bias=False, name='%s/output' % self._name) if tf.__version__ >= '2.0': self.output_layer_norm = tf.keras.layers.LayerNormalization( name='output_ln') else: - self.output_layer_norm = LayerNormalization(name='output_ln') + self.output_layer_norm = LayerNormalization(name='%s/output_ln' % self._name) def call(self, inputs, **kwargs): net, mask_input = inputs @@ -104,6 +105,7 @@ class MaskNet(Layer): def __init__(self, params, name='mask_net', reuse=None, **kwargs): super(MaskNet, self).__init__(name, **kwargs) + self._name = name self.reuse = reuse self.params = params self.config = params.get_pb_config() @@ -118,15 +120,15 @@ def __init__(self, params, name='mask_net', reuse=None, **kwargs): for i, block_conf in enumerate(self.config.mask_blocks): params = Parameter.make_from_pb(block_conf) params.l2_regularizer = self.params.l2_regularizer - mask_layer = MaskBlock(params, name='block_%d' % i, reuse=self.reuse) + mask_layer = MaskBlock(params, name='%s/block_%d' % (self._name, i), reuse=self.reuse) self.mask_layers.append(mask_layer) if self.config.input_layer_norm: if tf.__version__ >= '2.0': self.input_layer_norm = tf.keras.layers.LayerNormalization( - name='input_ln') + name='%s/input_ln' % self._name) else: - self.input_layer_norm = LayerNormalization(name='input_ln') + self.input_layer_norm = LayerNormalization(name='%s/input_ln' % self._name) def call(self, inputs, training=None, **kwargs): if self.config.input_layer_norm: