From aab9998916c2ffbd8f0592059fad352622f89cda Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Wed, 29 Sep 2021 13:00:50 -0700 Subject: [PATCH] Add shape checks to FusedBatchNorm kernels. --- .../core/kernels/fused_batch_norm_op.cc | 38 +++++- .../python/ops/nn_fused_batchnorm_test.py | 122 ++++++++++++++++++ 2 files changed, 153 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc index bd5dab36..b19323f0 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cc @@ -1279,18 +1279,20 @@ class FusedBatchNormOpBase : public OpKernel { errors::InvalidArgument("offset must have the same number of elements " "as the channels of x, got ", offset.NumElements(), " and ", num_channels)); - if (estimated_mean.NumElements() != 0) { + if (!is_training_ || exponential_avg_factor_ != 1.) { + std::string prefix_msg = is_training_ ? "When exponential_avg_factor != 1" + : "When is_training=false"; OP_REQUIRES(context, estimated_mean.NumElements() == num_channels, errors::InvalidArgument( - "mean must be empty or have the same number of " - "elements as the channels of x, got ", + prefix_msg, + ", mean must have the same number " + "of elements as the channels of x, got ", estimated_mean.NumElements(), " and ",num_channels)); - } - if (estimated_variance.NumElements() != 0) { OP_REQUIRES(context, estimated_variance.NumElements() == num_channels, errors::InvalidArgument( - "variance must be empty or have the same number of " - "elements as the channels of x, got ", + prefix_msg, + ", variance must have the same " + "number of elements as the channels of x, got ", estimated_variance.NumElements(), " and ", num_channels)); } @@ -1434,6 +1436,28 @@ class FusedBatchNormGradOpBase : public OpKernel { errors::InvalidArgument( "saved variance must be 1-dimensional", saved_maybe_inv_var_or_pop_var.shape().DebugString())); + OP_REQUIRES( + context, x.shape() == y_backprop.shape(), + errors::InvalidArgument( + "x and y_backprop must have same shape, but x has shape ", + x.shape(), " and y_backprop has shape ", y_backprop.shape())); + + const auto num_channels = GetTensorDim(x, tensor_format_, 'C'); + OP_REQUIRES( + context, scale.NumElements() == num_channels, + errors::InvalidArgument("scale must have the same number of elements " + "as the channels of x, got ", + scale.NumElements(), " and ", num_channels)); + OP_REQUIRES( + context, saved_mean_or_pop_mean.NumElements() == num_channels, + errors::InvalidArgument("reserve_space_1 must have the same number of " + "elements as the channels of x, got ", + scale.NumElements(), " and ", num_channels)); + OP_REQUIRES( + context, saved_maybe_inv_var_or_pop_var.NumElements() == num_channels, + errors::InvalidArgument("reserve_space_2 must have the same number of " + "elements as the channels of x, got ", + scale.NumElements(), " and ", num_channels)); Tensor* x_backprop = nullptr; OP_REQUIRES_OK(context, diff --git a/tensorflow/python/ops/nn_fused_batchnorm_test.py b/tensorflow/python/ops/nn_fused_batchnorm_test.py index 1742a919..8fecd1c7 100644 --- a/tensorflow/python/ops/nn_fused_batchnorm_test.py +++ b/tensorflow/python/ops/nn_fused_batchnorm_test.py @@ -20,10 +20,13 @@ from __future__ import print_function import numpy as np +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops @@ -610,6 +613,125 @@ class BatchNormalizationTest(test.TestCase): } self._testBatchNormGradGrad(config) + def testEagerShapeErrors(self): + with context.eager_mode(): + x = array_ops.ones((2, 2, 2, 2)) + scale = array_ops.ones((3,)) + offset = array_ops.ones((2,)) + with self.assertRaisesRegex( + errors_impl.InvalidArgumentError, + 'scale must have the same number of elements'): + nn_impl.fused_batch_norm(x, scale, offset) + + x = array_ops.ones((2, 2, 2, 2)) + scale = array_ops.ones((2,)) + offset = array_ops.ones((3,)) + with self.assertRaisesRegex( + errors_impl.InvalidArgumentError, + 'offset must have the same number of elements'): + nn_impl.fused_batch_norm(x, scale, offset) + + x = array_ops.ones((2, 2, 2, 2)) + scale = array_ops.ones((2,)) + offset = array_ops.ones((2,)) + mean = array_ops.ones((0,)) + variance = array_ops.ones((2,)) + with self.assertRaisesRegex( + errors_impl.InvalidArgumentError, + 'When is_training=false, mean must have the same number of elements'): + nn_impl.fused_batch_norm( + x, scale, offset, mean=mean, variance=variance, is_training=False) + + x = array_ops.ones((2, 2, 2, 2)) + scale = array_ops.ones((2,)) + offset = array_ops.ones((2,)) + mean = array_ops.ones((2,)) + variance = array_ops.ones((0,)) + with self.assertRaisesRegex( + errors_impl.InvalidArgumentError, + 'When is_training=false, variance must have the same number of ' + nn_impl.fused_batch_norm( + x, scale, offset, mean=mean, variance=variance, is_training=False) + + x = array_ops.ones((2, 2, 2, 2)) + scale = array_ops.ones((2,)) + offset = array_ops.ones((2,)) + mean = array_ops.ones((0,)) + variance = array_ops.ones((2,)) + with self.assertRaisesRegex( + errors_impl.InvalidArgumentError, + 'When exponential_avg_factor != 1, mean must have the same number of ' + 'elements'): + nn_impl.fused_batch_norm( + x, + scale, + offset, + mean=mean, + variance=variance, + exponential_avg_factor=0.5) + + x = array_ops.ones((2, 2, 2, 2)) + scale = array_ops.ones((2,)) + offset = array_ops.ones((2,)) + mean = array_ops.ones((2,)) + variance = array_ops.ones((0,)) + with self.assertRaisesRegex( + errors_impl.InvalidArgumentError, + 'When exponential_avg_factor != 1, variance must have the same ' + 'number of elements'): + nn_impl.fused_batch_norm( + x, + scale, + offset, + mean=mean, + variance=variance, + exponential_avg_factor=0.5) + + def testEagerShapeGradErrors(self): + with context.eager_mode(): + y_backprop = array_ops.ones((2, 2, 2, 3)) + x = array_ops.ones((2, 2, 2, 2)) + scale = array_ops.ones((2,)) + reserve_space_1 = array_ops.ones((2,)) + reserve_space_2 = array_ops.ones((2,)) + with self.assertRaisesRegex(errors_impl.InvalidArgumentError, + 'x and y_backprop must have same shape,'): + gen_nn_ops.fused_batch_norm_grad_v2(y_backprop, x, scale, + reserve_space_1, reserve_space_2) + + y_backprop = array_ops.ones((2, 2, 2, 2)) + x = array_ops.ones((2, 2, 2, 2)) + scale = array_ops.ones((3,)) + reserve_space_1 = array_ops.ones((2,)) + reserve_space_2 = array_ops.ones((2,)) + with self.assertRaisesRegex( + errors_impl.InvalidArgumentError, + 'scale must have the same number of elements'): + gen_nn_ops.fused_batch_norm_grad_v2(y_backprop, x, scale, + reserve_space_1, reserve_space_2) + + y_backprop = array_ops.ones((2, 2, 2, 2)) + x = array_ops.ones((2, 2, 2, 2)) + scale = array_ops.ones((2,)) + reserve_space_1 = array_ops.ones((3,)) + reserve_space_2 = array_ops.ones((2,)) + with self.assertRaisesRegex( + errors_impl.InvalidArgumentError, + 'reserve_space_1 must have the same number of elements'): + gen_nn_ops.fused_batch_norm_grad_v2(y_backprop, x, scale, + reserve_space_1, reserve_space_2) + + y_backprop = array_ops.ones((2, 2, 2, 2)) + x = array_ops.ones((2, 2, 2, 2)) + scale = array_ops.ones((2,)) + reserve_space_1 = array_ops.ones((2,)) + reserve_space_2 = array_ops.ones((3,)) + with self.assertRaisesRegex( + errors_impl.InvalidArgumentError, + 'reserve_space_2 must have the same number of elements'): + gen_nn_ops.fused_batch_norm_grad_v2(y_backprop, x, scale, + reserve_space_1, reserve_space_2) + if __name__ == '__main__': test.main() -- 2.23.0