73 lines
3.3 KiB
Diff
73 lines
3.3 KiB
Diff
From b761c9b652af2107cfbc33efd19be0ce41daa33e Mon Sep 17 00:00:00 2001
|
|
From: Amit Patankar <amitpatankar@google.com>
|
|
Date: Thu, 15 Apr 2021 13:28:49 -0700
|
|
Subject: [PATCH] Fix `tf.raw_ops.RaggedTensorToTensor` failing CHECK.
|
|
|
|
PiperOrigin-RevId: 368706628
|
|
Change-Id: I5c9ea4833f38835ee183ca50d63251dc89c9f3bc
|
|
---
|
|
.../kernels/ragged_tensor_to_tensor_op.cc | 20 ++++++++++---------
|
|
1 file changed, 11 insertions(+), 9 deletions(-)
|
|
|
|
diff --git a/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc b/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc
|
|
index 433d910f6090c..434c853b63daa 100644
|
|
--- a/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc
|
|
+++ b/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc
|
|
@@ -208,7 +208,7 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
|
|
}
|
|
|
|
void CalculateOutputIndexRowSplit(
|
|
- const RowPartitionTensor& row_split,
|
|
+ OpKernelContext* context, const RowPartitionTensor& row_split,
|
|
const vector<INDEX_TYPE>& parent_output_index,
|
|
INDEX_TYPE output_index_multiplier, INDEX_TYPE output_size,
|
|
vector<INDEX_TYPE>* result) {
|
|
@@ -233,7 +233,8 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
|
|
}
|
|
}
|
|
if (row_split_size > 0) {
|
|
- DCHECK_EQ(result->size(), row_split(row_split_size - 1));
|
|
+ OP_REQUIRES(context, result->size() == row_split(row_split_size - 1),
|
|
+ errors::InvalidArgument("Invalid row split size."));
|
|
}
|
|
}
|
|
|
|
@@ -259,7 +260,7 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
|
|
// result[7] = -1 because parent_output_index[value_rowids[6]] == -1
|
|
// result[8] = parent_output_index[value_rowids[7]]
|
|
void CalculateOutputIndexValueRowID(
|
|
- const RowPartitionTensor& value_rowids,
|
|
+ OpKernelContext* context, const RowPartitionTensor& value_rowids,
|
|
const vector<INDEX_TYPE>& parent_output_index,
|
|
INDEX_TYPE output_index_multiplier, INDEX_TYPE output_size,
|
|
vector<INDEX_TYPE>* result) {
|
|
@@ -293,7 +294,8 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
|
|
}
|
|
result->push_back(current_output_index);
|
|
}
|
|
- DCHECK_EQ(result->size(), value_rowids.size());
|
|
+ OP_REQUIRES(context, result->size() == value_rowids.size(),
|
|
+ errors::InvalidArgument("Invalid row ids."));
|
|
}
|
|
|
|
Status CalculateOutputIndex(OpKernelContext* context, int dimension,
|
|
@@ -307,13 +309,13 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
|
|
switch (partition_type) {
|
|
case RowPartitionType::VALUE_ROWIDS:
|
|
CalculateOutputIndexValueRowID(
|
|
- row_partition_tensor, parent_output_index, output_index_multiplier,
|
|
- output_size, result);
|
|
+ context, row_partition_tensor, parent_output_index,
|
|
+ output_index_multiplier, output_size, result);
|
|
return tensorflow::Status::OK();
|
|
case RowPartitionType::ROW_SPLITS:
|
|
- CalculateOutputIndexRowSplit(row_partition_tensor, parent_output_index,
|
|
- output_index_multiplier, output_size,
|
|
- result);
|
|
+ CalculateOutputIndexRowSplit(
|
|
+ context, row_partition_tensor, parent_output_index,
|
|
+ output_index_multiplier, output_size, result);
|
|
return tensorflow::Status::OK();
|
|
default:
|
|
return errors::InvalidArgument(
|