55 lines
2.6 KiB
Diff
55 lines
2.6 KiB
Diff
|
|
From 9c87c32c710d0b5b53dc6fd3bfde4046e1f7a5ad Mon Sep 17 00:00:00 2001
|
||
|
|
From: Laura Pak <lpak@google.com>
|
||
|
|
Date: Tue, 27 Jul 2021 12:11:33 -0700
|
||
|
|
Subject: [PATCH] Disallow empty node_id_range in
|
||
|
|
tf.raw_ops.BoostedTreesCalculateBestFeatureSplitV2 and
|
||
|
|
tf.raw_ops.BoostedTreesCalculateBestGainsPerFeature
|
||
|
|
|
||
|
|
PiperOrigin-RevId: 387165936
|
||
|
|
Change-Id: I2f70341af96236b2776c2a592c917d549c1fc1e2
|
||
|
|
---
|
||
|
|
.../core/kernels/boosted_trees/stats_ops.cc | 20 +++++++++++++++++++
|
||
|
|
1 file changed, 20 insertions(+)
|
||
|
|
|
||
|
|
diff --git a/tensorflow/core/kernels/boosted_trees/stats_ops.cc b/tensorflow/core/kernels/boosted_trees/stats_ops.cc
|
||
|
|
index 73e5b85f..48ace868 100644
|
||
|
|
--- a/tensorflow/core/kernels/boosted_trees/stats_ops.cc
|
||
|
|
+++ b/tensorflow/core/kernels/boosted_trees/stats_ops.cc
|
||
|
|
@@ -53,6 +53,16 @@ class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
|
||
|
|
// node_id_range
|
||
|
|
const Tensor* node_id_range_t;
|
||
|
|
OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t));
|
||
|
|
+ OP_REQUIRES(
|
||
|
|
+ context, node_id_range_t->dims() == 1,
|
||
|
|
+ errors::InvalidArgument("node_id_range must be a rank 1 tensor, but "
|
||
|
|
+ "given node_id_range has dims of ",
|
||
|
|
+ node_id_range_t->dims()));
|
||
|
|
+ OP_REQUIRES(context, node_id_range_t->dim_size(0) == 2,
|
||
|
|
+ errors::InvalidArgument(
|
||
|
|
+ "node_id_range must be a rank 1 tensor with shape=[2], but "
|
||
|
|
+ "given node_id_range has shape ",
|
||
|
|
+ node_id_range_t->dim_size(0), " on its first dim"));
|
||
|
|
const auto node_id_range = node_id_range_t->vec<int32>();
|
||
|
|
const int32 node_id_first = node_id_range(0); // inclusive
|
||
|
|
const int32 node_id_last = node_id_range(1); // exclusive
|
||
|
|
@@ -586,6 +596,16 @@ class BoostedTreesCalculateBestFeatureSplitV2 : public OpKernel {
|
||
|
|
const Tensor* node_id_range_t;
|
||
|
|
OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t));
|
||
|
|
const auto node_id_range = node_id_range_t->vec<int32>();
|
||
|
|
+ OP_REQUIRES(
|
||
|
|
+ context, node_id_range_t->dims() == 1,
|
||
|
|
+ errors::InvalidArgument("node_id_range must be a rank 1 tensor, but "
|
||
|
|
+ "given node_id_range has dim of ",
|
||
|
|
+ node_id_range_t->dims()));
|
||
|
|
+ OP_REQUIRES(context, node_id_range_t->dim_size(0) == 2,
|
||
|
|
+ errors::InvalidArgument(
|
||
|
|
+ "node_id_range must be a rank 1 tensor with shape=[2], but "
|
||
|
|
+ "given node_id_range has shape ",
|
||
|
|
+ node_id_range_t->dim_size(0), " on its first dim"));
|
||
|
|
const int32 node_id_first = node_id_range(0); // Inclusive.
|
||
|
|
const int32 node_id_last = node_id_range(1); // Exclusive.
|
||
|
|
|
||
|
|
--
|
||
|
|
2.27.0
|
||
|
|
|