60 lines
2.5 KiB
Diff
60 lines
2.5 KiB
Diff
|
|
From 7edb8c9b83ad583616406af61e0de61393996a3e Mon Sep 17 00:00:00 2001
|
||
|
|
From: Yong Tang <yong.tang.github@outlook.com>
|
||
|
|
Date: Sat, 6 Feb 2021 20:24:54 +0000
|
||
|
|
Subject: [PATCH] Fix crash of tf.strings.substr when pos and len have
|
||
|
|
different shapes
|
||
|
|
|
||
|
|
This PR tries to address the issue raised in 46900 where
|
||
|
|
tf.strings.substr will crash when pos and len have different shapes.
|
||
|
|
According to the documentation of tf.strings.substr, ValueError
|
||
|
|
should be raised instead when pos and len does not have the same shape.
|
||
|
|
|
||
|
|
This PR add shape check in kernel to allows grace error throw (instead of crash).
|
||
|
|
|
||
|
|
This PR fixes 46900.
|
||
|
|
|
||
|
|
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
|
||
|
|
---
|
||
|
|
tensorflow/core/kernels/substr_op.cc | 6 ++++++
|
||
|
|
tensorflow/python/kernel_tests/substr_op_test.py | 10 ++++++++++
|
||
|
|
2 files changed, 16 insertions(+)
|
||
|
|
|
||
|
|
diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc
|
||
|
|
index e382381e12232..0c94ba35b249a 100644
|
||
|
|
--- a/tensorflow/core/kernels/substr_op.cc
|
||
|
|
+++ b/tensorflow/core/kernels/substr_op.cc
|
||
|
|
@@ -51,6 +51,12 @@ class SubstrOp : public OpKernel {
|
||
|
|
const Tensor& len_tensor = context->input(2);
|
||
|
|
const TensorShape& input_shape = input_tensor.shape();
|
||
|
|
const TensorShape& pos_shape = pos_tensor.shape();
|
||
|
|
+ const TensorShape& len_shape = len_tensor.shape();
|
||
|
|
+ OP_REQUIRES(
|
||
|
|
+ context, (pos_shape == len_shape),
|
||
|
|
+ errors::InvalidArgument("pos and len should have the same shape, got: ",
|
||
|
|
+ pos_shape.DebugString(), " vs. ",
|
||
|
|
+ len_shape.DebugString()));
|
||
|
|
|
||
|
|
bool is_scalar = TensorShapeUtils::IsScalar(pos_shape);
|
||
|
|
|
||
|
|
diff --git a/tensorflow/python/kernel_tests/substr_op_test.py b/tensorflow/python/kernel_tests/substr_op_test.py
|
||
|
|
index 9302152e82bfa..ad7b6050c2901 100644
|
||
|
|
--- a/tensorflow/python/kernel_tests/substr_op_test.py
|
||
|
|
+++ b/tensorflow/python/kernel_tests/substr_op_test.py
|
||
|
|
@@ -492,6 +492,16 @@ def testInvalidUnit(self):
|
||
|
|
with self.assertRaises(ValueError):
|
||
|
|
string_ops.substr(b"test", 3, 1, unit="UTF8")
|
||
|
|
|
||
|
|
+ def testInvalidPos(self):
|
||
|
|
+ # Test case for GitHub issue 46900.
|
||
|
|
+ with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
|
||
|
|
+ x = string_ops.substr(b"abc", len=1, pos=[1, -1])
|
||
|
|
+ self.evaluate(x)
|
||
|
|
+
|
||
|
|
+ with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
|
||
|
|
+ x = string_ops.substr(b"abc", len=1, pos=[1, 2])
|
||
|
|
+ self.evaluate(x)
|
||
|
|
+
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
test.main()
|