99 lines
4.2 KiB
Diff
99 lines
4.2 KiB
Diff
|
|
From 429f009d2b2c09028647dd4bb7b3f6f414bbaad7 Mon Sep 17 00:00:00 2001
|
||
|
|
From: Mihai Maruseac <mihaimaruseac@google.com>
|
||
|
|
Date: Wed, 28 Jul 2021 13:25:18 -0700
|
||
|
|
Subject: [PATCH] Add remaining missing validation to
|
||
|
|
`BoostedTreesCalculateBestFeatureSplit`
|
||
|
|
|
||
|
|
PiperOrigin-RevId: 387423006
|
||
|
|
Change-Id: I8eaf30efb223011519e60707bfa751b275d3a443
|
||
|
|
---
|
||
|
|
.../core/kernels/boosted_trees/stats_ops.cc | 20 ++++++++++++++++++-
|
||
|
|
1 file changed, 19 insertions(+), 1 deletion(-)
|
||
|
|
|
||
|
|
diff --git a/tensorflow/core/kernels/boosted_trees/stats_ops.cc b/tensorflow/core/kernels/boosted_trees/stats_ops.cc
|
||
|
|
index 851e5b78..b26ab49e 100644
|
||
|
|
--- a/tensorflow/core/kernels/boosted_trees/stats_ops.cc
|
||
|
|
+++ b/tensorflow/core/kernels/boosted_trees/stats_ops.cc
|
||
|
|
@@ -14,6 +14,7 @@ limitations under the License.
|
||
|
|
==============================================================================*/
|
||
|
|
|
||
|
|
#include <limits>
|
||
|
|
+#include <string>
|
||
|
|
#include <vector>
|
||
|
|
|
||
|
|
#include "third_party/eigen3/Eigen/Core"
|
||
|
|
@@ -22,6 +23,7 @@ limitations under the License.
|
||
|
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||
|
|
#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
|
||
|
|
#include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
|
||
|
|
+#include "tensorflow/core/platform/errors.h"
|
||
|
|
#include "tensorflow/core/platform/logging.h"
|
||
|
|
|
||
|
|
namespace tensorflow {
|
||
|
|
@@ -244,12 +246,18 @@ class BoostedTreesCalculateBestFeatureSplitOp : public OpKernel {
|
||
|
|
// node_id_range
|
||
|
|
const Tensor* node_id_range_t;
|
||
|
|
OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t));
|
||
|
|
+ OP_REQUIRES(
|
||
|
|
+ context, node_id_range_t->NumElements() == 2,
|
||
|
|
+ errors::InvalidArgument("node_id_range argument must have shape [2]"));
|
||
|
|
const auto node_id_range = node_id_range_t->vec<int32>();
|
||
|
|
const int32 node_id_first = node_id_range(0); // inclusive
|
||
|
|
const int32 node_id_last = node_id_range(1); // exclusive
|
||
|
|
|
||
|
|
const Tensor* stats_summary_t;
|
||
|
|
OP_REQUIRES_OK(context, context->input("stats_summary", &stats_summary_t));
|
||
|
|
+ OP_REQUIRES(
|
||
|
|
+ context, stats_summary_t->shape().dims() == 4,
|
||
|
|
+ errors::InvalidArgument("stats_summary argument must have rank 4"));
|
||
|
|
TTypes<float, 4>::ConstTensor stats_summary =
|
||
|
|
stats_summary_t->tensor<float, 4>();
|
||
|
|
const int32 feature_dims = stats_summary_t->dim_size(1);
|
||
|
|
@@ -262,6 +270,8 @@ class BoostedTreesCalculateBestFeatureSplitOp : public OpKernel {
|
||
|
|
|
||
|
|
const Tensor* l1_t;
|
||
|
|
OP_REQUIRES_OK(context, context->input("l1", &l1_t));
|
||
|
|
+ OP_REQUIRES(context, l1_t->NumElements() == 1,
|
||
|
|
+ errors::InvalidArgument("l1 argument must be a scalar"));
|
||
|
|
const auto l1 = l1_t->scalar<float>()();
|
||
|
|
DCHECK_GE(l1, 0);
|
||
|
|
if (logits_dim_ > 1) {
|
||
|
|
@@ -271,17 +281,25 @@ class BoostedTreesCalculateBestFeatureSplitOp : public OpKernel {
|
||
|
|
|
||
|
|
const Tensor* l2_t;
|
||
|
|
OP_REQUIRES_OK(context, context->input("l2", &l2_t));
|
||
|
|
+ OP_REQUIRES(context, l2_t->NumElements() == 1,
|
||
|
|
+ errors::InvalidArgument("l2 argument must be a scalar"));
|
||
|
|
const auto l2 = l2_t->scalar<float>()();
|
||
|
|
DCHECK_GE(l2, 0);
|
||
|
|
|
||
|
|
const Tensor* tree_complexity_t;
|
||
|
|
OP_REQUIRES_OK(context,
|
||
|
|
context->input("tree_complexity", &tree_complexity_t));
|
||
|
|
+ OP_REQUIRES(
|
||
|
|
+ context, tree_complexity_t->NumElements() == 1,
|
||
|
|
+ errors::InvalidArgument("tree_complexity argument must be a scalar"));
|
||
|
|
const auto tree_complexity = tree_complexity_t->scalar<float>()();
|
||
|
|
|
||
|
|
const Tensor* min_node_weight_t;
|
||
|
|
OP_REQUIRES_OK(context,
|
||
|
|
context->input("min_node_weight", &min_node_weight_t));
|
||
|
|
+ OP_REQUIRES(
|
||
|
|
+ context, min_node_weight_t->NumElements() == 1,
|
||
|
|
+ errors::InvalidArgument("min_node_weight argument must be a scalar"));
|
||
|
|
const auto min_node_weight = min_node_weight_t->scalar<float>()();
|
||
|
|
|
||
|
|
std::vector<int32> output_node_ids;
|
||
|
|
@@ -290,7 +308,7 @@ class BoostedTreesCalculateBestFeatureSplitOp : public OpKernel {
|
||
|
|
std::vector<int32> output_thresholds;
|
||
|
|
std::vector<Eigen::VectorXf> output_left_node_contribs;
|
||
|
|
std::vector<Eigen::VectorXf> output_right_node_contribs;
|
||
|
|
- std::vector<string> output_split_types;
|
||
|
|
+ std::vector<std::string> output_split_types;
|
||
|
|
|
||
|
|
// TODO(tanzheny) parallelize the computation.
|
||
|
|
// Iterate each node and find the best gain per node.
|
||
|
|
--
|
||
|
|
2.27.0
|
||
|
|
|