pin-server/0033-Pin-server-Add-DataFlow-APIs.patch

1280 lines
45 KiB
Diff
Raw Normal View History

From e8adf83db9cc977fe0145555715d083805636757 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E9=83=91=E6=99=A8=E5=8D=89?= <zhengchenhui1@huawei.com>
Date: Mon, 25 Dec 2023 10:24:23 +0800
Subject: [PATCH 6/7] [Pin-server] Add DataFlow APIs.
---
CMakeLists.txt | 6 +-
include/PluginAPI/BasicPluginOpsAPI.h | 5 +
include/PluginAPI/ControlFlowAPI.h | 1 +
include/PluginAPI/DataFlowAPI.h | 65 +++++
include/PluginAPI/PluginServerAPI.h | 5 +
include/PluginServer/ManagerSetup.h | 1 +
include/PluginServer/PluginCom.h | 2 +
include/PluginServer/PluginJson.h | 1 +
include/PluginServer/PluginServer.h | 14 +-
include/user/SimpleLICMPass.h | 45 ++++
lib/Dialect/PluginOps.cpp | 8 +
lib/PluginAPI/ControlFlowAPI.cpp | 17 +-
lib/PluginAPI/DataFlowAPI.cpp | 171 ++++++++++++
lib/PluginAPI/PluginServerAPI.cpp | 40 ++-
lib/PluginServer/PluginCom.cpp | 13 +-
lib/PluginServer/PluginJson.cpp | 37 ++-
lib/PluginServer/PluginServer.cpp | 19 +-
user/SimpleLICMPass.cpp | 359 ++++++++++++++++++++++++++
user/user.cpp | 9 +-
19 files changed, 792 insertions(+), 26 deletions(-)
create mode 100644 include/PluginAPI/DataFlowAPI.h
create mode 100644 include/user/SimpleLICMPass.h
create mode 100644 lib/PluginAPI/DataFlowAPI.cpp
create mode 100644 user/SimpleLICMPass.cpp
diff --git a/CMakeLists.txt b/CMakeLists.txt
index df8b634..925f80e 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -85,10 +85,11 @@ target_link_libraries(plg_grpc_proto
add_subdirectory(include)
add_subdirectory(lib)
add_library(pin_user SHARED
- # "user/ArrayWidenPass.cpp"
- "user/StructReorder.cpp"
+ "user/ArrayWidenPass.cpp"
+ #"user/StructReorder.cpp"
"user/InlineFunctionPass.cpp"
"user/LocalVarSummeryPass.cpp"
+ "user/SimpleLICMPass.cpp"
"user/user.cpp")
target_link_libraries(pin_user
@@ -101,6 +102,7 @@ add_custom_command(TARGET pin_user COMMAND sha256sum libpin_user.so > libpin_use
add_executable(pin_server
"lib/PluginServer/PluginServer.cpp"
"lib/PluginAPI/ControlFlowAPI.cpp"
+ "lib/PluginAPI/DataFlowAPI.cpp"
"lib/PluginServer/PluginGrpc.cpp"
"lib/PluginServer/PluginJson.cpp"
"lib/PluginServer/PluginCom.cpp"
diff --git a/include/PluginAPI/BasicPluginOpsAPI.h b/include/PluginAPI/BasicPluginOpsAPI.h
index 065274d..4e05e6f 100644
--- a/include/PluginAPI/BasicPluginOpsAPI.h
+++ b/include/PluginAPI/BasicPluginOpsAPI.h
@@ -86,6 +86,7 @@ public:
virtual pair<mlir::Block*, mlir::Block*> LoopSingleExit(uint64_t) = 0;
virtual vector<pair<mlir::Block*, mlir::Block*> > GetLoopExitEdges(uint64_t) = 0;
virtual LoopOp GetBlockLoopFather(mlir::Block*) = 0;
+ virtual LoopOp FindCommonLoop(LoopOp*, LoopOp*) = 0;
virtual PhiOp GetPhiOp(uint64_t) = 0;
virtual CallOp GetCallOp(uint64_t) = 0;
virtual bool SetLhsInCallOp(uint64_t, uint64_t) = 0;
@@ -98,6 +99,8 @@ public:
virtual uint32_t AddArgInPhiOp(uint64_t, uint64_t, uint64_t, uint64_t) = 0;
virtual PhiOp CreatePhiOp(uint64_t, uint64_t) = 0;
virtual void DebugValue(uint64_t) = 0;
+ virtual void DebugOperation(uint64_t) = 0;
+ virtual void DebugBlock(mlir::Block*) = 0;
virtual bool IsLtoOptimize() = 0;
virtual bool IsWholeProgram() = 0;
@@ -109,6 +112,8 @@ public:
virtual mlir::Value BuildMemRef(PluginIR::PluginTypeBase, mlir::Value, mlir::Value) = 0;
virtual bool RedirectFallthroughTarget(FallThroughOp&, mlir::Block*, mlir::Block*) = 0;
virtual mlir::Operation* GetSSADefOperation(uint64_t) = 0;
+
+ virtual bool IsVirtualOperand(uint64_t) = 0;
}; // class BasicPluginOpsAPI
} // namespace PluginAPI
diff --git a/include/PluginAPI/ControlFlowAPI.h b/include/PluginAPI/ControlFlowAPI.h
index 429ba11..8c4c38d 100644
--- a/include/PluginAPI/ControlFlowAPI.h
+++ b/include/PluginAPI/ControlFlowAPI.h
@@ -40,6 +40,7 @@ public:
bool UpdateSSA(void);
vector<PhiOp> GetAllPhiOpInsideBlock(mlir::Block *b);
+ vector<mlir::Operation*> GetAllOpsInsideBlock(mlir::Block *b);
mlir::Block* CreateBlock(mlir::Block*, FunctionOp*);
void DeleteBlock(mlir::Block*, FunctionOp*);
diff --git a/include/PluginAPI/DataFlowAPI.h b/include/PluginAPI/DataFlowAPI.h
new file mode 100644
index 0000000..54b6b50
--- /dev/null
+++ b/include/PluginAPI/DataFlowAPI.h
@@ -0,0 +1,65 @@
+/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License"); you may
+ not use this file except in compliance with the License. You may obtain
+ a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ License for the specific language governing permissions and limitations
+ under the License.
+
+ Author: Chenhui Zheng
+ Create: 2023-11-01
+ Description:
+ This file contains the declaration of the PluginAPI_Server class.
+*/
+#ifndef PLUGIN_FRAMEWORK_DATA_FLOW_API_H
+#define PLUGIN_FRAMEWORK_DATA_FLOW_API_H
+
+#include "BasicPluginOpsAPI.h"
+#include "PluginServer/PluginServer.h"
+#include "Dialect/PluginTypes.h"
+#include "PluginServerAPI.h"
+
+namespace PluginAPI {
+
+using std::string;
+using std::vector;
+
+using namespace PinServer;
+using namespace mlir::Plugin;
+class DataFlowAPI {
+public:
+ DataFlowAPI() = default;
+ ~DataFlowAPI() = default;
+
+ // 计算支配信息
+ void CalDominanceInfo(uint64_t, uint64_t);
+
+ // USE-DEF
+ vector<mlir::Operation*> GetImmUseStmts(mlir::Value);
+ mlir::Value GetGimpleVuse(uint64_t);
+ mlir::Value GetGimpleVdef(uint64_t);
+ vector<mlir::Value> GetSsaUseOperand(uint64_t);
+ vector<mlir::Value> GetSsaDefOperand(uint64_t);
+ vector<mlir::Value> GetPhiOrStmtUse(uint64_t);
+ vector<mlir::Value> GetPhiOrStmtDef(uint64_t);
+
+ //别名分析
+ bool RefsMayAlias(mlir::Value, mlir::Value, uint64_t);
+
+ // 指针分析
+ bool PTIncludesDecl(mlir::Value, uint64_t);
+ bool PTsIntersect(mlir::Value, mlir::Value);
+
+private:
+ PluginServerAPI pluginAPI;
+};
+
+} // namespace PluginAPI
+
+#endif // PLUGIN_FRAMEWORK_CONTROL_FLOW_API_H
diff --git a/include/PluginAPI/PluginServerAPI.h b/include/PluginAPI/PluginServerAPI.h
index 9cc8498..236c8a1 100644
--- a/include/PluginAPI/PluginServerAPI.h
+++ b/include/PluginAPI/PluginServerAPI.h
@@ -86,6 +86,7 @@ public:
void SetLatch(LoopOp*, mlir::Block*);
vector<mlir::Block*> GetLoopBody(uint64_t loopID);
LoopOp GetBlockLoopFather(mlir::Block*);
+ LoopOp FindCommonLoop(LoopOp*, LoopOp*);
mlir::Block* FindBlock(uint64_t);
uint64_t FindBasicBlock(mlir::Block*);
bool InsertValue(uint64_t, mlir::Value);
@@ -104,6 +105,8 @@ public:
/* Plugin API for ConstOp. */
mlir::Value CreateConstOp(mlir::Attribute, mlir::Type) override;
void DebugValue(uint64_t) override;
+ void DebugOperation(uint64_t) override;
+ void DebugBlock(mlir::Block*) override;
bool IsLtoOptimize() override;
bool IsWholeProgram() override;
@@ -117,6 +120,8 @@ public:
mlir::Operation* GetSSADefOperation(uint64_t) override;
void InsertCreatedBlock(uint64_t id, mlir::Block* block);
+ bool IsVirtualOperand(uint64_t) override;
+
int64_t GetInjectDataAddress() override; // 获取注入点返回数据的地址
string GetDeclSourceFile(int64_t) override;
string VariableName(int64_t) override;
diff --git a/include/PluginServer/ManagerSetup.h b/include/PluginServer/ManagerSetup.h
index e264237..1fe3f41 100755
--- a/include/PluginServer/ManagerSetup.h
+++ b/include/PluginServer/ManagerSetup.h
@@ -28,6 +28,7 @@ enum RefPassName {
PASS_PHIOPT,
PASS_SSA,
PASS_LOOP,
+ PASS_LAD,
PASS_MAC,
};
diff --git a/include/PluginServer/PluginCom.h b/include/PluginServer/PluginCom.h
index 7deb615..2897742 100755
--- a/include/PluginServer/PluginCom.h
+++ b/include/PluginServer/PluginCom.h
@@ -74,6 +74,7 @@ public:
uint64_t GetIdResult(void);
vector<uint64_t> GetIdsResult(void);
mlir::Value GetValueResult(void);
+ vector<mlir::Value> GetValuesResult(void);
vector<mlir::Plugin::PhiOp> GetPhiOpsResult(void);
private:
@@ -94,6 +95,7 @@ private:
uint64_t idResult;
vector<uint64_t> idsResult;
mlir::Value valueResult;
+ vector<mlir::Value> valuesResult;
PluginIR::PluginTypeBase pTypeResult;
mlir::Plugin::DeclBaseOp declOp;
mlir::Plugin::FieldDeclOp fielddeclOp;
diff --git a/include/PluginServer/PluginJson.h b/include/PluginServer/PluginJson.h
index 1d4bfbc..ede7b53 100755
--- a/include/PluginServer/PluginJson.h
+++ b/include/PluginServer/PluginJson.h
@@ -88,6 +88,7 @@ public:
/* 将json格式数据解析成map<string, string>格式 */
void GetAttributes(Json::Value node, map<string, string>& attributes);
mlir::Value ValueJsonDeSerialize(Json::Value valueJson);
+ void ValuesJsonDeSerialize(const string&, vector<mlir::Value>&);
Json::Value TypeJsonSerialize(PluginIR::PluginTypeBase type);
mlir::Value MemRefDeSerialize(const string& data);
bool ProcessBlock(mlir::Block*, mlir::Region&, const Json::Value&);
diff --git a/include/PluginServer/PluginServer.h b/include/PluginServer/PluginServer.h
index 81f92e9..b44c1d6 100644
--- a/include/PluginServer/PluginServer.h
+++ b/include/PluginServer/PluginServer.h
@@ -244,6 +244,11 @@ public:
RemoteCallClientWithAPI(funName, params);
return pluginCom.GetValueResult();
}
+ vector<mlir::Value> GetValuesResult(const string& funName, const string& params)
+ {
+ RemoteCallClientWithAPI(funName, params);
+ return pluginCom.GetValuesResult();
+ }
vector<mlir::Plugin::PhiOp> GetPhiOpsResult(const string& funName, const string& params)
{
RemoteCallClientWithAPI(funName, params);
@@ -264,15 +269,15 @@ public:
valueMaps.clear();
blockMaps.clear();
basicblockMaps.clear();
- defOpMaps.clear();
+ opMaps.clear();
}
uint64_t FindBasicBlock(mlir::Block*);
bool InsertValue(uint64_t, mlir::Value);
bool HaveValue(uint64_t);
mlir::Value GetValue(uint64_t);
- mlir::Operation* FindDefOperation(uint64_t);
- bool InsertDefOperation(uint64_t, mlir::Operation*);
+ mlir::Operation* FindOperation(uint64_t);
+ bool InsertOperation(uint64_t, mlir::Operation*);
void RemoteCallClientWithAPI(const string& api, const string& params);
private:
@@ -296,7 +301,8 @@ private:
// process Block.
std::map<uint64_t, mlir::Block*> blockMaps;
std::map<mlir::Block*, uint64_t> basicblockMaps;
- std::map<uint64_t, mlir::Operation*> defOpMaps;
+ // std::map<uint64_t, mlir::Operation*> defOpMaps;
+ std::map<uint64_t, mlir::Operation*> opMaps; // 保存所有的<gimpleid, op>键值对
/* 解析客户端发送过来的-fplugin-arg参数并保存在私有变量args中 */
void ParseArgv(const string& data);
diff --git a/include/user/SimpleLICMPass.h b/include/user/SimpleLICMPass.h
new file mode 100644
index 0000000..aa6ffa7
--- /dev/null
+++ b/include/user/SimpleLICMPass.h
@@ -0,0 +1,45 @@
+/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License"); you may
+ not use this file except in compliance with the License. You may obtain
+ a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ License for the specific language governing permissions and limitations
+ under the License.
+
+ Author: Mingchuan Wu and Yancheng Li
+ Create: 2022-08-18
+ Description:
+ This file contains the declaration of the ArrayWidenPass class.
+*/
+
+#ifndef SIMPLE_LICM_PASS_H
+#define SIMPLE_LICM_PASS_H
+
+#include "PluginServer/PluginOptBase.h"
+
+namespace PluginOpt {
+class SimpleLICMPass : public PluginOptBase {
+public:
+ SimpleLICMPass() : PluginOptBase(HANDLE_MANAGER_SETUP)
+ {
+ }
+ bool Gate()
+ {
+ return true;
+ }
+ int DoOptimize()
+ {
+ uint64_t fun = (uint64_t)GetFuncAddr();
+ return DoOptimize(fun);
+ }
+ int DoOptimize(uint64_t fun);
+};
+}
+
+#endif
\ No newline at end of file
diff --git a/lib/Dialect/PluginOps.cpp b/lib/Dialect/PluginOps.cpp
index 3602856..d4b061f 100644
--- a/lib/Dialect/PluginOps.cpp
+++ b/lib/Dialect/PluginOps.cpp
@@ -22,6 +22,7 @@
#include "PluginAPI/PluginServerAPI.h"
#include "PluginAPI/ControlFlowAPI.h"
+#include "PluginAPI/DataFlowAPI.h"
#include "Dialect/PluginDialect.h"
#include "Dialect/PluginOps.h"
#include "Dialect/PluginTypes.h"
@@ -265,6 +266,13 @@ bool LoopOp::IsLoopFather(mlir::Block* b)
return id == loopId;
}
+LoopOp LoopOp::FindCommonLoop(LoopOp* loop_1, LoopOp* loop_2)
+{
+ PluginAPI::PluginServerAPI pluginAPI;
+ LoopOp commonLoop = pluginAPI.FindCommonLoop(loop_1, loop_2);
+ return commonLoop;
+}
+
vector<pair<mlir::Block*, mlir::Block*> > LoopOp::GetExitEdges()
{
PluginAPI::PluginServerAPI pluginAPI;
diff --git a/lib/PluginAPI/ControlFlowAPI.cpp b/lib/PluginAPI/ControlFlowAPI.cpp
index b356ae9..3cf72f6 100644
--- a/lib/PluginAPI/ControlFlowAPI.cpp
+++ b/lib/PluginAPI/ControlFlowAPI.cpp
@@ -106,10 +106,25 @@ vector<PhiOp> ControlFlowAPI::GetAllPhiOpInsideBlock(mlir::Block *b)
string funName = __func__;
root["bbAddr"] = std::to_string(server->FindBasicBlock(b));
string params = root.toStyledString();
-
return GetPhiOperationResult(funName, params);
}
+vector<mlir::Operation*> ControlFlowAPI::GetAllOpsInsideBlock(mlir::Block *b)
+{
+ PluginServer *server = PluginServer::GetInstance();
+ Json::Value root;
+ string funName = __func__;
+ root["bbAddr"] = std::to_string(server->FindBasicBlock(b));
+ string params = root.toStyledString();
+ vector<uint64_t> ids = server->GetIdsResult(funName, params);
+ vector<mlir::Operation*> ops;
+ for (auto id: ids) {
+ mlir::Operation* op = server->FindOperation(id);
+ ops.push_back(op);
+ }
+ return ops;
+}
+
mlir::Block* ControlFlowAPI::CreateBlock(mlir::Block* b, FunctionOp *funcOp)
{
Json::Value root;
diff --git a/lib/PluginAPI/DataFlowAPI.cpp b/lib/PluginAPI/DataFlowAPI.cpp
new file mode 100644
index 0000000..2fec19f
--- /dev/null
+++ b/lib/PluginAPI/DataFlowAPI.cpp
@@ -0,0 +1,171 @@
+/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License"); you may
+ not use this file except in compliance with the License. You may obtain
+ a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ License for the specific language governing permissions and limitations
+ under the License.
+
+*/
+#include "PluginAPI/DataFlowAPI.h"
+
+namespace PluginAPI {
+using namespace PinServer;
+using namespace mlir::Plugin;
+
+static uint64_t GetValueId(mlir::Value v)
+{
+ mlir::Operation*op = v.getDefiningOp();
+ if (auto mOp = llvm::dyn_cast<MemOp>(op)) {
+ return mOp.id();
+ } else if (auto ssaOp = llvm::dyn_cast<SSAOp>(op)) {
+ return ssaOp.id();
+ } else if (auto cstOp = llvm::dyn_cast<ConstOp>(op)) {
+ return cstOp.id();
+ } else if (auto treelistop = llvm::dyn_cast<ListOp>(op)) {
+ return treelistop.id();
+ } else if (auto strop = llvm::dyn_cast<StrOp>(op)) {
+ return strop.id();
+ } else if (auto arrayop = llvm::dyn_cast<ArrayOp>(op)) {
+ return arrayop.id();
+ } else if (auto declop = llvm::dyn_cast<DeclBaseOp>(op)) {
+ return declop.id();
+ } else if (auto fieldop = llvm::dyn_cast<FieldDeclOp>(op)) {
+ return fieldop.id();
+ } else if (auto addressop = llvm::dyn_cast<AddressOp>(op)) {
+ return addressop.id();
+ } else if (auto constructorop = llvm::dyn_cast<ConstructorOp>(op)) {
+ return constructorop.id();
+ } else if (auto vecop = llvm::dyn_cast<VecOp>(op)) {
+ return vecop.id();
+ } else if (auto blockop = llvm::dyn_cast<BlockOp>(op)) {
+ return blockop.id();
+ } else if (auto compop = llvm::dyn_cast<ComponentOp>(op)) {
+ return compop.id();
+ } else if (auto phOp = llvm::dyn_cast<PlaceholderOp>(op)) {
+ return phOp.id();
+ }
+ return 0;
+}
+
+/* dir: 1 or 2 */
+void DataFlowAPI::CalDominanceInfo(uint64_t dir, uint64_t funcId)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["dir"] = std::to_string(dir);
+ root["funcId"] = std::to_string(funcId);
+ string params = root.toStyledString();
+ PluginServer::GetInstance()->RemoteCallClientWithAPI(funName, params);
+}
+
+vector<mlir::Operation*> DataFlowAPI::GetImmUseStmts(mlir::Value v)
+{
+ Json::Value root;
+ string funName = __func__;
+ uint64_t varId = GetValueId(v);
+ root["varId"] = std::to_string(varId);
+ string params = root.toStyledString();
+ vector<uint64_t> retIds = PluginServer::GetInstance()->GetIdsResult(funName, params);
+ vector<mlir::Operation*> ops;
+ for (auto id : retIds) {
+ ops.push_back(PluginServer::GetInstance()->FindOperation(id));
+ }
+ return ops;
+}
+
+mlir::Value DataFlowAPI::GetGimpleVuse(uint64_t opId)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["opId"] = std::to_string(opId);
+ string params = root.toStyledString();
+ return PluginServer::GetInstance()->GetValueResult(funName, params);
+}
+
+mlir::Value DataFlowAPI::GetGimpleVdef(uint64_t opId)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["opId"] = std::to_string(opId);
+ string params = root.toStyledString();
+ return PluginServer::GetInstance()->GetValueResult(funName, params);
+}
+
+vector<mlir::Value> DataFlowAPI::GetSsaUseOperand(uint64_t opId)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["opId"] = std::to_string(opId);
+ string params = root.toStyledString();
+ return PluginServer::GetInstance()->GetValuesResult(funName, params);
+}
+
+vector<mlir::Value> DataFlowAPI::GetSsaDefOperand(uint64_t opId)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["opId"] = std::to_string(opId);
+ string params = root.toStyledString();
+ return PluginServer::GetInstance()->GetValuesResult(funName, params);
+}
+
+vector<mlir::Value> DataFlowAPI::GetPhiOrStmtUse(uint64_t opId)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["opId"] = std::to_string(opId);
+ string params = root.toStyledString();
+ return PluginServer::GetInstance()->GetValuesResult(funName, params);
+}
+
+vector<mlir::Value> DataFlowAPI::GetPhiOrStmtDef(uint64_t opId)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["opId"] = std::to_string(opId);
+ string params = root.toStyledString();
+ return PluginServer::GetInstance()->GetValuesResult(funName, params);
+}
+
+/* flag : 0 or 1 */
+bool DataFlowAPI::RefsMayAlias(mlir::Value v1, mlir::Value v2, uint64_t flag)
+{
+ Json::Value root;
+ string funName = __func__;
+ uint64_t id1 = GetValueId(v1);
+ uint64_t id2 = GetValueId(v2);
+ root["id1"] = std::to_string(id1);
+ root["id2"] = std::to_string(id2);
+ root["flag"] = std::to_string(flag);
+ string params = root.toStyledString();
+ return PluginServer::GetInstance()->GetBoolResult(funName, params);
+}
+
+bool DataFlowAPI::PTIncludesDecl(mlir::Value ptr, uint64_t declId)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["ptrId"] = std::to_string(GetValueId(ptr));
+ root["declId"] = std::to_string(declId);
+ string params = root.toStyledString();
+ return PluginServer::GetInstance()->GetBoolResult(funName, params);
+}
+
+bool DataFlowAPI::PTsIntersect(mlir::Value ptr_1, mlir::Value ptr_2)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["ptrId_1"] = std::to_string(GetValueId(ptr_1));
+ root["ptrId_2"] = std::to_string(GetValueId(ptr_2));
+ string params = root.toStyledString();
+ return PluginServer::GetInstance()->GetBoolResult(funName, params);
+}
+
+}
\ No newline at end of file
diff --git a/lib/PluginAPI/PluginServerAPI.cpp b/lib/PluginAPI/PluginServerAPI.cpp
index 09bd358..0b7d8ab 100644
--- a/lib/PluginAPI/PluginServerAPI.cpp
+++ b/lib/PluginAPI/PluginServerAPI.cpp
@@ -372,6 +372,15 @@ mlir::Value PluginServerAPI::ConfirmValue(mlir::Value v)
return PluginServer::GetInstance()->GetValueResult(funName, params);
}
+bool PluginServerAPI::IsVirtualOperand(uint64_t id)
+{
+ Json::Value root;
+ string funName = "IsVirtualOperand";
+ root["id"] = std::to_string(id);
+ string params = root.toStyledString();
+ return PluginServer::GetInstance()->GetBoolResult(funName, params);
+}
+
mlir::Value PluginServerAPI::BuildMemRef(PluginIR::PluginTypeBase type,
mlir::Value base, mlir::Value offset)
{
@@ -811,6 +820,16 @@ LoopOp PluginServerAPI::GetBlockLoopFather(mlir::Block* b)
return PluginServer::GetInstance()->LoopOpResult(funName, params);
}
+LoopOp PluginServerAPI::FindCommonLoop(LoopOp* loop_1, LoopOp* loop_2)
+{
+ Json::Value root;
+ string funName("FindCommonLoop");
+ root["loopId_1"] = loop_1->idAttr().getInt();
+ root["loopId_2"] = loop_2->idAttr().getInt();
+ string params = root.toStyledString();
+ return PluginServer::GetInstance()->LoopOpResult(funName, params);
+}
+
mlir::Block* PluginServerAPI::FindBlock(uint64_t b)
{
PluginServer *server = PluginServer::GetInstance();
@@ -845,7 +864,7 @@ bool PluginServerAPI::RedirectFallthroughTarget(
mlir::Operation* PluginServerAPI::GetSSADefOperation(uint64_t addr)
{
- return PluginServer::GetInstance()->FindDefOperation(addr);
+ return PluginServer::GetInstance()->FindOperation(addr);
}
void PluginServerAPI::InsertCreatedBlock(uint64_t id, mlir::Block* block)
@@ -862,6 +881,25 @@ void PluginServerAPI::DebugValue(uint64_t valId)
PluginServer::GetInstance()->RemoteCallClientWithAPI(funName, params);
}
+void PluginServerAPI::DebugOperation(uint64_t opId)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["opId"] = opId;
+ string params = root.toStyledString();
+ PluginServer::GetInstance()->RemoteCallClientWithAPI(funName, params);
+}
+
+void PluginServerAPI::DebugBlock(mlir::Block* b)
+{
+ PluginServer *server = PluginServer::GetInstance();
+ Json::Value root;
+ string funName = __func__;
+ root["bbAddr"] = std::to_string(server->FindBasicBlock(b));
+ string params = root.toStyledString();
+ server->RemoteCallClientWithAPI(funName, params);
+}
+
bool PluginServerAPI::IsLtoOptimize()
{
Json::Value root;
diff --git a/lib/PluginServer/PluginCom.cpp b/lib/PluginServer/PluginCom.cpp
index aaf4ba8..8717540 100755
--- a/lib/PluginServer/PluginCom.cpp
+++ b/lib/PluginServer/PluginCom.cpp
@@ -148,6 +148,13 @@ mlir::Value PluginCom::GetValueResult()
return this->valueResult;
}
+vector<mlir::Value> PluginCom::GetValuesResult()
+{
+ vector<mlir::Value> retVals = valuesResult;
+ valuesResult.clear();
+ return retVals;
+}
+
vector<mlir::Plugin::PhiOp> PluginCom::GetPhiOpsResult()
{
vector<mlir::Plugin::PhiOp> retOps;
@@ -196,9 +203,11 @@ void PluginCom::JsonDeSerialize(const string& key, const string& data)
this->declOp = llvm::dyn_cast<mlir::Plugin::DeclBaseOp>(decl.getDefiningOp());
} else if (key == "GetFieldsOpResult") {
json.FieldOpsJsonDeSerialize(data, this->fieldsOps);
- } else if (key == "OpsResult") {
+ } else if (key == "OpResult") {
json.OpJsonDeSerialize(data.c_str(), this->opData);
- } else if (key == "ValueResult") {
+ } else if (key == "ValuesResult") {
+ json.ValuesJsonDeSerialize(data, this->valuesResult);
+ }else if (key == "ValueResult") {
Json::Value node;
Json::Reader reader;
reader.parse(data, node);
diff --git a/lib/PluginServer/PluginJson.cpp b/lib/PluginServer/PluginJson.cpp
index 0392b75..864a444 100755
--- a/lib/PluginServer/PluginJson.cpp
+++ b/lib/PluginServer/PluginJson.cpp
@@ -189,6 +189,20 @@ mlir::Value PluginJson::ValueJsonDeSerialize(Json::Value valueJson)
return opValue;
}
+void PluginJson::ValuesJsonDeSerialize(const string& data, vector<mlir::Value>& valuesData)
+{
+ Json::Value root;
+ Json::Reader reader;
+ Json::Value node;
+ reader.parse(data, root);
+ Json::Value::Members operation = root.getMemberNames();
+ for (size_t iter = 0; iter < operation.size(); iter++) {
+ string operationKey = "Value" + std::to_string(iter);
+ node = root[operationKey];
+ valuesData.push_back(ValueJsonDeSerialize(node));
+ }
+}
+
mlir::Value PluginJson::MemRefDeSerialize(const string& data)
{
Json::Value root;
@@ -666,6 +680,7 @@ mlir::Operation *PluginJson::CallOpJsonDeSerialize(const string& data)
op = opBuilder->create<CallOp>(opBuilder->getUnknownLoc(),
id, callName, ops);
}
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -690,6 +705,7 @@ mlir::Operation *PluginJson::CondOpJsonDeSerialize(const string& data)
CondOp op = opBuilder->create<CondOp>(
opBuilder->getUnknownLoc(), id, address, iCode, LHS,
RHS, tb, fb, tbaddr, fbaddr, trueLabel, falseLabel);
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -736,7 +752,7 @@ mlir::Operation *PluginJson::AssignOpJsonDeSerialize(const string& data)
mlir::OpBuilder *opBuilder = PluginServer::GetInstance()->GetOpBuilder();
AssignOp op = opBuilder->create<AssignOp>(opBuilder->getUnknownLoc(),
ops, id, iCode);
- PluginServer::GetInstance()->InsertDefOperation(id, op.getOperation());
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -760,7 +776,7 @@ mlir::Operation *PluginJson::PhiOpJsonDeSerialize(const string& data)
PhiOp op = opBuilder->create<PhiOp>(opBuilder->getUnknownLoc(),
ops, id, capacity, nArgs);
- PluginServer::GetInstance()->InsertDefOperation(id, op.getOperation());
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -1003,6 +1019,7 @@ mlir::Operation *PluginJson::GotoOpJsonDeSerialize(const string& data)
mlir::Block* success = PluginServer::GetInstance()->FindBlock(successaddr);
mlir::OpBuilder *opBuilder = PluginServer::GetInstance()->GetOpBuilder();
GotoOp op = opBuilder->create<GotoOp>(opBuilder->getUnknownLoc(), id, address, dest, success, successaddr);
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -1030,6 +1047,7 @@ mlir::Operation *PluginJson::TransactionOpJsonDeSerialize(const string& data)
mlir::OpBuilder *opBuilder = PluginServer::GetInstance()->GetOpBuilder();
TransactionOp op = opBuilder->create<TransactionOp>(opBuilder->getUnknownLoc(), id, address, stmtaddr, labelNorm,
labelUninst, labelOver, fallthrough, fallthroughaddr, abort, abortaddr);
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -1044,6 +1062,7 @@ mlir::Operation *PluginJson::ResxOpJsonDeSerialize(const string& data)
mlir::OpBuilder *opBuilder = PluginServer::GetInstance()->GetOpBuilder();
ResxOp op = opBuilder->create<ResxOp>(opBuilder->getUnknownLoc(),
id, address, region);
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -1056,6 +1075,7 @@ mlir::Operation *PluginJson::EHMntOpJsonDeSerialize(const string& data)
mlir::Value decl = ValueJsonDeSerialize(node["decl"]);
mlir::OpBuilder *opBuilder = PluginServer::GetInstance()->GetOpBuilder();
EHMntOp op = opBuilder->create<EHMntOp>(opBuilder->getUnknownLoc(), id, decl);
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -1082,6 +1102,7 @@ mlir::Operation *PluginJson::EHDispatchOpJsonDeSerialize(const string& data)
mlir::OpBuilder *opBuilder = PluginServer::GetInstance()->GetOpBuilder();
EHDispatchOp op = opBuilder->create<EHDispatchOp>(opBuilder->getUnknownLoc(),
id, address, region, ehHandlers, ehHandlersaddrs);
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -1094,6 +1115,7 @@ mlir::Operation *PluginJson::LabelOpJsonDeSerialize(const string& data)
mlir::Value label = ValueJsonDeSerialize(node["label"]);
mlir::OpBuilder *opBuilder = PluginServer::GetInstance()->GetOpBuilder();
LabelOp op = opBuilder->create<LabelOp>(opBuilder->getUnknownLoc(), id, label);
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -1116,6 +1138,7 @@ mlir::Operation *PluginJson::BindOpJsonDeSerialize(const string& data)
}
mlir::OpBuilder *opBuilder = PluginServer::GetInstance()->GetOpBuilder();
BindOp op = opBuilder->create<BindOp>(opBuilder->getUnknownLoc(), id, vars, bodyaddrs, block);
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -1144,6 +1167,7 @@ mlir::Operation *PluginJson::TryOpJsonDeSerialize(const string& data)
int64_t kind = GetID(node["kind"]);
mlir::OpBuilder *opBuilder = PluginServer::GetInstance()->GetOpBuilder();
TryOp op = opBuilder->create<TryOp>(opBuilder->getUnknownLoc(), id, evaladdrs, cleanupaddrs, kind);
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -1165,6 +1189,7 @@ mlir::Operation *PluginJson::CatchOpJsonDeSerialize(const string& data)
}
mlir::OpBuilder *opBuilder = PluginServer::GetInstance()->GetOpBuilder();
CatchOp op = opBuilder->create<CatchOp>(opBuilder->getUnknownLoc(), id, types, handleraddrs);
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -1176,7 +1201,7 @@ mlir::Operation *PluginJson::NopOpJsonDeSerialize(const string& data)
uint64_t id = GetID(node["id"]);
mlir::OpBuilder *opBuilder = PluginServer::GetInstance()->GetOpBuilder();
NopOp op = opBuilder->create<NopOp>(opBuilder->getUnknownLoc(), id);
- PluginServer::GetInstance()->InsertDefOperation(id, op.getOperation());
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -1208,7 +1233,7 @@ mlir::Operation *PluginJson::EHElseOpJsonDeSerialize(const string& data)
}
mlir::OpBuilder *opBuilder = PluginServer::GetInstance()->GetOpBuilder();
EHElseOp op = opBuilder->create<EHElseOp>(opBuilder->getUnknownLoc(), id, nbody, ebody);
- PluginServer::GetInstance()->InsertDefOperation(id, op.getOperation());
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -1233,7 +1258,7 @@ mlir::Operation *PluginJson::AsmOpJsonDeserialize(const string& data)
mlir::OpBuilder *opBuilder = PluginServer::GetInstance()->GetOpBuilder();
AsmOp op = opBuilder->create<AsmOp>(opBuilder->getUnknownLoc(), id, statement, nInputs, nOutputs,
nClobbers, ops);
- PluginServer::GetInstance()->InsertDefOperation(id, op.getOperation());
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -1278,7 +1303,7 @@ mlir::Operation *PluginJson::SwitchOpJsonDeserialize(const string& data)
mlir::OpBuilder *opBuilder = PluginServer::GetInstance()->GetOpBuilder();
SwitchOp op = opBuilder->create<SwitchOp>(opBuilder->getUnknownLoc(), id, index, address, defaultLabel, ops, defaultDest,
defaultDestAddr, caseDest, caseaddr);
- PluginServer::GetInstance()->InsertDefOperation(id, op.getOperation());
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
diff --git a/lib/PluginServer/PluginServer.cpp b/lib/PluginServer/PluginServer.cpp
index 9a8d5e5..6424b94 100644
--- a/lib/PluginServer/PluginServer.cpp
+++ b/lib/PluginServer/PluginServer.cpp
@@ -95,17 +95,21 @@ mlir::Block* PluginServer::FindBlock(uint64_t id)
return iter->second;
}
-mlir::Operation* PluginServer::FindDefOperation(uint64_t id)
+mlir::Operation* PluginServer::FindOperation(uint64_t id)
{
- auto iter = this->defOpMaps.find(id);
- assert(iter != this->defOpMaps.end());
- return iter->second;
+ mlir::Operation* op;
+ auto iter = this->opMaps.find(id);
+ // assert(iter != this->opMaps.end());
+ if (iter != this->opMaps.end()) {
+ return iter->second;
+ }
+ return NULL;
}
-bool PluginServer::InsertDefOperation(uint64_t id, mlir::Operation* op)
+bool PluginServer::InsertOperation(uint64_t id, mlir::Operation* op)
{
- auto iter = this->defOpMaps.find(id);
- this->defOpMaps.insert({id, op});
+ auto iter = this->opMaps.find(id);
+ this->opMaps.insert({id, op});
return true;
}
@@ -312,6 +316,7 @@ void PluginServer::RunServer()
}
log->LOGI("Server ppid:%d listening on port:%s\n", getppid(), port.c_str());
ServerSemPost(port);
+
register_mutex.lock();
RegisterCallbacks();
register_mutex.unlock();
diff --git a/user/SimpleLICMPass.cpp b/user/SimpleLICMPass.cpp
new file mode 100644
index 0000000..b78f86f
--- /dev/null
+++ b/user/SimpleLICMPass.cpp
@@ -0,0 +1,359 @@
+#include "PluginAPI/ControlFlowAPI.h"
+#include "PluginAPI/PluginServerAPI.h"
+#include "user/SimpleLICMPass.h"
+#include "PluginAPI/DataFlowAPI.h"
+#include "mlir/Support/LLVM.h"
+
+namespace PluginOpt {
+using std::string;
+using std::vector;
+using std::cout;
+using std::endl;
+using namespace mlir;
+using namespace PluginAPI;
+
+PluginServerAPI pluginAPI;
+ControlFlowAPI cfAPI;
+DataFlowAPI dfAPI;
+vector<AssignOp> move_stmt;
+
+std::map<mlir::Operation*, bool> visited;
+std::map<mlir::Operation*, bool> not_move;
+
+enum EDGE_FLAG {
+ EDGE_FALLTHRU,
+ EDGE_TRUE_VALUE,
+ EDGE_FALSE_VALUE
+};
+
+struct edgeDef {
+ Block *src;
+ Block *dest;
+ unsigned destIdx;
+ enum EDGE_FLAG flag;
+};
+typedef struct edgeDef edge;
+
+static vector<Block *> getPredecessors(Block *bb)
+{
+ vector<Block *> preds;
+ for (auto it = bb->pred_begin(); it != bb->pred_end(); ++it) {
+ Block *pred = *it;
+ preds.push_back(pred);
+ }
+ return preds;
+}
+
+static enum EDGE_FLAG GetEdgeFlag(Block *src, Block *dest)
+{
+ Operation *op = src->getTerminator();
+ enum EDGE_FLAG flag;
+ if (isa<FallThroughOp>(op)) {
+ flag = EDGE_FALLTHRU;
+ }
+ if (isa<CondOp>(op)) {
+ if (op->getSuccessor(0) == dest) {
+ flag = EDGE_TRUE_VALUE;
+ } else {
+ flag = EDGE_FALSE_VALUE;
+ }
+ }
+ return flag;
+}
+
+static vector<edge> GetPredEdges(Block *bb)
+{
+ unsigned i = 0;
+ vector<edge> edges;
+ for (auto it = bb->pred_begin(); it != bb->pred_end(); ++it) {
+ Block *pred = *it;
+ edge e;
+ e.src = pred;
+ e.dest = bb;
+ e.destIdx = i;
+ e.flag = GetEdgeFlag(bb, pred);
+ edges.push_back(e);
+ i++;
+ }
+ return edges;
+}
+
+static edge GetEdge(Block *src, Block *dest)
+{
+ vector<edge> edges = GetPredEdges(dest);
+ edge e;
+ for (auto elm : edges) {
+ if (elm.src == src) {
+ e = elm;
+ break;
+ }
+ }
+ return e;
+}
+
+static uint64_t getValueId(Value v)
+{
+ uint64_t resid = 0;
+ if (auto ssaop = dyn_cast<SSAOp>(v.getDefiningOp())) {
+ resid = ssaop.id();
+ } else if (auto memop = dyn_cast<MemOp>(v.getDefiningOp())) {
+ resid = memop.id();
+ } else if (auto constop = dyn_cast<ConstOp>(v.getDefiningOp())) {
+ resid = constop.id();
+ } else if (auto holderop = dyn_cast<PlaceholderOp>(v.getDefiningOp())){
+ resid = holderop.id();
+ } else if (auto componentop = dyn_cast<ComponentOp>(v.getDefiningOp())){
+ resid = componentop.id();
+ } else if (auto declop = llvm::dyn_cast<DeclBaseOp>(v.getDefiningOp())) {
+ return declop.id();
+ }
+ return resid;
+}
+
+static IDefineCode getValueDefCode(Value v)
+{
+ IDefineCode rescode;
+ if (auto ssaop = dyn_cast<SSAOp>(v.getDefiningOp())) {
+ rescode = ssaop.defCode().getValue();
+ } else if (auto memop = dyn_cast<MemOp>(v.getDefiningOp())) {
+ rescode = memop.defCode().getValue();
+ } else if (auto constop = dyn_cast<ConstOp>(v.getDefiningOp())) {
+ rescode = constop.defCode().getValue();
+ } else {
+ auto holderop = dyn_cast<PlaceholderOp>(v.getDefiningOp());
+ rescode = holderop.defCode().getValue();
+ }
+ return rescode;
+}
+
+static bool isValueExist(Value v)
+{
+ uint64_t vid = getValueId(v);
+ if (vid != 0) {
+ return true;
+ }
+ return false;
+}
+
+static bool isSSANameVar(Value v)
+{
+ if (!isValueExist(v) || getValueDefCode(v) != IDefineCode::SSA) {
+ return false;
+ }
+ auto ssaOp = dyn_cast<SSAOp>(v.getDefiningOp());
+ uint64_t varid = ssaOp.nameVarId();
+ if (varid != 0) {
+ return true;
+ }
+ return false;
+}
+
+static Operation *getSSADefStmtofValue(Value v)
+{
+ if (!isa<SSAOp>(v.getDefiningOp())) {
+ return NULL;
+ }
+ auto ssaOp = dyn_cast<SSAOp>(v.getDefiningOp());
+ // uint64_t id = ssaOp.id();
+ // pluginAPI.DebugValue(id);
+ Operation *op = ssaOp.GetSSADefOperation();
+ if (!op || !isa<AssignOp, PhiOp>(op)) {
+ return NULL;
+ }
+ return op;
+}
+
+/* Get the edge that first entered the loop. */
+static edge getLoopPreheaderEdge(LoopOp loop)
+{
+ Block *header = loop.GetHeader();
+ vector<Block *> preds = getPredecessors(header);
+
+ Block *src;
+ for (auto bb : preds) {
+ if (bb != loop.GetLatch()) {
+ src = bb;
+ break;
+ }
+ }
+
+ edge e = GetEdge(src, header);
+
+ return e;
+}
+
+void compute_invariantness(Block* bb)
+{
+ LoopOp loop_father = pluginAPI.GetBlockLoopFather(bb);
+ if (!loop_father.outerLoopId().getValue()){
+ return ;
+ }
+ // pluginAPI.DebugBlock(bb);
+ uint64_t bbAddr = pluginAPI.FindBasicBlock(bb);
+ // 处理bb中的phi语句
+ vector<PhiOp> phis = cfAPI.GetAllPhiOpInsideBlock(bb);
+ for (auto phi : phis) {
+ Value result = phi.GetResult();
+ uint64_t varId = getValueId(result);
+ // pluginAPI.DebugOperation(phi.id());
+ int n_args = phi.nArgs();
+ if (n_args <= 2 && !pluginAPI.IsVirtualOperand(varId)) {
+ for (int i = 0 ; i < n_args; i++) {
+ Value v = phi.GetArgDef(i);
+ if (isSSANameVar(v)) {
+ Operation* def = getSSADefStmtofValue(v);
+ if (!def) break;
+ Block *def_bb = def->getBlock();
+ if (def_bb == bb && visited.find(def) == visited.end()) {
+ pluginAPI.DebugOperation(phi.id());
+ not_move[def] = true;
+ break;
+ }
+ }
+ }
+ }
+ }
+ vector<mlir::Operation*> ops = cfAPI.GetAllOpsInsideBlock(bb);
+ map<uint64_t, bool> isinvariant;
+ vector<mlir::Value> variants;
+ int i = 0 ;
+ bool change = false;
+ int n = ops.size();
+ do {
+ change = false;
+ for (auto op : ops) {
+ bool may_move = true;
+ if (not_move.find(op)!=not_move.end()) continue;
+ if(!isa<AssignOp>(op)) continue;
+ visited[op] = true;
+ auto assign = dyn_cast<AssignOp>(op);
+ int num = assign.getNumOperands();
+ Value lhs = assign.GetLHS();
+ Value rhs1 = assign.GetRHS1();
+
+ Value vdef = dfAPI.GetGimpleVdef(assign.id());
+ uint64_t vdef_id = getValueId(vdef);
+ if(vdef_id) {
+ variants.push_back(lhs);
+ not_move[op]=true;
+ change = true;
+ continue;
+ }
+ vector<mlir::Value> vals = dfAPI.GetSsaUseOperand(assign.id());
+ for (auto val : vals) {
+ Operation* def = getSSADefStmtofValue(val);
+ if(!def) continue;
+ Block *def_bb = def->getBlock();
+ if(def_bb == bb && visited.find(def) == visited.end()) {
+ not_move[def] = true;
+ not_move[op] = true;
+ change = true;
+ continue;
+ } else if (not_move.find(def)!=not_move.end()) {
+ not_move[op] = true;
+ change = true;
+ continue;
+ }
+ }
+ for (auto v : variants) {
+ if (dfAPI.RefsMayAlias(lhs, v, 1) || dfAPI.RefsMayAlias(rhs1, v, 1)) {
+ may_move = false;
+ break;
+ }
+ }
+ if(num == 3 && may_move) {
+ Value rhs2 = assign.GetRHS2();
+ for (auto v : variants) {
+ if (dfAPI.RefsMayAlias(rhs2, v, 1)) {
+ may_move = false;
+ break;
+ }
+ }
+ }
+ if(!may_move) {
+ not_move[op] = true;
+ change = true;
+ }
+ }
+ } while(change);
+
+ cout<<"move statements: "<<endl;
+ for (auto op : ops) {
+ if (not_move.find(op) != not_move.end()) continue;
+ if(!isa<AssignOp>(op)) continue;
+ auto assign = dyn_cast<AssignOp>(op);
+ pluginAPI.DebugOperation(assign.id());
+ move_stmt.push_back(assign);
+ }
+ cout<<" "<<endl;
+ cout<<move_stmt.size()<<endl;
+ return ;
+}
+
+LoopOp get_outermost_loop(AssignOp assign)
+{
+ Block *bb = assign->getBlock();
+ LoopOp loop = pluginAPI.GetBlockLoopFather(bb);
+ LoopOp maxloop;
+ vector<mlir::Value> vals = dfAPI.GetSsaUseOperand(assign.id());
+ for (auto val: vals)
+ {
+ uint64_t id = getValueId(val);
+ pluginAPI.DebugValue(id);
+ Operation* defOp = getSSADefStmtofValue(val);
+ if (!defOp) continue;
+ cout<<"none"<<endl;
+ Block *def_bb = defOp->getBlock();
+ LoopOp def_loop = pluginAPI.GetBlockLoopFather(def_bb);
+ maxloop = pluginAPI.FindCommonLoop(&loop, &def_loop);
+ if (maxloop == def_loop) {
+ maxloop = loop;
+ }
+ }
+
+ return maxloop;
+}
+
+void move_worker(AssignOp assign)
+{
+ LoopOp level = get_outermost_loop(assign);
+ if (!level) return ;
+ edge e = getLoopPreheaderEdge(level);
+
+ // TODO.
+}
+static void ProcessSimpleLICM(uint64_t fun)
+{
+ cout << "Running first pass, Loop Invariant Code Motion\n";
+
+ PluginServerAPI pluginAPI;
+ DataFlowAPI dfAPI;
+
+ dfAPI.CalDominanceInfo(1, fun);
+ mlir::Plugin::FunctionOp funcOp = pluginAPI.GetFunctionOpById(fun);
+ if (funcOp == nullptr) return;
+
+ mlir::MLIRContext * context = funcOp.getOperation()->getContext();
+ mlir::OpBuilder opBuilder_temp = mlir::OpBuilder(context);
+ mlir::OpBuilder* opBuilder = &opBuilder_temp;
+ string name = funcOp.funcNameAttr().getValue().str();
+ fprintf(stderr, "Now process func : %s \n", name.c_str());
+ vector<LoopOp> allLoop = funcOp.GetAllLoops();
+ for (auto &loop : allLoop) {
+ Block *header = loop.GetHeader();
+ // pluginAPI.DebugBlock(header);
+ compute_invariantness(header);
+ }
+ for (auto stmt : move_stmt) {
+ // TODO.
+ move_worker(stmt);
+ }
+
+}
+int SimpleLICMPass::DoOptimize(uint64_t fun)
+{
+ ProcessSimpleLICM(fun);
+ return 0;
+}
+
+}
diff --git a/user/user.cpp b/user/user.cpp
index 6bfa524..2f4c6d1 100644
--- a/user/user.cpp
+++ b/user/user.cpp
@@ -23,12 +23,15 @@
#include "user/InlineFunctionPass.h"
#include "user/LocalVarSummeryPass.h"
#include "user/StructReorder.h"
+#include "user/SimpleLICMPass.h"
void RegisterCallbacks(void)
{
- PinServer::PluginServer *pluginServer = PinServer::PluginServer::GetInstance();
- pluginServer->RegisterOpt(std::make_shared<PluginOpt::InlineFunctionPass>(PluginOpt::HANDLE_BEFORE_IPA));
- pluginServer->RegisterOpt(std::make_shared<PluginOpt::LocalVarSummeryPass>(PluginOpt::HANDLE_BEFORE_IPA));
+ PinServer::PluginServer *pluginServer = PinServer::PluginServer::GetInstance();
+ // pluginServer->RegisterOpt(std::make_shared<PluginOpt::InlineFunctionPass>(PluginOpt::HANDLE_BEFORE_IPA));
+ // pluginServer->RegisterOpt(std::make_shared<PluginOpt::LocalVarSummeryPass>(PluginOpt::HANDLE_BEFORE_IPA));
+ PluginOpt::ManagerSetup setupData(PluginOpt::PASS_LAD, 1, PluginOpt::PASS_INSERT_AFTER);
+ pluginServer->RegisterPassManagerOpt(setupData, std::make_shared<PluginOpt::SimpleLICMPass>());
// PluginOpt::ManagerSetup setupData(PluginOpt::PASS_PHIOPT, 1, PluginOpt::PASS_INSERT_AFTER);
// pluginServer->RegisterPassManagerOpt(setupData, std::make_shared<PluginOpt::ArrayWidenPass>());
// PluginOpt::ManagerSetup setupData(PluginOpt::PASS_MAC, 1, PluginOpt::PASS_INSERT_AFTER);
--
2.33.0