From 701cfaca222a82afbeeb17496bd718baa65a67d2 Mon Sep 17 00:00:00 2001 From: Robert Neale Date: Tue, 26 Oct 2021 16:50:02 -0700 Subject: [PATCH] Fix heap out of bounds error in tf.raw_ops.SparseCountSparseOutput shape inference when it is called with invalid inputs, and add a test for it. PiperOrigin-RevId: 405766415 Change-Id: I77d244ef35f351ef7b6f821efd959cac2c66db24 --- tensorflow/core/ops/count_ops.cc | 2 ++ tensorflow/python/ops/bincount_ops_test.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/tensorflow/core/ops/count_ops.cc b/tensorflow/core/ops/count_ops.cc index 4f9631310df92..aa6c0437337af 100644 --- a/tensorflow/core/ops/count_ops.cc +++ b/tensorflow/core/ops/count_ops.cc @@ -41,6 +41,8 @@ Status DenseCountSparseOutputShapeFn(InferenceContext *c) { } Status SparseCountSparseOutputShapeFn(InferenceContext *c) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); auto rank = c->Dim(c->input(0), 1); auto nvals = c->UnknownDim(); c->set_output(0, c->Matrix(nvals, rank)); // out.indices diff --git a/tensorflow/python/ops/bincount_ops_test.py b/tensorflow/python/ops/bincount_ops_test.py index 3c7a2a5da9daf..de7d1423870d7 100644 --- a/tensorflow/python/ops/bincount_ops_test.py +++ b/tensorflow/python/ops/bincount_ops_test.py @@ -831,6 +831,25 @@ def test_ragged_input_different_shape_fails(self): self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1)) +class RawOpsHeapOobTest(test.TestCase, parameterized.TestCase): + + @test_util.run_v1_only("Test security error") + def testSparseCountSparseOutputBadIndicesShapeTooSmall(self): + indices = [1] + values = [[1]] + weights = [] + dense_shape = [10] + with self.assertRaisesRegex(ValueError, + "Shape must be rank 2 but is rank 1 for"): + self.evaluate( + gen_count_ops.SparseCountSparseOutput( + indices=indices, + values=values, + dense_shape=dense_shape, + weights=weights, + binary_output=True)) + + @test_util.run_all_in_graph_and_eager_modes @test_util.disable_tfrt class RawOpsTest(test.TestCase, parameterized.TestCase):