From 51300ba1cc2f487aefec6e6631fef03b0e08b298 Mon Sep 17 00:00:00 2001 From: Laura Pak Date: Mon, 3 May 2021 09:53:26 -0700 Subject: [PATCH] Fix heap buffer overflow in tf.raw_ops.UnicodeEncode. PiperOrigin-RevId: 371717714 Change-Id: If33443b28f158e58078f1268f6b92f2728d219e0 --- tensorflow/core/kernels/unicode_ops.cc | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tensorflow/core/kernels/unicode_ops.cc b/tensorflow/core/kernels/unicode_ops.cc index d3a7ad7b2866f..e6c8f4dfc4228 100644 --- a/tensorflow/core/kernels/unicode_ops.cc +++ b/tensorflow/core/kernels/unicode_ops.cc @@ -533,6 +533,17 @@ class UnicodeEncodeOp : public OpKernel { const Tensor& input_splits = context->input(1); const auto input_splits_flat = input_splits.flat(); + // Operation will treat first argument in input_splits as if it were zero + // regardless of its actual value since splits should begin with zero and + // end with the length of the input values vector. + OP_REQUIRES( + context, input_splits_flat(0) == 0, + errors::InvalidArgument("First value in input_splits must be zero.")); + OP_REQUIRES(context, + input_splits_flat(input_splits_flat.size() - 1) == + input_tensor_flat.size(), + errors::InvalidArgument("Last value in input_splits must be " + "equal to length of input_tensor.")); // Since we limit to a 2-D input (flat_values of rank 1 and a single splits // tensor), our output dimension will be 1 with it's size equal to the // number of splits (outer dimension or ragged tensor). @@ -548,6 +559,14 @@ class UnicodeEncodeOp : public OpKernel { for (int i = 1; i < input_splits_flat.size(); ++i) { icu::UnicodeString unicode_string; icu::UnicodeStringAppendable appendable_unicode_string(unicode_string); + OP_REQUIRES( + context, input_splits_flat(i - 1) <= input_splits_flat(i), + errors::InvalidArgument( + "Values in input_splits must be equal or in ascending order.")); + OP_REQUIRES( + context, input_splits_flat(i) <= input_tensor_flat.size(), + errors::InvalidArgument("Values in input_splits must be less than or " + "equal to input_tensor length.")); for (; idx < input_splits_flat(i); ++idx) { int32 code_point = input_tensor_flat(idx); // Check for invalid code point