From 4e2565483d0ffcadc719bd44893fb7f609bb5f12 Mon Sep 17 00:00:00 2001 From: Edward Loper Date: Thu, 29 Jul 2021 09:50:01 -0700 Subject: [PATCH] Fix bug that could cause map_fn to produce incorrect results (rather than an error) when mapping over a ragged tensor with an inappropriate fn_output_signature. (Note: there are cases where the default value for fn_output_signature is not appropriate, so the user needs to explicitly specify the correct output signature.) PiperOrigin-RevId: 387606546 Change-Id: Ib4ea27b9634e6ab413f211cfe809a69a90f0e2cd --- .../kernels/ragged_tensor_from_variant_op.cc | 16 +++++++++++++ .../ops/ragged/ragged_map_fn_op_test.py | 23 +++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc index d9993bb6d3907..c481d90638e4e 100644 --- a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc +++ b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc @@ -174,7 +174,23 @@ Status NestedStackRaggedTensors( auto output_values_flat = output_ragged->mutable_values()->flat_outer_dims(); int values_index = 0; + + TensorShape expected_value_shape = component_values_shape; + expected_value_shape.RemoveDim(0); + for (int i = 0; i < ragged_components.size(); i++) { + // Check that the flat_values tensor shape is compatible. + TensorShape value_shape = ragged_components[i].values().shape(); + value_shape.RemoveDim(0); + if (value_shape != expected_value_shape) { + return errors::InvalidArgument( + "All flat_values must have compatible shapes. Shape at index 0: ", + expected_value_shape, ". Shape at index ", i, ": ", value_shape, + ". If you are using tf.map_fn, then you may need to specify an " + "explicit fn_output_signature with appropriate ragged_rank, and/or " + "convert output tensors to RaggedTensors."); + } + auto component_values_flat = ragged_components[i].values().flat_outer_dims(); int num_inner_elements = ragged_components[i].values().NumElements(); diff --git a/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py b/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py index bead4923a0a4c..ace724ac8711d 100644 --- a/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py +++ b/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py @@ -21,9 +21,11 @@ import numpy as np from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import map_fn as map_fn_lib from tensorflow.python.ops import math_ops as mo from tensorflow.python.ops import string_ops from tensorflow.python.ops.ragged import ragged_factory_ops @@ -309,6 +311,27 @@ def testMapOnSparseTensor(self): ) self.assertAllEqual(id_t2, [[0, 5], [0, 4]]) + def testRaggedMapWithIncorrectFnOutputSignature(self): + x = ragged_factory_ops.constant([[1, 2, 3, 4], [1]]) + with self.assertRaisesRegex(errors.InvalidArgumentError, + 'All flat_values must have compatible shapes'): + y = map_fn_lib.map_fn(lambda r: map_fn_lib.map_fn(lambda y: r, r), x) + self.evaluate(y) + + def testNestedRaggedMapWithFnOutputSignature(self): + ragged1d = ragged_tensor.RaggedTensorSpec([None], dtypes.int32) + ragged2d = ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32) + + x = ragged_factory_ops.constant([[1, 2, 3, 4], [1]]) + # pylint: disable=g-long-lambda + y = map_fn_lib.map_fn( + lambda r: map_fn_lib.map_fn( + lambda y: r, r, fn_output_signature=ragged1d), + x, + fn_output_signature=ragged2d) + expected = [[[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]], [[1]]] + self.assertAllEqual(y, expected) + if __name__ == '__main__': googletest.main()