From 63333b967844327856352f484aeddd1509b10604 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 6 Feb 2021 18:53:33 +0000 Subject: [PATCH 1/2] Fix crash with tf.transpose when a is complex and conjugate is True This PR tries to address the issue raised in 46891 where tf.transpose will crash when a is complex and conjugate is True. The issue comes from: https://github.com/tensorflow/tensorflow/blob/57bbc5e0d4b93483b8ae853352173516f1c08018/tensorflow/core/kernels/transpose_functor.h#L169 However, as ndims < 2 has already been handled properly: https://github.com/tensorflow/tensorflow/blob/57bbc5e0d4b93483b8ae853352173516f1c08018/tensorflow/core/kernels/transpose_functor_cpu.cc#L103-L105 The check could be removed. This PR fixes 46891. Signed-off-by: Yong Tang --- tensorflow/core/kernels/transpose_functor.h | 1 - tensorflow/python/kernel_tests/transpose_op_test.py | 4 ++++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/transpose_functor.h b/tensorflow/core/kernels/transpose_functor.h index 0c22b11b..479ad7af 100644 --- a/tensorflow/core/kernels/transpose_functor.h +++ b/tensorflow/core/kernels/transpose_functor.h @@ -166,7 +166,6 @@ template Status DoTransposeImpl(const Device& d, const Tensor& in, const gtl::ArraySlice perm, bool conjugate, Tensor* out) { - CHECK_GE(in.dims(), 2); CHECK_EQ(in.dims(), out->dims()); CHECK_EQ(in.dims(), perm.size()); CHECK_EQ(in.dtype(), out->dtype()); diff --git a/tensorflow/python/kernel_tests/transpose_op_test.py b/tensorflow/python/kernel_tests/transpose_op_test.py index 87096211..ed634ae7 100644 --- a/tensorflow/python/kernel_tests/transpose_op_test.py +++ b/tensorflow/python/kernel_tests/transpose_op_test.py @@ -387,6 +387,8 @@ class TransposeTest(test.TestCase): @test_util.run_v1_only("b/120545219") def testComplex64(self): + self._testBoth(np.array(np.complex(1, 2)).astype(np.complex64)) + self._testBoth(np.complex(1, 2) * np.arange(0, 21).astype(np.complex64)) self._testBoth( np.complex(1, 2) * np.arange(0, 21).reshape([3, 7]).astype(np.complex64)) @@ -399,6 +401,8 @@ class TransposeTest(test.TestCase): @test_util.run_v1_only("b/120545219") def testComplex128(self): + self._testBoth(np.array(np.complex(1, 2)).astype(np.complex128)) + self._testBoth(np.complex(1, 2) * np.arange(0, 21).astype(np.complex128)) self._testBoth( np.complex(1, 2) * np.arange(0, 21).reshape([3, 7]).astype(np.complex128)) -- 2.27.0