From 429f009d2b2c09028647dd4bb7b3f6f414bbaad7 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Wed, 28 Jul 2021 13:25:18 -0700 Subject: [PATCH] Add remaining missing validation to `BoostedTreesCalculateBestFeatureSplit` PiperOrigin-RevId: 387423006 Change-Id: I8eaf30efb223011519e60707bfa751b275d3a443 --- .../core/kernels/boosted_trees/stats_ops.cc | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/boosted_trees/stats_ops.cc b/tensorflow/core/kernels/boosted_trees/stats_ops.cc index 851e5b78..b26ab49e 100644 --- a/tensorflow/core/kernels/boosted_trees/stats_ops.cc +++ b/tensorflow/core/kernels/boosted_trees/stats_ops.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include "third_party/eigen3/Eigen/Core" @@ -22,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h" #include "tensorflow/core/kernels/boosted_trees/tree_helper.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { @@ -244,12 +246,18 @@ class BoostedTreesCalculateBestFeatureSplitOp : public OpKernel { // node_id_range const Tensor* node_id_range_t; OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t)); + OP_REQUIRES( + context, node_id_range_t->NumElements() == 2, + errors::InvalidArgument("node_id_range argument must have shape [2]")); const auto node_id_range = node_id_range_t->vec(); const int32 node_id_first = node_id_range(0); // inclusive const int32 node_id_last = node_id_range(1); // exclusive const Tensor* stats_summary_t; OP_REQUIRES_OK(context, context->input("stats_summary", &stats_summary_t)); + OP_REQUIRES( + context, stats_summary_t->shape().dims() == 4, + errors::InvalidArgument("stats_summary argument must have rank 4")); TTypes::ConstTensor stats_summary = stats_summary_t->tensor(); const int32 feature_dims = stats_summary_t->dim_size(1); @@ -262,6 +270,8 @@ class BoostedTreesCalculateBestFeatureSplitOp : public OpKernel { const Tensor* l1_t; OP_REQUIRES_OK(context, context->input("l1", &l1_t)); + OP_REQUIRES(context, l1_t->NumElements() == 1, + errors::InvalidArgument("l1 argument must be a scalar")); const auto l1 = l1_t->scalar()(); DCHECK_GE(l1, 0); if (logits_dim_ > 1) { @@ -271,17 +281,25 @@ class BoostedTreesCalculateBestFeatureSplitOp : public OpKernel { const Tensor* l2_t; OP_REQUIRES_OK(context, context->input("l2", &l2_t)); + OP_REQUIRES(context, l2_t->NumElements() == 1, + errors::InvalidArgument("l2 argument must be a scalar")); const auto l2 = l2_t->scalar()(); DCHECK_GE(l2, 0); const Tensor* tree_complexity_t; OP_REQUIRES_OK(context, context->input("tree_complexity", &tree_complexity_t)); + OP_REQUIRES( + context, tree_complexity_t->NumElements() == 1, + errors::InvalidArgument("tree_complexity argument must be a scalar")); const auto tree_complexity = tree_complexity_t->scalar()(); const Tensor* min_node_weight_t; OP_REQUIRES_OK(context, context->input("min_node_weight", &min_node_weight_t)); + OP_REQUIRES( + context, min_node_weight_t->NumElements() == 1, + errors::InvalidArgument("min_node_weight argument must be a scalar")); const auto min_node_weight = min_node_weight_t->scalar()(); std::vector output_node_ids; @@ -290,7 +308,7 @@ class BoostedTreesCalculateBestFeatureSplitOp : public OpKernel { std::vector output_thresholds; std::vector output_left_node_contribs; std::vector output_right_node_contribs; - std::vector output_split_types; + std::vector output_split_types; // TODO(tanzheny) parallelize the computation. // Iterate each node and find the best gain per node. -- 2.27.0