diff --git a/docs/source/component/custom_op.md b/docs/source/component/custom_op.md new file mode 100644 index 000000000..c4f1d47fb --- /dev/null +++ b/docs/source/component/custom_op.md @@ -0,0 +1,126 @@ +# 使用自定义 OP + +当内置的tf算子不能满足业务需求,或者通过组合现有算子实现需求的性能较差时,可以考虑自定义tf的OP。 + +1. 实现自定义算子,编译为动态库 + - 参考官方示例:[TensorFlow Custom Op](https://github.com/tensorflow/custom-op/) + - 注意:自定义Op的编译依赖tf版本需要与执行时的tf版本保持一致 + - 您可能需要为离线训练 与 在线推理服务 编译两个不同依赖环境的动态库 + - 在PAI平台上需要依赖 tf 1.12 版本编译 + - 在EAS的 [EasyRec Processor](https://help.aliyun.com/zh/pai/user-guide/easyrec) 中使用自定义Op需要依赖 tf 2.10.1 编译 +2. 在`EasyRec`中使用自定义Op的步骤 + 1. 下载EasyRec的最新[源代码](https://github.com/alibaba/EasyRec) + 2. 把上一步编译好的动态库放到`easy_rec/python/ops/${tf_version}`目录,注意版本要子目录名一致 + 3. 开发一个使用自定义Op的组件 + - 新组件的代码添加到 `easy_rec/python/layers/keras/custom_ops.py` + - `custom_ops.py` 提供了一个自定义Op组件的示例 + - 声明新组件,在`easy_rec/python/layers/keras/__init__.py`文件中添加导出语句 + 4. 编写模型配置文件,使用组件化的方式搭建模型,包含新定义的组件(参考下文) + 5. 运行`pai_jobs/deploy_ext.sh`脚本,打包EasyRec,并把打好的资源包(`easy_rec_ext_${version}_res.tar.gz`)上传到MaxCompute项目空间 + 6. (在DataWorks里 or 用odpscmd客户端工具) 训练 & 评估 & 导出 模型 + +## 导出自定义Op的动态库到 saved_model 的 assets 目录 + +```bash +pai -name easy_rec_ext +-Dcmd='export' +-Dconfig='oss://cold-start/EasyRec/custom_op/pipeline.config' +-Dexport_dir='oss://cold-start/EasyRec/custom_op/export/final_with_lib' +-Dextra_params='--asset_files oss://cold-start/EasyRec/config/libedit_distance.so' +-Dres_project='pai_rec_test_dev' +-Dversion='0.7.5' +-Dbuckets='oss://cold-start/' +-Darn='acs:ram::XXXXXXXXXX:role/aliyunodpspaidefaultrole' +-DossHost='oss-cn-beijing-internal.aliyuncs.com' +; +``` + +**注意**: +1. 在 训练、评估、导出 命令中需要用`-Dres_project`指定上传easyrec资源包的MaxCompute项目空间名 +2. 在 训练、评估、导出 命令中需要用`-Dversion`指定资源包的版本 +3. asset_files参数指定的动态库会被线上推理服务加载,因此需要在与线上推理服务一致的tf版本上编译。(目前是EAS平台的EasyRec Processor依赖 tf 2.10.1版本)。 + - 如果 asset_files 参数还需要指定其他文件路径(比如 fg.json),多个路径之间用英文逗号隔开。 +4. 再次强调一遍,**导出的动态库依赖的tf版本需要与推理服务依赖的tf版本保持一致** + +## 自定义Op的示例 + +```protobuf +feature_config: { + ... + features: { + feature_name: 'raw_genres' + input_names: 'genres' + feature_type: PassThroughFeature + } + features: { + feature_name: 'raw_title' + input_names: 'title' + feature_type: PassThroughFeature + } +} +model_config: { + model_class: 'RankModel' + model_name: 'MLP' + feature_groups: { + group_name: 'text' + feature_names: 'raw_genres' + feature_names: 'raw_title' + wide_deep: DEEP + } + feature_groups: { + group_name: 'features' + feature_names: 'user_id' + feature_names: 'movie_id' + feature_names: 'gender' + feature_names: 'age' + feature_names: 'occupation' + feature_names: 'zip_id' + feature_names: 'movie_year_bin' + wide_deep: DEEP + } + backbone { + blocks { + name: 'text' + inputs { + feature_group_name: 'text' + } + raw_input { + } + } + blocks { + name: 'edit_distance' + inputs { + block_name: 'text' + } + keras_layer { + class_name: 'EditDistance' + } + } + blocks { + name: 'mlp' + inputs { + feature_group_name: 'features' + } + inputs { + block_name: 'edit_distance' + } + keras_layer { + class_name: 'MLP' + mlp { + hidden_units: [256, 128] + } + } + } + } + model_params { + l2_regularization: 1e-5 + } + embedding_regularization: 1e-6 +} +``` + +1. 如果自定义Op需要处理原始输入特征,则在定义特征时指定 `feature_type: PassThroughFeature` + - 非 `PassThroughFeature` 类型的特征会在预处理阶段做一些变换,组件代码里拿不到原始值 +2. 自定义Op需要处理的原始输入特征按照顺序放置到同一个`feature group`内 +3. 配置一个类型为`raw_input`的输入组件,获取原始输入特征 + - 这是目前EasyRec支持的读取原始输入特征的唯一方式 diff --git a/docs/source/component/sequence.md b/docs/source/component/sequence.md new file mode 100644 index 000000000..de7026610 --- /dev/null +++ b/docs/source/component/sequence.md @@ -0,0 +1,78 @@ +# 序列化组件的配置方式 + +序列模型(DIN、BST)的组件化配置方式需要把输入特征放置在同一个`feature_group`内。 + +序列模型一般包含 `history behavior sequence` 与 `target item` 两部分,且每部分都可能包含多个属性(子特征)。 + +在序列组件输入的`feature_group`内,**按照顺序**定义 `history behavior sequence` 与 `target item`的各个子特征。 + +框架按照特征定义的类型`feature_type`字段来识别某个具体的特征是属于 `history behavior sequence` 还是 `target item`。 +所有 `SequenceFeature` 类型的子特征都被识别为`history behavior sequence`的一部分; 所有非`SequenceFeature` 类型的子特征都被识别为`target item`的一部分。 + +**两部分的子特征的顺序需要保持一致**。在下面的例子中, +- `concat([cate_id,brand], axis=-1)` 是`target item`最终的embedding(2D); +- `concat([tag_category_list, tag_brand_list], axis=-1)` 是`history behavior sequence`最终的embedding(3D) + +```protobuf +model_config: { + model_name: 'DIN' + model_class: 'RankModel + ... + feature_groups: { + group_name: 'sequence' + feature_names: "cate_id" + feature_names: "brand" + feature_names: "tag_category_list" + feature_names: "tag_brand_list" + wide_deep: DEEP + } + backbone { + blocks { + name: 'seq_input' + inputs { + feature_group_name: 'sequence' + } + input_layer { + output_seq_and_normal_feature: true + } + } + blocks { + name: 'DIN' + inputs { + block_name: 'seq_input' + } + keras_layer { + class_name: 'DIN' + din { + attention_dnn { + hidden_units: 32 + hidden_units: 1 + activation: "dice" + } + need_target_feature: true + } + } + } + ... + } +} +``` + +使用序列组件时,必须配置一个`input_layer`类型的`block`,并且配置`output_seq_and_normal_feature: true`参数,如下。 + +```protobuf +blocks { + name: 'seq_input' + inputs { + feature_group_name: 'sequence' + } + input_layer { + output_seq_and_normal_feature: true + } +} +``` + +## 完整的例子 + +- [DIN](../models/din.md) +- [BST](../models/bst.md) diff --git a/docs/source/index.rst b/docs/source/index.rst index 9cef0a0a5..f76eac87b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -30,6 +30,8 @@ Welcome to easy_rec's documentation! component/backbone component/component + component/sequence + component/custom_op .. toctree:: :maxdepth: 3 diff --git a/docs/source/models/bst.md b/docs/source/models/bst.md index e6a62fda3..681b08a53 100644 --- a/docs/source/models/bst.md +++ b/docs/source/models/bst.md @@ -158,8 +158,8 @@ model_config: { group_name: 'sequence' feature_names: "cate_id" feature_names: "brand" - feature_names: "tag_brand_list" feature_names: "tag_category_list" + feature_names: "tag_brand_list" wide_deep: DEEP } backbone { @@ -219,6 +219,7 @@ model_config: { - feature_groups: 特征组 - 包含两个feature_group: dense 和sparse group - wide_deep: BST模型使用的都是Deep features, 所以都设置成DEEP + - 序列组件对应的feature_group的配置方式请查看 [参考文档](../component/sequence.md) - backbone: 通过组件化的方式搭建的主干网络,[参考文档](../component/backbone.md) - blocks: 由多个`组件块`组成的一个有向无环图(DAG),框架负责按照DAG的拓扑排序执行个`组件块`关联的代码逻辑,构建TF Graph的一个子图 - name/inputs: 每个`block`有一个唯一的名字(name),并且有一个或多个输入(inputs)和输出 diff --git a/docs/source/models/din.md b/docs/source/models/din.md index b86444c81..b54f4f363 100644 --- a/docs/source/models/din.md +++ b/docs/source/models/din.md @@ -133,8 +133,8 @@ model_config: { group_name: 'sequence' feature_names: "cate_id" feature_names: "brand" - feature_names: "tag_brand_list" feature_names: "tag_category_list" + feature_names: "tag_brand_list" wide_deep: DEEP } backbone { @@ -192,6 +192,7 @@ model_config: { - feature_groups: 特征组 - 包含两个feature_group: dense 和sparse group - wide_deep: DIN模型使用的都是Deep features, 所以都设置成DEEP + - 序列组件对应的feature_group的配置方式请查看 [参考文档](../component/sequence.md) - backbone: 通过组件化的方式搭建的主干网络,[参考文档](../component/backbone.md) - blocks: 由多个`组件块`组成的一个有向无环图(DAG),框架负责按照DAG的拓扑排序执行个`组件块`关联的代码逻辑,构建TF Graph的一个子图 - name/inputs: 每个`block`有一个唯一的名字(name),并且有一个或多个输入(inputs)和输出 diff --git a/easy_rec/python/feature_column/feature_column.py b/easy_rec/python/feature_column/feature_column.py index 9487b4bc6..20872d2b4 100644 --- a/easy_rec/python/feature_column/feature_column.py +++ b/easy_rec/python/feature_column/feature_column.py @@ -129,7 +129,7 @@ def _cmp_embed_config(a, b): self.parse_sequence_feature(config) elif config.feature_type == config.ExprFeature: self.parse_expr_feature(config) - else: + elif config.feature_type != config.PassThroughFeature: assert False, 'invalid feature type: %s' % config.feature_type except FeatureKeyError: pass diff --git a/easy_rec/python/layers/keras/custom_ops.py b/easy_rec/python/layers/keras/custom_ops.py index 639542b63..cb76a86bd 100644 --- a/easy_rec/python/layers/keras/custom_ops.py +++ b/easy_rec/python/layers/keras/custom_ops.py @@ -5,13 +5,19 @@ import os import tensorflow as tf +from tensorflow.python.framework import ops import easy_rec LIB_PATH = tf.sysconfig.get_link_flags()[0][2:] LD_LIBRARY_PATH = os.getenv('LD_LIBRARY_PATH') -os.environ['LD_LIBRARY_PATH'] = ':'.join([LIB_PATH, LD_LIBRARY_PATH]) -logging.info('set LD_LIBRARY_PATH=%s' % os.getenv('LD_LIBRARY_PATH')) +if LIB_PATH not in LD_LIBRARY_PATH: + os.environ['LD_LIBRARY_PATH'] = ':'.join([LIB_PATH, LD_LIBRARY_PATH]) + logging.info('set LD_LIBRARY_PATH=%s' % os.getenv('LD_LIBRARY_PATH')) + + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 class EditDistance(tf.keras.layers.Layer): @@ -27,17 +33,20 @@ def __init__(self, params, name='edit_distance', reuse=None, **kwargs): logging.warning('load edit_distance op from %s failed: %s' % (custom_op_path, str(ex))) custom_ops = None - self.edit_distance = custom_ops.edit_distance_op + self.edit_distance = custom_ops.my_edit_distance + + self.txt_encoding = params.get_or_default('text_encoding', 'utf-8') + self.emb_size = params.get_or_default('embedding_size', 512) + emb_dim = params.get_or_default('embedding_dim', 4) + with tf.variable_scope(self.name, reuse=reuse): + self.embedding_table = tf.get_variable('embedding_table', + [self.emb_size, emb_dim], + tf.float32) def call(self, inputs, training=None, **kwargs): input1, input2 = inputs[:2] - print('input1:', input1) - print('input2:', input2) - str1 = tf.sparse.to_dense(input1, default_value='') - str2 = tf.sparse.to_dense(input1, default_value='') - print('str1:', str1) - print('str2:', str2) - dist = self.edit_distance(str1, str2, dtype=tf.float32) - print('dist:', dist) - dist = tf.reshape(dist, [-1, 1]) - return dist + with ops.device('/CPU:0'): + dist = self.edit_distance(input1, input2, normalize=False, dtype=tf.int32, encoding=self.txt_encoding) + ids = tf.clip_by_value(dist, 0, self.emb_size - 1) + embed = tf.nn.embedding_lookup(self.embedding_table, ids) + return embed diff --git a/easy_rec/python/ops/1.12/libedit_distance.so b/easy_rec/python/ops/1.12/libedit_distance.so new file mode 100755 index 000000000..51180bc6b Binary files /dev/null and b/easy_rec/python/ops/1.12/libedit_distance.so differ diff --git a/easy_rec/python/ops/1.12_pai/libedit_distance.so b/easy_rec/python/ops/1.12_pai/libedit_distance.so index 21c6743b1..e31899a7b 100755 Binary files a/easy_rec/python/ops/1.12_pai/libedit_distance.so and b/easy_rec/python/ops/1.12_pai/libedit_distance.so differ diff --git a/easy_rec/python/ops/1.15/libedit_distance.so b/easy_rec/python/ops/1.15/libedit_distance.so new file mode 100755 index 000000000..75bf35009 Binary files /dev/null and b/easy_rec/python/ops/1.15/libedit_distance.so differ diff --git a/easy_rec/python/ops/edit_distance_op.py b/easy_rec/python/ops/edit_distance_op.py deleted file mode 100644 index 4d7184100..000000000 --- a/easy_rec/python/ops/edit_distance_op.py +++ /dev/null @@ -1,23 +0,0 @@ -import logging -import os - -import tensorflow as tf -from tensorflow.python.util.tf_export import tf_export - -import easy_rec - -custom_op_path = os.path.join(easy_rec.ops_dir, 'libedit_distance.so') -print('custom op path: %s' % custom_op_path) - -try: - custom_ops = tf.load_op_library(custom_op_path) - logging.info('load edit_distance op from %s succeed' % custom_op_path) -except Exception as ex: - print('custom op path: %s' % custom_op_path) - logging.warning('load edit_distance op failed: %s' % str(ex)) - custom_ops = None - - -@tf_export('edit_distance') -def edit_distance(input1, input2): - return custom_ops.edit_distance_op(input1, input2) diff --git a/easy_rec/python/protos/feature_config.proto b/easy_rec/python/protos/feature_config.proto index 2f1627f29..6a708b110 100644 --- a/easy_rec/python/protos/feature_config.proto +++ b/easy_rec/python/protos/feature_config.proto @@ -42,6 +42,7 @@ message FeatureConfig { LookupFeature = 4; SequenceFeature = 5; ExprFeature = 6; + PassThroughFeature = 7; } enum FieldType { diff --git a/easy_rec/python/test/train_eval_test.py b/easy_rec/python/test/train_eval_test.py index b0b66d30c..bd0f21f07 100644 --- a/easy_rec/python/test/train_eval_test.py +++ b/easy_rec/python/test/train_eval_test.py @@ -402,6 +402,14 @@ def test_highway(self): 'samples/model_config/highway_on_movielens.config', self._test_dir) self.assertTrue(self._success) + @unittest.skipIf( + LooseVersion(tf.__version__) < LooseVersion('2.0.0'), + 'EditDistanceOp only work before tf version == 2.0') + def test_custom_op(self): + self._success = test_utils.test_single_train_eval( + 'samples/model_config/mlp_on_movielens_with_custom_op.config', self._test_dir) + self.assertTrue(self._success) + def test_cdn(self): self._success = test_utils.test_single_train_eval( 'samples/model_config/cdn_on_taobao.config', self._test_dir) diff --git a/samples/model_config/bst_backbone_on_taobao.config b/samples/model_config/bst_backbone_on_taobao.config index a46f0e09e..b801f87ef 100644 --- a/samples/model_config/bst_backbone_on_taobao.config +++ b/samples/model_config/bst_backbone_on_taobao.config @@ -257,8 +257,8 @@ model_config: { group_name: 'sequence' feature_names: "cate_id" feature_names: "brand" - feature_names: "tag_brand_list" feature_names: "tag_category_list" + feature_names: "tag_brand_list" wide_deep: DEEP } backbone { diff --git a/samples/model_config/din_backbone_on_taobao.config b/samples/model_config/din_backbone_on_taobao.config index d5e705747..7cb48ac56 100644 --- a/samples/model_config/din_backbone_on_taobao.config +++ b/samples/model_config/din_backbone_on_taobao.config @@ -257,8 +257,8 @@ model_config: { group_name: 'sequence' feature_names: "cate_id" feature_names: "brand" - feature_names: "tag_brand_list" feature_names: "tag_category_list" + feature_names: "tag_brand_list" wide_deep: DEEP } backbone { diff --git a/samples/model_config/mlp_on_movielens_with_custom_op.config b/samples/model_config/mlp_on_movielens_with_custom_op.config new file mode 100644 index 000000000..fb400ab1c --- /dev/null +++ b/samples/model_config/mlp_on_movielens_with_custom_op.config @@ -0,0 +1,257 @@ +train_input_path: "data/test/movielens_1m/ml_train_data" +eval_input_path: "data/test/movielens_1m/ml_test_data" +model_dir: "experiments/mlp_movielens_ckpt" + +train_config { + optimizer_config: { + adam_optimizer: { + learning_rate: { + constant_learning_rate { + learning_rate: 0.0001 + } + } + beta1: 0.9 + beta2: 0.999 + } + use_moving_average: false + } + log_step_count_steps: 100 + save_checkpoints_steps: 100 + sync_replicas: true + num_steps: 100 +} + +eval_config { + metrics_set: { + gauc { + uid_field: 'user_id' + } + } + metrics_set: { + auc {} + } +} + +data_config { + input_fields { + input_name:'rating' + input_type: INT32 + } + input_fields { + input_name:'label' + input_type: INT32 + } + input_fields { + input_name:'user_id' + input_type: INT32 + } + input_fields { + input_name:'movie_id' + input_type: INT32 + } + input_fields { + input_name:'gender' + input_type: INT32 + } + input_fields { + input_name: 'age' + input_type: INT32 + } + input_fields { + input_name: 'occupation' + input_type: INT32 + } + input_fields { + input_name: 'zip_id' + input_type: INT32 + default_val: '0' + } + input_fields { + input_name: 'genres' + input_type: STRING + default_val: 'unknown' + } + input_fields { + input_name: 'title' + input_type: STRING + default_val: 'unknown' + } + input_fields { + input_name: 'movie_year_bin' + input_type: INT32 + } + input_fields { + input_name: 'score_year_diff' + input_type: INT32 + default_val: '0' + } + input_fields { + input_name: 'score_time' + input_type: DOUBLE + } + input_fields { + input_name: 'embedding' + input_type: STRING + default_val: '' + } + + label_fields: 'label' + batch_size: 128 + num_epochs: 10000 + prefetch_size: 1 + input_type: CSVInput +} + +feature_config: { + features: { + input_names: 'user_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 12000 + } + features: { + input_names: 'movie_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 6000 + } + features: { + input_names: 'gender' + feature_type: IdFeature + embedding_dim: 16 + num_buckets: 2 + } + features: { + input_names: 'zip_id' + feature_type: IdFeature + embedding_dim: 16 + num_buckets: 3405 + } + features: { + input_names: 'occupation' + feature_type: IdFeature + embedding_dim: 16 + num_buckets: 21 + } + features: { + input_names: 'age' + feature_type: IdFeature + embedding_dim: 16 + num_buckets: 7 + } + features: { + input_names: 'genres' + feature_type: SequenceFeature + separator: '|' + embedding_dim: 16 + max_seq_len: 8 + hash_bucket_size: 100 + } + features: { + input_names: 'title' + feature_type: SequenceFeature + separator: ' ' + max_seq_len: 16 + embedding_dim: 16 + hash_bucket_size: 20000 + } + features: { + input_names: 'movie_year_bin' + feature_type: IdFeature + embedding_dim: 16 + num_buckets: 36 + } + features: { + input_names: 'score_year_diff' + feature_type: IdFeature + embedding_dim: 16 + num_buckets: 83 + } + features: { + input_names: 'score_time' + feature_type: RawFeature + embedding_dim: 16 + } + features: { + feature_name: 'raw_genres' + input_names: 'genres' + feature_type: PassThroughFeature + } + features: { + feature_name: 'raw_title' + input_names: 'title' + feature_type: PassThroughFeature + } +} +model_config: { + model_class: 'RankModel' + model_name: 'MLP' + feature_groups: { + group_name: 'text' + feature_names: 'raw_genres' + feature_names: 'raw_title' + wide_deep: DEEP + } + feature_groups: { + group_name: 'features' + feature_names: 'user_id' + feature_names: 'movie_id' + feature_names: 'gender' + feature_names: 'age' + feature_names: 'occupation' + feature_names: 'zip_id' + feature_names: 'movie_year_bin' + feature_names: 'score_year_diff' + feature_names: 'score_time' + wide_deep: DEEP + } + backbone { + blocks { + name: 'text' + inputs { + feature_group_name: 'text' + } + raw_input { + } + } + blocks { + name: 'edit_distance' + inputs { + block_name: 'text' + } + keras_layer { + class_name: 'EditDistance' + st_params { + fields { + key: 'text_encoding' + value: { string_value: 'latin' } + } + } + } + } + blocks { + name: 'mlp' + inputs { + feature_group_name: 'features' + } + inputs { + block_name: 'edit_distance' + } + keras_layer { + class_name: 'MLP' + mlp { + hidden_units: [256, 128] + } + } + } + } + model_params { + l2_regularization: 1e-5 + } + embedding_regularization: 1e-6 +} +export_config { + exporter_type: "best" + best_exporter_metric: "gauc" + exports_to_keep: 1 +}