diff --git a/docs/source/models/loss.md b/docs/source/models/loss.md index bda640f4e..e098aa0a6 100644 --- a/docs/source/models/loss.md +++ b/docs/source/models/loss.md @@ -27,6 +27,10 @@ EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2 | LISTWISE_DISTILL_LOSS | 用来蒸馏给定list排序的损失函数,与listwise rank loss 比较类似 | | ZILN_LOSS | LTV预测任务的损失函数(num_class必须设置为3) | +- ZILN_LOSS:使用时模型有3个可选的输出(在多目标任务重,输出名有一个目标相关的后缀) + - probs: 预估的转化概率 + - y: 预估的LTV值 + - logits: Shape为`[batch_size, 3]`的tensor,第一列是`probs`,第二列和第三列是学习到的LogNormal分布的均值与方差 - 说明:SOFTMAX_CROSS_ENTROPY_WITH_NEGATIVE_MINING - 支持参数配置,升级为 [support vector guided softmax loss](https://128.84.21.199/abs/1812.11317) , - 目前只在DropoutNet模型中可用,可参考《 [冷启动推荐模型DropoutNet深度解析与改进](https://zhuanlan.zhihu.com/p/475117993) 》。 diff --git a/easy_rec/python/loss/zero_inflated_lognormal.py b/easy_rec/python/loss/zero_inflated_lognormal.py index da1e03d25..e3ae3110e 100644 --- a/easy_rec/python/loss/zero_inflated_lognormal.py +++ b/easy_rec/python/loss/zero_inflated_lognormal.py @@ -17,6 +17,7 @@ def zero_inflated_lognormal_pred(logits): logits: [batch_size, 3] tensor of logits. Returns: + positive_probs: [batch_size, 1] tensor of positive probability. preds: [batch_size, 1] tensor of predicted mean. """ logits = tf.convert_to_tensor(logits, dtype=tf.float32) @@ -26,7 +27,7 @@ def zero_inflated_lognormal_pred(logits): preds = ( positive_probs * tf.keras.backend.exp(loc + 0.5 * tf.keras.backend.square(scale))) - return preds + return positive_probs, preds def zero_inflated_lognormal_loss(labels, logits, name=''): diff --git a/easy_rec/python/model/rank_model.py b/easy_rec/python/model/rank_model.py index 7fcc37126..640f52502 100644 --- a/easy_rec/python/model/rank_model.py +++ b/easy_rec/python/model/rank_model.py @@ -82,10 +82,12 @@ def _output_to_prediction_impl(self, prediction_dict['probs' + suffix] = probs[:, 1] elif loss_type == LossType.ZILN_LOSS: assert num_class == 3, 'num_class must be 3 when loss type is ZILN_LOSS' - probs = zero_inflated_lognormal_pred(output) + probs, preds = zero_inflated_lognormal_pred(output) tf.summary.scalar('prediction/probs', tf.reduce_mean(probs)) + tf.summary.scalar('prediction/y', tf.reduce_mean(preds)) prediction_dict['logits' + suffix] = output prediction_dict['probs' + suffix] = probs + prediction_dict['y' + suffix] = preds elif loss_type == LossType.CLASSIFICATION: if num_class == 1: output = tf.squeeze(output, axis=1) @@ -146,7 +148,7 @@ def build_rtp_output_dict(self): LossType.PAIR_WISE_LOSS, LossType.BINARY_FOCAL_LOSS, LossType.PAIRWISE_FOCAL_LOSS, LossType.PAIRWISE_LOGISTIC_LOSS, LossType.JRC_LOSS, LossType.LISTWISE_DISTILL_LOSS, - LossType.LISTWISE_RANK_LOSS, LossType.ZILN_LOSS + LossType.LISTWISE_RANK_LOSS } if loss_types & binary_loss_set: if 'probs' in self._prediction_dict: @@ -156,7 +158,7 @@ def build_rtp_output_dict(self): 'failed to build RTP rank_predict output: classification model ' + "expect 'probs' prediction, which is not found. Please check if" + ' build_predict_graph() is called.') - elif loss_types & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS}: + elif loss_types & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS}: if 'y' in self._prediction_dict: forwarded = self._prediction_dict['y'] else: @@ -377,7 +379,7 @@ def _build_metric_impl(self, metric.recall_at_topk.topk) elif metric.WhichOneof('metric') == 'mean_absolute_error': label = tf.to_float(self._labels[label_name]) - if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS}: + if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS}: metric_dict['mean_absolute_error' + suffix] = metrics_tf.mean_absolute_error( label, self._prediction_dict['y' + suffix]) @@ -389,7 +391,7 @@ def _build_metric_impl(self, assert False, 'mean_absolute_error is not supported for this model' elif metric.WhichOneof('metric') == 'mean_squared_error': label = tf.to_float(self._labels[label_name]) - if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS}: + if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS}: metric_dict['mean_squared_error' + suffix] = metrics_tf.mean_squared_error( label, self._prediction_dict['y' + suffix]) @@ -401,7 +403,7 @@ def _build_metric_impl(self, assert False, 'mean_squared_error is not supported for this model' elif metric.WhichOneof('metric') == 'root_mean_squared_error': label = tf.to_float(self._labels[label_name]) - if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS}: + if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS}: metric_dict['root_mean_squared_error' + suffix] = metrics_tf.root_mean_squared_error( label, self._prediction_dict['y' + suffix]) @@ -437,13 +439,14 @@ def _get_outputs_impl(self, loss_type, num_class=1, suffix=''): LossType.F1_REWEIGHTED_LOSS, LossType.PAIR_WISE_LOSS, LossType.BINARY_FOCAL_LOSS, LossType.PAIRWISE_FOCAL_LOSS, LossType.LISTWISE_RANK_LOSS, LossType.PAIRWISE_HINGE_LOSS, - LossType.PAIRWISE_LOGISTIC_LOSS, LossType.LISTWISE_DISTILL_LOSS, - LossType.ZILN_LOSS + LossType.PAIRWISE_LOGISTIC_LOSS, LossType.LISTWISE_DISTILL_LOSS } if loss_type in binary_loss_set: return ['probs' + suffix, 'logits' + suffix] if loss_type == LossType.JRC_LOSS: return ['probs' + suffix, 'pos_logits' + suffix] + if loss_type == LossType.ZILN_LOSS: + return ['probs' + suffix, 'y' + suffix, 'logits' + suffix] if loss_type == LossType.CLASSIFICATION: if num_class == 1: return ['probs' + suffix, 'logits' + suffix]