From e2a6f729f4ce40542fccec997529b43d25a6d5ae Mon Sep 17 00:00:00 2001 From: d00573793 Date: Tue, 21 Feb 2023 21:53:44 +0800 Subject: [PATCH 09/23] [Pin-server] Support functiontype structtype.eg. diff --git a/include/Dialect/PluginOps.td b/include/Dialect/PluginOps.td index 3a88846..1083141 100644 --- a/include/Dialect/PluginOps.td +++ b/include/Dialect/PluginOps.td @@ -30,7 +30,8 @@ def FunctionOp : Plugin_Op<"function", [NoSideEffect]> { let arguments = (ins UI64Attr:$id, StrAttr:$funcName, - OptionalAttr:$declaredInline); + OptionalAttr:$declaredInline, + TypeAttr:$type); let regions = (region AnyRegion:$bodyRegion); // Add custom build methods for the operation. These method populates @@ -39,13 +40,15 @@ def FunctionOp : Plugin_Op<"function", [NoSideEffect]> { let builders = [ OpBuilderDAG<(ins "uint64_t":$id, "StringRef":$funcName, - "bool":$declaredInline)> + "bool":$declaredInline, + "Type":$type)> ]; let extraClassDeclaration = [{ std::vector GetAllLoops(); LoopOp AllocateNewLoop(); bool IsDomInfoAvailable(); + Type getResultType(); }]; } diff --git a/include/Dialect/PluginTypes.h b/include/Dialect/PluginTypes.h index 7fb1ff9..3f7f14b 100644 --- a/include/Dialect/PluginTypes.h +++ b/include/Dialect/PluginTypes.h @@ -78,6 +78,9 @@ namespace detail { struct PluginIntegerTypeStorage; struct PluginFloatTypeStorage; struct PluginPointerTypeStorage; + struct PluginTypeAndSizeStorage; + struct PluginFunctionTypeStorage; + struct PluginStructTypeStorage; } class PluginIntegerType : public Type::TypeBase { @@ -128,6 +131,61 @@ public: unsigned isReadOnlyElem(); }; // class PluginPointerType +class PluginArrayType : public Type::TypeBase { +public: + using Base::Base; + + PluginTypeID getPluginTypeID (); + + static bool isValidElementType(Type type); + + static PluginArrayType get(MLIRContext *context, Type elementType, unsigned numElements); + + Type getElementType(); + + unsigned getNumElements(); +}; // class PluginArrayType + +class PluginFunctionType : public Type::TypeBase { +public: + using Base::Base; + + PluginTypeID getPluginTypeID (); + + static bool isValidArgumentType(Type type); + + static bool isValidResultType(Type type); + + static PluginFunctionType get(MLIRContext *context, Type result, ArrayRef arguments); + + Type getReturnType(); + + unsigned getNumParams(); + + Type getParamType(unsigned i); + + ArrayRef getParams(); + +}; // class PluginFunctionType + +class PluginStructType : public Type::TypeBase { +public: + using Base::Base; + + PluginTypeID getPluginTypeID (); + + static bool isValidElementType(Type type); + + static PluginStructType get(MLIRContext *context, std::string name, ArrayRef elements, ArrayRef elemNames); + + std::string getName(); + + ArrayRef getBody(); + + ArrayRef getElementNames(); + +}; // class PluginStructType + class PluginVoidType : public Type::TypeBase { public: using Base::Base; diff --git a/include/PluginServer/PluginJson.h b/include/PluginServer/PluginJson.h index 6f46187..fd2f05b 100755 --- a/include/PluginServer/PluginJson.h +++ b/include/PluginServer/PluginJson.h @@ -59,7 +59,7 @@ public: /* 将json格式数据解析成map格式 */ void GetAttributes(Json::Value node, map& attributes); mlir::Value ValueJsonDeSerialize(Json::Value valueJson); - Json::Value TypeJsonSerialize(PluginIR::PluginTypeBase& type); + Json::Value TypeJsonSerialize(PluginIR::PluginTypeBase type); mlir::Value MemRefDeSerialize(const string& data); bool ProcessBlock(mlir::Block*, mlir::Region&, const Json::Value&); }; diff --git a/lib/Dialect/PluginDialect.cpp b/lib/Dialect/PluginDialect.cpp index 95b38cf..ba8e4fe 100644 --- a/lib/Dialect/PluginDialect.cpp +++ b/lib/Dialect/PluginDialect.cpp @@ -37,6 +37,9 @@ void PluginDialect::initialize() PluginIR::PluginIntegerType, PluginIR::PluginFloatType, PluginIR::PluginPointerType, + PluginIR::PluginArrayType, + PluginIR::PluginFunctionType, + PluginIR::PluginStructType, PluginIR::PluginBooleanType, PluginIR::PluginVoidType, PluginIR::PluginUndefType>(); diff --git a/lib/Dialect/PluginOps.cpp b/lib/Dialect/PluginOps.cpp index a30e9ed..1c4fb2d 100644 --- a/lib/Dialect/PluginOps.cpp +++ b/lib/Dialect/PluginOps.cpp @@ -64,12 +64,19 @@ static uint64_t getBlockAddress(mlir::Block* b) } void FunctionOp::build(OpBuilder &builder, OperationState &state, - uint64_t id, StringRef funcName, bool declaredInline) + uint64_t id, StringRef funcName, bool declaredInline, Type type) { - FunctionOp::build(builder, state, - builder.getI64IntegerAttr(id), - builder.getStringAttr(funcName), - builder.getBoolAttr(declaredInline)); + state.addRegion(); + state.addAttribute("id", builder.getI64IntegerAttr(id)); + state.addAttribute("funcName", builder.getStringAttr(funcName)); + state.addAttribute("declaredInline", builder.getBoolAttr(declaredInline)); + if (type) state.addAttribute("type", TypeAttr::get(type)); +} + +Type FunctionOp::getResultType() +{ + PluginIR::PluginFunctionType resultType = type().dyn_cast(); + return resultType; } vector FunctionOp::GetAllLoops() diff --git a/lib/Dialect/PluginTypes.cpp b/lib/Dialect/PluginTypes.cpp index c0a58c2..337fc49 100644 --- a/lib/Dialect/PluginTypes.cpp +++ b/lib/Dialect/PluginTypes.cpp @@ -97,6 +97,80 @@ namespace detail { Type pointee; unsigned readOnlyPointee; }; + + struct PluginTypeAndSizeStorage : public TypeStorage { + using KeyTy = std::tuple; + + PluginTypeAndSizeStorage(const KeyTy &key) + : elementType(std::get<0>(key)), numElements(std::get<1>(key)) {} + + static PluginTypeAndSizeStorage *construct(TypeStorageAllocator &allocator, KeyTy key) + { + return new (allocator.allocate()) + PluginTypeAndSizeStorage(key); + } + + bool operator==(const KeyTy &key) const + { + return std::make_tuple(elementType, numElements) == key; + } + + Type elementType; + unsigned numElements; + }; + + struct PluginFunctionTypeStorage : public TypeStorage { + using KeyTy = std::tuple>; + + PluginFunctionTypeStorage(Type resultType, ArrayRef argumentTypes) + : resultType(resultType), argumentTypes(argumentTypes) {} + + static PluginFunctionTypeStorage *construct(TypeStorageAllocator &allocator, KeyTy key) + { + return new (allocator.allocate()) + PluginFunctionTypeStorage(std::get<0>(key), allocator.copyInto(std::get<1>(key))); + } + + static unsigned hashKey(const KeyTy &key) { + // LLVM doesn't like hashing bools in tuples. + return llvm::hash_combine(std::get<0>(key), std::get<1>(key)); + } + + bool operator==(const KeyTy &key) const + { + return std::make_tuple(resultType, argumentTypes) == key; + } + + Type resultType; + ArrayRef argumentTypes; + }; + + struct PluginStructTypeStorage : public TypeStorage { + using KeyTy = std::tuple, ArrayRef>; + + PluginStructTypeStorage(std::string name, ArrayRef elements, ArrayRef elemNames) + : name(name), elements(elements), elemNames(elemNames) {} + + static PluginStructTypeStorage *construct(TypeStorageAllocator &allocator, KeyTy key) + { + return new (allocator.allocate()) + PluginStructTypeStorage(std::get<0>(key), allocator.copyInto(std::get<1>(key)), allocator.copyInto(std::get<2>(key))); + } + + static unsigned hashKey(const KeyTy &key) { + // LLVM doesn't like hashing bools in tuples. + return llvm::hash_combine(std::get<0>(key), std::get<1>(key), std::get<2>(key)); + } + + bool operator==(const KeyTy &key) const + { + return std::make_tuple(name, elements, elemNames) == key; + } + + std::string name; + ArrayRef elements; + ArrayRef elemNames; + }; } } @@ -122,6 +196,15 @@ PluginTypeID PluginTypeBase::getPluginTypeID () if (auto Ty = dyn_cast()) { return Ty.getPluginTypeID (); } + if (auto Ty = dyn_cast()) { + return Ty.getPluginTypeID (); + } + if (auto Ty = dyn_cast()) { + return Ty.getPluginTypeID (); + } + if (auto Ty = dyn_cast()) { + return Ty.getPluginTypeID (); + } return PluginTypeID::UndefTyID; } @@ -292,4 +375,108 @@ unsigned PluginPointerType::isReadOnlyElem() PluginPointerType PluginPointerType::get (MLIRContext *context, Type pointee, unsigned readOnlyPointee) { return Base::get(context, pointee, readOnlyPointee); +} + +// ===----------------------------------------------------------------------===// +// Plugin Array Type +// ===----------------------------------------------------------------------===// + +PluginTypeID PluginArrayType::getPluginTypeID() +{ + return PluginTypeID::ArrayTyID; +} + +bool PluginArrayType::isValidElementType(Type type) +{ + return !type.isa(); +} + +PluginArrayType PluginArrayType::get(MLIRContext *context, Type elementType, unsigned numElements) +{ + return Base::get(context, elementType, numElements); +} + +Type PluginArrayType::getElementType() +{ + return getImpl()->elementType; +} + +unsigned PluginArrayType::getNumElements() +{ + return getImpl()->numElements; +} + +// ===----------------------------------------------------------------------===// +// Plugin Function Type +// ===----------------------------------------------------------------------===// + +PluginTypeID PluginFunctionType::getPluginTypeID() +{ + return PluginTypeID::FunctionTyID; +} + +bool PluginFunctionType::isValidArgumentType(Type type) +{ + return !type.isa(); +} + +bool PluginFunctionType::isValidResultType(Type type) { + return !type.isa(); +} + +PluginFunctionType PluginFunctionType::get(MLIRContext *context, Type result, ArrayRef arguments) +{ + return Base::get(context, result, arguments); +} + +Type PluginFunctionType::getReturnType() +{ + return getImpl()->resultType; +} + +unsigned PluginFunctionType::getNumParams() +{ + return getImpl()->argumentTypes.size(); +} + +Type PluginFunctionType::getParamType(unsigned i) { + return getImpl()->argumentTypes[i]; +} + +ArrayRef PluginFunctionType::getParams() +{ + return getImpl()->argumentTypes; +} + +// ===----------------------------------------------------------------------===// +// Plugin Struct Type +// ===----------------------------------------------------------------------===// + +PluginTypeID PluginStructType::getPluginTypeID() +{ + return PluginTypeID::StructTyID; +} + +bool PluginStructType::isValidElementType(Type type) { + return !type.isa(); +} + +PluginStructType PluginStructType::get(MLIRContext *context, std::string name, ArrayRef elements, ArrayRef elemNames) +{ + return Base::get(context, name, elements, elemNames); +} + +std::string PluginStructType::getName() +{ + return getImpl()->name; +} + +ArrayRef PluginStructType::getBody() +{ + return getImpl()->elements; +} + +ArrayRef PluginStructType::getElementNames() +{ + return getImpl()->elemNames; } \ No newline at end of file diff --git a/lib/PluginAPI/PluginServerAPI.cpp b/lib/PluginAPI/PluginServerAPI.cpp index f81a3ad..e3435b0 100644 --- a/lib/PluginAPI/PluginServerAPI.cpp +++ b/lib/PluginAPI/PluginServerAPI.cpp @@ -343,6 +343,14 @@ PluginIR::PluginTypeID PluginServerAPI::GetTypeCodeFromString(string type) return PluginIR::PluginTypeID::FloatTyID; } else if (type == "DoubleTy") { return PluginIR::PluginTypeID::DoubleTyID; + } else if (type == "PointerTy") { + return PluginIR::PluginTypeID::PointerTyID; + } else if (type == "ArrayTy") { + return PluginIR::PluginTypeID::ArrayTyID; + } else if (type == "FunctionTy") { + return PluginIR::PluginTypeID::FunctionTyID; + } else if (type == "StructTy") { + return PluginIR::PluginTypeID::StructTyID; } return PluginIR::PluginTypeID::UndefTyID; diff --git a/lib/PluginServer/PluginJson.cpp b/lib/PluginServer/PluginJson.cpp index 7bbf681..e1beddf 100755 --- a/lib/PluginServer/PluginJson.cpp +++ b/lib/PluginServer/PluginJson.cpp @@ -23,6 +23,7 @@ namespace PinJson { using namespace PinServer; +using namespace mlir; using namespace mlir::Plugin; static uintptr_t GetID(Json::Value node) @@ -41,7 +42,7 @@ static void JsonGetAttributes(Json::Value node, map& attributes) } } -Json::Value PluginJson::TypeJsonSerialize (PluginIR::PluginTypeBase& type) +Json::Value PluginJson::TypeJsonSerialize (PluginIR::PluginTypeBase type) { Json::Value root; Json::Value operationObj; @@ -53,6 +54,41 @@ Json::Value PluginJson::TypeJsonSerialize (PluginIR::PluginTypeBase& type) ReTypeId = static_cast(type.getPluginTypeID()); item["id"] = std::to_string(ReTypeId); + if (auto Ty = type.dyn_cast()) { + std::string tyName = Ty.getName(); + item["structtype"] = tyName; + size_t paramIndex = 0; + ArrayRef paramsType = Ty.getBody(); + for (auto ty :paramsType) { + std::string paramStr = "elemType" + std::to_string(paramIndex++); + item["structelemType"][paramStr] = TypeJsonSerialize(ty.dyn_cast()); + } + paramIndex = 0; + ArrayRef paramsNames = Ty.getElementNames(); + for (auto name :paramsNames) { + std::string paramStr = "elemName" + std::to_string(paramIndex++); + item["structelemName"][paramStr] = name; + } + } + + if (auto Ty = type.dyn_cast()) { + auto fnrestype = Ty.getReturnType().dyn_cast(); + item["fnreturntype"] = TypeJsonSerialize(fnrestype); + size_t paramIndex = 0; + ArrayRef paramsType = Ty.getParams(); + for (auto ty : Ty.getParams()) { + string paramStr = "argType" + std::to_string(paramIndex++); + item["fnargsType"][paramStr] = TypeJsonSerialize(ty.dyn_cast()); + } + } + + if (auto Ty = type.dyn_cast()) { + auto elemTy = Ty.getElementType().dyn_cast(); + item["elementType"] = TypeJsonSerialize(elemTy); + uint64_t elemNum = Ty.getNumElements(); + item["arraysize"] = std::to_string(elemNum); + } + if (auto elemTy = type.dyn_cast()) { auto baseTy = elemTy.getElementType().dyn_cast(); item["elementType"] = TypeJsonSerialize(baseTy); @@ -247,8 +283,9 @@ void PluginJson::FuncOpJsonDeSerialize( bool declaredInline = false; if (funcAttributes["declaredInline"] == "1") declaredInline = true; auto location = opBuilder.getUnknownLoc(); + PluginIR::PluginTypeBase retType = TypeJsonDeSerialize(node["retType"].toStyledString()); FunctionOp fOp = opBuilder.create( - location, id, funcAttributes["funcName"], declaredInline); + location, id, funcAttributes["funcName"], declaredInline, retType); mlir::Region &bodyRegion = fOp.bodyRegion(); Json::Value regionJson = node["region"]; Json::Value::Members bbMember = regionJson.getMemberNames(); @@ -302,7 +339,40 @@ PluginIR::PluginTypeBase PluginJson::TypeJsonDeSerialize(const string& data) mlir::Type elemTy = TypeJsonDeSerialize(type["elementType"].toStyledString()); baseType = PluginIR::PluginPointerType::get( PluginServer::GetInstance()->GetContext(), elemTy, type["elemConst"].asString() == "1" ? 1 : 0); - } else { + } else if (id == static_cast(PluginIR::ArrayTyID)) { + mlir::Type elemTy = TypeJsonDeSerialize(type["elementType"].toStyledString()); + uint64_t elemNum = GetID(type["arraysize"]); + baseType = PluginIR::PluginArrayType::get(PluginServer::GetInstance()->GetContext(), elemTy, elemNum); + } else if (id == static_cast(PluginIR::FunctionTyID)) { + mlir::Type returnTy = TypeJsonDeSerialize(type["fnreturntype"].toStyledString()); + llvm::SmallVector typelist; + Json::Value::Members fnTypeNum = type["fnargsType"].getMemberNames(); + uint64_t argsNum = fnTypeNum.size(); + for (size_t paramIndex = 0; paramIndex < argsNum; paramIndex++) { + string Key = "argType" + std::to_string(paramIndex); + mlir::Type paramTy = TypeJsonDeSerialize(type["fnargsType"][Key].toStyledString()); + typelist.push_back(paramTy); + } + baseType = PluginIR::PluginFunctionType::get(PluginServer::GetInstance()->GetContext(), returnTy, typelist); + } else if (id == static_cast(PluginIR::StructTyID)) { + std::string tyName = type["structtype"].asString(); + llvm::SmallVector typelist; + Json::Value::Members elemTypeNum = type["structelemType"].getMemberNames(); + for (size_t paramIndex = 0; paramIndex < elemTypeNum.size(); paramIndex++) { + string Key = "elemType" + std::to_string(paramIndex); + mlir::Type paramTy = TypeJsonDeSerialize(type["structelemType"][Key].toStyledString()); + typelist.push_back(paramTy); + } + llvm::SmallVector names; + Json::Value::Members elemNameNum = type["structelemName"].getMemberNames(); + for (size_t paramIndex = 0; paramIndex < elemTypeNum.size(); paramIndex++) { + std::string Key = "elemName" + std::to_string(paramIndex); + std::string elemName = type["structelemName"][Key].asString(); + names.push_back(elemName); + } + baseType = PluginIR::PluginStructType::get(PluginServer::GetInstance()->GetContext(), tyName, typelist, names); + } + else { if (PluginTypeId == PluginIR::VoidTyID) { baseType = PluginIR::PluginVoidType::get(PluginServer::GetInstance()->GetContext()); } diff --git a/user/LocalVarSummeryPass.cpp b/user/LocalVarSummeryPass.cpp index 4fc4985..2e157e3 100755 --- a/user/LocalVarSummeryPass.cpp +++ b/user/LocalVarSummeryPass.cpp @@ -23,6 +23,10 @@ #include "user/LocalVarSummeryPass.h" namespace PluginOpt { +using std::string; +using std::vector; +using std::cout; +using namespace mlir; using namespace PluginAPI; static void LocalVarSummery(void) @@ -38,6 +42,32 @@ static void LocalVarSummery(void) if (args.find("type_code") != args.end()) { typeFilter = (int64_t)pluginAPI.GetTypeCodeFromString(args["type_code"]); } + mlir::Plugin::FunctionOp funcOp = allFunction[i]; + printf("func name is :%s\n", funcOp.funcNameAttr().getValue().str().c_str()); + mlir::Type dgyty = funcOp.type(); + if (auto ty = dgyty.dyn_cast()) { + if(auto stTy = ty.getReturnType().dyn_cast()) { + printf("func return type is PluginStructType\n"); + std::string tyName = stTy.getName(); + printf(" struct name is : %s\n", tyName.c_str()); + + llvm::ArrayRef paramsType = stTy.getBody(); + for (auto tty :paramsType) { + printf("\n struct arg id : %d\n", tty.dyn_cast().getPluginTypeID()); + } + llvm::ArrayRef paramsNames = stTy.getElementNames(); + for (auto name :paramsNames) { + std::string pName = name; + printf("\n struct argname is : %s\n", pName.c_str()); + } + } + size_t paramIndex = 0; + llvm::ArrayRef paramsType = ty.getParams(); + for (auto ty : ty.getParams()) { + printf("\n Param index : %d\n", paramIndex++); + printf("\n Param type id : %d\n", ty.dyn_cast().getPluginTypeID()); + } + } for (size_t j = 0; j < decls.size(); j++) { auto decl = decls[j]; string name = decl.symNameAttr().getValue().str(); -- 2.33.0