26 lines
1.1 KiB
Diff
26 lines
1.1 KiB
Diff
From 704866eabe03a9aeda044ec91a8d0c83fc1ebdbe Mon Sep 17 00:00:00 2001
|
|
From: Amit Patankar <amitpatankar@google.com>
|
|
Date: Tue, 27 Apr 2021 14:41:40 -0700
|
|
Subject: [PATCH] Fix overflow CHECK issue with
|
|
`tf.raw_ops.UnsortedSegmentJoin`.
|
|
|
|
PiperOrigin-RevId: 370766155
|
|
Change-Id: I33e7c6626224e1060a8a4ab51ad5d861c6d4c63e
|
|
---
|
|
tensorflow/core/kernels/unsorted_segment_join_op.cc | 2 ++
|
|
1 file changed, 2 insertions(+)
|
|
|
|
diff --git a/tensorflow/core/kernels/unsorted_segment_join_op.cc b/tensorflow/core/kernels/unsorted_segment_join_op.cc
|
|
index 7464e165e46c8..9acfe7fb1e495 100644
|
|
--- a/tensorflow/core/kernels/unsorted_segment_join_op.cc
|
|
+++ b/tensorflow/core/kernels/unsorted_segment_join_op.cc
|
|
@@ -90,6 +90,8 @@ class UnsortedSegmentJoinOp : public OpKernel {
|
|
const int32 segment_dims = segment_id_shape.dims();
|
|
|
|
const Tensor& num_segments_tensor = context->input(2);
|
|
+ OP_REQUIRES(context, num_segments_tensor.NumElements() != 0,
|
|
+ errors::InvalidArgument("Number of segments cannot be empty."));
|
|
auto num_segments = num_segments_tensor.scalar<NUM_SEGMENTS_TYPE>()();
|
|
|
|
OP_REQUIRES(context, segment_dims != 0,
|