58 lines
2.2 KiB
Diff
58 lines
2.2 KiB
Diff
|
|
From 701cfaca222a82afbeeb17496bd718baa65a67d2 Mon Sep 17 00:00:00 2001
|
||
|
|
From: Robert Neale <rneale@google.com>
|
||
|
|
Date: Tue, 26 Oct 2021 16:50:02 -0700
|
||
|
|
Subject: [PATCH] Fix heap out of bounds error in
|
||
|
|
tf.raw_ops.SparseCountSparseOutput shape inference when it is called with
|
||
|
|
invalid inputs, and add a test for it.
|
||
|
|
|
||
|
|
PiperOrigin-RevId: 405766415
|
||
|
|
Change-Id: I77d244ef35f351ef7b6f821efd959cac2c66db24
|
||
|
|
---
|
||
|
|
tensorflow/core/ops/count_ops.cc | 2 ++
|
||
|
|
tensorflow/python/ops/bincount_ops_test.py | 19 +++++++++++++++++++
|
||
|
|
2 files changed, 21 insertions(+)
|
||
|
|
|
||
|
|
diff --git a/tensorflow/core/ops/count_ops.cc b/tensorflow/core/ops/count_ops.cc
|
||
|
|
index 4f9631310df92..aa6c0437337af 100644
|
||
|
|
--- a/tensorflow/core/ops/count_ops.cc
|
||
|
|
+++ b/tensorflow/core/ops/count_ops.cc
|
||
|
|
@@ -41,6 +41,8 @@ Status DenseCountSparseOutputShapeFn(InferenceContext *c) {
|
||
|
|
}
|
||
|
|
|
||
|
|
Status SparseCountSparseOutputShapeFn(InferenceContext *c) {
|
||
|
|
+ ShapeHandle unused;
|
||
|
|
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
|
||
|
|
auto rank = c->Dim(c->input(0), 1);
|
||
|
|
auto nvals = c->UnknownDim();
|
||
|
|
c->set_output(0, c->Matrix(nvals, rank)); // out.indices
|
||
|
|
diff --git a/tensorflow/python/ops/bincount_ops_test.py b/tensorflow/python/ops/bincount_ops_test.py
|
||
|
|
index 3c7a2a5da9daf..de7d1423870d7 100644
|
||
|
|
--- a/tensorflow/python/ops/bincount_ops_test.py
|
||
|
|
+++ b/tensorflow/python/ops/bincount_ops_test.py
|
||
|
|
@@ -831,6 +831,25 @@ def test_ragged_input_different_shape_fails(self):
|
||
|
|
self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1))
|
||
|
|
|
||
|
|
|
||
|
|
+class RawOpsHeapOobTest(test.TestCase, parameterized.TestCase):
|
||
|
|
+
|
||
|
|
+ @test_util.run_v1_only("Test security error")
|
||
|
|
+ def testSparseCountSparseOutputBadIndicesShapeTooSmall(self):
|
||
|
|
+ indices = [1]
|
||
|
|
+ values = [[1]]
|
||
|
|
+ weights = []
|
||
|
|
+ dense_shape = [10]
|
||
|
|
+ with self.assertRaisesRegex(ValueError,
|
||
|
|
+ "Shape must be rank 2 but is rank 1 for"):
|
||
|
|
+ self.evaluate(
|
||
|
|
+ gen_count_ops.SparseCountSparseOutput(
|
||
|
|
+ indices=indices,
|
||
|
|
+ values=values,
|
||
|
|
+ dense_shape=dense_shape,
|
||
|
|
+ weights=weights,
|
||
|
|
+ binary_output=True))
|
||
|
|
+
|
||
|
|
+
|
||
|
|
@test_util.run_all_in_graph_and_eager_modes
|
||
|
|
@test_util.disable_tfrt
|
||
|
|
class RawOpsTest(test.TestCase, parameterized.TestCase):
|