52 lines
2.5 KiB
Diff
52 lines
2.5 KiB
Diff
|
|
From f6fde895ef9c77d848061c0517f19d0ec2682f3a Mon Sep 17 00:00:00 2001
|
||
|
|
From: Mihai Maruseac <mihaimaruseac@google.com>
|
||
|
|
Date: Tue, 11 May 2021 18:32:03 -0700
|
||
|
|
Subject: [PATCH] Validate that a and b are proper sparse tensors
|
||
|
|
|
||
|
|
PiperOrigin-RevId: 373274848
|
||
|
|
Change-Id: I3a665ac3a29dee9fb69bdf408a939330cb93ea75
|
||
|
|
---
|
||
|
|
.../kernels/sparse_sparse_binary_op_shared.cc | 15 +++++++++------
|
||
|
|
1 file changed, 9 insertions(+), 6 deletions(-)
|
||
|
|
|
||
|
|
diff --git a/tensorflow/core/kernels/sparse_sparse_binary_op_shared.cc b/tensorflow/core/kernels/sparse_sparse_binary_op_shared.cc
|
||
|
|
index 9fe42e05d879e..eb993a5965043 100644
|
||
|
|
--- a/tensorflow/core/kernels/sparse_sparse_binary_op_shared.cc
|
||
|
|
+++ b/tensorflow/core/kernels/sparse_sparse_binary_op_shared.cc
|
||
|
|
@@ -150,6 +150,7 @@ class SparseSparseBinaryOpShared : public OpKernel {
|
||
|
|
|
||
|
|
const int64 a_nnz = a_indices_t->dim_size(0);
|
||
|
|
const int64 b_nnz = b_indices_t->dim_size(0);
|
||
|
|
+
|
||
|
|
const auto a_values = a_values_t->vec<T>();
|
||
|
|
const auto b_values = b_values_t->vec<T>();
|
||
|
|
|
||
|
|
@@ -166,6 +167,14 @@ class SparseSparseBinaryOpShared : public OpKernel {
|
||
|
|
"Input shapes should be a vector but received shapes ",
|
||
|
|
a_shape_t->shape().DebugString(), " and ",
|
||
|
|
b_shape_t->shape().DebugString()));
|
||
|
|
+ const int num_dims = a_indices_t->dim_size(1);
|
||
|
|
+ OP_REQUIRES(
|
||
|
|
+ ctx, a_shape_t->NumElements() == num_dims,
|
||
|
|
+ errors::InvalidArgument("Second dimension of a_indices and length of "
|
||
|
|
+ "a_shape must match, got ",
|
||
|
|
+ num_dims, " and ", a_shape_t->NumElements()));
|
||
|
|
+ OP_REQUIRES(ctx, num_dims > 0,
|
||
|
|
+ errors::InvalidArgument("Tensors must not be empty"));
|
||
|
|
OP_REQUIRES(ctx, a_shape_t->IsSameSize(*b_shape_t),
|
||
|
|
errors::InvalidArgument(
|
||
|
|
"Operands do not have the same ranks; got shapes: ",
|
||
|
|
@@ -180,12 +189,6 @@ class SparseSparseBinaryOpShared : public OpKernel {
|
||
|
|
" for dimension ", i));
|
||
|
|
}
|
||
|
|
|
||
|
|
- OP_REQUIRES(
|
||
|
|
- ctx, a_indices_t->dim_size(1) == b_indices_t->dim_size(1),
|
||
|
|
- errors::InvalidArgument(
|
||
|
|
- "Indices' dimensions do not match: got ", a_indices_t->dim_size(1),
|
||
|
|
- " and ", b_indices_t->dim_size(1), " for the second dimension."));
|
||
|
|
- const int num_dims = a_indices_t->dim_size(1);
|
||
|
|
const auto a_indices_mat = a_indices_t->matrix<int64>();
|
||
|
|
const auto b_indices_mat = b_indices_t->matrix<int64>();
|
||
|
|
std::vector<T> a_augmented_values, b_augmented_values;
|