70 lines
3.7 KiB
Diff
70 lines
3.7 KiB
Diff
From 02cc160e29d20631de3859c6653184e3f876b9d7 Mon Sep 17 00:00:00 2001
|
|
From: Mihai Maruseac <mihaimaruseac@google.com>
|
|
Date: Tue, 3 Aug 2021 15:51:47 -0700
|
|
Subject: [PATCH] Prevent nullptr deref in SparseTensorSliceDataset
|
|
|
|
The arguments must determine a valid sparse tensor. This means that when indices are empty then the values must be empty too (and the reverse).
|
|
|
|
Also added test, by modifying existing test with empty sparse tensor to now run with an invalid sparse tensor input.
|
|
|
|
PiperOrigin-RevId: 388562757
|
|
Change-Id: Id8b54cd7c2316025b4f9a77292c8fb5344d17609
|
|
---
|
|
.../data/sparse_tensor_slice_dataset_op.cc | 11 ++++++++++
|
|
.../from_sparse_tensor_slices_test.py | 20 +++++++++++++++++++
|
|
2 files changed, 31 insertions(+)
|
|
|
|
diff --git a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
|
|
index 00b71d41a7ca1..58bb4b0b6d806 100644
|
|
--- a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
|
|
+++ b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
|
|
@@ -241,6 +241,17 @@ class SparseTensorSliceDatasetOp : public DatasetOpKernel {
|
|
errors::InvalidArgument(
|
|
"Input indices should be a matrix but received shape ",
|
|
indices->shape().DebugString()));
|
|
+
|
|
+ const auto num_indices = indices->NumElements();
|
|
+ const auto num_values = values->NumElements();
|
|
+ if (num_indices == 0 || num_values == 0) {
|
|
+ OP_REQUIRES(ctx, num_indices == num_values,
|
|
+ errors::InvalidArgument(
|
|
+ "If indices or values are empty, the other one must also "
|
|
+ "be. Got indices of shape ",
|
|
+ indices->shape().DebugString(), " and values of shape ",
|
|
+ values->shape().DebugString()));
|
|
+ }
|
|
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(values->shape()),
|
|
errors::InvalidArgument(
|
|
"Input values should be a vector but received shape ",
|
|
diff --git a/tensorflow/python/data/kernel_tests/from_sparse_tensor_slices_test.py b/tensorflow/python/data/kernel_tests/from_sparse_tensor_slices_test.py
|
|
index 25ecbf20680b3..8f93530010cae 100644
|
|
--- a/tensorflow/python/data/kernel_tests/from_sparse_tensor_slices_test.py
|
|
+++ b/tensorflow/python/data/kernel_tests/from_sparse_tensor_slices_test.py
|
|
@@ -118,6 +118,26 @@ def testEmptySparseTensorSlices(self):
|
|
with self.assertRaises(errors.OutOfRangeError):
|
|
sess.run(get_next)
|
|
|
|
+ @combinations.generate(combinations.combine(tf_api_version=1, mode=["graph"]))
|
|
+ def testEmptySparseTensorSlicesInvalid(self):
|
|
+ """Test a dataset based on invalid `tf.sparse.SparseTensor`."""
|
|
+ st = array_ops.sparse_placeholder(dtypes.float64)
|
|
+ iterator = dataset_ops.make_initializable_iterator(
|
|
+ dataset_ops.Dataset.from_sparse_tensor_slices(st))
|
|
+ init_op = iterator.initializer
|
|
+
|
|
+ with self.cached_session() as sess:
|
|
+ # Test with an empty sparse tensor but with non empty values.
|
|
+ empty_indices = np.empty((0, 4), dtype=np.int64)
|
|
+ non_empty_values = [1, 2, 3, 4]
|
|
+ empty_dense_shape = [0, 4, 37, 9]
|
|
+ sparse_feed = sparse_tensor.SparseTensorValue(empty_indices,
|
|
+ non_empty_values,
|
|
+ empty_dense_shape)
|
|
+ # Here, we expect the test to fail when running the feed.
|
|
+ with self.assertRaises(errors.InvalidArgumentError):
|
|
+ sess.run(init_op, feed_dict={st: sparse_feed})
|
|
+
|
|
@combinations.generate(combinations.combine(tf_api_version=2, mode=["eager"]))
|
|
def testFromSparseTensorSlicesError(self):
|
|
with self.assertRaises(AttributeError):
|