From ff8894044dfae5568ecbf2ed514c1a37dc394f1b Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Fri, 30 Jul 2021 18:58:29 -0700 Subject: [PATCH] Add one missing valdiation to `matrix_set_diag_op.cc` PiperOrigin-RevId: 387923408 Change-Id: If6a97b9098c13879400f56c22f91555cdf0ce5d7 --- tensorflow/core/kernels/matrix_set_diag_op.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorflow/core/kernels/matrix_set_diag_op.cc b/tensorflow/core/kernels/matrix_set_diag_op.cc index 07b6e69de67dc..4e89433718b46 100644 --- a/tensorflow/core/kernels/matrix_set_diag_op.cc +++ b/tensorflow/core/kernels/matrix_set_diag_op.cc @@ -70,6 +70,9 @@ class MatrixSetDiagOp : public OpKernel { errors::InvalidArgument( "diag_index must be a scalar or vector, received shape: ", diag_index.shape().DebugString())); + OP_REQUIRES( + context, diag_index.NumElements() > 0, + errors::InvalidArgument("diag_index must have at least one element")); lower_diag_index = diag_index.flat()(0); upper_diag_index = lower_diag_index; if (TensorShapeUtils::IsVector(diag_index.shape())) {