267 lines
11 KiB
Diff
267 lines
11 KiB
Diff
|
|
From ebc70b7a592420d3d2f359e4b1694c236b82c7ae Mon Sep 17 00:00:00 2001
|
||
|
|
From: Mihai Maruseac <mihaimaruseac@google.com>
|
||
|
|
Date: Mon, 7 Dec 2020 11:15:21 -0800
|
||
|
|
Subject: [PATCH] Validate that `DataFormat*` attributes form a permutation.
|
||
|
|
|
||
|
|
The `src_format` and `dst_format` attributes for the `DataFormatDimMap` and `DataFormatVecPermute` raw ops are supposed to determine a permutation. However, this was not validated and could result in unitialized memory accesses as well as writes outside of bounds and potential crashes.
|
||
|
|
|
||
|
|
While here, we also test that the format attributes have the needed length, add tests for all validation failure cases, remove unnecessary calls to `strings::StrCat`, and fix a few grammar errors.
|
||
|
|
|
||
|
|
This will be cherry-picked on the supported release branches.
|
||
|
|
|
||
|
|
PiperOrigin-RevId: 346135579
|
||
|
|
Change-Id: I1c76392382c89ad8f072d5bc93d70669851eb404
|
||
|
|
---
|
||
|
|
tensorflow/core/kernels/data_format_ops.cc | 72 ++++++++++++++--
|
||
|
|
tensorflow/python/ops/nn_test.py | 96 ++++++++++++++++++++++
|
||
|
|
2 files changed, 161 insertions(+), 7 deletions(-)
|
||
|
|
|
||
|
|
diff --git a/tensorflow/core/kernels/data_format_ops.cc b/tensorflow/core/kernels/data_format_ops.cc
|
||
|
|
index e9c71f17..abe2fbc3 100644
|
||
|
|
--- a/tensorflow/core/kernels/data_format_ops.cc
|
||
|
|
+++ b/tensorflow/core/kernels/data_format_ops.cc
|
||
|
|
@@ -18,16 +18,52 @@ limitations under the License.
|
||
|
|
#define EIGEN_USE_THREADS
|
||
|
|
|
||
|
|
#include "tensorflow/core/kernels/data_format_ops.h"
|
||
|
|
+
|
||
|
|
+#include <map>
|
||
|
|
+
|
||
|
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||
|
|
#include "tensorflow/core/framework/op_kernel.h"
|
||
|
|
#include "tensorflow/core/framework/register_types.h"
|
||
|
|
#include "tensorflow/core/framework/tensor.h"
|
||
|
|
+#include "tensorflow/core/platform/errors.h"
|
||
|
|
|
||
|
|
namespace tensorflow {
|
||
|
|
|
||
|
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||
|
|
typedef Eigen::GpuDevice GPUDevice;
|
||
|
|
|
||
|
|
+// Ensure that `src` and `dst` define a valid permutation.
|
||
|
|
+// Ops defined in this file assume that user specifies a permutation via two
|
||
|
|
+// string attributes. This check validates that these attributes properly define
|
||
|
|
+// it to prevent security vulnerabilities.
|
||
|
|
+static bool IsValidPermutation(const std::string& src, const std::string& dst) {
|
||
|
|
+ if (src.size() != dst.size()) {
|
||
|
|
+ return false;
|
||
|
|
+ }
|
||
|
|
+
|
||
|
|
+ std::map<char, bool> characters;
|
||
|
|
+
|
||
|
|
+ // Every character in `src` must be present only once
|
||
|
|
+ for (const auto c : src) {
|
||
|
|
+ if (characters[c]) {
|
||
|
|
+ return false;
|
||
|
|
+ }
|
||
|
|
+ characters[c] = true;
|
||
|
|
+ }
|
||
|
|
+
|
||
|
|
+ // Every character in `dst` must show up in `src` exactly once
|
||
|
|
+ for (const auto c : dst) {
|
||
|
|
+ if (!characters[c]) {
|
||
|
|
+ return false;
|
||
|
|
+ }
|
||
|
|
+ characters[c] = false;
|
||
|
|
+ }
|
||
|
|
+
|
||
|
|
+ // At this point, characters[] has been switched to true and false exactly
|
||
|
|
+ // once for all character in `src` (and `dst`) so we have a valid permutation
|
||
|
|
+ return true;
|
||
|
|
+}
|
||
|
|
+
|
||
|
|
template <typename Device, typename T>
|
||
|
|
class DataFormatDimMapOp : public OpKernel {
|
||
|
|
public:
|
||
|
|
@@ -38,14 +74,18 @@ class DataFormatDimMapOp : public OpKernel {
|
||
|
|
string dst_format;
|
||
|
|
OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
|
||
|
|
OP_REQUIRES(context, src_format.size() == 4 || src_format.size() == 5,
|
||
|
|
- errors::InvalidArgument(strings::StrCat(
|
||
|
|
- "Source format must of length 4 or 5, received "
|
||
|
|
- "src_format = ", src_format)));
|
||
|
|
+ errors::InvalidArgument(
|
||
|
|
+ "Source format must be of length 4 or 5, received "
|
||
|
|
+ "src_format = ", src_format));
|
||
|
|
+ OP_REQUIRES(context, dst_format.size() == 4 || dst_format.size() == 5,
|
||
|
|
+ errors::InvalidArgument("Destination format must be of length "
|
||
|
|
+ "4 or 5, received dst_format = ",
|
||
|
|
+ dst_format));
|
||
|
|
OP_REQUIRES(
|
||
|
|
- context, dst_format.size() == 4 || dst_format.size() == 5,
|
||
|
|
- errors::InvalidArgument(strings::StrCat(
|
||
|
|
- "Destination format must of length 4 or 5, received dst_format = ",
|
||
|
|
- dst_format)));
|
||
|
|
+ context, IsValidPermutation(src_format, dst_format),
|
||
|
|
+ errors::InvalidArgument(
|
||
|
|
+ "Destination and source format must determine a permutation, got ",
|
||
|
|
+ src_format, " and ", dst_format));
|
||
|
|
dst_idx_ = Tensor(DT_INT32, {static_cast<int64>(src_format.size())});
|
||
|
|
for (int i = 0; i < src_format.size(); ++i) {
|
||
|
|
for (int j = 0; j < dst_format.size(); ++j) {
|
||
|
|
@@ -77,8 +117,22 @@ class DataFormatVecPermuteOp : public OpKernel {
|
||
|
|
: OpKernel(context) {
|
||
|
|
string src_format;
|
||
|
|
OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
|
||
|
|
+ OP_REQUIRES(context, src_format.size() == 4 || src_format.size() == 5,
|
||
|
|
+ errors::InvalidArgument(
|
||
|
|
+ "Source format must be of length 4 or 5, received "
|
||
|
|
+ "src_format = ",
|
||
|
|
+ src_format));
|
||
|
|
string dst_format;
|
||
|
|
OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
|
||
|
|
+ OP_REQUIRES(context, dst_format.size() == 4 || dst_format.size() == 5,
|
||
|
|
+ errors::InvalidArgument("Destination format must be of length "
|
||
|
|
+ "4 or 5, received dst_format = ",
|
||
|
|
+ dst_format));
|
||
|
|
+ OP_REQUIRES(
|
||
|
|
+ context, IsValidPermutation(src_format, dst_format),
|
||
|
|
+ errors::InvalidArgument(
|
||
|
|
+ "Destination and source format must determine a permutation, got ",
|
||
|
|
+ src_format, " and ", dst_format));
|
||
|
|
src_format_ = src_format;
|
||
|
|
dst_format_ = dst_format;
|
||
|
|
}
|
||
|
|
@@ -124,6 +178,10 @@ class DataFormatVecPermuteOp : public OpKernel {
|
||
|
|
};
|
||
|
|
keep_only_spatial_dimensions(&src_format_str);
|
||
|
|
keep_only_spatial_dimensions(&dst_format_str);
|
||
|
|
+ OP_REQUIRES(context,
|
||
|
|
+ src_format_str.size() == 2 && dst_format_str.size() == 2,
|
||
|
|
+ errors::InvalidArgument(
|
||
|
|
+ "Format specifier must contain H and W for 2D case"));
|
||
|
|
}
|
||
|
|
ComputeDstIndex(src_format_str, dst_format_str, input.dims(), &dst_idx);
|
||
|
|
|
||
|
|
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
|
||
|
|
index 55d11a35..d2094a7d 100644
|
||
|
|
--- a/tensorflow/python/ops/nn_test.py
|
||
|
|
+++ b/tensorflow/python/ops/nn_test.py
|
||
|
|
@@ -27,6 +27,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
|
||
|
|
from tensorflow.python.eager import def_function
|
||
|
|
from tensorflow.python.framework import constant_op
|
||
|
|
from tensorflow.python.framework import dtypes
|
||
|
|
+from tensorflow.python.framework import errors
|
||
|
|
from tensorflow.python.framework import ops
|
||
|
|
from tensorflow.python.framework import tensor_spec
|
||
|
|
from tensorflow.python.framework import test_util
|
||
|
|
@@ -1234,6 +1235,7 @@ class DataFormatDimMapTest(test_lib.TestCase):
|
||
|
|
y_val = self.evaluate(y)
|
||
|
|
self.assertAllEqual(y_val, y_val_expected)
|
||
|
|
|
||
|
|
+ @test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||
|
|
def testArbitraryASCII(self):
|
||
|
|
x_val = [-4, -3, -2, -1, 0, 1, 2, 3]
|
||
|
|
y_val_expected = [3, 2, 1, 0, 3, 2, 1, 0]
|
||
|
|
@@ -1243,6 +1245,46 @@ class DataFormatDimMapTest(test_lib.TestCase):
|
||
|
|
y_val = self.evaluate(y)
|
||
|
|
self.assertAllEqual(y_val, y_val_expected)
|
||
|
|
|
||
|
|
+ @test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||
|
|
+ def testInvalidLength(self):
|
||
|
|
+ x = [-4, -3, -2, -1, 0, 1, 2, 3]
|
||
|
|
+ with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||
|
|
+ "Source format must be of length 4 or 5"):
|
||
|
|
+ op = nn_ops.data_format_dim_map(
|
||
|
|
+ x, src_format="12345678", dst_format="87654321")
|
||
|
|
+ with test_util.use_gpu():
|
||
|
|
+ self.evaluate(op)
|
||
|
|
+
|
||
|
|
+ @test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||
|
|
+ def testDuplicateSrc(self):
|
||
|
|
+ x = [-4, -3, -2, -1, 0, 1, 2, 3]
|
||
|
|
+ with self.assertRaisesRegex(
|
||
|
|
+ errors.InvalidArgumentError,
|
||
|
|
+ "Destination and source format must determine a permutation"):
|
||
|
|
+ op = nn_ops.data_format_dim_map(x, src_format="1233", dst_format="4321")
|
||
|
|
+ with test_util.use_gpu():
|
||
|
|
+ self.evaluate(op)
|
||
|
|
+
|
||
|
|
+ @test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||
|
|
+ def testDuplicateDst(self):
|
||
|
|
+ x = [-4, -3, -2, -1, 0, 1, 2, 3]
|
||
|
|
+ with self.assertRaisesRegex(
|
||
|
|
+ errors.InvalidArgumentError,
|
||
|
|
+ "Destination and source format must determine a permutation"):
|
||
|
|
+ op = nn_ops.data_format_dim_map(x, src_format="1234", dst_format="3321")
|
||
|
|
+ with test_util.use_gpu():
|
||
|
|
+ self.evaluate(op)
|
||
|
|
+
|
||
|
|
+ @test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||
|
|
+ def testExtraSpecifiers(self):
|
||
|
|
+ x = [-4, -3, -2, -1, 0, 1, 2, 3]
|
||
|
|
+ with self.assertRaisesRegex(
|
||
|
|
+ errors.InvalidArgumentError,
|
||
|
|
+ "Destination and source format must determine a permutation"):
|
||
|
|
+ op = nn_ops.data_format_dim_map(x, src_format="1234", dst_format="5321")
|
||
|
|
+ with test_util.use_gpu():
|
||
|
|
+ self.evaluate(op)
|
||
|
|
+
|
||
|
|
|
||
|
|
class DataFormatVectorPermuteTest(test_lib.TestCase):
|
||
|
|
|
||
|
|
@@ -1344,6 +1386,60 @@ class DataFormatVectorPermuteTest(test_lib.TestCase):
|
||
|
|
y_val = self.evaluate(y)
|
||
|
|
self.assertAllEqual(y_val, [[7, 4], [4, 5], [5, 1], [9, 3]])
|
||
|
|
|
||
|
|
+ @test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||
|
|
+ def testInvalidLength(self):
|
||
|
|
+ x = [0, 1, 2, 3]
|
||
|
|
+ with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||
|
|
+ "Source format must be of length 4 or 5"):
|
||
|
|
+ op = nn_ops.data_format_vec_permute(
|
||
|
|
+ x, src_format="12345678", dst_format="87654321")
|
||
|
|
+ with test_util.use_gpu():
|
||
|
|
+ self.evaluate(op)
|
||
|
|
+
|
||
|
|
+ @test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||
|
|
+ def testDuplicateSrc(self):
|
||
|
|
+ x = [0, 1, 2, 3]
|
||
|
|
+ with self.assertRaisesRegex(
|
||
|
|
+ errors.InvalidArgumentError,
|
||
|
|
+ "Destination and source format must determine a permutation"):
|
||
|
|
+ op = nn_ops.data_format_vec_permute(
|
||
|
|
+ x, src_format="1233", dst_format="4321")
|
||
|
|
+ with test_util.use_gpu():
|
||
|
|
+ self.evaluate(op)
|
||
|
|
+
|
||
|
|
+ @test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||
|
|
+ def testDuplicateDst(self):
|
||
|
|
+ x = [0, 1, 2, 3]
|
||
|
|
+ with self.assertRaisesRegex(
|
||
|
|
+ errors.InvalidArgumentError,
|
||
|
|
+ "Destination and source format must determine a permutation"):
|
||
|
|
+ op = nn_ops.data_format_vec_permute(
|
||
|
|
+ x, src_format="1234", dst_format="3321")
|
||
|
|
+ with test_util.use_gpu():
|
||
|
|
+ self.evaluate(op)
|
||
|
|
+
|
||
|
|
+ @test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||
|
|
+ def testExtraSpecifiers(self):
|
||
|
|
+ x = [0, 1, 2, 3]
|
||
|
|
+ with self.assertRaisesRegex(
|
||
|
|
+ errors.InvalidArgumentError,
|
||
|
|
+ "Destination and source format must determine a permutation"):
|
||
|
|
+ op = nn_ops.data_format_vec_permute(
|
||
|
|
+ x, src_format="1234", dst_format="5321")
|
||
|
|
+ with test_util.use_gpu():
|
||
|
|
+ self.evaluate(op)
|
||
|
|
+
|
||
|
|
+ @test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||
|
|
+ def test2DNoWH(self):
|
||
|
|
+ x = [[0, 1], [2, 3]]
|
||
|
|
+ with self.assertRaisesRegex(
|
||
|
|
+ errors.InvalidArgumentError,
|
||
|
|
+ "Format specifier must contain H and W for 2D case"):
|
||
|
|
+ op = nn_ops.data_format_vec_permute(
|
||
|
|
+ x, src_format="1234", dst_format="4321")
|
||
|
|
+ with test_util.use_gpu():
|
||
|
|
+ self.evaluate(op)
|
||
|
|
+
|
||
|
|
|
||
|
|
@test_util.run_all_in_graph_and_eager_modes
|
||
|
|
class AvgPoolTest(test_lib.TestCase):
|
||
|
|
--
|
||
|
|
2.27.0
|
||
|
|
|