From 203214568f5bc237603dbab6e1fd389f1572f5c9 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Fri, 30 Jul 2021 16:06:23 -0700 Subject: [PATCH] Reorganize and add more validation to MKL requantization PiperOrigin-RevId: 387901341 Change-Id: I2515b9034c64e113db0bcec8337d30643ab0a0f1 --- .../mkl_requantize_per_channel_op.cc | 40 ++++++++++++------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/tensorflow/core/kernels/mkl_requantize_per_channel_op.cc b/tensorflow/core/kernels/mkl_requantize_per_channel_op.cc index c0f9845cd4b08..6ffbd09b44f54 100644 --- a/tensorflow/core/kernels/mkl_requantize_per_channel_op.cc +++ b/tensorflow/core/kernels/mkl_requantize_per_channel_op.cc @@ -49,35 +49,45 @@ class MklRequantizePerChannelOp : public OpKernel { void Compute(OpKernelContext* ctx) override { try { const Tensor& input = ctx->input(kInputTensorIndex); + OP_REQUIRES( + ctx, input.dims() == 4, + errors::InvalidArgument("Current RequantizePerChannel operator" + "supports 4D tensors only.")); + const Tensor& input_min_vec = ctx->input(kInputMinVecIndex); + size_t depth = input_min_vec.NumElements(); float* input_min_vec_data = (float*)const_cast( static_cast(input_min_vec.flat().data())); + const Tensor& input_max_vec = ctx->input(kInputMaxVecIndex); + OP_REQUIRES( + ctx, input_max_vec.NumElements() == depth, + errors::InvalidArgument("input_max has incorrect size, expected ", + depth, " was ", input_max_vec.NumElements())); float* input_max_vec_data = (float*)const_cast( static_cast(input_max_vec.flat().data())); const Tensor& input_requested_min = ctx->input(this->kRequestMinIndex); + OP_REQUIRES( + ctx, input_requested_min.NumElements() == 1, + errors::InvalidArgument("requested_output_min must be a scalar")); const float input_requested_min_float = input_requested_min.flat()(0); + const Tensor& input_requested_max = ctx->input(this->kRequestMaxIndex); + OP_REQUIRES( + ctx, input_requested_min.NumElements() == 1, + errors::InvalidArgument("requested_output_max must be a scalar")); const float input_requested_max_float = input_requested_max.flat()(0); - size_t depth = input_min_vec.NumElements(); - OP_REQUIRES( - ctx, input.dims() == 4, - errors::InvalidArgument("Current RequantizePerChannel operator" - "supports 4D tensors only.")); - OP_REQUIRES( - ctx, input_min_vec.dim_size(0) == depth, - errors::InvalidArgument("input_min has incorrect size, expected ", - depth, " was ", input_min_vec.dim_size(0))); - OP_REQUIRES( - ctx, input_max_vec.dim_size(0) == depth, - errors::InvalidArgument("input_max has incorrect size, expected ", - depth, " was ", input_max_vec.dim_size(0))); - - if (out_type_ == DT_QINT8) DCHECK(input_requested_min_float < 0.0f); + if (out_type_ == DT_QINT8) { + OP_REQUIRES(ctx, input_requested_min_float < 0.0f, + errors::InvalidArgument( + "If out_type is QINT8, requested_output_max must be " + "non negative, got ", + input_requested_min_float)); + } const float factor = (out_type_ == DT_QINT8) ? 127.0f : 255.0f; const float requested_min_max =