From 480641e3599775a8895254ffbc0fc45621334f68 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Sat, 24 Apr 2021 16:47:25 -0700 Subject: [PATCH] Validate (and ensure validation sticks) inputs for `MatrixTriangularSolve`. PiperOrigin-RevId: 370282444 Change-Id: Iaed61a0b0727cc42c830658b72eb69f785f48dc5 --- .../matrix_triangular_solve_op_impl.h | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/kernels/matrix_triangular_solve_op_impl.h b/tensorflow/core/kernels/matrix_triangular_solve_op_impl.h index 99249f792b6ed..ce5392e62b9fa 100644 --- a/tensorflow/core/kernels/matrix_triangular_solve_op_impl.h +++ b/tensorflow/core/kernels/matrix_triangular_solve_op_impl.h @@ -162,6 +162,9 @@ class BaseMatrixTriangularSolveOp : public OpKernel { const Tensor& in1 = ctx->input(1); ValidateInputTensors(ctx, in0, in1); + if (!ctx->status().ok()) { + return; + } MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes()); OP_REQUIRES( @@ -230,13 +233,22 @@ class MatrixTriangularSolveOp private: void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0, const Tensor& in1) override { + const auto in0_num_dims = in0.dims(); OP_REQUIRES( - ctx, in0.dims() >= 2, - errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims())); + ctx, in0_num_dims >= 2, + errors::InvalidArgument("In[0] ndims must be >= 2: ", in0_num_dims)); + const auto in1_num_dims = in1.dims(); OP_REQUIRES( - ctx, in1.dims() >= 2, - errors::InvalidArgument("In[0] ndims must be >= 2: ", in1.dims())); + ctx, in1_num_dims >= 2, + errors::InvalidArgument("In[1] ndims must be >= 2: ", in1_num_dims)); + + const auto in0_last_dim = in0.dim_size(in0_num_dims - 1); + const auto in0_prev_dim = in0.dim_size(in0_num_dims - 2); + OP_REQUIRES(ctx, in0_last_dim == in0_prev_dim, + errors::InvalidArgument( + "In[0] matrices in the last dimensions must be square (", + in0_last_dim, " =/= ", in0_prev_dim, ")")); } };