From 1e206baedf8bef0334cca3eb92bab134ef525a28 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Fri, 16 Jul 2021 14:23:21 -0700 Subject: [PATCH] Prevent a division by 0 in division ops. PiperOrigin-RevId: 385223169 Change-Id: Ia4228960b5d2aa44480385f74bdd70d21a3613c3 --- tensorflow/lite/kernels/div.cc | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/kernels/div.cc b/tensorflow/lite/kernels/div.cc index c9eb1db5..aafe00f0 100644 --- a/tensorflow/lite/kernels/div.cc +++ b/tensorflow/lite/kernels/div.cc @@ -204,9 +204,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) { + // TODO(b/193904910): This can written with C++ templates +#define TF_LITE_CHECK_DIV_NON_ZERO(data_type) \ + const auto* input2_data = GetTensorData(input2); \ + const size_t input2_elements = input2->bytes / sizeof(data_type); \ + for (size_t i = 0; i < input2_elements; i++) { \ + TF_LITE_ENSURE(context, input2_data[i] != 0); \ + } + + if (output->type == kTfLiteFloat32) { + // Div by zero seems ok in this case, just like in TF case infinities are + // returned. So we don't do a check at this point. + EvalDiv(context, node, params, data, input1, input2, output); + } else if (output->type == kTfLiteInt32) { + TF_LITE_CHECK_DIV_NON_ZERO(int32_t); EvalDiv(context, node, params, data, input1, input2, output); } else if (output->type == kTfLiteUInt8) { + TF_LITE_CHECK_DIV_NON_ZERO(uint8_t); TF_LITE_ENSURE_OK( context, EvalQuantized(context, node, params, data, input1, input2, output)); @@ -217,6 +231,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { output->type); return kTfLiteError; } +#undef TF_LITE_CHECK_DIV_NON_ZERO return kTfLiteOk; } -- 2.27.0