From be6b1fdb0699d4000b70ad32cc23d1503e5c7511 Mon Sep 17 00:00:00 2001 From: Edward Loper Date: Wed, 14 Oct 2020 09:41:17 -0700 Subject: [PATCH 1/1] Added gradients for RaggedTensorToVariant and RaggedTensorFromVariant. (This allows gradients to pass through map_fn when it is applied to ragged tensors.) PiperOrigin-RevId: 337108621 Change-Id: I73d5f3296181877f0cc4c7a6273b693bcf8310ab --- tensorflow/core/kernels/BUILD | 15 ++ .../kernels/ragged_tensor_from_variant_op.cc | 164 +++++++--------- .../kernels/ragged_tensor_to_variant_op.cc | 180 +++++++++++------- .../core/kernels/ragged_tensor_variant.cc | 86 +++++++++ .../core/kernels/ragged_tensor_variant.h | 110 +++++++++++ tensorflow/core/ops/ragged_conversion_ops.cc | 20 +- tensorflow/python/ops/ragged/BUILD | 1 + 9 files changed, 478 insertions(+), 172 deletions(-) create mode 100644 tensorflow/core/framework/tensor_key.h create mode 100644 tensorflow/core/kernels/ragged_tensor_variant.cc create mode 100644 tensorflow/core/kernels/ragged_tensor_variant.h diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index f5a480b3..12adb2b2 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1529,10 +1529,22 @@ tf_cc_test( ], ) +cc_library( + name = "ragged_tensor_variant", + srcs = ["ragged_tensor_variant.cc"], + hdrs = ["ragged_tensor_variant.h"], + deps = [ + ":cwise_op", + "//tensorflow/core:framework", + ], +) + tf_kernel_library( name = "ragged_tensor_to_variant_op", srcs = ["ragged_tensor_to_variant_op.cc"], deps = [ + ":concat_lib", + ":ragged_tensor_variant", "//tensorflow/core:framework", "//tensorflow/core:lib", ], @@ -1542,6 +1554,7 @@ tf_kernel_library( name = "ragged_tensor_from_variant_op", srcs = ["ragged_tensor_from_variant_op.cc"], deps = [ + ":ragged_tensor_variant", "//tensorflow/core:framework", "//tensorflow/core:lib", ], @@ -1554,6 +1567,7 @@ tf_cc_test( deps = [ ":ops_testutil", ":ragged_tensor_to_variant_op", + ":ragged_tensor_variant", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1570,6 +1584,7 @@ tf_cc_test( deps = [ ":ops_testutil", ":ragged_tensor_from_variant_op", + ":ragged_tensor_variant", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:test", diff --git a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc index d7b6a89a..fa8853af 100644 --- a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc +++ b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc @@ -20,110 +20,76 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/kernels/ragged_tensor_variant.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { namespace { -struct RaggedTensor { - Tensor values; - std::vector nested_splits; -}; - -Status RaggedComponentsFromVariant(const Tensor& encoded_variant, - int ragged_rank, DataType value_dtype, - DataType split_dtype, - std::vector* decoded_ragged) { +Status RaggedComponentsFromVariant( + const Tensor& encoded_variant, int ragged_rank, DataType value_dtype, + DataType split_dtype, std::vector* decoded_ragged) { const auto& flat_variants = encoded_variant.flat(); - decoded_ragged->resize(flat_variants.size()); - // Step 1: Extract the 1-D DT_VARIANT Tensor from each Variant element in the - // input. + decoded_ragged->reserve(flat_variants.size()); + for (int i = 0; i < flat_variants.size(); i++) { const auto& flat_variant = flat_variants(i); - const Tensor* encoded_list = flat_variant.get(); - if (encoded_list == nullptr) { + const RaggedTensorVariant* decoded = + flat_variant.get(); + if (decoded == nullptr) { return errors::InvalidArgument( "Input Variant element at index ", i, - " doesn't hold a Tensor: ", flat_variant.DebugString()); + " doesn't hold a RaggedTensorVariant: ", flat_variant.DebugString()); } - if (encoded_list->dims() != 1) { + decoded_ragged->push_back(*decoded); + decoded = &decoded_ragged->back(); + // Check ragged rank & types + if (decoded->ragged_rank() != ragged_rank) { return errors::InvalidArgument( - "Encoded input Variant must have rank 1, but found rank: ", - encoded_list->dims(), - ". encoded input Variant: ", encoded_list->DebugString()); + "Encoded input RaggedTensorVariant has ragged_rank=", + decoded->ragged_rank(), ". Expected ragged_rank=", ragged_rank, "."); } - if (encoded_list->NumElements() != (ragged_rank + 1) && - encoded_list->NumElements() != 1) { + if (decoded->values().dtype() != value_dtype) { return errors::InvalidArgument( - "Encoded input Variant must hold either input_ragged_rank + 1 " - "Tensors or an empty Tensor (zero splits Tensors, 1 values Tensor), " - "input_ragged_rank: ", - ragged_rank, - ", encoded input Variant: ", encoded_list->DebugString()); + "Expected values Tensor dtype: ", DataTypeString(value_dtype), + ", found: ", DataTypeString(decoded->values().dtype())); } - const auto& input_vec = encoded_list->vec(); - - // Step 2: Get the splits and value Tensors from the 1-D DT_VARIANT Tensor - // to create the component RaggedTensors. - (*decoded_ragged)[i].nested_splits.reserve(ragged_rank); - for (int j = 0; j < ragged_rank; j++) { - const Tensor* split_tensor = input_vec(j).get(); - if (split_tensor == nullptr) { - return errors::InvalidArgument( - "Encoded scalar element at index ", i, - " doesn't have a splits Tensor at split_index ", j, ": ", - input_vec(j).DebugString()); - } - Tensor splits_tensor = *split_tensor; - if (splits_tensor.dtype() != split_dtype) { + if (decoded->values().dims() < 1) { + return errors::InvalidArgument( + "Ragged values must have rank >= 1; encoded scalar element at index ", + i, " has values Tensor: ", decoded->values().DebugString()); + } + for (const auto& splits : decoded->nested_splits()) { + if (splits.dtype() != split_dtype) { return errors::InvalidArgument( - "Expected splits Tensor dtype: ", split_dtype, - ", found: ", splits_tensor.dtype()); + "Expected row_splits Tensor dtype: ", DataTypeString(split_dtype), + ", found: ", DataTypeString(splits.dtype())); } - if (splits_tensor.dims() != 1) { + if (splits.dims() != 1) { return errors::InvalidArgument( "Ragged splits must have rank 1; encoded scalar element at index ", - i, " has splits Tensor at split_index ", j, ": ", - splits_tensor.DebugString()); + i, " has splits Tensor ", splits.DebugString()); } - (*decoded_ragged)[i].nested_splits.push_back(splits_tensor); - } - const Tensor* values_tensor = input_vec(ragged_rank).get(); - if (values_tensor == nullptr) { - return errors::InvalidArgument("Encoded scalar element at index ", i, - " doesn't have a values Tensor: ", - input_vec(ragged_rank).DebugString()); - } - if (values_tensor->dtype() != value_dtype) { - return errors::InvalidArgument( - "Expected values Tensor dtype: ", DataTypeString(value_dtype), - ", found: ", DataTypeString(values_tensor->dtype())); - } - if (values_tensor->dims() < 1) { - return errors::InvalidArgument( - "Ragged values must have rank >= 1; encoded scalar element at index ", - i, " has values Tensor: ", values_tensor->DebugString()); } - (*decoded_ragged)[i].values = *values_tensor; } return Status::OK(); } template Status NestedStackRaggedTensors( - const std::vector& ragged_components, + const std::vector& ragged_components, const std::vector& nested_dim_sizes, const int input_ragged_rank, - const int output_ragged_rank, RaggedTensor* output_ragged) { - output_ragged->nested_splits.reserve(output_ragged_rank); + const int output_ragged_rank, RaggedTensorVariant* output_ragged) { + output_ragged->mutable_nested_splits()->reserve(output_ragged_rank); const int dims = nested_dim_sizes.size(); // Populate first `dims - 1` splits. for (int i = 0; i < dims - 1; i++) { int dims_splits_size = nested_dim_sizes[i] + 1; - output_ragged->nested_splits.push_back(Tensor( - DataTypeToEnum::value, TensorShape({dims_splits_size}))); - auto splits_vec = output_ragged->nested_splits[i].vec(); + output_ragged->append_splits(Tensor(DataTypeToEnum::value, + TensorShape({dims_splits_size}))); + auto splits_vec = output_ragged->mutable_splits(i)->vec(); int split_diff = nested_dim_sizes[i + 1]; for (int j = 0; j < dims_splits_size; j++) { splits_vec(j) = j * split_diff; @@ -132,15 +98,15 @@ Status NestedStackRaggedTensors( // Populate `dims`-th split. int splits_size = ragged_components.size() + 1; - output_ragged->nested_splits.push_back( + output_ragged->append_splits( Tensor(DataTypeToEnum::value, TensorShape({splits_size}))); auto dims_splits_vec = - output_ragged->nested_splits[dims - 1].vec(); + output_ragged->mutable_splits(dims - 1)->vec(); dims_splits_vec(0) = 0; for (int i = 0; i < ragged_components.size(); i++) { - int split_val = ragged_components[i].values.shape().dim_size(0); - if (input_ragged_rank != 0 && !ragged_components[i].nested_splits.empty()) { - split_val = ragged_components[i].nested_splits[0].NumElements() - 1; + int split_val = ragged_components[i].values().shape().dim_size(0); + if (input_ragged_rank != 0 && ragged_components[i].ragged_rank() > 0) { + split_val = ragged_components[i].splits(0).NumElements() - 1; } dims_splits_vec(i + 1) = dims_splits_vec(i) + split_val; } @@ -150,24 +116,24 @@ Status NestedStackRaggedTensors( int split_index = dims + i; int split_size = 1; for (int j = 0; j < ragged_components.size(); j++) { - if (!ragged_components[j].nested_splits.empty()) { - split_size += ragged_components[j].nested_splits[i].NumElements() - 1; + if (!ragged_components[j].nested_splits().empty()) { + split_size += ragged_components[j].splits(i).NumElements() - 1; } } - output_ragged->nested_splits.push_back( + output_ragged->append_splits( Tensor(DataTypeToEnum::value, TensorShape({split_size}))); auto splits_vec = - output_ragged->nested_splits[split_index].vec(); + output_ragged->mutable_splits(split_index)->vec(); splits_vec(0) = 0; SPLIT_TYPE last_split_value = 0; int index = 1; for (int j = 0; j < ragged_components.size(); j++) { - if (ragged_components[j].nested_splits.empty()) { + if (ragged_components[j].nested_splits().empty()) { // Corner case: empty row. e.g [ [[x], [x]], [] ] continue; } auto component_splits_vec = - ragged_components[j].nested_splits[i].vec(); + ragged_components[j].splits(i).vec(); for (int k = 1; k < component_splits_vec.size(); k++, index++) { splits_vec(index) = component_splits_vec(k) + last_split_value; } @@ -187,35 +153,35 @@ Status NestedStackRaggedTensors( if (ragged_components.empty()) { component_values_shape = TensorShape({0}); } else { - component_values_shape = ragged_components[0].values.shape(); + component_values_shape = ragged_components[0].values().shape(); } // Populate values. int values_size = component_values_shape.dim_size(0); for (int i = 1; i < ragged_components.size(); i++) { - if (ragged_components[i].values.dims() != component_values_shape.dims()) { + if (ragged_components[i].values().dims() != component_values_shape.dims()) { return errors::InvalidArgument( "Rank of values must match for all " "components; values shape at index 0: ", component_values_shape.DebugString(), ", values shape at index ", i, - ": ", ragged_components[i].values.shape().DebugString()); + ": ", ragged_components[i].values().shape().DebugString()); } - values_size += ragged_components[i].values.shape().dim_size(0); + values_size += ragged_components[i].values().shape().dim_size(0); } component_values_shape.set_dim(0, values_size); - output_ragged->values = - Tensor(DataTypeToEnum::value, component_values_shape); + output_ragged->set_values( + Tensor(DataTypeToEnum::value, component_values_shape)); auto output_values_flat = - output_ragged->values.flat_outer_dims(); + output_ragged->mutable_values()->flat_outer_dims(); int values_index = 0; for (int i = 0; i < ragged_components.size(); i++) { auto component_values_flat = - ragged_components[i].values.flat_outer_dims(); - int num_inner_elements = ragged_components[i].values.NumElements(); - if (ragged_components[i].values.dim_size(0) > 0) { - num_inner_elements /= ragged_components[i].values.dim_size(0); + ragged_components[i].values().flat_outer_dims(); + int num_inner_elements = ragged_components[i].values().NumElements(); + if (ragged_components[i].values().dim_size(0) > 0) { + num_inner_elements /= ragged_components[i].values().dim_size(0); } - for (int j = 0; j < ragged_components[i].values.dim_size(0); + for (int j = 0; j < ragged_components[i].values().dim_size(0); j++, values_index++) { for (int k = 0; k < num_inner_elements; k++) { output_values_flat(values_index, k) = component_values_flat(j, k); @@ -265,7 +231,7 @@ class RaggedTensorFromVariantOp : public OpKernel { // Decode all variants. const auto value_dtype = DataTypeToEnum::v(); const auto split_dtype = DataTypeToEnum::v(); - std::vector decoded_components; + std::vector decoded_components; OP_REQUIRES_OK(context, RaggedComponentsFromVariant( encoded_variant, input_ragged_rank_, value_dtype, split_dtype, &decoded_components)); @@ -281,7 +247,7 @@ class RaggedTensorFromVariantOp : public OpKernel { for (int i = 0; i < encoded_variant.dims(); i++) { encoded_dim_sizes[i] = encoded_variant.dim_size(i); } - RaggedTensor output_ragged; + RaggedTensorVariant output_ragged; OP_REQUIRES_OK( context, NestedStackRaggedTensors( decoded_components, encoded_dim_sizes, input_ragged_rank_, @@ -296,15 +262,15 @@ class RaggedTensorFromVariantOp : public OpKernel { int output_ragged_rank_; void ReturnRaggedTensor(OpKernelContext* context, - RaggedTensor ragged_tensor) { - int ragged_rank = ragged_tensor.nested_splits.size(); + const RaggedTensorVariant& ragged_tensor) { + int ragged_rank = ragged_tensor.ragged_rank(); OpOutputList splits_out; OP_REQUIRES_OK(context, context->output_list("output_nested_splits", &splits_out)); for (int i = 0; i < ragged_rank; i++) { - splits_out.set(i, ragged_tensor.nested_splits[i]); + splits_out.set(i, ragged_tensor.splits(i)); } - context->set_output(ragged_rank, ragged_tensor.values); + context->set_output(ragged_rank, ragged_tensor.values()); } }; diff --git a/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc index 3190534b..a60e5c62 100644 --- a/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc +++ b/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc @@ -18,50 +18,38 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/framework/variant_op_registry.h" +#include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/kernels/ragged_tensor_variant.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/tensor_ops_util.h" namespace tensorflow { namespace { -struct RaggedTensor { - Tensor values; - std::vector nested_splits; -}; - -Status RaggedToVariant(const RaggedTensor& ragged, Tensor* encoded_list) { - // Encode as a rank-1 Variant Tensor. - int ragged_rank = ragged.nested_splits.size(); - *encoded_list = Tensor(DT_VARIANT, TensorShape({ragged_rank + 1})); - auto encoded_vec = encoded_list->vec(); - for (int i = 0; i < ragged_rank; i++) { - encoded_vec(i) = ragged.nested_splits[i]; - } - encoded_vec(ragged_rank) = ragged.values; - return Status::OK(); -} - template -Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged, - std::vector* ragged_components) { +Status UnbatchRaggedZerothDim( + const RaggedTensorVariant& batched_ragged, + std::vector* ragged_components) { // Set up the component Ragged Tensors. - int ragged_rank = batched_ragged.nested_splits.size(); - auto batched_splits_top_vec = - batched_ragged.nested_splits[0].vec(); + int ragged_rank = batched_ragged.ragged_rank(); + auto batched_splits_top_vec = batched_ragged.splits(0).vec(); int num_components = batched_splits_top_vec.size() - 1; int num_splits = ragged_rank - 1; ragged_components->resize(num_components); - for (RaggedTensor ragged_component : *ragged_components) { - ragged_component.nested_splits.reserve(num_splits); + for (RaggedTensorVariant& ragged_component : *ragged_components) { + ragged_component.mutable_nested_splits()->reserve(num_splits); } - const auto& batched_flat = batched_ragged.values.flat(); - int num_inner_elems = batched_ragged.values.NumElements(); - if (batched_ragged.values.dim_size(0) > 1) { - num_inner_elems /= batched_ragged.values.dim_size(0); + const auto& batched_flat = batched_ragged.values().flat(); + int num_inner_elems = batched_ragged.values().NumElements(); + if (batched_ragged.values().dim_size(0) > 1) { + num_inner_elems /= batched_ragged.values().dim_size(0); } - TensorShape values_shape = batched_ragged.values.shape(); + TensorShape values_shape = batched_ragged.values().shape(); // Corner case: ragged_rank == 1, e.g. [[1, 2, 3], [4, 5]] if (num_splits == 0) { @@ -70,10 +58,10 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged, int limit = batched_splits_top_vec(i + 1); int num_values = limit - start; values_shape.set_dim(0, num_values); - (*ragged_components)[i].values = - Tensor(DataTypeToEnum::value, values_shape); + (*ragged_components)[i].set_values( + Tensor(DataTypeToEnum::value, values_shape)); auto ragged_component_values_flat = - (*ragged_components)[i].values.flat(); + (*ragged_components)[i].mutable_values()->flat(); for (int j = 0; j < num_values * num_inner_elems; j++) { ragged_component_values_flat(j) = batched_flat(j + start * num_inner_elems); @@ -86,8 +74,7 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged, std::vector::ConstVec> batched_splits_vec; batched_splits_vec.reserve(ragged_rank); for (int i = 0; i < ragged_rank; i++) { - batched_splits_vec.push_back( - batched_ragged.nested_splits[i].vec()); + batched_splits_vec.push_back(batched_ragged.splits(i).vec()); } std::vector index(num_splits, 1); std::vector ragged_component_values_size(num_components, 0); @@ -104,10 +91,10 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged, int last_index = ragged_component_splits_vec[j - 1].size() - 1; split_size = ragged_component_splits_vec[j - 1](last_index) + 1; } - (*ragged_components)[i].nested_splits.push_back( + (*ragged_components)[i].append_splits( Tensor(DataTypeToEnum::value, TensorShape({split_size}))); ragged_component_splits_vec.push_back( - (*ragged_components)[i].nested_splits[j].vec()); + (*ragged_components)[i].mutable_splits(j)->vec()); SPLIT_TYPE last_split_value = batched_splits_vec[j + 1](index[j] - 1); ragged_component_splits_vec[j](0) = 0; for (int k = 1; k < split_size; k++, index[j]++) { @@ -125,10 +112,10 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged, for (int i = 0; i < num_components; i++) { int num_values = ragged_component_values_size[i]; values_shape.set_dim(0, num_values); - (*ragged_components)[i].values = - Tensor(DataTypeToEnum::value, values_shape); + (*ragged_components)[i].set_values( + Tensor(DataTypeToEnum::value, values_shape)); auto ragged_component_values_flat = - (*ragged_components)[i].values.flat(); + (*ragged_components)[i].mutable_values()->flat(); for (int j = 0; j < num_values * num_inner_elems; j++, value_index++) { ragged_component_values_flat(j) = batched_flat(value_index); } @@ -152,24 +139,21 @@ class RaggedTensorToVariantOp : public OpKernel { OP_REQUIRES_OK(context, context->input_list("rt_nested_splits", &ragged_nested_splits_in)); const int ragged_nested_splits_len = ragged_nested_splits_in.size(); - RaggedTensor batched_ragged_input; + RaggedTensorVariant batched_ragged_input; // Read ragged_values input. - batched_ragged_input.values = context->input(ragged_nested_splits_len); - batched_ragged_input.nested_splits.reserve(ragged_nested_splits_len); + batched_ragged_input.set_values(context->input(ragged_nested_splits_len)); + batched_ragged_input.mutable_nested_splits()->reserve( + ragged_nested_splits_len); for (int i = 0; i < ragged_nested_splits_len; i++) { - batched_ragged_input.nested_splits.push_back(ragged_nested_splits_in[i]); + batched_ragged_input.append_splits(ragged_nested_splits_in[i]); } if (!batched_input_) { - // Encode the input as is. - Tensor encoded_list; - OP_REQUIRES_OK(context, - RaggedToVariant(batched_ragged_input, &encoded_list)); // Encode as a Scalar Variant Tensor. Tensor* encoded_scalar; OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &encoded_scalar)); - encoded_scalar->scalar()() = std::move(encoded_list); + encoded_scalar->scalar()() = std::move(batched_ragged_input); return; } @@ -180,24 +164,19 @@ class RaggedTensorToVariantOp : public OpKernel { "received rt_nested_splits of length 0.")); // Unbatch the Ragged Tensor and encode the components. - std::vector ragged_components; + std::vector unbatched_ragged_input; OP_REQUIRES_OK(context, UnbatchRaggedZerothDim( - batched_ragged_input, &ragged_components)); - std::vector encoded_components(ragged_components.size()); - for (int i = 0; i < ragged_components.size(); i++) { - OP_REQUIRES_OK(context, RaggedToVariant(ragged_components[i], - &encoded_components[i])); - } + batched_ragged_input, &unbatched_ragged_input)); // Bundle the encoded scalar Variant Tensors into a rank-1 Variant Tensor. - Tensor* encoded_ragged; - int output_size = ragged_components.size(); + Tensor* encoded_vector; + int output_size = unbatched_ragged_input.size(); OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({output_size}), - &encoded_ragged)); - auto encoded_ragged_vec = encoded_ragged->vec(); + &encoded_vector)); + auto encoded_vector_t = encoded_vector->vec(); for (int i = 0; i < output_size; i++) { - encoded_ragged_vec(i) = encoded_components[i]; + encoded_vector_t(i) = unbatched_ragged_input[i]; } } @@ -205,12 +184,81 @@ class RaggedTensorToVariantOp : public OpKernel { bool batched_input_; }; -#define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type) \ - REGISTER_KERNEL_BUILDER(Name("RaggedTensorToVariant") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("Tvalues") \ - .TypeConstraint("Tsplits"), \ - RaggedTensorToVariantOp); +template +class RaggedTensorToVariantGradientOp : public OpKernel { + public: + using OpKernel::OpKernel; + + void Compute(OpKernelContext* context) override { + // Read inputs. + Tensor encoded_variant = context->input(0); + Tensor row_splits = context->input(1); + auto flat_row_splits = row_splits.flat(); + TensorShape dense_values_shape; + OP_REQUIRES_OK(context, + TensorShapeUtils::MakeShape(context->input(2).vec(), + &dense_values_shape)); + + const auto& flat_variants = encoded_variant.flat(); + + // Get a Tensor containing the flat_values for each variant. + std::vector values; + for (int i = 0; i < flat_variants.size(); ++i) { + if (const auto* encoded = flat_variants(i).get()) { + values.push_back(encoded->values()); + } else { + // Missing value: this happens if only some of the variant values + // generated by ragged_tensor_to_variant impacted the value that we're + // calculating the gradient for. In this case, we will see a + // default-constructed variant; so treat it as a zero tensor with the + // appropriate shape. + const auto value_dtype = DataTypeToEnum::v(); + int piece_size = flat_row_splits(i + 1) - flat_row_splits(i); + TensorShape zeros_shape = dense_values_shape; + zeros_shape.set_dim(0, piece_size); + Tensor zero(value_dtype, zeros_shape); + zero.flat() = + zero.flat().constant(VALUE_TYPE()); + values.push_back(zero); + } + } + + if (values.size() == 1) { + // Just one flat_value tensor: return as-is. + context->set_output(0, values[0]); + } else { + // Multiple flat_values tensors: concatenate them together. + using Piece = typename TTypes::Matrix; + using ConstPiece = typename TTypes::ConstMatrix; + std::vector> pieces; + pieces.reserve(values.size()); + for (const Tensor& t : values) { + pieces.emplace_back( + new ConstPiece(t.shaped({1, t.NumElements()}))); + } + Tensor* out = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, dense_values_shape, &out)); + Piece out_flat = + out->shaped({1, dense_values_shape.num_elements()}); + ConcatCPU(context->device(), pieces, &out_flat); + } + } +}; + +#define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type) \ + REGISTER_KERNEL_BUILDER(Name("RaggedTensorToVariant") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("Tvalues") \ + .TypeConstraint("Tsplits"), \ + RaggedTensorToVariantOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("RaggedTensorToVariantGradient") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("Tvalues") \ + .TypeConstraint("Tsplits"), \ + RaggedTensorToVariantGradientOp); + #define REGISTER_KERNELS(value_type) \ REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int32) \ REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int64) diff --git a/tensorflow/core/kernels/ragged_tensor_variant.cc b/tensorflow/core/kernels/ragged_tensor_variant.cc new file mode 100644 index 00000000..94663138 --- /dev/null +++ b/tensorflow/core/kernels/ragged_tensor_variant.cc @@ -0,0 +1,86 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#include "tensorflow/core/kernels/ragged_tensor_variant.h" + +namespace tensorflow { + +string RaggedTensorVariant::TypeName() const { return "RaggedTensorVariant"; } + +string RaggedTensorVariant::DebugString() const { + return absl::StrCat( + "RaggedTensorVariant(dtype=", DataTypeString(values_.dtype()), + ", ragged_rank=", nested_splits_.size(), ", splits_dtype=", + DataTypeString(nested_splits_.empty() ? DT_INVALID + : nested_splits_.back().dtype())); +} + +void RaggedTensorVariant::Encode(VariantTensorData* data) const { + data->set_type_name(TypeName()); + for (const auto& splits : nested_splits_) { + *data->add_tensors() = splits; + } + *data->add_tensors() = values_; +} + +bool RaggedTensorVariant::Decode(const VariantTensorData& data) { + if (data.tensors_size() < 1) { + return false; + } + nested_splits_.assign(data.tensors().begin(), + std::prev(data.tensors().end())); + values_ = data.tensors().back(); + return true; +} + +namespace { + +Status RaggedTensorVariantDeviceCopy( + const RaggedTensorVariant& from, RaggedTensorVariant* to, + const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) { + TF_RETURN_IF_ERROR(copy(from.values(), to->mutable_values())); + // TODO(b/170415165) Should we use `copy` to move splits from device<->host? + *to->mutable_nested_splits() = from.nested_splits(); + return Status::OK(); +} + +} // namespace + +REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION( + ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, RaggedTensorVariant, + RaggedTensorVariantZerosLike); + +REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION( + ADD_VARIANT_BINARY_OP, DEVICE_CPU, RaggedTensorVariant, + RaggedTensorVariantBinaryAdd); + +REGISTER_UNARY_VARIANT_DECODE_FUNCTION(RaggedTensorVariant, + "RaggedTensorVariant"); + +#define REGISTER_RAGGED_TENSOR_VARIANT_COPY(DIRECTION) \ + INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \ + RaggedTensorVariant, DIRECTION, RaggedTensorVariantDeviceCopy) + +REGISTER_RAGGED_TENSOR_VARIANT_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE); +REGISTER_RAGGED_TENSOR_VARIANT_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST); +REGISTER_RAGGED_TENSOR_VARIANT_COPY( + VariantDeviceCopyDirection::DEVICE_TO_DEVICE); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/ragged_tensor_variant.h b/tensorflow/core/kernels/ragged_tensor_variant.h new file mode 100644 index 00000000..730758a3 --- /dev/null +++ b/tensorflow/core/kernels/ragged_tensor_variant.h @@ -0,0 +1,110 @@ +#include "tensorflow/core/framework/tensor_key.h" +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_ +#define TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_ + +#define EIGEN_USE_THREADS +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/variant_op_registry.h" +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/kernels/cwise_ops_common.h" +#include "tensorflow/core/util/tensor_ops_util.h" + +namespace tensorflow { + +// Class used to store a RaggedTensor as a Variant scalar. +class RaggedTensorVariant { + public: + RaggedTensorVariant() {} + RaggedTensorVariant(Tensor values, const std::vector& nested_splits) + : values_(std::move(values)), nested_splits_(nested_splits) {} + + // Variant support methods. + string TypeName() const; + string DebugString() const; + void Encode(VariantTensorData* data) const; + bool Decode(const VariantTensorData& data); + + // The flat_values of the RaggedTensor. + const Tensor& values() const { return values_; } + Tensor* mutable_values() { return &values_; } + void set_values(const Tensor& new_values) { values_ = new_values; } + + // The nested row_splits of the RaggedTensor. + int ragged_rank() const { return nested_splits_.size(); } + const std::vector& nested_splits() const { return nested_splits_; } + std::vector* mutable_nested_splits() { return &nested_splits_; } + const Tensor& splits(int i) const { return nested_splits_[i]; } + Tensor* mutable_splits(int i) { return &nested_splits_[i]; } + void set_nested_splits(const std::vector& nested_splits) { + nested_splits_ = nested_splits; + } + void append_splits(const Tensor& splits) { nested_splits_.push_back(splits); } + + private: + Tensor values_; + std::vector nested_splits_; +}; + +template +Status RaggedTensorVariantZerosLike(OpKernelContext* c, + const RaggedTensorVariant& x, + RaggedTensorVariant* y) { + y->set_nested_splits(x.nested_splits()); + TF_RETURN_IF_ERROR( + ZerosLikeTensor(c, x.values(), y->mutable_values())); + return Status::OK(); +} + +template +Status RaggedTensorVariantBinaryAdd(OpKernelContext* c, + const RaggedTensorVariant& x, + const RaggedTensorVariant& y, + RaggedTensorVariant* out) { + if (x.values().dtype() != y.values().dtype()) { + return errors::InvalidArgument( + "Can't add RaggedTensorVariants of different dtypes. One is ", + DataTypeString(x.values().dtype()), " and the other is ", + DataTypeString(y.values().dtype())); + } + if (x.ragged_rank() != y.ragged_rank()) { + return errors::InvalidArgument( + "Can't add RaggedTensorVariants of different ragged rank. ", "One is ", + x.ragged_rank(), " and the other is ", y.ragged_rank()); + } + for (int i = 0; i < x.ragged_rank(); ++i) { + if (TensorKey(x.splits(i)) != TensorKey(y.splits(i))) { + return errors::InvalidArgument( + "Can't add RaggedTensorVariants with different row_splits."); + } + } + out->set_nested_splits(x.nested_splits()); + TF_RETURN_IF_ERROR(BinaryAddTensors(c, x.values(), y.values(), + out->mutable_values())); + return Status::OK(); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_ diff --git a/tensorflow/core/ops/ragged_conversion_ops.cc b/tensorflow/core/ops/ragged_conversion_ops.cc index 6bee189c..8512bcf3 100644 --- a/tensorflow/core/ops/ragged_conversion_ops.cc +++ b/tensorflow/core/ops/ragged_conversion_ops.cc @@ -92,7 +92,8 @@ tensorflow::Status ValidateRowPartitionTypesAndShapes( Status RaggedTensorToSparseShapeFn(InferenceContext* c); Status RaggedTensorToVariantShapeFn(InferenceContext* c); Status RaggedTensorFromVariantShapeFn(InferenceContext* c); -tensorflow::Status RaggedTensorToTensorShapeFn(InferenceContext* c); +Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c); +Status RaggedTensorToTensorShapeFn(InferenceContext* c); //============================================================================== // Registered Ops @@ -129,6 +130,15 @@ REGISTER_OP("RaggedTensorFromVariant") .Attr("Tsplits: {int32, int64} = DT_INT64") .SetShapeFn(RaggedTensorFromVariantShapeFn); +REGISTER_OP("RaggedTensorToVariantGradient") + .Input("encoded_ragged_grad: variant") + .Input("row_splits: Tsplits") + .Input("dense_values_shape: int32") + .Output("dense_values_grad: Tvalues") + .Attr("Tvalues: type") + .Attr("Tsplits: {int32, int64} = DT_INT64") + .SetShapeFn(RaggedTensorToVariantGradientShapeFn); + REGISTER_OP("RaggedTensorToTensor") .Attr("T: type") .Attr("Tindex: {int64, int32}") @@ -201,6 +211,14 @@ Status RaggedTensorToVariantShapeFn(InferenceContext* c) { return Status::OK(); } +Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c) { + ShapeHandle shape; + TF_RETURN_IF_ERROR( + c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(2, &shape)); + c->set_output(0, shape); + return Status::OK(); +} + Status RaggedTensorFromVariantShapeFn(InferenceContext* c) { int64 input_ragged_rank; TF_RETURN_IF_ERROR( diff --git a/tensorflow/python/ops/ragged/BUILD b/tensorflow/python/ops/ragged/BUILD index 95e5602a..34372160 100644 --- a/tensorflow/python/ops/ragged/BUILD +++ b/tensorflow/python/ops/ragged/BUILD @@ -507,6 +507,7 @@ py_test( "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", + "//tensorflow/python:tensor_array_grad", "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_spec", "//tensorflow/python/data/ops:dataset_ops", -- 2.27.0