Skip to content

Commit

Permalink
[Op] Update IteratorV2 OpDef
Browse files Browse the repository at this point in the history
Signed-off-by: JunqiHu <[email protected]>
  • Loading branch information
Mesilenceki committed Oct 25, 2023
1 parent 648f694 commit daaf7a6
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 2 deletions.
7 changes: 6 additions & 1 deletion tensorflow/core/kernels/data/iterator_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ IteratorHandleOp::IteratorHandleOp(OpKernelConstruction* ctx)
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("recoverable", &recoverable_));
}

// The resource is deleted from the resource manager only when it is private
Expand Down Expand Up @@ -308,7 +309,11 @@ void IteratorHandleOp::Compute(OpKernelContext* context) LOCKS_EXCLUDED(mu_) {
}

ResourceMgr* mgr = context->resource_manager();
OP_REQUIRES_OK(context, cinfo_.Init(mgr, def(), true));
if (recoverable_ == false) {
OP_REQUIRES_OK(context, cinfo_.Init(mgr, def(), false));
} else {
OP_REQUIRES_OK(context, cinfo_.Init(mgr, def(), true));
}

IteratorResource* resource;
OP_REQUIRES_OK(
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/kernels/data/iterator_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ class IteratorHandleOp : public OpKernel {
std::vector<PartialTensorShape> output_shapes_;
const int graph_def_version_;
string name_;
bool recoverable_;
};

// Like IteratorHandleOp, but creates handles which are never shared, and does
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/data/multi_device_iterator_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ class MultiDeviceIteratorHandleOp : public OpKernel {
std::unique_ptr<FunctionHandleCache> function_handle_cache =
absl::make_unique<FunctionHandleCache>(flr);
ResourceMgr* mgr = context->resource_manager();
OP_REQUIRES_OK(context, cinfo_.Init(mgr, def(), true));
OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));

MultiDeviceIterator* resource;

Expand Down
11 changes: 11 additions & 0 deletions tensorflow/core/ops/dataset_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -555,13 +555,24 @@ REGISTER_OP("Iterator")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);

#ifndef TF_API_COMPATIBLE_1150
REGISTER_OP("IteratorV2")
.Output("handle: resource")
.Attr("shared_name: string")
.Attr("container: string")
.Attr("recoverable: bool = false")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
#else
REGISTER_OP("IteratorV2")
.Output("handle: resource")
.Attr("shared_name: string")
.Attr("container: string")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
#endif

REGISTER_OP("AnonymousIterator")
.Output("handle: resource")
Expand Down

0 comments on commit daaf7a6

Please sign in to comment.