61 lines
2.6 KiB
Diff
61 lines
2.6 KiB
Diff
From 63333b967844327856352f484aeddd1509b10604 Mon Sep 17 00:00:00 2001
|
|
From: Yong Tang <yong.tang.github@outlook.com>
|
|
Date: Sat, 6 Feb 2021 18:53:33 +0000
|
|
Subject: [PATCH 1/2] Fix crash with tf.transpose when a is complex and
|
|
conjugate is True
|
|
|
|
This PR tries to address the issue raised in 46891 where
|
|
tf.transpose will crash when a is complex and conjugate is True.
|
|
The issue comes from:
|
|
https://github.com/tensorflow/tensorflow/blob/57bbc5e0d4b93483b8ae853352173516f1c08018/tensorflow/core/kernels/transpose_functor.h#L169
|
|
|
|
However, as ndims < 2 has already been handled properly:
|
|
https://github.com/tensorflow/tensorflow/blob/57bbc5e0d4b93483b8ae853352173516f1c08018/tensorflow/core/kernels/transpose_functor_cpu.cc#L103-L105
|
|
The check could be removed.
|
|
|
|
This PR fixes 46891.
|
|
|
|
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
|
|
---
|
|
tensorflow/core/kernels/transpose_functor.h | 1 -
|
|
tensorflow/python/kernel_tests/transpose_op_test.py | 4 ++++
|
|
2 files changed, 4 insertions(+), 1 deletion(-)
|
|
|
|
diff --git a/tensorflow/core/kernels/transpose_functor.h b/tensorflow/core/kernels/transpose_functor.h
|
|
index 0c22b11b..479ad7af 100644
|
|
--- a/tensorflow/core/kernels/transpose_functor.h
|
|
+++ b/tensorflow/core/kernels/transpose_functor.h
|
|
@@ -166,7 +166,6 @@ template <typename Device>
|
|
Status DoTransposeImpl(const Device& d, const Tensor& in,
|
|
const gtl::ArraySlice<int32> perm, bool conjugate,
|
|
Tensor* out) {
|
|
- CHECK_GE(in.dims(), 2);
|
|
CHECK_EQ(in.dims(), out->dims());
|
|
CHECK_EQ(in.dims(), perm.size());
|
|
CHECK_EQ(in.dtype(), out->dtype());
|
|
diff --git a/tensorflow/python/kernel_tests/transpose_op_test.py b/tensorflow/python/kernel_tests/transpose_op_test.py
|
|
index 87096211..ed634ae7 100644
|
|
--- a/tensorflow/python/kernel_tests/transpose_op_test.py
|
|
+++ b/tensorflow/python/kernel_tests/transpose_op_test.py
|
|
@@ -387,6 +387,8 @@ class TransposeTest(test.TestCase):
|
|
|
|
@test_util.run_v1_only("b/120545219")
|
|
def testComplex64(self):
|
|
+ self._testBoth(np.array(np.complex(1, 2)).astype(np.complex64))
|
|
+ self._testBoth(np.complex(1, 2) * np.arange(0, 21).astype(np.complex64))
|
|
self._testBoth(
|
|
np.complex(1, 2) *
|
|
np.arange(0, 21).reshape([3, 7]).astype(np.complex64))
|
|
@@ -399,6 +401,8 @@ class TransposeTest(test.TestCase):
|
|
|
|
@test_util.run_v1_only("b/120545219")
|
|
def testComplex128(self):
|
|
+ self._testBoth(np.array(np.complex(1, 2)).astype(np.complex128))
|
|
+ self._testBoth(np.complex(1, 2) * np.arange(0, 21).astype(np.complex128))
|
|
self._testBoth(
|
|
np.complex(1, 2) *
|
|
np.arange(0, 21).reshape([3, 7]).astype(np.complex128))
|
|
--
|
|
2.27.0
|
|
|