219 lines
9.1 KiB
Diff
219 lines
9.1 KiB
Diff
From aab9998916c2ffbd8f0592059fad352622f89cda Mon Sep 17 00:00:00 2001
|
|
From: Reed Wanderman-Milne <reedwm@google.com>
|
|
Date: Wed, 29 Sep 2021 13:00:50 -0700
|
|
Subject: [PATCH] Add shape checks to FusedBatchNorm kernels.
|
|
|
|
---
|
|
.../core/kernels/fused_batch_norm_op.cc | 38 +++++-
|
|
.../python/ops/nn_fused_batchnorm_test.py | 122 ++++++++++++++++++
|
|
2 files changed, 153 insertions(+), 7 deletions(-)
|
|
|
|
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc
|
|
index bd5dab36..b19323f0 100644
|
|
--- a/tensorflow/core/kernels/fused_batch_norm_op.cc
|
|
+++ b/tensorflow/core/kernels/fused_batch_norm_op.cc
|
|
@@ -1279,18 +1279,20 @@ class FusedBatchNormOpBase : public OpKernel {
|
|
errors::InvalidArgument("offset must have the same number of elements "
|
|
"as the channels of x, got ",
|
|
offset.NumElements(), " and ", num_channels));
|
|
- if (estimated_mean.NumElements() != 0) {
|
|
+ if (!is_training_ || exponential_avg_factor_ != 1.) {
|
|
+ std::string prefix_msg = is_training_ ? "When exponential_avg_factor != 1"
|
|
+ : "When is_training=false";
|
|
OP_REQUIRES(context, estimated_mean.NumElements() == num_channels,
|
|
errors::InvalidArgument(
|
|
- "mean must be empty or have the same number of "
|
|
- "elements as the channels of x, got ",
|
|
+ prefix_msg,
|
|
+ ", mean must have the same number "
|
|
+ "of elements as the channels of x, got ",
|
|
estimated_mean.NumElements(), " and ",num_channels));
|
|
- }
|
|
- if (estimated_variance.NumElements() != 0) {
|
|
OP_REQUIRES(context, estimated_variance.NumElements() == num_channels,
|
|
errors::InvalidArgument(
|
|
- "variance must be empty or have the same number of "
|
|
- "elements as the channels of x, got ",
|
|
+ prefix_msg,
|
|
+ ", variance must have the same "
|
|
+ "number of elements as the channels of x, got ",
|
|
estimated_variance.NumElements(), " and ", num_channels));
|
|
}
|
|
|
|
@@ -1434,6 +1436,28 @@ class FusedBatchNormGradOpBase : public OpKernel {
|
|
errors::InvalidArgument(
|
|
"saved variance must be 1-dimensional",
|
|
saved_maybe_inv_var_or_pop_var.shape().DebugString()));
|
|
+ OP_REQUIRES(
|
|
+ context, x.shape() == y_backprop.shape(),
|
|
+ errors::InvalidArgument(
|
|
+ "x and y_backprop must have same shape, but x has shape ",
|
|
+ x.shape(), " and y_backprop has shape ", y_backprop.shape()));
|
|
+
|
|
+ const auto num_channels = GetTensorDim(x, tensor_format_, 'C');
|
|
+ OP_REQUIRES(
|
|
+ context, scale.NumElements() == num_channels,
|
|
+ errors::InvalidArgument("scale must have the same number of elements "
|
|
+ "as the channels of x, got ",
|
|
+ scale.NumElements(), " and ", num_channels));
|
|
+ OP_REQUIRES(
|
|
+ context, saved_mean_or_pop_mean.NumElements() == num_channels,
|
|
+ errors::InvalidArgument("reserve_space_1 must have the same number of "
|
|
+ "elements as the channels of x, got ",
|
|
+ scale.NumElements(), " and ", num_channels));
|
|
+ OP_REQUIRES(
|
|
+ context, saved_maybe_inv_var_or_pop_var.NumElements() == num_channels,
|
|
+ errors::InvalidArgument("reserve_space_2 must have the same number of "
|
|
+ "elements as the channels of x, got ",
|
|
+ scale.NumElements(), " and ", num_channels));
|
|
|
|
Tensor* x_backprop = nullptr;
|
|
OP_REQUIRES_OK(context,
|
|
diff --git a/tensorflow/python/ops/nn_fused_batchnorm_test.py b/tensorflow/python/ops/nn_fused_batchnorm_test.py
|
|
index 1742a919..8fecd1c7 100644
|
|
--- a/tensorflow/python/ops/nn_fused_batchnorm_test.py
|
|
+++ b/tensorflow/python/ops/nn_fused_batchnorm_test.py
|
|
@@ -20,10 +20,13 @@ from __future__ import print_function
|
|
|
|
import numpy as np
|
|
|
|
+from tensorflow.python.eager import context
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
+from tensorflow.python.framework import errors_impl
|
|
from tensorflow.python.framework import test_util
|
|
from tensorflow.python.ops import array_ops
|
|
+from tensorflow.python.ops import gen_nn_ops
|
|
from tensorflow.python.ops import gradient_checker
|
|
from tensorflow.python.ops import gradients_impl
|
|
from tensorflow.python.ops import math_ops
|
|
@@ -610,6 +613,125 @@ class BatchNormalizationTest(test.TestCase):
|
|
}
|
|
self._testBatchNormGradGrad(config)
|
|
|
|
+ def testEagerShapeErrors(self):
|
|
+ with context.eager_mode():
|
|
+ x = array_ops.ones((2, 2, 2, 2))
|
|
+ scale = array_ops.ones((3,))
|
|
+ offset = array_ops.ones((2,))
|
|
+ with self.assertRaisesRegex(
|
|
+ errors_impl.InvalidArgumentError,
|
|
+ 'scale must have the same number of elements'):
|
|
+ nn_impl.fused_batch_norm(x, scale, offset)
|
|
+
|
|
+ x = array_ops.ones((2, 2, 2, 2))
|
|
+ scale = array_ops.ones((2,))
|
|
+ offset = array_ops.ones((3,))
|
|
+ with self.assertRaisesRegex(
|
|
+ errors_impl.InvalidArgumentError,
|
|
+ 'offset must have the same number of elements'):
|
|
+ nn_impl.fused_batch_norm(x, scale, offset)
|
|
+
|
|
+ x = array_ops.ones((2, 2, 2, 2))
|
|
+ scale = array_ops.ones((2,))
|
|
+ offset = array_ops.ones((2,))
|
|
+ mean = array_ops.ones((0,))
|
|
+ variance = array_ops.ones((2,))
|
|
+ with self.assertRaisesRegex(
|
|
+ errors_impl.InvalidArgumentError,
|
|
+ 'When is_training=false, mean must have the same number of elements'):
|
|
+ nn_impl.fused_batch_norm(
|
|
+ x, scale, offset, mean=mean, variance=variance, is_training=False)
|
|
+
|
|
+ x = array_ops.ones((2, 2, 2, 2))
|
|
+ scale = array_ops.ones((2,))
|
|
+ offset = array_ops.ones((2,))
|
|
+ mean = array_ops.ones((2,))
|
|
+ variance = array_ops.ones((0,))
|
|
+ with self.assertRaisesRegex(
|
|
+ errors_impl.InvalidArgumentError,
|
|
+ 'When is_training=false, variance must have the same number of '
|
|
+ nn_impl.fused_batch_norm(
|
|
+ x, scale, offset, mean=mean, variance=variance, is_training=False)
|
|
+
|
|
+ x = array_ops.ones((2, 2, 2, 2))
|
|
+ scale = array_ops.ones((2,))
|
|
+ offset = array_ops.ones((2,))
|
|
+ mean = array_ops.ones((0,))
|
|
+ variance = array_ops.ones((2,))
|
|
+ with self.assertRaisesRegex(
|
|
+ errors_impl.InvalidArgumentError,
|
|
+ 'When exponential_avg_factor != 1, mean must have the same number of '
|
|
+ 'elements'):
|
|
+ nn_impl.fused_batch_norm(
|
|
+ x,
|
|
+ scale,
|
|
+ offset,
|
|
+ mean=mean,
|
|
+ variance=variance,
|
|
+ exponential_avg_factor=0.5)
|
|
+
|
|
+ x = array_ops.ones((2, 2, 2, 2))
|
|
+ scale = array_ops.ones((2,))
|
|
+ offset = array_ops.ones((2,))
|
|
+ mean = array_ops.ones((2,))
|
|
+ variance = array_ops.ones((0,))
|
|
+ with self.assertRaisesRegex(
|
|
+ errors_impl.InvalidArgumentError,
|
|
+ 'When exponential_avg_factor != 1, variance must have the same '
|
|
+ 'number of elements'):
|
|
+ nn_impl.fused_batch_norm(
|
|
+ x,
|
|
+ scale,
|
|
+ offset,
|
|
+ mean=mean,
|
|
+ variance=variance,
|
|
+ exponential_avg_factor=0.5)
|
|
+
|
|
+ def testEagerShapeGradErrors(self):
|
|
+ with context.eager_mode():
|
|
+ y_backprop = array_ops.ones((2, 2, 2, 3))
|
|
+ x = array_ops.ones((2, 2, 2, 2))
|
|
+ scale = array_ops.ones((2,))
|
|
+ reserve_space_1 = array_ops.ones((2,))
|
|
+ reserve_space_2 = array_ops.ones((2,))
|
|
+ with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
|
|
+ 'x and y_backprop must have same shape,'):
|
|
+ gen_nn_ops.fused_batch_norm_grad_v2(y_backprop, x, scale,
|
|
+ reserve_space_1, reserve_space_2)
|
|
+
|
|
+ y_backprop = array_ops.ones((2, 2, 2, 2))
|
|
+ x = array_ops.ones((2, 2, 2, 2))
|
|
+ scale = array_ops.ones((3,))
|
|
+ reserve_space_1 = array_ops.ones((2,))
|
|
+ reserve_space_2 = array_ops.ones((2,))
|
|
+ with self.assertRaisesRegex(
|
|
+ errors_impl.InvalidArgumentError,
|
|
+ 'scale must have the same number of elements'):
|
|
+ gen_nn_ops.fused_batch_norm_grad_v2(y_backprop, x, scale,
|
|
+ reserve_space_1, reserve_space_2)
|
|
+
|
|
+ y_backprop = array_ops.ones((2, 2, 2, 2))
|
|
+ x = array_ops.ones((2, 2, 2, 2))
|
|
+ scale = array_ops.ones((2,))
|
|
+ reserve_space_1 = array_ops.ones((3,))
|
|
+ reserve_space_2 = array_ops.ones((2,))
|
|
+ with self.assertRaisesRegex(
|
|
+ errors_impl.InvalidArgumentError,
|
|
+ 'reserve_space_1 must have the same number of elements'):
|
|
+ gen_nn_ops.fused_batch_norm_grad_v2(y_backprop, x, scale,
|
|
+ reserve_space_1, reserve_space_2)
|
|
+
|
|
+ y_backprop = array_ops.ones((2, 2, 2, 2))
|
|
+ x = array_ops.ones((2, 2, 2, 2))
|
|
+ scale = array_ops.ones((2,))
|
|
+ reserve_space_1 = array_ops.ones((2,))
|
|
+ reserve_space_2 = array_ops.ones((3,))
|
|
+ with self.assertRaisesRegex(
|
|
+ errors_impl.InvalidArgumentError,
|
|
+ 'reserve_space_2 must have the same number of elements'):
|
|
+ gen_nn_ops.fused_batch_norm_grad_v2(y_backprop, x, scale,
|
|
+ reserve_space_1, reserve_space_2)
|
|
+
|
|
|
|
if __name__ == '__main__':
|
|
test.main()
|
|
--
|
|
2.23.0
|
|
|