112 lines
6.2 KiB
Diff
112 lines
6.2 KiB
Diff
|
|
From d6ed5bcfe1dcab9e85a4d39931bd18d99018e75b Mon Sep 17 00:00:00 2001
|
||
|
|
From: Mihai Maruseac <mihaimaruseac@google.com>
|
||
|
|
Date: Fri, 23 Apr 2021 11:40:06 -0700
|
||
|
|
Subject: [PATCH] Add missing validation in
|
||
|
|
`QuantizedBatchNormWithGlobalNormalization`
|
||
|
|
|
||
|
|
PiperOrigin-RevId: 370123451
|
||
|
|
Change-Id: Id234d6dab1ec21230bb8e503dba30f899af87f33
|
||
|
|
---
|
||
|
|
.../core/kernels/quantized_batch_norm_op.cc | 77 ++++++++++++++++---
|
||
|
|
1 file changed, 67 insertions(+), 10 deletions(-)
|
||
|
|
|
||
|
|
diff --git a/tensorflow/core/kernels/quantized_batch_norm_op.cc b/tensorflow/core/kernels/quantized_batch_norm_op.cc
|
||
|
|
index b03da7ad17fab..6dfe07f97a400 100644
|
||
|
|
--- a/tensorflow/core/kernels/quantized_batch_norm_op.cc
|
||
|
|
+++ b/tensorflow/core/kernels/quantized_batch_norm_op.cc
|
||
|
|
@@ -173,20 +173,50 @@ class QuantizedBatchNormOp : public OpKernel {
|
||
|
|
|
||
|
|
void Compute(OpKernelContext* context) override {
|
||
|
|
const Tensor& input = context->input(0);
|
||
|
|
- const float input_min = context->input(1).flat<float>()(0);
|
||
|
|
- const float input_max = context->input(2).flat<float>()(0);
|
||
|
|
+ const auto& input_min_tensor = context->input(1);
|
||
|
|
+ OP_REQUIRES(context, input_min_tensor.NumElements() == 1,
|
||
|
|
+ errors::InvalidArgument("input_min must have 1 element"));
|
||
|
|
+ const float input_min = input_min_tensor.flat<float>()(0);
|
||
|
|
+ const auto& input_max_tensor = context->input(2);
|
||
|
|
+ OP_REQUIRES(context, input_max_tensor.NumElements() == 1,
|
||
|
|
+ errors::InvalidArgument("input_max must have 1 element"));
|
||
|
|
+ const float input_max = input_max_tensor.flat<float>()(0);
|
||
|
|
const Tensor& mean = context->input(3);
|
||
|
|
- const float mean_min = context->input(4).flat<float>()(0);
|
||
|
|
- const float mean_max = context->input(5).flat<float>()(0);
|
||
|
|
+ const auto& mean_min_tensor = context->input(4);
|
||
|
|
+ OP_REQUIRES(context, mean_min_tensor.NumElements() == 1,
|
||
|
|
+ errors::InvalidArgument("mean_min must have 1 element"));
|
||
|
|
+ const float mean_min = mean_min_tensor.flat<float>()(0);
|
||
|
|
+ const auto& mean_max_tensor = context->input(5);
|
||
|
|
+ OP_REQUIRES(context, mean_max_tensor.NumElements() == 1,
|
||
|
|
+ errors::InvalidArgument("mean_max must have 1 element"));
|
||
|
|
+ const float mean_max = mean_max_tensor.flat<float>()(0);
|
||
|
|
const Tensor& var = context->input(6);
|
||
|
|
- const float var_min = context->input(7).flat<float>()(0);
|
||
|
|
- const float var_max = context->input(8).flat<float>()(0);
|
||
|
|
+ const auto& var_min_tensor = context->input(7);
|
||
|
|
+ OP_REQUIRES(context, var_min_tensor.NumElements() == 1,
|
||
|
|
+ errors::InvalidArgument("var_min must have 1 element"));
|
||
|
|
+ const float var_min = var_min_tensor.flat<float>()(0);
|
||
|
|
+ const auto& var_max_tensor = context->input(8);
|
||
|
|
+ OP_REQUIRES(context, var_max_tensor.NumElements() == 1,
|
||
|
|
+ errors::InvalidArgument("var_max must have 1 element"));
|
||
|
|
+ const float var_max = var_max_tensor.flat<float>()(0);
|
||
|
|
const Tensor& beta = context->input(9);
|
||
|
|
- const float beta_min = context->input(10).flat<float>()(0);
|
||
|
|
- const float beta_max = context->input(11).flat<float>()(0);
|
||
|
|
+ const auto& beta_min_tensor = context->input(10);
|
||
|
|
+ OP_REQUIRES(context, beta_min_tensor.NumElements() == 1,
|
||
|
|
+ errors::InvalidArgument("beta_min must have 1 element"));
|
||
|
|
+ const float beta_min = beta_min_tensor.flat<float>()(0);
|
||
|
|
+ const auto& beta_max_tensor = context->input(11);
|
||
|
|
+ OP_REQUIRES(context, beta_max_tensor.NumElements() == 1,
|
||
|
|
+ errors::InvalidArgument("beta_max must have 1 element"));
|
||
|
|
+ const float beta_max = beta_max_tensor.flat<float>()(0);
|
||
|
|
const Tensor& gamma = context->input(12);
|
||
|
|
- const float gamma_min = context->input(13).flat<float>()(0);
|
||
|
|
- const float gamma_max = context->input(14).flat<float>()(0);
|
||
|
|
+ const auto& gamma_min_tensor = context->input(13);
|
||
|
|
+ OP_REQUIRES(context, gamma_min_tensor.NumElements() == 1,
|
||
|
|
+ errors::InvalidArgument("gamma_min must have 1 element"));
|
||
|
|
+ const float gamma_min = gamma_min_tensor.flat<float>()(0);
|
||
|
|
+ const auto& gamma_max_tensor = context->input(14);
|
||
|
|
+ OP_REQUIRES(context, gamma_max_tensor.NumElements() == 1,
|
||
|
|
+ errors::InvalidArgument("gamma_max must have 1 element"));
|
||
|
|
+ const float gamma_max = gamma_max_tensor.flat<float>()(0);
|
||
|
|
|
||
|
|
OP_REQUIRES(context, input.dims() == 4,
|
||
|
|
errors::InvalidArgument("input must be 4-dimensional",
|
||
|
|
@@ -203,6 +233,33 @@ class QuantizedBatchNormOp : public OpKernel {
|
||
|
|
OP_REQUIRES(context, gamma.dims() == 1,
|
||
|
|
errors::InvalidArgument("gamma must be 1-dimensional",
|
||
|
|
gamma.shape().DebugString()));
|
||
|
|
+ OP_REQUIRES(context, mean.NumElements() > 1,
|
||
|
|
+ errors::InvalidArgument("Must have at least a mean value",
|
||
|
|
+ gamma.shape().DebugString()));
|
||
|
|
+ OP_REQUIRES(context, mean.NumElements() > 1,
|
||
|
|
+ errors::InvalidArgument("Must have at least a mean value"));
|
||
|
|
+ const auto last_dim = input.shape().dims() - 1;
|
||
|
|
+ OP_REQUIRES(context,
|
||
|
|
+ mean.shape().dim_size(0) == input.shape().dim_size(last_dim),
|
||
|
|
+ errors::InvalidArgument("Must provide as many means as the "
|
||
|
|
+ "last dimension of the input tensor: ",
|
||
|
|
+ mean.shape().DebugString(), " vs. ",
|
||
|
|
+ input.shape().DebugString()));
|
||
|
|
+ OP_REQUIRES(
|
||
|
|
+ context, mean.shape().dim_size(0) == var.shape().dim_size(0),
|
||
|
|
+ errors::InvalidArgument(
|
||
|
|
+ "Mean and variance tensors must have the same shape: ",
|
||
|
|
+ mean.shape().DebugString(), " vs. ", var.shape().DebugString()));
|
||
|
|
+ OP_REQUIRES(
|
||
|
|
+ context, mean.shape().dim_size(0) == beta.shape().dim_size(0),
|
||
|
|
+ errors::InvalidArgument(
|
||
|
|
+ "Mean and beta tensors must have the same shape: ",
|
||
|
|
+ mean.shape().DebugString(), " vs. ", beta.shape().DebugString()));
|
||
|
|
+ OP_REQUIRES(
|
||
|
|
+ context, mean.shape().dim_size(0) == gamma.shape().dim_size(0),
|
||
|
|
+ errors::InvalidArgument(
|
||
|
|
+ "Mean and gamma tensors must have the same shape: ",
|
||
|
|
+ mean.shape().DebugString(), " vs. ", gamma.shape().DebugString()));
|
||
|
|
|
||
|
|
Tensor* output = nullptr;
|
||
|
|
OP_REQUIRES_OK(context,
|