From 12c727cee857fa19be717f336943d95fca4ffe4f Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Thu, 6 May 2021 14:02:47 -0700 Subject: [PATCH] Validate inputs of `FractionalAvgPoolGrad`. PiperOrigin-RevId: 372420640 Change-Id: Icc583928e6cdc3062e12498e4d2337a8fe3da016 --- tensorflow/core/kernels/fractional_avg_pool_op.cc | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tensorflow/core/kernels/fractional_avg_pool_op.cc b/tensorflow/core/kernels/fractional_avg_pool_op.cc index b8a5083e5340f..0452638a06679 100644 --- a/tensorflow/core/kernels/fractional_avg_pool_op.cc +++ b/tensorflow/core/kernels/fractional_avg_pool_op.cc @@ -250,6 +250,19 @@ class FractionalAvgPoolGradOp : public OpKernel { const int64 out_cols = out_backprop.dim_size(2); const int64 out_depth = out_backprop.dim_size(3); + OP_REQUIRES(context, row_seq_tensor.NumElements() > out_rows, + errors::InvalidArgument("Given out_backprop shape ", + out_backprop.shape().DebugString(), + ", row_seq_tensor must have at least ", + out_rows + 1, " elements, but got ", + row_seq_tensor.NumElements())); + OP_REQUIRES(context, col_seq_tensor.NumElements() > out_cols, + errors::InvalidArgument("Given out_backprop shape ", + out_backprop.shape().DebugString(), + ", col_seq_tensor must have at least ", + out_cols + 1, " elements, but got ", + col_seq_tensor.NumElements())); + auto row_seq_tensor_flat = row_seq_tensor.flat(); auto col_seq_tensor_flat = col_seq_tensor.flat(); auto orig_input_tensor_shape_flat = orig_input_tensor_shape.flat();