54 lines
2.3 KiB
Diff
54 lines
2.3 KiB
Diff
|
|
From eccb7ec454e6617738554a255d77f08e60ee0808 Mon Sep 17 00:00:00 2001
|
||
|
|
From: Mihai Maruseac <mihaimaruseac@google.com>
|
||
|
|
Date: Mon, 19 Oct 2020 17:56:36 -0700
|
||
|
|
Subject: [PATCH] Prevent segfault in `quantize_and_dequantize`
|
||
|
|
|
||
|
|
---
|
||
|
|
.../core/kernels/quantize_and_dequantize_op.cc | 4 ++++
|
||
|
|
tensorflow/python/kernel_tests/array_ops_test.py | 14 ++++++++++++++
|
||
|
|
2 files changed, 18 insertions(+)
|
||
|
|
|
||
|
|
diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op.cc b/tensorflow/core/kernels/quantize_and_dequantize_op.cc
|
||
|
|
index 8f71d09c..fda54208 100644
|
||
|
|
--- a/tensorflow/core/kernels/quantize_and_dequantize_op.cc
|
||
|
|
+++ b/tensorflow/core/kernels/quantize_and_dequantize_op.cc
|
||
|
|
@@ -71,6 +71,10 @@ class QuantizeAndDequantizeV2Op : public OpKernel {
|
||
|
|
|
||
|
|
void Compute(OpKernelContext* ctx) override {
|
||
|
|
const Tensor& input = ctx->input(0);
|
||
|
|
+ OP_REQUIRES(
|
||
|
|
+ ctx, (axis_ == -1 || axis_ < input.shape().dims()),
|
||
|
|
+ errors::InvalidArgument("Shape must be at least rank", axis_ + 1,
|
||
|
|
+ " but is rank ", input.shape().dims()));
|
||
|
|
const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_);
|
||
|
|
Tensor input_min_tensor;
|
||
|
|
Tensor input_max_tensor;
|
||
|
|
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
|
||
|
|
index dbff3a1b..c498ff62 100644
|
||
|
|
--- a/tensorflow/python/kernel_tests/array_ops_test.py
|
||
|
|
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
|
||
|
|
@@ -1541,6 +1541,20 @@ class QuantizeAndDequantizeTest(test_util.TensorFlowTestCase):
|
||
|
|
axis=(axis - 4)))
|
||
|
|
self.assertAllClose(fake_quantized, expected)
|
||
|
|
|
||
|
|
+ def testBadAxis(self):
|
||
|
|
+ input_tensor = [2.5, 2.5]
|
||
|
|
+ input_min = [0, 0]
|
||
|
|
+ input_max = [1, 1]
|
||
|
|
+ error_message_pattern = "Shape must be at least rank 11 but is rank 1"
|
||
|
|
+ # TODO(b/171260356): Eager mode and graph mode throw different error types
|
||
|
|
+ error = errors.InvalidArgumentError if context.executing_eagerly(
|
||
|
|
+ ) else ValueError
|
||
|
|
+ with self.assertRaisesRegex(error, error_message_pattern): self.evaluate(
|
||
|
|
+ array_ops.quantize_and_dequantize_v2(
|
||
|
|
+ input=input_tensor,
|
||
|
|
+ input_min=input_min,
|
||
|
|
+ input_max=input_max,
|
||
|
|
+ axis=10))
|
||
|
|
|
||
|
|
@test_util.run_all_in_graph_and_eager_modes
|
||
|
|
class SortedSearchTest(test_util.TensorFlowTestCase):
|
||
|
|
--
|
||
|
|
2.23.0
|
||
|
|
|