Skip to content

Commit

Permalink
增加目标的动态权重
Browse files Browse the repository at this point in the history
  • Loading branch information
chengaofei committed May 28, 2024
1 parent 98b7c8c commit 53d9ba2
Show file tree
Hide file tree
Showing 8 changed files with 661 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/source/quick_start/mc_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pai -name easy_rec_ext -project algo_public
- -Dtables: 定义其他依赖表(可选),如负采样的表
- -Dcluster: 定义PS的数目和worker的数目。具体见:[PAI-TF任务参数介绍](https://help.aliyun.com/document_detail/154186.html?spm=a2c4g.11186623.4.3.e56f1adb7AJ9T5)
- -Deval_method: 评估方法
- separate: 用worker(task_id=1)做评估
- separate: 用worker(task_id=1)做评估。点击训练的logview中worker#1_0的stderr,出现类似字段"Saving dict for global step 3949: auc = 0.7643898, global_step = 3949, loss = 0.38898173, loss/loss/cross_entropy_loss = 0.38898173, loss/loss/total_loss = 0.38898173"即是评估指标
- none: 不需要评估
- master: 在master(task_id=0)上做评估
- -Dfine_tune_checkpoint: 可选,从checkpoint restore参数,进行finetune
Expand Down
18 changes: 17 additions & 1 deletion easy_rec/python/input/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(self,
x.default_val for x in data_config.input_fields
]
self._label_fields = list(data_config.label_fields)
self._label_dynamic_weight = list(data_config.label_dynamic_weight)
self._feature_fields = list(data_config.feature_fields)
self._label_sep = list(data_config.label_sep)
self._label_dim = list(data_config.label_dim)
Expand Down Expand Up @@ -139,6 +140,8 @@ def __init__(self,
# add sample weight to effective fields
if self._data_config.HasField('sample_weight'):
self._effective_fields.append(self._data_config.sample_weight)
if len(self._label_dynamic_weight) > 0:
self._effective_fields.extend(self._label_dynamic_weight)

# add uid_field of GAUC and session_fields of SessionAUC
if self._pipeline_config is not None:
Expand Down Expand Up @@ -234,6 +237,7 @@ def get_feature_input_fields(self):
return [
x for x in self._input_fields
if x not in self._label_fields and x != self._data_config.sample_weight
and x not in self._label_dynamic_weight
]

def should_stop(self, curr_epoch):
Expand Down Expand Up @@ -269,13 +273,14 @@ def create_multi_placeholders(self, export_config):
effective_fids = [
fid for fid in range(len(self._input_fields))
if self._input_fields[fid] not in self._label_fields and
self._input_fields[fid] not in self._label_dynamic_weight and
self._input_fields[fid] != sample_weight_field
]

inputs = {}
for fid in effective_fids:
input_name = self._input_fields[fid]
if input_name == sample_weight_field:
if input_name == sample_weight_field or input_name in self._label_dynamic_weight:
continue
if placeholder_named_by_input:
placeholder_name = input_name
Expand Down Expand Up @@ -318,6 +323,7 @@ def create_placeholders(self, export_config):
effective_fids = [
fid for fid in range(len(self._input_fields))
if self._input_fields[fid] not in self._label_fields and
self._input_fields[fid] not in self._label_dynamic_weight and
self._input_fields[fid] != sample_weight_field
]
logging.info(
Expand All @@ -330,6 +336,8 @@ def create_placeholders(self, export_config):
ftype = self._input_field_types[fid]
tf_type = get_tf_type(ftype)
input_name = self._input_fields[fid]
if input_name in self._label_dynamic_weight:
continue
if tf_type in [tf.float32, tf.double, tf.int32, tf.int64]:
features[input_name] = tf.string_to_number(
input_vals[:, tmp_id],
Expand Down Expand Up @@ -925,6 +933,14 @@ def _preprocess(self, field_dict):
if self._mode != tf.estimator.ModeKeys.PREDICT:
parsed_dict[constant.SAMPLE_WEIGHT] = field_dict[
self._data_config.sample_weight]
if len(self._label_dynamic_weight
) > 0 and self._mode != tf.estimator.ModeKeys.PREDICT:
for label_weight in self._label_dynamic_weight:
if field_dict[label_weight].dtype == tf.float32:
parsed_dict[label_weight] = field_dict[label_weight]
else:
parsed_dict[label_weight] = tf.cast(
field_dict[label_weight], dtype=tf.float64)

if Input.DATA_OFFSET in field_dict:
parsed_dict[Input.DATA_OFFSET] = field_dict[Input.DATA_OFFSET]
Expand Down
2 changes: 2 additions & 0 deletions easy_rec/python/model/multi_task_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ def build_loss_graph(self):
for task_tower_cfg in self._task_towers:
tower_name = task_tower_cfg.tower_name
loss_weight = task_tower_cfg.weight
if task_tower_cfg.HasField('dynamic_weight'):
loss_weight *= self._feature_dict[task_tower_cfg.dynamic_weight]
if task_tower_cfg.use_sample_weight:
loss_weight *= self._sample_weight

Expand Down
2 changes: 2 additions & 0 deletions easy_rec/python/protos/dataset.proto
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ message DatasetConfig {

// input field for sample weight
optional string sample_weight = 22;
// input field for label dynimic weight
repeated string label_dynamic_weight = 27;
// the compression type of tfrecord
optional string data_compression_type = 23 [default = ''];

Expand Down
4 changes: 4 additions & 0 deletions easy_rec/python/protos/tower.proto
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ message TaskTower {
optional DNN dnn = 6;
// training loss weights
optional float weight = 7 [default = 1.0];
// training loss label dynamic weights
optional string dynamic_weight = 8;
// label name for indicating the sample space for the task tower
optional string task_space_indicator_label = 10;
// the loss weight for sample in the task space
Expand Down Expand Up @@ -72,4 +74,6 @@ message BayesTaskTower {
repeated Loss losses = 15;
// whether to use sample weight in this tower
required bool use_sample_weight = 16 [default = true];
// training loss label dynamic weights
optional string dynamic_weight = 17;
};
12 changes: 12 additions & 0 deletions easy_rec/python/test/train_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,12 +937,24 @@ def test_sequence_esmm(self):
self._test_dir)
self.assertTrue(self._success)

def test_label_dynamic_weight_esmm(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/esmm_on_label_dynamic_weight_feature_taobao.config',
self._test_dir)
self.assertTrue(self._success)

def test_sequence_mmoe(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/mmoe_on_sequence_feature_taobao.config',
self._test_dir)
self.assertTrue(self._success)

def test_label_dynamic_weight_sequence_mmoe(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/mmoe_on_label_dynamic_weight_sequence_feature_taobao.config',
self._test_dir)
self.assertTrue(self._success)

def test_sequence_ple(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/ple_on_sequence_feature_taobao.config',
Expand Down
Loading

0 comments on commit 53d9ba2

Please sign in to comment.