From 14755416e364f17fb1870882fa778c7fec7f16e3 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Mon, 7 Dec 2020 20:31:31 -0800 Subject: [PATCH] Prevent CHECK-fail in LSTM/GRU with zero-length input. PiperOrigin-RevId: 346239181 Change-Id: I5f233dbc076aab7bb4e31ba24f5abd4eaf99ea4f --- tensorflow/stream_executor/cuda/cuda_dnn.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index a97850bd..5ae19f27 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -1474,7 +1474,9 @@ class CudnnRnnSequenceTensorDescriptor static port::StatusOr Create( GpuExecutor* parent, int max_seq_length, int batch_size, int data_size, cudnnDataType_t data_type) { - CHECK_GT(max_seq_length, 0); + if (max_seq_length <= 0) { + return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0"); + } int dims[] = {batch_size, data_size, 1}; int strides[] = {dims[1] * dims[2], dims[2], 1}; TensorDescriptor tensor_desc = CreateTensorDescriptor(); @@ -1495,7 +1497,9 @@ class CudnnRnnSequenceTensorDescriptor const absl::Span& seq_lengths, bool time_major, cudnnDataType_t data_type) { #if CUDNN_VERSION >= 7201 - CHECK_GT(max_seq_length, 0); + if (max_seq_length <= 0) { + return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0"); + } int dims[] = {batch_size, data_size, 1}; int strides[] = {dims[1] * dims[2], dims[2], 1}; TensorDescriptor tensor_desc = CreateTensorDescriptor(); -- 2.27.0