From a776040a5e7ebf76eeb7eb923bf1ae417dd4d233 Mon Sep 17 00:00:00 2001 From: Laura Pak Date: Mon, 12 Jul 2021 11:55:27 -0700 Subject: [PATCH] Disallow dims input of 0 in tf.raw_ops.UnravelIndex PiperOrigin-RevId: 384284198 Change-Id: Ia1804ef1aec57b4d857ea507e6891bcccde18e9b --- tensorflow/core/kernels/unravel_index_op.cc | 9 +++++++++ tensorflow/python/kernel_tests/array_ops_test.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/unravel_index_op.cc b/tensorflow/core/kernels/unravel_index_op.cc index b45ff5e5..fff4527d 100644 --- a/tensorflow/core/kernels/unravel_index_op.cc +++ b/tensorflow/core/kernels/unravel_index_op.cc @@ -53,6 +53,15 @@ class UnravelIndexOp : public OpKernel { dims_tensor.shape().DebugString(), "\"")); auto dims = dims_tensor.vec(); + // Make sure dims does not contain a zero + for (int i = 0; i < dims.size(); i++) { + OP_REQUIRES( + ctx, dims(i) != 0, + errors::InvalidArgument("Input dims cannot contain a dim of zero, " + "but dims contains zero at index ", + i)); + + } // Chek to make sure indices is not out of boundary Eigen::Tensor dims_prod_eigen = dims.prod(); diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index dbff3a1b..6ce6a9e6 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -1441,7 +1441,7 @@ class UnravelIndexTest(test_util.TensorFlowTestCase): with self.cached_session(): for dtype in [dtypes.int32, dtypes.int64]: with self.assertRaisesRegexp(errors.InvalidArgumentError, - "index is out of bound as with dims"): + "dims cannot contain a dim of zero"): indices = constant_op.constant([2, 5, 7], dtype=dtype) dims = constant_op.constant([3, 0], dtype=dtype) self.evaluate(array_ops.unravel_index(indices=indices, dims=dims)) -- 2.27.0