diff --git a/easy_rec/python/loss/listwise_loss.py b/easy_rec/python/loss/listwise_loss.py index e31385dac..f778f38f8 100644 --- a/easy_rec/python/loss/listwise_loss.py +++ b/easy_rec/python/loss/listwise_loss.py @@ -44,7 +44,7 @@ def listwise_rank_loss(labels, """ loss_name = name if name else 'listwise_rank_loss' logging.info('[{}] temperature: {}'.format(loss_name, temperature)) - + labels = tf.to_float(labels) if temperature != 1.0: logits /= temperature if label_is_logits: