71 lines
2.1 KiB
Diff
71 lines
2.1 KiB
Diff
From 4aacb30888638da75023e6601149415b39763d76 Mon Sep 17 00:00:00 2001
|
|
From: Mihai Maruseac <mihaimaruseac@google.com>
|
|
Date: Tue, 3 Aug 2021 12:28:58 -0700
|
|
Subject: [PATCH] Disallow division by zero FPE in
|
|
`tf.raw_ops.ResourceScatterDiv`
|
|
|
|
Had to update a test that was broken.
|
|
|
|
PiperOrigin-RevId: 388516976
|
|
Change-Id: Ic358e6bf0559e011539974d453fc7aa18b427e9c
|
|
---
|
|
.../core/kernels/resource_variable_ops.cc | 35 +++++++++++++++++++
|
|
1 file changed, 35 insertions(+)
|
|
|
|
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
|
|
index b9c883c7..e056d9cb 100644
|
|
--- a/tensorflow/core/kernels/resource_variable_ops.cc
|
|
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
|
|
@@ -844,6 +844,35 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_ND_GPU);
|
|
#undef REGISTER_GATHER_ND_ALL_INDICES
|
|
#undef REGISTER_GATHER_ND_FULL
|
|
|
|
+namespace {
|
|
+
|
|
+template <typename Device>
|
|
+bool isCPUDevice() {
|
|
+ return false;
|
|
+}
|
|
+
|
|
+template <>
|
|
+bool isCPUDevice<CPUDevice>() {
|
|
+ return true;
|
|
+}
|
|
+
|
|
+template <typename T>
|
|
+bool ValidateInput(const Tensor& updates) {
|
|
+ const auto updates_flat = updates.flat<T>();
|
|
+ const T zero(0);
|
|
+ for (int i = 0; i < updates.NumElements(); i++) {
|
|
+ if (updates_flat(i) == zero) return false;
|
|
+ }
|
|
+ return true;
|
|
+}
|
|
+
|
|
+template <>
|
|
+bool ValidateInput<Variant>(const Tensor& updates) {
|
|
+ return true;
|
|
+}
|
|
+
|
|
+} // namespace
|
|
+
|
|
template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
|
|
class ResourceScatterUpdateOp : public OpKernel {
|
|
public:
|
|
@@ -910,6 +939,12 @@ class ResourceScatterUpdateOp : public OpKernel {
|
|
" indexing: ", params->dim_size(0), " > ",
|
|
std::numeric_limits<Index>::max()));
|
|
|
|
+ // Prevent division by 0
|
|
+ if (isCPUDevice<Device>() && op == tensorflow::scatter_op::UpdateOp::DIV) {
|
|
+ OP_REQUIRES(c, ValidateInput<T>(updates),
|
|
+ errors::InvalidArgument("updates must not contain 0"));
|
|
+ }
|
|
+
|
|
if (N > 0) {
|
|
auto indices_flat = indices.flat<Index>();
|
|
auto params_flat = params->flat_outer_dims<T>();
|
|
--
|
|
2.27.0
|
|
|