From b1cc5e5a50e7cee09f2c6eb48eb40ee9c4125025 Mon Sep 17 00:00:00 2001 From: Amit Patankar Date: Thu, 15 Apr 2021 13:03:19 -0700 Subject: [PATCH] Fix `tf.raw_ops.SparseCross` failing CHECK. PiperOrigin-RevId: 368701671 Change-Id: Id805729dd9ba0bda36e4bb309408129b55fb649d --- tensorflow/core/kernels/sparse_cross_op.cc | 55 +++++++++++++++++++--- 1 file changed, 48 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/kernels/sparse_cross_op.cc b/tensorflow/core/kernels/sparse_cross_op.cc index 583235b4a309b..43b3bedc74503 100644 --- a/tensorflow/core/kernels/sparse_cross_op.cc +++ b/tensorflow/core/kernels/sparse_cross_op.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/fingerprint.h" @@ -460,10 +461,19 @@ int64 CalculateBatchSize(const OpInputList& shapes_list_in, Status ValidateInput(const OpInputList& indices_list_in, const OpInputList& values_list_in, const OpInputList& shapes_list_in, - const OpInputList& dense_list_in) { + const OpInputList& dense_list_in, + const DataType& internal_type) { const auto size = indices_list_in.size(); + // Only perform internal_type check for SparseCrossOp. + // Check if the internal_type is not invalid before doing so. + bool check_type = internal_type != DT_INVALID; // Validates indices_list_in OpInputList. for (int i = 0; i < size; i++) { + if (check_type && indices_list_in[i].dtype() != DT_INT64) { + return errors::InvalidArgument("Input indices should be of type ", + DT_INT64, " but received ", + indices_list_in[i].dtype()); + } if (!TensorShapeUtils::IsMatrix(indices_list_in[i].shape())) { return errors::InvalidArgument( "Input indices should be a matrix but received shape ", @@ -482,6 +492,14 @@ Status ValidateInput(const OpInputList& indices_list_in, values_list_in.size()); } for (int i = 0; i < size; i++) { + // Make sure to avoid the expected type to be string, but input values to be + // int64. + if (check_type && internal_type == DT_STRING && + values_list_in[i].dtype() == DT_INT64) { + return errors::InvalidArgument("Input values should be of internal type ", + internal_type, " but received ", + values_list_in[i].dtype()); + } if (!TensorShapeUtils::IsVector(values_list_in[i].shape())) { return errors::InvalidArgument( "Input values should be a vector but received shape ", @@ -502,6 +520,11 @@ Status ValidateInput(const OpInputList& indices_list_in, shapes_list_in.size()); } for (int i = 0; i < size; i++) { + if (check_type && shapes_list_in[i].dtype() != DT_INT64) { + return errors::InvalidArgument("Input shape should be of type ", DT_INT64, + " but received ", + shapes_list_in[i].dtype()); + } if (!TensorShapeUtils::IsVector(shapes_list_in[i].shape())) { return errors::InvalidArgument( "Input shapes should be a vector but received shape ", @@ -517,6 +540,14 @@ Status ValidateInput(const OpInputList& indices_list_in, // Validates dense_list_in OpInputList for (int i = 0; i < dense_list_in.size(); ++i) { + // Make sure to avoid the expected type to be string, but input values to be + // int64. + if (check_type && internal_type == DT_STRING && + dense_list_in[i].dtype() == DT_INT64) { + return errors::InvalidArgument("Dense inputs should be of internal type ", + internal_type, " but received ", + dense_list_in[i].dtype()); + } if (!TensorShapeUtils::IsMatrix(dense_list_in[i].shape())) { return errors::InvalidArgument( "Dense inputs should be a matrix but received shape ", @@ -698,6 +729,7 @@ class SparseCrossOp : public OpKernel { int64 signed_hash_key_; OP_REQUIRES_OK(context, context->GetAttr("hash_key", &signed_hash_key_)); hash_key_ = static_cast(signed_hash_key_); + OP_REQUIRES_OK(context, context->GetAttr("internal_type", &internal_type_)); } void Compute(OpKernelContext* context) override { @@ -711,8 +743,10 @@ class SparseCrossOp : public OpKernel { OP_REQUIRES_OK(context, context->input_list("dense_inputs", &dense_list_in)); - OP_REQUIRES_OK(context, ValidateInput(indices_list_in, values_list_in, - shapes_list_in, dense_list_in)); + DataType internal_type = internal_type_; + OP_REQUIRES_OK( + context, ValidateInput(indices_list_in, values_list_in, shapes_list_in, + dense_list_in, internal_type)); std::vector>> columns = GenerateColumnsFromInput(indices_list_in, values_list_in, @@ -756,6 +790,7 @@ class SparseCrossOp : public OpKernel { private: int64 num_buckets_; uint64 hash_key_; + DataType internal_type_; }; class SparseCrossV2Op : public OpKernel { @@ -773,8 +808,11 @@ class SparseCrossV2Op : public OpKernel { OP_REQUIRES_OK(context, context->input_list("dense_inputs", &dense_list_in)); - OP_REQUIRES_OK(context, ValidateInput(indices_list_in, values_list_in, - shapes_list_in, dense_list_in)); + // Set internal_type to invalid_type so that the check will be ignored. + DataType internal_type = DT_INVALID; + OP_REQUIRES_OK( + context, ValidateInput(indices_list_in, values_list_in, shapes_list_in, + dense_list_in, internal_type)); const Tensor* sep_t; OP_REQUIRES_OK(context, context->input("sep", &sep_t)); @@ -832,8 +870,11 @@ class SparseCrossHashedOp : public OpKernel { OP_REQUIRES_OK(context, context->input_list("dense_inputs", &dense_list_in)); - OP_REQUIRES_OK(context, ValidateInput(indices_list_in, values_list_in, - shapes_list_in, dense_list_in)); + // Set internal_type to invalid_type so that the check will be ignored. + DataType internal_type = DT_INVALID; + OP_REQUIRES_OK( + context, ValidateInput(indices_list_in, values_list_in, shapes_list_in, + dense_list_in, internal_type)); const Tensor* num_buckets_t; OP_REQUIRES_OK(context, context->input("num_buckets", &num_buckets_t));