905 lines
40 KiB
Diff
905 lines
40 KiB
Diff
|
|
From be6b1fdb0699d4000b70ad32cc23d1503e5c7511 Mon Sep 17 00:00:00 2001
|
||
|
|
From: Edward Loper <edloper@google.com>
|
||
|
|
Date: Wed, 14 Oct 2020 09:41:17 -0700
|
||
|
|
Subject: [PATCH 1/1] Added gradients for RaggedTensorToVariant and
|
||
|
|
RaggedTensorFromVariant. (This allows gradients to pass through map_fn when
|
||
|
|
it is applied to ragged tensors.)
|
||
|
|
|
||
|
|
PiperOrigin-RevId: 337108621
|
||
|
|
Change-Id: I73d5f3296181877f0cc4c7a6273b693bcf8310ab
|
||
|
|
---
|
||
|
|
tensorflow/core/kernels/BUILD | 15 ++
|
||
|
|
.../kernels/ragged_tensor_from_variant_op.cc | 164 +++++++---------
|
||
|
|
.../kernels/ragged_tensor_to_variant_op.cc | 180 +++++++++++-------
|
||
|
|
.../core/kernels/ragged_tensor_variant.cc | 86 +++++++++
|
||
|
|
.../core/kernels/ragged_tensor_variant.h | 110 +++++++++++
|
||
|
|
tensorflow/core/ops/ragged_conversion_ops.cc | 20 +-
|
||
|
|
tensorflow/python/ops/ragged/BUILD | 1 +
|
||
|
|
9 files changed, 478 insertions(+), 172 deletions(-)
|
||
|
|
create mode 100644 tensorflow/core/framework/tensor_key.h
|
||
|
|
create mode 100644 tensorflow/core/kernels/ragged_tensor_variant.cc
|
||
|
|
create mode 100644 tensorflow/core/kernels/ragged_tensor_variant.h
|
||
|
|
|
||
|
|
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
|
||
|
|
index f5a480b3..12adb2b2 100644
|
||
|
|
--- a/tensorflow/core/kernels/BUILD
|
||
|
|
+++ b/tensorflow/core/kernels/BUILD
|
||
|
|
@@ -1529,10 +1529,22 @@ tf_cc_test(
|
||
|
|
],
|
||
|
|
)
|
||
|
|
|
||
|
|
+cc_library(
|
||
|
|
+ name = "ragged_tensor_variant",
|
||
|
|
+ srcs = ["ragged_tensor_variant.cc"],
|
||
|
|
+ hdrs = ["ragged_tensor_variant.h"],
|
||
|
|
+ deps = [
|
||
|
|
+ ":cwise_op",
|
||
|
|
+ "//tensorflow/core:framework",
|
||
|
|
+ ],
|
||
|
|
+)
|
||
|
|
+
|
||
|
|
tf_kernel_library(
|
||
|
|
name = "ragged_tensor_to_variant_op",
|
||
|
|
srcs = ["ragged_tensor_to_variant_op.cc"],
|
||
|
|
deps = [
|
||
|
|
+ ":concat_lib",
|
||
|
|
+ ":ragged_tensor_variant",
|
||
|
|
"//tensorflow/core:framework",
|
||
|
|
"//tensorflow/core:lib",
|
||
|
|
],
|
||
|
|
@@ -1542,6 +1554,7 @@ tf_kernel_library(
|
||
|
|
name = "ragged_tensor_from_variant_op",
|
||
|
|
srcs = ["ragged_tensor_from_variant_op.cc"],
|
||
|
|
deps = [
|
||
|
|
+ ":ragged_tensor_variant",
|
||
|
|
"//tensorflow/core:framework",
|
||
|
|
"//tensorflow/core:lib",
|
||
|
|
],
|
||
|
|
@@ -1554,6 +1567,7 @@ tf_cc_test(
|
||
|
|
deps = [
|
||
|
|
":ops_testutil",
|
||
|
|
":ragged_tensor_to_variant_op",
|
||
|
|
+ ":ragged_tensor_variant",
|
||
|
|
"//tensorflow/core:framework",
|
||
|
|
"//tensorflow/core:lib",
|
||
|
|
"//tensorflow/core:test",
|
||
|
|
@@ -1570,6 +1584,7 @@ tf_cc_test(
|
||
|
|
deps = [
|
||
|
|
":ops_testutil",
|
||
|
|
":ragged_tensor_from_variant_op",
|
||
|
|
+ ":ragged_tensor_variant",
|
||
|
|
"//tensorflow/core:framework",
|
||
|
|
"//tensorflow/core:lib",
|
||
|
|
"//tensorflow/core:test",
|
||
|
|
diff --git a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
|
||
|
|
index d7b6a89a..fa8853af 100644
|
||
|
|
--- a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
|
||
|
|
+++ b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
|
||
|
|
@@ -20,110 +20,76 @@ limitations under the License.
|
||
|
|
#include "tensorflow/core/framework/tensor.h"
|
||
|
|
#include "tensorflow/core/framework/variant.h"
|
||
|
|
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||
|
|
+#include "tensorflow/core/kernels/ragged_tensor_variant.h"
|
||
|
|
#include "tensorflow/core/lib/core/errors.h"
|
||
|
|
#include "tensorflow/core/lib/core/status.h"
|
||
|
|
|
||
|
|
namespace tensorflow {
|
||
|
|
namespace {
|
||
|
|
|
||
|
|
-struct RaggedTensor {
|
||
|
|
- Tensor values;
|
||
|
|
- std::vector<Tensor> nested_splits;
|
||
|
|
-};
|
||
|
|
-
|
||
|
|
-Status RaggedComponentsFromVariant(const Tensor& encoded_variant,
|
||
|
|
- int ragged_rank, DataType value_dtype,
|
||
|
|
- DataType split_dtype,
|
||
|
|
- std::vector<RaggedTensor>* decoded_ragged) {
|
||
|
|
+Status RaggedComponentsFromVariant(
|
||
|
|
+ const Tensor& encoded_variant, int ragged_rank, DataType value_dtype,
|
||
|
|
+ DataType split_dtype, std::vector<RaggedTensorVariant>* decoded_ragged) {
|
||
|
|
const auto& flat_variants = encoded_variant.flat<Variant>();
|
||
|
|
- decoded_ragged->resize(flat_variants.size());
|
||
|
|
- // Step 1: Extract the 1-D DT_VARIANT Tensor from each Variant element in the
|
||
|
|
- // input.
|
||
|
|
+ decoded_ragged->reserve(flat_variants.size());
|
||
|
|
+
|
||
|
|
for (int i = 0; i < flat_variants.size(); i++) {
|
||
|
|
const auto& flat_variant = flat_variants(i);
|
||
|
|
- const Tensor* encoded_list = flat_variant.get<Tensor>();
|
||
|
|
- if (encoded_list == nullptr) {
|
||
|
|
+ const RaggedTensorVariant* decoded =
|
||
|
|
+ flat_variant.get<RaggedTensorVariant>();
|
||
|
|
+ if (decoded == nullptr) {
|
||
|
|
return errors::InvalidArgument(
|
||
|
|
"Input Variant element at index ", i,
|
||
|
|
- " doesn't hold a Tensor: ", flat_variant.DebugString());
|
||
|
|
+ " doesn't hold a RaggedTensorVariant: ", flat_variant.DebugString());
|
||
|
|
}
|
||
|
|
- if (encoded_list->dims() != 1) {
|
||
|
|
+ decoded_ragged->push_back(*decoded);
|
||
|
|
+ decoded = &decoded_ragged->back();
|
||
|
|
+ // Check ragged rank & types
|
||
|
|
+ if (decoded->ragged_rank() != ragged_rank) {
|
||
|
|
return errors::InvalidArgument(
|
||
|
|
- "Encoded input Variant must have rank 1, but found rank: ",
|
||
|
|
- encoded_list->dims(),
|
||
|
|
- ". encoded input Variant: ", encoded_list->DebugString());
|
||
|
|
+ "Encoded input RaggedTensorVariant has ragged_rank=",
|
||
|
|
+ decoded->ragged_rank(), ". Expected ragged_rank=", ragged_rank, ".");
|
||
|
|
}
|
||
|
|
- if (encoded_list->NumElements() != (ragged_rank + 1) &&
|
||
|
|
- encoded_list->NumElements() != 1) {
|
||
|
|
+ if (decoded->values().dtype() != value_dtype) {
|
||
|
|
return errors::InvalidArgument(
|
||
|
|
- "Encoded input Variant must hold either input_ragged_rank + 1 "
|
||
|
|
- "Tensors or an empty Tensor (zero splits Tensors, 1 values Tensor), "
|
||
|
|
- "input_ragged_rank: ",
|
||
|
|
- ragged_rank,
|
||
|
|
- ", encoded input Variant: ", encoded_list->DebugString());
|
||
|
|
+ "Expected values Tensor dtype: ", DataTypeString(value_dtype),
|
||
|
|
+ ", found: ", DataTypeString(decoded->values().dtype()));
|
||
|
|
}
|
||
|
|
- const auto& input_vec = encoded_list->vec<Variant>();
|
||
|
|
-
|
||
|
|
- // Step 2: Get the splits and value Tensors from the 1-D DT_VARIANT Tensor
|
||
|
|
- // to create the component RaggedTensors.
|
||
|
|
- (*decoded_ragged)[i].nested_splits.reserve(ragged_rank);
|
||
|
|
- for (int j = 0; j < ragged_rank; j++) {
|
||
|
|
- const Tensor* split_tensor = input_vec(j).get<Tensor>();
|
||
|
|
- if (split_tensor == nullptr) {
|
||
|
|
- return errors::InvalidArgument(
|
||
|
|
- "Encoded scalar element at index ", i,
|
||
|
|
- " doesn't have a splits Tensor at split_index ", j, ": ",
|
||
|
|
- input_vec(j).DebugString());
|
||
|
|
- }
|
||
|
|
- Tensor splits_tensor = *split_tensor;
|
||
|
|
- if (splits_tensor.dtype() != split_dtype) {
|
||
|
|
+ if (decoded->values().dims() < 1) {
|
||
|
|
+ return errors::InvalidArgument(
|
||
|
|
+ "Ragged values must have rank >= 1; encoded scalar element at index ",
|
||
|
|
+ i, " has values Tensor: ", decoded->values().DebugString());
|
||
|
|
+ }
|
||
|
|
+ for (const auto& splits : decoded->nested_splits()) {
|
||
|
|
+ if (splits.dtype() != split_dtype) {
|
||
|
|
return errors::InvalidArgument(
|
||
|
|
- "Expected splits Tensor dtype: ", split_dtype,
|
||
|
|
- ", found: ", splits_tensor.dtype());
|
||
|
|
+ "Expected row_splits Tensor dtype: ", DataTypeString(split_dtype),
|
||
|
|
+ ", found: ", DataTypeString(splits.dtype()));
|
||
|
|
}
|
||
|
|
- if (splits_tensor.dims() != 1) {
|
||
|
|
+ if (splits.dims() != 1) {
|
||
|
|
return errors::InvalidArgument(
|
||
|
|
"Ragged splits must have rank 1; encoded scalar element at index ",
|
||
|
|
- i, " has splits Tensor at split_index ", j, ": ",
|
||
|
|
- splits_tensor.DebugString());
|
||
|
|
+ i, " has splits Tensor ", splits.DebugString());
|
||
|
|
}
|
||
|
|
- (*decoded_ragged)[i].nested_splits.push_back(splits_tensor);
|
||
|
|
- }
|
||
|
|
- const Tensor* values_tensor = input_vec(ragged_rank).get<Tensor>();
|
||
|
|
- if (values_tensor == nullptr) {
|
||
|
|
- return errors::InvalidArgument("Encoded scalar element at index ", i,
|
||
|
|
- " doesn't have a values Tensor: ",
|
||
|
|
- input_vec(ragged_rank).DebugString());
|
||
|
|
- }
|
||
|
|
- if (values_tensor->dtype() != value_dtype) {
|
||
|
|
- return errors::InvalidArgument(
|
||
|
|
- "Expected values Tensor dtype: ", DataTypeString(value_dtype),
|
||
|
|
- ", found: ", DataTypeString(values_tensor->dtype()));
|
||
|
|
- }
|
||
|
|
- if (values_tensor->dims() < 1) {
|
||
|
|
- return errors::InvalidArgument(
|
||
|
|
- "Ragged values must have rank >= 1; encoded scalar element at index ",
|
||
|
|
- i, " has values Tensor: ", values_tensor->DebugString());
|
||
|
|
}
|
||
|
|
- (*decoded_ragged)[i].values = *values_tensor;
|
||
|
|
}
|
||
|
|
return Status::OK();
|
||
|
|
}
|
||
|
|
|
||
|
|
template <typename VALUE_TYPE, typename SPLIT_TYPE>
|
||
|
|
Status NestedStackRaggedTensors(
|
||
|
|
- const std::vector<RaggedTensor>& ragged_components,
|
||
|
|
+ const std::vector<RaggedTensorVariant>& ragged_components,
|
||
|
|
const std::vector<int>& nested_dim_sizes, const int input_ragged_rank,
|
||
|
|
- const int output_ragged_rank, RaggedTensor* output_ragged) {
|
||
|
|
- output_ragged->nested_splits.reserve(output_ragged_rank);
|
||
|
|
+ const int output_ragged_rank, RaggedTensorVariant* output_ragged) {
|
||
|
|
+ output_ragged->mutable_nested_splits()->reserve(output_ragged_rank);
|
||
|
|
const int dims = nested_dim_sizes.size();
|
||
|
|
|
||
|
|
// Populate first `dims - 1` splits.
|
||
|
|
for (int i = 0; i < dims - 1; i++) {
|
||
|
|
int dims_splits_size = nested_dim_sizes[i] + 1;
|
||
|
|
- output_ragged->nested_splits.push_back(Tensor(
|
||
|
|
- DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({dims_splits_size})));
|
||
|
|
- auto splits_vec = output_ragged->nested_splits[i].vec<SPLIT_TYPE>();
|
||
|
|
+ output_ragged->append_splits(Tensor(DataTypeToEnum<SPLIT_TYPE>::value,
|
||
|
|
+ TensorShape({dims_splits_size})));
|
||
|
|
+ auto splits_vec = output_ragged->mutable_splits(i)->vec<SPLIT_TYPE>();
|
||
|
|
int split_diff = nested_dim_sizes[i + 1];
|
||
|
|
for (int j = 0; j < dims_splits_size; j++) {
|
||
|
|
splits_vec(j) = j * split_diff;
|
||
|
|
@@ -132,15 +98,15 @@ Status NestedStackRaggedTensors(
|
||
|
|
|
||
|
|
// Populate `dims`-th split.
|
||
|
|
int splits_size = ragged_components.size() + 1;
|
||
|
|
- output_ragged->nested_splits.push_back(
|
||
|
|
+ output_ragged->append_splits(
|
||
|
|
Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({splits_size})));
|
||
|
|
auto dims_splits_vec =
|
||
|
|
- output_ragged->nested_splits[dims - 1].vec<SPLIT_TYPE>();
|
||
|
|
+ output_ragged->mutable_splits(dims - 1)->vec<SPLIT_TYPE>();
|
||
|
|
dims_splits_vec(0) = 0;
|
||
|
|
for (int i = 0; i < ragged_components.size(); i++) {
|
||
|
|
- int split_val = ragged_components[i].values.shape().dim_size(0);
|
||
|
|
- if (input_ragged_rank != 0 && !ragged_components[i].nested_splits.empty()) {
|
||
|
|
- split_val = ragged_components[i].nested_splits[0].NumElements() - 1;
|
||
|
|
+ int split_val = ragged_components[i].values().shape().dim_size(0);
|
||
|
|
+ if (input_ragged_rank != 0 && ragged_components[i].ragged_rank() > 0) {
|
||
|
|
+ split_val = ragged_components[i].splits(0).NumElements() - 1;
|
||
|
|
}
|
||
|
|
dims_splits_vec(i + 1) = dims_splits_vec(i) + split_val;
|
||
|
|
}
|
||
|
|
@@ -150,24 +116,24 @@ Status NestedStackRaggedTensors(
|
||
|
|
int split_index = dims + i;
|
||
|
|
int split_size = 1;
|
||
|
|
for (int j = 0; j < ragged_components.size(); j++) {
|
||
|
|
- if (!ragged_components[j].nested_splits.empty()) {
|
||
|
|
- split_size += ragged_components[j].nested_splits[i].NumElements() - 1;
|
||
|
|
+ if (!ragged_components[j].nested_splits().empty()) {
|
||
|
|
+ split_size += ragged_components[j].splits(i).NumElements() - 1;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
- output_ragged->nested_splits.push_back(
|
||
|
|
+ output_ragged->append_splits(
|
||
|
|
Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({split_size})));
|
||
|
|
auto splits_vec =
|
||
|
|
- output_ragged->nested_splits[split_index].vec<SPLIT_TYPE>();
|
||
|
|
+ output_ragged->mutable_splits(split_index)->vec<SPLIT_TYPE>();
|
||
|
|
splits_vec(0) = 0;
|
||
|
|
SPLIT_TYPE last_split_value = 0;
|
||
|
|
int index = 1;
|
||
|
|
for (int j = 0; j < ragged_components.size(); j++) {
|
||
|
|
- if (ragged_components[j].nested_splits.empty()) {
|
||
|
|
+ if (ragged_components[j].nested_splits().empty()) {
|
||
|
|
// Corner case: empty row. e.g [ [[x], [x]], [] ]
|
||
|
|
continue;
|
||
|
|
}
|
||
|
|
auto component_splits_vec =
|
||
|
|
- ragged_components[j].nested_splits[i].vec<SPLIT_TYPE>();
|
||
|
|
+ ragged_components[j].splits(i).vec<SPLIT_TYPE>();
|
||
|
|
for (int k = 1; k < component_splits_vec.size(); k++, index++) {
|
||
|
|
splits_vec(index) = component_splits_vec(k) + last_split_value;
|
||
|
|
}
|
||
|
|
@@ -187,35 +153,35 @@ Status NestedStackRaggedTensors(
|
||
|
|
if (ragged_components.empty()) {
|
||
|
|
component_values_shape = TensorShape({0});
|
||
|
|
} else {
|
||
|
|
- component_values_shape = ragged_components[0].values.shape();
|
||
|
|
+ component_values_shape = ragged_components[0].values().shape();
|
||
|
|
}
|
||
|
|
|
||
|
|
// Populate values.
|
||
|
|
int values_size = component_values_shape.dim_size(0);
|
||
|
|
for (int i = 1; i < ragged_components.size(); i++) {
|
||
|
|
- if (ragged_components[i].values.dims() != component_values_shape.dims()) {
|
||
|
|
+ if (ragged_components[i].values().dims() != component_values_shape.dims()) {
|
||
|
|
return errors::InvalidArgument(
|
||
|
|
"Rank of values must match for all "
|
||
|
|
"components; values shape at index 0: ",
|
||
|
|
component_values_shape.DebugString(), ", values shape at index ", i,
|
||
|
|
- ": ", ragged_components[i].values.shape().DebugString());
|
||
|
|
+ ": ", ragged_components[i].values().shape().DebugString());
|
||
|
|
}
|
||
|
|
- values_size += ragged_components[i].values.shape().dim_size(0);
|
||
|
|
+ values_size += ragged_components[i].values().shape().dim_size(0);
|
||
|
|
}
|
||
|
|
component_values_shape.set_dim(0, values_size);
|
||
|
|
- output_ragged->values =
|
||
|
|
- Tensor(DataTypeToEnum<VALUE_TYPE>::value, component_values_shape);
|
||
|
|
+ output_ragged->set_values(
|
||
|
|
+ Tensor(DataTypeToEnum<VALUE_TYPE>::value, component_values_shape));
|
||
|
|
auto output_values_flat =
|
||
|
|
- output_ragged->values.flat_outer_dims<VALUE_TYPE, 2>();
|
||
|
|
+ output_ragged->mutable_values()->flat_outer_dims<VALUE_TYPE, 2>();
|
||
|
|
int values_index = 0;
|
||
|
|
for (int i = 0; i < ragged_components.size(); i++) {
|
||
|
|
auto component_values_flat =
|
||
|
|
- ragged_components[i].values.flat_outer_dims<VALUE_TYPE, 2>();
|
||
|
|
- int num_inner_elements = ragged_components[i].values.NumElements();
|
||
|
|
- if (ragged_components[i].values.dim_size(0) > 0) {
|
||
|
|
- num_inner_elements /= ragged_components[i].values.dim_size(0);
|
||
|
|
+ ragged_components[i].values().flat_outer_dims<VALUE_TYPE, 2>();
|
||
|
|
+ int num_inner_elements = ragged_components[i].values().NumElements();
|
||
|
|
+ if (ragged_components[i].values().dim_size(0) > 0) {
|
||
|
|
+ num_inner_elements /= ragged_components[i].values().dim_size(0);
|
||
|
|
}
|
||
|
|
- for (int j = 0; j < ragged_components[i].values.dim_size(0);
|
||
|
|
+ for (int j = 0; j < ragged_components[i].values().dim_size(0);
|
||
|
|
j++, values_index++) {
|
||
|
|
for (int k = 0; k < num_inner_elements; k++) {
|
||
|
|
output_values_flat(values_index, k) = component_values_flat(j, k);
|
||
|
|
@@ -265,7 +231,7 @@ class RaggedTensorFromVariantOp : public OpKernel {
|
||
|
|
// Decode all variants.
|
||
|
|
const auto value_dtype = DataTypeToEnum<VALUE_TYPE>::v();
|
||
|
|
const auto split_dtype = DataTypeToEnum<SPLIT_TYPE>::v();
|
||
|
|
- std::vector<RaggedTensor> decoded_components;
|
||
|
|
+ std::vector<RaggedTensorVariant> decoded_components;
|
||
|
|
OP_REQUIRES_OK(context, RaggedComponentsFromVariant(
|
||
|
|
encoded_variant, input_ragged_rank_,
|
||
|
|
value_dtype, split_dtype, &decoded_components));
|
||
|
|
@@ -281,7 +247,7 @@ class RaggedTensorFromVariantOp : public OpKernel {
|
||
|
|
for (int i = 0; i < encoded_variant.dims(); i++) {
|
||
|
|
encoded_dim_sizes[i] = encoded_variant.dim_size(i);
|
||
|
|
}
|
||
|
|
- RaggedTensor output_ragged;
|
||
|
|
+ RaggedTensorVariant output_ragged;
|
||
|
|
OP_REQUIRES_OK(
|
||
|
|
context, NestedStackRaggedTensors<VALUE_TYPE, SPLIT_TYPE>(
|
||
|
|
decoded_components, encoded_dim_sizes, input_ragged_rank_,
|
||
|
|
@@ -296,15 +262,15 @@ class RaggedTensorFromVariantOp : public OpKernel {
|
||
|
|
int output_ragged_rank_;
|
||
|
|
|
||
|
|
void ReturnRaggedTensor(OpKernelContext* context,
|
||
|
|
- RaggedTensor ragged_tensor) {
|
||
|
|
- int ragged_rank = ragged_tensor.nested_splits.size();
|
||
|
|
+ const RaggedTensorVariant& ragged_tensor) {
|
||
|
|
+ int ragged_rank = ragged_tensor.ragged_rank();
|
||
|
|
OpOutputList splits_out;
|
||
|
|
OP_REQUIRES_OK(context,
|
||
|
|
context->output_list("output_nested_splits", &splits_out));
|
||
|
|
for (int i = 0; i < ragged_rank; i++) {
|
||
|
|
- splits_out.set(i, ragged_tensor.nested_splits[i]);
|
||
|
|
+ splits_out.set(i, ragged_tensor.splits(i));
|
||
|
|
}
|
||
|
|
- context->set_output(ragged_rank, ragged_tensor.values);
|
||
|
|
+ context->set_output(ragged_rank, ragged_tensor.values());
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
diff --git a/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc
|
||
|
|
index 3190534b..a60e5c62 100644
|
||
|
|
--- a/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc
|
||
|
|
+++ b/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc
|
||
|
|
@@ -18,50 +18,38 @@ limitations under the License.
|
||
|
|
#include "tensorflow/core/framework/op_kernel.h"
|
||
|
|
#include "tensorflow/core/framework/register_types.h"
|
||
|
|
#include "tensorflow/core/framework/tensor.h"
|
||
|
|
+#include "tensorflow/core/framework/tensor_shape.h"
|
||
|
|
#include "tensorflow/core/framework/variant.h"
|
||
|
|
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||
|
|
+#include "tensorflow/core/framework/variant_op_registry.h"
|
||
|
|
+#include "tensorflow/core/kernels/concat_lib.h"
|
||
|
|
+#include "tensorflow/core/kernels/ragged_tensor_variant.h"
|
||
|
|
#include "tensorflow/core/lib/core/errors.h"
|
||
|
|
#include "tensorflow/core/lib/core/status.h"
|
||
|
|
+#include "tensorflow/core/util/tensor_ops_util.h"
|
||
|
|
|
||
|
|
namespace tensorflow {
|
||
|
|
namespace {
|
||
|
|
|
||
|
|
-struct RaggedTensor {
|
||
|
|
- Tensor values;
|
||
|
|
- std::vector<Tensor> nested_splits;
|
||
|
|
-};
|
||
|
|
-
|
||
|
|
-Status RaggedToVariant(const RaggedTensor& ragged, Tensor* encoded_list) {
|
||
|
|
- // Encode as a rank-1 Variant Tensor.
|
||
|
|
- int ragged_rank = ragged.nested_splits.size();
|
||
|
|
- *encoded_list = Tensor(DT_VARIANT, TensorShape({ragged_rank + 1}));
|
||
|
|
- auto encoded_vec = encoded_list->vec<Variant>();
|
||
|
|
- for (int i = 0; i < ragged_rank; i++) {
|
||
|
|
- encoded_vec(i) = ragged.nested_splits[i];
|
||
|
|
- }
|
||
|
|
- encoded_vec(ragged_rank) = ragged.values;
|
||
|
|
- return Status::OK();
|
||
|
|
-}
|
||
|
|
-
|
||
|
|
template <typename VALUE_TYPE, typename SPLIT_TYPE>
|
||
|
|
-Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
|
||
|
|
- std::vector<RaggedTensor>* ragged_components) {
|
||
|
|
+Status UnbatchRaggedZerothDim(
|
||
|
|
+ const RaggedTensorVariant& batched_ragged,
|
||
|
|
+ std::vector<RaggedTensorVariant>* ragged_components) {
|
||
|
|
// Set up the component Ragged Tensors.
|
||
|
|
- int ragged_rank = batched_ragged.nested_splits.size();
|
||
|
|
- auto batched_splits_top_vec =
|
||
|
|
- batched_ragged.nested_splits[0].vec<SPLIT_TYPE>();
|
||
|
|
+ int ragged_rank = batched_ragged.ragged_rank();
|
||
|
|
+ auto batched_splits_top_vec = batched_ragged.splits(0).vec<SPLIT_TYPE>();
|
||
|
|
int num_components = batched_splits_top_vec.size() - 1;
|
||
|
|
int num_splits = ragged_rank - 1;
|
||
|
|
ragged_components->resize(num_components);
|
||
|
|
- for (RaggedTensor ragged_component : *ragged_components) {
|
||
|
|
- ragged_component.nested_splits.reserve(num_splits);
|
||
|
|
+ for (RaggedTensorVariant& ragged_component : *ragged_components) {
|
||
|
|
+ ragged_component.mutable_nested_splits()->reserve(num_splits);
|
||
|
|
}
|
||
|
|
- const auto& batched_flat = batched_ragged.values.flat<VALUE_TYPE>();
|
||
|
|
- int num_inner_elems = batched_ragged.values.NumElements();
|
||
|
|
- if (batched_ragged.values.dim_size(0) > 1) {
|
||
|
|
- num_inner_elems /= batched_ragged.values.dim_size(0);
|
||
|
|
+ const auto& batched_flat = batched_ragged.values().flat<VALUE_TYPE>();
|
||
|
|
+ int num_inner_elems = batched_ragged.values().NumElements();
|
||
|
|
+ if (batched_ragged.values().dim_size(0) > 1) {
|
||
|
|
+ num_inner_elems /= batched_ragged.values().dim_size(0);
|
||
|
|
}
|
||
|
|
- TensorShape values_shape = batched_ragged.values.shape();
|
||
|
|
+ TensorShape values_shape = batched_ragged.values().shape();
|
||
|
|
|
||
|
|
// Corner case: ragged_rank == 1, e.g. [[1, 2, 3], [4, 5]]
|
||
|
|
if (num_splits == 0) {
|
||
|
|
@@ -70,10 +58,10 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
|
||
|
|
int limit = batched_splits_top_vec(i + 1);
|
||
|
|
int num_values = limit - start;
|
||
|
|
values_shape.set_dim(0, num_values);
|
||
|
|
- (*ragged_components)[i].values =
|
||
|
|
- Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape);
|
||
|
|
+ (*ragged_components)[i].set_values(
|
||
|
|
+ Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
|
||
|
|
auto ragged_component_values_flat =
|
||
|
|
- (*ragged_components)[i].values.flat<VALUE_TYPE>();
|
||
|
|
+ (*ragged_components)[i].mutable_values()->flat<VALUE_TYPE>();
|
||
|
|
for (int j = 0; j < num_values * num_inner_elems; j++) {
|
||
|
|
ragged_component_values_flat(j) =
|
||
|
|
batched_flat(j + start * num_inner_elems);
|
||
|
|
@@ -86,8 +74,7 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
|
||
|
|
std::vector<typename TTypes<SPLIT_TYPE>::ConstVec> batched_splits_vec;
|
||
|
|
batched_splits_vec.reserve(ragged_rank);
|
||
|
|
for (int i = 0; i < ragged_rank; i++) {
|
||
|
|
- batched_splits_vec.push_back(
|
||
|
|
- batched_ragged.nested_splits[i].vec<SPLIT_TYPE>());
|
||
|
|
+ batched_splits_vec.push_back(batched_ragged.splits(i).vec<SPLIT_TYPE>());
|
||
|
|
}
|
||
|
|
std::vector<int> index(num_splits, 1);
|
||
|
|
std::vector<int> ragged_component_values_size(num_components, 0);
|
||
|
|
@@ -104,10 +91,10 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
|
||
|
|
int last_index = ragged_component_splits_vec[j - 1].size() - 1;
|
||
|
|
split_size = ragged_component_splits_vec[j - 1](last_index) + 1;
|
||
|
|
}
|
||
|
|
- (*ragged_components)[i].nested_splits.push_back(
|
||
|
|
+ (*ragged_components)[i].append_splits(
|
||
|
|
Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({split_size})));
|
||
|
|
ragged_component_splits_vec.push_back(
|
||
|
|
- (*ragged_components)[i].nested_splits[j].vec<SPLIT_TYPE>());
|
||
|
|
+ (*ragged_components)[i].mutable_splits(j)->vec<SPLIT_TYPE>());
|
||
|
|
SPLIT_TYPE last_split_value = batched_splits_vec[j + 1](index[j] - 1);
|
||
|
|
ragged_component_splits_vec[j](0) = 0;
|
||
|
|
for (int k = 1; k < split_size; k++, index[j]++) {
|
||
|
|
@@ -125,10 +112,10 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
|
||
|
|
for (int i = 0; i < num_components; i++) {
|
||
|
|
int num_values = ragged_component_values_size[i];
|
||
|
|
values_shape.set_dim(0, num_values);
|
||
|
|
- (*ragged_components)[i].values =
|
||
|
|
- Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape);
|
||
|
|
+ (*ragged_components)[i].set_values(
|
||
|
|
+ Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
|
||
|
|
auto ragged_component_values_flat =
|
||
|
|
- (*ragged_components)[i].values.flat<VALUE_TYPE>();
|
||
|
|
+ (*ragged_components)[i].mutable_values()->flat<VALUE_TYPE>();
|
||
|
|
for (int j = 0; j < num_values * num_inner_elems; j++, value_index++) {
|
||
|
|
ragged_component_values_flat(j) = batched_flat(value_index);
|
||
|
|
}
|
||
|
|
@@ -152,24 +139,21 @@ class RaggedTensorToVariantOp : public OpKernel {
|
||
|
|
OP_REQUIRES_OK(context, context->input_list("rt_nested_splits",
|
||
|
|
&ragged_nested_splits_in));
|
||
|
|
const int ragged_nested_splits_len = ragged_nested_splits_in.size();
|
||
|
|
- RaggedTensor batched_ragged_input;
|
||
|
|
+ RaggedTensorVariant batched_ragged_input;
|
||
|
|
// Read ragged_values input.
|
||
|
|
- batched_ragged_input.values = context->input(ragged_nested_splits_len);
|
||
|
|
- batched_ragged_input.nested_splits.reserve(ragged_nested_splits_len);
|
||
|
|
+ batched_ragged_input.set_values(context->input(ragged_nested_splits_len));
|
||
|
|
+ batched_ragged_input.mutable_nested_splits()->reserve(
|
||
|
|
+ ragged_nested_splits_len);
|
||
|
|
for (int i = 0; i < ragged_nested_splits_len; i++) {
|
||
|
|
- batched_ragged_input.nested_splits.push_back(ragged_nested_splits_in[i]);
|
||
|
|
+ batched_ragged_input.append_splits(ragged_nested_splits_in[i]);
|
||
|
|
}
|
||
|
|
|
||
|
|
if (!batched_input_) {
|
||
|
|
- // Encode the input as is.
|
||
|
|
- Tensor encoded_list;
|
||
|
|
- OP_REQUIRES_OK(context,
|
||
|
|
- RaggedToVariant(batched_ragged_input, &encoded_list));
|
||
|
|
// Encode as a Scalar Variant Tensor.
|
||
|
|
Tensor* encoded_scalar;
|
||
|
|
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}),
|
||
|
|
&encoded_scalar));
|
||
|
|
- encoded_scalar->scalar<Variant>()() = std::move(encoded_list);
|
||
|
|
+ encoded_scalar->scalar<Variant>()() = std::move(batched_ragged_input);
|
||
|
|
return;
|
||
|
|
}
|
||
|
|
|
||
|
|
@@ -180,24 +164,19 @@ class RaggedTensorToVariantOp : public OpKernel {
|
||
|
|
"received rt_nested_splits of length 0."));
|
||
|
|
|
||
|
|
// Unbatch the Ragged Tensor and encode the components.
|
||
|
|
- std::vector<RaggedTensor> ragged_components;
|
||
|
|
+ std::vector<RaggedTensorVariant> unbatched_ragged_input;
|
||
|
|
OP_REQUIRES_OK(context, UnbatchRaggedZerothDim<VALUE_TYPE, SPLIT_TYPE>(
|
||
|
|
- batched_ragged_input, &ragged_components));
|
||
|
|
- std::vector<Tensor> encoded_components(ragged_components.size());
|
||
|
|
- for (int i = 0; i < ragged_components.size(); i++) {
|
||
|
|
- OP_REQUIRES_OK(context, RaggedToVariant(ragged_components[i],
|
||
|
|
- &encoded_components[i]));
|
||
|
|
- }
|
||
|
|
+ batched_ragged_input, &unbatched_ragged_input));
|
||
|
|
|
||
|
|
// Bundle the encoded scalar Variant Tensors into a rank-1 Variant Tensor.
|
||
|
|
- Tensor* encoded_ragged;
|
||
|
|
- int output_size = ragged_components.size();
|
||
|
|
+ Tensor* encoded_vector;
|
||
|
|
+ int output_size = unbatched_ragged_input.size();
|
||
|
|
OP_REQUIRES_OK(context,
|
||
|
|
context->allocate_output(0, TensorShape({output_size}),
|
||
|
|
- &encoded_ragged));
|
||
|
|
- auto encoded_ragged_vec = encoded_ragged->vec<Variant>();
|
||
|
|
+ &encoded_vector));
|
||
|
|
+ auto encoded_vector_t = encoded_vector->vec<Variant>();
|
||
|
|
for (int i = 0; i < output_size; i++) {
|
||
|
|
- encoded_ragged_vec(i) = encoded_components[i];
|
||
|
|
+ encoded_vector_t(i) = unbatched_ragged_input[i];
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
@@ -205,12 +184,81 @@ class RaggedTensorToVariantOp : public OpKernel {
|
||
|
|
bool batched_input_;
|
||
|
|
};
|
||
|
|
|
||
|
|
-#define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type) \
|
||
|
|
- REGISTER_KERNEL_BUILDER(Name("RaggedTensorToVariant") \
|
||
|
|
- .Device(DEVICE_CPU) \
|
||
|
|
- .TypeConstraint<value_type>("Tvalues") \
|
||
|
|
- .TypeConstraint<split_type>("Tsplits"), \
|
||
|
|
- RaggedTensorToVariantOp<value_type, split_type>);
|
||
|
|
+template <typename VALUE_TYPE, typename SPLIT_TYPE>
|
||
|
|
+class RaggedTensorToVariantGradientOp : public OpKernel {
|
||
|
|
+ public:
|
||
|
|
+ using OpKernel::OpKernel;
|
||
|
|
+
|
||
|
|
+ void Compute(OpKernelContext* context) override {
|
||
|
|
+ // Read inputs.
|
||
|
|
+ Tensor encoded_variant = context->input(0);
|
||
|
|
+ Tensor row_splits = context->input(1);
|
||
|
|
+ auto flat_row_splits = row_splits.flat<SPLIT_TYPE>();
|
||
|
|
+ TensorShape dense_values_shape;
|
||
|
|
+ OP_REQUIRES_OK(context,
|
||
|
|
+ TensorShapeUtils::MakeShape(context->input(2).vec<int32>(),
|
||
|
|
+ &dense_values_shape));
|
||
|
|
+
|
||
|
|
+ const auto& flat_variants = encoded_variant.flat<Variant>();
|
||
|
|
+
|
||
|
|
+ // Get a Tensor containing the flat_values for each variant.
|
||
|
|
+ std::vector<Tensor> values;
|
||
|
|
+ for (int i = 0; i < flat_variants.size(); ++i) {
|
||
|
|
+ if (const auto* encoded = flat_variants(i).get<RaggedTensorVariant>()) {
|
||
|
|
+ values.push_back(encoded->values());
|
||
|
|
+ } else {
|
||
|
|
+ // Missing value: this happens if only some of the variant values
|
||
|
|
+ // generated by ragged_tensor_to_variant impacted the value that we're
|
||
|
|
+ // calculating the gradient for. In this case, we will see a
|
||
|
|
+ // default-constructed variant; so treat it as a zero tensor with the
|
||
|
|
+ // appropriate shape.
|
||
|
|
+ const auto value_dtype = DataTypeToEnum<VALUE_TYPE>::v();
|
||
|
|
+ int piece_size = flat_row_splits(i + 1) - flat_row_splits(i);
|
||
|
|
+ TensorShape zeros_shape = dense_values_shape;
|
||
|
|
+ zeros_shape.set_dim(0, piece_size);
|
||
|
|
+ Tensor zero(value_dtype, zeros_shape);
|
||
|
|
+ zero.flat<VALUE_TYPE>() =
|
||
|
|
+ zero.flat<VALUE_TYPE>().constant(VALUE_TYPE());
|
||
|
|
+ values.push_back(zero);
|
||
|
|
+ }
|
||
|
|
+ }
|
||
|
|
+
|
||
|
|
+ if (values.size() == 1) {
|
||
|
|
+ // Just one flat_value tensor: return as-is.
|
||
|
|
+ context->set_output(0, values[0]);
|
||
|
|
+ } else {
|
||
|
|
+ // Multiple flat_values tensors: concatenate them together.
|
||
|
|
+ using Piece = typename TTypes<VALUE_TYPE, 2>::Matrix;
|
||
|
|
+ using ConstPiece = typename TTypes<VALUE_TYPE, 2>::ConstMatrix;
|
||
|
|
+ std::vector<std::unique_ptr<ConstPiece>> pieces;
|
||
|
|
+ pieces.reserve(values.size());
|
||
|
|
+ for (const Tensor& t : values) {
|
||
|
|
+ pieces.emplace_back(
|
||
|
|
+ new ConstPiece(t.shaped<VALUE_TYPE, 2>({1, t.NumElements()})));
|
||
|
|
+ }
|
||
|
|
+ Tensor* out = nullptr;
|
||
|
|
+ OP_REQUIRES_OK(context,
|
||
|
|
+ context->allocate_output(0, dense_values_shape, &out));
|
||
|
|
+ Piece out_flat =
|
||
|
|
+ out->shaped<VALUE_TYPE, 2>({1, dense_values_shape.num_elements()});
|
||
|
|
+ ConcatCPU<VALUE_TYPE>(context->device(), pieces, &out_flat);
|
||
|
|
+ }
|
||
|
|
+ }
|
||
|
|
+};
|
||
|
|
+
|
||
|
|
+#define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type) \
|
||
|
|
+ REGISTER_KERNEL_BUILDER(Name("RaggedTensorToVariant") \
|
||
|
|
+ .Device(DEVICE_CPU) \
|
||
|
|
+ .TypeConstraint<value_type>("Tvalues") \
|
||
|
|
+ .TypeConstraint<split_type>("Tsplits"), \
|
||
|
|
+ RaggedTensorToVariantOp<value_type, split_type>); \
|
||
|
|
+ REGISTER_KERNEL_BUILDER( \
|
||
|
|
+ Name("RaggedTensorToVariantGradient") \
|
||
|
|
+ .Device(DEVICE_CPU) \
|
||
|
|
+ .TypeConstraint<value_type>("Tvalues") \
|
||
|
|
+ .TypeConstraint<split_type>("Tsplits"), \
|
||
|
|
+ RaggedTensorToVariantGradientOp<value_type, split_type>);
|
||
|
|
+
|
||
|
|
#define REGISTER_KERNELS(value_type) \
|
||
|
|
REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int32) \
|
||
|
|
REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int64)
|
||
|
|
diff --git a/tensorflow/core/kernels/ragged_tensor_variant.cc b/tensorflow/core/kernels/ragged_tensor_variant.cc
|
||
|
|
new file mode 100644
|
||
|
|
index 00000000..94663138
|
||
|
|
--- /dev/null
|
||
|
|
+++ b/tensorflow/core/kernels/ragged_tensor_variant.cc
|
||
|
|
@@ -0,0 +1,86 @@
|
||
|
|
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||
|
|
+
|
||
|
|
+Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
|
+you may not use this file except in compliance with the License.
|
||
|
|
+You may obtain a copy of the License at
|
||
|
|
+
|
||
|
|
+ http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
+
|
||
|
|
+Unless required by applicable law or agreed to in writing, software
|
||
|
|
+distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
|
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
|
+See the License for the specific language governing permissions and
|
||
|
|
+limitations under the License.
|
||
|
|
+==============================================================================*/
|
||
|
|
+
|
||
|
|
+#define EIGEN_USE_THREADS
|
||
|
|
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||
|
|
+#define EIGEN_USE_GPU
|
||
|
|
+#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||
|
|
+
|
||
|
|
+#include "tensorflow/core/kernels/ragged_tensor_variant.h"
|
||
|
|
+
|
||
|
|
+namespace tensorflow {
|
||
|
|
+
|
||
|
|
+string RaggedTensorVariant::TypeName() const { return "RaggedTensorVariant"; }
|
||
|
|
+
|
||
|
|
+string RaggedTensorVariant::DebugString() const {
|
||
|
|
+ return absl::StrCat(
|
||
|
|
+ "RaggedTensorVariant(dtype=", DataTypeString(values_.dtype()),
|
||
|
|
+ ", ragged_rank=", nested_splits_.size(), ", splits_dtype=",
|
||
|
|
+ DataTypeString(nested_splits_.empty() ? DT_INVALID
|
||
|
|
+ : nested_splits_.back().dtype()));
|
||
|
|
+}
|
||
|
|
+
|
||
|
|
+void RaggedTensorVariant::Encode(VariantTensorData* data) const {
|
||
|
|
+ data->set_type_name(TypeName());
|
||
|
|
+ for (const auto& splits : nested_splits_) {
|
||
|
|
+ *data->add_tensors() = splits;
|
||
|
|
+ }
|
||
|
|
+ *data->add_tensors() = values_;
|
||
|
|
+}
|
||
|
|
+
|
||
|
|
+bool RaggedTensorVariant::Decode(const VariantTensorData& data) {
|
||
|
|
+ if (data.tensors_size() < 1) {
|
||
|
|
+ return false;
|
||
|
|
+ }
|
||
|
|
+ nested_splits_.assign(data.tensors().begin(),
|
||
|
|
+ std::prev(data.tensors().end()));
|
||
|
|
+ values_ = data.tensors().back();
|
||
|
|
+ return true;
|
||
|
|
+}
|
||
|
|
+
|
||
|
|
+namespace {
|
||
|
|
+
|
||
|
|
+Status RaggedTensorVariantDeviceCopy(
|
||
|
|
+ const RaggedTensorVariant& from, RaggedTensorVariant* to,
|
||
|
|
+ const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
|
||
|
|
+ TF_RETURN_IF_ERROR(copy(from.values(), to->mutable_values()));
|
||
|
|
+ // TODO(b/170415165) Should we use `copy` to move splits from device<->host?
|
||
|
|
+ *to->mutable_nested_splits() = from.nested_splits();
|
||
|
|
+ return Status::OK();
|
||
|
|
+}
|
||
|
|
+
|
||
|
|
+} // namespace
|
||
|
|
+
|
||
|
|
+REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(
|
||
|
|
+ ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, RaggedTensorVariant,
|
||
|
|
+ RaggedTensorVariantZerosLike<CPUDevice>);
|
||
|
|
+
|
||
|
|
+REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(
|
||
|
|
+ ADD_VARIANT_BINARY_OP, DEVICE_CPU, RaggedTensorVariant,
|
||
|
|
+ RaggedTensorVariantBinaryAdd<CPUDevice>);
|
||
|
|
+
|
||
|
|
+REGISTER_UNARY_VARIANT_DECODE_FUNCTION(RaggedTensorVariant,
|
||
|
|
+ "RaggedTensorVariant");
|
||
|
|
+
|
||
|
|
+#define REGISTER_RAGGED_TENSOR_VARIANT_COPY(DIRECTION) \
|
||
|
|
+ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
|
||
|
|
+ RaggedTensorVariant, DIRECTION, RaggedTensorVariantDeviceCopy)
|
||
|
|
+
|
||
|
|
+REGISTER_RAGGED_TENSOR_VARIANT_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
|
||
|
|
+REGISTER_RAGGED_TENSOR_VARIANT_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
|
||
|
|
+REGISTER_RAGGED_TENSOR_VARIANT_COPY(
|
||
|
|
+ VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
|
||
|
|
+
|
||
|
|
+} // namespace tensorflow
|
||
|
|
diff --git a/tensorflow/core/kernels/ragged_tensor_variant.h b/tensorflow/core/kernels/ragged_tensor_variant.h
|
||
|
|
new file mode 100644
|
||
|
|
index 00000000..730758a3
|
||
|
|
--- /dev/null
|
||
|
|
+++ b/tensorflow/core/kernels/ragged_tensor_variant.h
|
||
|
|
@@ -0,0 +1,110 @@
|
||
|
|
+#include "tensorflow/core/framework/tensor_key.h"
|
||
|
|
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||
|
|
+
|
||
|
|
+Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
|
+you may not use this file except in compliance with the License.
|
||
|
|
+You may obtain a copy of the License at
|
||
|
|
+
|
||
|
|
+ http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
+
|
||
|
|
+Unless required by applicable law or agreed to in writing, software
|
||
|
|
+distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
|
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
|
+See the License for the specific language governing permissions and
|
||
|
|
+limitations under the License.
|
||
|
|
+==============================================================================*/
|
||
|
|
+
|
||
|
|
+#ifndef TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_
|
||
|
|
+#define TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_
|
||
|
|
+
|
||
|
|
+#define EIGEN_USE_THREADS
|
||
|
|
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||
|
|
+#define EIGEN_USE_GPU
|
||
|
|
+#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||
|
|
+
|
||
|
|
+#include <vector>
|
||
|
|
+
|
||
|
|
+#include "tensorflow/core/framework/tensor.h"
|
||
|
|
+#include "tensorflow/core/framework/types.h"
|
||
|
|
+#include "tensorflow/core/framework/variant_op_registry.h"
|
||
|
|
+#include "tensorflow/core/framework/variant_tensor_data.h"
|
||
|
|
+#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||
|
|
+#include "tensorflow/core/util/tensor_ops_util.h"
|
||
|
|
+
|
||
|
|
+namespace tensorflow {
|
||
|
|
+
|
||
|
|
+// Class used to store a RaggedTensor as a Variant scalar.
|
||
|
|
+class RaggedTensorVariant {
|
||
|
|
+ public:
|
||
|
|
+ RaggedTensorVariant() {}
|
||
|
|
+ RaggedTensorVariant(Tensor values, const std::vector<Tensor>& nested_splits)
|
||
|
|
+ : values_(std::move(values)), nested_splits_(nested_splits) {}
|
||
|
|
+
|
||
|
|
+ // Variant support methods.
|
||
|
|
+ string TypeName() const;
|
||
|
|
+ string DebugString() const;
|
||
|
|
+ void Encode(VariantTensorData* data) const;
|
||
|
|
+ bool Decode(const VariantTensorData& data);
|
||
|
|
+
|
||
|
|
+ // The flat_values of the RaggedTensor.
|
||
|
|
+ const Tensor& values() const { return values_; }
|
||
|
|
+ Tensor* mutable_values() { return &values_; }
|
||
|
|
+ void set_values(const Tensor& new_values) { values_ = new_values; }
|
||
|
|
+
|
||
|
|
+ // The nested row_splits of the RaggedTensor.
|
||
|
|
+ int ragged_rank() const { return nested_splits_.size(); }
|
||
|
|
+ const std::vector<Tensor>& nested_splits() const { return nested_splits_; }
|
||
|
|
+ std::vector<Tensor>* mutable_nested_splits() { return &nested_splits_; }
|
||
|
|
+ const Tensor& splits(int i) const { return nested_splits_[i]; }
|
||
|
|
+ Tensor* mutable_splits(int i) { return &nested_splits_[i]; }
|
||
|
|
+ void set_nested_splits(const std::vector<Tensor>& nested_splits) {
|
||
|
|
+ nested_splits_ = nested_splits;
|
||
|
|
+ }
|
||
|
|
+ void append_splits(const Tensor& splits) { nested_splits_.push_back(splits); }
|
||
|
|
+
|
||
|
|
+ private:
|
||
|
|
+ Tensor values_;
|
||
|
|
+ std::vector<Tensor> nested_splits_;
|
||
|
|
+};
|
||
|
|
+
|
||
|
|
+template <typename Device>
|
||
|
|
+Status RaggedTensorVariantZerosLike(OpKernelContext* c,
|
||
|
|
+ const RaggedTensorVariant& x,
|
||
|
|
+ RaggedTensorVariant* y) {
|
||
|
|
+ y->set_nested_splits(x.nested_splits());
|
||
|
|
+ TF_RETURN_IF_ERROR(
|
||
|
|
+ ZerosLikeTensor<Device>(c, x.values(), y->mutable_values()));
|
||
|
|
+ return Status::OK();
|
||
|
|
+}
|
||
|
|
+
|
||
|
|
+template <typename Device>
|
||
|
|
+Status RaggedTensorVariantBinaryAdd(OpKernelContext* c,
|
||
|
|
+ const RaggedTensorVariant& x,
|
||
|
|
+ const RaggedTensorVariant& y,
|
||
|
|
+ RaggedTensorVariant* out) {
|
||
|
|
+ if (x.values().dtype() != y.values().dtype()) {
|
||
|
|
+ return errors::InvalidArgument(
|
||
|
|
+ "Can't add RaggedTensorVariants of different dtypes. One is ",
|
||
|
|
+ DataTypeString(x.values().dtype()), " and the other is ",
|
||
|
|
+ DataTypeString(y.values().dtype()));
|
||
|
|
+ }
|
||
|
|
+ if (x.ragged_rank() != y.ragged_rank()) {
|
||
|
|
+ return errors::InvalidArgument(
|
||
|
|
+ "Can't add RaggedTensorVariants of different ragged rank. ", "One is ",
|
||
|
|
+ x.ragged_rank(), " and the other is ", y.ragged_rank());
|
||
|
|
+ }
|
||
|
|
+ for (int i = 0; i < x.ragged_rank(); ++i) {
|
||
|
|
+ if (TensorKey(x.splits(i)) != TensorKey(y.splits(i))) {
|
||
|
|
+ return errors::InvalidArgument(
|
||
|
|
+ "Can't add RaggedTensorVariants with different row_splits.");
|
||
|
|
+ }
|
||
|
|
+ }
|
||
|
|
+ out->set_nested_splits(x.nested_splits());
|
||
|
|
+ TF_RETURN_IF_ERROR(BinaryAddTensors<Device>(c, x.values(), y.values(),
|
||
|
|
+ out->mutable_values()));
|
||
|
|
+ return Status::OK();
|
||
|
|
+}
|
||
|
|
+
|
||
|
|
+} // namespace tensorflow
|
||
|
|
+
|
||
|
|
+#endif // TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_
|
||
|
|
diff --git a/tensorflow/core/ops/ragged_conversion_ops.cc b/tensorflow/core/ops/ragged_conversion_ops.cc
|
||
|
|
index 6bee189c..8512bcf3 100644
|
||
|
|
--- a/tensorflow/core/ops/ragged_conversion_ops.cc
|
||
|
|
+++ b/tensorflow/core/ops/ragged_conversion_ops.cc
|
||
|
|
@@ -92,7 +92,8 @@ tensorflow::Status ValidateRowPartitionTypesAndShapes(
|
||
|
|
Status RaggedTensorToSparseShapeFn(InferenceContext* c);
|
||
|
|
Status RaggedTensorToVariantShapeFn(InferenceContext* c);
|
||
|
|
Status RaggedTensorFromVariantShapeFn(InferenceContext* c);
|
||
|
|
-tensorflow::Status RaggedTensorToTensorShapeFn(InferenceContext* c);
|
||
|
|
+Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c);
|
||
|
|
+Status RaggedTensorToTensorShapeFn(InferenceContext* c);
|
||
|
|
|
||
|
|
//==============================================================================
|
||
|
|
// Registered Ops
|
||
|
|
@@ -129,6 +130,15 @@ REGISTER_OP("RaggedTensorFromVariant")
|
||
|
|
.Attr("Tsplits: {int32, int64} = DT_INT64")
|
||
|
|
.SetShapeFn(RaggedTensorFromVariantShapeFn);
|
||
|
|
|
||
|
|
+REGISTER_OP("RaggedTensorToVariantGradient")
|
||
|
|
+ .Input("encoded_ragged_grad: variant")
|
||
|
|
+ .Input("row_splits: Tsplits")
|
||
|
|
+ .Input("dense_values_shape: int32")
|
||
|
|
+ .Output("dense_values_grad: Tvalues")
|
||
|
|
+ .Attr("Tvalues: type")
|
||
|
|
+ .Attr("Tsplits: {int32, int64} = DT_INT64")
|
||
|
|
+ .SetShapeFn(RaggedTensorToVariantGradientShapeFn);
|
||
|
|
+
|
||
|
|
REGISTER_OP("RaggedTensorToTensor")
|
||
|
|
.Attr("T: type")
|
||
|
|
.Attr("Tindex: {int64, int32}")
|
||
|
|
@@ -201,6 +211,14 @@ Status RaggedTensorToVariantShapeFn(InferenceContext* c) {
|
||
|
|
return Status::OK();
|
||
|
|
}
|
||
|
|
|
||
|
|
+Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c) {
|
||
|
|
+ ShapeHandle shape;
|
||
|
|
+ TF_RETURN_IF_ERROR(
|
||
|
|
+ c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(2, &shape));
|
||
|
|
+ c->set_output(0, shape);
|
||
|
|
+ return Status::OK();
|
||
|
|
+}
|
||
|
|
+
|
||
|
|
Status RaggedTensorFromVariantShapeFn(InferenceContext* c) {
|
||
|
|
int64 input_ragged_rank;
|
||
|
|
TF_RETURN_IF_ERROR(
|
||
|
|
diff --git a/tensorflow/python/ops/ragged/BUILD b/tensorflow/python/ops/ragged/BUILD
|
||
|
|
index 95e5602a..34372160 100644
|
||
|
|
--- a/tensorflow/python/ops/ragged/BUILD
|
||
|
|
+++ b/tensorflow/python/ops/ragged/BUILD
|
||
|
|
@@ -507,6 +507,7 @@ py_test(
|
||
|
|
"//tensorflow/python:framework_ops",
|
||
|
|
"//tensorflow/python:framework_test_lib",
|
||
|
|
"//tensorflow/python:platform_test",
|
||
|
|
+ "//tensorflow/python:tensor_array_grad",
|
||
|
|
"//tensorflow/python:tensor_shape",
|
||
|
|
"//tensorflow/python:tensor_spec",
|
||
|
|
"//tensorflow/python/data/ops:dataset_ops",
|
||
|
|
--
|
||
|
|
2.27.0
|
||
|
|
|