65 lines
2.6 KiB
Diff
65 lines
2.6 KiB
Diff
|
|
diff -Nur a/tensorflow/core/kernels/sparse_tensors_map_ops.cc b/tensorflow/core/kernels/sparse_tensors_map_ops.cc
|
||
|
|
--- a/tensorflow/core/kernels/sparse_tensors_map_ops.cc 2020-09-22 09:57:17.000000000 +0800
|
||
|
|
+++ b/tensorflow/core/kernels/sparse_tensors_map_ops.cc 2021-06-28 22:53:37.005305788 +0800
|
||
|
|
@@ -21,16 +21,12 @@
|
||
|
|
#include <utility>
|
||
|
|
#include <vector>
|
||
|
|
|
||
|
|
-#include "tensorflow/core/framework/op_kernel.h"
|
||
|
|
-#include "tensorflow/core/framework/register_types.h"
|
||
|
|
-
|
||
|
|
-#include "tensorflow/core/framework/op_kernel.h"
|
||
|
|
-#include "tensorflow/core/framework/register_types.h"
|
||
|
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||
|
|
#include "tensorflow/core/framework/tensor.h"
|
||
|
|
#include "tensorflow/core/framework/tensor_util.h"
|
||
|
|
#include "tensorflow/core/framework/types.h"
|
||
|
|
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||
|
|
+#include "tensorflow/core/util/overflow.h"
|
||
|
|
#include "tensorflow/core/util/sparse/sparse_tensor.h"
|
||
|
|
|
||
|
|
namespace tensorflow {
|
||
|
|
@@ -254,7 +250,22 @@
|
||
|
|
errors::InvalidArgument(
|
||
|
|
"Rank of input SparseTensor should be > 1, but saw rank: ", rank));
|
||
|
|
|
||
|
|
- TensorShape tensor_input_shape(input_shape->vec<int64>());
|
||
|
|
+ auto input_shape_vec = input_shape->vec<int64>();
|
||
|
|
+ int new_num_elements = 1;
|
||
|
|
+ bool overflow_ocurred = false;
|
||
|
|
+ for (int i = 0; i < input_shape_vec.size(); i++) {
|
||
|
|
+ new_num_elements =
|
||
|
|
+ MultiplyWithoutOverflow(new_num_elements, input_shape_vec(i));
|
||
|
|
+ if (new_num_elements < 0) {
|
||
|
|
+ overflow_ocurred = true;
|
||
|
|
+ }
|
||
|
|
+ }
|
||
|
|
+
|
||
|
|
+ OP_REQUIRES(
|
||
|
|
+ context, !overflow_ocurred,
|
||
|
|
+ errors::Internal("Encountered overflow from large input shape."));
|
||
|
|
+
|
||
|
|
+ TensorShape tensor_input_shape(input_shape_vec);
|
||
|
|
gtl::InlinedVector<int64, 8> std_order(rank);
|
||
|
|
std::iota(std_order.begin(), std_order.end(), 0);
|
||
|
|
SparseTensor input_st;
|
||
|
|
@@ -262,8 +273,7 @@
|
||
|
|
tensor_input_shape, std_order,
|
||
|
|
&input_st));
|
||
|
|
|
||
|
|
- auto input_shape_t = input_shape->vec<int64>();
|
||
|
|
- const int64 N = input_shape_t(0);
|
||
|
|
+ const int64 N = input_shape_vec(0);
|
||
|
|
|
||
|
|
Tensor sparse_handles(DT_INT64, TensorShape({N}));
|
||
|
|
auto sparse_handles_t = sparse_handles.vec<int64>();
|
||
|
|
@@ -274,7 +284,7 @@
|
||
|
|
// minibatch entries.
|
||
|
|
TensorShape output_shape;
|
||
|
|
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
|
||
|
|
- input_shape_t.data() + 1,
|
||
|
|
+ input_shape_vec.data() + 1,
|
||
|
|
input_shape->NumElements() - 1, &output_shape));
|
||
|
|
|
||
|
|
// Get groups by minibatch dimension
|