Skip to content

Commit

Permalink
[Embedding] Fix shared embedding frequency counting problem. (#962)
Browse files Browse the repository at this point in the history
Signed-off-by: 泊霆 <[email protected]>
Co-authored-by: 泊霆 <[email protected]>
  • Loading branch information
Mesilenceki and Mesilenceki authored Feb 4, 2024
1 parent d84837f commit 2b15e8a
Show file tree
Hide file tree
Showing 12 changed files with 386 additions and 72 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
op {
graph_op_name: "UniqueWithExtraCounts"
visibility: HIDDEN
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
op {
graph_op_name: "UniqueWithExtraCounts"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
op {
graph_op_name: "UniqueWithExtraCounts"
visibility: HIDDEN
}
121 changes: 87 additions & 34 deletions tensorflow/core/kernels/unique_ali_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/task_runner.h"
#include "tensorflow/core/kernels/unique_ali_op_util.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/util/env_var.h"

namespace tensorflow {
Expand All @@ -41,40 +41,43 @@ const char* kStlHashMapString = "STL";
const char* kAbslHashMapString = "ABSL";
const char* kGoogleHashMapString = "GOOGLE";
const int64 kDefaultUniqueRatioHint = 4;
}
} // namespace

template <typename T, typename TIndex>
class UniqueAliOp : public OpKernel {
public:
explicit UniqueAliOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, ReadInt64FromEnvVar(kUniqueOpPartitionSizeEnv,
kPartitionSize, &partition_size_));
OP_REQUIRES(context, partition_size_ > 0,
errors::InvalidArgument("Invaild PARTITION_SIZE=",
partition_size_));
OP_REQUIRES_OK(
context, ReadInt64FromEnvVar(kUniqueOpPartitionSizeEnv, kPartitionSize,
&partition_size_));
OP_REQUIRES(
context, partition_size_ > 0,
errors::InvalidArgument("Invaild PARTITION_SIZE=", partition_size_));

OP_REQUIRES_OK(context, ReadBoolFromEnvVar(kUniqueOpSerialEnv,
false, &serial_));
OP_REQUIRES_OK(context,
ReadBoolFromEnvVar(kUniqueOpSerialEnv, false, &serial_));

// NOTE(zycao>: Hash map insertion and lookup performance is dominating in
// Unique Op. Based on benchmark results, 'google::dense_hash_map' will be
// used as default for most key types except string.
//
// By setting "DEEPREC_UNIQUE_OP_HASH_MAP" environment variable, a particular
// hash map could be seleteed to use. Possible choices are listed below:
// By setting "DEEPREC_UNIQUE_OP_HASH_MAP" environment variable, a
// particular hash map could be seleteed to use. Possible choices are listed
// below:
// "MULTIMAP" for multimap parrallel process,
// "STL" for std::unordred_map,
// "ABSL" for absl::flat_hash_map,
// "GOOGLE" for google::dense_hash_map.
std::string hash_map_str;
OP_REQUIRES_OK(context, ReadStringFromEnvVar(kUniqueOpHashMapEnv,
kGoogleHashMapString,
&hash_map_str));
OP_REQUIRES_OK(
context, ReadStringFromEnvVar(kUniqueOpHashMapEnv, kGoogleHashMapString,
&hash_map_str));
std::transform(hash_map_str.begin(), hash_map_str.end(),
hash_map_str.begin(), ::toupper);

OP_REQUIRES_OK(context, ReadInt64FromEnvVar(kUniqueOpUniqRatioHint,
kDefaultUniqueRatioHint, &unique_ratio_hint_));
kDefaultUniqueRatioHint,
&unique_ratio_hint_));
OP_REQUIRES(context, unique_ratio_hint_ > 0,
errors::InvalidArgument("Invaild ", kUniqueOpUniqRatioHint, "=",
unique_ratio_hint_));
Expand All @@ -83,7 +86,8 @@ class UniqueAliOp : public OpKernel {
map_flag_ = MULTIMAP;
static char print_once = [] {
LOG(INFO) << "MultiMapCompute preserved "
"dense hash map key: " << kPreseverdEmptyKey;
"dense hash map key: "
<< kPreseverdEmptyKey;
return '\0';
}();
} else if (!hash_map_str.compare(kStlHashMapString)) {
Expand All @@ -95,7 +99,6 @@ class UniqueAliOp : public OpKernel {
} else {
map_flag_ = GOOGLE;
}

}

void Compute(OpKernelContext* context) override {
Expand All @@ -110,16 +113,14 @@ class UniqueAliOp : public OpKernel {
Tensor output;
Tensor output_counter;
if (context->num_inputs() == 1) {
UniqueWithoutAxis<T, TIndex>(context, input,
&idx, &output, &output_counter, num_outputs(),
partition_size_, serial_, unique_ratio_hint_,
map_flag_);
UniqueWithoutAxis<T, TIndex>(
context, input, &idx, &output, &output_counter, num_outputs(),
partition_size_, serial_, unique_ratio_hint_, map_flag_);
} else {
const Tensor& axis_tensor = context->input(1);
UniqueWithAxis<T, TIndex>(context, input,
axis_tensor, &idx, &output, &output_counter,
num_outputs(), partition_size_, serial_,
unique_ratio_hint_, map_flag_);
UniqueWithAxis<T, TIndex>(context, input, axis_tensor, &idx, &output,
&output_counter, num_outputs(), partition_size_,
serial_, unique_ratio_hint_, map_flag_);
}
context->set_output(0, output);
context->set_output(1, idx);
Expand All @@ -128,33 +129,65 @@ class UniqueAliOp : public OpKernel {
}
}

protected:
bool serial_ = false;
int64 partition_size_ = 0;
int64 unique_ratio_hint_;
UniqueMaps map_flag_ = GOOGLE; // "GOOGLE" dense hash map is default
};

template <typename T, typename TIndex>
class UniqueWithCountAliOp : public UniqueAliOp<T, TIndex> {
using UniqueAliOp<T, TIndex>::serial_;
using UniqueAliOp<T, TIndex>::partition_size_;
using UniqueAliOp<T, TIndex>::unique_ratio_hint_;
using UniqueAliOp<T, TIndex>::map_flag_;
using OpKernel::num_outputs;

public:
explicit UniqueWithCountAliOp(OpKernelConstruction* context)
: UniqueAliOp<T, TIndex>(context) {
OP_REQUIRES_OK(context, context->GetAttr("N", &num_sparse_));
}

void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
Tensor idx;
Tensor output;
Tensor output_counter;
UniqueWithExtraCounts<T, TIndex>(
context, input, &idx, &output, &output_counter, num_outputs(),
partition_size_, serial_, unique_ratio_hint_, num_sparse_, map_flag_);
context->set_output(0, output);
context->set_output(1, idx);
context->set_output(2, output_counter);
}

private:
int num_sparse_;
};

#define REGISTER_UNIQUE(type) \
REGISTER_KERNEL_BUILDER(Name("Unique") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("out_idx"), \
UniqueAliOp<type, int32>); \
UniqueAliOp<type, int32>) \
REGISTER_KERNEL_BUILDER(Name("Unique") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("out_idx"), \
UniqueAliOp<type, int64>); \
UniqueAliOp<type, int64>) \
REGISTER_KERNEL_BUILDER(Name("UniqueV2") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("out_idx"), \
UniqueAliOp<type, int32>); \
UniqueAliOp<type, int32>) \
REGISTER_KERNEL_BUILDER(Name("UniqueV2") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("out_idx"), \
UniqueAliOp<type, int64>); \
UniqueAliOp<type, int64>) \
REGISTER_KERNEL_BUILDER(Name("UniqueWithCounts") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
Expand All @@ -164,7 +197,7 @@ class UniqueAliOp : public OpKernel {
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("out_idx"), \
UniqueAliOp<type, int64>); \
UniqueAliOp<type, int64>) \
REGISTER_KERNEL_BUILDER(Name("UniqueWithCountsV2") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
Expand All @@ -174,7 +207,17 @@ class UniqueAliOp : public OpKernel {
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("out_idx"), \
UniqueAliOp<type, int64>)
UniqueAliOp<type, int64>) \
REGISTER_KERNEL_BUILDER(Name("UniqueWithExtraCounts") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("out_idx"), \
UniqueWithCountAliOp<type, int32>) \
REGISTER_KERNEL_BUILDER(Name("UniqueWithExtraCounts") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("out_idx"), \
UniqueWithCountAliOp<type, int64>)
TF_CALL_REAL_NUMBER_TYPES(REGISTER_UNIQUE);
REGISTER_UNIQUE(string)
#undef REGISTER_UNIQUE
Expand All @@ -198,12 +241,22 @@ REGISTER_UNIQUE(string)
.HostMemory("count") \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("out_idx"), \
UniqueAliOp<type, int64>);
UniqueAliOp<type, int64>) \
REGISTER_KERNEL_BUILDER(Name("UniqueWithExtraCounts") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int32>("out_idx"), \
UniqueWithCountAliOp<type, int32>) \
REGISTER_KERNEL_BUILDER(Name("UniqueWithExtraCounts") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<int64>("out_idx"), \
UniqueWithCountAliOp<type, int64>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_UNIQUE);
REGISTER_UNIQUE(string)
#undef REGISTER_UNIQUE
#endif //GOOGLE_CUDA
#endif // GOOGLE_CUDA

#ifdef TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(Name("Unique")
.Device(DEVICE_SYCL)
Expand Down
Loading

0 comments on commit 2b15e8a

Please sign in to comment.