243 lines
8.9 KiB
Diff
243 lines
8.9 KiB
Diff
From 1d8218f155c1d22c21afda8bf28e36e4094d9e88 Mon Sep 17 00:00:00 2001
|
|
From: Ben Barsdell <bbarsdell@nvidia.com>
|
|
Date: Fri, 8 Jan 2021 11:04:37 +1100
|
|
Subject: [PATCH 1/2] Refactor ReshapeSparseTensor into a template+class
|
|
|
|
- This is in preparation for adding a GPU implementation.
|
|
- No functional change.
|
|
---
|
|
.../kernels/deserialize_sparse_string_op.cc | 8 +-
|
|
tensorflow/core/kernels/reshape_util.cc | 101 ++++++++++++------
|
|
tensorflow/core/kernels/reshape_util.h | 18 ++++
|
|
tensorflow/core/kernels/sparse_reshape_op.cc | 12 ++-
|
|
4 files changed, 98 insertions(+), 41 deletions(-)
|
|
|
|
diff --git a/tensorflow/core/kernels/deserialize_sparse_string_op.cc b/tensorflow/core/kernels/deserialize_sparse_string_op.cc
|
|
index 2e151078..3acd86ef 100644
|
|
--- a/tensorflow/core/kernels/deserialize_sparse_string_op.cc
|
|
+++ b/tensorflow/core/kernels/deserialize_sparse_string_op.cc
|
|
@@ -35,6 +35,8 @@ limitations under the License.
|
|
|
|
namespace tensorflow {
|
|
|
|
+using CPUDevice = Eigen::ThreadPoolDevice;
|
|
+
|
|
namespace {
|
|
|
|
using sparse::SparseTensor;
|
|
@@ -204,9 +206,9 @@ class DeserializeSparseOp : public OpKernel {
|
|
target_shape.vec<int64>()(i + ndims - 1) = output.shape().data()[i + 1];
|
|
}
|
|
|
|
- ReshapeSparseTensor(context, output.indices(), input_shape, target_shape,
|
|
- 0 /* output indices index */,
|
|
- 2 /* output shape index */);
|
|
+ ReshapeSparseTensor<CPUDevice>(context, output.indices(), input_shape,
|
|
+ target_shape, 0 /* output indices index */,
|
|
+ 2 /* output shape index */);
|
|
context->set_output(1, output.values());
|
|
}
|
|
|
|
diff --git a/tensorflow/core/kernels/reshape_util.cc b/tensorflow/core/kernels/reshape_util.cc
|
|
index 1fce80f7..d0d54738 100644
|
|
--- a/tensorflow/core/kernels/reshape_util.cc
|
|
+++ b/tensorflow/core/kernels/reshape_util.cc
|
|
@@ -31,6 +31,53 @@ limitations under the License.
|
|
|
|
namespace tensorflow {
|
|
|
|
+using CPUDevice = Eigen::ThreadPoolDevice;
|
|
+
|
|
+namespace functor {
|
|
+
|
|
+template <>
|
|
+struct ReshapeSparseTensorFunctor<CPUDevice> {
|
|
+ Status operator()(const TensorShape &input_shape,
|
|
+ const TensorShape &output_shape,
|
|
+ typename TTypes<int64>::ConstMatrix input_indices,
|
|
+ typename TTypes<int64>::Matrix output_indices) const {
|
|
+ const int64 input_rank = input_shape.dims();
|
|
+ const int64 output_rank = output_shape.dims();
|
|
+ const int64 nnz = input_indices.dimension(0);
|
|
+ gtl::InlinedVector<int64, 8> input_strides(input_rank);
|
|
+ if (input_rank > 0) {
|
|
+ input_strides[input_rank - 1] = 1;
|
|
+ for (int d = input_rank - 2; d >= 0; --d) {
|
|
+ input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1);
|
|
+ }
|
|
+ }
|
|
+
|
|
+ gtl::InlinedVector<int64, 8> output_strides(output_rank);
|
|
+ if (output_rank > 0) {
|
|
+ output_strides[output_rank - 1] = 1;
|
|
+ for (int d = output_rank - 2; d >= 0; --d) {
|
|
+ output_strides[d] =
|
|
+ output_strides[d + 1] * output_shape.dim_size(d + 1);
|
|
+ }
|
|
+ }
|
|
+
|
|
+ for (int i = 0; i < nnz; ++i) {
|
|
+ int64 id = 0;
|
|
+ for (int j = 0; j < input_rank; ++j) {
|
|
+ id += input_indices(i, j) * input_strides[j];
|
|
+ }
|
|
+ for (int j = 0; j < output_rank; ++j) {
|
|
+ output_indices(i, j) = id / output_strides[j];
|
|
+ id %= output_strides[j];
|
|
+ }
|
|
+ }
|
|
+ return Status::OK();
|
|
+ }
|
|
+};
|
|
+
|
|
+} // namespace functor
|
|
+
|
|
+template <typename Device>
|
|
void ReshapeSparseTensor(OpKernelContext *context,
|
|
const Tensor &input_indices_in,
|
|
const Tensor &input_shape_in,
|
|
@@ -111,40 +158,6 @@ void ReshapeSparseTensor(OpKernelContext *context,
|
|
return;
|
|
}
|
|
|
|
- gtl::InlinedVector<int64, 8> input_strides(input_rank);
|
|
- if (input_rank > 0) {
|
|
- input_strides[input_rank - 1] = 1;
|
|
- for (int d = input_rank - 2; d >= 0; --d) {
|
|
- input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1);
|
|
- }
|
|
- }
|
|
-
|
|
- gtl::InlinedVector<int64, 8> output_strides(output_rank);
|
|
- if (output_rank > 0) {
|
|
- output_strides[output_rank - 1] = 1;
|
|
- for (int d = output_rank - 2; d >= 0; --d) {
|
|
- output_strides[d] = output_strides[d + 1] * output_shape.dim_size(d + 1);
|
|
- }
|
|
- }
|
|
-
|
|
- Tensor *result_indices = nullptr;
|
|
- OP_REQUIRES_OK(context,
|
|
- context->allocate_output(output_indices_idx,
|
|
- TensorShape({nnz, output_rank}),
|
|
- &result_indices));
|
|
- auto input_ind = input_indices_in.matrix<int64>();
|
|
- auto output_ind = result_indices->matrix<int64>();
|
|
- for (int i = 0; i < nnz; ++i) {
|
|
- int64 id = 0;
|
|
- for (int j = 0; j < input_rank; ++j) {
|
|
- id += input_ind(i, j) * input_strides[j];
|
|
- }
|
|
- for (int j = 0; j < output_rank; ++j) {
|
|
- output_ind(i, j) = id / output_strides[j];
|
|
- id %= output_strides[j];
|
|
- }
|
|
- }
|
|
-
|
|
Tensor *result_shape = nullptr;
|
|
OP_REQUIRES_OK(context, context->allocate_output(output_shape_idx,
|
|
TensorShape({output_rank}),
|
|
@@ -153,6 +166,26 @@ void ReshapeSparseTensor(OpKernelContext *context,
|
|
for (int j = 0; j < output_shape.dims(); ++j) {
|
|
output_shape_vec(j) = output_shape.dim_size(j);
|
|
}
|
|
+
|
|
+ Tensor *result_indices = nullptr;
|
|
+ OP_REQUIRES_OK(context,
|
|
+ context->allocate_output(output_indices_idx,
|
|
+ TensorShape({nnz, output_rank}),
|
|
+ &result_indices));
|
|
+ if (nnz > 0) {
|
|
+ OP_REQUIRES_OK(context, functor::ReshapeSparseTensorFunctor<Device>()(
|
|
+ input_shape, output_shape,
|
|
+ input_indices_in.matrix<int64>(),
|
|
+ result_indices->matrix<int64>()));
|
|
+ }
|
|
}
|
|
|
|
+#define EXPLICITLY_INSTANTIATE_FUNCTION(Device) \
|
|
+ template void ReshapeSparseTensor<Device>( \
|
|
+ OpKernelContext *context, const Tensor &input_indices_in, \
|
|
+ const Tensor &input_shape_in, const Tensor &target_shape_in, \
|
|
+ int output_indices_idx, int output_shape_idx)
|
|
+EXPLICITLY_INSTANTIATE_FUNCTION(CPUDevice);
|
|
+#undef EXPLICITLY_INSTANTIATE_FUNCTION
|
|
+
|
|
} // namespace tensorflow
|
|
diff --git a/tensorflow/core/kernels/reshape_util.h b/tensorflow/core/kernels/reshape_util.h
|
|
index 7e1809e8..b3a35651 100644
|
|
--- a/tensorflow/core/kernels/reshape_util.h
|
|
+++ b/tensorflow/core/kernels/reshape_util.h
|
|
@@ -16,18 +16,36 @@ limitations under the License.
|
|
#ifndef TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_
|
|
#define TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_
|
|
|
|
+#include "tensorflow/core/framework/tensor_shape.h"
|
|
+#include "tensorflow/core/framework/tensor_types.h"
|
|
+#include "tensorflow/core/lib/core/status.h"
|
|
+
|
|
namespace tensorflow {
|
|
|
|
class OpKernelContext;
|
|
class Tensor;
|
|
|
|
// Reshapes the input indices and input shape to the target shape.
|
|
+// Note: This template is explicitly instantiated for CPU device only.
|
|
+template <typename Device>
|
|
void ReshapeSparseTensor(OpKernelContext *context,
|
|
const Tensor &input_indices_in,
|
|
const Tensor &input_shape_in,
|
|
const Tensor &target_shape_in, int output_indices_idx,
|
|
int output_shape_idx);
|
|
|
|
+namespace functor {
|
|
+
|
|
+template <typename Device>
|
|
+struct ReshapeSparseTensorFunctor {
|
|
+ Status operator()(const TensorShape &input_shape,
|
|
+ const TensorShape &output_shape,
|
|
+ typename TTypes<int64>::ConstMatrix input_indices,
|
|
+ typename TTypes<int64>::Matrix output_indices) const;
|
|
+};
|
|
+
|
|
+} // namespace functor
|
|
+
|
|
} // namespace tensorflow
|
|
|
|
#endif // TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_
|
|
diff --git a/tensorflow/core/kernels/sparse_reshape_op.cc b/tensorflow/core/kernels/sparse_reshape_op.cc
|
|
index 3896c959..490d9ffd 100644
|
|
--- a/tensorflow/core/kernels/sparse_reshape_op.cc
|
|
+++ b/tensorflow/core/kernels/sparse_reshape_op.cc
|
|
@@ -30,6 +30,9 @@ limitations under the License.
|
|
|
|
namespace tensorflow {
|
|
|
|
+using CPUDevice = Eigen::ThreadPoolDevice;
|
|
+
|
|
+template <typename Device>
|
|
class SparseReshapeOp : public OpKernel {
|
|
public:
|
|
explicit SparseReshapeOp(OpKernelConstruction* context) : OpKernel(context) {}
|
|
@@ -46,12 +49,13 @@ class SparseReshapeOp : public OpKernel {
|
|
input_indices_in.dim_size(1) == input_shape_in.dim_size(0),
|
|
errors::InvalidArgument(
|
|
"Input tensor rank must match input shape length."));
|
|
- ReshapeSparseTensor(context, context->input(0), context->input(1),
|
|
- context->input(2), 0 /* output indices index */,
|
|
- 1 /* output shape index */);
|
|
+ ReshapeSparseTensor<Device>(
|
|
+ context, context->input(0), context->input(1), context->input(2),
|
|
+ 0 /* output indices index */, 1 /* output shape index */);
|
|
}
|
|
};
|
|
|
|
REGISTER_KERNEL_BUILDER(Name("SparseReshape").Device(DEVICE_CPU),
|
|
- SparseReshapeOp)
|
|
+ SparseReshapeOp<CPUDevice>)
|
|
+
|
|
} // namespace tensorflow
|
|
--
|
|
2.27.0
|
|
|