89 lines
4.0 KiB
Diff
89 lines
4.0 KiB
Diff
|
|
From 4e2565483d0ffcadc719bd44893fb7f609bb5f12 Mon Sep 17 00:00:00 2001
|
||
|
|
From: Edward Loper <edloper@google.com>
|
||
|
|
Date: Thu, 29 Jul 2021 09:50:01 -0700
|
||
|
|
Subject: [PATCH] Fix bug that could cause map_fn to produce incorrect results
|
||
|
|
(rather than an error) when mapping over a ragged tensor with an
|
||
|
|
inappropriate fn_output_signature. (Note: there are cases where the default
|
||
|
|
value for fn_output_signature is not appropriate, so the user needs to
|
||
|
|
explicitly specify the correct output signature.)
|
||
|
|
|
||
|
|
PiperOrigin-RevId: 387606546
|
||
|
|
Change-Id: Ib4ea27b9634e6ab413f211cfe809a69a90f0e2cd
|
||
|
|
---
|
||
|
|
.../kernels/ragged_tensor_from_variant_op.cc | 16 +++++++++++++
|
||
|
|
.../ops/ragged/ragged_map_fn_op_test.py | 23 +++++++++++++++++++
|
||
|
|
2 files changed, 39 insertions(+)
|
||
|
|
|
||
|
|
diff --git a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
|
||
|
|
index d9993bb6d3907..c481d90638e4e 100644
|
||
|
|
--- a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
|
||
|
|
+++ b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
|
||
|
|
@@ -174,7 +174,23 @@ Status NestedStackRaggedTensors(
|
||
|
|
auto output_values_flat =
|
||
|
|
output_ragged->mutable_values()->flat_outer_dims<VALUE_TYPE, 2>();
|
||
|
|
int values_index = 0;
|
||
|
|
+
|
||
|
|
+ TensorShape expected_value_shape = component_values_shape;
|
||
|
|
+ expected_value_shape.RemoveDim(0);
|
||
|
|
+
|
||
|
|
for (int i = 0; i < ragged_components.size(); i++) {
|
||
|
|
+ // Check that the flat_values tensor shape is compatible.
|
||
|
|
+ TensorShape value_shape = ragged_components[i].values().shape();
|
||
|
|
+ value_shape.RemoveDim(0);
|
||
|
|
+ if (value_shape != expected_value_shape) {
|
||
|
|
+ return errors::InvalidArgument(
|
||
|
|
+ "All flat_values must have compatible shapes. Shape at index 0: ",
|
||
|
|
+ expected_value_shape, ". Shape at index ", i, ": ", value_shape,
|
||
|
|
+ ". If you are using tf.map_fn, then you may need to specify an "
|
||
|
|
+ "explicit fn_output_signature with appropriate ragged_rank, and/or "
|
||
|
|
+ "convert output tensors to RaggedTensors.");
|
||
|
|
+ }
|
||
|
|
+
|
||
|
|
auto component_values_flat =
|
||
|
|
ragged_components[i].values().flat_outer_dims<VALUE_TYPE, 2>();
|
||
|
|
int num_inner_elements = ragged_components[i].values().NumElements();
|
||
|
|
diff --git a/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py b/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py
|
||
|
|
index bead4923a0a4c..ace724ac8711d 100644
|
||
|
|
--- a/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py
|
||
|
|
+++ b/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py
|
||
|
|
@@ -21,9 +21,11 @@
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
from tensorflow.python.framework import dtypes
|
||
|
|
+from tensorflow.python.framework import errors
|
||
|
|
from tensorflow.python.framework import sparse_tensor
|
||
|
|
from tensorflow.python.framework import test_util
|
||
|
|
from tensorflow.python.ops import array_ops
|
||
|
|
+from tensorflow.python.ops import map_fn as map_fn_lib
|
||
|
|
from tensorflow.python.ops import math_ops as mo
|
||
|
|
from tensorflow.python.ops import string_ops
|
||
|
|
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||
|
|
@@ -309,6 +311,27 @@ def testMapOnSparseTensor(self):
|
||
|
|
)
|
||
|
|
self.assertAllEqual(id_t2, [[0, 5], [0, 4]])
|
||
|
|
|
||
|
|
+ def testRaggedMapWithIncorrectFnOutputSignature(self):
|
||
|
|
+ x = ragged_factory_ops.constant([[1, 2, 3, 4], [1]])
|
||
|
|
+ with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||
|
|
+ 'All flat_values must have compatible shapes'):
|
||
|
|
+ y = map_fn_lib.map_fn(lambda r: map_fn_lib.map_fn(lambda y: r, r), x)
|
||
|
|
+ self.evaluate(y)
|
||
|
|
+
|
||
|
|
+ def testNestedRaggedMapWithFnOutputSignature(self):
|
||
|
|
+ ragged1d = ragged_tensor.RaggedTensorSpec([None], dtypes.int32)
|
||
|
|
+ ragged2d = ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)
|
||
|
|
+
|
||
|
|
+ x = ragged_factory_ops.constant([[1, 2, 3, 4], [1]])
|
||
|
|
+ # pylint: disable=g-long-lambda
|
||
|
|
+ y = map_fn_lib.map_fn(
|
||
|
|
+ lambda r: map_fn_lib.map_fn(
|
||
|
|
+ lambda y: r, r, fn_output_signature=ragged1d),
|
||
|
|
+ x,
|
||
|
|
+ fn_output_signature=ragged2d)
|
||
|
|
+ expected = [[[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]], [[1]]]
|
||
|
|
+ self.assertAllEqual(y, expected)
|
||
|
|
+
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
googletest.main()
|