diff -Nur a/tensorflow/core/kernels/sparse_tensors_map_ops.cc b/tensorflow/core/kernels/sparse_tensors_map_ops.cc --- a/tensorflow/core/kernels/sparse_tensors_map_ops.cc 2020-09-22 09:57:17.000000000 +0800 +++ b/tensorflow/core/kernels/sparse_tensors_map_ops.cc 2021-06-28 22:53:37.005305788 +0800 @@ -21,16 +21,12 @@ #include #include -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" - -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/util/overflow.h" #include "tensorflow/core/util/sparse/sparse_tensor.h" namespace tensorflow { @@ -254,7 +250,22 @@ errors::InvalidArgument( "Rank of input SparseTensor should be > 1, but saw rank: ", rank)); - TensorShape tensor_input_shape(input_shape->vec()); + auto input_shape_vec = input_shape->vec(); + int new_num_elements = 1; + bool overflow_ocurred = false; + for (int i = 0; i < input_shape_vec.size(); i++) { + new_num_elements = + MultiplyWithoutOverflow(new_num_elements, input_shape_vec(i)); + if (new_num_elements < 0) { + overflow_ocurred = true; + } + } + + OP_REQUIRES( + context, !overflow_ocurred, + errors::Internal("Encountered overflow from large input shape.")); + + TensorShape tensor_input_shape(input_shape_vec); gtl::InlinedVector std_order(rank); std::iota(std_order.begin(), std_order.end(), 0); SparseTensor input_st; @@ -262,8 +273,7 @@ tensor_input_shape, std_order, &input_st)); - auto input_shape_t = input_shape->vec(); - const int64 N = input_shape_t(0); + const int64 N = input_shape_vec(0); Tensor sparse_handles(DT_INT64, TensorShape({N})); auto sparse_handles_t = sparse_handles.vec(); @@ -274,7 +284,7 @@ // minibatch entries. TensorShape output_shape; OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( - input_shape_t.data() + 1, + input_shape_vec.data() + 1, input_shape->NumElements() - 1, &output_shape)); // Get groups by minibatch dimension