144 lines
6.9 KiB
Diff
144 lines
6.9 KiB
Diff
From b1cc5e5a50e7cee09f2c6eb48eb40ee9c4125025 Mon Sep 17 00:00:00 2001
|
|
From: Amit Patankar <amitpatankar@google.com>
|
|
Date: Thu, 15 Apr 2021 13:03:19 -0700
|
|
Subject: [PATCH] Fix `tf.raw_ops.SparseCross` failing CHECK.
|
|
|
|
PiperOrigin-RevId: 368701671
|
|
Change-Id: Id805729dd9ba0bda36e4bb309408129b55fb649d
|
|
---
|
|
tensorflow/core/kernels/sparse_cross_op.cc | 55 +++++++++++++++++++---
|
|
1 file changed, 48 insertions(+), 7 deletions(-)
|
|
|
|
diff --git a/tensorflow/core/kernels/sparse_cross_op.cc b/tensorflow/core/kernels/sparse_cross_op.cc
|
|
index 583235b4a309b..43b3bedc74503 100644
|
|
--- a/tensorflow/core/kernels/sparse_cross_op.cc
|
|
+++ b/tensorflow/core/kernels/sparse_cross_op.cc
|
|
@@ -27,6 +27,7 @@ limitations under the License.
|
|
#include "tensorflow/core/framework/tensor.h"
|
|
#include "tensorflow/core/framework/tensor_shape.h"
|
|
#include "tensorflow/core/framework/types.h"
|
|
+#include "tensorflow/core/framework/types.pb.h"
|
|
#include "tensorflow/core/lib/core/stringpiece.h"
|
|
#include "tensorflow/core/lib/strings/str_util.h"
|
|
#include "tensorflow/core/platform/fingerprint.h"
|
|
@@ -460,10 +461,19 @@ int64 CalculateBatchSize(const OpInputList& shapes_list_in,
|
|
Status ValidateInput(const OpInputList& indices_list_in,
|
|
const OpInputList& values_list_in,
|
|
const OpInputList& shapes_list_in,
|
|
- const OpInputList& dense_list_in) {
|
|
+ const OpInputList& dense_list_in,
|
|
+ const DataType& internal_type) {
|
|
const auto size = indices_list_in.size();
|
|
+ // Only perform internal_type check for SparseCrossOp.
|
|
+ // Check if the internal_type is not invalid before doing so.
|
|
+ bool check_type = internal_type != DT_INVALID;
|
|
// Validates indices_list_in OpInputList.
|
|
for (int i = 0; i < size; i++) {
|
|
+ if (check_type && indices_list_in[i].dtype() != DT_INT64) {
|
|
+ return errors::InvalidArgument("Input indices should be of type ",
|
|
+ DT_INT64, " but received ",
|
|
+ indices_list_in[i].dtype());
|
|
+ }
|
|
if (!TensorShapeUtils::IsMatrix(indices_list_in[i].shape())) {
|
|
return errors::InvalidArgument(
|
|
"Input indices should be a matrix but received shape ",
|
|
@@ -482,6 +492,14 @@ Status ValidateInput(const OpInputList& indices_list_in,
|
|
values_list_in.size());
|
|
}
|
|
for (int i = 0; i < size; i++) {
|
|
+ // Make sure to avoid the expected type to be string, but input values to be
|
|
+ // int64.
|
|
+ if (check_type && internal_type == DT_STRING &&
|
|
+ values_list_in[i].dtype() == DT_INT64) {
|
|
+ return errors::InvalidArgument("Input values should be of internal type ",
|
|
+ internal_type, " but received ",
|
|
+ values_list_in[i].dtype());
|
|
+ }
|
|
if (!TensorShapeUtils::IsVector(values_list_in[i].shape())) {
|
|
return errors::InvalidArgument(
|
|
"Input values should be a vector but received shape ",
|
|
@@ -502,6 +520,11 @@ Status ValidateInput(const OpInputList& indices_list_in,
|
|
shapes_list_in.size());
|
|
}
|
|
for (int i = 0; i < size; i++) {
|
|
+ if (check_type && shapes_list_in[i].dtype() != DT_INT64) {
|
|
+ return errors::InvalidArgument("Input shape should be of type ", DT_INT64,
|
|
+ " but received ",
|
|
+ shapes_list_in[i].dtype());
|
|
+ }
|
|
if (!TensorShapeUtils::IsVector(shapes_list_in[i].shape())) {
|
|
return errors::InvalidArgument(
|
|
"Input shapes should be a vector but received shape ",
|
|
@@ -517,6 +540,14 @@ Status ValidateInput(const OpInputList& indices_list_in,
|
|
|
|
// Validates dense_list_in OpInputList
|
|
for (int i = 0; i < dense_list_in.size(); ++i) {
|
|
+ // Make sure to avoid the expected type to be string, but input values to be
|
|
+ // int64.
|
|
+ if (check_type && internal_type == DT_STRING &&
|
|
+ dense_list_in[i].dtype() == DT_INT64) {
|
|
+ return errors::InvalidArgument("Dense inputs should be of internal type ",
|
|
+ internal_type, " but received ",
|
|
+ dense_list_in[i].dtype());
|
|
+ }
|
|
if (!TensorShapeUtils::IsMatrix(dense_list_in[i].shape())) {
|
|
return errors::InvalidArgument(
|
|
"Dense inputs should be a matrix but received shape ",
|
|
@@ -698,6 +729,7 @@ class SparseCrossOp : public OpKernel {
|
|
int64 signed_hash_key_;
|
|
OP_REQUIRES_OK(context, context->GetAttr("hash_key", &signed_hash_key_));
|
|
hash_key_ = static_cast<uint64>(signed_hash_key_);
|
|
+ OP_REQUIRES_OK(context, context->GetAttr("internal_type", &internal_type_));
|
|
}
|
|
|
|
void Compute(OpKernelContext* context) override {
|
|
@@ -711,8 +743,10 @@ class SparseCrossOp : public OpKernel {
|
|
OP_REQUIRES_OK(context,
|
|
context->input_list("dense_inputs", &dense_list_in));
|
|
|
|
- OP_REQUIRES_OK(context, ValidateInput(indices_list_in, values_list_in,
|
|
- shapes_list_in, dense_list_in));
|
|
+ DataType internal_type = internal_type_;
|
|
+ OP_REQUIRES_OK(
|
|
+ context, ValidateInput(indices_list_in, values_list_in, shapes_list_in,
|
|
+ dense_list_in, internal_type));
|
|
|
|
std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns =
|
|
GenerateColumnsFromInput<InternalType>(indices_list_in, values_list_in,
|
|
@@ -756,6 +790,7 @@ class SparseCrossOp : public OpKernel {
|
|
private:
|
|
int64 num_buckets_;
|
|
uint64 hash_key_;
|
|
+ DataType internal_type_;
|
|
};
|
|
|
|
class SparseCrossV2Op : public OpKernel {
|
|
@@ -773,8 +808,11 @@ class SparseCrossV2Op : public OpKernel {
|
|
OP_REQUIRES_OK(context,
|
|
context->input_list("dense_inputs", &dense_list_in));
|
|
|
|
- OP_REQUIRES_OK(context, ValidateInput(indices_list_in, values_list_in,
|
|
- shapes_list_in, dense_list_in));
|
|
+ // Set internal_type to invalid_type so that the check will be ignored.
|
|
+ DataType internal_type = DT_INVALID;
|
|
+ OP_REQUIRES_OK(
|
|
+ context, ValidateInput(indices_list_in, values_list_in, shapes_list_in,
|
|
+ dense_list_in, internal_type));
|
|
|
|
const Tensor* sep_t;
|
|
OP_REQUIRES_OK(context, context->input("sep", &sep_t));
|
|
@@ -832,8 +870,11 @@ class SparseCrossHashedOp : public OpKernel {
|
|
OP_REQUIRES_OK(context,
|
|
context->input_list("dense_inputs", &dense_list_in));
|
|
|
|
- OP_REQUIRES_OK(context, ValidateInput(indices_list_in, values_list_in,
|
|
- shapes_list_in, dense_list_in));
|
|
+ // Set internal_type to invalid_type so that the check will be ignored.
|
|
+ DataType internal_type = DT_INVALID;
|
|
+ OP_REQUIRES_OK(
|
|
+ context, ValidateInput(indices_list_in, values_list_in, shapes_list_in,
|
|
+ dense_list_in, internal_type));
|
|
|
|
const Tensor* num_buckets_t;
|
|
OP_REQUIRES_OK(context, context->input("num_buckets", &num_buckets_t));
|