200 lines
6.5 KiB
Diff
200 lines
6.5 KiB
Diff
From ce47a396ff795bdb6cf48eb53dbcba46cb51fa7d Mon Sep 17 00:00:00 2001
|
|
From: Katherine Tian <kattian@google.com>
|
|
Date: Tue, 30 Jun 2020 04:12:11 +0000
|
|
Subject: [PATCH 1/1] TensorKey class and TensorMap tests
|
|
|
|
---
|
|
tensorflow/core/BUILD | 1 +
|
|
tensorflow/core/framework/BUILD | 70 ++++++++++++++++++++++++++
|
|
tensorflow/core/framework/tensor_key.h | 64 +++++++++++++++++++++++
|
|
tensorflow/core/kernels/BUILD | 1 +
|
|
4 files changed, 136 insertions(+)
|
|
create mode 100644 tensorflow/core/framework/tensor_key.h
|
|
|
|
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
|
|
index d0be6ee9..6e745b4e 100644
|
|
--- a/tensorflow/core/BUILD
|
|
+++ b/tensorflow/core/BUILD
|
|
@@ -495,6 +495,7 @@ tf_cuda_library(
|
|
"//tensorflow/core/framework:shared_ptr_variant.h",
|
|
"//tensorflow/core/framework:stats_aggregator.h",
|
|
"//tensorflow/core/framework:tensor.h",
|
|
+ "//tensorflow/core/framework:tensor_key.h",
|
|
"//tensorflow/core/framework:tensor_shape.h",
|
|
"//tensorflow/core/framework:tensor_slice.h",
|
|
"//tensorflow/core/framework:tensor_types.h",
|
|
diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD
|
|
index 9b6ddb2a..093f0545 100644
|
|
--- a/tensorflow/core/framework/BUILD
|
|
+++ b/tensorflow/core/framework/BUILD
|
|
@@ -209,6 +209,7 @@ filegroup(
|
|
"shared_ptr_variant.h",
|
|
"stats_aggregator.h",
|
|
"tensor.h",
|
|
+ "tensor_key.h",
|
|
"tensor_reference.h",
|
|
"tensor_shape.h",
|
|
"tensor_slice.h",
|
|
@@ -760,6 +761,75 @@ tf_cuda_library(
|
|
alwayslink = 1,
|
|
)
|
|
|
|
+tf_cuda_library(
|
|
+ name = "tensor_key",
|
|
+ srcs = [
|
|
+ "log_memory.cc",
|
|
+ "tensor.cc",
|
|
+ "typed_allocator.cc",
|
|
+ "types.cc",
|
|
+ "variant.cc",
|
|
+ "variant_op_registry.cc",
|
|
+ "variant_tensor_data.cc",
|
|
+ ],
|
|
+ hdrs = [
|
|
+ "log_memory.h",
|
|
+ "register_types.h",
|
|
+ "tensor.h",
|
|
+ "tensor_key.h",
|
|
+ "typed_allocator.h",
|
|
+ "types.h",
|
|
+ "variant.h",
|
|
+ "variant_encode_decode.h",
|
|
+ "variant_op_registry.h",
|
|
+ "variant_tensor_data.h",
|
|
+ ],
|
|
+ visibility = [
|
|
+ "//tensorflow/core:__pkg__",
|
|
+ "//tensorflow/core/util:__pkg__",
|
|
+ ],
|
|
+ deps = [
|
|
+ ":allocation_description_proto_cc",
|
|
+ ":allocator",
|
|
+ ":bfloat16",
|
|
+ ":log_memory_proto_cc",
|
|
+ ":numeric_types",
|
|
+ ":resource_handle",
|
|
+ ":resource_handle_proto_cc",
|
|
+ ":tensor_description_proto_cc",
|
|
+ ":tensor_proto_cc",
|
|
+ ":tensor_shape",
|
|
+ ":tensor_types",
|
|
+ ":type_index",
|
|
+ ":type_traits",
|
|
+ ":types_proto_cc",
|
|
+ "//tensorflow/core/lib/core:coding",
|
|
+ "//tensorflow/core/lib/core:errors",
|
|
+ "//tensorflow/core/lib/core:refcount",
|
|
+ "//tensorflow/core/lib/core:status",
|
|
+ "//tensorflow/core/lib/core:stringpiece",
|
|
+ "//tensorflow/core/lib/gtl:array_slice",
|
|
+ "//tensorflow/core/lib/gtl:flatmap",
|
|
+ "//tensorflow/core/lib/gtl:inlined_vector",
|
|
+ "//tensorflow/core/lib/hash",
|
|
+ "//tensorflow/core/lib/strings:str_util",
|
|
+ "//tensorflow/core/lib/strings:strcat",
|
|
+ "//tensorflow/core/platform:abi",
|
|
+ "//tensorflow/core/platform:logging",
|
|
+ "//tensorflow/core/platform:macros",
|
|
+ "//tensorflow/core/platform:platform_port",
|
|
+ "//tensorflow/core/platform:protobuf",
|
|
+ "//tensorflow/core/platform:strcat",
|
|
+ "//tensorflow/core/platform:tensor_coding",
|
|
+ "//tensorflow/core/platform:types",
|
|
+ "//tensorflow/core/public:version",
|
|
+ "//third_party/eigen3",
|
|
+ "@com_google_absl//absl/memory",
|
|
+ "@com_google_absl//absl/strings",
|
|
+ ],
|
|
+ alwayslink = 1,
|
|
+)
|
|
+
|
|
cc_library(
|
|
name = "shape_inference",
|
|
srcs = ["shape_inference.cc"],
|
|
diff --git a/tensorflow/core/framework/tensor_key.h b/tensorflow/core/framework/tensor_key.h
|
|
new file mode 100644
|
|
index 00000000..8eff58b2
|
|
--- /dev/null
|
|
+++ b/tensorflow/core/framework/tensor_key.h
|
|
@@ -0,0 +1,64 @@
|
|
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
+
|
|
+Licensed under the Apache License, Version 2.0 (the "License");
|
|
+you may not use this file except in compliance with the License.
|
|
+You may obtain a copy of the License at
|
|
+
|
|
+ http://www.apache.org/licenses/LICENSE-2.0
|
|
+
|
|
+Unless required by applicable law or agreed to in writing, software
|
|
+distributed under the License is distributed on an "AS IS" BASIS,
|
|
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
+See the License for the specific language governing permissions and
|
|
+limitations under the License.
|
|
+==============================================================================*/
|
|
+
|
|
+#include "tensorflow/core/framework/tensor.h"
|
|
+
|
|
+namespace tensorflow {
|
|
+
|
|
+class TensorKey : public Tensor {
|
|
+ public:
|
|
+ using Tensor::Tensor;
|
|
+
|
|
+ TensorKey(const Tensor& t) : Tensor(t) {}
|
|
+
|
|
+ // Equality operator. Needed for absl hashing.
|
|
+ friend bool operator==(const TensorKey& t1, const TensorKey& t2) {
|
|
+ if (t1.dtype() != t2.dtype() || t1.shape() != t2.shape()) {
|
|
+ return false;
|
|
+ }
|
|
+ if (DataTypeCanUseMemcpy(t1.dtype())) {
|
|
+ return t1.tensor_data() == t2.tensor_data();
|
|
+ }
|
|
+ if (t1.dtype() == DT_STRING) {
|
|
+ const auto s1 = t1.unaligned_flat<tstring>();
|
|
+ const auto s2 = t2.unaligned_flat<tstring>();
|
|
+ for (int64 i = 0, n = t1.NumElements(); i < n; ++i) {
|
|
+ if (TF_PREDICT_FALSE(s1(i) != s2(i))) {
|
|
+ return false;
|
|
+ }
|
|
+ }
|
|
+ return true;
|
|
+ }
|
|
+ return false;
|
|
+ }
|
|
+
|
|
+ friend bool operator!=(const TensorKey& t1, const TensorKey& t2) {
|
|
+ return !(t1==t2);
|
|
+ }
|
|
+
|
|
+ // AbslHashValue() function, needed for absl hashing.
|
|
+ template <typename H>
|
|
+ friend H AbslHashValue(H h, const TensorKey& k) {
|
|
+ uint8* d = (uint8*)(k.data());
|
|
+ size_t s = k.AllocatedBytes();
|
|
+ std::vector<uint8> vec;
|
|
+ for (int i=0; i < s; i++) {
|
|
+ vec.push_back(d[i]);
|
|
+ }
|
|
+ return H::combine(std::move(h), s);
|
|
+ }
|
|
+};
|
|
+
|
|
+} //namespace tensorflow
|
|
\ No newline at end of file
|
|
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
|
|
index f5a480b3..4ef86efb 100644
|
|
--- a/tensorflow/core/kernels/BUILD
|
|
+++ b/tensorflow/core/kernels/BUILD
|
|
@@ -3219,6 +3219,7 @@ tf_cc_tests(
|
|
],
|
|
deps = [
|
|
":eigen_helpers",
|
|
+ "//tensorflow/core/framework:tensor_testutil",
|
|
"//tensorflow/core:test",
|
|
"//tensorflow/core:test_main",
|
|
"@com_google_absl//absl/strings",
|
|
--
|
|
2.27.0
|
|
|