65 lines
3.1 KiB
Diff
65 lines
3.1 KiB
Diff
From 8a793b5d7f59e37ac7f3cd0954a750a2fe76bad4 Mon Sep 17 00:00:00 2001
|
|
From: Mihai Maruseac <mihaimaruseac@google.com>
|
|
Date: Thu, 29 Jul 2021 18:23:45 -0700
|
|
Subject: [PATCH] Prevent division by 0 in common shape functions.
|
|
|
|
PiperOrigin-RevId: 387712197
|
|
Change-Id: Id25c7460e35b68aeeeac23b9a88e455b443ee149
|
|
---
|
|
tensorflow/core/framework/common_shape_fns.cc | 11 +++++++++++
|
|
1 file changed, 11 insertions(+)
|
|
|
|
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
|
|
index b9efddf4..e25d5581 100644
|
|
--- a/tensorflow/core/framework/common_shape_fns.cc
|
|
+++ b/tensorflow/core/framework/common_shape_fns.cc
|
|
@@ -659,6 +659,8 @@ Status Conv2DShapeImpl(shape_inference::InferenceContext* c,
|
|
if (c->ValueKnown(input_depth_dim) && c->ValueKnown(filter_input_depth_dim)) {
|
|
int64 input_depth_value = c->Value(input_depth_dim),
|
|
filter_input_depth_value = c->Value(filter_input_depth_dim);
|
|
+ if (filter_input_depth_value == 0)
|
|
+ return errors::InvalidArgument("Depth of filter must not be 0");
|
|
if (input_depth_value % filter_input_depth_value != 0)
|
|
return errors::InvalidArgument(
|
|
"Depth of input (", input_depth_value,
|
|
@@ -668,6 +670,8 @@ Status Conv2DShapeImpl(shape_inference::InferenceContext* c,
|
|
int64 num_groups = input_depth_value / filter_input_depth_value;
|
|
if (c->ValueKnown(output_depth_dim)) {
|
|
int64 output_depth_value = c->Value(output_depth_dim);
|
|
+ if (num_groups == 0)
|
|
+ return errors::InvalidArgument("Number of groups must not be 0");
|
|
if (output_depth_value % num_groups != 0)
|
|
return errors::InvalidArgument(
|
|
"Depth of output (", output_depth_value,
|
|
@@ -798,6 +802,8 @@ Status Conv3DShape(shape_inference::InferenceContext* c) {
|
|
if (c->ValueKnown(input_depth_dim) && c->ValueKnown(filter_input_depth_dim)) {
|
|
int64 input_depth_value = c->Value(input_depth_dim),
|
|
filter_input_depth_value = c->Value(filter_input_depth_dim);
|
|
+ if (filter_input_depth_value == 0)
|
|
+ return errors::InvalidArgument("Depth of filter must not be 0");
|
|
if (input_depth_value % filter_input_depth_value != 0)
|
|
return errors::InvalidArgument(
|
|
"Depth of input (", input_depth_value,
|
|
@@ -807,6 +813,8 @@ Status Conv3DShape(shape_inference::InferenceContext* c) {
|
|
int64 num_groups = input_depth_value / filter_input_depth_value;
|
|
if (c->ValueKnown(output_depth_dim)) {
|
|
int64 output_depth_value = c->Value(output_depth_dim);
|
|
+ if (num_groups == 0)
|
|
+ return errors::InvalidArgument("Number of groups must not be 0");
|
|
if (output_depth_value % num_groups != 0)
|
|
return errors::InvalidArgument(
|
|
"Depth of output (", output_depth_value,
|
|
@@ -2364,6 +2372,9 @@ Status SparseReduceShapeFn(InferenceContext* c) {
|
|
|
|
int64 ndims = shape_vec.size();
|
|
absl::flat_hash_set<int64> axes;
|
|
+ if (ndims == 0)
|
|
+ return errors::InvalidArgument(
|
|
+ "Number of dims in shape tensor must not be 0");
|
|
for (int i = 0; i < axes_vec.size(); i++) {
|
|
axes.insert((axes_vec(i) + ndims) % ndims);
|
|
}
|
|
--
|
|
2.27.0
|
|
|