tensorflow/CVE-2021-29551.patch

54 lines
2.1 KiB
Diff
Raw Normal View History

From 480641e3599775a8895254ffbc0fc45621334f68 Mon Sep 17 00:00:00 2001
From: Mihai Maruseac <mihaimaruseac@google.com>
Date: Sat, 24 Apr 2021 16:47:25 -0700
Subject: [PATCH] Validate (and ensure validation sticks) inputs for
`MatrixTriangularSolve`.
PiperOrigin-RevId: 370282444
Change-Id: Iaed61a0b0727cc42c830658b72eb69f785f48dc5
---
.../matrix_triangular_solve_op_impl.h | 20 +++++++++++++++----
1 file changed, 16 insertions(+), 4 deletions(-)
diff --git a/tensorflow/core/kernels/matrix_triangular_solve_op_impl.h b/tensorflow/core/kernels/matrix_triangular_solve_op_impl.h
index 99249f792b6ed..ce5392e62b9fa 100644
--- a/tensorflow/core/kernels/matrix_triangular_solve_op_impl.h
+++ b/tensorflow/core/kernels/matrix_triangular_solve_op_impl.h
@@ -162,6 +162,9 @@ class BaseMatrixTriangularSolveOp : public OpKernel {
const Tensor& in1 = ctx->input(1);
ValidateInputTensors(ctx, in0, in1);
+ if (!ctx->status().ok()) {
+ return;
+ }
MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
OP_REQUIRES(
@@ -230,13 +233,22 @@ class MatrixTriangularSolveOp
private:
void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
const Tensor& in1) override {
+ const auto in0_num_dims = in0.dims();
OP_REQUIRES(
- ctx, in0.dims() >= 2,
- errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims()));
+ ctx, in0_num_dims >= 2,
+ errors::InvalidArgument("In[0] ndims must be >= 2: ", in0_num_dims));
+ const auto in1_num_dims = in1.dims();
OP_REQUIRES(
- ctx, in1.dims() >= 2,
- errors::InvalidArgument("In[0] ndims must be >= 2: ", in1.dims()));
+ ctx, in1_num_dims >= 2,
+ errors::InvalidArgument("In[1] ndims must be >= 2: ", in1_num_dims));
+
+ const auto in0_last_dim = in0.dim_size(in0_num_dims - 1);
+ const auto in0_prev_dim = in0.dim_size(in0_num_dims - 2);
+ OP_REQUIRES(ctx, in0_last_dim == in0_prev_dim,
+ errors::InvalidArgument(
+ "In[0] matrices in the last dimensions must be square (",
+ in0_last_dim, " =/= ", in0_prev_dim, ")"));
}
};