From 871a40e6407d0ba0ff60e91b5f82f1fee4d381fa Mon Sep 17 00:00:00 2001 From: "weisu.yxd" Date: Wed, 12 Jun 2024 15:10:19 +0800 Subject: [PATCH] add attention layer and AITM model --- docs/source/component/backbone.md | 1 + docs/source/component/component.md | 27 +++ easy_rec/python/layers/keras/__init__.py | 1 + easy_rec/python/layers/keras/attention.py | 280 ++++++++++++++++++++++ easy_rec/python/model/easy_rec_model.py | 2 + easy_rec/python/model/multi_task_model.py | 29 ++- easy_rec/python/protos/tower.proto | 5 +- 7 files changed, 343 insertions(+), 2 deletions(-) create mode 100644 easy_rec/python/layers/keras/attention.py diff --git a/docs/source/component/backbone.md b/docs/source/component/backbone.md index 2a0ec03a5..1bcf9e7d5 100644 --- a/docs/source/component/backbone.md +++ b/docs/source/component/backbone.md @@ -1118,6 +1118,7 @@ MovieLens-1M数据集效果: | Cross | bit-wise交叉 | DCN v2模型的组件 | [案例3](#dcn) | | BiLinear | 双线性 | FiBiNet模型的组件 | [fibinet_on_movielens.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/fibinet_on_movielens.config) | | FiBiNet | SENet & BiLinear | FiBiNet模型 | [fibinet_on_movielens.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/fibinet_on_movielens.config) | +| Attention | Dot-product attention | Transformer模型的组件 | | ## 3.特征重要度学习组件 diff --git a/docs/source/component/component.md b/docs/source/component/component.md index 897e53162..fbb276f80 100644 --- a/docs/source/component/component.md +++ b/docs/source/component/component.md @@ -79,6 +79,33 @@ | senet | SENet | | protobuf message | | mlp | MLP | | protobuf message | +- Attention + +Dot-product attention layer, a.k.a. Luong-style attention. + +The calculation follows the steps: + +1. Calculate attention scores using query and key with shape (batch_size, Tq, Tv). +2. Use scores to calculate a softmax distribution with shape (batch_size, Tq, Tv). +3. Use the softmax distribution to create a linear combination of value with shape (batch_size, Tq, dim). + +| 参数 | 类型 | 默认值 | 说明 | +| -------- | -------- | --- | ---------------- | +| use_scale | bool | False | If True, will create a scalar variable to scale the attention scores. | +| score_mode | string | dot | Function to use to compute attention scores, one of {"dot", "concat"}. "dot" refers to the dot product between the query and key vectors. "concat" refers to the hyperbolic tangent of the concatenation of the query and key vectors. | +| dropout | float | 0.0 | Float between 0 and 1. Fraction of the units to drop for the attention scores. | +| seed | int | None | A Python integer to use as random seed incase of dropout. | +| return_attention_scores | bool | False | if True, returns the attention scores (after masking and softmax) as an additional output argument. | +| use_causal_mask | bool | False | Set to True for decoder self-attention. Adds a mask such that position i cannot attend to positions j > i. This prevents the flow of information from the future towards the past. | + + - inputs: List of the following tensors: + - query: Query tensor of shape (batch_size, Tq, dim). + - value: Value tensor of shape (batch_size, Tv, dim). + - key: Optional key tensor of shape (batch_size, Tv, dim). If not given, will use value for both key and value, which is the most common case. + - output: + - Attention outputs of shape (batch_size, Tq, dim). + - (Optional) Attention scores after masking and softmax with shape (batch_size, Tq, Tv). + ## 3.特征重要度学习组件 - SENet diff --git a/easy_rec/python/layers/keras/__init__.py b/easy_rec/python/layers/keras/__init__.py index 0e59090ce..e4f1f641e 100644 --- a/easy_rec/python/layers/keras/__init__.py +++ b/easy_rec/python/layers/keras/__init__.py @@ -1,4 +1,5 @@ from .auxiliary_loss import AuxiliaryLoss +from .attention import Attention from .blocks import MLP from .blocks import Gate from .blocks import Highway diff --git a/easy_rec/python/layers/keras/attention.py b/easy_rec/python/layers/keras/attention.py new file mode 100644 index 000000000..d8ee108c8 --- /dev/null +++ b/easy_rec/python/layers/keras/attention.py @@ -0,0 +1,280 @@ +# -*- encoding:utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +"""Attention layers that can be used in sequence DNN/CNN models. + +This file follows the terminology of https://arxiv.org/abs/1706.03762 Figure 2. +Attention is formed by three tensors: Query, Key and Value. +""" +from tensorflow.python.keras.layers import Layer +import tensorflow as tf + + +class Attention(Layer): + """Dot-product attention layer, a.k.a. Luong-style attention. + + Inputs are a list with 2 or 3 elements: + 1. A `query` tensor of shape `(batch_size, Tq, dim)`. + 2. A `value` tensor of shape `(batch_size, Tv, dim)`. + 3. A optional `key` tensor of shape `(batch_size, Tv, dim)`. If none + supplied, `value` will be used as a `key`. + + The calculation follows the steps: + 1. Calculate attention scores using `query` and `key` with shape + `(batch_size, Tq, Tv)`. + 2. Use scores to calculate a softmax distribution with shape + `(batch_size, Tq, Tv)`. + 3. Use the softmax distribution to create a linear combination of `value` + with shape `(batch_size, Tq, dim)`. + + Args: + use_scale: If `True`, will create a scalar variable to scale the + attention scores. + dropout: Float between 0 and 1. Fraction of the units to drop for the + attention scores. Defaults to `0.0`. + seed: A Python integer to use as random seed in case of `dropout`. + score_mode: Function to use to compute attention scores, one of + `{"dot", "concat"}`. `"dot"` refers to the dot product between the + query and key vectors. `"concat"` refers to the hyperbolic tangent + of the concatenation of the `query` and `key` vectors. + + Call Args: + inputs: List of the following tensors: + - `query`: Query tensor of shape `(batch_size, Tq, dim)`. + - `value`: Value tensor of shape `(batch_size, Tv, dim)`. + - `key`: Optional key tensor of shape `(batch_size, Tv, dim)`. If + not given, will use `value` for both `key` and `value`, which is + the most common case. + mask: List of the following tensors: + - `query_mask`: A boolean mask tensor of shape `(batch_size, Tq)`. + If given, the output will be zero at the positions where + `mask==False`. + - `value_mask`: A boolean mask tensor of shape `(batch_size, Tv)`. + If given, will apply the mask such that values at positions + where `mask==False` do not contribute to the result. + return_attention_scores: bool, it `True`, returns the attention scores + (after masking and softmax) as an additional output argument. + training: Python boolean indicating whether the layer should behave in + training mode (adding dropout) or in inference mode (no dropout). + use_causal_mask: Boolean. Set to `True` for decoder self-attention. Adds + a mask such that position `i` cannot attend to positions `j > i`. + This prevents the flow of information from the future towards the + past. Defaults to `False`. + + Output: + Attention outputs of shape `(batch_size, Tq, dim)`. + (Optional) Attention scores after masking and softmax with shape + `(batch_size, Tq, Tv)`. + """ + + def __init__(self, params, name='attention', reuse=None, **kwargs): + super(Attention, self).__init__(name=name, **kwargs) + self.use_scale = params.get_or_default('use_scale', False) + self.scale_by_dim = params.get_or_default('scale_by_dim', False) + self.score_mode = params.get_or_default('score_mode', 'dot') + if self.score_mode not in ["dot", "concat"]: + raise ValueError( + "Invalid value for argument score_mode. " + "Expected one of {'dot', 'concat'}. " + "Received: score_mode=%s" % self.score_mode + ) + self.dropout = params.get_or_default('dropout', 0.0) + self.seed = params.get_or_default('seed', None) + self.scale = None + self.concat_score_weight = None + self.return_attention_scores = params.get_or_default('return_attention_scores', False) + self.use_causal_mask = params.get_or_default('use_causal_mask', False) + + def build(self, input_shape): + self._validate_inputs(input_shape) + if self.use_scale: + self.scale = self.add_weight( + name="scale", + shape=(), + initializer="ones", + dtype=self.dtype, + trainable=True, + ) + if self.score_mode == "concat": + self.concat_score_weight = self.add_weight( + name="concat_score_weight", + shape=(), + initializer="ones", + dtype=self.dtype, + trainable=True, + ) + self.built = True + + def _calculate_scores(self, query, key): + """Calculates attention scores as a query-key dot product. + + Args: + query: Query tensor of shape `(batch_size, Tq, dim)`. + key: Key tensor of shape `(batch_size, Tv, dim)`. + + Returns: + Tensor of shape `(batch_size, Tq, Tv)`. + """ + if self.score_mode == "dot": + scores = tf.matmul(query, tf.transpose(key, axes=[0, 2, 1])) + if self.scale is not None: + scores *= self.scale + elif self.scale_by_dim: + dk = tf.cast(tf.shape(key)[-1], tf.float32) + scores /= tf.math.sqrt(dk) + elif self.score_mode == "concat": + # Reshape tensors to enable broadcasting. + # Reshape into [batch_size, Tq, 1, dim]. + q_reshaped = tf.expand_dims(query, axis=-2) + # Reshape into [batch_size, 1, Tv, dim]. + k_reshaped = tf.expand_dims(key, axis=-3) + if self.scale is not None: + scores = self.concat_score_weight * tf.reduce_sum( + tf.tanh(self.scale * (q_reshaped + k_reshaped)), axis=-1 + ) + else: + scores = self.concat_score_weight * tf.reduce_sum( + tf.tanh(q_reshaped + k_reshaped), axis=-1 + ) + return scores + + def _apply_scores(self, scores, value, scores_mask=None, training=False): + """Applies attention scores to the given value tensor. + + To use this method in your attention layer, follow the steps: + + * Use `query` tensor of shape `(batch_size, Tq)` and `key` tensor of + shape `(batch_size, Tv)` to calculate the attention `scores`. + * Pass `scores` and `value` tensors to this method. The method applies + `scores_mask`, calculates + `attention_distribution = softmax(scores)`, then returns + `matmul(attention_distribution, value). + * Apply `query_mask` and return the result. + + Args: + scores: Scores float tensor of shape `(batch_size, Tq, Tv)`. + value: Value tensor of shape `(batch_size, Tv, dim)`. + scores_mask: A boolean mask tensor of shape `(batch_size, 1, Tv)` + or `(batch_size, Tq, Tv)`. If given, scores at positions where + `scores_mask==False` do not contribute to the result. It must + contain at least one `True` value in each line along the last + dimension. + training: Python boolean indicating whether the layer should behave + in training mode (adding dropout) or in inference mode + (no dropout). + + Returns: + Tensor of shape `(batch_size, Tq, dim)`. + Attention scores after masking and softmax with shape + `(batch_size, Tq, Tv)`. + """ + if scores_mask is not None: + padding_mask = tf.logical_not(scores_mask) + # Bias so padding positions do not contribute to attention + # distribution. Note 65504. is the max float16 value. + max_value = 65504.0 if scores.dtype == "float16" else 1.0e9 + scores -= max_value * tf.cast(padding_mask, dtype=scores.dtype) + + weights = tf.nn.softmax(scores, axis=-1) + if training and self.dropout > 0: + weights = tf.nn.dropout( + weights, + 1.0 - self.dropout, + seed=self.seed + ) + return tf.matmul(weights, value), weights + + def _calculate_score_mask(self, scores, v_mask, use_causal_mask): + if use_causal_mask: + # Creates a lower triangular mask, so position i cannot attend to + # positions j > i. This prevents the flow of information from the + # future into the past. + score_shape = tf.shape(scores) + # causal_mask_shape = [1, Tq, Tv]. + mask_shape = (1, score_shape[-2], score_shape[-1]) + ones_mask = tf.ones(shape=mask_shape, dtype="int32") + row_index = tf.cumsum(ones_mask, axis=-2) + col_index = tf.cumsum(ones_mask, axis=-1) + causal_mask = tf.greater_equal(row_index, col_index) + + if v_mask is not None: + # Mask of shape [batch_size, 1, Tv]. + v_mask = tf.expand_dims(v_mask, axis=-2) + return tf.logical_and(v_mask, causal_mask) + return causal_mask + else: + # If not using causal mask, return the value mask as is, + # or None if the value mask is not provided. + return v_mask + + def call( + self, + inputs, + mask=None, + training=False, + ): + self._validate_inputs(inputs=inputs, mask=mask) + q = inputs[0] + v = inputs[1] + k = inputs[2] if len(inputs) > 2 else v + q_mask = mask[0] if mask else None + v_mask = mask[1] if mask else None + scores = self._calculate_scores(query=q, key=k) + scores_mask = self._calculate_score_mask( + scores, v_mask, self.use_causal_mask + ) + result, attention_scores = self._apply_scores( + scores=scores, value=v, scores_mask=scores_mask, training=training + ) + if q_mask is not None: + # Mask of shape [batch_size, Tq, 1]. + q_mask = tf.expand_dims(q_mask, axis=-1) + result *= tf.cast(q_mask, dtype=result.dtype) + if self.return_attention_scores: + return result, attention_scores + return result + + def compute_mask(self, inputs, mask=None): + self._validate_inputs(inputs=inputs, mask=mask) + if mask is None or mask[0] is None: + return None + return tf.convert_to_tensor(mask[0]) + + def compute_output_shape(self, input_shape): + """Returns shape of value tensor dim, but for query tensor length""" + return list(input_shape[0][:-1]), input_shape[1][-1] + + def _validate_inputs(self, inputs, mask=None): + """Validates arguments of the call method.""" + class_name = self.__class__.__name__ + if not isinstance(inputs, list): + raise ValueError( + "{class_name} layer must be called on a list of inputs, " + "namely [query, value] or [query, value, key]. " + "Received: inputs={inputs}.".format(class_name=class_name, inputs=inputs) + ) + if len(inputs) < 2 or len(inputs) > 3: + raise ValueError( + "%s layer accepts inputs list of length 2 or 3, " + "namely [query, value] or [query, value, key]. " + "Received length: %d." % (class_name, len(inputs)) + ) + if mask is not None: + if not isinstance(mask, list): + raise ValueError( + "{class_name} layer mask must be a list, " + "namely [query_mask, value_mask]. Received: mask={mask}.".format(class_name=class_name, mask=mask) + ) + if len(mask) < 2 or len(mask) > 3: + raise ValueError( + "{class_name} layer accepts mask list of length 2 or 3. " + "Received: inputs={inputs}, mask={mask}.".format(class_name=class_name, inputs=inputs, mask=mask) + ) + + def get_config(self): + base_config = super(Attention, self).get_config() + config = { + "use_scale": self.use_scale, + "score_mode": self.score_mode, + "dropout": self.dropout, + } + return dict(list(base_config.items()) + list(config.items())) diff --git a/easy_rec/python/model/easy_rec_model.py b/easy_rec/python/model/easy_rec_model.py index e45010553..f2408ba47 100644 --- a/easy_rec/python/model/easy_rec_model.py +++ b/easy_rec/python/model/easy_rec_model.py @@ -120,6 +120,8 @@ def backbone(self): kwargs = { 'loss_dict': self._loss_dict, 'metric_dict': self._metric_dict, + 'prediction_dict': self._prediction_dict, + 'labels': self._labels, constant.SAMPLE_WEIGHT: self._sample_weight } return self._backbone_net(self._is_training, **kwargs) diff --git a/easy_rec/python/model/multi_task_model.py b/easy_rec/python/model/multi_task_model.py index cff58e079..88e7e4b04 100644 --- a/easy_rec/python/model/multi_task_model.py +++ b/easy_rec/python/model/multi_task_model.py @@ -4,13 +4,16 @@ from collections import OrderedDict import tensorflow as tf - +from tensorflow.python.keras.layers import Dense from easy_rec.python.builders import loss_builder from easy_rec.python.layers.dnn import DNN +from easy_rec.python.layers.keras.attention import Attention from easy_rec.python.model.rank_model import RankModel +from easy_rec.python.layers.utils import Parameter from easy_rec.python.protos import tower_pb2 from easy_rec.python.protos.easy_rec_model_pb2 import EasyRecModel from easy_rec.python.protos.loss_pb2 import LossType +from google.protobuf import struct_pb2 if tf.__version__ >= '2.0': tf = tf.compat.v1 @@ -82,6 +85,30 @@ def build_predict_graph(self): tower_inputs, axis=-1, name=tower_name + '/relation_input') relation_fea = relation_dnn(relation_input) relation_features[tower_name] = relation_fea + elif task_tower_cfg.use_ait_module: + tower_inputs = [tower_features[tower_name]] + for relation_tower_name in task_tower_cfg.relation_tower_names: + tower_inputs.append(relation_features[relation_tower_name]) + if len(tower_inputs) == 1: + relation_fea = tower_inputs[0] + relation_features[tower_name] = relation_fea + else: + if task_tower_cfg.HasField('ait_project_dim'): + dim = task_tower_cfg.ait_project_dim + else: + dim = tf.shape(tower_inputs[0])[-1] + queries = tf.stack([Dense(dim)(x) for x in tower_inputs], axis=1) + keys = tf.stack([Dense(dim)(x) for x in tower_inputs], axis=1) + values = tf.stack([Dense(dim)(x) for x in tower_inputs], axis=1) + st_params = struct_pb2.Struct() + st_params.update({ + 'scale_by_dim': True + }) + params = Parameter(st_params, True) + attention_layer = Attention(params, name="AITM_%s" % tower_name) + result = attention_layer([queries, values, keys]) + relation_fea = result[0] + relation_features[tower_name] = relation_fea else: relation_fea = tower_features[tower_name] diff --git a/easy_rec/python/protos/tower.proto b/easy_rec/python/protos/tower.proto index 580708825..f6981da5b 100644 --- a/easy_rec/python/protos/tower.proto +++ b/easy_rec/python/protos/tower.proto @@ -58,7 +58,7 @@ message BayesTaskTower { optional DNN relation_dnn = 8; // training loss weights optional float weight = 9 [default = 1.0]; - // label name for indcating the sample space for the task tower + // label name for indicating the sample space for the task tower optional string task_space_indicator_label = 10; // the loss weight for sample in the task space optional float in_task_space_weight = 11 [default = 1.0]; @@ -72,4 +72,7 @@ message BayesTaskTower { repeated Loss losses = 15; // whether to use sample weight in this tower required bool use_sample_weight = 16 [default = true]; + // whether to use AIT module + optional bool use_ait_module = 17 [default = false]; + optional uint32 ait_project_dim = 18; };