diff --git a/easy_rec/python/layers/keras/custom_ops.py b/easy_rec/python/layers/keras/custom_ops.py index f3affa1bc..eedee439f 100644 --- a/easy_rec/python/layers/keras/custom_ops.py +++ b/easy_rec/python/layers/keras/custom_ops.py @@ -3,17 +3,25 @@ """Convenience blocks for using custom ops.""" import logging import os - import tensorflow as tf from tensorflow.python.framework import ops -import easy_rec -# LIB_PATH = tf.sysconfig.get_link_flags()[0][2:] -# LD_LIBRARY_PATH = os.getenv('LD_LIBRARY_PATH') -# if LIB_PATH not in LD_LIBRARY_PATH: -# os.environ['LD_LIBRARY_PATH'] = ':'.join([LIB_PATH, LD_LIBRARY_PATH]) -# logging.info('set LD_LIBRARY_PATH=%s' % os.getenv('LD_LIBRARY_PATH')) +curr_dir, _ = os.path.split(__file__) +parent_dir = os.path.dirname(curr_dir) +ops_idr = os.path.dirname(parent_dir) +ops_dir = os.path.join(ops_idr, 'python', 'ops') +if 'PAI' in tf.__version__: + ops_dir = os.path.join(ops_dir, '1.12_pai') +elif tf.__version__.startswith('1.12'): + ops_dir = os.path.join(ops_dir, '1.12') +elif tf.__version__.startswith('1.15'): + if 'IS_ON_PAI' in os.environ: + ops_dir = os.path.join(ops_dir, 'DeepRec') + else: + ops_dir = os.path.join(ops_dir, '1.15') +else: + ops_dir = None if tf.__version__ >= '2.0': tf = tf.compat.v1 @@ -23,8 +31,8 @@ class EditDistance(tf.keras.layers.Layer): def __init__(self, params, name='edit_distance', reuse=None, **kwargs): super(EditDistance, self).__init__(name, **kwargs) - - custom_op_path = os.path.join(easy_rec.ops_dir, 'libedit_distance.so') + logging.info("ops_dir is %s" % ops_dir) + custom_op_path = os.path.join(ops_dir, 'libedit_distance.so') try: custom_ops = tf.load_op_library(custom_op_path) logging.info('load edit_distance op from %s succeed' % custom_op_path) diff --git a/easy_rec/python/test/train_eval_test.py b/easy_rec/python/test/train_eval_test.py index a13022fd8..e4b45a041 100644 --- a/easy_rec/python/test/train_eval_test.py +++ b/easy_rec/python/test/train_eval_test.py @@ -403,7 +403,7 @@ def test_highway(self): self.assertTrue(self._success) @unittest.skipIf( - LooseVersion(tf.__version__) < LooseVersion('2.0.0'), + LooseVersion(tf.__version__) >= LooseVersion('2.0.0'), 'EditDistanceOp only work before tf version == 2.0') def test_custom_op(self): self._success = test_utils.test_single_train_eval( diff --git a/easy_rec/version.py b/easy_rec/version.py index edc79a5a5..235c9c2a6 100644 --- a/easy_rec/version.py +++ b/easy_rec/version.py @@ -1,3 +1,3 @@ # -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. -__version__ = '0.7.5' +__version__ = '0.7.6'