From 9c87c32c710d0b5b53dc6fd3bfde4046e1f7a5ad Mon Sep 17 00:00:00 2001 From: Laura Pak Date: Tue, 27 Jul 2021 12:11:33 -0700 Subject: [PATCH] Disallow empty node_id_range in tf.raw_ops.BoostedTreesCalculateBestFeatureSplitV2 and tf.raw_ops.BoostedTreesCalculateBestGainsPerFeature PiperOrigin-RevId: 387165936 Change-Id: I2f70341af96236b2776c2a592c917d549c1fc1e2 --- .../core/kernels/boosted_trees/stats_ops.cc | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tensorflow/core/kernels/boosted_trees/stats_ops.cc b/tensorflow/core/kernels/boosted_trees/stats_ops.cc index 73e5b85f..48ace868 100644 --- a/tensorflow/core/kernels/boosted_trees/stats_ops.cc +++ b/tensorflow/core/kernels/boosted_trees/stats_ops.cc @@ -53,6 +53,16 @@ class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel { // node_id_range const Tensor* node_id_range_t; OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t)); + OP_REQUIRES( + context, node_id_range_t->dims() == 1, + errors::InvalidArgument("node_id_range must be a rank 1 tensor, but " + "given node_id_range has dims of ", + node_id_range_t->dims())); + OP_REQUIRES(context, node_id_range_t->dim_size(0) == 2, + errors::InvalidArgument( + "node_id_range must be a rank 1 tensor with shape=[2], but " + "given node_id_range has shape ", + node_id_range_t->dim_size(0), " on its first dim")); const auto node_id_range = node_id_range_t->vec(); const int32 node_id_first = node_id_range(0); // inclusive const int32 node_id_last = node_id_range(1); // exclusive @@ -586,6 +596,16 @@ class BoostedTreesCalculateBestFeatureSplitV2 : public OpKernel { const Tensor* node_id_range_t; OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t)); const auto node_id_range = node_id_range_t->vec(); + OP_REQUIRES( + context, node_id_range_t->dims() == 1, + errors::InvalidArgument("node_id_range must be a rank 1 tensor, but " + "given node_id_range has dim of ", + node_id_range_t->dims())); + OP_REQUIRES(context, node_id_range_t->dim_size(0) == 2, + errors::InvalidArgument( + "node_id_range must be a rank 1 tensor with shape=[2], but " + "given node_id_range has shape ", + node_id_range_t->dim_size(0), " on its first dim")); const int32 node_id_first = node_id_range(0); // Inclusive. const int32 node_id_last = node_id_range(1); // Exclusive. -- 2.27.0