152 lines
6.0 KiB
Diff
152 lines
6.0 KiB
Diff
From 50801883d5d5e89d44026ead6e38e149caa346b2 Mon Sep 17 00:00:00 2001
|
|
From: d00573793 <dingguangya1@huawei.com>
|
|
Date: Sat, 25 Feb 2023 15:31:02 +0800
|
|
Subject: [PATCH 3/5] [Pin-server] Support Vectortype
|
|
|
|
|
|
diff --git a/include/Dialect/PluginTypes.h b/include/Dialect/PluginTypes.h
|
|
index 3f7f14b..9693294 100644
|
|
--- a/include/Dialect/PluginTypes.h
|
|
+++ b/include/Dialect/PluginTypes.h
|
|
@@ -55,8 +55,7 @@ enum PluginTypeID {
|
|
PointerTyID, ///< Pointers
|
|
StructTyID, ///< Structures
|
|
ArrayTyID, ///< Arrays
|
|
- FixedVectorTyID, ///< Fixed width SIMD vector type
|
|
- ScalableVectorTyID ///< Scalable SIMD vector type
|
|
+ VectorTyID, ///< Arrays
|
|
};
|
|
|
|
class PluginTypeBase : public Type {
|
|
@@ -146,6 +145,21 @@ public:
|
|
unsigned getNumElements();
|
|
}; // class PluginArrayType
|
|
|
|
+class PluginVectorType : public Type::TypeBase<PluginVectorType, PluginTypeBase, detail::PluginTypeAndSizeStorage> {
|
|
+public:
|
|
+ using Base::Base;
|
|
+
|
|
+ PluginTypeID getPluginTypeID ();
|
|
+
|
|
+ static bool isValidElementType(Type type);
|
|
+
|
|
+ static PluginVectorType get(MLIRContext *context, Type elementType, unsigned numElements);
|
|
+
|
|
+ Type getElementType();
|
|
+
|
|
+ unsigned getNumElements();
|
|
+}; // class PluginVectorType
|
|
+
|
|
class PluginFunctionType : public Type::TypeBase<PluginFunctionType, PluginTypeBase, detail::PluginFunctionTypeStorage> {
|
|
public:
|
|
using Base::Base;
|
|
diff --git a/lib/Dialect/PluginDialect.cpp b/lib/Dialect/PluginDialect.cpp
|
|
index ba8e4fe..95a78da 100644
|
|
--- a/lib/Dialect/PluginDialect.cpp
|
|
+++ b/lib/Dialect/PluginDialect.cpp
|
|
@@ -38,6 +38,7 @@ void PluginDialect::initialize()
|
|
PluginIR::PluginFloatType,
|
|
PluginIR::PluginPointerType,
|
|
PluginIR::PluginArrayType,
|
|
+ PluginIR::PluginVectorType,
|
|
PluginIR::PluginFunctionType,
|
|
PluginIR::PluginStructType,
|
|
PluginIR::PluginBooleanType,
|
|
diff --git a/lib/Dialect/PluginTypes.cpp b/lib/Dialect/PluginTypes.cpp
|
|
index 337fc49..89e4b1a 100644
|
|
--- a/lib/Dialect/PluginTypes.cpp
|
|
+++ b/lib/Dialect/PluginTypes.cpp
|
|
@@ -199,6 +199,9 @@ PluginTypeID PluginTypeBase::getPluginTypeID ()
|
|
if (auto Ty = dyn_cast<PluginIR::PluginArrayType>()) {
|
|
return Ty.getPluginTypeID ();
|
|
}
|
|
+ if (auto Ty = dyn_cast<PluginIR::PluginVectorType>()) {
|
|
+ return Ty.getPluginTypeID ();
|
|
+ }
|
|
if (auto Ty = dyn_cast<PluginIR::PluginFunctionType>()) {
|
|
return Ty.getPluginTypeID ();
|
|
}
|
|
@@ -406,6 +409,35 @@ unsigned PluginArrayType::getNumElements()
|
|
return getImpl()->numElements;
|
|
}
|
|
|
|
+// ===----------------------------------------------------------------------===//
|
|
+// Plugin Vector Type
|
|
+// ===----------------------------------------------------------------------===//
|
|
+
|
|
+PluginTypeID PluginVectorType::getPluginTypeID()
|
|
+{
|
|
+ return PluginTypeID::ArrayTyID;
|
|
+}
|
|
+
|
|
+bool PluginVectorType::isValidElementType(Type type)
|
|
+{
|
|
+ return type.isa<PluginIntegerType, PluginFloatType>();
|
|
+}
|
|
+
|
|
+PluginVectorType PluginVectorType::get(MLIRContext *context, Type elementType, unsigned numElements)
|
|
+{
|
|
+ return Base::get(context, elementType, numElements);
|
|
+}
|
|
+
|
|
+Type PluginVectorType::getElementType()
|
|
+{
|
|
+ return getImpl()->elementType;
|
|
+}
|
|
+
|
|
+unsigned PluginVectorType::getNumElements()
|
|
+{
|
|
+ return getImpl()->numElements;
|
|
+}
|
|
+
|
|
// ===----------------------------------------------------------------------===//
|
|
// Plugin Function Type
|
|
// ===----------------------------------------------------------------------===//
|
|
diff --git a/lib/PluginAPI/PluginServerAPI.cpp b/lib/PluginAPI/PluginServerAPI.cpp
|
|
index e3435b0..a471ddf 100644
|
|
--- a/lib/PluginAPI/PluginServerAPI.cpp
|
|
+++ b/lib/PluginAPI/PluginServerAPI.cpp
|
|
@@ -347,6 +347,8 @@ PluginIR::PluginTypeID PluginServerAPI::GetTypeCodeFromString(string type)
|
|
return PluginIR::PluginTypeID::PointerTyID;
|
|
} else if (type == "ArrayTy") {
|
|
return PluginIR::PluginTypeID::ArrayTyID;
|
|
+ } else if (type == "VectorTy") {
|
|
+ return PluginIR::PluginTypeID::VectorTyID;
|
|
} else if (type == "FunctionTy") {
|
|
return PluginIR::PluginTypeID::FunctionTyID;
|
|
} else if (type == "StructTy") {
|
|
diff --git a/lib/PluginServer/PluginJson.cpp b/lib/PluginServer/PluginJson.cpp
|
|
index e1beddf..4d41351 100755
|
|
--- a/lib/PluginServer/PluginJson.cpp
|
|
+++ b/lib/PluginServer/PluginJson.cpp
|
|
@@ -343,6 +343,10 @@ PluginIR::PluginTypeBase PluginJson::TypeJsonDeSerialize(const string& data)
|
|
mlir::Type elemTy = TypeJsonDeSerialize(type["elementType"].toStyledString());
|
|
uint64_t elemNum = GetID(type["arraysize"]);
|
|
baseType = PluginIR::PluginArrayType::get(PluginServer::GetInstance()->GetContext(), elemTy, elemNum);
|
|
+ } else if (id == static_cast<uint64_t>(PluginIR::VectorTyID)) {
|
|
+ mlir::Type elemTy = TypeJsonDeSerialize(type["elementType"].toStyledString());
|
|
+ uint64_t elemNum = GetID(type["vectorelemnum"]);
|
|
+ baseType = PluginIR::PluginVectorType::get(PluginServer::GetInstance()->GetContext(), elemTy, elemNum);
|
|
} else if (id == static_cast<uint64_t>(PluginIR::FunctionTyID)) {
|
|
mlir::Type returnTy = TypeJsonDeSerialize(type["fnreturntype"].toStyledString());
|
|
llvm::SmallVector<Type> typelist;
|
|
diff --git a/user/LocalVarSummeryPass.cpp b/user/LocalVarSummeryPass.cpp
|
|
index 2e157e3..ccee9f7 100755
|
|
--- a/user/LocalVarSummeryPass.cpp
|
|
+++ b/user/LocalVarSummeryPass.cpp
|
|
@@ -61,6 +61,11 @@ static void LocalVarSummery(void)
|
|
printf("\n struct argname is : %s\n", pName.c_str());
|
|
}
|
|
}
|
|
+ if(auto stTy = ty.getReturnType().dyn_cast<PluginIR::PluginVectorType>()) {
|
|
+ printf("func return type is PluginVectorType\n");
|
|
+ printf(" vector elem num : %d\n", stTy.getNumElements());
|
|
+ printf(" vector elem type id : %d\n", stTy.getElementType().dyn_cast<PluginIR::PluginTypeBase>().getPluginTypeID());
|
|
+ }
|
|
size_t paramIndex = 0;
|
|
llvm::ArrayRef<mlir::Type> paramsType = ty.getParams();
|
|
for (auto ty : ty.getParams()) {
|
|
--
|
|
2.27.0.windows.1
|
|
|