pin-server/0003-Pin-server-Support-build-CFG-CondOp-CallOp-PhiOp-and.patch
benniaobufeijiushiji 43622cf48f [sync] Sync patch from openEuler/pin-server
Sync patch from openEuler/pin-server 20221208

(cherry picked from commit 5b1dfc11cd542269226338319a38315ba6f5cf7e)
2022-12-09 09:52:15 +08:00

985 lines
39 KiB
Diff
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

From b721e3a4292503115329d6380b590b3bab1cd856 Mon Sep 17 00:00:00 2001
From: Mingchuan Wu <wumingchuan1992@foxmail.com>
Date: Thu, 8 Dec 2022 20:14:09 +0800
Subject: [PATCH 3/4] [Pin-server] Support build CFG, CondOp, CallOp, PhiOp and
AssignOp. Add CondOp, CallOp, PhiOp and AssignOp for PluginDialect. Now we
can support CFG.
---
include/Dialect/CMakeLists.txt | 3 +
include/Dialect/PluginDialect.td | 38 ++++
include/Dialect/PluginOps.h | 4 +
include/Dialect/PluginOps.td | 142 ++++++++++++++-
include/PluginAPI/BasicPluginOpsAPI.h | 5 +
include/PluginAPI/PluginServerAPI.h | 10 +-
include/PluginServer/PluginServer.h | 22 +++
lib/Dialect/PluginOps.cpp | 150 +++++++++++++++-
lib/PluginAPI/PluginServerAPI.cpp | 60 ++++++-
lib/PluginServer/PluginServer.cpp | 250 ++++++++++++++++++++++++--
user.cpp | 2 +-
11 files changed, 663 insertions(+), 23 deletions(-)
diff --git a/include/Dialect/CMakeLists.txt b/include/Dialect/CMakeLists.txt
index f0dfc58..f44e77c 100644
--- a/include/Dialect/CMakeLists.txt
+++ b/include/Dialect/CMakeLists.txt
@@ -1,5 +1,8 @@
# Add for the dialect operations. arg:(dialect dialect_namespace)
file(COPY /usr/bin/mlir-tblgen DESTINATION ./)
+set(LLVM_TARGET_DEFINITIONS PluginOps.td)
+mlir_tablegen(PluginOpsEnums.h.inc -gen-enum-decls)
+mlir_tablegen(PluginOpsEnums.cpp.inc -gen-enum-defs)
add_mlir_dialect(PluginOps Plugin)
# Necessary to generate documentation. arg:(doc_filename command output_file output_directory)
diff --git a/include/Dialect/PluginDialect.td b/include/Dialect/PluginDialect.td
index 362a5c8..31b338d 100644
--- a/include/Dialect/PluginDialect.td
+++ b/include/Dialect/PluginDialect.td
@@ -46,4 +46,42 @@ def Plugin_Dialect : Dialect {
class Plugin_Op<string mnemonic, list<OpTrait> traits = []> :
Op<Plugin_Dialect, mnemonic, traits>;
+//===----------------------------------------------------------------------===//
+// PluginDialect enum definitions
+//===----------------------------------------------------------------------===//
+
+def IComparisonLT : I32EnumAttrCase<"lt", 0>;
+def IComparisonLE : I32EnumAttrCase<"le", 1>;
+def IComparisonGT : I32EnumAttrCase<"gt", 2>;
+def IComparisonGE : I32EnumAttrCase<"ge", 3>;
+def IComparisonLTGT : I32EnumAttrCase<"ltgt", 4>;
+def IComparisonEQ : I32EnumAttrCase<"eq", 5>;
+def IComparisonNE : I32EnumAttrCase<"ne", 6>;
+def IComparisonUNDEF : I32EnumAttrCase<"UNDEF", 7>;
+def IComparisonAttr : I32EnumAttr<
+ "IComparisonCode", "plugin comparison code",
+ [IComparisonLT, IComparisonLE, IComparisonGT, IComparisonGE,
+ IComparisonLTGT, IComparisonEQ, IComparisonNE, IComparisonUNDEF]>{
+ let cppNamespace = "::mlir::Plugin";
+}
+
+def IDefineCodeMemRef : I32EnumAttrCase<"MemRef", 0>;
+def IDefineCodeIntCST : I32EnumAttrCase<"IntCST", 1>;
+def IDefineCodeUNDEF : I32EnumAttrCase<"UNDEF", 2>;
+def IDefineCodeAttr : I32EnumAttr<
+ "IDefineCode", "plugin define code",
+ [IDefineCodeMemRef, IDefineCodeIntCST, IDefineCodeUNDEF]>{
+ let cppNamespace = "::mlir::Plugin";
+}
+
+def IExprCodePlus : I32EnumAttrCase<"Plus", 0>;
+def IExprCodeMinus : I32EnumAttrCase<"Minus", 1>;
+def IExprCodeNop : I32EnumAttrCase<"Nop", 2>;
+def IExprCodeUNDEF : I32EnumAttrCase<"UNDEF", 3>;
+def IExprCodeAttr : I32EnumAttr<
+ "IExprCode", "plugin expr code",
+ [IExprCodePlus, IExprCodeMinus, IExprCodeNop, IExprCodeUNDEF]>{
+ let cppNamespace = "::mlir::Plugin";
+}
+
#endif // PLUGIN_DIALECT_TD
\ No newline at end of file
diff --git a/include/Dialect/PluginOps.h b/include/Dialect/PluginOps.h
index be2a3ef..8bab877 100644
--- a/include/Dialect/PluginOps.h
+++ b/include/Dialect/PluginOps.h
@@ -27,6 +27,10 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+
+// Pull in all enum type definitions and utility function declarations.
+#include "Dialect/PluginOpsEnums.h.inc"
#define GET_OP_CLASSES
#include "Dialect/PluginOps.h.inc"
diff --git a/include/Dialect/PluginOps.td b/include/Dialect/PluginOps.td
index 02d9bdd..e91cd84 100644
--- a/include/Dialect/PluginOps.td
+++ b/include/Dialect/PluginOps.td
@@ -28,8 +28,8 @@ def FunctionOp : Plugin_Op<"function", [NoSideEffect]> {
TODO.
}];
- let arguments = (ins OptionalAttr<UI64Attr>:$id,
- OptionalAttr<StrAttr>:$funcName,
+ let arguments = (ins UI64Attr:$id,
+ StrAttr:$funcName,
OptionalAttr<BoolAttr>:$declaredInline);
let regions = (region AnyRegion:$bodyRegion);
@@ -89,4 +89,142 @@ def LoopOp : Plugin_Op<"loop", [NoSideEffect]> {
}];
}
+def CallOp : Plugin_Op<"call", [
+ DeclareOpInterfaceMethods<CallOpInterface>]> {
+ let summary = "call operation";
+ let description = [{
+ CallOp represent calls to a user defined function that needs to
+ be specialized for the shape of its arguments.
+ The callee name is attached as a symbol reference via an attribute.
+ The arguments list must match the arguments expected by the callee.
+ }];
+ let arguments = (ins UI64Attr:$id,
+ FlatSymbolRefAttr:$callee,
+ Variadic<AnyType>:$inputs);
+ let results = (outs Optional<AnyType>:$result);
+ let builders = [
+ OpBuilderDAG<(ins "uint64_t":$id, "StringRef":$callee,
+ "ArrayRef<Value>":$arguments)>
+ ];
+ let extraClassDeclaration = [{
+ bool SetLHS(Value lhs);
+ }];
+ let assemblyFormat = [{
+ $callee `(` $inputs `)` attr-dict `:` functional-type($inputs, results)
+ }];
+}
+
+def PhiOp : Plugin_Op<"phi", [NoSideEffect]> {
+ let summary = "phi op";
+ let description = [{TODO}];
+ let arguments = (ins UI64Attr:$id,
+ UI32Attr:$capacity,
+ UI32Attr:$nArgs,
+ Variadic<AnyType>:$operands);
+ let results = (outs AnyType:$result);
+ let builders = [
+ OpBuilderDAG<(ins "uint64_t":$id, "uint32_t":$capacity,
+ "uint32_t":$nArgs, "ArrayRef<Value>":$operands,
+ "Type":$resultType)>
+ ];
+ let extraClassDeclaration = [{
+ Value GetResult();
+ Value GetArgDef(int i) { return getOperand(i); }
+ }];
+}
+
+def AssignOp : Plugin_Op<"assign", [NoSideEffect]> {
+ let summary = "assign op";
+ let description = [{TODO}];
+ let arguments = (ins UI64Attr:$id,
+ IExprCodeAttr:$exprCode,
+ Variadic<AnyType>:$operands);
+ let results = (outs AnyType:$result);
+ let builders = [
+ OpBuilderDAG<(ins "uint64_t":$id, "IExprCode":$exprCode,
+ "ArrayRef<Value>":$operands, "Type":$resultType)>
+ ];
+ let extraClassDeclaration = [{
+ Value GetLHS() { return getOperand(0); }
+ Value GetRHS1() { return getOperand(1); }
+ Value GetRHS2() { return getOperand(2); }
+ }];
+}
+
+def PlaceholderOp : Plugin_Op<"placeholder", [NoSideEffect]> {
+ let summary = "PlaceHolder.";
+ let description = [{TODO}];
+ let arguments = (ins UI64Attr:$id,
+ OptionalAttr<IDefineCodeAttr>:$defCode);
+ let results = (outs AnyType);
+ let builders = [
+ OpBuilderDAG<(ins "uint64_t":$id, "IDefineCode":$defCode, "Type":$retType)>
+ ];
+}
+
+def BaseOp : Plugin_Op<"statement_base", [NoSideEffect]> {
+ let summary = "Base operation, just like placeholder for statement.";
+ let description = [{TODO}];
+ let arguments = (ins UI64Attr:$id, StrAttr:$opCode);
+ let results = (outs AnyType);
+ let builders = [
+ OpBuilderDAG<(ins "uint64_t":$id, "StringRef":$opCode)>
+ ];
+}
+
+// Terminators
+// Opaque builder used for terminator operations that contain successors.
+
+class Plugin_TerminatorOp<string mnemonic, list<OpTrait> traits = []> :
+ Plugin_Op<mnemonic, !listconcat(traits, [Terminator])>;
+
+def FallThroughOp : Plugin_TerminatorOp<"fallthrough", [NoSideEffect]> {
+ let summary = "FallThroughOp";
+ let description = [{TODO}];
+ let successors = (successor AnySuccessor:$dest);
+ // for bb address
+ let arguments = (ins UI64Attr:$address, UI64Attr:$destaddr);
+ let results = (outs AnyType);
+ let builders = [
+ OpBuilderDAG<(ins "uint64_t":$address, "Block*":$dest, "uint64_t":$destaddr)>
+ ];
+}
+
+def CondOp : Plugin_TerminatorOp<"condition", [NoSideEffect]> {
+ let summary = "condition op";
+ let description = [{TODO}];
+ let arguments = (ins UI64Attr:$id, UI64Attr:$address,
+ IComparisonAttr:$condCode,
+ AnyType:$LHS, AnyType:$RHS,
+ UI64Attr:$tbaddr,
+ UI64Attr:$fbaddr,
+ OptionalAttr<TypeAttr>:$trueLabel,
+ OptionalAttr<TypeAttr>:$falseLabel);
+ let successors = (successor AnySuccessor:$tb, AnySuccessor:$fb);
+ let builders = [
+ OpBuilderDAG<(ins "uint64_t":$id, "uint64_t":$address, "IComparisonCode":$condCode,
+ "Value":$lhs, "Value":$rhs, "Block*":$tb, "Block*":$fb,
+ "uint64_t":$tbaddr, "uint64_t":$fbaddr, "Value":$trueLabel,
+ "Value":$falseLabel)>,
+ // Only for server.
+ OpBuilderDAG<(ins "IComparisonCode":$condCode,
+ "Value":$lhs, "Value":$rhs)>
+ ];
+ let extraClassDeclaration = [{
+ Value GetLHS() { return getOperand(0); }
+ Value GetRHS() { return getOperand(1); }
+ }];
+}
+
+// todo: currently RetOp do not have a correct assemblyFormat
+def RetOp : Plugin_TerminatorOp<"ret", [NoSideEffect]> {
+ let summary = "RetOp";
+ let description = [{TODO}];
+ let arguments = (ins UI64Attr:$address); // for bb address
+ let results = (outs AnyType);
+ let builders = [
+ OpBuilderDAG<(ins "uint64_t":$address)>
+ ];
+}
+
#endif // PLUGIN_OPS_TD
\ No newline at end of file
diff --git a/include/PluginAPI/BasicPluginOpsAPI.h b/include/PluginAPI/BasicPluginOpsAPI.h
index fde5f63..6f5fbb0 100644
--- a/include/PluginAPI/BasicPluginOpsAPI.h
+++ b/include/PluginAPI/BasicPluginOpsAPI.h
@@ -48,6 +48,11 @@ public:
virtual pair<uint64_t, uint64_t> LoopSingleExit(uint64_t) = 0;
virtual vector<pair<uint64_t, uint64_t> > GetLoopExitEdges(uint64_t) = 0;
virtual LoopOp GetBlockLoopFather(uint64_t) = 0;
+ virtual PhiOp GetPhiOp(uint64_t) = 0;
+ virtual CallOp GetCallOp(uint64_t) = 0;
+ virtual bool SetLhsInCallOp(uint64_t, uint64_t) = 0;
+ virtual uint64_t CreateCondOp(IComparisonCode, uint64_t, uint64_t) = 0;
+ virtual mlir::Value GetResultFromPhi(uint64_t) = 0;
}; // class BasicPluginOpsAPI
} // namespace PluginAPI
diff --git a/include/PluginAPI/PluginServerAPI.h b/include/PluginAPI/PluginServerAPI.h
index 5b14f2e..425b946 100644
--- a/include/PluginAPI/PluginServerAPI.h
+++ b/include/PluginAPI/PluginServerAPI.h
@@ -38,6 +38,8 @@ public:
vector<FunctionOp> GetAllFunc() override;
vector<LocalDeclOp> GetDecls(uint64_t) override;
+ PhiOp GetPhiOp(uint64_t) override;
+ CallOp GetCallOp(uint64_t) override;
PluginIR::PluginTypeID GetTypeCodeFromString(string type);
LoopOp AllocateNewLoop(uint64_t funcID);
LoopOp GetLoopById(uint64_t loopID);
@@ -51,8 +53,14 @@ public:
uint64_t GetLatch(uint64_t loopID);
vector<uint64_t> GetLoopBody(uint64_t loopID);
LoopOp GetBlockLoopFather(uint64_t blockID);
+ /* Plugin API for CallOp. */
+ bool SetLhsInCallOp(uint64_t, uint64_t);
+ /* Plugin API for CondOp. */
+ uint64_t CreateCondOp(IComparisonCode, uint64_t, uint64_t) override;
+ mlir::Value GetResultFromPhi(uint64_t) override;
+
private:
- vector<FunctionOp> GetOperationResult(const string& funName, const string& params);
+ vector<FunctionOp> GetFunctionOpResult(const string& funName, const string& params);
vector<LocalDeclOp> GetDeclOperationResult(const string& funName, const string& params);
LoopOp GetLoopResult(const string&funName, const string& params);
vector<LoopOp> GetLoopsResult(const string& funName, const string& params);
diff --git a/include/PluginServer/PluginServer.h b/include/PluginServer/PluginServer.h
index 610ecb9..b28c6d0 100644
--- a/include/PluginServer/PluginServer.h
+++ b/include/PluginServer/PluginServer.h
@@ -33,6 +33,7 @@
#include "Dialect/PluginOps.h"
#include "plugin.grpc.pb.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Builders.h"
#include "Dialect/PluginTypes.h"
namespace PinServer {
@@ -120,6 +121,7 @@ private:
class PluginServer final : public PluginService::Service {
public:
+ PluginServer() : opBuilder(&context){}
/* 定义的grpc服务端和客户端通信的接口函数 */
Status ReceiveSendMsg(ServerContext* context, ServerReaderWriter<ServerMsg, ClientMsg>* stream) override;
/* 服务端发送数据给client接口 */
@@ -137,6 +139,10 @@ public:
vector<std::pair<uint64_t, uint64_t> > EdgesResult(void);
std::pair<uint64_t, uint64_t> EdgeResult(void);
bool BoolResult(void);
+ vector<mlir::Operation *> GetOpResult(void);
+ bool GetBoolResult(void);
+ uint64_t GetIdResult(void);
+ mlir::Value GetValueResult(void);
/* 回调函数接口用于向server注册用户需要执行的函数 */
int RegisterUserFunc(InjectPoint inject, UserFunc func);
int RegisterPassManagerSetup(InjectPoint inject, const ManagerSetupData& passData, UserFunc func);
@@ -190,6 +196,13 @@ public:
void EdgeJsonDeSerialize(const string& data);
void BlocksJsonDeSerialize(const string& data);
void BlockJsonDeSerialize(const string& data);
+ void CallOpJsonDeSerialize(const string& data);
+ void CondOpJsonDeSerialize(const string& data);
+ void RetOpJsonDeSerialize(const string& data);
+ void FallThroughOpJsonDeSerialize(const string& data);
+ void PhiOpJsonDeSerialize(const string& data);
+ void AssignOpJsonDeSerialize(const string& data);
+ mlir::Value ValueJsonDeSerialize(Json::Value valueJson);
/* json反序列化根据key值分别调用Operation/Decl/Type反序列化接口函数 */
void JsonDeSerialize(const string& key, const string& data);
/* 解析客户端发送过来的-fplugin-arg参数并保存在私有变量args中 */
@@ -234,6 +247,7 @@ private:
/* 用户函数执行状态client返回结果后为STATE_RETURN,开始执行下一个函数 */
volatile UserFunStateEnum userFunState;
mlir::MLIRContext context;
+ mlir::OpBuilder opBuilder;
vector<mlir::Plugin::FunctionOp> funcOpData;
PluginIR::PluginTypeBase pluginType;
vector<mlir::Plugin::LocalDeclOp> decls;
@@ -244,6 +258,10 @@ private:
vector<uint64_t> blockIds;
uint64_t blockId;
bool boolRes;
+ vector<mlir::Operation *> opData;
+ bool boolResult;
+ bool idResult;
+ mlir::Value valueResult;
/* 保存用户注册的回调函数,它们将在注入点事件触发后调用 */
map<InjectPoint, vector<RecordedUserFunc>> userFunc;
string apiFuncName; // 保存用户调用PluginAPI的函数名
@@ -252,6 +270,10 @@ private:
timer_t timerId;
map<string, string> args; // 保存gcc编译时用户传入参数
sem_t sem[2];
+
+ // process Block.
+ std::map<uint64_t, mlir::Block*> blockMaps;
+ bool ProcessBlock(mlir::Block*, mlir::Region&, const Json::Value&);
}; // class PluginServer
void RunServer(int timeout, string& port);
diff --git a/lib/Dialect/PluginOps.cpp b/lib/Dialect/PluginOps.cpp
index 481d51a..d647fc4 100644
--- a/lib/Dialect/PluginOps.cpp
+++ b/lib/Dialect/PluginOps.cpp
@@ -33,8 +33,9 @@ using namespace mlir::Plugin;
using std::vector;
using std::pair;
-void FunctionOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
- uint64_t id, StringRef funcName, bool declaredInline) {
+void FunctionOp::build(OpBuilder &builder, OperationState &state,
+ uint64_t id, StringRef funcName, bool declaredInline)
+{
FunctionOp::build(builder, state,
builder.getI64IntegerAttr(id),
builder.getStringAttr(funcName),
@@ -48,9 +49,10 @@ vector<LoopOp> FunctionOp::GetAllLoops()
return pluginAPI.GetLoopsFromFunc(funcId);
}
-void LocalDeclOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
+void LocalDeclOp::build(OpBuilder &builder, OperationState &state,
uint64_t id, StringRef symName,
- int64_t typeID, uint64_t typeWidth) {
+ int64_t typeID, uint64_t typeWidth)
+{
LocalDeclOp::build(builder, state,
builder.getI64IntegerAttr(id),
builder.getStringAttr(symName),
@@ -142,6 +144,146 @@ void LoopOp::AddLoop(uint64_t outerId, uint64_t funcId)
return pluginAPI.AddLoop(loopId, outerId, funcId);
}
+//===----------------------------------------------------------------------===//
+// PlaceholderOp
+
+void PlaceholderOp::build(OpBuilder &builder, OperationState &state,
+ uint64_t id, IDefineCode defCode, Type retType) {
+ state.addAttribute("id", builder.getI64IntegerAttr(id));
+ state.addAttribute("defCode",
+ builder.getI32IntegerAttr(static_cast<int32_t>(defCode)));
+ if (retType) state.addTypes(retType);
+}
+
+//===----------------------------------------------------------------------===//
+// CallOp
+
+void CallOp::build(OpBuilder &builder, OperationState &state,
+ uint64_t id, StringRef callee,
+ ArrayRef<Value> arguments)
+{
+ state.addAttribute("id", builder.getI64IntegerAttr(id));
+ state.addOperands(arguments);
+ state.addAttribute("callee", builder.getSymbolRefAttr(callee));
+}
+
+/// Return the callee of the generic call operation, this is required by the
+/// call interface.
+CallInterfaceCallable CallOp::getCallableForCallee()
+{
+ return (*this)->getAttrOfType<SymbolRefAttr>("callee");
+}
+
+/// Get the argument operands to the called function, this is required by the
+/// call interface.
+Operation::operand_range CallOp::getArgOperands() { return inputs(); }
+
+bool CallOp::SetLHS(Value lhs)
+{
+ PlaceholderOp phOp = lhs.getDefiningOp<PlaceholderOp>();
+ uint64_t lhsId = phOp.idAttr().getInt();
+ PluginAPI::PluginServerAPI pluginAPI;
+ return pluginAPI.SetLhsInCallOp(this->idAttr().getInt(), lhsId);
+}
+
+//===----------------------------------------------------------------------===//
+// CondOp
+
+void CondOp::build(OpBuilder &builder, OperationState &state,
+ uint64_t id, uint64_t address, IComparisonCode condCode,
+ Value lhs, Value rhs, Block* tb, Block* fb, uint64_t tbaddr,
+ uint64_t fbaddr, Value trueLabel, Value falseLabel) {
+ state.addAttribute("id", builder.getI64IntegerAttr(id));
+ state.addAttribute("address", builder.getI64IntegerAttr(address));
+ state.addAttribute("tbaddr", builder.getI64IntegerAttr(tbaddr));
+ state.addAttribute("fbaddr", builder.getI64IntegerAttr(fbaddr));
+ state.addOperands({lhs, rhs});
+ state.addSuccessors(tb);
+ state.addSuccessors(fb);
+ state.addAttribute("condCode",
+ builder.getI32IntegerAttr(static_cast<int32_t>(condCode)));
+ if (trueLabel != nullptr) state.addOperands(trueLabel);
+ if (falseLabel != nullptr) state.addOperands(falseLabel);
+}
+
+void CondOp::build(OpBuilder &builder, OperationState &state,
+ IComparisonCode condCode, Value lhs, Value rhs)
+{
+ PluginAPI::PluginServerAPI pluginAPI;
+ PlaceholderOp lhsOp = lhs.getDefiningOp<PlaceholderOp>();
+ uint64_t lhsId = lhsOp.idAttr().getInt();
+ PlaceholderOp rhsOp = rhs.getDefiningOp<PlaceholderOp>();
+ uint64_t rhsId = rhsOp.idAttr().getInt();
+ uint64_t id = pluginAPI.CreateCondOp(condCode, lhsId, rhsId);
+ state.addAttribute("id", builder.getI64IntegerAttr(id));
+ state.addOperands({lhs, rhs});
+ state.addAttribute("condCode",
+ builder.getI32IntegerAttr(static_cast<int32_t>(condCode)));
+}
+
+//===----------------------------------------------------------------------===//
+// PhiOp
+
+void PhiOp::build(OpBuilder &builder, OperationState &state,
+ uint64_t id, uint32_t capacity, uint32_t nArgs,
+ ArrayRef<Value> operands, Type resultType)
+{
+ state.addAttribute("id", builder.getI64IntegerAttr(id));
+ state.addAttribute("capacity", builder.getI32IntegerAttr(capacity));
+ state.addAttribute("nArgs", builder.getI32IntegerAttr(nArgs));
+ state.addOperands(operands);
+ if (resultType) state.addTypes(resultType);
+}
+
+Value PhiOp::GetResult()
+{
+ PluginAPI::PluginServerAPI pluginAPI;
+ return pluginAPI.GetResultFromPhi(this->idAttr().getInt());
+}
+
+//===----------------------------------------------------------------------===//
+// AssignOp
+
+void AssignOp::build(OpBuilder &builder, OperationState &state,
+ uint64_t id, IExprCode exprCode,
+ ArrayRef<Value> operands, Type resultType)
+{
+ state.addAttribute("id", builder.getI64IntegerAttr(id));
+ state.addAttribute("exprCode",
+ builder.getI32IntegerAttr(static_cast<int32_t>(exprCode)));
+ state.addOperands(operands);
+ if (resultType) state.addTypes(resultType);
+}
+
+//===----------------------------------------------------------------------===//
+// BaseOp
+
+void BaseOp::build(OpBuilder &builder, OperationState &state,
+ uint64_t id, StringRef opCode)
+{
+ state.addAttribute("id", builder.getI64IntegerAttr(id));
+ state.addAttribute("opCode", builder.getStringAttr(opCode));
+}
+
+//===----------------------------------------------------------------------===//
+// FallThroughOp
+
+void FallThroughOp::build(OpBuilder &builder, OperationState &state,
+ uint64_t address, Block* dest, uint64_t destaddr)
+{
+ state.addAttribute("address", builder.getI64IntegerAttr(address));
+ state.addAttribute("destaddr", builder.getI64IntegerAttr(destaddr));
+ state.addSuccessors(dest);
+}
+
+//===----------------------------------------------------------------------===//
+// RetOp
+
+void RetOp::build(OpBuilder &builder, OperationState &state, uint64_t address)
+{
+ state.addAttribute("address", builder.getI64IntegerAttr(address));
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/lib/PluginAPI/PluginServerAPI.cpp b/lib/PluginAPI/PluginServerAPI.cpp
index 41626f9..65eafa7 100644
--- a/lib/PluginAPI/PluginServerAPI.cpp
+++ b/lib/PluginAPI/PluginServerAPI.cpp
@@ -54,7 +54,7 @@ void PluginServerAPI::WaitClientResult(const string& funName, const string& para
}
}
-vector<FunctionOp> PluginServerAPI::GetOperationResult(const string& funName, const string& params)
+vector<FunctionOp> PluginServerAPI::GetFunctionOpResult(const string& funName, const string& params)
{
WaitClientResult(funName, params);
vector<FunctionOp> retOps = PluginServer::GetInstance()->GetFunctionOpResult();
@@ -67,7 +67,63 @@ vector<FunctionOp> PluginServerAPI::GetAllFunc()
string funName = __func__;
string params = root.toStyledString();
- return GetOperationResult(funName, params);
+ return GetFunctionOpResult(funName, params);
+}
+
+PhiOp PluginServerAPI::GetPhiOp(uint64_t id)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["id"] = std::to_string(id);
+ string params = root.toStyledString();
+ WaitClientResult(funName, params);
+ vector<mlir::Operation*> opRet = PluginServer::GetInstance()->GetOpResult();
+ return llvm::dyn_cast<PhiOp>(opRet[0]);
+}
+
+CallOp PluginServerAPI::GetCallOp(uint64_t id)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["id"] = std::to_string(id);
+ string params = root.toStyledString();
+ WaitClientResult(funName, params);
+ vector<mlir::Operation*> opRet = PluginServer::GetInstance()->GetOpResult();
+ return llvm::dyn_cast<CallOp>(opRet[0]);
+}
+
+bool PluginServerAPI::SetLhsInCallOp(uint64_t callId, uint64_t lhsId)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["callId"] = std::to_string(callId);
+ root["lhsId"] = std::to_string(lhsId);
+ string params = root.toStyledString();
+ WaitClientResult(funName, params);
+ return PluginServer::GetInstance()->GetBoolResult();
+}
+
+uint64_t PluginServerAPI::CreateCondOp(IComparisonCode iCode,
+ uint64_t lhs, uint64_t rhs)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["condCode"] = std::to_string(static_cast<int32_t>(iCode));
+ root["lhsId"] = std::to_string(lhs);
+ root["rhsId"] = std::to_string(rhs);
+ string params = root.toStyledString();
+ WaitClientResult(funName, params);
+ return PluginServer::GetInstance()->GetIdResult();
+}
+
+mlir::Value PluginServerAPI::GetResultFromPhi(uint64_t phiId)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["id"] = std::to_string(phiId);
+ string params = root.toStyledString();
+ WaitClientResult(funName, params);
+ return PluginServer::GetInstance()->GetValueResult();
}
PluginIR::PluginTypeID PluginServerAPI::GetTypeCodeFromString(string type)
diff --git a/lib/PluginServer/PluginServer.cpp b/lib/PluginServer/PluginServer.cpp
index 92159e8..2eac906 100644
--- a/lib/PluginServer/PluginServer.cpp
+++ b/lib/PluginServer/PluginServer.cpp
@@ -27,7 +27,6 @@
#include "Dialect/PluginDialect.h"
#include "PluginAPI/PluginServerAPI.h"
#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "user.h"
@@ -37,7 +36,6 @@
namespace PinServer {
using namespace mlir::Plugin;
-
using std::cout;
using std::endl;
using std::pair;
@@ -75,6 +73,13 @@ int PluginServer::RegisterPassManagerSetup(InjectPoint inject, const ManagerSetu
return 0;
}
+vector<mlir::Operation *> PluginServer::GetOpResult(void)
+{
+ vector<mlir::Operation *> retOps = opData;
+ opData.clear();
+ return retOps;
+}
+
vector<mlir::Plugin::FunctionOp> PluginServer::GetFunctionOpResult(void)
{
vector<mlir::Plugin::FunctionOp> retOps = funcOpData;
@@ -135,6 +140,21 @@ vector<pair<uint64_t, uint64_t> > PluginServer::EdgesResult()
return retEdges;
}
+bool PluginServer::GetBoolResult()
+{
+ return this->boolResult;
+}
+
+uint64_t PluginServer::GetIdResult()
+{
+ return this->idResult;
+}
+
+mlir::Value PluginServer::GetValueResult()
+{
+ return this->valueResult;
+}
+
void PluginServer::JsonGetAttributes(Json::Value node, map<string, string>& attributes)
{
Json::Value::Members attMember = node.getMemberNames();
@@ -151,6 +171,27 @@ static uintptr_t GetID(Json::Value node)
return atol(id.c_str());
}
+mlir::Value PluginServer::ValueJsonDeSerialize(Json::Value valueJson)
+{
+ uint64_t opId = GetID(valueJson["id"]);
+ IDefineCode defCode = IDefineCode(
+ atoi(valueJson["defCode"].asString().c_str()));
+ switch (defCode) {
+ case IDefineCode::MemRef : {
+ break;
+ }
+ case IDefineCode::IntCST : {
+ break;
+ }
+ default:
+ break;
+ }
+ mlir::Type retType = PluginIR::PluginUndefType::get(&context); // FIXME!
+ mlir::Value opValue = opBuilder.create<PlaceholderOp>(
+ opBuilder.getUnknownLoc(), opId, defCode, retType);
+ return opValue;
+}
+
void PluginServer::JsonDeSerialize(const string& key, const string& data)
{
if (key == "FuncOpResult") {
@@ -173,6 +214,12 @@ void PluginServer::JsonDeSerialize(const string& key, const string& data)
EdgesJsonDeSerialize(data);
} else if (key == "BlockIdsResult") {
BlocksJsonDeSerialize(data);
+ } else if (key == "IdResult") {
+ this->idResult = atol(data.c_str());
+ } else if (key == "ValueResult") {
+ context.getOrLoadDialect<PluginDialect>();
+ opBuilder = mlir::OpBuilder(&context);
+ this->valueResult = ValueJsonDeSerialize(data.c_str());
} else {
cout << "not Json,key:" << key << ",value:" << data << endl;
}
@@ -217,6 +264,46 @@ void PluginServer::TypeJsonDeSerialize(const string& data)
return;
}
+bool PluginServer::ProcessBlock(mlir::Block* block, mlir::Region& rg,
+ const Json::Value& blockJson)
+{
+ if (blockJson.isNull()) {
+ return false;
+ }
+ // todo process func return type
+ // todo isDeclaration
+
+ // process each stmt
+ opBuilder.setInsertionPointToStart(block);
+ Json::Value::Members opMember = blockJson.getMemberNames();
+ for (Json::Value::Members::iterator opIdx = opMember.begin();
+ opIdx != opMember.end(); opIdx++) {
+ string baseOpKey = *opIdx;
+ Json::Value opJson = blockJson[baseOpKey];
+ if (opJson.isNull()) continue;
+ // llvm::StringRef opCode(opJson["OperationName"].asString().c_str());
+ string opCode = opJson["OperationName"].asString();
+ if (opCode == PhiOp::getOperationName().str()) {
+ PhiOpJsonDeSerialize(opJson.toStyledString());
+ } else if (opCode == CallOp::getOperationName().str()) {
+ CallOpJsonDeSerialize(opJson.toStyledString());
+ } else if (opCode == AssignOp::getOperationName().str()) {
+ AssignOpJsonDeSerialize(opJson.toStyledString());
+ } else if (opCode == CondOp::getOperationName().str()) {
+ CondOpJsonDeSerialize(opJson.toStyledString());
+ } else if (opCode == RetOp::getOperationName().str()) {
+ RetOpJsonDeSerialize(opJson.toStyledString());
+ } else if (opCode == FallThroughOp::getOperationName().str()) {
+ FallThroughOpJsonDeSerialize(opJson.toStyledString());
+ } else if (opCode == BaseOp::getOperationName().str()) {
+ uint64_t opID = GetID(opJson["id"]);
+ opBuilder.create<BaseOp>(opBuilder.getUnknownLoc(), opID, opCode);
+ }
+ }
+ // fprintf(stderr, "[bb] op:%ld, succ: %d\n", block->getOperations().size(), block->getNumSuccessors());
+ return true;
+}
+
void PluginServer::FuncOpJsonDeSerialize(const string& data)
{
Json::Value root;
@@ -227,8 +314,9 @@ void PluginServer::FuncOpJsonDeSerialize(const string& data)
Json::Value::Members operation = root.getMemberNames();
context.getOrLoadDialect<PluginDialect>();
- mlir::OpBuilder builder(&context);
- for (Json::Value::Members::iterator iter = operation.begin(); iter != operation.end(); iter++) {
+ opBuilder = mlir::OpBuilder(&context);
+ for (Json::Value::Members::iterator iter = operation.begin();
+ iter != operation.end(); iter++) {
string operationKey = *iter;
node = root[operationKey];
int64_t id = GetID(node["id"]);
@@ -237,9 +325,30 @@ void PluginServer::FuncOpJsonDeSerialize(const string& data)
JsonGetAttributes(attributes, funcAttributes);
bool declaredInline = false;
if (funcAttributes["declaredInline"] == "1") declaredInline = true;
- auto location = builder.getUnknownLoc();
- FunctionOp op = builder.create<FunctionOp>(location, id, funcAttributes["funcName"], declaredInline);
- funcOpData.push_back(op);
+ auto location = opBuilder.getUnknownLoc();
+ FunctionOp fOp = opBuilder.create<FunctionOp>(
+ location, id, funcAttributes["funcName"], declaredInline);
+ mlir::Region &bodyRegion = fOp.bodyRegion();
+ Json::Value regionJson = node["region"];
+ Json::Value::Members bbMember = regionJson.getMemberNames();
+ // We must create Blocks before process ops
+ for (Json::Value::Members::iterator bbIdx = bbMember.begin();
+ bbIdx != bbMember.end(); bbIdx++) {
+ string blockKey = *bbIdx;
+ Json::Value blockJson = regionJson[blockKey];
+ mlir::Block* block = opBuilder.createBlock(&bodyRegion);
+ this->blockMaps.insert({GetID(blockJson["address"]), block});
+ }
+
+ for (Json::Value::Members::iterator bbIdx = bbMember.begin();
+ bbIdx != bbMember.end(); bbIdx++) {
+ string blockKey = *bbIdx;
+ Json::Value blockJson = regionJson[blockKey];
+ uint64_t bbAddress = GetID(blockJson["address"]);
+ ProcessBlock(this->blockMaps[bbAddress], bodyRegion, blockJson["ops"]);
+ }
+ funcOpData.push_back(fOp);
+ opBuilder.setInsertionPointAfter(fOp.getOperation());
}
}
@@ -253,7 +362,7 @@ void PluginServer::LocalDeclOpJsonDeSerialize(const string& data)
Json::Value::Members operation = root.getMemberNames();
context.getOrLoadDialect<PluginDialect>();
- mlir::OpBuilder builder(&context);
+ opBuilder = mlir::OpBuilder(&context);
for (Json::Value::Members::iterator iter = operation.begin(); iter != operation.end(); iter++) {
string operationKey = *iter;
node = root[operationKey];
@@ -261,11 +370,11 @@ void PluginServer::LocalDeclOpJsonDeSerialize(const string& data)
Json::Value attributes = node["attributes"];
map<string, string> declAttributes;
JsonGetAttributes(attributes, declAttributes);
- string symName = declAttributes["symName"];
- uint64_t typeID = atol(declAttributes["typeID"].c_str());
- uint64_t typeWidth = atol(declAttributes["typeWidth"].c_str());
- auto location = builder.getUnknownLoc();
- LocalDeclOp op = builder.create<LocalDeclOp>(location, id, symName, typeID, typeWidth);
+ string symName = declAttributes["symName"];
+ uint64_t typeID = atol(declAttributes["typeID"].c_str());
+ uint64_t typeWidth = atol(declAttributes["typeWidth"].c_str());
+ auto location = opBuilder.getUnknownLoc();
+ LocalDeclOp op = opBuilder.create<LocalDeclOp>(location, id, symName, typeID, typeWidth);
decls.push_back(op);
}
}
@@ -384,6 +493,121 @@ void PluginServer::BlockJsonDeSerialize(const string& data)
blockId = (uint64_t)atol(root["id"].asString().c_str());
}
+void PluginServer::CallOpJsonDeSerialize(const string& data)
+{
+ Json::Value node;
+ Json::Reader reader;
+ reader.parse(data, node);
+ Json::Value operandJson = node["operands"];
+ Json::Value::Members operandMember = operandJson.getMemberNames();
+ llvm::SmallVector<mlir::Value, 4> ops;
+ for (Json::Value::Members::iterator opIter = operandMember.begin();
+ opIter != operandMember.end(); opIter++) {
+ string key = *opIter;
+ mlir::Value opValue = ValueJsonDeSerialize(operandJson[key.c_str()]);
+ ops.push_back(opValue);
+ }
+ int64_t id = GetID(node["id"]);
+ mlir::StringRef callName(node["callee"].asString());
+ CallOp op = opBuilder.create<CallOp>(opBuilder.getUnknownLoc(),
+ id, callName, ops);
+ opData.push_back(op.getOperation());
+}
+
+void PluginServer::CondOpJsonDeSerialize(const string& data)
+{
+ Json::Value node;
+ Json::Reader reader;
+ reader.parse(data, node);
+ mlir::Value LHS = ValueJsonDeSerialize(node["lhs"]);
+ mlir::Value RHS = ValueJsonDeSerialize(node["rhs"]);
+ mlir::Value trueLabel = nullptr;
+ mlir::Value falseLabel = nullptr;
+ int64_t id = GetID(node["id"]);
+ int64_t address = GetID(node["address"]);
+ int64_t tbaddr = GetID(node["tbaddr"]);
+ int64_t fbaddr = GetID(node["fbaddr"]);
+ assert (this->blockMaps.find(tbaddr) != this->blockMaps.end());
+ assert (this->blockMaps.find(fbaddr) != this->blockMaps.end());
+ mlir::Block* tb = this->blockMaps[tbaddr];
+ mlir::Block* fb = this->blockMaps[fbaddr];
+ IComparisonCode iCode = IComparisonCode(
+ atoi(node["condCode"].asString().c_str()));
+ CondOp op = opBuilder.create<CondOp>(opBuilder.getUnknownLoc(), id,
+ address, iCode, LHS, RHS, tb, fb, tbaddr, fbaddr,
+ trueLabel, falseLabel);
+ opData.push_back(op.getOperation());
+}
+
+void PluginServer::RetOpJsonDeSerialize(const string& data)
+{
+ Json::Value node;
+ Json::Reader reader;
+ reader.parse(data, node);
+ int64_t address = GetID(node["address"]);
+ RetOp op = opBuilder.create<RetOp>(opBuilder.getUnknownLoc(), address);
+ opData.push_back(op.getOperation());
+}
+
+void PluginServer::FallThroughOpJsonDeSerialize(const string& data)
+{
+ Json::Value node;
+ Json::Reader reader;
+ reader.parse(data, node);
+ int64_t address = GetID(node["address"]);
+ int64_t destaddr = GetID(node["destaddr"]);
+ assert (this->blockMaps.find(destaddr) != this->blockMaps.end());
+ mlir::Block* succ = this->blockMaps[destaddr];
+ FallThroughOp op = opBuilder.create<FallThroughOp>(opBuilder.getUnknownLoc(),
+ address, succ, destaddr);
+ opData.push_back(op.getOperation());
+}
+
+void PluginServer::PhiOpJsonDeSerialize(const string& data)
+{
+ Json::Value node;
+ Json::Reader reader;
+ reader.parse(data, node);
+ Json::Value operandJson = node["operands"];
+ Json::Value::Members operandMember = operandJson.getMemberNames();
+ llvm::SmallVector<mlir::Value, 4> ops;
+ for (Json::Value::Members::iterator opIter = operandMember.begin();
+ opIter != operandMember.end(); opIter++) {
+ string key = *opIter;
+ mlir::Value opValue = ValueJsonDeSerialize(operandJson[key.c_str()]);
+ ops.push_back(opValue);
+ }
+ int64_t id = GetID(node["id"]);
+ uint32_t capacity = atoi(node["capacity"].asString().c_str());
+ uint32_t nArgs = atoi(node["nArgs"].asString().c_str());
+ mlir::Type retType = nullptr;
+ PhiOp op = opBuilder.create<PhiOp>(opBuilder.getUnknownLoc(),
+ id, capacity, nArgs, ops, retType);
+ opData.push_back(op.getOperation());
+}
+
+void PluginServer::AssignOpJsonDeSerialize(const string& data)
+{
+ Json::Value node;
+ Json::Reader reader;
+ reader.parse(data, node);
+ Json::Value operandJson = node["operands"];
+ Json::Value::Members operandMember = operandJson.getMemberNames();
+ llvm::SmallVector<mlir::Value, 4> ops;
+ for (Json::Value::Members::iterator opIter = operandMember.begin();
+ opIter != operandMember.end(); opIter++) {
+ string key = *opIter;
+ mlir::Value opValue = ValueJsonDeSerialize(operandJson[key.c_str()]);
+ ops.push_back(opValue);
+ }
+ int64_t id = GetID(node["id"]);
+ IExprCode iCode = IExprCode(atoi(node["exprCode"].asString().c_str()));
+ mlir::Type retType = nullptr;
+ AssignOp op = opBuilder.create<AssignOp>(opBuilder.getUnknownLoc(),
+ id, iCode, ops, retType);
+ opData.push_back(op.getOperation());
+}
+
/* 线程函数,执行用户注册函数,客户端返回数据后退出 */
static void ExecCallbacks(const string& name)
{
diff --git a/user.cpp b/user.cpp
index efc4158..605e6f8 100644
--- a/user.cpp
+++ b/user.cpp
@@ -117,7 +117,7 @@ ProcessArrayWiden(void)
void RegisterCallbacks(void)
{
- PluginServer::GetInstance()->RegisterUserFunc(HANDLE_BEFORE_IPA, UserOptimizeFunc);
+ // PluginServer::GetInstance()->RegisterUserFunc(HANDLE_BEFORE_IPA, UserOptimizeFunc);
PluginServer::GetInstance()->RegisterUserFunc(HANDLE_BEFORE_IPA, LocalVarSummery);
ManagerSetupData setupData;
setupData.refPassName = PASS_PHIOPT;
--
2.27.0.windows.1