From 3ebedd7e345453d68e279cfc3e4072648e5e12e5 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Wed, 28 Apr 2021 12:58:07 -0700 Subject: [PATCH] Prevent division by 0 in OneHot implementation If input indices is degenerate, the implementation would do a divide by zero. See https://github.com/tensorflow/tensorflow/blob/745d57df6d5e9bc568666a2a48ed8dd629c27241/tensorflow/lite/kernels/one_hot.cc#L68-L72 PiperOrigin-RevId: 370966870 Change-Id: Ie018337811c8016b5a1d3a277d00d5f2e19a2058 --- tensorflow/lite/kernels/one_hot.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tensorflow/lite/kernels/one_hot.cc b/tensorflow/lite/kernels/one_hot.cc index f7b4e8e7e19d5..75bfb48d6b19c 100644 --- a/tensorflow/lite/kernels/one_hot.cc +++ b/tensorflow/lite/kernels/one_hot.cc @@ -69,6 +69,11 @@ void OneHotComputeImpl(const OneHotContext& op_context) { for (int i = 0; i < op_context.axis; ++i) { prefix_dim_size *= op_context.indices->dims->data[i]; } + if (prefix_dim_size == 0) { + // If indices tensor is degenerate, return a degenerate tensor, just like + // TensorFlow does. + return; + } const int suffix_dim_size = NumElements(op_context.indices) / prefix_dim_size; const int depth = *op_context.depth->data.i32;