52 lines
2.1 KiB
Diff
52 lines
2.1 KiB
Diff
From 1e206baedf8bef0334cca3eb92bab134ef525a28 Mon Sep 17 00:00:00 2001
|
|
From: Mihai Maruseac <mihaimaruseac@google.com>
|
|
Date: Fri, 16 Jul 2021 14:23:21 -0700
|
|
Subject: [PATCH] Prevent a division by 0 in division ops.
|
|
|
|
PiperOrigin-RevId: 385223169
|
|
Change-Id: Ia4228960b5d2aa44480385f74bdd70d21a3613c3
|
|
---
|
|
tensorflow/lite/kernels/div.cc | 17 ++++++++++++++++-
|
|
1 file changed, 16 insertions(+), 1 deletion(-)
|
|
|
|
diff --git a/tensorflow/lite/kernels/div.cc b/tensorflow/lite/kernels/div.cc
|
|
index c9eb1db5..aafe00f0 100644
|
|
--- a/tensorflow/lite/kernels/div.cc
|
|
+++ b/tensorflow/lite/kernels/div.cc
|
|
@@ -204,9 +204,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
|
|
|
- if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
|
|
+ // TODO(b/193904910): This can written with C++ templates
|
|
+#define TF_LITE_CHECK_DIV_NON_ZERO(data_type) \
|
|
+ const auto* input2_data = GetTensorData<data_type>(input2); \
|
|
+ const size_t input2_elements = input2->bytes / sizeof(data_type); \
|
|
+ for (size_t i = 0; i < input2_elements; i++) { \
|
|
+ TF_LITE_ENSURE(context, input2_data[i] != 0); \
|
|
+ }
|
|
+
|
|
+ if (output->type == kTfLiteFloat32) {
|
|
+ // Div by zero seems ok in this case, just like in TF case infinities are
|
|
+ // returned. So we don't do a check at this point.
|
|
+ EvalDiv<kernel_type>(context, node, params, data, input1, input2, output);
|
|
+ } else if (output->type == kTfLiteInt32) {
|
|
+ TF_LITE_CHECK_DIV_NON_ZERO(int32_t);
|
|
EvalDiv<kernel_type>(context, node, params, data, input1, input2, output);
|
|
} else if (output->type == kTfLiteUInt8) {
|
|
+ TF_LITE_CHECK_DIV_NON_ZERO(uint8_t);
|
|
TF_LITE_ENSURE_OK(
|
|
context, EvalQuantized<kernel_type>(context, node, params, data, input1,
|
|
input2, output));
|
|
@@ -217,6 +231,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|
output->type);
|
|
return kTfLiteError;
|
|
}
|
|
+#undef TF_LITE_CHECK_DIV_NON_ZERO
|
|
|
|
return kTfLiteOk;
|
|
}
|
|
--
|
|
2.27.0
|
|
|