From 52df91c5634e6c666843849a1c6ff29b3d2676be Mon Sep 17 00:00:00 2001 From: Pankaj Kanwar Date: Mon, 12 Oct 2020 10:30:20 -0700 Subject: [PATCH] Create a V2 Op to stop the gradient when the input is out of range. PiperOrigin-RevId: 336692325 Change-Id: I36fd3fcfc58a30d5218beca512fbfc7c24b8b5cb --- tensorflow/cc/gradients/array_grad.cc | 29 ++-- tensorflow/compiler/tests/unary_ops_test.py | 6 +- .../api_def_QuantizeAndDequantizeV4.pbtxt | 8 ++ .../api_def_QuantizeAndDequantizeV4Grad.pbtxt | 8 ++ .../api_def_QuantizeAndDequantizeV4.pbtxt | 3 + .../api_def_QuantizeAndDequantizeV4Grad.pbtxt | 3 + .../api_def_QuantizeAndDequantizeV4.pbtxt | 4 + .../api_def_QuantizeAndDequantizeV4Grad.pbtxt | 4 + .../kernels/quantize_and_dequantize_op.cc | 126 ++++++++++++++++++ .../core/kernels/quantize_and_dequantize_op.h | 71 ++++++++++ .../quantize_and_dequantize_op_gpu.cu.cc | 40 ++++++ .../quantize_and_dequantize_op_test.cc | 48 +++++++ tensorflow/core/ops/array_ops.cc | 64 +++++++++ .../python/kernel_tests/array_ops_test.py | 21 ++- tensorflow/python/ops/array_ops.py | 113 +++++++++++++++- .../tools/api/golden/v1/tensorflow.pbtxt | 4 + .../golden/v1/tensorflow.quantization.pbtxt | 4 + .../api/golden/v1/tensorflow.raw_ops.pbtxt | 8 ++ .../tools/api/golden/v2/tensorflow.pbtxt | 4 + .../golden/v2/tensorflow.quantization.pbtxt | 4 + .../api/golden/v2/tensorflow.raw_ops.pbtxt | 8 ++ 21 files changed, 564 insertions(+), 16 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_QuantizeAndDequantizeV4.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_QuantizeAndDequantizeV4Grad.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_QuantizeAndDequantizeV4.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_QuantizeAndDequantizeV4Grad.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_QuantizeAndDequantizeV4.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_QuantizeAndDequantizeV4Grad.pbtxt diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index e9173227..480243a2 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -15,13 +15,12 @@ limitations under the License. #include +#include "tensorflow/cc/framework/grad_op_registry.h" +#include "tensorflow/cc/framework/gradients.h" #include "tensorflow/cc/ops/array_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/cc/framework/grad_op_registry.h" -#include "tensorflow/cc/framework/gradients.h" - namespace tensorflow { namespace ops { namespace { @@ -90,15 +89,25 @@ Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad); -Status QuantizeAndDequantizeV2Grad(const Scope& scope, const Operation& op, - const std::vector& grad_inputs, - std::vector* grad_outputs) { - grad_outputs->push_back(Identity(scope, grad_inputs[0])); - grad_outputs->push_back(NoGradient()); - grad_outputs->push_back(NoGradient()); +Status QuantizeAndDequantizeV4GradHelper(const Scope& scope, + const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + Input input = Shape(scope, op.input(0)); + Input input_min = op.input(1); + Input input_max = op.input(2); + int64 axis; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis)); + auto qdq_v4_grad = QuantizeAndDequantizeV4Grad( + scope, grad_inputs[0], input, input_min, input_max, + QuantizeAndDequantizeV4Grad::Axis(axis)); + grad_outputs->push_back(qdq_v4_grad.input_backprop); + grad_outputs->push_back(qdq_v4_grad.input_min_backprop); + grad_outputs->push_back(qdq_v4_grad.input_max_backprop); return scope.status(); } -REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2", QuantizeAndDequantizeV2Grad); +REGISTER_GRADIENT_OP("QuantizeAndDequantizeV4", + QuantizeAndDequantizeV4GradHelper); Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 162693a9..dacd7232 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -535,7 +535,7 @@ class UnaryOpsTest(xla_test.XLATestCase): for dtype in self.float_types: def quantize_and_dequantize_v2(x): - return array_ops.quantize_and_dequantize_v2( + return array_ops.quantize_and_dequantize( x, -127, 127, signed_input=True, num_bits=8) self._assertOpOutputMatchesExpected( @@ -544,7 +544,7 @@ class UnaryOpsTest(xla_test.XLATestCase): expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype)) def quantize_and_dequantize_v2_round_half_up(x): - return array_ops.quantize_and_dequantize_v2( + return array_ops.quantize_and_dequantize( x, -1, 1.0, @@ -568,7 +568,7 @@ class UnaryOpsTest(xla_test.XLATestCase): dtype=dtype)) def quantize_and_dequantize_v2_round_half_to_even(x): - return array_ops.quantize_and_dequantize_v2( + return array_ops.quantize_and_dequantize( x, -1.0, 1.0, diff --git a/tensorflow/core/api_def/base_api/api_def_QuantizeAndDequantizeV4.pbtxt b/tensorflow/core/api_def/base_api/api_def_QuantizeAndDequantizeV4.pbtxt new file mode 100644 index 00000000..a84ccb78 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_QuantizeAndDequantizeV4.pbtxt @@ -0,0 +1,8 @@ +op { + graph_op_name: "QuantizeAndDequantizeV4" + summary: "Returns the gradient of `QuantizeAndDequantizeV4`." + description: <