91 lines
4.4 KiB
Diff
91 lines
4.4 KiB
Diff
From f7cc8755ac6683131fdfa7a8a121f9d7a9dec6fb Mon Sep 17 00:00:00 2001
|
|
From: Mihai Maruseac <mihaimaruseac@google.com>
|
|
Date: Wed, 5 May 2021 11:40:50 -0700
|
|
Subject: [PATCH] Add several missing validations in SDCA
|
|
|
|
PiperOrigin-RevId: 372172877
|
|
Change-Id: Id366da962432e18dcbfac847d11e98488bebb70a
|
|
---
|
|
tensorflow/core/kernels/sdca_internal.cc | 36 ++++++++++++++++++++++++
|
|
1 file changed, 36 insertions(+)
|
|
|
|
diff --git a/tensorflow/core/kernels/sdca_internal.cc b/tensorflow/core/kernels/sdca_internal.cc
|
|
index cbc754af0e9bb..11a3be8bf46a7 100644
|
|
--- a/tensorflow/core/kernels/sdca_internal.cc
|
|
+++ b/tensorflow/core/kernels/sdca_internal.cc
|
|
@@ -99,6 +99,10 @@ Status ModelWeights::Initialize(OpKernelContext* const context) {
|
|
OpInputList sparse_weights_inputs;
|
|
TF_RETURN_IF_ERROR(
|
|
context->input_list("sparse_weights", &sparse_weights_inputs));
|
|
+ if (sparse_indices_inputs.size() != sparse_weights_inputs.size())
|
|
+ return errors::InvalidArgument(
|
|
+ "sparse_indices and sparse_weights must have the same length, got ",
|
|
+ sparse_indices_inputs.size(), " and ", sparse_weights_inputs.size());
|
|
OpInputList dense_weights_inputs;
|
|
TF_RETURN_IF_ERROR(
|
|
context->input_list("dense_weights", &dense_weights_inputs));
|
|
@@ -106,10 +110,20 @@ Status ModelWeights::Initialize(OpKernelContext* const context) {
|
|
OpOutputList sparse_weights_outputs;
|
|
TF_RETURN_IF_ERROR(context->output_list("out_delta_sparse_weights",
|
|
&sparse_weights_outputs));
|
|
+ if (sparse_weights_outputs.size() != sparse_weights_inputs.size())
|
|
+ return errors::InvalidArgument(
|
|
+ "out_delta_sparse_weights and sparse_weights must have the same "
|
|
+ "length, got ",
|
|
+ sparse_weights_outputs.size(), " and ", sparse_weights_inputs.size());
|
|
|
|
OpOutputList dense_weights_outputs;
|
|
TF_RETURN_IF_ERROR(
|
|
context->output_list("out_delta_dense_weights", &dense_weights_outputs));
|
|
+ if (dense_weights_outputs.size() != dense_weights_inputs.size())
|
|
+ return errors::InvalidArgument(
|
|
+ "out_delta_dense_weights and dense_weights must have the same length, "
|
|
+ "got ",
|
|
+ dense_weights_outputs.size(), " and ", dense_weights_inputs.size());
|
|
|
|
for (int i = 0; i < sparse_weights_inputs.size(); ++i) {
|
|
Tensor* delta_t;
|
|
@@ -327,13 +341,28 @@ Status Examples::Initialize(OpKernelContext* const context,
|
|
OpInputList sparse_example_indices_inputs;
|
|
TF_RETURN_IF_ERROR(context->input_list("sparse_example_indices",
|
|
&sparse_example_indices_inputs));
|
|
+ if (sparse_example_indices_inputs.size() != num_sparse_features)
|
|
+ return errors::InvalidArgument(
|
|
+ "Expected ", num_sparse_features,
|
|
+ " tensors in sparse_example_indices but got ",
|
|
+ sparse_example_indices_inputs.size());
|
|
OpInputList sparse_feature_indices_inputs;
|
|
TF_RETURN_IF_ERROR(context->input_list("sparse_feature_indices",
|
|
&sparse_feature_indices_inputs));
|
|
+ if (sparse_feature_indices_inputs.size() != num_sparse_features)
|
|
+ return errors::InvalidArgument(
|
|
+ "Expected ", num_sparse_features,
|
|
+ " tensors in sparse_feature_indices but got ",
|
|
+ sparse_feature_indices_inputs.size());
|
|
OpInputList sparse_feature_values_inputs;
|
|
if (num_sparse_features_with_values > 0) {
|
|
TF_RETURN_IF_ERROR(context->input_list("sparse_feature_values",
|
|
&sparse_feature_values_inputs));
|
|
+ if (sparse_feature_values_inputs.size() != num_sparse_features_with_values)
|
|
+ return errors::InvalidArgument(
|
|
+ "Expected ", num_sparse_features_with_values,
|
|
+ " tensors in sparse_feature_values but got ",
|
|
+ sparse_feature_values_inputs.size());
|
|
}
|
|
|
|
const Tensor* example_weights_t;
|
|
@@ -400,6 +429,13 @@ Status Examples::CreateSparseFeatureRepresentation(
|
|
sparse_example_indices_inputs[i].template flat<int64>();
|
|
auto feature_indices =
|
|
sparse_feature_indices_inputs[i].template flat<int64>();
|
|
+ if (example_indices.size() != feature_indices.size()) {
|
|
+ mutex_lock l(mu);
|
|
+ result = errors::InvalidArgument(
|
|
+ "Found mismatched example_indices and feature_indices [",
|
|
+ example_indices, "] vs [", feature_indices, "]");
|
|
+ return;
|
|
+ }
|
|
|
|
// Parse features for each example. Features for a particular example
|
|
// are at the offsets (start_id, end_id]
|