51 lines
2.3 KiB
Diff
51 lines
2.3 KiB
Diff
|
|
From 1071f554dbd09f7e101324d366eec5f4fe5a3ece Mon Sep 17 00:00:00 2001
|
||
|
|
From: Mihai Maruseac <mihaimaruseac@google.com>
|
||
|
|
Date: Thu, 29 Jul 2021 18:23:29 -0700
|
||
|
|
Subject: [PATCH] Add missing validation to `RaggedTensorToSparse`.
|
||
|
|
|
||
|
|
There needs to be a check that the splits allow for valid ragged tensors.
|
||
|
|
|
||
|
|
PiperOrigin-RevId: 387712169
|
||
|
|
Change-Id: I2499175324b82b65d159a260c7f83b98ceb5cc7d
|
||
|
|
---
|
||
|
|
.../core/kernels/ragged_tensor_to_sparse_kernel.cc | 12 +++++++++++-
|
||
|
|
1 file changed, 11 insertions(+), 1 deletion(-)
|
||
|
|
|
||
|
|
diff --git a/tensorflow/core/kernels/ragged_tensor_to_sparse_kernel.cc b/tensorflow/core/kernels/ragged_tensor_to_sparse_kernel.cc
|
||
|
|
index a47ff372cf087..639280a26b40d 100644
|
||
|
|
--- a/tensorflow/core/kernels/ragged_tensor_to_sparse_kernel.cc
|
||
|
|
+++ b/tensorflow/core/kernels/ragged_tensor_to_sparse_kernel.cc
|
||
|
|
@@ -21,6 +21,7 @@ limitations under the License.
|
||
|
|
#include "tensorflow/core/framework/register_types.h"
|
||
|
|
#include "tensorflow/core/framework/tensor.h"
|
||
|
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||
|
|
+#include "tensorflow/core/platform/errors.h"
|
||
|
|
|
||
|
|
namespace tensorflow {
|
||
|
|
|
||
|
|
@@ -38,7 +39,8 @@ class RaggedTensorToSparseOp : public OpKernel {
|
||
|
|
OP_REQUIRES_OK(
|
||
|
|
context, context->input_list("rt_nested_splits", &rt_nested_splits_in));
|
||
|
|
const int rt_nested_splits_len = rt_nested_splits_in.size();
|
||
|
|
- DCHECK_GT(rt_nested_splits_len, 0); // Enforced by REGISTER_OP.
|
||
|
|
+ OP_REQUIRES(context, rt_nested_splits_len > 0,
|
||
|
|
+ errors::InvalidArgument("rt_nested_splits must be non empty"));
|
||
|
|
std::vector<ConstFlatSplits> rt_nested_splits;
|
||
|
|
rt_nested_splits.reserve(rt_nested_splits_len);
|
||
|
|
for (int i = 0; i < rt_nested_splits_len; ++i) {
|
||
|
|
@@ -162,6 +164,14 @@ class RaggedTensorToSparseOp : public OpKernel {
|
||
|
|
if (rt_nested_splits[i](0) != 0) {
|
||
|
|
return InvalidArgument("First value of ragged splits must be 0.");
|
||
|
|
}
|
||
|
|
+ for (int j = 1; j < rt_nested_splits[i].size(); ++j) {
|
||
|
|
+ if (rt_nested_splits[i](j) < rt_nested_splits[i](j - 1)) {
|
||
|
|
+ return InvalidArgument(
|
||
|
|
+ "Ragged splits should be non decreasing, but we got ",
|
||
|
|
+ rt_nested_splits[i](j - 1), " followed by ",
|
||
|
|
+ rt_nested_splits[i](j));
|
||
|
|
+ }
|
||
|
|
+ }
|
||
|
|
if (i > 0) {
|
||
|
|
SPLITS_TYPE last_split =
|
||
|
|
rt_nested_splits[i - 1](rt_nested_splits[i - 1].size() - 1);
|