From 4aacb30888638da75023e6601149415b39763d76 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Tue, 3 Aug 2021 12:28:58 -0700 Subject: [PATCH] Disallow division by zero FPE in `tf.raw_ops.ResourceScatterDiv` Had to update a test that was broken. PiperOrigin-RevId: 388516976 Change-Id: Ic358e6bf0559e011539974d453fc7aa18b427e9c --- .../core/kernels/resource_variable_ops.cc | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index b9c883c7..e056d9cb 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -844,6 +844,35 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_ND_GPU); #undef REGISTER_GATHER_ND_ALL_INDICES #undef REGISTER_GATHER_ND_FULL +namespace { + +template +bool isCPUDevice() { + return false; +} + +template <> +bool isCPUDevice() { + return true; +} + +template +bool ValidateInput(const Tensor& updates) { + const auto updates_flat = updates.flat(); + const T zero(0); + for (int i = 0; i < updates.NumElements(); i++) { + if (updates_flat(i) == zero) return false; + } + return true; +} + +template <> +bool ValidateInput(const Tensor& updates) { + return true; +} + +} // namespace + template class ResourceScatterUpdateOp : public OpKernel { public: @@ -910,6 +939,12 @@ class ResourceScatterUpdateOp : public OpKernel { " indexing: ", params->dim_size(0), " > ", std::numeric_limits::max())); + // Prevent division by 0 + if (isCPUDevice() && op == tensorflow::scatter_op::UpdateOp::DIV) { + OP_REQUIRES(c, ValidateInput(updates), + errors::InvalidArgument("updates must not contain 0")); + } + if (N > 0) { auto indices_flat = indices.flat(); auto params_flat = params->flat_outer_dims(); -- 2.27.0