Skip to content

Commit

Permalink
add ZILN loss for ltv prediction task
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxudong committed Nov 12, 2024
1 parent 60ca6d8 commit d34957a
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
4 changes: 4 additions & 0 deletions docs/source/models/loss.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2
| LISTWISE_DISTILL_LOSS | 用来蒸馏给定list排序的损失函数,与listwise rank loss 比较类似 |
| ZILN_LOSS | LTV预测任务的损失函数(num_class必须设置为3) |

- ZILN_LOSS:使用时模型有3个可选的输出(在多目标任务重,输出名有一个目标相关的后缀)
- probs: 预估的转化概率
- y: 预估的LTV值
- logits: Shape为`[batch_size, 3]`的tensor,第一列是`probs`,第二列和第三列是学习到的LogNormal分布的均值与方差
- 说明:SOFTMAX_CROSS_ENTROPY_WITH_NEGATIVE_MINING
- 支持参数配置,升级为 [support vector guided softmax loss](https://128.84.21.199/abs/1812.11317)
- 目前只在DropoutNet模型中可用,可参考《 [冷启动推荐模型DropoutNet深度解析与改进](https://zhuanlan.zhihu.com/p/475117993) 》。
Expand Down
3 changes: 2 additions & 1 deletion easy_rec/python/loss/zero_inflated_lognormal.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def zero_inflated_lognormal_pred(logits):
logits: [batch_size, 3] tensor of logits.
Returns:
positive_probs: [batch_size, 1] tensor of positive probability.
preds: [batch_size, 1] tensor of predicted mean.
"""
logits = tf.convert_to_tensor(logits, dtype=tf.float32)
Expand All @@ -26,7 +27,7 @@ def zero_inflated_lognormal_pred(logits):
preds = (
positive_probs *
tf.keras.backend.exp(loc + 0.5 * tf.keras.backend.square(scale)))
return preds
return positive_probs, preds


def zero_inflated_lognormal_loss(labels, logits, name=''):
Expand Down
19 changes: 11 additions & 8 deletions easy_rec/python/model/rank_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,12 @@ def _output_to_prediction_impl(self,
prediction_dict['probs' + suffix] = probs[:, 1]
elif loss_type == LossType.ZILN_LOSS:
assert num_class == 3, 'num_class must be 3 when loss type is ZILN_LOSS'
probs = zero_inflated_lognormal_pred(output)
probs, preds = zero_inflated_lognormal_pred(output)
tf.summary.scalar('prediction/probs', tf.reduce_mean(probs))
tf.summary.scalar('prediction/y', tf.reduce_mean(preds))
prediction_dict['logits' + suffix] = output
prediction_dict['probs' + suffix] = probs
prediction_dict['y' + suffix] = preds
elif loss_type == LossType.CLASSIFICATION:
if num_class == 1:
output = tf.squeeze(output, axis=1)
Expand Down Expand Up @@ -146,7 +148,7 @@ def build_rtp_output_dict(self):
LossType.PAIR_WISE_LOSS, LossType.BINARY_FOCAL_LOSS,
LossType.PAIRWISE_FOCAL_LOSS, LossType.PAIRWISE_LOGISTIC_LOSS,
LossType.JRC_LOSS, LossType.LISTWISE_DISTILL_LOSS,
LossType.LISTWISE_RANK_LOSS, LossType.ZILN_LOSS
LossType.LISTWISE_RANK_LOSS
}
if loss_types & binary_loss_set:
if 'probs' in self._prediction_dict:
Expand All @@ -156,7 +158,7 @@ def build_rtp_output_dict(self):
'failed to build RTP rank_predict output: classification model ' +
"expect 'probs' prediction, which is not found. Please check if" +
' build_predict_graph() is called.')
elif loss_types & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS}:
elif loss_types & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS}:
if 'y' in self._prediction_dict:
forwarded = self._prediction_dict['y']
else:
Expand Down Expand Up @@ -377,7 +379,7 @@ def _build_metric_impl(self,
metric.recall_at_topk.topk)
elif metric.WhichOneof('metric') == 'mean_absolute_error':
label = tf.to_float(self._labels[label_name])
if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS}:
if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS}:
metric_dict['mean_absolute_error' +
suffix] = metrics_tf.mean_absolute_error(
label, self._prediction_dict['y' + suffix])
Expand All @@ -389,7 +391,7 @@ def _build_metric_impl(self,
assert False, 'mean_absolute_error is not supported for this model'
elif metric.WhichOneof('metric') == 'mean_squared_error':
label = tf.to_float(self._labels[label_name])
if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS}:
if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS}:
metric_dict['mean_squared_error' +
suffix] = metrics_tf.mean_squared_error(
label, self._prediction_dict['y' + suffix])
Expand All @@ -401,7 +403,7 @@ def _build_metric_impl(self,
assert False, 'mean_squared_error is not supported for this model'
elif metric.WhichOneof('metric') == 'root_mean_squared_error':
label = tf.to_float(self._labels[label_name])
if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS}:
if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS}:
metric_dict['root_mean_squared_error' +
suffix] = metrics_tf.root_mean_squared_error(
label, self._prediction_dict['y' + suffix])
Expand Down Expand Up @@ -437,13 +439,14 @@ def _get_outputs_impl(self, loss_type, num_class=1, suffix=''):
LossType.F1_REWEIGHTED_LOSS, LossType.PAIR_WISE_LOSS,
LossType.BINARY_FOCAL_LOSS, LossType.PAIRWISE_FOCAL_LOSS,
LossType.LISTWISE_RANK_LOSS, LossType.PAIRWISE_HINGE_LOSS,
LossType.PAIRWISE_LOGISTIC_LOSS, LossType.LISTWISE_DISTILL_LOSS,
LossType.ZILN_LOSS
LossType.PAIRWISE_LOGISTIC_LOSS, LossType.LISTWISE_DISTILL_LOSS
}
if loss_type in binary_loss_set:
return ['probs' + suffix, 'logits' + suffix]
if loss_type == LossType.JRC_LOSS:
return ['probs' + suffix, 'pos_logits' + suffix]
if loss_type == LossType.ZILN_LOSS:
return ['probs' + suffix, 'y' + suffix, 'logits' + suffix]
if loss_type == LossType.CLASSIFICATION:
if num_class == 1:
return ['probs' + suffix, 'logits' + suffix]
Expand Down

0 comments on commit d34957a

Please sign in to comment.