tensorflow/CVE-2021-29516-3.patch

905 lines
40 KiB
Diff
Raw Normal View History

2021-09-13 10:32:00 +08:00
From be6b1fdb0699d4000b70ad32cc23d1503e5c7511 Mon Sep 17 00:00:00 2001
From: Edward Loper <edloper@google.com>
Date: Wed, 14 Oct 2020 09:41:17 -0700
Subject: [PATCH 1/1] Added gradients for RaggedTensorToVariant and
RaggedTensorFromVariant. (This allows gradients to pass through map_fn when
it is applied to ragged tensors.)
PiperOrigin-RevId: 337108621
Change-Id: I73d5f3296181877f0cc4c7a6273b693bcf8310ab
---
tensorflow/core/kernels/BUILD | 15 ++
.../kernels/ragged_tensor_from_variant_op.cc | 164 +++++++---------
.../kernels/ragged_tensor_to_variant_op.cc | 180 +++++++++++-------
.../core/kernels/ragged_tensor_variant.cc | 86 +++++++++
.../core/kernels/ragged_tensor_variant.h | 110 +++++++++++
tensorflow/core/ops/ragged_conversion_ops.cc | 20 +-
tensorflow/python/ops/ragged/BUILD | 1 +
9 files changed, 478 insertions(+), 172 deletions(-)
create mode 100644 tensorflow/core/framework/tensor_key.h
create mode 100644 tensorflow/core/kernels/ragged_tensor_variant.cc
create mode 100644 tensorflow/core/kernels/ragged_tensor_variant.h
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index f5a480b3..12adb2b2 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1529,10 +1529,22 @@ tf_cc_test(
],
)
+cc_library(
+ name = "ragged_tensor_variant",
+ srcs = ["ragged_tensor_variant.cc"],
+ hdrs = ["ragged_tensor_variant.h"],
+ deps = [
+ ":cwise_op",
+ "//tensorflow/core:framework",
+ ],
+)
+
tf_kernel_library(
name = "ragged_tensor_to_variant_op",
srcs = ["ragged_tensor_to_variant_op.cc"],
deps = [
+ ":concat_lib",
+ ":ragged_tensor_variant",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
@@ -1542,6 +1554,7 @@ tf_kernel_library(
name = "ragged_tensor_from_variant_op",
srcs = ["ragged_tensor_from_variant_op.cc"],
deps = [
+ ":ragged_tensor_variant",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
@@ -1554,6 +1567,7 @@ tf_cc_test(
deps = [
":ops_testutil",
":ragged_tensor_to_variant_op",
+ ":ragged_tensor_variant",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@@ -1570,6 +1584,7 @@ tf_cc_test(
deps = [
":ops_testutil",
":ragged_tensor_from_variant_op",
+ ":ragged_tensor_variant",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:test",
diff --git a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
index d7b6a89a..fa8853af 100644
--- a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
+++ b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
@@ -20,110 +20,76 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
+#include "tensorflow/core/kernels/ragged_tensor_variant.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
namespace {
-struct RaggedTensor {
- Tensor values;
- std::vector<Tensor> nested_splits;
-};
-
-Status RaggedComponentsFromVariant(const Tensor& encoded_variant,
- int ragged_rank, DataType value_dtype,
- DataType split_dtype,
- std::vector<RaggedTensor>* decoded_ragged) {
+Status RaggedComponentsFromVariant(
+ const Tensor& encoded_variant, int ragged_rank, DataType value_dtype,
+ DataType split_dtype, std::vector<RaggedTensorVariant>* decoded_ragged) {
const auto& flat_variants = encoded_variant.flat<Variant>();
- decoded_ragged->resize(flat_variants.size());
- // Step 1: Extract the 1-D DT_VARIANT Tensor from each Variant element in the
- // input.
+ decoded_ragged->reserve(flat_variants.size());
+
for (int i = 0; i < flat_variants.size(); i++) {
const auto& flat_variant = flat_variants(i);
- const Tensor* encoded_list = flat_variant.get<Tensor>();
- if (encoded_list == nullptr) {
+ const RaggedTensorVariant* decoded =
+ flat_variant.get<RaggedTensorVariant>();
+ if (decoded == nullptr) {
return errors::InvalidArgument(
"Input Variant element at index ", i,
- " doesn't hold a Tensor: ", flat_variant.DebugString());
+ " doesn't hold a RaggedTensorVariant: ", flat_variant.DebugString());
}
- if (encoded_list->dims() != 1) {
+ decoded_ragged->push_back(*decoded);
+ decoded = &decoded_ragged->back();
+ // Check ragged rank & types
+ if (decoded->ragged_rank() != ragged_rank) {
return errors::InvalidArgument(
- "Encoded input Variant must have rank 1, but found rank: ",
- encoded_list->dims(),
- ". encoded input Variant: ", encoded_list->DebugString());
+ "Encoded input RaggedTensorVariant has ragged_rank=",
+ decoded->ragged_rank(), ". Expected ragged_rank=", ragged_rank, ".");
}
- if (encoded_list->NumElements() != (ragged_rank + 1) &&
- encoded_list->NumElements() != 1) {
+ if (decoded->values().dtype() != value_dtype) {
return errors::InvalidArgument(
- "Encoded input Variant must hold either input_ragged_rank + 1 "
- "Tensors or an empty Tensor (zero splits Tensors, 1 values Tensor), "
- "input_ragged_rank: ",
- ragged_rank,
- ", encoded input Variant: ", encoded_list->DebugString());
+ "Expected values Tensor dtype: ", DataTypeString(value_dtype),
+ ", found: ", DataTypeString(decoded->values().dtype()));
}
- const auto& input_vec = encoded_list->vec<Variant>();
-
- // Step 2: Get the splits and value Tensors from the 1-D DT_VARIANT Tensor
- // to create the component RaggedTensors.
- (*decoded_ragged)[i].nested_splits.reserve(ragged_rank);
- for (int j = 0; j < ragged_rank; j++) {
- const Tensor* split_tensor = input_vec(j).get<Tensor>();
- if (split_tensor == nullptr) {
- return errors::InvalidArgument(
- "Encoded scalar element at index ", i,
- " doesn't have a splits Tensor at split_index ", j, ": ",
- input_vec(j).DebugString());
- }
- Tensor splits_tensor = *split_tensor;
- if (splits_tensor.dtype() != split_dtype) {
+ if (decoded->values().dims() < 1) {
+ return errors::InvalidArgument(
+ "Ragged values must have rank >= 1; encoded scalar element at index ",
+ i, " has values Tensor: ", decoded->values().DebugString());
+ }
+ for (const auto& splits : decoded->nested_splits()) {
+ if (splits.dtype() != split_dtype) {
return errors::InvalidArgument(
- "Expected splits Tensor dtype: ", split_dtype,
- ", found: ", splits_tensor.dtype());
+ "Expected row_splits Tensor dtype: ", DataTypeString(split_dtype),
+ ", found: ", DataTypeString(splits.dtype()));
}
- if (splits_tensor.dims() != 1) {
+ if (splits.dims() != 1) {
return errors::InvalidArgument(
"Ragged splits must have rank 1; encoded scalar element at index ",
- i, " has splits Tensor at split_index ", j, ": ",
- splits_tensor.DebugString());
+ i, " has splits Tensor ", splits.DebugString());
}
- (*decoded_ragged)[i].nested_splits.push_back(splits_tensor);
- }
- const Tensor* values_tensor = input_vec(ragged_rank).get<Tensor>();
- if (values_tensor == nullptr) {
- return errors::InvalidArgument("Encoded scalar element at index ", i,
- " doesn't have a values Tensor: ",
- input_vec(ragged_rank).DebugString());
- }
- if (values_tensor->dtype() != value_dtype) {
- return errors::InvalidArgument(
- "Expected values Tensor dtype: ", DataTypeString(value_dtype),
- ", found: ", DataTypeString(values_tensor->dtype()));
- }
- if (values_tensor->dims() < 1) {
- return errors::InvalidArgument(
- "Ragged values must have rank >= 1; encoded scalar element at index ",
- i, " has values Tensor: ", values_tensor->DebugString());
}
- (*decoded_ragged)[i].values = *values_tensor;
}
return Status::OK();
}
template <typename VALUE_TYPE, typename SPLIT_TYPE>
Status NestedStackRaggedTensors(
- const std::vector<RaggedTensor>& ragged_components,
+ const std::vector<RaggedTensorVariant>& ragged_components,
const std::vector<int>& nested_dim_sizes, const int input_ragged_rank,
- const int output_ragged_rank, RaggedTensor* output_ragged) {
- output_ragged->nested_splits.reserve(output_ragged_rank);
+ const int output_ragged_rank, RaggedTensorVariant* output_ragged) {
+ output_ragged->mutable_nested_splits()->reserve(output_ragged_rank);
const int dims = nested_dim_sizes.size();
// Populate first `dims - 1` splits.
for (int i = 0; i < dims - 1; i++) {
int dims_splits_size = nested_dim_sizes[i] + 1;
- output_ragged->nested_splits.push_back(Tensor(
- DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({dims_splits_size})));
- auto splits_vec = output_ragged->nested_splits[i].vec<SPLIT_TYPE>();
+ output_ragged->append_splits(Tensor(DataTypeToEnum<SPLIT_TYPE>::value,
+ TensorShape({dims_splits_size})));
+ auto splits_vec = output_ragged->mutable_splits(i)->vec<SPLIT_TYPE>();
int split_diff = nested_dim_sizes[i + 1];
for (int j = 0; j < dims_splits_size; j++) {
splits_vec(j) = j * split_diff;
@@ -132,15 +98,15 @@ Status NestedStackRaggedTensors(
// Populate `dims`-th split.
int splits_size = ragged_components.size() + 1;
- output_ragged->nested_splits.push_back(
+ output_ragged->append_splits(
Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({splits_size})));
auto dims_splits_vec =
- output_ragged->nested_splits[dims - 1].vec<SPLIT_TYPE>();
+ output_ragged->mutable_splits(dims - 1)->vec<SPLIT_TYPE>();
dims_splits_vec(0) = 0;
for (int i = 0; i < ragged_components.size(); i++) {
- int split_val = ragged_components[i].values.shape().dim_size(0);
- if (input_ragged_rank != 0 && !ragged_components[i].nested_splits.empty()) {
- split_val = ragged_components[i].nested_splits[0].NumElements() - 1;
+ int split_val = ragged_components[i].values().shape().dim_size(0);
+ if (input_ragged_rank != 0 && ragged_components[i].ragged_rank() > 0) {
+ split_val = ragged_components[i].splits(0).NumElements() - 1;
}
dims_splits_vec(i + 1) = dims_splits_vec(i) + split_val;
}
@@ -150,24 +116,24 @@ Status NestedStackRaggedTensors(
int split_index = dims + i;
int split_size = 1;
for (int j = 0; j < ragged_components.size(); j++) {
- if (!ragged_components[j].nested_splits.empty()) {
- split_size += ragged_components[j].nested_splits[i].NumElements() - 1;
+ if (!ragged_components[j].nested_splits().empty()) {
+ split_size += ragged_components[j].splits(i).NumElements() - 1;
}
}
- output_ragged->nested_splits.push_back(
+ output_ragged->append_splits(
Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({split_size})));
auto splits_vec =
- output_ragged->nested_splits[split_index].vec<SPLIT_TYPE>();
+ output_ragged->mutable_splits(split_index)->vec<SPLIT_TYPE>();
splits_vec(0) = 0;
SPLIT_TYPE last_split_value = 0;
int index = 1;
for (int j = 0; j < ragged_components.size(); j++) {
- if (ragged_components[j].nested_splits.empty()) {
+ if (ragged_components[j].nested_splits().empty()) {
// Corner case: empty row. e.g [ [[x], [x]], [] ]
continue;
}
auto component_splits_vec =
- ragged_components[j].nested_splits[i].vec<SPLIT_TYPE>();
+ ragged_components[j].splits(i).vec<SPLIT_TYPE>();
for (int k = 1; k < component_splits_vec.size(); k++, index++) {
splits_vec(index) = component_splits_vec(k) + last_split_value;
}
@@ -187,35 +153,35 @@ Status NestedStackRaggedTensors(
if (ragged_components.empty()) {
component_values_shape = TensorShape({0});
} else {
- component_values_shape = ragged_components[0].values.shape();
+ component_values_shape = ragged_components[0].values().shape();
}
// Populate values.
int values_size = component_values_shape.dim_size(0);
for (int i = 1; i < ragged_components.size(); i++) {
- if (ragged_components[i].values.dims() != component_values_shape.dims()) {
+ if (ragged_components[i].values().dims() != component_values_shape.dims()) {
return errors::InvalidArgument(
"Rank of values must match for all "
"components; values shape at index 0: ",
component_values_shape.DebugString(), ", values shape at index ", i,
- ": ", ragged_components[i].values.shape().DebugString());
+ ": ", ragged_components[i].values().shape().DebugString());
}
- values_size += ragged_components[i].values.shape().dim_size(0);
+ values_size += ragged_components[i].values().shape().dim_size(0);
}
component_values_shape.set_dim(0, values_size);
- output_ragged->values =
- Tensor(DataTypeToEnum<VALUE_TYPE>::value, component_values_shape);
+ output_ragged->set_values(
+ Tensor(DataTypeToEnum<VALUE_TYPE>::value, component_values_shape));
auto output_values_flat =
- output_ragged->values.flat_outer_dims<VALUE_TYPE, 2>();
+ output_ragged->mutable_values()->flat_outer_dims<VALUE_TYPE, 2>();
int values_index = 0;
for (int i = 0; i < ragged_components.size(); i++) {
auto component_values_flat =
- ragged_components[i].values.flat_outer_dims<VALUE_TYPE, 2>();
- int num_inner_elements = ragged_components[i].values.NumElements();
- if (ragged_components[i].values.dim_size(0) > 0) {
- num_inner_elements /= ragged_components[i].values.dim_size(0);
+ ragged_components[i].values().flat_outer_dims<VALUE_TYPE, 2>();
+ int num_inner_elements = ragged_components[i].values().NumElements();
+ if (ragged_components[i].values().dim_size(0) > 0) {
+ num_inner_elements /= ragged_components[i].values().dim_size(0);
}
- for (int j = 0; j < ragged_components[i].values.dim_size(0);
+ for (int j = 0; j < ragged_components[i].values().dim_size(0);
j++, values_index++) {
for (int k = 0; k < num_inner_elements; k++) {
output_values_flat(values_index, k) = component_values_flat(j, k);
@@ -265,7 +231,7 @@ class RaggedTensorFromVariantOp : public OpKernel {
// Decode all variants.
const auto value_dtype = DataTypeToEnum<VALUE_TYPE>::v();
const auto split_dtype = DataTypeToEnum<SPLIT_TYPE>::v();
- std::vector<RaggedTensor> decoded_components;
+ std::vector<RaggedTensorVariant> decoded_components;
OP_REQUIRES_OK(context, RaggedComponentsFromVariant(
encoded_variant, input_ragged_rank_,
value_dtype, split_dtype, &decoded_components));
@@ -281,7 +247,7 @@ class RaggedTensorFromVariantOp : public OpKernel {
for (int i = 0; i < encoded_variant.dims(); i++) {
encoded_dim_sizes[i] = encoded_variant.dim_size(i);
}
- RaggedTensor output_ragged;
+ RaggedTensorVariant output_ragged;
OP_REQUIRES_OK(
context, NestedStackRaggedTensors<VALUE_TYPE, SPLIT_TYPE>(
decoded_components, encoded_dim_sizes, input_ragged_rank_,
@@ -296,15 +262,15 @@ class RaggedTensorFromVariantOp : public OpKernel {
int output_ragged_rank_;
void ReturnRaggedTensor(OpKernelContext* context,
- RaggedTensor ragged_tensor) {
- int ragged_rank = ragged_tensor.nested_splits.size();
+ const RaggedTensorVariant& ragged_tensor) {
+ int ragged_rank = ragged_tensor.ragged_rank();
OpOutputList splits_out;
OP_REQUIRES_OK(context,
context->output_list("output_nested_splits", &splits_out));
for (int i = 0; i < ragged_rank; i++) {
- splits_out.set(i, ragged_tensor.nested_splits[i]);
+ splits_out.set(i, ragged_tensor.splits(i));
}
- context->set_output(ragged_rank, ragged_tensor.values);
+ context->set_output(ragged_rank, ragged_tensor.values());
}
};
diff --git a/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc
index 3190534b..a60e5c62 100644
--- a/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc
+++ b/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc
@@ -18,50 +18,38 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
+#include "tensorflow/core/framework/variant_op_registry.h"
+#include "tensorflow/core/kernels/concat_lib.h"
+#include "tensorflow/core/kernels/ragged_tensor_variant.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/tensor_ops_util.h"
namespace tensorflow {
namespace {
-struct RaggedTensor {
- Tensor values;
- std::vector<Tensor> nested_splits;
-};
-
-Status RaggedToVariant(const RaggedTensor& ragged, Tensor* encoded_list) {
- // Encode as a rank-1 Variant Tensor.
- int ragged_rank = ragged.nested_splits.size();
- *encoded_list = Tensor(DT_VARIANT, TensorShape({ragged_rank + 1}));
- auto encoded_vec = encoded_list->vec<Variant>();
- for (int i = 0; i < ragged_rank; i++) {
- encoded_vec(i) = ragged.nested_splits[i];
- }
- encoded_vec(ragged_rank) = ragged.values;
- return Status::OK();
-}
-
template <typename VALUE_TYPE, typename SPLIT_TYPE>
-Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
- std::vector<RaggedTensor>* ragged_components) {
+Status UnbatchRaggedZerothDim(
+ const RaggedTensorVariant& batched_ragged,
+ std::vector<RaggedTensorVariant>* ragged_components) {
// Set up the component Ragged Tensors.
- int ragged_rank = batched_ragged.nested_splits.size();
- auto batched_splits_top_vec =
- batched_ragged.nested_splits[0].vec<SPLIT_TYPE>();
+ int ragged_rank = batched_ragged.ragged_rank();
+ auto batched_splits_top_vec = batched_ragged.splits(0).vec<SPLIT_TYPE>();
int num_components = batched_splits_top_vec.size() - 1;
int num_splits = ragged_rank - 1;
ragged_components->resize(num_components);
- for (RaggedTensor ragged_component : *ragged_components) {
- ragged_component.nested_splits.reserve(num_splits);
+ for (RaggedTensorVariant& ragged_component : *ragged_components) {
+ ragged_component.mutable_nested_splits()->reserve(num_splits);
}
- const auto& batched_flat = batched_ragged.values.flat<VALUE_TYPE>();
- int num_inner_elems = batched_ragged.values.NumElements();
- if (batched_ragged.values.dim_size(0) > 1) {
- num_inner_elems /= batched_ragged.values.dim_size(0);
+ const auto& batched_flat = batched_ragged.values().flat<VALUE_TYPE>();
+ int num_inner_elems = batched_ragged.values().NumElements();
+ if (batched_ragged.values().dim_size(0) > 1) {
+ num_inner_elems /= batched_ragged.values().dim_size(0);
}
- TensorShape values_shape = batched_ragged.values.shape();
+ TensorShape values_shape = batched_ragged.values().shape();
// Corner case: ragged_rank == 1, e.g. [[1, 2, 3], [4, 5]]
if (num_splits == 0) {
@@ -70,10 +58,10 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
int limit = batched_splits_top_vec(i + 1);
int num_values = limit - start;
values_shape.set_dim(0, num_values);
- (*ragged_components)[i].values =
- Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape);
+ (*ragged_components)[i].set_values(
+ Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
auto ragged_component_values_flat =
- (*ragged_components)[i].values.flat<VALUE_TYPE>();
+ (*ragged_components)[i].mutable_values()->flat<VALUE_TYPE>();
for (int j = 0; j < num_values * num_inner_elems; j++) {
ragged_component_values_flat(j) =
batched_flat(j + start * num_inner_elems);
@@ -86,8 +74,7 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
std::vector<typename TTypes<SPLIT_TYPE>::ConstVec> batched_splits_vec;
batched_splits_vec.reserve(ragged_rank);
for (int i = 0; i < ragged_rank; i++) {
- batched_splits_vec.push_back(
- batched_ragged.nested_splits[i].vec<SPLIT_TYPE>());
+ batched_splits_vec.push_back(batched_ragged.splits(i).vec<SPLIT_TYPE>());
}
std::vector<int> index(num_splits, 1);
std::vector<int> ragged_component_values_size(num_components, 0);
@@ -104,10 +91,10 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
int last_index = ragged_component_splits_vec[j - 1].size() - 1;
split_size = ragged_component_splits_vec[j - 1](last_index) + 1;
}
- (*ragged_components)[i].nested_splits.push_back(
+ (*ragged_components)[i].append_splits(
Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({split_size})));
ragged_component_splits_vec.push_back(
- (*ragged_components)[i].nested_splits[j].vec<SPLIT_TYPE>());
+ (*ragged_components)[i].mutable_splits(j)->vec<SPLIT_TYPE>());
SPLIT_TYPE last_split_value = batched_splits_vec[j + 1](index[j] - 1);
ragged_component_splits_vec[j](0) = 0;
for (int k = 1; k < split_size; k++, index[j]++) {
@@ -125,10 +112,10 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
for (int i = 0; i < num_components; i++) {
int num_values = ragged_component_values_size[i];
values_shape.set_dim(0, num_values);
- (*ragged_components)[i].values =
- Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape);
+ (*ragged_components)[i].set_values(
+ Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
auto ragged_component_values_flat =
- (*ragged_components)[i].values.flat<VALUE_TYPE>();
+ (*ragged_components)[i].mutable_values()->flat<VALUE_TYPE>();
for (int j = 0; j < num_values * num_inner_elems; j++, value_index++) {
ragged_component_values_flat(j) = batched_flat(value_index);
}
@@ -152,24 +139,21 @@ class RaggedTensorToVariantOp : public OpKernel {
OP_REQUIRES_OK(context, context->input_list("rt_nested_splits",
&ragged_nested_splits_in));
const int ragged_nested_splits_len = ragged_nested_splits_in.size();
- RaggedTensor batched_ragged_input;
+ RaggedTensorVariant batched_ragged_input;
// Read ragged_values input.
- batched_ragged_input.values = context->input(ragged_nested_splits_len);
- batched_ragged_input.nested_splits.reserve(ragged_nested_splits_len);
+ batched_ragged_input.set_values(context->input(ragged_nested_splits_len));
+ batched_ragged_input.mutable_nested_splits()->reserve(
+ ragged_nested_splits_len);
for (int i = 0; i < ragged_nested_splits_len; i++) {
- batched_ragged_input.nested_splits.push_back(ragged_nested_splits_in[i]);
+ batched_ragged_input.append_splits(ragged_nested_splits_in[i]);
}
if (!batched_input_) {
- // Encode the input as is.
- Tensor encoded_list;
- OP_REQUIRES_OK(context,
- RaggedToVariant(batched_ragged_input, &encoded_list));
// Encode as a Scalar Variant Tensor.
Tensor* encoded_scalar;
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}),
&encoded_scalar));
- encoded_scalar->scalar<Variant>()() = std::move(encoded_list);
+ encoded_scalar->scalar<Variant>()() = std::move(batched_ragged_input);
return;
}
@@ -180,24 +164,19 @@ class RaggedTensorToVariantOp : public OpKernel {
"received rt_nested_splits of length 0."));
// Unbatch the Ragged Tensor and encode the components.
- std::vector<RaggedTensor> ragged_components;
+ std::vector<RaggedTensorVariant> unbatched_ragged_input;
OP_REQUIRES_OK(context, UnbatchRaggedZerothDim<VALUE_TYPE, SPLIT_TYPE>(
- batched_ragged_input, &ragged_components));
- std::vector<Tensor> encoded_components(ragged_components.size());
- for (int i = 0; i < ragged_components.size(); i++) {
- OP_REQUIRES_OK(context, RaggedToVariant(ragged_components[i],
- &encoded_components[i]));
- }
+ batched_ragged_input, &unbatched_ragged_input));
// Bundle the encoded scalar Variant Tensors into a rank-1 Variant Tensor.
- Tensor* encoded_ragged;
- int output_size = ragged_components.size();
+ Tensor* encoded_vector;
+ int output_size = unbatched_ragged_input.size();
OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape({output_size}),
- &encoded_ragged));
- auto encoded_ragged_vec = encoded_ragged->vec<Variant>();
+ &encoded_vector));
+ auto encoded_vector_t = encoded_vector->vec<Variant>();
for (int i = 0; i < output_size; i++) {
- encoded_ragged_vec(i) = encoded_components[i];
+ encoded_vector_t(i) = unbatched_ragged_input[i];
}
}
@@ -205,12 +184,81 @@ class RaggedTensorToVariantOp : public OpKernel {
bool batched_input_;
};
-#define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type) \
- REGISTER_KERNEL_BUILDER(Name("RaggedTensorToVariant") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<value_type>("Tvalues") \
- .TypeConstraint<split_type>("Tsplits"), \
- RaggedTensorToVariantOp<value_type, split_type>);
+template <typename VALUE_TYPE, typename SPLIT_TYPE>
+class RaggedTensorToVariantGradientOp : public OpKernel {
+ public:
+ using OpKernel::OpKernel;
+
+ void Compute(OpKernelContext* context) override {
+ // Read inputs.
+ Tensor encoded_variant = context->input(0);
+ Tensor row_splits = context->input(1);
+ auto flat_row_splits = row_splits.flat<SPLIT_TYPE>();
+ TensorShape dense_values_shape;
+ OP_REQUIRES_OK(context,
+ TensorShapeUtils::MakeShape(context->input(2).vec<int32>(),
+ &dense_values_shape));
+
+ const auto& flat_variants = encoded_variant.flat<Variant>();
+
+ // Get a Tensor containing the flat_values for each variant.
+ std::vector<Tensor> values;
+ for (int i = 0; i < flat_variants.size(); ++i) {
+ if (const auto* encoded = flat_variants(i).get<RaggedTensorVariant>()) {
+ values.push_back(encoded->values());
+ } else {
+ // Missing value: this happens if only some of the variant values
+ // generated by ragged_tensor_to_variant impacted the value that we're
+ // calculating the gradient for. In this case, we will see a
+ // default-constructed variant; so treat it as a zero tensor with the
+ // appropriate shape.
+ const auto value_dtype = DataTypeToEnum<VALUE_TYPE>::v();
+ int piece_size = flat_row_splits(i + 1) - flat_row_splits(i);
+ TensorShape zeros_shape = dense_values_shape;
+ zeros_shape.set_dim(0, piece_size);
+ Tensor zero(value_dtype, zeros_shape);
+ zero.flat<VALUE_TYPE>() =
+ zero.flat<VALUE_TYPE>().constant(VALUE_TYPE());
+ values.push_back(zero);
+ }
+ }
+
+ if (values.size() == 1) {
+ // Just one flat_value tensor: return as-is.
+ context->set_output(0, values[0]);
+ } else {
+ // Multiple flat_values tensors: concatenate them together.
+ using Piece = typename TTypes<VALUE_TYPE, 2>::Matrix;
+ using ConstPiece = typename TTypes<VALUE_TYPE, 2>::ConstMatrix;
+ std::vector<std::unique_ptr<ConstPiece>> pieces;
+ pieces.reserve(values.size());
+ for (const Tensor& t : values) {
+ pieces.emplace_back(
+ new ConstPiece(t.shaped<VALUE_TYPE, 2>({1, t.NumElements()})));
+ }
+ Tensor* out = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, dense_values_shape, &out));
+ Piece out_flat =
+ out->shaped<VALUE_TYPE, 2>({1, dense_values_shape.num_elements()});
+ ConcatCPU<VALUE_TYPE>(context->device(), pieces, &out_flat);
+ }
+ }
+};
+
+#define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type) \
+ REGISTER_KERNEL_BUILDER(Name("RaggedTensorToVariant") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<value_type>("Tvalues") \
+ .TypeConstraint<split_type>("Tsplits"), \
+ RaggedTensorToVariantOp<value_type, split_type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("RaggedTensorToVariantGradient") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<value_type>("Tvalues") \
+ .TypeConstraint<split_type>("Tsplits"), \
+ RaggedTensorToVariantGradientOp<value_type, split_type>);
+
#define REGISTER_KERNELS(value_type) \
REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int32) \
REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int64)
diff --git a/tensorflow/core/kernels/ragged_tensor_variant.cc b/tensorflow/core/kernels/ragged_tensor_variant.cc
new file mode 100644
index 00000000..94663138
--- /dev/null
+++ b/tensorflow/core/kernels/ragged_tensor_variant.cc
@@ -0,0 +1,86 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#define EIGEN_USE_GPU
+#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+
+#include "tensorflow/core/kernels/ragged_tensor_variant.h"
+
+namespace tensorflow {
+
+string RaggedTensorVariant::TypeName() const { return "RaggedTensorVariant"; }
+
+string RaggedTensorVariant::DebugString() const {
+ return absl::StrCat(
+ "RaggedTensorVariant(dtype=", DataTypeString(values_.dtype()),
+ ", ragged_rank=", nested_splits_.size(), ", splits_dtype=",
+ DataTypeString(nested_splits_.empty() ? DT_INVALID
+ : nested_splits_.back().dtype()));
+}
+
+void RaggedTensorVariant::Encode(VariantTensorData* data) const {
+ data->set_type_name(TypeName());
+ for (const auto& splits : nested_splits_) {
+ *data->add_tensors() = splits;
+ }
+ *data->add_tensors() = values_;
+}
+
+bool RaggedTensorVariant::Decode(const VariantTensorData& data) {
+ if (data.tensors_size() < 1) {
+ return false;
+ }
+ nested_splits_.assign(data.tensors().begin(),
+ std::prev(data.tensors().end()));
+ values_ = data.tensors().back();
+ return true;
+}
+
+namespace {
+
+Status RaggedTensorVariantDeviceCopy(
+ const RaggedTensorVariant& from, RaggedTensorVariant* to,
+ const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
+ TF_RETURN_IF_ERROR(copy(from.values(), to->mutable_values()));
+ // TODO(b/170415165) Should we use `copy` to move splits from device<->host?
+ *to->mutable_nested_splits() = from.nested_splits();
+ return Status::OK();
+}
+
+} // namespace
+
+REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(
+ ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, RaggedTensorVariant,
+ RaggedTensorVariantZerosLike<CPUDevice>);
+
+REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(
+ ADD_VARIANT_BINARY_OP, DEVICE_CPU, RaggedTensorVariant,
+ RaggedTensorVariantBinaryAdd<CPUDevice>);
+
+REGISTER_UNARY_VARIANT_DECODE_FUNCTION(RaggedTensorVariant,
+ "RaggedTensorVariant");
+
+#define REGISTER_RAGGED_TENSOR_VARIANT_COPY(DIRECTION) \
+ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
+ RaggedTensorVariant, DIRECTION, RaggedTensorVariantDeviceCopy)
+
+REGISTER_RAGGED_TENSOR_VARIANT_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
+REGISTER_RAGGED_TENSOR_VARIANT_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
+REGISTER_RAGGED_TENSOR_VARIANT_COPY(
+ VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/ragged_tensor_variant.h b/tensorflow/core/kernels/ragged_tensor_variant.h
new file mode 100644
index 00000000..730758a3
--- /dev/null
+++ b/tensorflow/core/kernels/ragged_tensor_variant.h
@@ -0,0 +1,110 @@
+#include "tensorflow/core/framework/tensor_key.h"
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_
+#define TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_
+
+#define EIGEN_USE_THREADS
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#define EIGEN_USE_GPU
+#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+
+#include <vector>
+
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/variant_op_registry.h"
+#include "tensorflow/core/framework/variant_tensor_data.h"
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+#include "tensorflow/core/util/tensor_ops_util.h"
+
+namespace tensorflow {
+
+// Class used to store a RaggedTensor as a Variant scalar.
+class RaggedTensorVariant {
+ public:
+ RaggedTensorVariant() {}
+ RaggedTensorVariant(Tensor values, const std::vector<Tensor>& nested_splits)
+ : values_(std::move(values)), nested_splits_(nested_splits) {}
+
+ // Variant support methods.
+ string TypeName() const;
+ string DebugString() const;
+ void Encode(VariantTensorData* data) const;
+ bool Decode(const VariantTensorData& data);
+
+ // The flat_values of the RaggedTensor.
+ const Tensor& values() const { return values_; }
+ Tensor* mutable_values() { return &values_; }
+ void set_values(const Tensor& new_values) { values_ = new_values; }
+
+ // The nested row_splits of the RaggedTensor.
+ int ragged_rank() const { return nested_splits_.size(); }
+ const std::vector<Tensor>& nested_splits() const { return nested_splits_; }
+ std::vector<Tensor>* mutable_nested_splits() { return &nested_splits_; }
+ const Tensor& splits(int i) const { return nested_splits_[i]; }
+ Tensor* mutable_splits(int i) { return &nested_splits_[i]; }
+ void set_nested_splits(const std::vector<Tensor>& nested_splits) {
+ nested_splits_ = nested_splits;
+ }
+ void append_splits(const Tensor& splits) { nested_splits_.push_back(splits); }
+
+ private:
+ Tensor values_;
+ std::vector<Tensor> nested_splits_;
+};
+
+template <typename Device>
+Status RaggedTensorVariantZerosLike(OpKernelContext* c,
+ const RaggedTensorVariant& x,
+ RaggedTensorVariant* y) {
+ y->set_nested_splits(x.nested_splits());
+ TF_RETURN_IF_ERROR(
+ ZerosLikeTensor<Device>(c, x.values(), y->mutable_values()));
+ return Status::OK();
+}
+
+template <typename Device>
+Status RaggedTensorVariantBinaryAdd(OpKernelContext* c,
+ const RaggedTensorVariant& x,
+ const RaggedTensorVariant& y,
+ RaggedTensorVariant* out) {
+ if (x.values().dtype() != y.values().dtype()) {
+ return errors::InvalidArgument(
+ "Can't add RaggedTensorVariants of different dtypes. One is ",
+ DataTypeString(x.values().dtype()), " and the other is ",
+ DataTypeString(y.values().dtype()));
+ }
+ if (x.ragged_rank() != y.ragged_rank()) {
+ return errors::InvalidArgument(
+ "Can't add RaggedTensorVariants of different ragged rank. ", "One is ",
+ x.ragged_rank(), " and the other is ", y.ragged_rank());
+ }
+ for (int i = 0; i < x.ragged_rank(); ++i) {
+ if (TensorKey(x.splits(i)) != TensorKey(y.splits(i))) {
+ return errors::InvalidArgument(
+ "Can't add RaggedTensorVariants with different row_splits.");
+ }
+ }
+ out->set_nested_splits(x.nested_splits());
+ TF_RETURN_IF_ERROR(BinaryAddTensors<Device>(c, x.values(), y.values(),
+ out->mutable_values()));
+ return Status::OK();
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_
diff --git a/tensorflow/core/ops/ragged_conversion_ops.cc b/tensorflow/core/ops/ragged_conversion_ops.cc
index 6bee189c..8512bcf3 100644
--- a/tensorflow/core/ops/ragged_conversion_ops.cc
+++ b/tensorflow/core/ops/ragged_conversion_ops.cc
@@ -92,7 +92,8 @@ tensorflow::Status ValidateRowPartitionTypesAndShapes(
Status RaggedTensorToSparseShapeFn(InferenceContext* c);
Status RaggedTensorToVariantShapeFn(InferenceContext* c);
Status RaggedTensorFromVariantShapeFn(InferenceContext* c);
-tensorflow::Status RaggedTensorToTensorShapeFn(InferenceContext* c);
+Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c);
+Status RaggedTensorToTensorShapeFn(InferenceContext* c);
//==============================================================================
// Registered Ops
@@ -129,6 +130,15 @@ REGISTER_OP("RaggedTensorFromVariant")
.Attr("Tsplits: {int32, int64} = DT_INT64")
.SetShapeFn(RaggedTensorFromVariantShapeFn);
+REGISTER_OP("RaggedTensorToVariantGradient")
+ .Input("encoded_ragged_grad: variant")
+ .Input("row_splits: Tsplits")
+ .Input("dense_values_shape: int32")
+ .Output("dense_values_grad: Tvalues")
+ .Attr("Tvalues: type")
+ .Attr("Tsplits: {int32, int64} = DT_INT64")
+ .SetShapeFn(RaggedTensorToVariantGradientShapeFn);
+
REGISTER_OP("RaggedTensorToTensor")
.Attr("T: type")
.Attr("Tindex: {int64, int32}")
@@ -201,6 +211,14 @@ Status RaggedTensorToVariantShapeFn(InferenceContext* c) {
return Status::OK();
}
+Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c) {
+ ShapeHandle shape;
+ TF_RETURN_IF_ERROR(
+ c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(2, &shape));
+ c->set_output(0, shape);
+ return Status::OK();
+}
+
Status RaggedTensorFromVariantShapeFn(InferenceContext* c) {
int64 input_ragged_rank;
TF_RETURN_IF_ERROR(
diff --git a/tensorflow/python/ops/ragged/BUILD b/tensorflow/python/ops/ragged/BUILD
index 95e5602a..34372160 100644
--- a/tensorflow/python/ops/ragged/BUILD
+++ b/tensorflow/python/ops/ragged/BUILD
@@ -507,6 +507,7 @@ py_test(
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
+ "//tensorflow/python:tensor_array_grad",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_spec",
"//tensorflow/python/data/ops:dataset_ops",
--
2.27.0