323 lines
14 KiB
Diff
323 lines
14 KiB
Diff
From 7b8db6083b34520688dbc71f341f7aeaf156bf17 Mon Sep 17 00:00:00 2001
|
|
From: Eugene Zhulenev <ezhulenev@google.com>
|
|
Date: Fri, 19 Mar 2021 16:16:41 -0700
|
|
Subject: [PATCH] Implement grouped convolution on CPU
|
|
|
|
To get better compute resources utilization group-compute loop has to be parallelized, but it involves a lot of changes in Conv2D primitives. Will address that later if it will be critical for some of the users.
|
|
|
|
Fix for: https://github.com/tensorflow/tensorflow/issues/29005
|
|
|
|
PiperOrigin-RevId: 363991782
|
|
Change-Id: I97f375b1133833c4de5181199316be7cbf4ebee0
|
|
---
|
|
tensorflow/core/kernels/BUILD | 1 +
|
|
tensorflow/core/kernels/conv_2d.h | 54 +++++++
|
|
tensorflow/core/kernels/conv_ops.cc | 133 ++++++++++++++++--
|
|
.../python/kernel_tests/conv_ops_test.py | 20 +--
|
|
4 files changed, 189 insertions(+), 19 deletions(-)
|
|
|
|
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
|
|
index 8e49f1e0a5caf..bc455626f4322 100644
|
|
--- a/tensorflow/core/kernels/BUILD
|
|
+++ b/tensorflow/core/kernels/BUILD
|
|
@@ -3818,6 +3818,7 @@ tf_kernel_library(
|
|
":ops_util",
|
|
"@com_google_absl//absl/base:dynamic_annotations",
|
|
"@com_google_absl//absl/strings",
|
|
+ "@com_google_absl//absl/synchronization",
|
|
"//third_party/eigen3",
|
|
"//tensorflow/core:core_cpu",
|
|
"//tensorflow/core:framework",
|
|
diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h
|
|
index b9a8c977e11ee..87df4a848dd56 100644
|
|
--- a/tensorflow/core/kernels/conv_2d.h
|
|
+++ b/tensorflow/core/kernels/conv_2d.h
|
|
@@ -43,6 +43,9 @@ void SpatialConvolutionFunc(const Device& d, Output output, Input input,
|
|
padding_bottom);
|
|
}
|
|
|
|
+// TODO(ezhulenev): Non-templated `operator()` are required by explicit template
|
|
+// instantiations for the GPU device. However they are almost certainly not used
|
|
+// in any of the kernel implementation. Check if they can be removed.
|
|
template <typename Device, typename T,
|
|
typename OutputKernel = const Eigen::NoOpOutputKernel>
|
|
struct SpatialConvolution {
|
|
@@ -55,6 +58,16 @@ struct SpatialConvolution {
|
|
SpatialConvolutionFunc(d, output, input, filter, row_stride, col_stride,
|
|
row_dilation, col_dilation, padding, output_kernel);
|
|
}
|
|
+
|
|
+ template <typename Input, typename Filter, typename Output>
|
|
+ void operator()(const Device& d, Output output, Input input, Filter filter,
|
|
+ int row_stride, int col_stride, int row_dilation,
|
|
+ int col_dilation, const Eigen::PaddingType& padding,
|
|
+ const OutputKernel& output_kernel = OutputKernel()) {
|
|
+ SpatialConvolutionFunc(d, output, input, filter, row_stride, col_stride,
|
|
+ row_dilation, col_dilation, padding, output_kernel);
|
|
+ }
|
|
+
|
|
void operator()(const Device& d, typename TTypes<T, 4>::Tensor output,
|
|
typename TTypes<T, 4>::ConstTensor input,
|
|
typename TTypes<T, 4>::ConstTensor filter, int row_stride,
|
|
@@ -67,6 +80,18 @@ struct SpatialConvolution {
|
|
col_dilation, Eigen::PaddingType::PADDING_VALID, output_kernel,
|
|
padding_top, padding_bottom, padding_left, padding_right);
|
|
}
|
|
+
|
|
+ template <typename Input, typename Filter, typename Output>
|
|
+ void operator()(const Device& d, Output output, Input input, Filter filter,
|
|
+ int row_stride, int col_stride, int row_dilation,
|
|
+ int col_dilation, int padding_top, int padding_bottom,
|
|
+ int padding_left, int padding_right,
|
|
+ const OutputKernel& output_kernel = OutputKernel()) {
|
|
+ SpatialConvolutionFunc(
|
|
+ d, output, input, filter, row_stride, col_stride, row_dilation,
|
|
+ col_dilation, Eigen::PaddingType::PADDING_VALID, output_kernel,
|
|
+ padding_top, padding_bottom, padding_left, padding_right);
|
|
+ }
|
|
};
|
|
|
|
template <typename Device, typename OutputKernel>
|
|
@@ -84,6 +109,20 @@ struct SpatialConvolution<Device, Eigen::half, OutputKernel> {
|
|
row_dilation, output_kernel)
|
|
.template cast<Eigen::half>();
|
|
}
|
|
+
|
|
+ template <typename Input, typename Filter, typename Output>
|
|
+ void operator()(const Device& d, Output output, Input input, Filter filter,
|
|
+ int row_stride, int col_stride, int row_dilation,
|
|
+ int col_dilation, const Eigen::PaddingType& padding,
|
|
+ const OutputKernel& output_kernel = OutputKernel()) {
|
|
+ output.device(d) =
|
|
+ Eigen::SpatialConvolution(input.template cast<float>(),
|
|
+ filter.template cast<float>(), col_stride,
|
|
+ row_stride, padding, col_dilation,
|
|
+ row_dilation, output_kernel)
|
|
+ .template cast<Eigen::half>();
|
|
+ }
|
|
+
|
|
void operator()(const Device& d,
|
|
typename TTypes<Eigen::half, 4>::Tensor output,
|
|
typename TTypes<Eigen::half, 4>::ConstTensor input,
|
|
@@ -100,6 +139,21 @@ struct SpatialConvolution<Device, Eigen::half, OutputKernel> {
|
|
padding_bottom)
|
|
.template cast<Eigen::half>();
|
|
}
|
|
+
|
|
+ template <typename Input, typename Filter, typename Output>
|
|
+ void operator()(const Device& d, Output output, Input input, Filter filter,
|
|
+ int row_stride, int col_stride, int row_dilation,
|
|
+ int col_dilation, int padding_top, int padding_bottom,
|
|
+ int padding_left, int padding_right,
|
|
+ const OutputKernel& output_kernel = OutputKernel()) {
|
|
+ output.device(d) =
|
|
+ Eigen::SpatialConvolution(
|
|
+ input.template cast<float>(), filter.template cast<float>(),
|
|
+ col_stride, row_stride, Eigen::PaddingType::PADDING_VALID,
|
|
+ col_dilation, row_dilation, output_kernel, padding_left,
|
|
+ padding_right, padding_top, padding_bottom)
|
|
+ .template cast<Eigen::half>();
|
|
+ }
|
|
};
|
|
|
|
template <typename Device, typename T>
|
|
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
|
|
index 025a8e37a94e9..8fdfe04bd1c67 100644
|
|
--- a/tensorflow/core/kernels/conv_ops.cc
|
|
+++ b/tensorflow/core/kernels/conv_ops.cc
|
|
@@ -30,6 +30,7 @@ limitations under the License.
|
|
#include <map>
|
|
#include <vector>
|
|
|
|
+#include "absl/synchronization/blocking_counter.h"
|
|
#include "tensorflow/core/framework/allocator.h"
|
|
#include "tensorflow/core/framework/bounds_check.h"
|
|
#include "tensorflow/core/framework/kernel_shape_util.h"
|
|
@@ -138,6 +139,98 @@ struct LaunchGeneric {
|
|
}
|
|
}
|
|
};
|
|
+
|
|
+// Compute grouped 2D convolutions on CPU. Unlike grouped convolution
|
|
+// implementation in cuDNN this is faaaaaar from optimal and needs more work
|
|
+// to deliver competitive performance. Currently it exists to close the feature
|
|
+// parity gap between convolution operations on different devices.
|
|
+template <typename T>
|
|
+struct LaunchGrouped {
|
|
+ void operator()(OpKernelContext* ctx, const Tensor& input,
|
|
+ const Tensor& filter, int row_stride, int col_stride,
|
|
+ int row_dilation, int col_dilation, const Padding& padding,
|
|
+ const std::vector<int64>& explicit_paddings, Tensor* output,
|
|
+ TensorFormat data_format) {
|
|
+ DCHECK(data_format == FORMAT_NHWC)
|
|
+ << "Grouped conv implementation only "
|
|
+ "supports NHWC tensor format for now.";
|
|
+
|
|
+ const int64 in_depth = input.dim_size(3);
|
|
+ const int64 patch_depth = filter.dim_size(2);
|
|
+ const int64 num_groups = in_depth / patch_depth;
|
|
+
|
|
+ // Shuffle input/filter tensors to have group as a leading dimension.
|
|
+ std::array<int64, 5> shuffle({3, 0, 1, 2, 4});
|
|
+
|
|
+ // Compute pre shuffle dimemnsions.
|
|
+ auto pre_shuffle = [&](const Tensor& tensor) -> std::array<int64, 5> {
|
|
+ return {tensor.dim_size(0), tensor.dim_size(1), tensor.dim_size(2),
|
|
+ num_groups, tensor.dim_size(3) / num_groups};
|
|
+ };
|
|
+
|
|
+ // Compute post shuffle dimemnsions.
|
|
+ auto post_shuffle = [&](const Tensor& tensor) -> std::array<int64, 5> {
|
|
+ return {num_groups, tensor.dim_size(0), tensor.dim_size(1),
|
|
+ tensor.dim_size(2), tensor.dim_size(3) / num_groups};
|
|
+ };
|
|
+
|
|
+ auto& device = ctx->eigen_device<CPUDevice>();
|
|
+
|
|
+ absl::BlockingCounter shuffles_completed(2);
|
|
+ auto on_shuffled = [&]() { shuffles_completed.DecrementCount(); };
|
|
+
|
|
+ // Shuffle input into temporary tensor.
|
|
+ Tensor input_shuffled(input.dtype(), TensorShape(post_shuffle(input)));
|
|
+ input_shuffled.tensor<T, 5>().device(device, on_shuffled) =
|
|
+ input.shaped<T, 5>(pre_shuffle(input)).shuffle(shuffle);
|
|
+
|
|
+ // Shuffle filter into temporary tensor.
|
|
+ Tensor filter_shuffled(filter.dtype(), TensorShape(post_shuffle(filter)));
|
|
+ filter_shuffled.tensor<T, 5>().device(device, on_shuffled) =
|
|
+ filter.shaped<T, 5>(pre_shuffle(filter)).shuffle(shuffle);
|
|
+
|
|
+ // Wait for the completion of input/filter shuffles.
|
|
+ shuffles_completed.Wait();
|
|
+
|
|
+ // Write group convolution results into temporary output tensor.
|
|
+ Tensor output_shuffled(output->dtype(), TensorShape(post_shuffle(*output)));
|
|
+
|
|
+ for (int64 i = 0; i < num_groups; ++i) {
|
|
+ // TODO(ezhulenev): Run this loop using `parallelFor` (regular parallelFor
|
|
+ // will lead to deadlock, SpatialConvolution has to use async Eigen
|
|
+ // assignment). This requires small changes to Eigen to support async
|
|
+ // exeuction for tensor chipping operation.
|
|
+
|
|
+ // TODO(ezhulenev): Grouped convolution should also support 1x1 filter
|
|
+ // optimization.
|
|
+
|
|
+ auto input_slice = input_shuffled.tensor<T, 5>().template chip<0>(i);
|
|
+ auto filter_slice = filter_shuffled.tensor<T, 5>().template chip<0>(i);
|
|
+ auto output_slice = output_shuffled.tensor<T, 5>().template chip<0>(i);
|
|
+
|
|
+ if (padding == EXPLICIT) {
|
|
+ functor::SpatialConvolution<CPUDevice, T>()(
|
|
+ ctx->eigen_device<CPUDevice>(), output_slice, input_slice,
|
|
+ filter_slice, row_stride, col_stride, row_dilation, col_dilation,
|
|
+ static_cast<int>(explicit_paddings[2]),
|
|
+ static_cast<int>(explicit_paddings[3]),
|
|
+ static_cast<int>(explicit_paddings[4]),
|
|
+ static_cast<int>(explicit_paddings[5]));
|
|
+ } else {
|
|
+ functor::SpatialConvolution<CPUDevice, T>()(
|
|
+ ctx->eigen_device<CPUDevice>(), output_slice, input_slice,
|
|
+ filter_slice, row_stride, col_stride, row_dilation, col_dilation,
|
|
+ BrainPadding2EigenPadding(padding));
|
|
+ }
|
|
+ }
|
|
+
|
|
+ // Shuffle temporary output back into pre-shuffled shape.
|
|
+ std::array<int64, 5> rev_shuffle({1, 2, 3, 0, 4});
|
|
+ output->shaped<T, 5>(pre_shuffle(*output)).device(device) =
|
|
+ output_shuffled.tensor<T, 5>().shuffle(rev_shuffle);
|
|
+ }
|
|
+};
|
|
+
|
|
} // namespace
|
|
|
|
template <typename T>
|
|
@@ -155,14 +248,6 @@ struct LaunchConv2DOp<CPUDevice, T> {
|
|
ToString(data_format)));
|
|
return;
|
|
}
|
|
- const int64 in_depth = GetTensorDim(input, data_format, 'C');
|
|
- OP_REQUIRES(ctx, in_depth == filter.dim_size(2),
|
|
- errors::Unimplemented(
|
|
- "The Conv2D op currently does not support grouped "
|
|
- "convolutions on the CPU. A grouped convolution was "
|
|
- "attempted to be run because the input depth of ",
|
|
- in_depth, " does not match the filter input depth of ",
|
|
- filter.dim_size(2)));
|
|
|
|
for (int64 explicit_padding : explicit_paddings) {
|
|
if (!FastBoundsCheck(explicit_padding, std::numeric_limits<int>::max())) {
|
|
@@ -170,9 +255,35 @@ struct LaunchConv2DOp<CPUDevice, T> {
|
|
return;
|
|
}
|
|
}
|
|
- LaunchGeneric<CPUDevice, T>()(ctx, input, filter, row_stride, col_stride,
|
|
- row_dilation, col_dilation, padding,
|
|
- explicit_paddings, output, data_format);
|
|
+
|
|
+ const int64 in_depth = input.dim_size(3);
|
|
+ const int64 out_depth = output->dim_size(3);
|
|
+ const int64 patch_depth = filter.dim_size(2);
|
|
+
|
|
+ if (in_depth % patch_depth != 0) {
|
|
+ ctx->SetStatus(errors::InvalidArgument(
|
|
+ "input depth must be evenly divisible by filter depth: ", in_depth,
|
|
+ " vs ", patch_depth));
|
|
+ return;
|
|
+ }
|
|
+
|
|
+ const int64 num_groups = in_depth / patch_depth;
|
|
+ if (out_depth % num_groups != 0 || out_depth < num_groups) {
|
|
+ ctx->SetStatus(errors::InvalidArgument(
|
|
+ "output depth must be evenly divisible by number of groups: ",
|
|
+ out_depth, " vs ", num_groups));
|
|
+ return;
|
|
+ }
|
|
+
|
|
+ if (in_depth != patch_depth) {
|
|
+ LaunchGrouped<T>()(ctx, input, filter, row_stride, col_stride,
|
|
+ row_dilation, col_dilation, padding, explicit_paddings,
|
|
+ output, data_format);
|
|
+ } else {
|
|
+ LaunchGeneric<CPUDevice, T>()(ctx, input, filter, row_stride, col_stride,
|
|
+ row_dilation, col_dilation, padding,
|
|
+ explicit_paddings, output, data_format);
|
|
+ }
|
|
}
|
|
};
|
|
|
|
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
|
|
index 44a67ccc55f0a..92af04359caa9 100644
|
|
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
|
|
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
|
|
@@ -834,17 +834,21 @@ def MakeConv2d(inputs, filters):
|
|
results[0], results[1], atol=tol_to_use, rtol=tol_to_use)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
- @test_util.run_cuda_only
|
|
def testConv2DGroupConvFwd(self):
|
|
- for data_format in ["NHWC", "NCHW"]:
|
|
+ if test.is_gpu_available(cuda_only=True):
|
|
+ data_formats = ["NHWC", "NCHW"]
|
|
+ else:
|
|
+ data_formats = ["NHWC"]
|
|
+ for data_format in data_formats:
|
|
for dilation in [1, 2]:
|
|
for stride in [1, 2]:
|
|
- self._VerifyGroupConvFwd([10, 32, 32, 16], [3, 3, 4, 8],
|
|
- dilations=[dilation, dilation],
|
|
- strides=[stride, stride],
|
|
- padding="SAME",
|
|
- data_format=data_format,
|
|
- dtype=dtypes.float32)
|
|
+ for filter_dims in [[3, 3, 4, 8], [1, 1, 2, 16]]:
|
|
+ self._VerifyGroupConvFwd([10, 32, 32, 16], filter_dims,
|
|
+ dilations=[dilation, dilation],
|
|
+ strides=[stride, stride],
|
|
+ padding="SAME",
|
|
+ data_format=data_format,
|
|
+ dtype=dtypes.float32)
|
|
|
|
@test_util.deprecated_graph_mode_only
|
|
@test_util.run_cuda_only
|