From ce47a396ff795bdb6cf48eb53dbcba46cb51fa7d Mon Sep 17 00:00:00 2001 From: Katherine Tian Date: Tue, 30 Jun 2020 04:12:11 +0000 Subject: [PATCH 1/1] TensorKey class and TensorMap tests --- tensorflow/core/BUILD | 1 + tensorflow/core/framework/BUILD | 70 ++++++++++++++++++++++++++ tensorflow/core/framework/tensor_key.h | 64 +++++++++++++++++++++++ tensorflow/core/kernels/BUILD | 1 + 4 files changed, 136 insertions(+) create mode 100644 tensorflow/core/framework/tensor_key.h diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index d0be6ee9..6e745b4e 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -495,6 +495,7 @@ tf_cuda_library( "//tensorflow/core/framework:shared_ptr_variant.h", "//tensorflow/core/framework:stats_aggregator.h", "//tensorflow/core/framework:tensor.h", + "//tensorflow/core/framework:tensor_key.h", "//tensorflow/core/framework:tensor_shape.h", "//tensorflow/core/framework:tensor_slice.h", "//tensorflow/core/framework:tensor_types.h", diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index 9b6ddb2a..093f0545 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -209,6 +209,7 @@ filegroup( "shared_ptr_variant.h", "stats_aggregator.h", "tensor.h", + "tensor_key.h", "tensor_reference.h", "tensor_shape.h", "tensor_slice.h", @@ -760,6 +761,75 @@ tf_cuda_library( alwayslink = 1, ) +tf_cuda_library( + name = "tensor_key", + srcs = [ + "log_memory.cc", + "tensor.cc", + "typed_allocator.cc", + "types.cc", + "variant.cc", + "variant_op_registry.cc", + "variant_tensor_data.cc", + ], + hdrs = [ + "log_memory.h", + "register_types.h", + "tensor.h", + "tensor_key.h", + "typed_allocator.h", + "types.h", + "variant.h", + "variant_encode_decode.h", + "variant_op_registry.h", + "variant_tensor_data.h", + ], + visibility = [ + "//tensorflow/core:__pkg__", + "//tensorflow/core/util:__pkg__", + ], + deps = [ + ":allocation_description_proto_cc", + ":allocator", + ":bfloat16", + ":log_memory_proto_cc", + ":numeric_types", + ":resource_handle", + ":resource_handle_proto_cc", + ":tensor_description_proto_cc", + ":tensor_proto_cc", + ":tensor_shape", + ":tensor_types", + ":type_index", + ":type_traits", + ":types_proto_cc", + "//tensorflow/core/lib/core:coding", + "//tensorflow/core/lib/core:errors", + "//tensorflow/core/lib/core:refcount", + "//tensorflow/core/lib/core:status", + "//tensorflow/core/lib/core:stringpiece", + "//tensorflow/core/lib/gtl:array_slice", + "//tensorflow/core/lib/gtl:flatmap", + "//tensorflow/core/lib/gtl:inlined_vector", + "//tensorflow/core/lib/hash", + "//tensorflow/core/lib/strings:str_util", + "//tensorflow/core/lib/strings:strcat", + "//tensorflow/core/platform:abi", + "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:macros", + "//tensorflow/core/platform:platform_port", + "//tensorflow/core/platform:protobuf", + "//tensorflow/core/platform:strcat", + "//tensorflow/core/platform:tensor_coding", + "//tensorflow/core/platform:types", + "//tensorflow/core/public:version", + "//third_party/eigen3", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + cc_library( name = "shape_inference", srcs = ["shape_inference.cc"], diff --git a/tensorflow/core/framework/tensor_key.h b/tensorflow/core/framework/tensor_key.h new file mode 100644 index 00000000..8eff58b2 --- /dev/null +++ b/tensorflow/core/framework/tensor_key.h @@ -0,0 +1,64 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { + +class TensorKey : public Tensor { + public: + using Tensor::Tensor; + + TensorKey(const Tensor& t) : Tensor(t) {} + + // Equality operator. Needed for absl hashing. + friend bool operator==(const TensorKey& t1, const TensorKey& t2) { + if (t1.dtype() != t2.dtype() || t1.shape() != t2.shape()) { + return false; + } + if (DataTypeCanUseMemcpy(t1.dtype())) { + return t1.tensor_data() == t2.tensor_data(); + } + if (t1.dtype() == DT_STRING) { + const auto s1 = t1.unaligned_flat(); + const auto s2 = t2.unaligned_flat(); + for (int64 i = 0, n = t1.NumElements(); i < n; ++i) { + if (TF_PREDICT_FALSE(s1(i) != s2(i))) { + return false; + } + } + return true; + } + return false; + } + + friend bool operator!=(const TensorKey& t1, const TensorKey& t2) { + return !(t1==t2); + } + + // AbslHashValue() function, needed for absl hashing. + template + friend H AbslHashValue(H h, const TensorKey& k) { + uint8* d = (uint8*)(k.data()); + size_t s = k.AllocatedBytes(); + std::vector vec; + for (int i=0; i < s; i++) { + vec.push_back(d[i]); + } + return H::combine(std::move(h), s); + } +}; + +} //namespace tensorflow \ No newline at end of file diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index f5a480b3..4ef86efb 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3219,6 +3219,7 @@ tf_cc_tests( ], deps = [ ":eigen_helpers", + "//tensorflow/core/framework:tensor_testutil", "//tensorflow/core:test", "//tensorflow/core:test_main", "@com_google_absl//absl/strings", -- 2.27.0