From 1071f554dbd09f7e101324d366eec5f4fe5a3ece Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Thu, 29 Jul 2021 18:23:29 -0700 Subject: [PATCH] Add missing validation to `RaggedTensorToSparse`. There needs to be a check that the splits allow for valid ragged tensors. PiperOrigin-RevId: 387712169 Change-Id: I2499175324b82b65d159a260c7f83b98ceb5cc7d --- .../core/kernels/ragged_tensor_to_sparse_kernel.cc | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/ragged_tensor_to_sparse_kernel.cc b/tensorflow/core/kernels/ragged_tensor_to_sparse_kernel.cc index a47ff372cf087..639280a26b40d 100644 --- a/tensorflow/core/kernels/ragged_tensor_to_sparse_kernel.cc +++ b/tensorflow/core/kernels/ragged_tensor_to_sparse_kernel.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/errors.h" namespace tensorflow { @@ -38,7 +39,8 @@ class RaggedTensorToSparseOp : public OpKernel { OP_REQUIRES_OK( context, context->input_list("rt_nested_splits", &rt_nested_splits_in)); const int rt_nested_splits_len = rt_nested_splits_in.size(); - DCHECK_GT(rt_nested_splits_len, 0); // Enforced by REGISTER_OP. + OP_REQUIRES(context, rt_nested_splits_len > 0, + errors::InvalidArgument("rt_nested_splits must be non empty")); std::vector rt_nested_splits; rt_nested_splits.reserve(rt_nested_splits_len); for (int i = 0; i < rt_nested_splits_len; ++i) { @@ -162,6 +164,14 @@ class RaggedTensorToSparseOp : public OpKernel { if (rt_nested_splits[i](0) != 0) { return InvalidArgument("First value of ragged splits must be 0."); } + for (int j = 1; j < rt_nested_splits[i].size(); ++j) { + if (rt_nested_splits[i](j) < rt_nested_splits[i](j - 1)) { + return InvalidArgument( + "Ragged splits should be non decreasing, but we got ", + rt_nested_splits[i](j - 1), " followed by ", + rt_nested_splits[i](j)); + } + } if (i > 0) { SPLITS_TYPE last_split = rt_nested_splits[i - 1](rt_nested_splits[i - 1].size() - 1);