From 1fb8ac0412d8693ba800a42077604ce1ec777ed9 Mon Sep 17 00:00:00 2001 From: "weisu.yxd" Date: Fri, 20 Oct 2023 14:45:58 +0800 Subject: [PATCH] add custom op demo --- easy_rec/python/layers/keras/custom_ops.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/easy_rec/python/layers/keras/custom_ops.py b/easy_rec/python/layers/keras/custom_ops.py index cb76a86bd..f3affa1bc 100644 --- a/easy_rec/python/layers/keras/custom_ops.py +++ b/easy_rec/python/layers/keras/custom_ops.py @@ -9,12 +9,11 @@ import easy_rec -LIB_PATH = tf.sysconfig.get_link_flags()[0][2:] -LD_LIBRARY_PATH = os.getenv('LD_LIBRARY_PATH') -if LIB_PATH not in LD_LIBRARY_PATH: - os.environ['LD_LIBRARY_PATH'] = ':'.join([LIB_PATH, LD_LIBRARY_PATH]) - logging.info('set LD_LIBRARY_PATH=%s' % os.getenv('LD_LIBRARY_PATH')) - +# LIB_PATH = tf.sysconfig.get_link_flags()[0][2:] +# LD_LIBRARY_PATH = os.getenv('LD_LIBRARY_PATH') +# if LIB_PATH not in LD_LIBRARY_PATH: +# os.environ['LD_LIBRARY_PATH'] = ':'.join([LIB_PATH, LD_LIBRARY_PATH]) +# logging.info('set LD_LIBRARY_PATH=%s' % os.getenv('LD_LIBRARY_PATH')) if tf.__version__ >= '2.0': tf = tf.compat.v1 @@ -46,7 +45,12 @@ def __init__(self, params, name='edit_distance', reuse=None, **kwargs): def call(self, inputs, training=None, **kwargs): input1, input2 = inputs[:2] with ops.device('/CPU:0'): - dist = self.edit_distance(input1, input2, normalize=False, dtype=tf.int32, encoding=self.txt_encoding) + dist = self.edit_distance( + input1, + input2, + normalize=False, + dtype=tf.int32, + encoding=self.txt_encoding) ids = tf.clip_by_value(dist, 0, self.emb_size - 1) embed = tf.nn.embedding_lookup(self.embedding_table, ids) return embed