118 lines
5.5 KiB
Diff
118 lines
5.5 KiB
Diff
|
|
From e6a7c7cc18c3aaad1ae0872cb0a959f5c923d2bd Mon Sep 17 00:00:00 2001
|
||
|
|
From: Mihai Maruseac <mihaimaruseac@google.com>
|
||
|
|
Date: Tue, 20 Apr 2021 14:45:33 -0700
|
||
|
|
Subject: [PATCH] Remove `OP_REQUIRES` call from helper function.
|
||
|
|
|
||
|
|
Since `OP_REQUIRES` macro expands to a `return;` (among other), calling it in a helper function only ends the helper function's execution earlier, but the kernel will still run from start to end. Thus, all the expected validations are actually broken/useless as the code ploughs through the next crash anyway.
|
||
|
|
|
||
|
|
PiperOrigin-RevId: 369524386
|
||
|
|
Change-Id: I54f6cf9328445675ccc392e661b04336b229c9da
|
||
|
|
---
|
||
|
|
.../core/kernels/sparse/sparse_cholesky_op.cc | 67 ++++++++++---------
|
||
|
|
1 file changed, 34 insertions(+), 33 deletions(-)
|
||
|
|
|
||
|
|
diff --git a/tensorflow/core/kernels/sparse/sparse_cholesky_op.cc b/tensorflow/core/kernels/sparse/sparse_cholesky_op.cc
|
||
|
|
index 9a939276f0b6c..47ab252317de5 100644
|
||
|
|
--- a/tensorflow/core/kernels/sparse/sparse_cholesky_op.cc
|
||
|
|
+++ b/tensorflow/core/kernels/sparse/sparse_cholesky_op.cc
|
||
|
|
@@ -17,6 +17,8 @@ limitations under the License.
|
||
|
|
#include <numeric>
|
||
|
|
#include <vector>
|
||
|
|
|
||
|
|
+#include "tensorflow/core/framework/op_requires.h"
|
||
|
|
+
|
||
|
|
#define EIGEN_USE_THREADS
|
||
|
|
|
||
|
|
#include "third_party/eigen3/Eigen/Core"
|
||
|
|
@@ -82,8 +84,8 @@ class CSRSparseCholeskyCPUOp : public OpKernel {
|
||
|
|
|
||
|
|
int64 num_rows;
|
||
|
|
int batch_size;
|
||
|
|
- ValidateInputs(ctx, *input_matrix, input_permutation_indices, &batch_size,
|
||
|
|
- &num_rows);
|
||
|
|
+ OP_REQUIRES_OK(ctx, ValidateInputs(*input_matrix, input_permutation_indices,
|
||
|
|
+ &batch_size, &num_rows));
|
||
|
|
|
||
|
|
// Allocate batch pointers.
|
||
|
|
Tensor batch_ptr(cpu_allocator(), DT_INT32, TensorShape({batch_size + 1}));
|
||
|
|
@@ -226,49 +228,48 @@ class CSRSparseCholeskyCPUOp : public OpKernel {
|
||
|
|
}
|
||
|
|
|
||
|
|
private:
|
||
|
|
- void ValidateInputs(OpKernelContext* ctx,
|
||
|
|
- const CSRSparseMatrix& sparse_matrix,
|
||
|
|
- const Tensor& permutation_indices, int* batch_size,
|
||
|
|
- int64* num_rows) {
|
||
|
|
- OP_REQUIRES(ctx, sparse_matrix.dtype() == DataTypeToEnum<T>::value,
|
||
|
|
- errors::InvalidArgument(
|
||
|
|
- "Asked for a CSRSparseMatrix of type ",
|
||
|
|
- DataTypeString(DataTypeToEnum<T>::value),
|
||
|
|
- " but saw dtype: ", DataTypeString(sparse_matrix.dtype())));
|
||
|
|
+ Status ValidateInputs(const CSRSparseMatrix& sparse_matrix,
|
||
|
|
+ const Tensor& permutation_indices, int* batch_size,
|
||
|
|
+ int64* num_rows) {
|
||
|
|
+ if (sparse_matrix.dtype() != DataTypeToEnum<T>::value)
|
||
|
|
+ return errors::InvalidArgument(
|
||
|
|
+ "Asked for a CSRSparseMatrix of type ",
|
||
|
|
+ DataTypeString(DataTypeToEnum<T>::value),
|
||
|
|
+ " but saw dtype: ", DataTypeString(sparse_matrix.dtype()));
|
||
|
|
|
||
|
|
const Tensor& dense_shape = sparse_matrix.dense_shape();
|
||
|
|
const int rank = dense_shape.dim_size(0);
|
||
|
|
- OP_REQUIRES(ctx, rank == 2 || rank == 3,
|
||
|
|
- errors::InvalidArgument("sparse matrix must have rank 2 or 3; ",
|
||
|
|
- "but dense_shape has size ", rank));
|
||
|
|
+ if (rank < 2 || rank > 3)
|
||
|
|
+ return errors::InvalidArgument("sparse matrix must have rank 2 or 3; ",
|
||
|
|
+ "but dense_shape has size ", rank);
|
||
|
|
const int row_dim = (rank == 2) ? 0 : 1;
|
||
|
|
auto dense_shape_vec = dense_shape.vec<int64>();
|
||
|
|
*num_rows = dense_shape_vec(row_dim);
|
||
|
|
const int64 num_cols = dense_shape_vec(row_dim + 1);
|
||
|
|
- OP_REQUIRES(ctx, *num_rows == num_cols,
|
||
|
|
- errors::InvalidArgument("sparse matrix must be square; got: ",
|
||
|
|
- *num_rows, " != ", num_cols));
|
||
|
|
+ if (*num_rows != num_cols)
|
||
|
|
+ return errors::InvalidArgument(
|
||
|
|
+ "sparse matrix must be square; got: ", *num_rows, " != ", num_cols);
|
||
|
|
const TensorShape& perm_shape = permutation_indices.shape();
|
||
|
|
- OP_REQUIRES(
|
||
|
|
- ctx, perm_shape.dims() + 1 == rank,
|
||
|
|
- errors::InvalidArgument(
|
||
|
|
- "sparse matrix must have the same rank as permutation; got: ", rank,
|
||
|
|
- " != ", perm_shape.dims(), " + 1."));
|
||
|
|
- OP_REQUIRES(
|
||
|
|
- ctx, perm_shape.dim_size(rank - 2) == *num_rows,
|
||
|
|
- errors::InvalidArgument(
|
||
|
|
- "permutation must have the same number of elements in each batch "
|
||
|
|
- "as the number of rows in sparse matrix; got: ",
|
||
|
|
- perm_shape.dim_size(rank - 2), " != ", *num_rows));
|
||
|
|
+ if (perm_shape.dims() + 1 != rank)
|
||
|
|
+ return errors::InvalidArgument(
|
||
|
|
+ "sparse matrix must have the same rank as permutation; got: ", rank,
|
||
|
|
+ " != ", perm_shape.dims(), " + 1.");
|
||
|
|
+ if (perm_shape.dim_size(rank - 2) != *num_rows)
|
||
|
|
+ return errors::InvalidArgument(
|
||
|
|
+ "permutation must have the same number of elements in each batch "
|
||
|
|
+ "as the number of rows in sparse matrix; got: ",
|
||
|
|
+ perm_shape.dim_size(rank - 2), " != ", *num_rows);
|
||
|
|
|
||
|
|
*batch_size = sparse_matrix.batch_size();
|
||
|
|
if (*batch_size > 1) {
|
||
|
|
- OP_REQUIRES(
|
||
|
|
- ctx, perm_shape.dim_size(0) == *batch_size,
|
||
|
|
- errors::InvalidArgument("permutation must have the same batch size "
|
||
|
|
- "as sparse matrix; got: ",
|
||
|
|
- perm_shape.dim_size(0), " != ", *batch_size));
|
||
|
|
+ if (perm_shape.dim_size(0) != *batch_size)
|
||
|
|
+ return errors::InvalidArgument(
|
||
|
|
+ "permutation must have the same batch size "
|
||
|
|
+ "as sparse matrix; got: ",
|
||
|
|
+ perm_shape.dim_size(0), " != ", *batch_size);
|
||
|
|
}
|
||
|
|
+
|
||
|
|
+ return Status::OK();
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|