49 lines
2.1 KiB
Diff
49 lines
2.1 KiB
Diff
|
|
From a776040a5e7ebf76eeb7eb923bf1ae417dd4d233 Mon Sep 17 00:00:00 2001
|
||
|
|
From: Laura Pak <lpak@google.com>
|
||
|
|
Date: Mon, 12 Jul 2021 11:55:27 -0700
|
||
|
|
Subject: [PATCH] Disallow dims input of 0 in tf.raw_ops.UnravelIndex
|
||
|
|
|
||
|
|
PiperOrigin-RevId: 384284198
|
||
|
|
Change-Id: Ia1804ef1aec57b4d857ea507e6891bcccde18e9b
|
||
|
|
---
|
||
|
|
tensorflow/core/kernels/unravel_index_op.cc | 9 +++++++++
|
||
|
|
tensorflow/python/kernel_tests/array_ops_test.py | 2 +-
|
||
|
|
2 files changed, 10 insertions(+), 1 deletion(-)
|
||
|
|
|
||
|
|
diff --git a/tensorflow/core/kernels/unravel_index_op.cc b/tensorflow/core/kernels/unravel_index_op.cc
|
||
|
|
index b45ff5e5..fff4527d 100644
|
||
|
|
--- a/tensorflow/core/kernels/unravel_index_op.cc
|
||
|
|
+++ b/tensorflow/core/kernels/unravel_index_op.cc
|
||
|
|
@@ -53,6 +53,15 @@ class UnravelIndexOp : public OpKernel {
|
||
|
|
dims_tensor.shape().DebugString(), "\""));
|
||
|
|
|
||
|
|
auto dims = dims_tensor.vec<Tidx>();
|
||
|
|
+ // Make sure dims does not contain a zero
|
||
|
|
+ for (int i = 0; i < dims.size(); i++) {
|
||
|
|
+ OP_REQUIRES(
|
||
|
|
+ ctx, dims(i) != 0,
|
||
|
|
+ errors::InvalidArgument("Input dims cannot contain a dim of zero, "
|
||
|
|
+ "but dims contains zero at index ",
|
||
|
|
+ i));
|
||
|
|
+
|
||
|
|
+ }
|
||
|
|
|
||
|
|
// Chek to make sure indices is not out of boundary
|
||
|
|
Eigen::Tensor<Tidx, 0, Eigen::RowMajor> dims_prod_eigen = dims.prod();
|
||
|
|
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
|
||
|
|
index dbff3a1b..6ce6a9e6 100644
|
||
|
|
--- a/tensorflow/python/kernel_tests/array_ops_test.py
|
||
|
|
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
|
||
|
|
@@ -1441,7 +1441,7 @@ class UnravelIndexTest(test_util.TensorFlowTestCase):
|
||
|
|
with self.cached_session():
|
||
|
|
for dtype in [dtypes.int32, dtypes.int64]:
|
||
|
|
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||
|
|
- "index is out of bound as with dims"):
|
||
|
|
+ "dims cannot contain a dim of zero"):
|
||
|
|
indices = constant_op.constant([2, 5, 7], dtype=dtype)
|
||
|
|
dims = constant_op.constant([3, 0], dtype=dtype)
|
||
|
|
self.evaluate(array_ops.unravel_index(indices=indices, dims=dims))
|
||
|
|
--
|
||
|
|
2.27.0
|
||
|
|
|