Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add masknet variable name prefix #478

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions easy_rec/python/layers/keras/mask_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self, params, name='mask_block', reuse=None, **kwargs):
self._projection_dim = params.get_or_default('projection_dim', None)
self.reuse = reuse
self.final_relu = Activation('relu', name='relu')
self._name = name

def build(self, input_shape):
assert len(input_shape) >= 2, 'MaskBlock must has at least two inputs'
Expand All @@ -49,30 +50,30 @@ def build(self, input_shape):
activation='relu',
kernel_initializer='he_uniform',
kernel_regularizer=self.l2_reg,
name='aggregation')
name='%s/aggregation' % self._name)
self.weight_layer = Dense(input_dim, name='weights')
if self._projection_dim is not None:
self.project_layer = Dense(
self._projection_dim,
kernel_regularizer=self.l2_reg,
use_bias=False,
name='project')
name='%s/project' % self._name)
if self.config.input_layer_norm:
# 推荐在调用MaskBlock之前做好 layer norm,否则每一次调用都需要对input做ln
if tf.__version__ >= '2.0':
self.input_layer_norm = tf.keras.layers.LayerNormalization(
name='input_ln')
name='%s/input_ln' % self._name)
else:
self.input_layer_norm = LayerNormalization(name='input_ln')
self.input_layer_norm = LayerNormalization(name='%s/input_ln' % self._name)

if self.config.HasField('output_size'):
self.output_layer = Dense(
self.config.output_size, use_bias=False, name='output')
self.config.output_size, use_bias=False, name='%s/output' % self._name)
if tf.__version__ >= '2.0':
self.output_layer_norm = tf.keras.layers.LayerNormalization(
name='output_ln')
else:
self.output_layer_norm = LayerNormalization(name='output_ln')
self.output_layer_norm = LayerNormalization(name='%s/output_ln' % self._name)

def call(self, inputs, **kwargs):
net, mask_input = inputs
Expand Down Expand Up @@ -104,6 +105,7 @@ class MaskNet(Layer):

def __init__(self, params, name='mask_net', reuse=None, **kwargs):
super(MaskNet, self).__init__(name, **kwargs)
self._name = name
self.reuse = reuse
self.params = params
self.config = params.get_pb_config()
Expand All @@ -118,15 +120,15 @@ def __init__(self, params, name='mask_net', reuse=None, **kwargs):
for i, block_conf in enumerate(self.config.mask_blocks):
params = Parameter.make_from_pb(block_conf)
params.l2_regularizer = self.params.l2_regularizer
mask_layer = MaskBlock(params, name='block_%d' % i, reuse=self.reuse)
mask_layer = MaskBlock(params, name='%s/block_%d' % (self._name, i), reuse=self.reuse)
self.mask_layers.append(mask_layer)

if self.config.input_layer_norm:
if tf.__version__ >= '2.0':
self.input_layer_norm = tf.keras.layers.LayerNormalization(
name='input_ln')
name='%s/input_ln' % self._name)
else:
self.input_layer_norm = LayerNormalization(name='input_ln')
self.input_layer_norm = LayerNormalization(name='%s/input_ln' % self._name)

def call(self, inputs, training=None, **kwargs):
if self.config.input_layer_norm:
Expand Down
Loading