134 lines
5.7 KiB
Diff
134 lines
5.7 KiB
Diff
From ba424dd8f16f7110eea526a8086f1a155f14f22b Mon Sep 17 00:00:00 2001
|
|
From: Mihai Maruseac <mihaimaruseac@google.com>
|
|
Date: Thu, 22 Apr 2021 13:29:54 -0700
|
|
Subject: [PATCH] Enhance validation of ngram op and handle case of 0 tokens.
|
|
|
|
PiperOrigin-RevId: 369940178
|
|
Change-Id: Ia82f42c09d14efe76e7dc013505b832a42282f0b
|
|
---
|
|
tensorflow/core/kernels/string_ngrams_op.cc | 52 +++++++++++++++----
|
|
.../core/kernels/string_ngrams_op_test.cc | 34 ++++++++++++
|
|
2 files changed, 75 insertions(+), 11 deletions(-)
|
|
|
|
diff --git a/tensorflow/core/kernels/string_ngrams_op.cc b/tensorflow/core/kernels/string_ngrams_op.cc
|
|
index 8aed2b3831a2f..7008a1d766af2 100644
|
|
--- a/tensorflow/core/kernels/string_ngrams_op.cc
|
|
+++ b/tensorflow/core/kernels/string_ngrams_op.cc
|
|
@@ -61,16 +61,28 @@ class StringNGramsOp : public tensorflow::OpKernel {
|
|
OP_REQUIRES_OK(context, context->input("data_splits", &splits));
|
|
const auto& splits_vec = splits->flat<SPLITS_TYPE>();
|
|
|
|
- // Validate that the splits are valid indices into data
|
|
+ // Validate that the splits are valid indices into data, only if there are
|
|
+ // splits specified.
|
|
const int input_data_size = data->flat<tstring>().size();
|
|
const int splits_vec_size = splits_vec.size();
|
|
- for (int i = 0; i < splits_vec_size; ++i) {
|
|
- bool valid_splits = splits_vec(i) >= 0;
|
|
- valid_splits = valid_splits && (splits_vec(i) <= input_data_size);
|
|
- OP_REQUIRES(
|
|
- context, valid_splits,
|
|
- errors::InvalidArgument("Invalid split value ", splits_vec(i),
|
|
- ", must be in [0,", input_data_size, "]"));
|
|
+ if (splits_vec_size > 0) {
|
|
+ int prev_split = splits_vec(0);
|
|
+ OP_REQUIRES(context, prev_split == 0,
|
|
+ errors::InvalidArgument("First split value must be 0, got ",
|
|
+ prev_split));
|
|
+ for (int i = 1; i < splits_vec_size; ++i) {
|
|
+ bool valid_splits = splits_vec(i) >= prev_split;
|
|
+ valid_splits = valid_splits && (splits_vec(i) <= input_data_size);
|
|
+ OP_REQUIRES(context, valid_splits,
|
|
+ errors::InvalidArgument(
|
|
+ "Invalid split value ", splits_vec(i), ", must be in [",
|
|
+ prev_split, ", ", input_data_size, "]"));
|
|
+ prev_split = splits_vec(i);
|
|
+ }
|
|
+ OP_REQUIRES(context, prev_split == input_data_size,
|
|
+ errors::InvalidArgument(
|
|
+ "Last split value must be data size. Expected ",
|
|
+ input_data_size, ", got ", prev_split));
|
|
}
|
|
|
|
int num_batch_items = splits_vec.size() - 1;
|
|
@@ -174,13 +186,31 @@ class StringNGramsOp : public tensorflow::OpKernel {
|
|
ngram->append(left_pad_);
|
|
ngram->append(separator_);
|
|
}
|
|
+ // Only output first num_tokens - 1 pairs of data and separator
|
|
for (int n = 0; n < num_tokens - 1; ++n) {
|
|
ngram->append(data[data_start_index + n]);
|
|
ngram->append(separator_);
|
|
}
|
|
- ngram->append(data[data_start_index + num_tokens - 1]);
|
|
- for (int n = 0; n < right_padding; ++n) {
|
|
- ngram->append(separator_);
|
|
+ // Handle case when there are no tokens or no right padding as these can
|
|
+ // result in consecutive separators.
|
|
+ if (num_tokens > 0) {
|
|
+ // If we have tokens, then output last and then pair each separator with
|
|
+ // the right padding that follows, to ensure ngram ends either with the
|
|
+ // token or with the right pad.
|
|
+ ngram->append(data[data_start_index + num_tokens - 1]);
|
|
+ for (int n = 0; n < right_padding; ++n) {
|
|
+ ngram->append(separator_);
|
|
+ ngram->append(right_pad_);
|
|
+ }
|
|
+ } else {
|
|
+ // If we don't have tokens, then the last item inserted into the ngram
|
|
+ // has been the separator from the left padding loop above. Hence,
|
|
+ // output right pad and separator and make sure to finish with a
|
|
+ // padding, not a separator.
|
|
+ for (int n = 0; n < right_padding - 1; ++n) {
|
|
+ ngram->append(right_pad_);
|
|
+ ngram->append(separator_);
|
|
+ }
|
|
ngram->append(right_pad_);
|
|
}
|
|
|
|
diff --git a/tensorflow/core/kernels/string_ngrams_op_test.cc b/tensorflow/core/kernels/string_ngrams_op_test.cc
|
|
index b89de9ad16dab..0d52283bd8fb9 100644
|
|
--- a/tensorflow/core/kernels/string_ngrams_op_test.cc
|
|
+++ b/tensorflow/core/kernels/string_ngrams_op_test.cc
|
|
@@ -542,6 +542,40 @@ TEST_F(NgramKernelTest, TestEmptyInput) {
|
|
assert_int64_equal(expected_splits, *GetOutput(1));
|
|
}
|
|
|
|
+TEST_F(NgramKernelTest, TestNoTokens) {
|
|
+ MakeOp("|", {3}, "L", "R", -1, false);
|
|
+ // Batch items are:
|
|
+ // 0:
|
|
+ // 1: "a"
|
|
+ AddInputFromArray<tstring>(TensorShape({1}), {"a"});
|
|
+ AddInputFromArray<int64>(TensorShape({3}), {0, 0, 1});
|
|
+ TF_ASSERT_OK(RunOpKernel());
|
|
+
|
|
+ std::vector<tstring> expected_values(
|
|
+ {"L|L|R", "L|R|R", // no input in first split
|
|
+ "L|L|a", "L|a|R", "a|R|R"}); // second split
|
|
+ std::vector<int64> expected_splits({0, 2, 5});
|
|
+
|
|
+ assert_string_equal(expected_values, *GetOutput(0));
|
|
+ assert_int64_equal(expected_splits, *GetOutput(1));
|
|
+}
|
|
+
|
|
+TEST_F(NgramKernelTest, TestNoTokensNoPad) {
|
|
+ MakeOp("|", {3}, "", "", 0, false);
|
|
+ // Batch items are:
|
|
+ // 0:
|
|
+ // 1: "a"
|
|
+ AddInputFromArray<tstring>(TensorShape({1}), {"a"});
|
|
+ AddInputFromArray<int64>(TensorShape({3}), {0, 0, 1});
|
|
+ TF_ASSERT_OK(RunOpKernel());
|
|
+
|
|
+ std::vector<tstring> expected_values({});
|
|
+ std::vector<int64> expected_splits({0, 0, 0});
|
|
+
|
|
+ assert_string_equal(expected_values, *GetOutput(0));
|
|
+ assert_int64_equal(expected_splits, *GetOutput(1));
|
|
+}
|
|
+
|
|
TEST_F(NgramKernelTest, ShapeFn) {
|
|
ShapeInferenceTestOp op("StringNGrams");
|
|
INFER_OK(op, "?;?", "[?];[?]");
|