From 9d94482224acde044692d74107339a29f862cbac Mon Sep 17 00:00:00 2001 From: Advait Jain Date: Wed, 15 Jul 2020 16:20:40 -0700 Subject: [PATCH] Change some getters to not be inline. This enables some --- tensorflow/lite/kernels/kernel_util.cc | 25 +++++++++++++ tensorflow/lite/kernels/kernel_util.h | 49 +++++++++++--------------- 2 files changed, 46 insertions(+), 28 deletions(-) diff --git a/tensorflow/lite/kernels/kernel_util.cc b/tensorflow/lite/kernels/kernel_util.cc index 164aec3f..f7d7c25b 100644 --- a/tensorflow/lite/kernels/kernel_util.cc +++ b/tensorflow/lite/kernels/kernel_util.cc @@ -27,6 +27,31 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/quantization_util.h" namespace tflite { +const TfLiteTensor* GetInput(const TfLiteContext* context, + const TfLiteNode* node, int index) { + return &context->tensors[node->inputs->data[index]]; +} + +TfLiteTensor* GetVariableInput(TfLiteContext* context, const TfLiteNode* node, + int index) { + TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]]; + return (tensor->is_variable) ? tensor : nullptr; +} + +TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node, + int index) { + return &context->tensors[node->outputs->data[index]]; +} + +const TfLiteTensor* GetOptionalInputTensor(const TfLiteContext* context, + const TfLiteNode* node, int index) { + const bool use_tensor = index < node->inputs->size && + node->inputs->data[index] != kTfLiteOptionalTensor; + if (use_tensor) { + return &context->tensors[node->inputs->data[index]]; + } + return nullptr; +} // Per-axis TfLiteStatus PopulateConvolutionQuantizationParams( diff --git a/tensorflow/lite/kernels/kernel_util.h b/tensorflow/lite/kernels/kernel_util.h index 59b1974c..371b712f 100644 --- a/tensorflow/lite/kernels/kernel_util.h +++ b/tensorflow/lite/kernels/kernel_util.h @@ -24,38 +24,31 @@ limitations under the License. namespace tflite { -inline int NumDimensions(const TfLiteTensor* t) { return t->dims->size; } -inline int SizeOfDimension(const TfLiteTensor* t, int dim) { - return t->dims->data[dim]; -} -inline const TfLiteTensor* GetInput(const TfLiteContext* context, - const TfLiteNode* node, int index) { - const int tensor_index = node->inputs->data[index]; - if (tensor_index < 0) { - return nullptr; - } - return &context->tensors[tensor_index]; -} +// A fair number of functions in this header have historically been inline. +// It is ok to change functions to not be inline if the latency with +// benchmark_model for MobileNet + MobileBERT is unaffected. If such a change is +// made, move the newly non-inlined function declarations to the top of this +// header file. +const TfLiteTensor* GetInput(const TfLiteContext* context, + const TfLiteNode* node, int index); + // Note: You must check if result is not null: // TfLiteTensor* my_tensor = GetVariableInput(context, node, kMyTensorIdx); // TF_LITE_ENSURE(context, my_tensor != nullptr); -inline TfLiteTensor* GetVariableInput(TfLiteContext* context, - const TfLiteNode* node, int index) { - const int tensor_index = node->inputs->data[index]; - if (tensor_index < 0) { - return nullptr; - } - TfLiteTensor* tensor = &context->tensors[tensor_index]; - return (tensor->is_variable) ? tensor : nullptr; -} -inline TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node, - int index) { - const int tensor_index = node->outputs->data[index]; - if (tensor_index < 0) { - return nullptr; - } - return &context->tensors[tensor_index]; +TfLiteTensor* GetVariableInput(TfLiteContext* context, const TfLiteNode* node, + int index); + +TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node, + int index); + +const TfLiteTensor* GetOptionalInputTensor(const TfLiteContext* context, + const TfLiteNode* node, int index); + +inline int NumDimensions(const TfLiteTensor* t) { return t->dims->size; } +inline int SizeOfDimension(const TfLiteTensor* t, int dim) { + return t->dims->data[dim]; } + inline TfLiteTensor* GetTemporary(TfLiteContext* context, const TfLiteNode* node, int index) { const int tensor_index = node->temporaries->data[index]; -- 2.23.0