From 7b8db6083b34520688dbc71f341f7aeaf156bf17 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Fri, 19 Mar 2021 16:16:41 -0700 Subject: [PATCH] Implement grouped convolution on CPU To get better compute resources utilization group-compute loop has to be parallelized, but it involves a lot of changes in Conv2D primitives. Will address that later if it will be critical for some of the users. Fix for: https://github.com/tensorflow/tensorflow/issues/29005 PiperOrigin-RevId: 363991782 Change-Id: I97f375b1133833c4de5181199316be7cbf4ebee0 --- tensorflow/core/kernels/BUILD | 1 + tensorflow/core/kernels/conv_2d.h | 54 +++++++ tensorflow/core/kernels/conv_ops.cc | 133 ++++++++++++++++-- .../python/kernel_tests/conv_ops_test.py | 20 +-- 4 files changed, 189 insertions(+), 19 deletions(-) diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 8e49f1e0a5caf..bc455626f4322 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3818,6 +3818,7 @@ tf_kernel_library( ":ops_util", "@com_google_absl//absl/base:dynamic_annotations", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "//third_party/eigen3", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h index b9a8c977e11ee..87df4a848dd56 100644 --- a/tensorflow/core/kernels/conv_2d.h +++ b/tensorflow/core/kernels/conv_2d.h @@ -43,6 +43,9 @@ void SpatialConvolutionFunc(const Device& d, Output output, Input input, padding_bottom); } +// TODO(ezhulenev): Non-templated `operator()` are required by explicit template +// instantiations for the GPU device. However they are almost certainly not used +// in any of the kernel implementation. Check if they can be removed. template struct SpatialConvolution { @@ -55,6 +58,16 @@ struct SpatialConvolution { SpatialConvolutionFunc(d, output, input, filter, row_stride, col_stride, row_dilation, col_dilation, padding, output_kernel); } + + template + void operator()(const Device& d, Output output, Input input, Filter filter, + int row_stride, int col_stride, int row_dilation, + int col_dilation, const Eigen::PaddingType& padding, + const OutputKernel& output_kernel = OutputKernel()) { + SpatialConvolutionFunc(d, output, input, filter, row_stride, col_stride, + row_dilation, col_dilation, padding, output_kernel); + } + void operator()(const Device& d, typename TTypes::Tensor output, typename TTypes::ConstTensor input, typename TTypes::ConstTensor filter, int row_stride, @@ -67,6 +80,18 @@ struct SpatialConvolution { col_dilation, Eigen::PaddingType::PADDING_VALID, output_kernel, padding_top, padding_bottom, padding_left, padding_right); } + + template + void operator()(const Device& d, Output output, Input input, Filter filter, + int row_stride, int col_stride, int row_dilation, + int col_dilation, int padding_top, int padding_bottom, + int padding_left, int padding_right, + const OutputKernel& output_kernel = OutputKernel()) { + SpatialConvolutionFunc( + d, output, input, filter, row_stride, col_stride, row_dilation, + col_dilation, Eigen::PaddingType::PADDING_VALID, output_kernel, + padding_top, padding_bottom, padding_left, padding_right); + } }; template @@ -84,6 +109,20 @@ struct SpatialConvolution { row_dilation, output_kernel) .template cast(); } + + template + void operator()(const Device& d, Output output, Input input, Filter filter, + int row_stride, int col_stride, int row_dilation, + int col_dilation, const Eigen::PaddingType& padding, + const OutputKernel& output_kernel = OutputKernel()) { + output.device(d) = + Eigen::SpatialConvolution(input.template cast(), + filter.template cast(), col_stride, + row_stride, padding, col_dilation, + row_dilation, output_kernel) + .template cast(); + } + void operator()(const Device& d, typename TTypes::Tensor output, typename TTypes::ConstTensor input, @@ -100,6 +139,21 @@ struct SpatialConvolution { padding_bottom) .template cast(); } + + template + void operator()(const Device& d, Output output, Input input, Filter filter, + int row_stride, int col_stride, int row_dilation, + int col_dilation, int padding_top, int padding_bottom, + int padding_left, int padding_right, + const OutputKernel& output_kernel = OutputKernel()) { + output.device(d) = + Eigen::SpatialConvolution( + input.template cast(), filter.template cast(), + col_stride, row_stride, Eigen::PaddingType::PADDING_VALID, + col_dilation, row_dilation, output_kernel, padding_left, + padding_right, padding_top, padding_bottom) + .template cast(); + } }; template diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index 025a8e37a94e9..8fdfe04bd1c67 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -30,6 +30,7 @@ limitations under the License. #include #include +#include "absl/synchronization/blocking_counter.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_shape_util.h" @@ -138,6 +139,98 @@ struct LaunchGeneric { } } }; + +// Compute grouped 2D convolutions on CPU. Unlike grouped convolution +// implementation in cuDNN this is faaaaaar from optimal and needs more work +// to deliver competitive performance. Currently it exists to close the feature +// parity gap between convolution operations on different devices. +template +struct LaunchGrouped { + void operator()(OpKernelContext* ctx, const Tensor& input, + const Tensor& filter, int row_stride, int col_stride, + int row_dilation, int col_dilation, const Padding& padding, + const std::vector& explicit_paddings, Tensor* output, + TensorFormat data_format) { + DCHECK(data_format == FORMAT_NHWC) + << "Grouped conv implementation only " + "supports NHWC tensor format for now."; + + const int64 in_depth = input.dim_size(3); + const int64 patch_depth = filter.dim_size(2); + const int64 num_groups = in_depth / patch_depth; + + // Shuffle input/filter tensors to have group as a leading dimension. + std::array shuffle({3, 0, 1, 2, 4}); + + // Compute pre shuffle dimemnsions. + auto pre_shuffle = [&](const Tensor& tensor) -> std::array { + return {tensor.dim_size(0), tensor.dim_size(1), tensor.dim_size(2), + num_groups, tensor.dim_size(3) / num_groups}; + }; + + // Compute post shuffle dimemnsions. + auto post_shuffle = [&](const Tensor& tensor) -> std::array { + return {num_groups, tensor.dim_size(0), tensor.dim_size(1), + tensor.dim_size(2), tensor.dim_size(3) / num_groups}; + }; + + auto& device = ctx->eigen_device(); + + absl::BlockingCounter shuffles_completed(2); + auto on_shuffled = [&]() { shuffles_completed.DecrementCount(); }; + + // Shuffle input into temporary tensor. + Tensor input_shuffled(input.dtype(), TensorShape(post_shuffle(input))); + input_shuffled.tensor().device(device, on_shuffled) = + input.shaped(pre_shuffle(input)).shuffle(shuffle); + + // Shuffle filter into temporary tensor. + Tensor filter_shuffled(filter.dtype(), TensorShape(post_shuffle(filter))); + filter_shuffled.tensor().device(device, on_shuffled) = + filter.shaped(pre_shuffle(filter)).shuffle(shuffle); + + // Wait for the completion of input/filter shuffles. + shuffles_completed.Wait(); + + // Write group convolution results into temporary output tensor. + Tensor output_shuffled(output->dtype(), TensorShape(post_shuffle(*output))); + + for (int64 i = 0; i < num_groups; ++i) { + // TODO(ezhulenev): Run this loop using `parallelFor` (regular parallelFor + // will lead to deadlock, SpatialConvolution has to use async Eigen + // assignment). This requires small changes to Eigen to support async + // exeuction for tensor chipping operation. + + // TODO(ezhulenev): Grouped convolution should also support 1x1 filter + // optimization. + + auto input_slice = input_shuffled.tensor().template chip<0>(i); + auto filter_slice = filter_shuffled.tensor().template chip<0>(i); + auto output_slice = output_shuffled.tensor().template chip<0>(i); + + if (padding == EXPLICIT) { + functor::SpatialConvolution()( + ctx->eigen_device(), output_slice, input_slice, + filter_slice, row_stride, col_stride, row_dilation, col_dilation, + static_cast(explicit_paddings[2]), + static_cast(explicit_paddings[3]), + static_cast(explicit_paddings[4]), + static_cast(explicit_paddings[5])); + } else { + functor::SpatialConvolution()( + ctx->eigen_device(), output_slice, input_slice, + filter_slice, row_stride, col_stride, row_dilation, col_dilation, + BrainPadding2EigenPadding(padding)); + } + } + + // Shuffle temporary output back into pre-shuffled shape. + std::array rev_shuffle({1, 2, 3, 0, 4}); + output->shaped(pre_shuffle(*output)).device(device) = + output_shuffled.tensor().shuffle(rev_shuffle); + } +}; + } // namespace template @@ -155,14 +248,6 @@ struct LaunchConv2DOp { ToString(data_format))); return; } - const int64 in_depth = GetTensorDim(input, data_format, 'C'); - OP_REQUIRES(ctx, in_depth == filter.dim_size(2), - errors::Unimplemented( - "The Conv2D op currently does not support grouped " - "convolutions on the CPU. A grouped convolution was " - "attempted to be run because the input depth of ", - in_depth, " does not match the filter input depth of ", - filter.dim_size(2))); for (int64 explicit_padding : explicit_paddings) { if (!FastBoundsCheck(explicit_padding, std::numeric_limits::max())) { @@ -170,9 +255,35 @@ struct LaunchConv2DOp { return; } } - LaunchGeneric()(ctx, input, filter, row_stride, col_stride, - row_dilation, col_dilation, padding, - explicit_paddings, output, data_format); + + const int64 in_depth = input.dim_size(3); + const int64 out_depth = output->dim_size(3); + const int64 patch_depth = filter.dim_size(2); + + if (in_depth % patch_depth != 0) { + ctx->SetStatus(errors::InvalidArgument( + "input depth must be evenly divisible by filter depth: ", in_depth, + " vs ", patch_depth)); + return; + } + + const int64 num_groups = in_depth / patch_depth; + if (out_depth % num_groups != 0 || out_depth < num_groups) { + ctx->SetStatus(errors::InvalidArgument( + "output depth must be evenly divisible by number of groups: ", + out_depth, " vs ", num_groups)); + return; + } + + if (in_depth != patch_depth) { + LaunchGrouped()(ctx, input, filter, row_stride, col_stride, + row_dilation, col_dilation, padding, explicit_paddings, + output, data_format); + } else { + LaunchGeneric()(ctx, input, filter, row_stride, col_stride, + row_dilation, col_dilation, padding, + explicit_paddings, output, data_format); + } } }; diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py index 44a67ccc55f0a..92af04359caa9 100644 --- a/tensorflow/python/kernel_tests/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_test.py @@ -834,17 +834,21 @@ def MakeConv2d(inputs, filters): results[0], results[1], atol=tol_to_use, rtol=tol_to_use) @test_util.run_in_graph_and_eager_modes - @test_util.run_cuda_only def testConv2DGroupConvFwd(self): - for data_format in ["NHWC", "NCHW"]: + if test.is_gpu_available(cuda_only=True): + data_formats = ["NHWC", "NCHW"] + else: + data_formats = ["NHWC"] + for data_format in data_formats: for dilation in [1, 2]: for stride in [1, 2]: - self._VerifyGroupConvFwd([10, 32, 32, 16], [3, 3, 4, 8], - dilations=[dilation, dilation], - strides=[stride, stride], - padding="SAME", - data_format=data_format, - dtype=dtypes.float32) + for filter_dims in [[3, 3, 4, 8], [1, 1, 2, 16]]: + self._VerifyGroupConvFwd([10, 32, 32, 16], filter_dims, + dilations=[dilation, dilation], + strides=[stride, stride], + padding="SAME", + data_format=data_format, + dtype=dtypes.float32) @test_util.deprecated_graph_mode_only @test_util.run_cuda_only