54 lines
2.1 KiB
Diff
54 lines
2.1 KiB
Diff
From 480641e3599775a8895254ffbc0fc45621334f68 Mon Sep 17 00:00:00 2001
|
|
From: Mihai Maruseac <mihaimaruseac@google.com>
|
|
Date: Sat, 24 Apr 2021 16:47:25 -0700
|
|
Subject: [PATCH] Validate (and ensure validation sticks) inputs for
|
|
`MatrixTriangularSolve`.
|
|
|
|
PiperOrigin-RevId: 370282444
|
|
Change-Id: Iaed61a0b0727cc42c830658b72eb69f785f48dc5
|
|
---
|
|
.../matrix_triangular_solve_op_impl.h | 20 +++++++++++++++----
|
|
1 file changed, 16 insertions(+), 4 deletions(-)
|
|
|
|
diff --git a/tensorflow/core/kernels/matrix_triangular_solve_op_impl.h b/tensorflow/core/kernels/matrix_triangular_solve_op_impl.h
|
|
index 99249f792b6ed..ce5392e62b9fa 100644
|
|
--- a/tensorflow/core/kernels/matrix_triangular_solve_op_impl.h
|
|
+++ b/tensorflow/core/kernels/matrix_triangular_solve_op_impl.h
|
|
@@ -162,6 +162,9 @@ class BaseMatrixTriangularSolveOp : public OpKernel {
|
|
const Tensor& in1 = ctx->input(1);
|
|
|
|
ValidateInputTensors(ctx, in0, in1);
|
|
+ if (!ctx->status().ok()) {
|
|
+ return;
|
|
+ }
|
|
|
|
MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
|
|
OP_REQUIRES(
|
|
@@ -230,13 +233,22 @@ class MatrixTriangularSolveOp
|
|
private:
|
|
void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
|
|
const Tensor& in1) override {
|
|
+ const auto in0_num_dims = in0.dims();
|
|
OP_REQUIRES(
|
|
- ctx, in0.dims() >= 2,
|
|
- errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims()));
|
|
+ ctx, in0_num_dims >= 2,
|
|
+ errors::InvalidArgument("In[0] ndims must be >= 2: ", in0_num_dims));
|
|
|
|
+ const auto in1_num_dims = in1.dims();
|
|
OP_REQUIRES(
|
|
- ctx, in1.dims() >= 2,
|
|
- errors::InvalidArgument("In[0] ndims must be >= 2: ", in1.dims()));
|
|
+ ctx, in1_num_dims >= 2,
|
|
+ errors::InvalidArgument("In[1] ndims must be >= 2: ", in1_num_dims));
|
|
+
|
|
+ const auto in0_last_dim = in0.dim_size(in0_num_dims - 1);
|
|
+ const auto in0_prev_dim = in0.dim_size(in0_num_dims - 2);
|
|
+ OP_REQUIRES(ctx, in0_last_dim == in0_prev_dim,
|
|
+ errors::InvalidArgument(
|
|
+ "In[0] matrices in the last dimensions must be square (",
|
|
+ in0_last_dim, " =/= ", in0_prev_dim, ")"));
|
|
}
|
|
};
|
|
|