From 6972f9dfe325636b3db4e0bc517ee22a159365c0 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Thu, 6 May 2021 17:45:51 -0700 Subject: [PATCH] Add missing valuidation to FusedBatchNorm. --- .../core/kernels/fused_batch_norm_op.cc | 29 ++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc index 59470c8a..bd5dab36 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cc @@ -1267,6 +1267,33 @@ class FusedBatchNormOpBase : public OpKernel { context, estimated_variance.dims() == 1, errors::InvalidArgument("estimated_variance must be 1-dimensional", estimated_variance.shape().DebugString())); + + const auto num_channels = GetTensorDim(x, tensor_format_, 'C'); + OP_REQUIRES( + context, scale.NumElements() == num_channels, + errors::InvalidArgument("scale must have the same number of elements " + "as the channels of x, got ", + scale.NumElements(), " and ", num_channels)); + OP_REQUIRES( + context, offset.NumElements() == num_channels, + errors::InvalidArgument("offset must have the same number of elements " + "as the channels of x, got ", + offset.NumElements(), " and ", num_channels)); + if (estimated_mean.NumElements() != 0) { + OP_REQUIRES(context, estimated_mean.NumElements() == num_channels, + errors::InvalidArgument( + "mean must be empty or have the same number of " + "elements as the channels of x, got ", + estimated_mean.NumElements(), " and ",num_channels)); + } + if (estimated_variance.NumElements() != 0) { + OP_REQUIRES(context, estimated_variance.NumElements() == num_channels, + errors::InvalidArgument( + "variance must be empty or have the same number of " + "elements as the channels of x, got ", + estimated_variance.NumElements(), " and ", num_channels)); + } + if (has_side_input_) { OP_REQUIRES(context, side_input->shape() == x.shape(), errors::InvalidArgument( @@ -1279,7 +1306,7 @@ class FusedBatchNormOpBase : public OpKernel { // NOTE(ezhulenev): This requirement is coming from implementation // details of cudnnBatchNormalizationForwardTrainingEx. OP_REQUIRES( - context, !is_training_ || x.dim_size(3) % 4 == 0, + context, !is_training_ || num_channels % 4 == 0, errors::InvalidArgument("FusedBatchNorm with activation requires " "channel dimension to be a multiple of 4.")); } -- 2.23.0