Skip to content

Commit

Permalink
fix doc build problem
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxudong committed Dec 14, 2023
1 parent 5cafa73 commit ac9fb01
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 22 deletions.
50 changes: 29 additions & 21 deletions docs/source/train.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,31 +236,39 @@ EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2
多目标学习任务中,人工指定多个损失函数的静态权重通常不能获得最好的效果。EasyRec支持损失函数权重自适应学习,示例如下:

```protobuf
losses {
loss_type: CLASSIFICATION
learn_loss_weight: true
}
losses {
loss_type: BINARY_FOCAL_LOSS
learn_loss_weight: true
binary_focal_loss {
gamma: 2.0
alpha: 0.85
}
}
losses {
loss_type: PAIRWISE_FOCAL_LOSS
learn_loss_weight: true
pairwise_focal_loss {
session_name: "client_str"
hinge_margin: 1.0
}
}
loss_weight_strategy: Uncertainty
losses {
loss_type: CLASSIFICATION
learn_loss_weight: true
}
losses {
loss_type: BINARY_FOCAL_LOSS
learn_loss_weight: true
binary_focal_loss {
gamma: 2.0
alpha: 0.85
}
}
losses {
loss_type: PAIRWISE_FOCAL_LOSS
learn_loss_weight: true
pairwise_focal_loss {
session_name: "client_str"
hinge_margin: 1.0
}
}
```

通过`learn_loss_weight`参数配置是否需要开启权重自适应学习,默认不开启。开启之后,`weight`参数不再生效。

参考论文:《Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics》
- loss_weight_strategy: Uncertainty
- 表示通过不确定性来度量损失函数的权重;目前在`learn_loss_weight: true`时必须要设置该值
- loss_weight_strategy: Random
- 表示损失函数的权重设定为归一化的随机数

参考论文:
- 《 Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics 》
-[Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning](https://arxiv.org/abs/2111.10603)

## 训练命令

Expand Down
4 changes: 3 additions & 1 deletion easy_rec/python/model/multi_task_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from easy_rec.python.model.rank_model import RankModel
from easy_rec.python.protos import tower_pb2
from easy_rec.python.protos.loss_pb2 import LossType
from easy_rec.python.protos.easy_rec_model_pb2 import EasyRecModel

if tf.__version__ >= '2.0':
tf = tf.compat.v1
Expand Down Expand Up @@ -188,7 +189,8 @@ def get_learnt_loss(self, loss_type, name, value):
else:
return tf.exp(-uncertainty) * value + 0.5 * uncertainty
else:
raise ValueError('Unsupported loss weight strategy: ' + strategy.Name)
strategy_name = EasyRecModel.LossWeightStrategy.Name(strategy)
raise ValueError('Unsupported loss weight strategy: ' + strategy_name)

def build_loss_graph(self):
"""Build loss graph for multi task model."""
Expand Down

0 comments on commit ac9fb01

Please sign in to comment.