pin-gcc-client/0009-Pin-gcc-client-Support-functiontype-structtype.eg.patch
d00573793 fdcab1229a [sync] Sync patch from openeuler/pin-gcc-client
(cherry picked from commit c479a0e1cdb0f7672c180d94662776012d7bc37a)
2023-02-27 09:22:08 +08:00

722 lines
26 KiB
Diff

From 307831e1210fd9aa4453d774308f63812198b555 Mon Sep 17 00:00:00 2001
From: d00573793 <dingguangya1@huawei.com>
Date: Tue, 21 Feb 2023 11:09:39 +0800
Subject: [PATCH 9/9] [Pin-gcc-client] Support functiontype structtype.eg.
diff --git a/include/Dialect/PluginOps.td b/include/Dialect/PluginOps.td
index c48d002..64a0c3d 100644
--- a/include/Dialect/PluginOps.td
+++ b/include/Dialect/PluginOps.td
@@ -32,15 +32,19 @@ def FunctionOp : Plugin_Op<"function", [NoSideEffect]> {
let arguments = (ins UI64Attr:$id,
StrAttr:$funcName,
- OptionalAttr<BoolAttr>:$declaredInline);
+ OptionalAttr<BoolAttr>:$declaredInline,
+ TypeAttr:$type);
let regions = (region AnyRegion:$bodyRegion);
// Add custom build methods for the operation. These method populates
// the `state` that MLIR uses to create operations, i.e. these are used when
// using `builder.create<Op>(...)`.
let builders = [
- OpBuilderDAG<(ins "uint64_t":$id, "StringRef":$funcName, "bool":$declaredInline)>
+ OpBuilderDAG<(ins "uint64_t":$id, "StringRef":$funcName, "bool":$declaredInline, "Type":$type)>
];
+ let extraClassDeclaration = [{
+ Type getResultType();
+ }];
}
def LocalDeclOp : Plugin_Op<"declaration", [NoSideEffect]> {
diff --git a/include/Dialect/PluginTypes.h b/include/Dialect/PluginTypes.h
index 157b868..da16886 100644
--- a/include/Dialect/PluginTypes.h
+++ b/include/Dialect/PluginTypes.h
@@ -78,6 +78,9 @@ namespace Detail {
struct PluginIntegerTypeStorage;
struct PluginFloatTypeStorage;
struct PluginPointerTypeStorage;
+ struct PluginTypeAndSizeStorage;
+ struct PluginFunctionTypeStorage;
+ struct PluginStructTypeStorage;
}
class PluginIntegerType : public Type::TypeBase<PluginIntegerType, PluginTypeBase, Detail::PluginIntegerTypeStorage> {
@@ -128,6 +131,60 @@ public:
unsigned isReadOnlyElem();
}; // class PluginPointerType
+class PluginArrayType : public Type::TypeBase<PluginArrayType, PluginTypeBase, Detail::PluginTypeAndSizeStorage> {
+public:
+ using Base::Base;
+
+ PluginTypeID getPluginTypeID ();
+
+ static bool isValidElementType(Type type);
+
+ static PluginArrayType get(MLIRContext *context, Type elementType, unsigned numElements);
+
+ Type getElementType();
+
+ unsigned getNumElements();
+}; // class PluginArrayType
+
+class PluginFunctionType : public Type::TypeBase<PluginFunctionType, PluginTypeBase, Detail::PluginFunctionTypeStorage> {
+public:
+ using Base::Base;
+
+ PluginTypeID getPluginTypeID ();
+
+ static bool isValidArgumentType(Type type);
+
+ static bool isValidResultType(Type type);
+
+ static PluginFunctionType get(MLIRContext *context, Type result, ArrayRef<Type> arguments);
+
+ Type getReturnType();
+
+ unsigned getNumParams();
+
+ Type getParamType(unsigned i);
+
+ ArrayRef<Type> getParams();
+}; // class PluginFunctionType
+
+class PluginStructType : public Type::TypeBase<PluginStructType, PluginTypeBase, Detail::PluginStructTypeStorage> {
+public:
+ using Base::Base;
+
+ PluginTypeID getPluginTypeID ();
+
+ static bool isValidElementType(Type type);
+
+ static PluginStructType get(MLIRContext *context, StringRef name, ArrayRef<Type> elements, ArrayRef<StringRef> elemNames);
+
+ StringRef getName();
+
+ ArrayRef<Type> getBody();
+
+ ArrayRef<StringRef> getElementNames();
+
+}; // class PluginStructType
+
class PluginVoidType : public Type::TypeBase<PluginVoidType, PluginTypeBase, TypeStorage> {
public:
using Base::Base;
diff --git a/include/PluginClient/PluginJson.h b/include/PluginClient/PluginJson.h
index 91bd925..3c9fe1f 100755
--- a/include/PluginClient/PluginJson.h
+++ b/include/PluginClient/PluginJson.h
@@ -59,7 +59,7 @@ public:
Json::Value MemOpJsonSerialize(mlir::Plugin::MemOp& data);
Json::Value SSAOpJsonSerialize(mlir::Plugin::SSAOp& data);
/* 将Type类型数据序列化 */
- Json::Value TypeJsonSerialize(PluginIR::PluginTypeBase& type);
+ Json::Value TypeJsonSerialize(PluginIR::PluginTypeBase type);
PluginIR::PluginTypeBase TypeJsonDeSerialize(const string& data, mlir::MLIRContext &context);
/* 将整数型数据序列化 */
void IntegerSerialize(int64_t data, string& out);
diff --git a/lib/Dialect/PluginDialect.cpp b/lib/Dialect/PluginDialect.cpp
index 527c076..69a0aa5 100644
--- a/lib/Dialect/PluginDialect.cpp
+++ b/lib/Dialect/PluginDialect.cpp
@@ -38,6 +38,9 @@ void PluginDialect::initialize()
addTypes<PluginIR::PluginIntegerType,
PluginIR::PluginFloatType,
PluginIR::PluginPointerType,
+ PluginIR::PluginArrayType,
+ PluginIR::PluginFunctionType,
+ PluginIR::PluginStructType,
PluginIR::PluginBooleanType,
PluginIR::PluginVoidType,
PluginIR::PluginUndefType>();
diff --git a/lib/Dialect/PluginOps.cpp b/lib/Dialect/PluginOps.cpp
index 052ebfd..db2574c 100644
--- a/lib/Dialect/PluginOps.cpp
+++ b/lib/Dialect/PluginOps.cpp
@@ -24,6 +24,7 @@
#include "Dialect/PluginDialect.h"
#include "Dialect/PluginOps.h"
+#include "Dialect/PluginTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -32,13 +33,20 @@
using namespace mlir;
using namespace mlir::Plugin;
-void FunctionOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
- uint64_t id, StringRef funcName, bool declaredInline)
+void FunctionOp::build(OpBuilder &builder, OperationState &state,
+ uint64_t id, StringRef funcName, bool declaredInline, Type type)
{
- FunctionOp::build(builder, state,
- builder.getI64IntegerAttr(id),
- builder.getStringAttr(funcName),
- builder.getBoolAttr(declaredInline));
+ state.addRegion();
+ state.addAttribute("id", builder.getI64IntegerAttr(id));
+ state.addAttribute("funcName", builder.getStringAttr(funcName));
+ state.addAttribute("declaredInline", builder.getBoolAttr(declaredInline));
+ if (type) state.addAttribute("type", TypeAttr::get(type));
+}
+
+Type FunctionOp::getResultType()
+{
+ PluginIR::PluginFunctionType resultType = type().dyn_cast<PluginIR::PluginFunctionType>();
+ return resultType;
}
void LocalDeclOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
diff --git a/lib/Dialect/PluginTypes.cpp b/lib/Dialect/PluginTypes.cpp
index 396bf0f..329a2b6 100644
--- a/lib/Dialect/PluginTypes.cpp
+++ b/lib/Dialect/PluginTypes.cpp
@@ -98,6 +98,80 @@ namespace Detail {
Type pointee;
unsigned readOnlyPointee;
};
+
+ struct PluginTypeAndSizeStorage : public TypeStorage {
+ using KeyTy = std::tuple<Type, unsigned>;
+
+ PluginTypeAndSizeStorage(const KeyTy &key)
+ : elementType(std::get<0>(key)), numElements(std::get<1>(key)) {}
+
+ static PluginTypeAndSizeStorage *construct(TypeStorageAllocator &allocator, KeyTy key)
+ {
+ return new (allocator.allocate<PluginTypeAndSizeStorage>())
+ PluginTypeAndSizeStorage(key);
+ }
+
+ bool operator==(const KeyTy &key) const
+ {
+ return std::make_tuple(elementType, numElements) == key;
+ }
+
+ Type elementType;
+ unsigned numElements;
+ };
+
+ struct PluginFunctionTypeStorage : public TypeStorage {
+ using KeyTy = std::tuple<Type, ArrayRef<Type>>;
+
+ PluginFunctionTypeStorage(Type resultType, ArrayRef<Type> argumentTypes)
+ : resultType(resultType), argumentTypes(argumentTypes) {}
+
+ static PluginFunctionTypeStorage *construct(TypeStorageAllocator &allocator, KeyTy key)
+ {
+ return new (allocator.allocate<PluginFunctionTypeStorage>())
+ PluginFunctionTypeStorage(std::get<0>(key), allocator.copyInto(std::get<1>(key)));
+ }
+
+ static unsigned hashKey(const KeyTy &key) {
+ // LLVM doesn't like hashing bools in tuples.
+ return llvm::hash_combine(std::get<0>(key), std::get<1>(key));
+ }
+
+ bool operator==(const KeyTy &key) const
+ {
+ return std::make_tuple(resultType, argumentTypes) == key;
+ }
+
+ Type resultType;
+ ArrayRef<Type> argumentTypes;
+ };
+
+ struct PluginStructTypeStorage : public TypeStorage {
+ using KeyTy = std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StringRef>>;
+
+ PluginStructTypeStorage(StringRef name, ArrayRef<Type> elements, ArrayRef<StringRef> elemNames)
+ : name(name), elements(elements), elemNames(elemNames) {}
+
+ static PluginStructTypeStorage *construct(TypeStorageAllocator &allocator, KeyTy key)
+ {
+ return new (allocator.allocate<PluginStructTypeStorage>())
+ PluginStructTypeStorage(std::get<0>(key), allocator.copyInto(std::get<1>(key)), allocator.copyInto(std::get<2>(key)));
+ }
+
+ static unsigned hashKey(const KeyTy &key) {
+ // LLVM doesn't like hashing bools in tuples.
+ return llvm::hash_combine(std::get<0>(key), std::get<1>(key), std::get<2>(key));
+ }
+
+ bool operator==(const KeyTy &key) const
+ {
+ return std::make_tuple(name, elements, elemNames) == key;
+ }
+
+ StringRef name;
+ ArrayRef<Type> elements;
+ ArrayRef<StringRef> elemNames;
+ };
} // namespace Detail
} // namespace PluginIR
@@ -123,6 +197,15 @@ PluginTypeID PluginTypeBase::getPluginTypeID ()
if (auto Ty = dyn_cast<PluginIR::PluginPointerType>()) {
return Ty.getPluginTypeID ();
}
+ if (auto Ty = dyn_cast<PluginIR::PluginArrayType>()) {
+ return Ty.getPluginTypeID ();
+ }
+ if (auto Ty = dyn_cast<PluginIR::PluginFunctionType>()) {
+ return Ty.getPluginTypeID ();
+ }
+ if (auto Ty = dyn_cast<PluginIR::PluginStructType>()) {
+ return Ty.getPluginTypeID ();
+ }
return PluginTypeID::UndefTyID;
}
@@ -295,3 +378,108 @@ PluginPointerType PluginPointerType::get (MLIRContext *context, Type pointee, un
{
return Base::get(context, pointee, readOnlyPointee);
}
+
+
+// ===----------------------------------------------------------------------===//
+// Plugin Array Type
+// ===----------------------------------------------------------------------===//
+
+PluginTypeID PluginArrayType::getPluginTypeID()
+{
+ return PluginTypeID::ArrayTyID;
+}
+
+bool PluginArrayType::isValidElementType(Type type)
+{
+ return !type.isa<PluginVoidType, PluginFunctionType, PluginUndefType>();
+}
+
+PluginArrayType PluginArrayType::get(MLIRContext *context, Type elementType, unsigned numElements)
+{
+ return Base::get(context, elementType, numElements);
+}
+
+Type PluginArrayType::getElementType()
+{
+ return getImpl()->elementType;
+}
+
+unsigned PluginArrayType::getNumElements()
+{
+ return getImpl()->numElements;
+}
+
+// ===----------------------------------------------------------------------===//
+// Plugin Function Type
+// ===----------------------------------------------------------------------===//
+
+PluginTypeID PluginFunctionType::getPluginTypeID()
+{
+ return PluginTypeID::FunctionTyID;
+}
+
+bool PluginFunctionType::isValidArgumentType(Type type)
+{
+ return !type.isa<PluginVoidType, PluginFunctionType>();
+}
+
+bool PluginFunctionType::isValidResultType(Type type) {
+ return !type.isa<PluginFunctionType>();
+}
+
+PluginFunctionType PluginFunctionType::get(MLIRContext *context, Type result, ArrayRef<Type> arguments)
+{
+ return Base::get(context, result, arguments);
+}
+
+Type PluginFunctionType::getReturnType()
+{
+ return getImpl()->resultType;
+}
+
+unsigned PluginFunctionType::getNumParams()
+{
+ return getImpl()->argumentTypes.size();
+}
+
+Type PluginFunctionType::getParamType(unsigned i) {
+ return getImpl()->argumentTypes[i];
+}
+
+ArrayRef<Type> PluginFunctionType::getParams()
+{
+ return getImpl()->argumentTypes;
+}
+
+// ===----------------------------------------------------------------------===//
+// Plugin Struct Type
+// ===----------------------------------------------------------------------===//
+
+PluginTypeID PluginStructType::getPluginTypeID()
+{
+ return PluginTypeID::StructTyID;
+}
+
+bool PluginStructType::isValidElementType(Type type) {
+ return !type.isa<PluginVoidType, PluginFunctionType>();
+}
+
+PluginStructType PluginStructType::get(MLIRContext *context, StringRef name, ArrayRef<Type> elements, ArrayRef<StringRef> elemNames)
+{
+ return Base::get(context, name, elements, elemNames);
+}
+
+StringRef PluginStructType::getName()
+{
+ return getImpl()->name;
+}
+
+ArrayRef<Type> PluginStructType::getBody()
+{
+ return getImpl()->elements;
+}
+
+ArrayRef<StringRef> PluginStructType::getElementNames()
+{
+ return getImpl()->elemNames;
+}
\ No newline at end of file
diff --git a/lib/PluginClient/PluginJson.cpp b/lib/PluginClient/PluginJson.cpp
index 22cd489..a9db475 100755
--- a/lib/PluginClient/PluginJson.cpp
+++ b/lib/PluginClient/PluginJson.cpp
@@ -20,6 +20,7 @@
This file contains the implementation of the PluginJson class.
*/
+#include <iostream>
#include <json/json.h>
#include "PluginAPI/PluginClientAPI.h"
#include "PluginClient/PluginLog.h"
@@ -36,7 +37,7 @@ static uintptr_t GetID(Json::Value node)
return atol(id.c_str());
}
-Json::Value PluginJson::TypeJsonSerialize(PluginIR::PluginTypeBase& type)
+Json::Value PluginJson::TypeJsonSerialize(PluginIR::PluginTypeBase type)
{
Json::Value root;
Json::Value operationObj;
@@ -48,6 +49,41 @@ Json::Value PluginJson::TypeJsonSerialize(PluginIR::PluginTypeBase& type)
ReTypeId = static_cast<uint64_t>(type.getPluginTypeID());
item["id"] = std::to_string(ReTypeId);
+ if (auto Ty = type.dyn_cast<PluginIR::PluginStructType>()) {
+ std::string tyName = Ty.getName().str();
+ item["structtype"] = tyName;
+ size_t paramIndex = 0;
+ ArrayRef<Type> paramsType = Ty.getBody();
+ for (auto ty :paramsType) {
+ string paramStr = "elemType" + std::to_string(paramIndex++);
+ item["structelemType"][paramStr] = TypeJsonSerialize(ty.dyn_cast<PluginIR::PluginTypeBase>());
+ }
+ paramIndex = 0;
+ ArrayRef<StringRef> paramsNames = Ty.getElementNames();
+ for (auto name :paramsNames) {
+ string paramStr = "elemName" + std::to_string(paramIndex++);
+ item["structelemName"][paramStr] = name.str();
+ }
+ }
+
+ if (auto Ty = type.dyn_cast<PluginIR::PluginFunctionType>()) {
+ auto fnrestype = Ty.getReturnType().dyn_cast<PluginIR::PluginTypeBase>();
+ item["fnreturntype"] = TypeJsonSerialize(fnrestype);
+ size_t paramIndex = 0;
+ ArrayRef<Type> paramsType = Ty.getParams();
+ for (auto ty : Ty.getParams()) {
+ string paramStr = "argType" + std::to_string(paramIndex++);
+ item["fnargsType"][paramStr] = TypeJsonSerialize(ty.dyn_cast<PluginIR::PluginTypeBase>());
+ }
+ }
+
+ if (auto Ty = type.dyn_cast<PluginIR::PluginArrayType>()) {
+ auto elemTy = Ty.getElementType().dyn_cast<PluginIR::PluginTypeBase>();
+ item["elementType"] = TypeJsonSerialize(elemTy);
+ uint64_t elemNum = Ty.getNumElements();
+ item["arraysize"] = std::to_string(elemNum);
+ }
+
if (auto elemTy = type.dyn_cast<PluginIR::PluginPointerType>()) {
auto baseTy = elemTy.getElementType().dyn_cast<PluginIR::PluginTypeBase>();
item["elementType"] = TypeJsonSerialize(baseTy);
@@ -104,7 +140,40 @@ PluginIR::PluginTypeBase PluginJson::TypeJsonDeSerialize(const string& data, mli
} else if (id == static_cast<uint64_t>(PluginIR::PointerTyID)) {
mlir::Type elemTy = TypeJsonDeSerialize(type["elementType"].toStyledString(), context);
baseType = PluginIR::PluginPointerType::get(&context, elemTy, type["elemConst"].asString() == "1" ? 1 : 0);
- } else {
+ } else if (id == static_cast<uint64_t>(PluginIR::ArrayTyID)) {
+ mlir::Type elemTy = TypeJsonDeSerialize(type["elementType"].toStyledString(), context);
+ uint64_t elemNum = GetID(type["arraysize"]);
+ baseType = PluginIR::PluginArrayType::get(&context, elemTy, elemNum);
+ } else if (id == static_cast<uint64_t>(PluginIR::FunctionTyID)) {
+ mlir::Type returnTy = TypeJsonDeSerialize(type["fnreturntype"].toStyledString(), context);
+ llvm::SmallVector<Type> typelist;
+ Json::Value::Members fnTypeNum = type["fnargsType"].getMemberNames();
+ uint64_t argsNum = fnTypeNum.size();
+ for (size_t paramIndex = 0; paramIndex < argsNum; paramIndex++) {
+ string Key = "argType" + std::to_string(paramIndex);
+ mlir::Type paramTy = TypeJsonDeSerialize(type["fnargsType"][Key].toStyledString(), context);
+ typelist.push_back(paramTy);
+ }
+ baseType = PluginIR::PluginFunctionType::get(&context, returnTy, typelist);
+ } else if (id == static_cast<uint64_t>(PluginIR::StructTyID)) {
+ StringRef tyName = type["structtype"].toStyledString();
+ llvm::SmallVector<Type> typelist;
+ Json::Value::Members elemTypeNum = type["structelemType"].getMemberNames();
+ for (size_t paramIndex = 0; paramIndex < elemTypeNum.size(); paramIndex++) {
+ string Key = "elemType" + std::to_string(paramIndex);
+ mlir::Type paramTy = TypeJsonDeSerialize(type["structelemType"][Key].toStyledString(), context);
+ typelist.push_back(paramTy);
+ }
+ llvm::SmallVector<StringRef> names;
+ Json::Value::Members elemNameNum = type["structelemName"].getMemberNames();
+ for (size_t paramIndex = 0; paramIndex < elemTypeNum.size(); paramIndex++) {
+ string Key = "elemName" + std::to_string(paramIndex);
+ StringRef elemName = type["structelemName"][Key].toStyledString();
+ names.push_back(elemName);
+ }
+ baseType = PluginIR::PluginStructType::get(&context, tyName, typelist, names);
+ }
+ else {
if (PluginTypeId == PluginIR::VoidTyID) {
baseType = PluginIR::PluginVoidType::get(&context);
}
@@ -127,7 +196,6 @@ void PluginJson::FunctionOpJsonSerialize(vector<FunctionOp>& data, string& out)
int i = 0;
string operation;
-
for (auto& d: data) {
item["id"] = std::to_string(d.idAttr().getInt());
if (d.declaredInlineAttr().getValue()) {
@@ -136,6 +204,14 @@ void PluginJson::FunctionOpJsonSerialize(vector<FunctionOp>& data, string& out)
item["attributes"]["declaredInline"] = "0";
}
item["attributes"]["funcName"] = d.funcNameAttr().getValue().str().c_str();
+
+ mlir::Type fnty = d.type();
+ if (auto ty = fnty.dyn_cast<PluginIR::PluginFunctionType>()) {
+ if (auto retTy = ty.dyn_cast<PluginIR::PluginTypeBase>()) {
+ item["retType"] = TypeJsonSerialize(retTy);
+ }
+ }
+
auto &region = d.getRegion();
size_t bbIdx = 0;
for (auto &b : region) {
diff --git a/lib/Translate/GimpleToPluginOps.cpp b/lib/Translate/GimpleToPluginOps.cpp
index b5974aa..3e2321a 100644
--- a/lib/Translate/GimpleToPluginOps.cpp
+++ b/lib/Translate/GimpleToPluginOps.cpp
@@ -536,9 +536,12 @@ FunctionOp GimpleToPluginOps::BuildFunctionOp(uint64_t functionId)
bool declaredInline = false;
if (DECL_DECLARED_INLINE_P(fn->decl))
declaredInline = true;
+ tree returnType = TREE_TYPE(fn->decl);
+ PluginTypeBase rPluginType = typeTranslator.translateType((intptr_t)returnType);
auto location = builder.getUnknownLoc();
+ auto Ty = rPluginType.dyn_cast<PluginFunctionType>();
FunctionOp retOp = builder.create<FunctionOp>(location, functionId,
- funcName, declaredInline);
+ funcName, declaredInline, Ty);
auto& fr = retOp.bodyRegion();
if (fn->cfg == nullptr) return retOp;
if (!ProcessBasicBlock((intptr_t)ENTRY_BLOCK_PTR_FOR_FN(fn), fr)) {
diff --git a/lib/Translate/TypeTranslation.cpp b/lib/Translate/TypeTranslation.cpp
index 7c7cff4..ad840f7 100644
--- a/lib/Translate/TypeTranslation.cpp
+++ b/lib/Translate/TypeTranslation.cpp
@@ -45,10 +45,12 @@
#include "ssa.h"
#include "output.h"
#include "langhooks.h"
+#include "print-tree.h"
+#include "stor-layout.h"
-using namespace mlir;
namespace PluginIR {
+using namespace mlir;
namespace Detail {
/* Support for translating Plugin IR types to MLIR Plugin dialect types. */
class TypeFromPluginIRTranslatorImpl {
@@ -77,6 +79,77 @@ private:
return false;
}
+ unsigned getDomainIndex (tree type)
+ {
+ return tree_to_shwi(TYPE_MAX_VALUE(TYPE_DOMAIN(type)))+1;
+ }
+
+ llvm::SmallVector<Type> getArgsType (tree type)
+ {
+ tree parmlist = TYPE_ARG_TYPES (type);
+ tree parmtype;
+ llvm::SmallVector<Type> typelist;
+ for (; parmlist; parmlist = TREE_CHAIN (parmlist))
+ {
+ parmtype = TREE_VALUE (parmlist);
+ typelist.push_back(translatePrimitiveType(parmtype));
+ }
+ return typelist;
+ }
+
+ const char *getTypeName (tree type)
+ {
+ const char *tname = NULL;
+
+ if (type == NULL)
+ {
+ return NULL;
+ }
+
+ if (TYPE_NAME (type) != NULL)
+ {
+ if (TREE_CODE (TYPE_NAME (type)) == IDENTIFIER_NODE)
+ {
+ tname = IDENTIFIER_POINTER (TYPE_NAME (type));
+ }
+ else if (DECL_NAME (TYPE_NAME (type)) != NULL)
+ {
+ tname = IDENTIFIER_POINTER (DECL_NAME (TYPE_NAME (type)));
+ }
+ }
+ return tname;
+ }
+
+ llvm::SmallVector<Type> getElemType(tree type)
+ {
+ llvm::SmallVector<Type> typelist;
+ tree parmtype;
+ for (tree field = TYPE_FIELDS (type); field; field = DECL_CHAIN (field))
+ {
+ if (TREE_CODE (field) == FIELD_DECL)
+ {
+ parmtype = TREE_TYPE(field);
+ typelist.push_back(translatePrimitiveType(parmtype));
+ }
+ }
+ return typelist;
+ }
+
+ llvm::SmallVector<StringRef> getElemNames(tree type)
+ {
+ llvm::SmallVector<StringRef> names;
+ StringRef name;
+ for (tree field = TYPE_FIELDS (type); field; field = DECL_CHAIN (field))
+ {
+ if (TREE_CODE (field) == FIELD_DECL)
+ {
+ name = IDENTIFIER_POINTER ( DECL_NAME(field));
+ names.push_back(name);
+ }
+ }
+ return names;
+ }
+
/* Translates the given primitive, i.e. non-parametric in MLIR nomenclature,
type. */
PluginTypeBase translatePrimitiveType (tree type)
@@ -93,12 +166,21 @@ private:
if (TREE_CODE(type) == POINTER_TYPE)
return PluginPointerType::get(&context, translatePrimitiveType(TREE_TYPE(type)),
TYPE_READONLY(TREE_TYPE(type)) ? 1 : 0);
+ if (TREE_CODE(type) == ARRAY_TYPE)
+ return PluginArrayType::get(&context,translatePrimitiveType(TREE_TYPE(type)), getDomainIndex(type));
+ if (TREE_CODE(type) == FUNCTION_TYPE) {
+ llvm::SmallVector<Type> argsType = getArgsType(type);
+ return PluginFunctionType::get(&context, translatePrimitiveType(TREE_TYPE(type)),argsType);
+ }
+ if (TREE_CODE(type) == RECORD_TYPE) {
+ return PluginStructType::get(&context, getTypeName(type), getElemType(type), getElemNames(type));
+ }
return PluginUndefType::get(&context);
}
/* The context in which MLIR types are created. */
mlir::MLIRContext &context;
-};
+}; // class TypeFromPluginIRTranslatorImpl
/* Support for translating MLIR Plugin dialect types to Plugin IR types . */
class TypeToPluginIRTranslatorImpl {
@@ -126,6 +208,16 @@ private:
return false;
}
+ auto_vec<tree> getParamsType(PluginFunctionType Ty)
+ {
+ auto_vec<tree> paramTypes;
+ ArrayRef<Type> ArgsTypes = Ty.getParams();
+ for (auto ty :ArgsTypes) {
+ paramTypes.safe_push(translatePrimitiveType(ty.dyn_cast<PluginTypeBase>()));
+ }
+ return paramTypes;
+ }
+
tree translatePrimitiveType(PluginTypeBase type)
{
if (auto Ty = type.dyn_cast<PluginIntegerType>()) {
@@ -179,9 +271,50 @@ private:
TYPE_READONLY(elmTy) = elmConst ? 1 : 0;
return build_pointer_type(elmTy);
}
+ if (auto Ty = type.dyn_cast<PluginArrayType>()) {
+ mlir::Type elmType = Ty.getElementType();
+ auto ty = elmType.dyn_cast<PluginTypeBase>();
+ tree elmTy = translatePrimitiveType(ty);
+ unsigned elmNum = Ty.getNumElements();
+ tree index = build_index_type (size_int (elmNum));
+ return build_array_type(elmTy, index);
+ }
+ if (auto Ty = type.dyn_cast<PluginFunctionType>()) {
+ Type resultType = Ty.getReturnType();
+ tree returnType = translatePrimitiveType(resultType.dyn_cast<PluginTypeBase>());
+ auto_vec<tree> paramTypes = getParamsType(Ty);
+ return build_function_type_array(returnType, paramTypes.length (), paramTypes.address ());
+ }
+ if (auto Ty = type.dyn_cast<PluginStructType>()) {
+ ArrayRef<Type> elemTypes = Ty.getBody();
+ ArrayRef<StringRef> elemNames = Ty.getElementNames();
+ StringRef tyName = Ty.getName();
+ unsigned fieldSize = elemNames.size();
+
+ tree fields[fieldSize];
+ tree ret;
+ unsigned i;
+
+ ret = make_node (RECORD_TYPE);
+ for (i = 0; i < fieldSize; i++)
+ {
+ mlir::Type elemTy = elemTypes[i];
+ auto ty = elemTy.dyn_cast<PluginTypeBase>();
+ tree elmType = translatePrimitiveType(ty);
+ fields[i] = build_decl (UNKNOWN_LOCATION, FIELD_DECL, get_identifier (elemNames[i].str().c_str()), elmType);
+ DECL_CONTEXT (fields[i]) = ret;
+ if (i) DECL_CHAIN (fields[i - 1]) = fields[i];
+ }
+ tree typeDecl = build_decl (input_location, TYPE_DECL, get_identifier (tyName.str().c_str()), ret);
+ DECL_ARTIFICIAL (typeDecl) = 1;
+ TYPE_FIELDS (ret) = fields[0];
+ TYPE_NAME (ret) = typeDecl;
+ layout_type (ret);
+ return ret;
+ }
return NULL;
}
-};
+}; // class TypeToPluginIRTranslatorImpl
} // namespace Detail
} // namespace PluginIR
--
2.27.0.windows.1