pin-server/0002-Pin-server-support-LoopOp.patch
benniaobufeijiushiji 43622cf48f [sync] Sync patch from openEuler/pin-server
Sync patch from openEuler/pin-server 20221208

(cherry picked from commit 5b1dfc11cd542269226338319a38315ba6f5cf7e)
2022-12-09 09:52:15 +08:00

773 lines
26 KiB
Diff
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

From c51ee872abc92b72fa35f7ff52c2d9e506dde40a Mon Sep 17 00:00:00 2001
From: benniaobufeijiushiji <linda7@huawei.com>
Date: Wed, 7 Dec 2022 14:52:19 +0800
Subject: [PATCH 2/4] [Pin-server] support LoopOp Add LoopOp for Plugin Dialect
---
include/Dialect/PluginOps.td | 35 +++++
include/PluginAPI/BasicPluginOpsAPI.h | 11 ++
include/PluginAPI/PluginServerAPI.h | 20 +++
include/PluginServer/PluginServer.h | 21 +++
lib/Dialect/PluginOps.cpp | 96 ++++++++++++-
lib/PluginAPI/PluginServerAPI.cpp | 171 +++++++++++++++++++++++-
lib/PluginServer/PluginServer.cpp | 185 +++++++++++++++++++++++++-
user.cpp | 16 +--
8 files changed, 541 insertions(+), 14 deletions(-)
diff --git a/include/Dialect/PluginOps.td b/include/Dialect/PluginOps.td
index 374340c..02d9bdd 100644
--- a/include/Dialect/PluginOps.td
+++ b/include/Dialect/PluginOps.td
@@ -20,6 +20,7 @@
include "PluginDialect.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/CallInterfaces.td"
def FunctionOp : Plugin_Op<"function", [NoSideEffect]> {
let summary = "function with a region";
@@ -38,6 +39,10 @@ def FunctionOp : Plugin_Op<"function", [NoSideEffect]> {
let builders = [
OpBuilderDAG<(ins "uint64_t":$id, "StringRef":$funcName, "bool":$declaredInline)>
];
+
+ let extraClassDeclaration = [{
+ std::vector<LoopOp> GetAllLoops();
+ }];
}
def LocalDeclOp : Plugin_Op<"declaration", [NoSideEffect]> {
@@ -54,4 +59,34 @@ def LocalDeclOp : Plugin_Op<"declaration", [NoSideEffect]> {
];
}
+def LoopOp : Plugin_Op<"loop", [NoSideEffect]> {
+ let summary = "loop operation";
+ let description = [{
+ TODO.
+ }];
+ let arguments = (ins OptionalAttr<UI64Attr>:$id,
+ OptionalAttr<UI32Attr>:$index,
+ OptionalAttr<UI64Attr>:$innerLoopId,
+ OptionalAttr<UI64Attr>:$outerLoopId,
+ OptionalAttr<UI32Attr>:$numBlock);
+ let regions = (region AnyRegion:$bodyRegion);
+ let builders = [
+ OpBuilderDAG<(ins "uint64_t":$id, "uint32_t":$index,
+ "uint64_t":$innerLoopId, "uint64_t":$outerLoopId,
+ "uint32_t":$numBlock)>
+ ];
+ let extraClassDeclaration = [{
+ uint64_t GetHeader();
+ uint64_t GetLatch();
+ std::pair<uint64_t, uint64_t> GetSingleExit();
+ void Delete();
+ LoopOp GetInnerLoop();
+ LoopOp GetOuterLoop();
+ bool IsBlockInside(uint64_t);
+ std::vector<std::pair<uint64_t, uint64_t> > GetExitEdges();
+ std::vector<uint64_t> GetLoopBody();
+ void AddLoop(uint64_t, uint64_t);
+ }];
+}
+
#endif // PLUGIN_OPS_TD
\ No newline at end of file
diff --git a/include/PluginAPI/BasicPluginOpsAPI.h b/include/PluginAPI/BasicPluginOpsAPI.h
index b6f2b4f..fde5f63 100644
--- a/include/PluginAPI/BasicPluginOpsAPI.h
+++ b/include/PluginAPI/BasicPluginOpsAPI.h
@@ -25,6 +25,7 @@
namespace PluginAPI {
using std::vector;
using std::string;
+using std::pair;
using namespace mlir::Plugin;
/* The BasicPluginOpsAPI class defines the basic plugin API, both the plugin
@@ -37,6 +38,16 @@ public:
virtual vector<FunctionOp> GetAllFunc() = 0;
virtual vector<LocalDeclOp> GetDecls(uint64_t) = 0;
+ virtual LoopOp AllocateNewLoop(uint64_t) = 0;
+ virtual vector<LoopOp> GetLoopsFromFunc(uint64_t) = 0;
+ virtual LoopOp GetLoopById(uint64_t) = 0;
+ virtual void AddLoop(uint64_t, uint64_t, uint64_t) = 0;
+ virtual void DeleteLoop(uint64_t) = 0;
+ virtual vector<uint64_t> GetLoopBody(uint64_t) = 0;
+ virtual bool IsBlockInLoop(uint64_t, uint64_t) = 0;
+ virtual pair<uint64_t, uint64_t> LoopSingleExit(uint64_t) = 0;
+ virtual vector<pair<uint64_t, uint64_t> > GetLoopExitEdges(uint64_t) = 0;
+ virtual LoopOp GetBlockLoopFather(uint64_t) = 0;
}; // class BasicPluginOpsAPI
} // namespace PluginAPI
diff --git a/include/PluginAPI/PluginServerAPI.h b/include/PluginAPI/PluginServerAPI.h
index 7cb4fa7..5b14f2e 100644
--- a/include/PluginAPI/PluginServerAPI.h
+++ b/include/PluginAPI/PluginServerAPI.h
@@ -29,6 +29,7 @@ namespace PluginAPI {
using std::vector;
using std::string;
+using std::pair;
using namespace mlir::Plugin;
class PluginServerAPI : public BasicPluginOpsAPI {
public:
@@ -38,9 +39,28 @@ public:
vector<FunctionOp> GetAllFunc() override;
vector<LocalDeclOp> GetDecls(uint64_t) override;
PluginIR::PluginTypeID GetTypeCodeFromString(string type);
+ LoopOp AllocateNewLoop(uint64_t funcID);
+ LoopOp GetLoopById(uint64_t loopID);
+ vector<LoopOp> GetLoopsFromFunc(uint64_t funcID);
+ bool IsBlockInLoop(uint64_t loopID, uint64_t blockID);
+ void DeleteLoop(uint64_t loopID);
+ void AddLoop(uint64_t loopID, uint64_t outerID, uint64_t funcID);
+ pair<uint64_t, uint64_t> LoopSingleExit(uint64_t loopID);
+ vector<pair<uint64_t, uint64_t> > GetLoopExitEdges(uint64_t loopID);
+ uint64_t GetHeader(uint64_t loopID);
+ uint64_t GetLatch(uint64_t loopID);
+ vector<uint64_t> GetLoopBody(uint64_t loopID);
+ LoopOp GetBlockLoopFather(uint64_t blockID);
private:
vector<FunctionOp> GetOperationResult(const string& funName, const string& params);
vector<LocalDeclOp> GetDeclOperationResult(const string& funName, const string& params);
+ LoopOp GetLoopResult(const string&funName, const string& params);
+ vector<LoopOp> GetLoopsResult(const string& funName, const string& params);
+ bool GetBoolResult(const string& funName, const string& params);
+ pair<uint64_t, uint64_t> EdgeResult(const string& funName, const string& params);
+ vector<pair<uint64_t, uint64_t> > EdgesResult(const string& funName, const string& params);
+ uint64_t BlockResult(const string& funName, const string& params);
+ vector<uint64_t> BlocksResult(const string& funName, const string& params);
void WaitClientResult(const string& funName, const string& params);
}; // class PluginServerAPI
} // namespace PluginAPI
diff --git a/include/PluginServer/PluginServer.h b/include/PluginServer/PluginServer.h
index 2cb314c..610ecb9 100644
--- a/include/PluginServer/PluginServer.h
+++ b/include/PluginServer/PluginServer.h
@@ -130,6 +130,13 @@ public:
static PluginServer *GetInstance(void);
vector<mlir::Plugin::FunctionOp> GetFunctionOpResult(void);
vector<mlir::Plugin::LocalDeclOp> GetLocalDeclResult(void);
+ mlir::Plugin::LoopOp LoopOpResult(void);
+ vector<mlir::Plugin::LoopOp> LoopOpsResult(void);
+ vector<uint64_t> BlockIdsResult(void);
+ uint64_t BlockIdResult(void);
+ vector<std::pair<uint64_t, uint64_t> > EdgesResult(void);
+ std::pair<uint64_t, uint64_t> EdgeResult(void);
+ bool BoolResult(void);
/* 回调函数接口用于向server注册用户需要执行的函数 */
int RegisterUserFunc(InjectPoint inject, UserFunc func);
int RegisterPassManagerSetup(InjectPoint inject, const ManagerSetupData& passData, UserFunc func);
@@ -176,6 +183,13 @@ public:
void FuncOpJsonDeSerialize(const string& data);
void TypeJsonDeSerialize(const string& data);
void LocalDeclOpJsonDeSerialize(const string& data);
+ void LoopOpsJsonDeSerialize(const string& data);
+ void LoopOpJsonDeSerialize(const string& data);
+ void BoolResJsonDeSerialize(const string& data);
+ void EdgesJsonDeSerialize(const string& data);
+ void EdgeJsonDeSerialize(const string& data);
+ void BlocksJsonDeSerialize(const string& data);
+ void BlockJsonDeSerialize(const string& data);
/* json反序列化根据key值分别调用Operation/Decl/Type反序列化接口函数 */
void JsonDeSerialize(const string& key, const string& data);
/* 解析客户端发送过来的-fplugin-arg参数并保存在私有变量args中 */
@@ -223,6 +237,13 @@ private:
vector<mlir::Plugin::FunctionOp> funcOpData;
PluginIR::PluginTypeBase pluginType;
vector<mlir::Plugin::LocalDeclOp> decls;
+ vector<mlir::Plugin::LoopOp> loops;
+ mlir::Plugin::LoopOp loop;
+ vector<std::pair<uint64_t, uint64_t> > edges;
+ std::pair<uint64_t, uint64_t> edge;
+ vector<uint64_t> blockIds;
+ uint64_t blockId;
+ bool boolRes;
/* 保存用户注册的回调函数,它们将在注入点事件触发后调用 */
map<InjectPoint, vector<RecordedUserFunc>> userFunc;
string apiFuncName; // 保存用户调用PluginAPI的函数名
diff --git a/lib/Dialect/PluginOps.cpp b/lib/Dialect/PluginOps.cpp
index 2377875..481d51a 100644
--- a/lib/Dialect/PluginOps.cpp
+++ b/lib/Dialect/PluginOps.cpp
@@ -20,6 +20,7 @@
//
//===----------------------------------------------------------------------===//
+#include "PluginAPI/PluginServerAPI.h"
#include "Dialect/PluginDialect.h"
#include "Dialect/PluginOps.h"
@@ -29,15 +30,24 @@
using namespace mlir;
using namespace mlir::Plugin;
+using std::vector;
+using std::pair;
void FunctionOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
uint64_t id, StringRef funcName, bool declaredInline) {
FunctionOp::build(builder, state,
- builder.getI64IntegerAttr(id),
+ builder.getI64IntegerAttr(id),
builder.getStringAttr(funcName),
builder.getBoolAttr(declaredInline));
}
+vector<LoopOp> FunctionOp::GetAllLoops()
+{
+ PluginAPI::PluginServerAPI pluginAPI;
+ uint64_t funcId = idAttr().getInt();
+ return pluginAPI.GetLoopsFromFunc(funcId);
+}
+
void LocalDeclOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
uint64_t id, StringRef symName,
int64_t typeID, uint64_t typeWidth) {
@@ -48,6 +58,90 @@ void LocalDeclOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
builder.getI64IntegerAttr(typeWidth));
}
+void LoopOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
+ uint64_t id, uint32_t index, uint64_t innerLoopId,
+ uint64_t outerLoopId, uint32_t numBlock) {
+ LoopOp::build(builder, state,
+ builder.getI64IntegerAttr(id),
+ builder.getI32IntegerAttr(index),
+ builder.getI64IntegerAttr(innerLoopId),
+ builder.getI64IntegerAttr(outerLoopId),
+ builder.getI32IntegerAttr(numBlock));
+}
+
+// FIXME: use Block instead of uint64_t
+uint64_t LoopOp::GetHeader()
+{
+ PluginAPI::PluginServerAPI pluginAPI;
+ uint64_t loopId = idAttr().getInt();
+ return pluginAPI.GetHeader(loopId);
+}
+
+// FIXME: use Block instead of uint64_t
+uint64_t LoopOp::GetLatch()
+{
+ PluginAPI::PluginServerAPI pluginAPI;
+ uint64_t loopId = idAttr().getInt();
+ return pluginAPI.GetLatch(loopId);
+}
+
+vector<uint64_t> LoopOp::GetLoopBody()
+{
+ PluginAPI::PluginServerAPI pluginAPI;
+ uint64_t loopId = idAttr().getInt();
+ return pluginAPI.GetLoopBody(loopId);
+}
+
+pair<uint64_t, uint64_t> LoopOp::GetSingleExit()
+{
+ PluginAPI::PluginServerAPI pluginAPI;
+ uint64_t loopId = idAttr().getInt();
+ return pluginAPI.LoopSingleExit(loopId);
+}
+
+void LoopOp::Delete()
+{
+ PluginAPI::PluginServerAPI pluginAPI;
+ uint64_t loopId = idAttr().getInt();
+ pluginAPI.DeleteLoop(loopId);
+}
+
+LoopOp LoopOp::GetInnerLoop()
+{
+ PluginAPI::PluginServerAPI pluginAPI;
+ uint64_t loopId = innerLoopIdAttr().getInt();
+ return pluginAPI.GetLoopById(loopId);
+}
+
+LoopOp LoopOp::GetOuterLoop()
+{
+ PluginAPI::PluginServerAPI pluginAPI;
+ uint64_t loopId = outerLoopIdAttr().getInt();
+ return pluginAPI.GetLoopById(loopId);
+}
+
+// FIXME: 用Block替换uint64_t
+bool LoopOp::IsBlockInside(uint64_t b)
+{
+ PluginAPI::PluginServerAPI pluginAPI;
+ uint64_t loopId = idAttr().getInt();
+ return pluginAPI.IsBlockInLoop(loopId, b);
+}
+
+vector<pair<uint64_t, uint64_t> > LoopOp::GetExitEdges()
+{
+ PluginAPI::PluginServerAPI pluginAPI;
+ uint64_t loopId = idAttr().getInt();
+ return pluginAPI.GetLoopExitEdges(loopId);
+}
+
+void LoopOp::AddLoop(uint64_t outerId, uint64_t funcId)
+{
+ PluginAPI::PluginServerAPI pluginAPI;
+ uint64_t loopId = idAttr().getInt();
+ return pluginAPI.AddLoop(loopId, outerId, funcId);
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/lib/PluginAPI/PluginServerAPI.cpp b/lib/PluginAPI/PluginServerAPI.cpp
index b6ed53f..41626f9 100644
--- a/lib/PluginAPI/PluginServerAPI.cpp
+++ b/lib/PluginAPI/PluginServerAPI.cpp
@@ -116,10 +116,179 @@ vector<LocalDeclOp> PluginServerAPI::GetDecls(uint64_t funcID)
{
Json::Value root;
string funName("GetLocalDecls");
- root[std::to_string(0)] = std::to_string(funcID);
+ root["funcId"] = std::to_string(funcID);
string params = root.toStyledString();
return GetDeclOperationResult(funName, params);
}
+vector<LoopOp> PluginServerAPI::GetLoopsResult(const string& funName, const string& params)
+{
+ WaitClientResult(funName, params);
+ vector<LoopOp> loops = PluginServer::GetInstance()->LoopOpsResult();
+ return loops;
+}
+
+LoopOp PluginServerAPI::GetLoopResult(const string& funName, const string& params)
+{
+ WaitClientResult(funName, params);
+ LoopOp loop = PluginServer::GetInstance()->LoopOpResult();
+ return loop;
+}
+
+bool PluginServerAPI::GetBoolResult(const string& funName, const string& params)
+{
+ WaitClientResult(funName, params);
+ return PluginServer::GetInstance()->BoolResult();
+}
+
+pair<uint64_t, uint64_t> PluginServerAPI::EdgeResult(const string& funName, const string& params)
+{
+ WaitClientResult(funName, params);
+ pair<uint64_t, uint64_t> e = PluginServer::GetInstance()->EdgeResult();
+ return e;
+}
+
+vector<pair<uint64_t, uint64_t> > PluginServerAPI::EdgesResult(const string& funName, const string& params)
+{
+ WaitClientResult(funName, params);
+ vector<pair<uint64_t, uint64_t> > retEdges = PluginServer::GetInstance()->EdgesResult();
+ return retEdges;
+}
+
+uint64_t PluginServerAPI::BlockResult(const string& funName, const string& params)
+{
+ WaitClientResult(funName, params);
+ return PluginServer::GetInstance()->BlockIdResult();
+}
+
+vector<uint64_t> PluginServerAPI::BlocksResult(const string& funName, const string& params)
+{
+ WaitClientResult(funName, params);
+ vector<uint64_t> retBlocks = PluginServer::GetInstance()->BlockIdsResult();
+ return retBlocks;
+}
+
+vector<LoopOp> PluginServerAPI::GetLoopsFromFunc(uint64_t funcID)
+{
+ Json::Value root;
+ string funName("GetLoopsFromFunc");
+ root["funcId"] = std::to_string(funcID);
+ string params = root.toStyledString();
+
+ return GetLoopsResult(funName, params);
+}
+
+// FIXME: 入参void
+LoopOp PluginServerAPI::AllocateNewLoop(uint64_t funcID)
+{
+ Json::Value root;
+ string funName("AllocateNewLoop");
+ root["funcId"] = std::to_string(funcID);
+ string params = root.toStyledString();
+
+ return GetLoopResult(funName, params);
+}
+
+LoopOp PluginServerAPI::GetLoopById(uint64_t loopID)
+{
+ Json::Value root;
+ string funName("GetLoopById");
+ root["loopId"] = std::to_string(loopID);
+ string params = root.toStyledString();
+
+ return GetLoopResult(funName, params);
+}
+
+void PluginServerAPI::DeleteLoop(uint64_t loopID)
+{
+ Json::Value root;
+ string funName("DeleteLoop");
+ root["loopId"] = std::to_string(loopID);
+ string params = root.toStyledString();
+ WaitClientResult(funName, params);
+}
+
+void PluginServerAPI::AddLoop(uint64_t loopID, uint64_t outerID, uint64_t funcID)
+{
+ Json::Value root;
+ string funName("AddLoop");
+ root["loopId"] = loopID;
+ root["outerId"] = outerID;
+ root["funcId"] = funcID;
+ string params = root.toStyledString();
+ WaitClientResult(funName, params);
+}
+
+bool PluginServerAPI::IsBlockInLoop(uint64_t loopID, uint64_t blockID)
+{
+ Json::Value root;
+ string funName("IsBlockInside");
+ root["loopId"] = std::to_string(loopID);
+ root["blockId"] = std::to_string(blockID);
+ string params = root.toStyledString();
+
+ return GetBoolResult(funName, params);
+}
+
+uint64_t PluginServerAPI::GetHeader(uint64_t loopID)
+{
+ Json::Value root;
+ string funName("GetHeader");
+ root["loopId"] = std::to_string(loopID);
+ string params = root.toStyledString();
+
+ return BlockResult(funName, params);
+}
+
+uint64_t PluginServerAPI::GetLatch(uint64_t loopID)
+{
+ Json::Value root;
+ string funName("GetLatch");
+ root["loopId"] = std::to_string(loopID);
+ string params = root.toStyledString();
+
+ return BlockResult(funName, params);
+}
+
+pair<uint64_t, uint64_t> PluginServerAPI::LoopSingleExit(uint64_t loopID)
+{
+ Json::Value root;
+ string funName("GetLoopSingleExit");
+ root["loopId"] = std::to_string(loopID);
+ string params = root.toStyledString();
+
+ return EdgeResult(funName, params);
+}
+
+vector<pair<uint64_t, uint64_t> > PluginServerAPI::GetLoopExitEdges(uint64_t loopID)
+{
+ Json::Value root;
+ string funName("GetExitEdges");
+ root["loopId"] = std::to_string(loopID);
+ string params = root.toStyledString();
+
+ return EdgesResult(funName, params);
+}
+
+vector<uint64_t> PluginServerAPI::GetLoopBody(uint64_t loopID)
+{
+ Json::Value root;
+ string funName("GetBlocksInLoop");
+ root["loopId"] = std::to_string(loopID);
+ string params = root.toStyledString();
+
+ return BlocksResult(funName, params);
+}
+
+LoopOp PluginServerAPI::GetBlockLoopFather(uint64_t blockID)
+{
+ Json::Value root;
+ string funName("GetBlockLoopFather");
+ root["blockId"] = std::to_string(blockID);
+ string params = root.toStyledString();
+
+ return GetLoopResult(funName, params);
+}
+
} // namespace Plugin_IR
diff --git a/lib/PluginServer/PluginServer.cpp b/lib/PluginServer/PluginServer.cpp
index 1ec47ae..92159e8 100644
--- a/lib/PluginServer/PluginServer.cpp
+++ b/lib/PluginServer/PluginServer.cpp
@@ -40,6 +40,7 @@ using namespace mlir::Plugin;
using std::cout;
using std::endl;
+using std::pair;
static std::unique_ptr<Server> g_server; // grpc对象指针
static PluginServer g_service; // 插件server对象
@@ -88,6 +89,52 @@ vector<mlir::Plugin::LocalDeclOp> PluginServer::GetLocalDeclResult()
return retOps;
}
+vector<mlir::Plugin::LoopOp> PluginServer::LoopOpsResult()
+{
+ vector<mlir::Plugin::LoopOp> retLoops = loops;
+ loops.clear();
+ return retLoops;
+}
+
+LoopOp PluginServer::LoopOpResult()
+{
+ mlir::Plugin::LoopOp retLoop = loop;
+ return retLoop;
+}
+
+bool PluginServer::BoolResult()
+{
+ return boolRes;
+}
+
+uint64_t PluginServer::BlockIdResult()
+{
+ return blockId;
+}
+
+vector<uint64_t> PluginServer::BlockIdsResult()
+{
+ vector<uint64_t> retIds = blockIds;
+ blockIds.clear();
+ return retIds;
+}
+
+pair<uint64_t, uint64_t> PluginServer::EdgeResult()
+{
+ pair<uint64_t, uint64_t> e;
+ e.first = edge.first;
+ e.second = edge.second;
+ return e;
+}
+
+vector<pair<uint64_t, uint64_t> > PluginServer::EdgesResult()
+{
+ vector<pair<uint64_t, uint64_t> > retEdges;
+ retEdges = edges;
+ edges.clear();
+ return retEdges;
+}
+
void PluginServer::JsonGetAttributes(Json::Value node, map<string, string>& attributes)
{
Json::Value::Members attMember = node.getMemberNames();
@@ -110,7 +157,23 @@ void PluginServer::JsonDeSerialize(const string& key, const string& data)
FuncOpJsonDeSerialize(data);
} else if (key == "LocalDeclOpResult") {
LocalDeclOpJsonDeSerialize(data);
- }else {
+ } else if (key == "LoopOpResult") {
+ LoopOpJsonDeSerialize (data);
+ } else if (key == "LoopOpsResult") {
+ LoopOpsJsonDeSerialize (data);
+ } else if (key == "BoolResult") {
+ BoolResJsonDeSerialize(data);
+ } else if (key == "VoidResult") {
+ ;
+ } else if (key == "BlockIdResult") {
+ BlockJsonDeSerialize(data);
+ } else if (key == "EdgeResult") {
+ EdgeJsonDeSerialize(data);
+ } else if (key == "EdgesResult") {
+ EdgesJsonDeSerialize(data);
+ } else if (key == "BlockIdsResult") {
+ BlocksJsonDeSerialize(data);
+ } else {
cout << "not Json,key:" << key << ",value:" << data << endl;
}
}
@@ -198,14 +261,128 @@ void PluginServer::LocalDeclOpJsonDeSerialize(const string& data)
Json::Value attributes = node["attributes"];
map<string, string> declAttributes;
JsonGetAttributes(attributes, declAttributes);
- string symName = declAttributes["symName"];
- uint64_t typeID = atol(declAttributes["typeID"].c_str());
- uint64_t typeWidth = atol(declAttributes["typeWidth"].c_str());
+ string symName = declAttributes["symName"];
+ uint64_t typeID = atol(declAttributes["typeID"].c_str());
+ uint64_t typeWidth = atol(declAttributes["typeWidth"].c_str());
auto location = builder.getUnknownLoc();
LocalDeclOp op = builder.create<LocalDeclOp>(location, id, symName, typeID, typeWidth);
decls.push_back(op);
}
}
+void PluginServer::LoopOpsJsonDeSerialize(const string& data)
+{
+ Json::Value root;
+ Json::Reader reader;
+ Json::Value node;
+ reader.parse(data, root);
+
+ Json::Value::Members operation = root.getMemberNames();
+ context.getOrLoadDialect<PluginDialect>();
+ mlir::OpBuilder builder(&context);
+ for (Json::Value::Members::iterator iter = operation.begin(); iter != operation.end(); iter++) {
+ string operationKey = *iter;
+ node = root[operationKey];
+ int64_t id = GetID(node["id"]);
+ Json::Value attributes = node["attributes"];
+ map<string, string> loopAttributes;
+ JsonGetAttributes(attributes, loopAttributes);
+ uint32_t index = atoi(attributes["index"].asString().c_str());
+ uint64_t innerId = atol(loopAttributes["innerLoopId"].c_str());
+ uint64_t outerId = atol(loopAttributes["outerLoopId"].c_str());
+ uint32_t numBlock = atoi(loopAttributes["numBlock"].c_str());
+ auto location = builder.getUnknownLoc();
+ LoopOp op = builder.create<LoopOp>(location, id, index, innerId, outerId, numBlock);
+ loops.push_back(op);
+ }
+}
+
+void PluginServer::LoopOpJsonDeSerialize(const string& data)
+{
+ Json::Value root;
+ Json::Reader reader;
+ reader.parse(data, root);
+
+ context.getOrLoadDialect<PluginDialect>();
+ mlir::OpBuilder builder(&context);
+
+ uint64_t id = GetID(root["id"]);
+ Json::Value attributes = root["attributes"];
+ uint32_t index = atoi(attributes["index"].asString().c_str());
+ uint64_t innerLoopId = atol(attributes["innerLoopId"].asString().c_str());
+ uint64_t outerLoopId = atol(attributes["outerLoopId"].asString().c_str());
+ uint32_t numBlock = atoi(attributes["numBlock"].asString().c_str());
+ auto location = builder.getUnknownLoc();
+ loop = builder.create<LoopOp>(location, id, index, innerLoopId, outerLoopId, numBlock);
+}
+
+void PluginServer::BoolResJsonDeSerialize(const string& data)
+{
+ Json::Value root;
+ Json::Reader reader;
+ reader.parse(data, root);
+
+ boolRes = (bool)atoi(root["result"].asString().c_str());
+}
+
+void PluginServer::EdgesJsonDeSerialize(const string& data)
+{
+ Json::Value root;
+ Json::Reader reader;
+ Json::Value node;
+ reader.parse(data, root);
+
+ Json::Value::Members operation = root.getMemberNames();
+ context.getOrLoadDialect<PluginDialect>();
+ mlir::OpBuilder builder(&context);
+ for (Json::Value::Members::iterator iter = operation.begin(); iter != operation.end(); iter++) {
+ string operationKey = *iter;
+ node = root[operationKey];
+ uint64_t src = atol(node["src"].asString().c_str());
+ uint64_t dest = atol(node["dest"].asString().c_str());
+ pair<uint64_t, uint64_t> e;
+ e.first = src;
+ e.second = dest;
+ edges.push_back(e);
+ }
+}
+
+void PluginServer::EdgeJsonDeSerialize(const string& data)
+{
+ Json::Value root;
+ Json::Reader reader;
+ reader.parse(data, root);
+ uint64_t src = atol(root["src"].asString().c_str());
+ uint64_t dest = atol(root["dest"].asString().c_str());
+ edge.first = src;
+ edge.second = dest;
+}
+
+void PluginServer::BlocksJsonDeSerialize(const string& data)
+{
+ Json::Value root;
+ Json::Reader reader;
+ Json::Value node;
+ reader.parse(data, root);
+
+ Json::Value::Members operation = root.getMemberNames();
+ context.getOrLoadDialect<PluginDialect>();
+ mlir::OpBuilder builder(&context);
+ for (Json::Value::Members::iterator iter = operation.begin(); iter != operation.end(); iter++) {
+ string operationKey = *iter;
+ node = root[operationKey];
+ uint64_t id = atol(node["id"].asString().c_str());
+ blockIds.push_back(id);
+ }
+}
+
+void PluginServer::BlockJsonDeSerialize(const string& data)
+{
+ Json::Value root;
+ Json::Reader reader;
+ reader.parse(data, root);
+
+ blockId = (uint64_t)atol(root["id"].asString().c_str());
+}
/* 线程函数,执行用户注册函数,客户端返回数据后退出 */
static void ExecCallbacks(const string& name)
diff --git a/user.cpp b/user.cpp
index 81958c0..efc4158 100644
--- a/user.cpp
+++ b/user.cpp
@@ -80,7 +80,7 @@ static void PassManagerSetupFunc(void)
static bool
determineLoopForm(LoopOp loop)
{
- if (loop.innerLoopIdAttr().getUInt() != 0 || loop.numBlockAttr().getUInt() != 3)
+ if (loop.innerLoopIdAttr().getInt() != 0 || loop.numBlockAttr().getInt() != 3)
{
printf ("\nWrong loop form, there is inner loop or redundant bb.\n");
return false;
@@ -97,21 +97,21 @@ determineLoopForm(LoopOp loop)
static void
ProcessArrayWiden(void)
{
- std::cout << "Running first pass, OMG\n";
+ std::cout << "Running first pass, awiden\n";
PluginServerAPI pluginAPI;
vector<FunctionOp> allFunction = pluginAPI.GetAllFunc();
-
+
for (auto &funcOp : allFunction) {
string name = funcOp.funcNameAttr().getValue().str();
printf("Now process func : %s \n", name.c_str());
vector<LoopOp> allLoop = funcOp.GetAllLoops();
- for (auto &loop : allLoop) {
- if (determineLoopForm(loop)) {
- printf("The %dth loop form is success matched, and the loop can be optimized.\n", loop.indexAttr().getUInt());
- return;
- }
+ for (auto &loop : allLoop) {
+ if (determineLoopForm(loop)) {
+ printf("The %ldth loop form is success matched, and the loop can be optimized.\n", loop.indexAttr().getInt());
+ return;
}
+ }
}
}
--
2.27.0.windows.1