From eccb7ec454e6617738554a255d77f08e60ee0808 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Mon, 19 Oct 2020 17:56:36 -0700 Subject: [PATCH] Prevent segfault in `quantize_and_dequantize` --- .../core/kernels/quantize_and_dequantize_op.cc | 4 ++++ tensorflow/python/kernel_tests/array_ops_test.py | 14 ++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op.cc b/tensorflow/core/kernels/quantize_and_dequantize_op.cc index 8f71d09c..fda54208 100644 --- a/tensorflow/core/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/core/kernels/quantize_and_dequantize_op.cc @@ -71,6 +71,10 @@ class QuantizeAndDequantizeV2Op : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor& input = ctx->input(0); + OP_REQUIRES( + ctx, (axis_ == -1 || axis_ < input.shape().dims()), + errors::InvalidArgument("Shape must be at least rank", axis_ + 1, + " but is rank ", input.shape().dims())); const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_); Tensor input_min_tensor; Tensor input_max_tensor; diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index dbff3a1b..c498ff62 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -1541,6 +1541,20 @@ class QuantizeAndDequantizeTest(test_util.TensorFlowTestCase): axis=(axis - 4))) self.assertAllClose(fake_quantized, expected) + def testBadAxis(self): + input_tensor = [2.5, 2.5] + input_min = [0, 0] + input_max = [1, 1] + error_message_pattern = "Shape must be at least rank 11 but is rank 1" + # TODO(b/171260356): Eager mode and graph mode throw different error types + error = errors.InvalidArgumentError if context.executing_eagerly( + ) else ValueError + with self.assertRaisesRegex(error, error_message_pattern): self.evaluate( + array_ops.quantize_and_dequantize_v2( + input=input_tensor, + input_min=input_min, + input_max=input_max, + axis=10)) @test_util.run_all_in_graph_and_eager_modes class SortedSearchTest(test_util.TensorFlowTestCase): -- 2.23.0