From 7e56b74990e5f6ccb9a2896eeea0d60e1e2b1f21 Mon Sep 17 00:00:00 2001 From: dingguangya Date: Thu, 8 Dec 2022 20:50:23 +0800 Subject: [PATCH 4/4] [Pin-server] Support Plugin Pointer Type --- include/Dialect/PluginTypes.h | 12 +++++++++ include/PluginServer/PluginServer.h | 2 +- lib/Dialect/PluginDialect.cpp | 1 + lib/Dialect/PluginTypes.cpp | 41 +++++++++++++++++++++++++++++ lib/PluginServer/PluginServer.cpp | 28 +++++++++++++------- 5 files changed, 73 insertions(+), 11 deletions(-) diff --git a/include/Dialect/PluginTypes.h b/include/Dialect/PluginTypes.h index bb8d81a..1329b8d 100644 --- a/include/Dialect/PluginTypes.h +++ b/include/Dialect/PluginTypes.h @@ -80,6 +80,7 @@ private: namespace detail { struct PluginIntegerTypeStorage; struct PluginFloatTypeStorage; + struct PluginPointerTypeStorage; } class PluginIntegerType : public Type::TypeBase { @@ -117,6 +118,17 @@ public: unsigned getWidth() const; }; +class PluginPointerType : public Type::TypeBase { +public: + using Base::Base; + + PluginTypeID getPluginTypeID (); + + static PluginPointerType get(MLIRContext *context, Type pointee); + + Type getElementType(); +}; // class PluginPointerType + class PluginVoidType : public Type::TypeBase { public: using Base::Base; diff --git a/include/PluginServer/PluginServer.h b/include/PluginServer/PluginServer.h index b28c6d0..207c018 100644 --- a/include/PluginServer/PluginServer.h +++ b/include/PluginServer/PluginServer.h @@ -187,7 +187,7 @@ public: timeout = time; } void FuncOpJsonDeSerialize(const string& data); - void TypeJsonDeSerialize(const string& data); + PluginIR::PluginTypeBase TypeJsonDeSerialize(const string& data); void LocalDeclOpJsonDeSerialize(const string& data); void LoopOpsJsonDeSerialize(const string& data); void LoopOpJsonDeSerialize(const string& data); diff --git a/lib/Dialect/PluginDialect.cpp b/lib/Dialect/PluginDialect.cpp index 63ba167..001fdab 100644 --- a/lib/Dialect/PluginDialect.cpp +++ b/lib/Dialect/PluginDialect.cpp @@ -35,6 +35,7 @@ void PluginDialect::initialize() { addTypes< PluginIR::PluginIntegerType, PluginIR::PluginFloatType, + PluginIR::PluginPointerType, PluginIR::PluginBooleanType, PluginIR::PluginVoidType, PluginIR::PluginUndefType>(); diff --git a/lib/Dialect/PluginTypes.cpp b/lib/Dialect/PluginTypes.cpp index def302b..0fa003a 100644 --- a/lib/Dialect/PluginTypes.cpp +++ b/lib/Dialect/PluginTypes.cpp @@ -72,6 +72,25 @@ namespace detail { unsigned width : 30; }; + + struct PluginPointerTypeStorage : public TypeStorage { + using KeyTy = Type; + + PluginPointerTypeStorage(const KeyTy &key) + : pointee(key) {} + + static PluginPointerTypeStorage *construct(TypeStorageAllocator &allocator, + KeyTy key) { + return new (allocator.allocate()) + PluginPointerTypeStorage(key); + } + + bool operator==(const KeyTy &key) const { + return KeyTy(pointee) == key; + } + + Type pointee; + }; } } @@ -94,6 +113,9 @@ PluginTypeID PluginTypeBase::getPluginTypeID () if (auto Ty = dyn_cast()) { return Ty.getPluginTypeID (); } + if (auto Ty = dyn_cast()) { + return Ty.getPluginTypeID (); + } return PluginTypeID::UndefTyID; } @@ -252,4 +274,23 @@ PluginTypeID PluginVoidType::getPluginTypeID() PluginTypeID PluginUndefType::getPluginTypeID() { return PluginTypeID::UndefTyID; +} + +//===----------------------------------------------------------------------===// +// Plugin Pointer Type +//===----------------------------------------------------------------------===// + +PluginTypeID PluginPointerType::getPluginTypeID() +{ + return PluginTypeID::PointerTyID; +} + +Type PluginPointerType::getElementType() +{ + return getImpl()->pointee; +} + +PluginPointerType PluginPointerType::get (MLIRContext *context, Type pointee) +{ + return Base::get(context, pointee); } \ No newline at end of file diff --git a/lib/PluginServer/PluginServer.cpp b/lib/PluginServer/PluginServer.cpp index 2eac906..74974dc 100644 --- a/lib/PluginServer/PluginServer.cpp +++ b/lib/PluginServer/PluginServer.cpp @@ -225,13 +225,15 @@ void PluginServer::JsonDeSerialize(const string& key, const string& data) } } -void PluginServer::TypeJsonDeSerialize(const string& data) +PluginIR::PluginTypeBase PluginServer::TypeJsonDeSerialize(const string& data) { Json::Value root; Json::Reader reader; Json::Value node; reader.parse(data, root); + PluginIR::PluginTypeBase baseType; + Json::Value type = root["type"]; uint64_t id = GetID(type["id"]); PluginIR::PluginTypeID PluginTypeId = static_cast(id); @@ -240,28 +242,34 @@ void PluginServer::TypeJsonDeSerialize(const string& data) string s = type["signed"].asString(); uint64_t width = GetID(type["width"]); if (s == "1") { - pluginType = PluginIR::PluginIntegerType::get(&context, width, PluginIR::PluginIntegerType::Signed); + baseType = PluginIR::PluginIntegerType::get(&context, width, PluginIR::PluginIntegerType::Signed); } else { - pluginType = PluginIR::PluginIntegerType::get(&context, width, PluginIR::PluginIntegerType::Unsigned); + baseType = PluginIR::PluginIntegerType::get(&context, width, PluginIR::PluginIntegerType::Unsigned); } } else if (type["width"] && (id == static_cast(PluginIR::FloatTyID) || id == static_cast(PluginIR::DoubleTyID)) ) { uint64_t width = GetID(type["width"]); - pluginType = PluginIR::PluginFloatType::get(&context, width); + baseType = PluginIR::PluginFloatType::get(&context, width); + }else if (id == static_cast(PluginIR::PointerTyID)) { + mlir::Type elemTy = TypeJsonDeSerialize(type["elementType"].toStyledString()); + auto ty = elemTy.dyn_cast(); + baseType = PluginIR::PluginPointerType::get(&context, elemTy); }else { if (PluginTypeId == PluginIR::VoidTyID) - pluginType = PluginIR::PluginVoidType::get(&context); + baseType = PluginIR::PluginVoidType::get(&context); if (PluginTypeId == PluginIR::BooleanTyID) - pluginType = PluginIR::PluginBooleanType::get(&context); + baseType = PluginIR::PluginBooleanType::get(&context); if (PluginTypeId == PluginIR::UndefTyID) - pluginType = PluginIR::PluginUndefType::get(&context); + baseType = PluginIR::PluginUndefType::get(&context); } if (type["readonly"] == "1") - pluginType.setReadOnlyFlag(1); + baseType.setReadOnlyFlag(1); else - pluginType.setReadOnlyFlag(0); - return; + baseType.setReadOnlyFlag(0); + + pluginType = baseType; + return baseType; } bool PluginServer::ProcessBlock(mlir::Block* block, mlir::Region& rg, -- 2.27.0.windows.1