253 lines
11 KiB
Diff
253 lines
11 KiB
Diff
From 1a11d01c1fdd6683e9aa210dccde81de127dbf3e Mon Sep 17 00:00:00 2001
|
|
From: Kaixi Hou <kaixih@nvidia.com>
|
|
Date: Mon, 14 Sep 2020 15:52:22 -0700
|
|
Subject: [PATCH 1/1] support reduce ops for 5d tensors in layout optimizer
|
|
|
|
---
|
|
.../generic_layout_optimizer_transposer.cc | 27 +++++++++-
|
|
tensorflow/core/kernels/data_format_ops.cc | 10 ++--
|
|
tensorflow/core/kernels/data_format_ops.h | 53 ++++++++++++++-----
|
|
.../python/grappler/layout_optimizer_test.py | 39 ++++++++++++++
|
|
tensorflow/python/ops/nn_test.py | 27 ++++++++++
|
|
5 files changed, 136 insertions(+), 20 deletions(-)
|
|
|
|
diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
|
|
index ab7d8fcd..fbbeffc7 100644
|
|
--- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
|
|
+++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
|
|
@@ -1283,11 +1283,31 @@ bool ReduceTransposer::IsReduceAxisSupported(
|
|
Status ReduceTransposer::TransposeNode(TransposeContext* context,
|
|
utils::MutableNodeView* node) {
|
|
DCHECK(IsReduceOp(*node->node()));
|
|
- if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, 4) ||
|
|
+ const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
|
|
+ const auto& shape = output_shape_attr->list().shape(0);
|
|
+ const int rank = shape.dim_size();
|
|
+ std::string src_format = context->src_format;
|
|
+ std::string dst_format = context->dst_format;
|
|
+ // Update the format from 4D to 5D layout if necessary.
|
|
+ if (rank == 5) {
|
|
+ std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW";
|
|
+ std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW";
|
|
+ context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
|
|
+ dst_format_3d);
|
|
+ }
|
|
+ if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, rank) ||
|
|
!IsReduceAxisSupported(*context, *node) ||
|
|
!IsAfterDstToSrcTransform(*context, *node)) {
|
|
+ // Change back to the original layout due to early exit.
|
|
+ if (rank == 5) {
|
|
+ context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
|
+ dst_format);
|
|
+ }
|
|
return Status::OK();
|
|
}
|
|
+ VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
|
+ << "' with op '" << node->GetOp() << "' from data format '"
|
|
+ << context->src_format << "' to '" << context->dst_format << "'";
|
|
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
|
|
TF_RETURN_IF_ERROR(
|
|
UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatDimMap));
|
|
@@ -1295,6 +1315,11 @@ Status ReduceTransposer::TransposeNode(TransposeContext* context,
|
|
TF_RETURN_IF_ERROR(
|
|
UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
|
|
}
|
|
+ // Change back the format from 5D to 4D layout.
|
|
+ if (rank == 5) {
|
|
+ context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
|
+ dst_format);
|
|
+ }
|
|
return context->graph_view->GetMutationBuilder()->Apply();
|
|
}
|
|
|
|
diff --git a/tensorflow/core/kernels/data_format_ops.cc b/tensorflow/core/kernels/data_format_ops.cc
|
|
index 181aa1b8..e9c71f17 100644
|
|
--- a/tensorflow/core/kernels/data_format_ops.cc
|
|
+++ b/tensorflow/core/kernels/data_format_ops.cc
|
|
@@ -37,14 +37,14 @@ class DataFormatDimMapOp : public OpKernel {
|
|
OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
|
|
string dst_format;
|
|
OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
|
|
- OP_REQUIRES(context, src_format.size() == 4,
|
|
+ OP_REQUIRES(context, src_format.size() == 4 || src_format.size() == 5,
|
|
errors::InvalidArgument(strings::StrCat(
|
|
- "Source format must of length 4, received src_format = ",
|
|
- src_format)));
|
|
+ "Source format must of length 4 or 5, received "
|
|
+ "src_format = ", src_format)));
|
|
OP_REQUIRES(
|
|
- context, dst_format.size() == 4,
|
|
+ context, dst_format.size() == 4 || dst_format.size() == 5,
|
|
errors::InvalidArgument(strings::StrCat(
|
|
- "Destination format must of length 4, received dst_format = ",
|
|
+ "Destination format must of length 4 or 5, received dst_format = ",
|
|
dst_format)));
|
|
dst_idx_ = Tensor(DT_INT32, {static_cast<int64>(src_format.size())});
|
|
for (int i = 0; i < src_format.size(); ++i) {
|
|
diff --git a/tensorflow/core/kernels/data_format_ops.h b/tensorflow/core/kernels/data_format_ops.h
|
|
index bc416fa7..89b54901 100644
|
|
--- a/tensorflow/core/kernels/data_format_ops.h
|
|
+++ b/tensorflow/core/kernels/data_format_ops.h
|
|
@@ -28,24 +28,49 @@ template <typename Device, typename T>
|
|
struct DataFormatDimMap {
|
|
void operator()(const Device& d, typename TTypes<T>::ConstFlat x,
|
|
typename TTypes<T>::Flat y, const TTypes<int>::Vec dst) {
|
|
- auto zero = x.constant(0);
|
|
- auto one = x.constant(1);
|
|
- auto two = x.constant(2);
|
|
+ if (dst.size() == 4) {
|
|
+ auto zero = x.constant(0);
|
|
+ auto one = x.constant(1);
|
|
+ auto two = x.constant(2);
|
|
|
|
- auto f_zero = x.constant(dst(0));
|
|
- auto f_one = x.constant(dst(1));
|
|
- auto f_two = x.constant(dst(2));
|
|
- auto f_three = x.constant(dst(3));
|
|
+ auto f_zero = x.constant(dst(0));
|
|
+ auto f_one = x.constant(dst(1));
|
|
+ auto f_two = x.constant(dst(2));
|
|
+ auto f_three = x.constant(dst(3));
|
|
|
|
- auto four = x.constant(4);
|
|
- auto x_mod = (x + four) % 4;
|
|
+ auto four = x.constant(4);
|
|
+ auto x_mod = (x + four) % 4;
|
|
|
|
- auto is_zero = (x_mod == zero);
|
|
- auto is_one = (x_mod == one);
|
|
- auto is_two = (x_mod == two);
|
|
+ auto is_zero = (x_mod == zero);
|
|
+ auto is_one = (x_mod == one);
|
|
+ auto is_two = (x_mod == two);
|
|
|
|
- y.device(d) = is_zero.select(
|
|
- f_zero, is_one.select(f_one, is_two.select(f_two, f_three)));
|
|
+ y.device(d) = is_zero.select(
|
|
+ f_zero, is_one.select(f_one, is_two.select(f_two, f_three)));
|
|
+ } else {
|
|
+ auto zero = x.constant(0);
|
|
+ auto one = x.constant(1);
|
|
+ auto two = x.constant(2);
|
|
+ auto three = x.constant(3);
|
|
+
|
|
+ auto f_zero = x.constant(dst(0));
|
|
+ auto f_one = x.constant(dst(1));
|
|
+ auto f_two = x.constant(dst(2));
|
|
+ auto f_three = x.constant(dst(3));
|
|
+ auto f_four = x.constant(dst(4));
|
|
+
|
|
+ auto five = x.constant(5);
|
|
+ auto x_mod = (x + five) % 5;
|
|
+
|
|
+ auto is_zero = (x_mod == zero);
|
|
+ auto is_one = (x_mod == one);
|
|
+ auto is_two = (x_mod == two);
|
|
+ auto is_three = (x_mod == three);
|
|
+
|
|
+ y.device(d) = is_zero.select(
|
|
+ f_zero, is_one.select(f_one, is_two.select(f_two,
|
|
+ is_three.select(f_three, f_four))));
|
|
+ }
|
|
}
|
|
};
|
|
|
|
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py
|
|
index 10f86980..f90da7ed 100644
|
|
--- a/tensorflow/python/grappler/layout_optimizer_test.py
|
|
+++ b/tensorflow/python/grappler/layout_optimizer_test.py
|
|
@@ -215,6 +215,9 @@ class LayoutOptimizerTest(test.TestCase):
|
|
def _assert_map_nhwc_to_nchw(self, name, nodes):
|
|
self.assertIn(name + '-DimMapNHWCToNCHW-LayoutOptimizer', nodes)
|
|
|
|
+ def _assert_map_ndhwc_to_ncdhw(self, name, nodes):
|
|
+ self.assertIn(name + '-DataFormatDimMapNDHWCToNCDHW-LayoutOptimizer', nodes)
|
|
+
|
|
def _assert_vec_nchw_to_nhwc(self, name, nodes):
|
|
self.assertIn(name + '-VecPermuteNCHWToNHWC-LayoutOptimizer', nodes)
|
|
|
|
@@ -286,6 +289,42 @@ class LayoutOptimizerTest(test.TestCase):
|
|
|
|
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
|
|
|
+ @test_util.deprecated_graph_mode_only
|
|
+ def testReduceOpsFor5DTensors(self):
|
|
+ if test.is_gpu_available(cuda_only=True):
|
|
+ random_seed.set_random_seed(0)
|
|
+ x = random_ops.truncated_normal([1, 4, 2, 3, 3], seed=0)
|
|
+ w = random_ops.truncated_normal([2, 2, 2, 3, 3], seed=0)
|
|
+ gamma = random_ops.truncated_normal([1, 1, 1, 1, 3], seed=0)
|
|
+ beta = random_ops.truncated_normal([1, 1, 1, 1, 3], seed=0)
|
|
+ conv3d = gen_nn_ops.conv3d(x, w, [1, 1, 1, 1, 1], 'SAME')
|
|
+ y = math_ops.reduce_mean(conv3d, [0, 1, 2, 3], keepdims=True)
|
|
+ output = array_ops.identity(y)
|
|
+
|
|
+ with session.Session(config=_get_config(False)) as sess:
|
|
+ output_val_ref = sess.run(output)
|
|
+
|
|
+ with session.Session(config=_get_config()) as sess:
|
|
+ metadata = config_pb2.RunMetadata()
|
|
+ output_val = sess.run(output, run_metadata=metadata)
|
|
+
|
|
+ nodes = []
|
|
+ num_transposes = 0
|
|
+ for node in metadata.cost_graph.node:
|
|
+ if _is_transpose(node.name):
|
|
+ num_transposes += 1
|
|
+ nodes.append(node.name)
|
|
+ print(node.name)
|
|
+
|
|
+ # The reduce op Mean needs to dim map the input reduce index to NCDHW.
|
|
+ # Then, the output needs to be tranposed back to NDHWC.
|
|
+ expected_num_transposes = 2
|
|
+ self.assertEqual(expected_num_transposes, num_transposes)
|
|
+ self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes)
|
|
+ self._assert_map_ndhwc_to_ncdhw('Mean-1', nodes)
|
|
+ self._assert_trans_ncdhw_to_ndhwc('Mean-0-0', nodes)
|
|
+ self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
|
+
|
|
@test_util.deprecated_graph_mode_only
|
|
def testSplitWithNonConstAxis(self):
|
|
if test.is_gpu_available(cuda_only=True):
|
|
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
|
|
index bfe11b63..55d11a35 100644
|
|
--- a/tensorflow/python/ops/nn_test.py
|
|
+++ b/tensorflow/python/ops/nn_test.py
|
|
@@ -1207,6 +1207,33 @@ class DataFormatDimMapTest(test_lib.TestCase):
|
|
y_val = self.evaluate(y)
|
|
self.assertAllEqual(y_val, y_val_expected)
|
|
|
|
+ def testNDHWCtoNCDHW(self):
|
|
+ x_val = [1, -4, -3, -2]
|
|
+ y_val_expected = [2, 2, 3, 4]
|
|
+ x = constant_op.constant(x_val)
|
|
+ y = nn_ops.data_format_dim_map(x, src_format="NDHWC", dst_format="NCDHW")
|
|
+ with test_util.use_gpu():
|
|
+ y_val = self.evaluate(y)
|
|
+ self.assertAllEqual(y_val, y_val_expected)
|
|
+
|
|
+ def testNDHWCtoDHWNC(self):
|
|
+ x_val = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
|
|
+ y_val_expected = [3, 0, 1, 2, 4, 3, 0, 1, 2, 4]
|
|
+ x = constant_op.constant(x_val)
|
|
+ y = nn_ops.data_format_dim_map(x, src_format="NDHWC", dst_format="DHWNC")
|
|
+ with test_util.use_gpu():
|
|
+ y_val = self.evaluate(y)
|
|
+ self.assertAllEqual(y_val, y_val_expected)
|
|
+
|
|
+ def testDNHWCtoWHDCN(self):
|
|
+ x_val = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
|
|
+ y_val_expected = [4, 2, 1, 0, 3, 4, 2, 1, 0, 3]
|
|
+ x = constant_op.constant(x_val)
|
|
+ y = nn_ops.data_format_dim_map(x, src_format="NDHWC", dst_format="WHDCN")
|
|
+ with test_util.use_gpu():
|
|
+ y_val = self.evaluate(y)
|
|
+ self.assertAllEqual(y_val, y_val_expected)
|
|
+
|
|
def testArbitraryASCII(self):
|
|
x_val = [-4, -3, -2, -1, 0, 1, 2, 3]
|
|
y_val_expected = [3, 2, 1, 0, 3, 2, 1, 0]
|
|
--
|
|
2.27.0
|
|
|