pin-server/0033-Pin-server-Add-DataFlow-APIs.patch
Mingchuan Wu cfa9847203 [sync] Sync patch from openEuler/pin-server
(cherry picked from commit dc67116851cf691609f9677eb8a40ce7e1bf99e1)
2024-04-10 15:52:07 +08:00

1280 lines
45 KiB
Diff
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

From e8adf83db9cc977fe0145555715d083805636757 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E9=83=91=E6=99=A8=E5=8D=89?= <zhengchenhui1@huawei.com>
Date: Mon, 25 Dec 2023 10:24:23 +0800
Subject: [PATCH 6/7] [Pin-server] Add DataFlow APIs.
---
CMakeLists.txt | 6 +-
include/PluginAPI/BasicPluginOpsAPI.h | 5 +
include/PluginAPI/ControlFlowAPI.h | 1 +
include/PluginAPI/DataFlowAPI.h | 65 +++++
include/PluginAPI/PluginServerAPI.h | 5 +
include/PluginServer/ManagerSetup.h | 1 +
include/PluginServer/PluginCom.h | 2 +
include/PluginServer/PluginJson.h | 1 +
include/PluginServer/PluginServer.h | 14 +-
include/user/SimpleLICMPass.h | 45 ++++
lib/Dialect/PluginOps.cpp | 8 +
lib/PluginAPI/ControlFlowAPI.cpp | 17 +-
lib/PluginAPI/DataFlowAPI.cpp | 171 ++++++++++++
lib/PluginAPI/PluginServerAPI.cpp | 40 ++-
lib/PluginServer/PluginCom.cpp | 13 +-
lib/PluginServer/PluginJson.cpp | 37 ++-
lib/PluginServer/PluginServer.cpp | 19 +-
user/SimpleLICMPass.cpp | 359 ++++++++++++++++++++++++++
user/user.cpp | 9 +-
19 files changed, 792 insertions(+), 26 deletions(-)
create mode 100644 include/PluginAPI/DataFlowAPI.h
create mode 100644 include/user/SimpleLICMPass.h
create mode 100644 lib/PluginAPI/DataFlowAPI.cpp
create mode 100644 user/SimpleLICMPass.cpp
diff --git a/CMakeLists.txt b/CMakeLists.txt
index df8b634..925f80e 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -85,10 +85,11 @@ target_link_libraries(plg_grpc_proto
add_subdirectory(include)
add_subdirectory(lib)
add_library(pin_user SHARED
- # "user/ArrayWidenPass.cpp"
- "user/StructReorder.cpp"
+ "user/ArrayWidenPass.cpp"
+ #"user/StructReorder.cpp"
"user/InlineFunctionPass.cpp"
"user/LocalVarSummeryPass.cpp"
+ "user/SimpleLICMPass.cpp"
"user/user.cpp")
target_link_libraries(pin_user
@@ -101,6 +102,7 @@ add_custom_command(TARGET pin_user COMMAND sha256sum libpin_user.so > libpin_use
add_executable(pin_server
"lib/PluginServer/PluginServer.cpp"
"lib/PluginAPI/ControlFlowAPI.cpp"
+ "lib/PluginAPI/DataFlowAPI.cpp"
"lib/PluginServer/PluginGrpc.cpp"
"lib/PluginServer/PluginJson.cpp"
"lib/PluginServer/PluginCom.cpp"
diff --git a/include/PluginAPI/BasicPluginOpsAPI.h b/include/PluginAPI/BasicPluginOpsAPI.h
index 065274d..4e05e6f 100644
--- a/include/PluginAPI/BasicPluginOpsAPI.h
+++ b/include/PluginAPI/BasicPluginOpsAPI.h
@@ -86,6 +86,7 @@ public:
virtual pair<mlir::Block*, mlir::Block*> LoopSingleExit(uint64_t) = 0;
virtual vector<pair<mlir::Block*, mlir::Block*> > GetLoopExitEdges(uint64_t) = 0;
virtual LoopOp GetBlockLoopFather(mlir::Block*) = 0;
+ virtual LoopOp FindCommonLoop(LoopOp*, LoopOp*) = 0;
virtual PhiOp GetPhiOp(uint64_t) = 0;
virtual CallOp GetCallOp(uint64_t) = 0;
virtual bool SetLhsInCallOp(uint64_t, uint64_t) = 0;
@@ -98,6 +99,8 @@ public:
virtual uint32_t AddArgInPhiOp(uint64_t, uint64_t, uint64_t, uint64_t) = 0;
virtual PhiOp CreatePhiOp(uint64_t, uint64_t) = 0;
virtual void DebugValue(uint64_t) = 0;
+ virtual void DebugOperation(uint64_t) = 0;
+ virtual void DebugBlock(mlir::Block*) = 0;
virtual bool IsLtoOptimize() = 0;
virtual bool IsWholeProgram() = 0;
@@ -109,6 +112,8 @@ public:
virtual mlir::Value BuildMemRef(PluginIR::PluginTypeBase, mlir::Value, mlir::Value) = 0;
virtual bool RedirectFallthroughTarget(FallThroughOp&, mlir::Block*, mlir::Block*) = 0;
virtual mlir::Operation* GetSSADefOperation(uint64_t) = 0;
+
+ virtual bool IsVirtualOperand(uint64_t) = 0;
}; // class BasicPluginOpsAPI
} // namespace PluginAPI
diff --git a/include/PluginAPI/ControlFlowAPI.h b/include/PluginAPI/ControlFlowAPI.h
index 429ba11..8c4c38d 100644
--- a/include/PluginAPI/ControlFlowAPI.h
+++ b/include/PluginAPI/ControlFlowAPI.h
@@ -40,6 +40,7 @@ public:
bool UpdateSSA(void);
vector<PhiOp> GetAllPhiOpInsideBlock(mlir::Block *b);
+ vector<mlir::Operation*> GetAllOpsInsideBlock(mlir::Block *b);
mlir::Block* CreateBlock(mlir::Block*, FunctionOp*);
void DeleteBlock(mlir::Block*, FunctionOp*);
diff --git a/include/PluginAPI/DataFlowAPI.h b/include/PluginAPI/DataFlowAPI.h
new file mode 100644
index 0000000..54b6b50
--- /dev/null
+++ b/include/PluginAPI/DataFlowAPI.h
@@ -0,0 +1,65 @@
+/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License"); you may
+ not use this file except in compliance with the License. You may obtain
+ a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ License for the specific language governing permissions and limitations
+ under the License.
+
+ Author: Chenhui Zheng
+ Create: 2023-11-01
+ Description:
+ This file contains the declaration of the PluginAPI_Server class.
+*/
+#ifndef PLUGIN_FRAMEWORK_DATA_FLOW_API_H
+#define PLUGIN_FRAMEWORK_DATA_FLOW_API_H
+
+#include "BasicPluginOpsAPI.h"
+#include "PluginServer/PluginServer.h"
+#include "Dialect/PluginTypes.h"
+#include "PluginServerAPI.h"
+
+namespace PluginAPI {
+
+using std::string;
+using std::vector;
+
+using namespace PinServer;
+using namespace mlir::Plugin;
+class DataFlowAPI {
+public:
+ DataFlowAPI() = default;
+ ~DataFlowAPI() = default;
+
+ // 计算支配信息
+ void CalDominanceInfo(uint64_t, uint64_t);
+
+ // USE-DEF
+ vector<mlir::Operation*> GetImmUseStmts(mlir::Value);
+ mlir::Value GetGimpleVuse(uint64_t);
+ mlir::Value GetGimpleVdef(uint64_t);
+ vector<mlir::Value> GetSsaUseOperand(uint64_t);
+ vector<mlir::Value> GetSsaDefOperand(uint64_t);
+ vector<mlir::Value> GetPhiOrStmtUse(uint64_t);
+ vector<mlir::Value> GetPhiOrStmtDef(uint64_t);
+
+ //别名分析
+ bool RefsMayAlias(mlir::Value, mlir::Value, uint64_t);
+
+ // 指针分析
+ bool PTIncludesDecl(mlir::Value, uint64_t);
+ bool PTsIntersect(mlir::Value, mlir::Value);
+
+private:
+ PluginServerAPI pluginAPI;
+};
+
+} // namespace PluginAPI
+
+#endif // PLUGIN_FRAMEWORK_CONTROL_FLOW_API_H
diff --git a/include/PluginAPI/PluginServerAPI.h b/include/PluginAPI/PluginServerAPI.h
index 9cc8498..236c8a1 100644
--- a/include/PluginAPI/PluginServerAPI.h
+++ b/include/PluginAPI/PluginServerAPI.h
@@ -86,6 +86,7 @@ public:
void SetLatch(LoopOp*, mlir::Block*);
vector<mlir::Block*> GetLoopBody(uint64_t loopID);
LoopOp GetBlockLoopFather(mlir::Block*);
+ LoopOp FindCommonLoop(LoopOp*, LoopOp*);
mlir::Block* FindBlock(uint64_t);
uint64_t FindBasicBlock(mlir::Block*);
bool InsertValue(uint64_t, mlir::Value);
@@ -104,6 +105,8 @@ public:
/* Plugin API for ConstOp. */
mlir::Value CreateConstOp(mlir::Attribute, mlir::Type) override;
void DebugValue(uint64_t) override;
+ void DebugOperation(uint64_t) override;
+ void DebugBlock(mlir::Block*) override;
bool IsLtoOptimize() override;
bool IsWholeProgram() override;
@@ -117,6 +120,8 @@ public:
mlir::Operation* GetSSADefOperation(uint64_t) override;
void InsertCreatedBlock(uint64_t id, mlir::Block* block);
+ bool IsVirtualOperand(uint64_t) override;
+
int64_t GetInjectDataAddress() override; // 获取注入点返回数据的地址
string GetDeclSourceFile(int64_t) override;
string VariableName(int64_t) override;
diff --git a/include/PluginServer/ManagerSetup.h b/include/PluginServer/ManagerSetup.h
index e264237..1fe3f41 100755
--- a/include/PluginServer/ManagerSetup.h
+++ b/include/PluginServer/ManagerSetup.h
@@ -28,6 +28,7 @@ enum RefPassName {
PASS_PHIOPT,
PASS_SSA,
PASS_LOOP,
+ PASS_LAD,
PASS_MAC,
};
diff --git a/include/PluginServer/PluginCom.h b/include/PluginServer/PluginCom.h
index 7deb615..2897742 100755
--- a/include/PluginServer/PluginCom.h
+++ b/include/PluginServer/PluginCom.h
@@ -74,6 +74,7 @@ public:
uint64_t GetIdResult(void);
vector<uint64_t> GetIdsResult(void);
mlir::Value GetValueResult(void);
+ vector<mlir::Value> GetValuesResult(void);
vector<mlir::Plugin::PhiOp> GetPhiOpsResult(void);
private:
@@ -94,6 +95,7 @@ private:
uint64_t idResult;
vector<uint64_t> idsResult;
mlir::Value valueResult;
+ vector<mlir::Value> valuesResult;
PluginIR::PluginTypeBase pTypeResult;
mlir::Plugin::DeclBaseOp declOp;
mlir::Plugin::FieldDeclOp fielddeclOp;
diff --git a/include/PluginServer/PluginJson.h b/include/PluginServer/PluginJson.h
index 1d4bfbc..ede7b53 100755
--- a/include/PluginServer/PluginJson.h
+++ b/include/PluginServer/PluginJson.h
@@ -88,6 +88,7 @@ public:
/* 将json格式数据解析成map<string, string>格式 */
void GetAttributes(Json::Value node, map<string, string>& attributes);
mlir::Value ValueJsonDeSerialize(Json::Value valueJson);
+ void ValuesJsonDeSerialize(const string&, vector<mlir::Value>&);
Json::Value TypeJsonSerialize(PluginIR::PluginTypeBase type);
mlir::Value MemRefDeSerialize(const string& data);
bool ProcessBlock(mlir::Block*, mlir::Region&, const Json::Value&);
diff --git a/include/PluginServer/PluginServer.h b/include/PluginServer/PluginServer.h
index 81f92e9..b44c1d6 100644
--- a/include/PluginServer/PluginServer.h
+++ b/include/PluginServer/PluginServer.h
@@ -244,6 +244,11 @@ public:
RemoteCallClientWithAPI(funName, params);
return pluginCom.GetValueResult();
}
+ vector<mlir::Value> GetValuesResult(const string& funName, const string& params)
+ {
+ RemoteCallClientWithAPI(funName, params);
+ return pluginCom.GetValuesResult();
+ }
vector<mlir::Plugin::PhiOp> GetPhiOpsResult(const string& funName, const string& params)
{
RemoteCallClientWithAPI(funName, params);
@@ -264,15 +269,15 @@ public:
valueMaps.clear();
blockMaps.clear();
basicblockMaps.clear();
- defOpMaps.clear();
+ opMaps.clear();
}
uint64_t FindBasicBlock(mlir::Block*);
bool InsertValue(uint64_t, mlir::Value);
bool HaveValue(uint64_t);
mlir::Value GetValue(uint64_t);
- mlir::Operation* FindDefOperation(uint64_t);
- bool InsertDefOperation(uint64_t, mlir::Operation*);
+ mlir::Operation* FindOperation(uint64_t);
+ bool InsertOperation(uint64_t, mlir::Operation*);
void RemoteCallClientWithAPI(const string& api, const string& params);
private:
@@ -296,7 +301,8 @@ private:
// process Block.
std::map<uint64_t, mlir::Block*> blockMaps;
std::map<mlir::Block*, uint64_t> basicblockMaps;
- std::map<uint64_t, mlir::Operation*> defOpMaps;
+ // std::map<uint64_t, mlir::Operation*> defOpMaps;
+ std::map<uint64_t, mlir::Operation*> opMaps; // 保存所有的<gimpleid, op>键值对
/* 解析客户端发送过来的-fplugin-arg参数并保存在私有变量args中 */
void ParseArgv(const string& data);
diff --git a/include/user/SimpleLICMPass.h b/include/user/SimpleLICMPass.h
new file mode 100644
index 0000000..aa6ffa7
--- /dev/null
+++ b/include/user/SimpleLICMPass.h
@@ -0,0 +1,45 @@
+/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License"); you may
+ not use this file except in compliance with the License. You may obtain
+ a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ License for the specific language governing permissions and limitations
+ under the License.
+
+ Author: Mingchuan Wu and Yancheng Li
+ Create: 2022-08-18
+ Description:
+ This file contains the declaration of the ArrayWidenPass class.
+*/
+
+#ifndef SIMPLE_LICM_PASS_H
+#define SIMPLE_LICM_PASS_H
+
+#include "PluginServer/PluginOptBase.h"
+
+namespace PluginOpt {
+class SimpleLICMPass : public PluginOptBase {
+public:
+ SimpleLICMPass() : PluginOptBase(HANDLE_MANAGER_SETUP)
+ {
+ }
+ bool Gate()
+ {
+ return true;
+ }
+ int DoOptimize()
+ {
+ uint64_t fun = (uint64_t)GetFuncAddr();
+ return DoOptimize(fun);
+ }
+ int DoOptimize(uint64_t fun);
+};
+}
+
+#endif
\ No newline at end of file
diff --git a/lib/Dialect/PluginOps.cpp b/lib/Dialect/PluginOps.cpp
index 3602856..d4b061f 100644
--- a/lib/Dialect/PluginOps.cpp
+++ b/lib/Dialect/PluginOps.cpp
@@ -22,6 +22,7 @@
#include "PluginAPI/PluginServerAPI.h"
#include "PluginAPI/ControlFlowAPI.h"
+#include "PluginAPI/DataFlowAPI.h"
#include "Dialect/PluginDialect.h"
#include "Dialect/PluginOps.h"
#include "Dialect/PluginTypes.h"
@@ -265,6 +266,13 @@ bool LoopOp::IsLoopFather(mlir::Block* b)
return id == loopId;
}
+LoopOp LoopOp::FindCommonLoop(LoopOp* loop_1, LoopOp* loop_2)
+{
+ PluginAPI::PluginServerAPI pluginAPI;
+ LoopOp commonLoop = pluginAPI.FindCommonLoop(loop_1, loop_2);
+ return commonLoop;
+}
+
vector<pair<mlir::Block*, mlir::Block*> > LoopOp::GetExitEdges()
{
PluginAPI::PluginServerAPI pluginAPI;
diff --git a/lib/PluginAPI/ControlFlowAPI.cpp b/lib/PluginAPI/ControlFlowAPI.cpp
index b356ae9..3cf72f6 100644
--- a/lib/PluginAPI/ControlFlowAPI.cpp
+++ b/lib/PluginAPI/ControlFlowAPI.cpp
@@ -106,10 +106,25 @@ vector<PhiOp> ControlFlowAPI::GetAllPhiOpInsideBlock(mlir::Block *b)
string funName = __func__;
root["bbAddr"] = std::to_string(server->FindBasicBlock(b));
string params = root.toStyledString();
-
return GetPhiOperationResult(funName, params);
}
+vector<mlir::Operation*> ControlFlowAPI::GetAllOpsInsideBlock(mlir::Block *b)
+{
+ PluginServer *server = PluginServer::GetInstance();
+ Json::Value root;
+ string funName = __func__;
+ root["bbAddr"] = std::to_string(server->FindBasicBlock(b));
+ string params = root.toStyledString();
+ vector<uint64_t> ids = server->GetIdsResult(funName, params);
+ vector<mlir::Operation*> ops;
+ for (auto id: ids) {
+ mlir::Operation* op = server->FindOperation(id);
+ ops.push_back(op);
+ }
+ return ops;
+}
+
mlir::Block* ControlFlowAPI::CreateBlock(mlir::Block* b, FunctionOp *funcOp)
{
Json::Value root;
diff --git a/lib/PluginAPI/DataFlowAPI.cpp b/lib/PluginAPI/DataFlowAPI.cpp
new file mode 100644
index 0000000..2fec19f
--- /dev/null
+++ b/lib/PluginAPI/DataFlowAPI.cpp
@@ -0,0 +1,171 @@
+/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License"); you may
+ not use this file except in compliance with the License. You may obtain
+ a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ License for the specific language governing permissions and limitations
+ under the License.
+
+*/
+#include "PluginAPI/DataFlowAPI.h"
+
+namespace PluginAPI {
+using namespace PinServer;
+using namespace mlir::Plugin;
+
+static uint64_t GetValueId(mlir::Value v)
+{
+ mlir::Operation*op = v.getDefiningOp();
+ if (auto mOp = llvm::dyn_cast<MemOp>(op)) {
+ return mOp.id();
+ } else if (auto ssaOp = llvm::dyn_cast<SSAOp>(op)) {
+ return ssaOp.id();
+ } else if (auto cstOp = llvm::dyn_cast<ConstOp>(op)) {
+ return cstOp.id();
+ } else if (auto treelistop = llvm::dyn_cast<ListOp>(op)) {
+ return treelistop.id();
+ } else if (auto strop = llvm::dyn_cast<StrOp>(op)) {
+ return strop.id();
+ } else if (auto arrayop = llvm::dyn_cast<ArrayOp>(op)) {
+ return arrayop.id();
+ } else if (auto declop = llvm::dyn_cast<DeclBaseOp>(op)) {
+ return declop.id();
+ } else if (auto fieldop = llvm::dyn_cast<FieldDeclOp>(op)) {
+ return fieldop.id();
+ } else if (auto addressop = llvm::dyn_cast<AddressOp>(op)) {
+ return addressop.id();
+ } else if (auto constructorop = llvm::dyn_cast<ConstructorOp>(op)) {
+ return constructorop.id();
+ } else if (auto vecop = llvm::dyn_cast<VecOp>(op)) {
+ return vecop.id();
+ } else if (auto blockop = llvm::dyn_cast<BlockOp>(op)) {
+ return blockop.id();
+ } else if (auto compop = llvm::dyn_cast<ComponentOp>(op)) {
+ return compop.id();
+ } else if (auto phOp = llvm::dyn_cast<PlaceholderOp>(op)) {
+ return phOp.id();
+ }
+ return 0;
+}
+
+/* dir: 1 or 2 */
+void DataFlowAPI::CalDominanceInfo(uint64_t dir, uint64_t funcId)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["dir"] = std::to_string(dir);
+ root["funcId"] = std::to_string(funcId);
+ string params = root.toStyledString();
+ PluginServer::GetInstance()->RemoteCallClientWithAPI(funName, params);
+}
+
+vector<mlir::Operation*> DataFlowAPI::GetImmUseStmts(mlir::Value v)
+{
+ Json::Value root;
+ string funName = __func__;
+ uint64_t varId = GetValueId(v);
+ root["varId"] = std::to_string(varId);
+ string params = root.toStyledString();
+ vector<uint64_t> retIds = PluginServer::GetInstance()->GetIdsResult(funName, params);
+ vector<mlir::Operation*> ops;
+ for (auto id : retIds) {
+ ops.push_back(PluginServer::GetInstance()->FindOperation(id));
+ }
+ return ops;
+}
+
+mlir::Value DataFlowAPI::GetGimpleVuse(uint64_t opId)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["opId"] = std::to_string(opId);
+ string params = root.toStyledString();
+ return PluginServer::GetInstance()->GetValueResult(funName, params);
+}
+
+mlir::Value DataFlowAPI::GetGimpleVdef(uint64_t opId)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["opId"] = std::to_string(opId);
+ string params = root.toStyledString();
+ return PluginServer::GetInstance()->GetValueResult(funName, params);
+}
+
+vector<mlir::Value> DataFlowAPI::GetSsaUseOperand(uint64_t opId)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["opId"] = std::to_string(opId);
+ string params = root.toStyledString();
+ return PluginServer::GetInstance()->GetValuesResult(funName, params);
+}
+
+vector<mlir::Value> DataFlowAPI::GetSsaDefOperand(uint64_t opId)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["opId"] = std::to_string(opId);
+ string params = root.toStyledString();
+ return PluginServer::GetInstance()->GetValuesResult(funName, params);
+}
+
+vector<mlir::Value> DataFlowAPI::GetPhiOrStmtUse(uint64_t opId)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["opId"] = std::to_string(opId);
+ string params = root.toStyledString();
+ return PluginServer::GetInstance()->GetValuesResult(funName, params);
+}
+
+vector<mlir::Value> DataFlowAPI::GetPhiOrStmtDef(uint64_t opId)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["opId"] = std::to_string(opId);
+ string params = root.toStyledString();
+ return PluginServer::GetInstance()->GetValuesResult(funName, params);
+}
+
+/* flag : 0 or 1 */
+bool DataFlowAPI::RefsMayAlias(mlir::Value v1, mlir::Value v2, uint64_t flag)
+{
+ Json::Value root;
+ string funName = __func__;
+ uint64_t id1 = GetValueId(v1);
+ uint64_t id2 = GetValueId(v2);
+ root["id1"] = std::to_string(id1);
+ root["id2"] = std::to_string(id2);
+ root["flag"] = std::to_string(flag);
+ string params = root.toStyledString();
+ return PluginServer::GetInstance()->GetBoolResult(funName, params);
+}
+
+bool DataFlowAPI::PTIncludesDecl(mlir::Value ptr, uint64_t declId)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["ptrId"] = std::to_string(GetValueId(ptr));
+ root["declId"] = std::to_string(declId);
+ string params = root.toStyledString();
+ return PluginServer::GetInstance()->GetBoolResult(funName, params);
+}
+
+bool DataFlowAPI::PTsIntersect(mlir::Value ptr_1, mlir::Value ptr_2)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["ptrId_1"] = std::to_string(GetValueId(ptr_1));
+ root["ptrId_2"] = std::to_string(GetValueId(ptr_2));
+ string params = root.toStyledString();
+ return PluginServer::GetInstance()->GetBoolResult(funName, params);
+}
+
+}
\ No newline at end of file
diff --git a/lib/PluginAPI/PluginServerAPI.cpp b/lib/PluginAPI/PluginServerAPI.cpp
index 09bd358..0b7d8ab 100644
--- a/lib/PluginAPI/PluginServerAPI.cpp
+++ b/lib/PluginAPI/PluginServerAPI.cpp
@@ -372,6 +372,15 @@ mlir::Value PluginServerAPI::ConfirmValue(mlir::Value v)
return PluginServer::GetInstance()->GetValueResult(funName, params);
}
+bool PluginServerAPI::IsVirtualOperand(uint64_t id)
+{
+ Json::Value root;
+ string funName = "IsVirtualOperand";
+ root["id"] = std::to_string(id);
+ string params = root.toStyledString();
+ return PluginServer::GetInstance()->GetBoolResult(funName, params);
+}
+
mlir::Value PluginServerAPI::BuildMemRef(PluginIR::PluginTypeBase type,
mlir::Value base, mlir::Value offset)
{
@@ -811,6 +820,16 @@ LoopOp PluginServerAPI::GetBlockLoopFather(mlir::Block* b)
return PluginServer::GetInstance()->LoopOpResult(funName, params);
}
+LoopOp PluginServerAPI::FindCommonLoop(LoopOp* loop_1, LoopOp* loop_2)
+{
+ Json::Value root;
+ string funName("FindCommonLoop");
+ root["loopId_1"] = loop_1->idAttr().getInt();
+ root["loopId_2"] = loop_2->idAttr().getInt();
+ string params = root.toStyledString();
+ return PluginServer::GetInstance()->LoopOpResult(funName, params);
+}
+
mlir::Block* PluginServerAPI::FindBlock(uint64_t b)
{
PluginServer *server = PluginServer::GetInstance();
@@ -845,7 +864,7 @@ bool PluginServerAPI::RedirectFallthroughTarget(
mlir::Operation* PluginServerAPI::GetSSADefOperation(uint64_t addr)
{
- return PluginServer::GetInstance()->FindDefOperation(addr);
+ return PluginServer::GetInstance()->FindOperation(addr);
}
void PluginServerAPI::InsertCreatedBlock(uint64_t id, mlir::Block* block)
@@ -862,6 +881,25 @@ void PluginServerAPI::DebugValue(uint64_t valId)
PluginServer::GetInstance()->RemoteCallClientWithAPI(funName, params);
}
+void PluginServerAPI::DebugOperation(uint64_t opId)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["opId"] = opId;
+ string params = root.toStyledString();
+ PluginServer::GetInstance()->RemoteCallClientWithAPI(funName, params);
+}
+
+void PluginServerAPI::DebugBlock(mlir::Block* b)
+{
+ PluginServer *server = PluginServer::GetInstance();
+ Json::Value root;
+ string funName = __func__;
+ root["bbAddr"] = std::to_string(server->FindBasicBlock(b));
+ string params = root.toStyledString();
+ server->RemoteCallClientWithAPI(funName, params);
+}
+
bool PluginServerAPI::IsLtoOptimize()
{
Json::Value root;
diff --git a/lib/PluginServer/PluginCom.cpp b/lib/PluginServer/PluginCom.cpp
index aaf4ba8..8717540 100755
--- a/lib/PluginServer/PluginCom.cpp
+++ b/lib/PluginServer/PluginCom.cpp
@@ -148,6 +148,13 @@ mlir::Value PluginCom::GetValueResult()
return this->valueResult;
}
+vector<mlir::Value> PluginCom::GetValuesResult()
+{
+ vector<mlir::Value> retVals = valuesResult;
+ valuesResult.clear();
+ return retVals;
+}
+
vector<mlir::Plugin::PhiOp> PluginCom::GetPhiOpsResult()
{
vector<mlir::Plugin::PhiOp> retOps;
@@ -196,9 +203,11 @@ void PluginCom::JsonDeSerialize(const string& key, const string& data)
this->declOp = llvm::dyn_cast<mlir::Plugin::DeclBaseOp>(decl.getDefiningOp());
} else if (key == "GetFieldsOpResult") {
json.FieldOpsJsonDeSerialize(data, this->fieldsOps);
- } else if (key == "OpsResult") {
+ } else if (key == "OpResult") {
json.OpJsonDeSerialize(data.c_str(), this->opData);
- } else if (key == "ValueResult") {
+ } else if (key == "ValuesResult") {
+ json.ValuesJsonDeSerialize(data, this->valuesResult);
+ }else if (key == "ValueResult") {
Json::Value node;
Json::Reader reader;
reader.parse(data, node);
diff --git a/lib/PluginServer/PluginJson.cpp b/lib/PluginServer/PluginJson.cpp
index 0392b75..864a444 100755
--- a/lib/PluginServer/PluginJson.cpp
+++ b/lib/PluginServer/PluginJson.cpp
@@ -189,6 +189,20 @@ mlir::Value PluginJson::ValueJsonDeSerialize(Json::Value valueJson)
return opValue;
}
+void PluginJson::ValuesJsonDeSerialize(const string& data, vector<mlir::Value>& valuesData)
+{
+ Json::Value root;
+ Json::Reader reader;
+ Json::Value node;
+ reader.parse(data, root);
+ Json::Value::Members operation = root.getMemberNames();
+ for (size_t iter = 0; iter < operation.size(); iter++) {
+ string operationKey = "Value" + std::to_string(iter);
+ node = root[operationKey];
+ valuesData.push_back(ValueJsonDeSerialize(node));
+ }
+}
+
mlir::Value PluginJson::MemRefDeSerialize(const string& data)
{
Json::Value root;
@@ -666,6 +680,7 @@ mlir::Operation *PluginJson::CallOpJsonDeSerialize(const string& data)
op = opBuilder->create<CallOp>(opBuilder->getUnknownLoc(),
id, callName, ops);
}
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -690,6 +705,7 @@ mlir::Operation *PluginJson::CondOpJsonDeSerialize(const string& data)
CondOp op = opBuilder->create<CondOp>(
opBuilder->getUnknownLoc(), id, address, iCode, LHS,
RHS, tb, fb, tbaddr, fbaddr, trueLabel, falseLabel);
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -736,7 +752,7 @@ mlir::Operation *PluginJson::AssignOpJsonDeSerialize(const string& data)
mlir::OpBuilder *opBuilder = PluginServer::GetInstance()->GetOpBuilder();
AssignOp op = opBuilder->create<AssignOp>(opBuilder->getUnknownLoc(),
ops, id, iCode);
- PluginServer::GetInstance()->InsertDefOperation(id, op.getOperation());
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -760,7 +776,7 @@ mlir::Operation *PluginJson::PhiOpJsonDeSerialize(const string& data)
PhiOp op = opBuilder->create<PhiOp>(opBuilder->getUnknownLoc(),
ops, id, capacity, nArgs);
- PluginServer::GetInstance()->InsertDefOperation(id, op.getOperation());
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -1003,6 +1019,7 @@ mlir::Operation *PluginJson::GotoOpJsonDeSerialize(const string& data)
mlir::Block* success = PluginServer::GetInstance()->FindBlock(successaddr);
mlir::OpBuilder *opBuilder = PluginServer::GetInstance()->GetOpBuilder();
GotoOp op = opBuilder->create<GotoOp>(opBuilder->getUnknownLoc(), id, address, dest, success, successaddr);
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -1030,6 +1047,7 @@ mlir::Operation *PluginJson::TransactionOpJsonDeSerialize(const string& data)
mlir::OpBuilder *opBuilder = PluginServer::GetInstance()->GetOpBuilder();
TransactionOp op = opBuilder->create<TransactionOp>(opBuilder->getUnknownLoc(), id, address, stmtaddr, labelNorm,
labelUninst, labelOver, fallthrough, fallthroughaddr, abort, abortaddr);
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -1044,6 +1062,7 @@ mlir::Operation *PluginJson::ResxOpJsonDeSerialize(const string& data)
mlir::OpBuilder *opBuilder = PluginServer::GetInstance()->GetOpBuilder();
ResxOp op = opBuilder->create<ResxOp>(opBuilder->getUnknownLoc(),
id, address, region);
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -1056,6 +1075,7 @@ mlir::Operation *PluginJson::EHMntOpJsonDeSerialize(const string& data)
mlir::Value decl = ValueJsonDeSerialize(node["decl"]);
mlir::OpBuilder *opBuilder = PluginServer::GetInstance()->GetOpBuilder();
EHMntOp op = opBuilder->create<EHMntOp>(opBuilder->getUnknownLoc(), id, decl);
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -1082,6 +1102,7 @@ mlir::Operation *PluginJson::EHDispatchOpJsonDeSerialize(const string& data)
mlir::OpBuilder *opBuilder = PluginServer::GetInstance()->GetOpBuilder();
EHDispatchOp op = opBuilder->create<EHDispatchOp>(opBuilder->getUnknownLoc(),
id, address, region, ehHandlers, ehHandlersaddrs);
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -1094,6 +1115,7 @@ mlir::Operation *PluginJson::LabelOpJsonDeSerialize(const string& data)
mlir::Value label = ValueJsonDeSerialize(node["label"]);
mlir::OpBuilder *opBuilder = PluginServer::GetInstance()->GetOpBuilder();
LabelOp op = opBuilder->create<LabelOp>(opBuilder->getUnknownLoc(), id, label);
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -1116,6 +1138,7 @@ mlir::Operation *PluginJson::BindOpJsonDeSerialize(const string& data)
}
mlir::OpBuilder *opBuilder = PluginServer::GetInstance()->GetOpBuilder();
BindOp op = opBuilder->create<BindOp>(opBuilder->getUnknownLoc(), id, vars, bodyaddrs, block);
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -1144,6 +1167,7 @@ mlir::Operation *PluginJson::TryOpJsonDeSerialize(const string& data)
int64_t kind = GetID(node["kind"]);
mlir::OpBuilder *opBuilder = PluginServer::GetInstance()->GetOpBuilder();
TryOp op = opBuilder->create<TryOp>(opBuilder->getUnknownLoc(), id, evaladdrs, cleanupaddrs, kind);
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -1165,6 +1189,7 @@ mlir::Operation *PluginJson::CatchOpJsonDeSerialize(const string& data)
}
mlir::OpBuilder *opBuilder = PluginServer::GetInstance()->GetOpBuilder();
CatchOp op = opBuilder->create<CatchOp>(opBuilder->getUnknownLoc(), id, types, handleraddrs);
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -1176,7 +1201,7 @@ mlir::Operation *PluginJson::NopOpJsonDeSerialize(const string& data)
uint64_t id = GetID(node["id"]);
mlir::OpBuilder *opBuilder = PluginServer::GetInstance()->GetOpBuilder();
NopOp op = opBuilder->create<NopOp>(opBuilder->getUnknownLoc(), id);
- PluginServer::GetInstance()->InsertDefOperation(id, op.getOperation());
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -1208,7 +1233,7 @@ mlir::Operation *PluginJson::EHElseOpJsonDeSerialize(const string& data)
}
mlir::OpBuilder *opBuilder = PluginServer::GetInstance()->GetOpBuilder();
EHElseOp op = opBuilder->create<EHElseOp>(opBuilder->getUnknownLoc(), id, nbody, ebody);
- PluginServer::GetInstance()->InsertDefOperation(id, op.getOperation());
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -1233,7 +1258,7 @@ mlir::Operation *PluginJson::AsmOpJsonDeserialize(const string& data)
mlir::OpBuilder *opBuilder = PluginServer::GetInstance()->GetOpBuilder();
AsmOp op = opBuilder->create<AsmOp>(opBuilder->getUnknownLoc(), id, statement, nInputs, nOutputs,
nClobbers, ops);
- PluginServer::GetInstance()->InsertDefOperation(id, op.getOperation());
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
@@ -1278,7 +1303,7 @@ mlir::Operation *PluginJson::SwitchOpJsonDeserialize(const string& data)
mlir::OpBuilder *opBuilder = PluginServer::GetInstance()->GetOpBuilder();
SwitchOp op = opBuilder->create<SwitchOp>(opBuilder->getUnknownLoc(), id, index, address, defaultLabel, ops, defaultDest,
defaultDestAddr, caseDest, caseaddr);
- PluginServer::GetInstance()->InsertDefOperation(id, op.getOperation());
+ PluginServer::GetInstance()->InsertOperation(id, op.getOperation());
return op.getOperation();
}
diff --git a/lib/PluginServer/PluginServer.cpp b/lib/PluginServer/PluginServer.cpp
index 9a8d5e5..6424b94 100644
--- a/lib/PluginServer/PluginServer.cpp
+++ b/lib/PluginServer/PluginServer.cpp
@@ -95,17 +95,21 @@ mlir::Block* PluginServer::FindBlock(uint64_t id)
return iter->second;
}
-mlir::Operation* PluginServer::FindDefOperation(uint64_t id)
+mlir::Operation* PluginServer::FindOperation(uint64_t id)
{
- auto iter = this->defOpMaps.find(id);
- assert(iter != this->defOpMaps.end());
- return iter->second;
+ mlir::Operation* op;
+ auto iter = this->opMaps.find(id);
+ // assert(iter != this->opMaps.end());
+ if (iter != this->opMaps.end()) {
+ return iter->second;
+ }
+ return NULL;
}
-bool PluginServer::InsertDefOperation(uint64_t id, mlir::Operation* op)
+bool PluginServer::InsertOperation(uint64_t id, mlir::Operation* op)
{
- auto iter = this->defOpMaps.find(id);
- this->defOpMaps.insert({id, op});
+ auto iter = this->opMaps.find(id);
+ this->opMaps.insert({id, op});
return true;
}
@@ -312,6 +316,7 @@ void PluginServer::RunServer()
}
log->LOGI("Server ppid:%d listening on port:%s\n", getppid(), port.c_str());
ServerSemPost(port);
+
register_mutex.lock();
RegisterCallbacks();
register_mutex.unlock();
diff --git a/user/SimpleLICMPass.cpp b/user/SimpleLICMPass.cpp
new file mode 100644
index 0000000..b78f86f
--- /dev/null
+++ b/user/SimpleLICMPass.cpp
@@ -0,0 +1,359 @@
+#include "PluginAPI/ControlFlowAPI.h"
+#include "PluginAPI/PluginServerAPI.h"
+#include "user/SimpleLICMPass.h"
+#include "PluginAPI/DataFlowAPI.h"
+#include "mlir/Support/LLVM.h"
+
+namespace PluginOpt {
+using std::string;
+using std::vector;
+using std::cout;
+using std::endl;
+using namespace mlir;
+using namespace PluginAPI;
+
+PluginServerAPI pluginAPI;
+ControlFlowAPI cfAPI;
+DataFlowAPI dfAPI;
+vector<AssignOp> move_stmt;
+
+std::map<mlir::Operation*, bool> visited;
+std::map<mlir::Operation*, bool> not_move;
+
+enum EDGE_FLAG {
+ EDGE_FALLTHRU,
+ EDGE_TRUE_VALUE,
+ EDGE_FALSE_VALUE
+};
+
+struct edgeDef {
+ Block *src;
+ Block *dest;
+ unsigned destIdx;
+ enum EDGE_FLAG flag;
+};
+typedef struct edgeDef edge;
+
+static vector<Block *> getPredecessors(Block *bb)
+{
+ vector<Block *> preds;
+ for (auto it = bb->pred_begin(); it != bb->pred_end(); ++it) {
+ Block *pred = *it;
+ preds.push_back(pred);
+ }
+ return preds;
+}
+
+static enum EDGE_FLAG GetEdgeFlag(Block *src, Block *dest)
+{
+ Operation *op = src->getTerminator();
+ enum EDGE_FLAG flag;
+ if (isa<FallThroughOp>(op)) {
+ flag = EDGE_FALLTHRU;
+ }
+ if (isa<CondOp>(op)) {
+ if (op->getSuccessor(0) == dest) {
+ flag = EDGE_TRUE_VALUE;
+ } else {
+ flag = EDGE_FALSE_VALUE;
+ }
+ }
+ return flag;
+}
+
+static vector<edge> GetPredEdges(Block *bb)
+{
+ unsigned i = 0;
+ vector<edge> edges;
+ for (auto it = bb->pred_begin(); it != bb->pred_end(); ++it) {
+ Block *pred = *it;
+ edge e;
+ e.src = pred;
+ e.dest = bb;
+ e.destIdx = i;
+ e.flag = GetEdgeFlag(bb, pred);
+ edges.push_back(e);
+ i++;
+ }
+ return edges;
+}
+
+static edge GetEdge(Block *src, Block *dest)
+{
+ vector<edge> edges = GetPredEdges(dest);
+ edge e;
+ for (auto elm : edges) {
+ if (elm.src == src) {
+ e = elm;
+ break;
+ }
+ }
+ return e;
+}
+
+static uint64_t getValueId(Value v)
+{
+ uint64_t resid = 0;
+ if (auto ssaop = dyn_cast<SSAOp>(v.getDefiningOp())) {
+ resid = ssaop.id();
+ } else if (auto memop = dyn_cast<MemOp>(v.getDefiningOp())) {
+ resid = memop.id();
+ } else if (auto constop = dyn_cast<ConstOp>(v.getDefiningOp())) {
+ resid = constop.id();
+ } else if (auto holderop = dyn_cast<PlaceholderOp>(v.getDefiningOp())){
+ resid = holderop.id();
+ } else if (auto componentop = dyn_cast<ComponentOp>(v.getDefiningOp())){
+ resid = componentop.id();
+ } else if (auto declop = llvm::dyn_cast<DeclBaseOp>(v.getDefiningOp())) {
+ return declop.id();
+ }
+ return resid;
+}
+
+static IDefineCode getValueDefCode(Value v)
+{
+ IDefineCode rescode;
+ if (auto ssaop = dyn_cast<SSAOp>(v.getDefiningOp())) {
+ rescode = ssaop.defCode().getValue();
+ } else if (auto memop = dyn_cast<MemOp>(v.getDefiningOp())) {
+ rescode = memop.defCode().getValue();
+ } else if (auto constop = dyn_cast<ConstOp>(v.getDefiningOp())) {
+ rescode = constop.defCode().getValue();
+ } else {
+ auto holderop = dyn_cast<PlaceholderOp>(v.getDefiningOp());
+ rescode = holderop.defCode().getValue();
+ }
+ return rescode;
+}
+
+static bool isValueExist(Value v)
+{
+ uint64_t vid = getValueId(v);
+ if (vid != 0) {
+ return true;
+ }
+ return false;
+}
+
+static bool isSSANameVar(Value v)
+{
+ if (!isValueExist(v) || getValueDefCode(v) != IDefineCode::SSA) {
+ return false;
+ }
+ auto ssaOp = dyn_cast<SSAOp>(v.getDefiningOp());
+ uint64_t varid = ssaOp.nameVarId();
+ if (varid != 0) {
+ return true;
+ }
+ return false;
+}
+
+static Operation *getSSADefStmtofValue(Value v)
+{
+ if (!isa<SSAOp>(v.getDefiningOp())) {
+ return NULL;
+ }
+ auto ssaOp = dyn_cast<SSAOp>(v.getDefiningOp());
+ // uint64_t id = ssaOp.id();
+ // pluginAPI.DebugValue(id);
+ Operation *op = ssaOp.GetSSADefOperation();
+ if (!op || !isa<AssignOp, PhiOp>(op)) {
+ return NULL;
+ }
+ return op;
+}
+
+/* Get the edge that first entered the loop. */
+static edge getLoopPreheaderEdge(LoopOp loop)
+{
+ Block *header = loop.GetHeader();
+ vector<Block *> preds = getPredecessors(header);
+
+ Block *src;
+ for (auto bb : preds) {
+ if (bb != loop.GetLatch()) {
+ src = bb;
+ break;
+ }
+ }
+
+ edge e = GetEdge(src, header);
+
+ return e;
+}
+
+void compute_invariantness(Block* bb)
+{
+ LoopOp loop_father = pluginAPI.GetBlockLoopFather(bb);
+ if (!loop_father.outerLoopId().getValue()){
+ return ;
+ }
+ // pluginAPI.DebugBlock(bb);
+ uint64_t bbAddr = pluginAPI.FindBasicBlock(bb);
+ // 处理bb中的phi语句
+ vector<PhiOp> phis = cfAPI.GetAllPhiOpInsideBlock(bb);
+ for (auto phi : phis) {
+ Value result = phi.GetResult();
+ uint64_t varId = getValueId(result);
+ // pluginAPI.DebugOperation(phi.id());
+ int n_args = phi.nArgs();
+ if (n_args <= 2 && !pluginAPI.IsVirtualOperand(varId)) {
+ for (int i = 0 ; i < n_args; i++) {
+ Value v = phi.GetArgDef(i);
+ if (isSSANameVar(v)) {
+ Operation* def = getSSADefStmtofValue(v);
+ if (!def) break;
+ Block *def_bb = def->getBlock();
+ if (def_bb == bb && visited.find(def) == visited.end()) {
+ pluginAPI.DebugOperation(phi.id());
+ not_move[def] = true;
+ break;
+ }
+ }
+ }
+ }
+ }
+ vector<mlir::Operation*> ops = cfAPI.GetAllOpsInsideBlock(bb);
+ map<uint64_t, bool> isinvariant;
+ vector<mlir::Value> variants;
+ int i = 0 ;
+ bool change = false;
+ int n = ops.size();
+ do {
+ change = false;
+ for (auto op : ops) {
+ bool may_move = true;
+ if (not_move.find(op)!=not_move.end()) continue;
+ if(!isa<AssignOp>(op)) continue;
+ visited[op] = true;
+ auto assign = dyn_cast<AssignOp>(op);
+ int num = assign.getNumOperands();
+ Value lhs = assign.GetLHS();
+ Value rhs1 = assign.GetRHS1();
+
+ Value vdef = dfAPI.GetGimpleVdef(assign.id());
+ uint64_t vdef_id = getValueId(vdef);
+ if(vdef_id) {
+ variants.push_back(lhs);
+ not_move[op]=true;
+ change = true;
+ continue;
+ }
+ vector<mlir::Value> vals = dfAPI.GetSsaUseOperand(assign.id());
+ for (auto val : vals) {
+ Operation* def = getSSADefStmtofValue(val);
+ if(!def) continue;
+ Block *def_bb = def->getBlock();
+ if(def_bb == bb && visited.find(def) == visited.end()) {
+ not_move[def] = true;
+ not_move[op] = true;
+ change = true;
+ continue;
+ } else if (not_move.find(def)!=not_move.end()) {
+ not_move[op] = true;
+ change = true;
+ continue;
+ }
+ }
+ for (auto v : variants) {
+ if (dfAPI.RefsMayAlias(lhs, v, 1) || dfAPI.RefsMayAlias(rhs1, v, 1)) {
+ may_move = false;
+ break;
+ }
+ }
+ if(num == 3 && may_move) {
+ Value rhs2 = assign.GetRHS2();
+ for (auto v : variants) {
+ if (dfAPI.RefsMayAlias(rhs2, v, 1)) {
+ may_move = false;
+ break;
+ }
+ }
+ }
+ if(!may_move) {
+ not_move[op] = true;
+ change = true;
+ }
+ }
+ } while(change);
+
+ cout<<"move statements: "<<endl;
+ for (auto op : ops) {
+ if (not_move.find(op) != not_move.end()) continue;
+ if(!isa<AssignOp>(op)) continue;
+ auto assign = dyn_cast<AssignOp>(op);
+ pluginAPI.DebugOperation(assign.id());
+ move_stmt.push_back(assign);
+ }
+ cout<<" "<<endl;
+ cout<<move_stmt.size()<<endl;
+ return ;
+}
+
+LoopOp get_outermost_loop(AssignOp assign)
+{
+ Block *bb = assign->getBlock();
+ LoopOp loop = pluginAPI.GetBlockLoopFather(bb);
+ LoopOp maxloop;
+ vector<mlir::Value> vals = dfAPI.GetSsaUseOperand(assign.id());
+ for (auto val: vals)
+ {
+ uint64_t id = getValueId(val);
+ pluginAPI.DebugValue(id);
+ Operation* defOp = getSSADefStmtofValue(val);
+ if (!defOp) continue;
+ cout<<"none"<<endl;
+ Block *def_bb = defOp->getBlock();
+ LoopOp def_loop = pluginAPI.GetBlockLoopFather(def_bb);
+ maxloop = pluginAPI.FindCommonLoop(&loop, &def_loop);
+ if (maxloop == def_loop) {
+ maxloop = loop;
+ }
+ }
+
+ return maxloop;
+}
+
+void move_worker(AssignOp assign)
+{
+ LoopOp level = get_outermost_loop(assign);
+ if (!level) return ;
+ edge e = getLoopPreheaderEdge(level);
+
+ // TODO.
+}
+static void ProcessSimpleLICM(uint64_t fun)
+{
+ cout << "Running first pass, Loop Invariant Code Motion\n";
+
+ PluginServerAPI pluginAPI;
+ DataFlowAPI dfAPI;
+
+ dfAPI.CalDominanceInfo(1, fun);
+ mlir::Plugin::FunctionOp funcOp = pluginAPI.GetFunctionOpById(fun);
+ if (funcOp == nullptr) return;
+
+ mlir::MLIRContext * context = funcOp.getOperation()->getContext();
+ mlir::OpBuilder opBuilder_temp = mlir::OpBuilder(context);
+ mlir::OpBuilder* opBuilder = &opBuilder_temp;
+ string name = funcOp.funcNameAttr().getValue().str();
+ fprintf(stderr, "Now process func : %s \n", name.c_str());
+ vector<LoopOp> allLoop = funcOp.GetAllLoops();
+ for (auto &loop : allLoop) {
+ Block *header = loop.GetHeader();
+ // pluginAPI.DebugBlock(header);
+ compute_invariantness(header);
+ }
+ for (auto stmt : move_stmt) {
+ // TODO.
+ move_worker(stmt);
+ }
+
+}
+int SimpleLICMPass::DoOptimize(uint64_t fun)
+{
+ ProcessSimpleLICM(fun);
+ return 0;
+}
+
+}
diff --git a/user/user.cpp b/user/user.cpp
index 6bfa524..2f4c6d1 100644
--- a/user/user.cpp
+++ b/user/user.cpp
@@ -23,12 +23,15 @@
#include "user/InlineFunctionPass.h"
#include "user/LocalVarSummeryPass.h"
#include "user/StructReorder.h"
+#include "user/SimpleLICMPass.h"
void RegisterCallbacks(void)
{
- PinServer::PluginServer *pluginServer = PinServer::PluginServer::GetInstance();
- pluginServer->RegisterOpt(std::make_shared<PluginOpt::InlineFunctionPass>(PluginOpt::HANDLE_BEFORE_IPA));
- pluginServer->RegisterOpt(std::make_shared<PluginOpt::LocalVarSummeryPass>(PluginOpt::HANDLE_BEFORE_IPA));
+ PinServer::PluginServer *pluginServer = PinServer::PluginServer::GetInstance();
+ // pluginServer->RegisterOpt(std::make_shared<PluginOpt::InlineFunctionPass>(PluginOpt::HANDLE_BEFORE_IPA));
+ // pluginServer->RegisterOpt(std::make_shared<PluginOpt::LocalVarSummeryPass>(PluginOpt::HANDLE_BEFORE_IPA));
+ PluginOpt::ManagerSetup setupData(PluginOpt::PASS_LAD, 1, PluginOpt::PASS_INSERT_AFTER);
+ pluginServer->RegisterPassManagerOpt(setupData, std::make_shared<PluginOpt::SimpleLICMPass>());
// PluginOpt::ManagerSetup setupData(PluginOpt::PASS_PHIOPT, 1, PluginOpt::PASS_INSERT_AFTER);
// pluginServer->RegisterPassManagerOpt(setupData, std::make_shared<PluginOpt::ArrayWidenPass>());
// PluginOpt::ManagerSetup setupData(PluginOpt::PASS_MAC, 1, PluginOpt::PASS_INSERT_AFTER);
--
2.33.0