From 698e01511f62a3c185754db78ebce0eee1f0184d Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Fri, 30 Apr 2021 06:36:59 -0700 Subject: [PATCH] Fix `tf.io.decode_raw` bugs and update documentation. Fixes cases where specifying `fixed_length` resulted in data loss and even segfault and corruption of the Python interpreter. The fix is subtle but needed due to pointer arithmetic rules. Makes sure that `fixed_length` does not change the output when present but not needed. Eliminates needless copy and cast in the main codepath. PiperOrigin-RevId: 371322725 Change-Id: I514ef67a2961c86422f69d05122d31615e87896c --- .../core/kernels/decode_padded_raw_op.cc | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/tensorflow/core/kernels/decode_padded_raw_op.cc b/tensorflow/core/kernels/decode_padded_raw_op.cc index 12e8ec6a..d3e830c0 100644 --- a/tensorflow/core/kernels/decode_padded_raw_op.cc +++ b/tensorflow/core/kernels/decode_padded_raw_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { @@ -83,14 +84,13 @@ class DecodePaddedRawOp : public OpKernel { // can copy the memory directly. if (!convert_data_endianness_ || sizeof(T) == 1) { for (int64 i = 0; i < flat_in.size(); ++i) { - const T* in_data = reinterpret_cast(flat_in(i).data()); - - if (flat_in(i).size() > fixed_length) { - memcpy(out_data, in_data, fixed_length); - } else { - memcpy(out_data, in_data, flat_in(i).size()); - } - out_data += fixed_length; + const auto to_copy = + std::min(flat_in(i).size(), static_cast(fixed_length)); + memcpy(out_data, flat_in(i).data(), to_copy); + // Note: increase out_data by width since it's already of type T* so + // each shift amount is implicitly multiplied by sizeof(T) according to + // pointer arithmetic rules. + out_data += width; } } else { // Otherwise, the data is not in the host's byte order, and rather than a @@ -105,7 +105,10 @@ class DecodePaddedRawOp : public OpKernel { p_in += sizeof(T), p_out += sizeof(T)) { std::reverse_copy(p_in, p_in + sizeof(T), p_out); } - out_data += fixed_length; + // Note: increase out_data by width since it's already of type T* so + // each shift amount is implicitly multiplied by sizeof(T) according to + // pointer arithmetic rules. + out_data += width; } } } -- 2.27.0