From 50801883d5d5e89d44026ead6e38e149caa346b2 Mon Sep 17 00:00:00 2001 From: d00573793 Date: Sat, 25 Feb 2023 15:31:02 +0800 Subject: [PATCH 12/23] [Pin-server] Support Vectortype diff --git a/include/Dialect/PluginTypes.h b/include/Dialect/PluginTypes.h index 3f7f14b..9693294 100644 --- a/include/Dialect/PluginTypes.h +++ b/include/Dialect/PluginTypes.h @@ -55,8 +55,7 @@ enum PluginTypeID { PointerTyID, ///< Pointers StructTyID, ///< Structures ArrayTyID, ///< Arrays - FixedVectorTyID, ///< Fixed width SIMD vector type - ScalableVectorTyID ///< Scalable SIMD vector type + VectorTyID, ///< Arrays }; class PluginTypeBase : public Type { @@ -146,6 +145,21 @@ public: unsigned getNumElements(); }; // class PluginArrayType +class PluginVectorType : public Type::TypeBase { +public: + using Base::Base; + + PluginTypeID getPluginTypeID (); + + static bool isValidElementType(Type type); + + static PluginVectorType get(MLIRContext *context, Type elementType, unsigned numElements); + + Type getElementType(); + + unsigned getNumElements(); +}; // class PluginVectorType + class PluginFunctionType : public Type::TypeBase { public: using Base::Base; diff --git a/lib/Dialect/PluginDialect.cpp b/lib/Dialect/PluginDialect.cpp index ba8e4fe..95a78da 100644 --- a/lib/Dialect/PluginDialect.cpp +++ b/lib/Dialect/PluginDialect.cpp @@ -38,6 +38,7 @@ void PluginDialect::initialize() PluginIR::PluginFloatType, PluginIR::PluginPointerType, PluginIR::PluginArrayType, + PluginIR::PluginVectorType, PluginIR::PluginFunctionType, PluginIR::PluginStructType, PluginIR::PluginBooleanType, diff --git a/lib/Dialect/PluginTypes.cpp b/lib/Dialect/PluginTypes.cpp index 337fc49..89e4b1a 100644 --- a/lib/Dialect/PluginTypes.cpp +++ b/lib/Dialect/PluginTypes.cpp @@ -199,6 +199,9 @@ PluginTypeID PluginTypeBase::getPluginTypeID () if (auto Ty = dyn_cast()) { return Ty.getPluginTypeID (); } + if (auto Ty = dyn_cast()) { + return Ty.getPluginTypeID (); + } if (auto Ty = dyn_cast()) { return Ty.getPluginTypeID (); } @@ -406,6 +409,35 @@ unsigned PluginArrayType::getNumElements() return getImpl()->numElements; } +// ===----------------------------------------------------------------------===// +// Plugin Vector Type +// ===----------------------------------------------------------------------===// + +PluginTypeID PluginVectorType::getPluginTypeID() +{ + return PluginTypeID::ArrayTyID; +} + +bool PluginVectorType::isValidElementType(Type type) +{ + return type.isa(); +} + +PluginVectorType PluginVectorType::get(MLIRContext *context, Type elementType, unsigned numElements) +{ + return Base::get(context, elementType, numElements); +} + +Type PluginVectorType::getElementType() +{ + return getImpl()->elementType; +} + +unsigned PluginVectorType::getNumElements() +{ + return getImpl()->numElements; +} + // ===----------------------------------------------------------------------===// // Plugin Function Type // ===----------------------------------------------------------------------===// diff --git a/lib/PluginAPI/PluginServerAPI.cpp b/lib/PluginAPI/PluginServerAPI.cpp index e3435b0..a471ddf 100644 --- a/lib/PluginAPI/PluginServerAPI.cpp +++ b/lib/PluginAPI/PluginServerAPI.cpp @@ -347,6 +347,8 @@ PluginIR::PluginTypeID PluginServerAPI::GetTypeCodeFromString(string type) return PluginIR::PluginTypeID::PointerTyID; } else if (type == "ArrayTy") { return PluginIR::PluginTypeID::ArrayTyID; + } else if (type == "VectorTy") { + return PluginIR::PluginTypeID::VectorTyID; } else if (type == "FunctionTy") { return PluginIR::PluginTypeID::FunctionTyID; } else if (type == "StructTy") { diff --git a/lib/PluginServer/PluginJson.cpp b/lib/PluginServer/PluginJson.cpp index e1beddf..4d41351 100755 --- a/lib/PluginServer/PluginJson.cpp +++ b/lib/PluginServer/PluginJson.cpp @@ -343,6 +343,10 @@ PluginIR::PluginTypeBase PluginJson::TypeJsonDeSerialize(const string& data) mlir::Type elemTy = TypeJsonDeSerialize(type["elementType"].toStyledString()); uint64_t elemNum = GetID(type["arraysize"]); baseType = PluginIR::PluginArrayType::get(PluginServer::GetInstance()->GetContext(), elemTy, elemNum); + } else if (id == static_cast(PluginIR::VectorTyID)) { + mlir::Type elemTy = TypeJsonDeSerialize(type["elementType"].toStyledString()); + uint64_t elemNum = GetID(type["vectorelemnum"]); + baseType = PluginIR::PluginVectorType::get(PluginServer::GetInstance()->GetContext(), elemTy, elemNum); } else if (id == static_cast(PluginIR::FunctionTyID)) { mlir::Type returnTy = TypeJsonDeSerialize(type["fnreturntype"].toStyledString()); llvm::SmallVector typelist; diff --git a/user/LocalVarSummeryPass.cpp b/user/LocalVarSummeryPass.cpp index 2e157e3..ccee9f7 100755 --- a/user/LocalVarSummeryPass.cpp +++ b/user/LocalVarSummeryPass.cpp @@ -61,6 +61,11 @@ static void LocalVarSummery(void) printf("\n struct argname is : %s\n", pName.c_str()); } } + if(auto stTy = ty.getReturnType().dyn_cast()) { + printf("func return type is PluginVectorType\n"); + printf(" vector elem num : %d\n", stTy.getNumElements()); + printf(" vector elem type id : %d\n", stTy.getElementType().dyn_cast().getPluginTypeID()); + } size_t paramIndex = 0; llvm::ArrayRef paramsType = ty.getParams(); for (auto ty : ty.getParams()) { -- 2.33.0