From a7116dd3913c4a4afd2a3a938573aa7c785fdfc6 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Sat, 17 Apr 2021 20:55:53 -0700 Subject: [PATCH] Validate `MatrixDiagV{2,3}` arguments to prevent breakage. PiperOrigin-RevId: 369056033 Change-Id: Ic2018c297d3dd6f252dc1dd3667f1ed5cb1eaa42 --- .../core/kernels/matrix_diag_op.cc | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/kernels/matrix_diag_op.cc b/tensorflow/core/kernels/matrix_diag_op.cc index 69cc8170793ae..d4eb589836a85 100644 --- a/tensorflow/core/kernels/matrix_diag_op.cc +++ b/tensorflow/core/kernels/matrix_diag_op.cc @@ -192,9 +192,22 @@ class MatrixDiagOp : public OpKernel { upper_diag_index = diag_index.flat()(1); } } - num_rows = context->input(2).flat()(0); - num_cols = context->input(3).flat()(0); - padding_value = context->input(4).flat()(0); + + auto& num_rows_tensor = context->input(2); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_rows_tensor.shape()), + errors::InvalidArgument("num_rows must be a scalar")); + num_rows = num_rows_tensor.flat()(0); + + auto& num_cols_tensor = context->input(3); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_cols_tensor.shape()), + errors::InvalidArgument("num_cols must be a scalar")); + num_cols = num_cols_tensor.flat()(0); + + auto& padding_value_tensor = context->input(4); + OP_REQUIRES(context, + TensorShapeUtils::IsScalar(padding_value_tensor.shape()), + errors::InvalidArgument("padding_value must be a scalar")); + padding_value = padding_value_tensor.flat()(0); } // Size validations.