pin-server/0003-Pin-server-Support-build-CFG-CondOp-CallOp-PhiOp-and.patch

985 lines
39 KiB
Diff
Raw Normal View History

From b721e3a4292503115329d6380b590b3bab1cd856 Mon Sep 17 00:00:00 2001
From: Mingchuan Wu <wumingchuan1992@foxmail.com>
Date: Thu, 8 Dec 2022 20:14:09 +0800
Subject: [PATCH 3/4] [Pin-server] Support build CFG, CondOp, CallOp, PhiOp and
AssignOp. Add CondOp, CallOp, PhiOp and AssignOp for PluginDialect. Now we
can support CFG.
---
include/Dialect/CMakeLists.txt | 3 +
include/Dialect/PluginDialect.td | 38 ++++
include/Dialect/PluginOps.h | 4 +
include/Dialect/PluginOps.td | 142 ++++++++++++++-
include/PluginAPI/BasicPluginOpsAPI.h | 5 +
include/PluginAPI/PluginServerAPI.h | 10 +-
include/PluginServer/PluginServer.h | 22 +++
lib/Dialect/PluginOps.cpp | 150 +++++++++++++++-
lib/PluginAPI/PluginServerAPI.cpp | 60 ++++++-
lib/PluginServer/PluginServer.cpp | 250 ++++++++++++++++++++++++--
user.cpp | 2 +-
11 files changed, 663 insertions(+), 23 deletions(-)
diff --git a/include/Dialect/CMakeLists.txt b/include/Dialect/CMakeLists.txt
index f0dfc58..f44e77c 100644
--- a/include/Dialect/CMakeLists.txt
+++ b/include/Dialect/CMakeLists.txt
@@ -1,5 +1,8 @@
# Add for the dialect operations. arg:(dialect dialect_namespace)
file(COPY /usr/bin/mlir-tblgen DESTINATION ./)
+set(LLVM_TARGET_DEFINITIONS PluginOps.td)
+mlir_tablegen(PluginOpsEnums.h.inc -gen-enum-decls)
+mlir_tablegen(PluginOpsEnums.cpp.inc -gen-enum-defs)
add_mlir_dialect(PluginOps Plugin)
# Necessary to generate documentation. arg:(doc_filename command output_file output_directory)
diff --git a/include/Dialect/PluginDialect.td b/include/Dialect/PluginDialect.td
index 362a5c8..31b338d 100644
--- a/include/Dialect/PluginDialect.td
+++ b/include/Dialect/PluginDialect.td
@@ -46,4 +46,42 @@ def Plugin_Dialect : Dialect {
class Plugin_Op<string mnemonic, list<OpTrait> traits = []> :
Op<Plugin_Dialect, mnemonic, traits>;
+//===----------------------------------------------------------------------===//
+// PluginDialect enum definitions
+//===----------------------------------------------------------------------===//
+
+def IComparisonLT : I32EnumAttrCase<"lt", 0>;
+def IComparisonLE : I32EnumAttrCase<"le", 1>;
+def IComparisonGT : I32EnumAttrCase<"gt", 2>;
+def IComparisonGE : I32EnumAttrCase<"ge", 3>;
+def IComparisonLTGT : I32EnumAttrCase<"ltgt", 4>;
+def IComparisonEQ : I32EnumAttrCase<"eq", 5>;
+def IComparisonNE : I32EnumAttrCase<"ne", 6>;
+def IComparisonUNDEF : I32EnumAttrCase<"UNDEF", 7>;
+def IComparisonAttr : I32EnumAttr<
+ "IComparisonCode", "plugin comparison code",
+ [IComparisonLT, IComparisonLE, IComparisonGT, IComparisonGE,
+ IComparisonLTGT, IComparisonEQ, IComparisonNE, IComparisonUNDEF]>{
+ let cppNamespace = "::mlir::Plugin";
+}
+
+def IDefineCodeMemRef : I32EnumAttrCase<"MemRef", 0>;
+def IDefineCodeIntCST : I32EnumAttrCase<"IntCST", 1>;
+def IDefineCodeUNDEF : I32EnumAttrCase<"UNDEF", 2>;
+def IDefineCodeAttr : I32EnumAttr<
+ "IDefineCode", "plugin define code",
+ [IDefineCodeMemRef, IDefineCodeIntCST, IDefineCodeUNDEF]>{
+ let cppNamespace = "::mlir::Plugin";
+}
+
+def IExprCodePlus : I32EnumAttrCase<"Plus", 0>;
+def IExprCodeMinus : I32EnumAttrCase<"Minus", 1>;
+def IExprCodeNop : I32EnumAttrCase<"Nop", 2>;
+def IExprCodeUNDEF : I32EnumAttrCase<"UNDEF", 3>;
+def IExprCodeAttr : I32EnumAttr<
+ "IExprCode", "plugin expr code",
+ [IExprCodePlus, IExprCodeMinus, IExprCodeNop, IExprCodeUNDEF]>{
+ let cppNamespace = "::mlir::Plugin";
+}
+
#endif // PLUGIN_DIALECT_TD
\ No newline at end of file
diff --git a/include/Dialect/PluginOps.h b/include/Dialect/PluginOps.h
index be2a3ef..8bab877 100644
--- a/include/Dialect/PluginOps.h
+++ b/include/Dialect/PluginOps.h
@@ -27,6 +27,10 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+
+// Pull in all enum type definitions and utility function declarations.
+#include "Dialect/PluginOpsEnums.h.inc"
#define GET_OP_CLASSES
#include "Dialect/PluginOps.h.inc"
diff --git a/include/Dialect/PluginOps.td b/include/Dialect/PluginOps.td
index 02d9bdd..e91cd84 100644
--- a/include/Dialect/PluginOps.td
+++ b/include/Dialect/PluginOps.td
@@ -28,8 +28,8 @@ def FunctionOp : Plugin_Op<"function", [NoSideEffect]> {
TODO.
}];
- let arguments = (ins OptionalAttr<UI64Attr>:$id,
- OptionalAttr<StrAttr>:$funcName,
+ let arguments = (ins UI64Attr:$id,
+ StrAttr:$funcName,
OptionalAttr<BoolAttr>:$declaredInline);
let regions = (region AnyRegion:$bodyRegion);
@@ -89,4 +89,142 @@ def LoopOp : Plugin_Op<"loop", [NoSideEffect]> {
}];
}
+def CallOp : Plugin_Op<"call", [
+ DeclareOpInterfaceMethods<CallOpInterface>]> {
+ let summary = "call operation";
+ let description = [{
+ CallOp represent calls to a user defined function that needs to
+ be specialized for the shape of its arguments.
+ The callee name is attached as a symbol reference via an attribute.
+ The arguments list must match the arguments expected by the callee.
+ }];
+ let arguments = (ins UI64Attr:$id,
+ FlatSymbolRefAttr:$callee,
+ Variadic<AnyType>:$inputs);
+ let results = (outs Optional<AnyType>:$result);
+ let builders = [
+ OpBuilderDAG<(ins "uint64_t":$id, "StringRef":$callee,
+ "ArrayRef<Value>":$arguments)>
+ ];
+ let extraClassDeclaration = [{
+ bool SetLHS(Value lhs);
+ }];
+ let assemblyFormat = [{
+ $callee `(` $inputs `)` attr-dict `:` functional-type($inputs, results)
+ }];
+}
+
+def PhiOp : Plugin_Op<"phi", [NoSideEffect]> {
+ let summary = "phi op";
+ let description = [{TODO}];
+ let arguments = (ins UI64Attr:$id,
+ UI32Attr:$capacity,
+ UI32Attr:$nArgs,
+ Variadic<AnyType>:$operands);
+ let results = (outs AnyType:$result);
+ let builders = [
+ OpBuilderDAG<(ins "uint64_t":$id, "uint32_t":$capacity,
+ "uint32_t":$nArgs, "ArrayRef<Value>":$operands,
+ "Type":$resultType)>
+ ];
+ let extraClassDeclaration = [{
+ Value GetResult();
+ Value GetArgDef(int i) { return getOperand(i); }
+ }];
+}
+
+def AssignOp : Plugin_Op<"assign", [NoSideEffect]> {
+ let summary = "assign op";
+ let description = [{TODO}];
+ let arguments = (ins UI64Attr:$id,
+ IExprCodeAttr:$exprCode,
+ Variadic<AnyType>:$operands);
+ let results = (outs AnyType:$result);
+ let builders = [
+ OpBuilderDAG<(ins "uint64_t":$id, "IExprCode":$exprCode,
+ "ArrayRef<Value>":$operands, "Type":$resultType)>
+ ];
+ let extraClassDeclaration = [{
+ Value GetLHS() { return getOperand(0); }
+ Value GetRHS1() { return getOperand(1); }
+ Value GetRHS2() { return getOperand(2); }
+ }];
+}
+
+def PlaceholderOp : Plugin_Op<"placeholder", [NoSideEffect]> {
+ let summary = "PlaceHolder.";
+ let description = [{TODO}];
+ let arguments = (ins UI64Attr:$id,
+ OptionalAttr<IDefineCodeAttr>:$defCode);
+ let results = (outs AnyType);
+ let builders = [
+ OpBuilderDAG<(ins "uint64_t":$id, "IDefineCode":$defCode, "Type":$retType)>
+ ];
+}
+
+def BaseOp : Plugin_Op<"statement_base", [NoSideEffect]> {
+ let summary = "Base operation, just like placeholder for statement.";
+ let description = [{TODO}];
+ let arguments = (ins UI64Attr:$id, StrAttr:$opCode);
+ let results = (outs AnyType);
+ let builders = [
+ OpBuilderDAG<(ins "uint64_t":$id, "StringRef":$opCode)>
+ ];
+}
+
+// Terminators
+// Opaque builder used for terminator operations that contain successors.
+
+class Plugin_TerminatorOp<string mnemonic, list<OpTrait> traits = []> :
+ Plugin_Op<mnemonic, !listconcat(traits, [Terminator])>;
+
+def FallThroughOp : Plugin_TerminatorOp<"fallthrough", [NoSideEffect]> {
+ let summary = "FallThroughOp";
+ let description = [{TODO}];
+ let successors = (successor AnySuccessor:$dest);
+ // for bb address
+ let arguments = (ins UI64Attr:$address, UI64Attr:$destaddr);
+ let results = (outs AnyType);
+ let builders = [
+ OpBuilderDAG<(ins "uint64_t":$address, "Block*":$dest, "uint64_t":$destaddr)>
+ ];
+}
+
+def CondOp : Plugin_TerminatorOp<"condition", [NoSideEffect]> {
+ let summary = "condition op";
+ let description = [{TODO}];
+ let arguments = (ins UI64Attr:$id, UI64Attr:$address,
+ IComparisonAttr:$condCode,
+ AnyType:$LHS, AnyType:$RHS,
+ UI64Attr:$tbaddr,
+ UI64Attr:$fbaddr,
+ OptionalAttr<TypeAttr>:$trueLabel,
+ OptionalAttr<TypeAttr>:$falseLabel);
+ let successors = (successor AnySuccessor:$tb, AnySuccessor:$fb);
+ let builders = [
+ OpBuilderDAG<(ins "uint64_t":$id, "uint64_t":$address, "IComparisonCode":$condCode,
+ "Value":$lhs, "Value":$rhs, "Block*":$tb, "Block*":$fb,
+ "uint64_t":$tbaddr, "uint64_t":$fbaddr, "Value":$trueLabel,
+ "Value":$falseLabel)>,
+ // Only for server.
+ OpBuilderDAG<(ins "IComparisonCode":$condCode,
+ "Value":$lhs, "Value":$rhs)>
+ ];
+ let extraClassDeclaration = [{
+ Value GetLHS() { return getOperand(0); }
+ Value GetRHS() { return getOperand(1); }
+ }];
+}
+
+// todo: currently RetOp do not have a correct assemblyFormat
+def RetOp : Plugin_TerminatorOp<"ret", [NoSideEffect]> {
+ let summary = "RetOp";
+ let description = [{TODO}];
+ let arguments = (ins UI64Attr:$address); // for bb address
+ let results = (outs AnyType);
+ let builders = [
+ OpBuilderDAG<(ins "uint64_t":$address)>
+ ];
+}
+
#endif // PLUGIN_OPS_TD
\ No newline at end of file
diff --git a/include/PluginAPI/BasicPluginOpsAPI.h b/include/PluginAPI/BasicPluginOpsAPI.h
index fde5f63..6f5fbb0 100644
--- a/include/PluginAPI/BasicPluginOpsAPI.h
+++ b/include/PluginAPI/BasicPluginOpsAPI.h
@@ -48,6 +48,11 @@ public:
virtual pair<uint64_t, uint64_t> LoopSingleExit(uint64_t) = 0;
virtual vector<pair<uint64_t, uint64_t> > GetLoopExitEdges(uint64_t) = 0;
virtual LoopOp GetBlockLoopFather(uint64_t) = 0;
+ virtual PhiOp GetPhiOp(uint64_t) = 0;
+ virtual CallOp GetCallOp(uint64_t) = 0;
+ virtual bool SetLhsInCallOp(uint64_t, uint64_t) = 0;
+ virtual uint64_t CreateCondOp(IComparisonCode, uint64_t, uint64_t) = 0;
+ virtual mlir::Value GetResultFromPhi(uint64_t) = 0;
}; // class BasicPluginOpsAPI
} // namespace PluginAPI
diff --git a/include/PluginAPI/PluginServerAPI.h b/include/PluginAPI/PluginServerAPI.h
index 5b14f2e..425b946 100644
--- a/include/PluginAPI/PluginServerAPI.h
+++ b/include/PluginAPI/PluginServerAPI.h
@@ -38,6 +38,8 @@ public:
vector<FunctionOp> GetAllFunc() override;
vector<LocalDeclOp> GetDecls(uint64_t) override;
+ PhiOp GetPhiOp(uint64_t) override;
+ CallOp GetCallOp(uint64_t) override;
PluginIR::PluginTypeID GetTypeCodeFromString(string type);
LoopOp AllocateNewLoop(uint64_t funcID);
LoopOp GetLoopById(uint64_t loopID);
@@ -51,8 +53,14 @@ public:
uint64_t GetLatch(uint64_t loopID);
vector<uint64_t> GetLoopBody(uint64_t loopID);
LoopOp GetBlockLoopFather(uint64_t blockID);
+ /* Plugin API for CallOp. */
+ bool SetLhsInCallOp(uint64_t, uint64_t);
+ /* Plugin API for CondOp. */
+ uint64_t CreateCondOp(IComparisonCode, uint64_t, uint64_t) override;
+ mlir::Value GetResultFromPhi(uint64_t) override;
+
private:
- vector<FunctionOp> GetOperationResult(const string& funName, const string& params);
+ vector<FunctionOp> GetFunctionOpResult(const string& funName, const string& params);
vector<LocalDeclOp> GetDeclOperationResult(const string& funName, const string& params);
LoopOp GetLoopResult(const string&funName, const string& params);
vector<LoopOp> GetLoopsResult(const string& funName, const string& params);
diff --git a/include/PluginServer/PluginServer.h b/include/PluginServer/PluginServer.h
index 610ecb9..b28c6d0 100644
--- a/include/PluginServer/PluginServer.h
+++ b/include/PluginServer/PluginServer.h
@@ -33,6 +33,7 @@
#include "Dialect/PluginOps.h"
#include "plugin.grpc.pb.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Builders.h"
#include "Dialect/PluginTypes.h"
namespace PinServer {
@@ -120,6 +121,7 @@ private:
class PluginServer final : public PluginService::Service {
public:
+ PluginServer() : opBuilder(&context){}
/* 定义的grpc服务端和客户端通信的接口函数 */
Status ReceiveSendMsg(ServerContext* context, ServerReaderWriter<ServerMsg, ClientMsg>* stream) override;
/* 服务端发送数据给client接口 */
@@ -137,6 +139,10 @@ public:
vector<std::pair<uint64_t, uint64_t> > EdgesResult(void);
std::pair<uint64_t, uint64_t> EdgeResult(void);
bool BoolResult(void);
+ vector<mlir::Operation *> GetOpResult(void);
+ bool GetBoolResult(void);
+ uint64_t GetIdResult(void);
+ mlir::Value GetValueResult(void);
/* 回调函数接口用于向server注册用户需要执行的函数 */
int RegisterUserFunc(InjectPoint inject, UserFunc func);
int RegisterPassManagerSetup(InjectPoint inject, const ManagerSetupData& passData, UserFunc func);
@@ -190,6 +196,13 @@ public:
void EdgeJsonDeSerialize(const string& data);
void BlocksJsonDeSerialize(const string& data);
void BlockJsonDeSerialize(const string& data);
+ void CallOpJsonDeSerialize(const string& data);
+ void CondOpJsonDeSerialize(const string& data);
+ void RetOpJsonDeSerialize(const string& data);
+ void FallThroughOpJsonDeSerialize(const string& data);
+ void PhiOpJsonDeSerialize(const string& data);
+ void AssignOpJsonDeSerialize(const string& data);
+ mlir::Value ValueJsonDeSerialize(Json::Value valueJson);
/* json反序列化根据key值分别调用Operation/Decl/Type反序列化接口函数 */
void JsonDeSerialize(const string& key, const string& data);
/* 解析客户端发送过来的-fplugin-arg参数并保存在私有变量args中 */
@@ -234,6 +247,7 @@ private:
/* 用户函数执行状态client返回结果后为STATE_RETURN,开始执行下一个函数 */
volatile UserFunStateEnum userFunState;
mlir::MLIRContext context;
+ mlir::OpBuilder opBuilder;
vector<mlir::Plugin::FunctionOp> funcOpData;
PluginIR::PluginTypeBase pluginType;
vector<mlir::Plugin::LocalDeclOp> decls;
@@ -244,6 +258,10 @@ private:
vector<uint64_t> blockIds;
uint64_t blockId;
bool boolRes;
+ vector<mlir::Operation *> opData;
+ bool boolResult;
+ bool idResult;
+ mlir::Value valueResult;
/* 保存用户注册的回调函数,它们将在注入点事件触发后调用 */
map<InjectPoint, vector<RecordedUserFunc>> userFunc;
string apiFuncName; // 保存用户调用PluginAPI的函数名
@@ -252,6 +270,10 @@ private:
timer_t timerId;
map<string, string> args; // 保存gcc编译时用户传入参数
sem_t sem[2];
+
+ // process Block.
+ std::map<uint64_t, mlir::Block*> blockMaps;
+ bool ProcessBlock(mlir::Block*, mlir::Region&, const Json::Value&);
}; // class PluginServer
void RunServer(int timeout, string& port);
diff --git a/lib/Dialect/PluginOps.cpp b/lib/Dialect/PluginOps.cpp
index 481d51a..d647fc4 100644
--- a/lib/Dialect/PluginOps.cpp
+++ b/lib/Dialect/PluginOps.cpp
@@ -33,8 +33,9 @@ using namespace mlir::Plugin;
using std::vector;
using std::pair;
-void FunctionOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
- uint64_t id, StringRef funcName, bool declaredInline) {
+void FunctionOp::build(OpBuilder &builder, OperationState &state,
+ uint64_t id, StringRef funcName, bool declaredInline)
+{
FunctionOp::build(builder, state,
builder.getI64IntegerAttr(id),
builder.getStringAttr(funcName),
@@ -48,9 +49,10 @@ vector<LoopOp> FunctionOp::GetAllLoops()
return pluginAPI.GetLoopsFromFunc(funcId);
}
-void LocalDeclOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
+void LocalDeclOp::build(OpBuilder &builder, OperationState &state,
uint64_t id, StringRef symName,
- int64_t typeID, uint64_t typeWidth) {
+ int64_t typeID, uint64_t typeWidth)
+{
LocalDeclOp::build(builder, state,
builder.getI64IntegerAttr(id),
builder.getStringAttr(symName),
@@ -142,6 +144,146 @@ void LoopOp::AddLoop(uint64_t outerId, uint64_t funcId)
return pluginAPI.AddLoop(loopId, outerId, funcId);
}
+//===----------------------------------------------------------------------===//
+// PlaceholderOp
+
+void PlaceholderOp::build(OpBuilder &builder, OperationState &state,
+ uint64_t id, IDefineCode defCode, Type retType) {
+ state.addAttribute("id", builder.getI64IntegerAttr(id));
+ state.addAttribute("defCode",
+ builder.getI32IntegerAttr(static_cast<int32_t>(defCode)));
+ if (retType) state.addTypes(retType);
+}
+
+//===----------------------------------------------------------------------===//
+// CallOp
+
+void CallOp::build(OpBuilder &builder, OperationState &state,
+ uint64_t id, StringRef callee,
+ ArrayRef<Value> arguments)
+{
+ state.addAttribute("id", builder.getI64IntegerAttr(id));
+ state.addOperands(arguments);
+ state.addAttribute("callee", builder.getSymbolRefAttr(callee));
+}
+
+/// Return the callee of the generic call operation, this is required by the
+/// call interface.
+CallInterfaceCallable CallOp::getCallableForCallee()
+{
+ return (*this)->getAttrOfType<SymbolRefAttr>("callee");
+}
+
+/// Get the argument operands to the called function, this is required by the
+/// call interface.
+Operation::operand_range CallOp::getArgOperands() { return inputs(); }
+
+bool CallOp::SetLHS(Value lhs)
+{
+ PlaceholderOp phOp = lhs.getDefiningOp<PlaceholderOp>();
+ uint64_t lhsId = phOp.idAttr().getInt();
+ PluginAPI::PluginServerAPI pluginAPI;
+ return pluginAPI.SetLhsInCallOp(this->idAttr().getInt(), lhsId);
+}
+
+//===----------------------------------------------------------------------===//
+// CondOp
+
+void CondOp::build(OpBuilder &builder, OperationState &state,
+ uint64_t id, uint64_t address, IComparisonCode condCode,
+ Value lhs, Value rhs, Block* tb, Block* fb, uint64_t tbaddr,
+ uint64_t fbaddr, Value trueLabel, Value falseLabel) {
+ state.addAttribute("id", builder.getI64IntegerAttr(id));
+ state.addAttribute("address", builder.getI64IntegerAttr(address));
+ state.addAttribute("tbaddr", builder.getI64IntegerAttr(tbaddr));
+ state.addAttribute("fbaddr", builder.getI64IntegerAttr(fbaddr));
+ state.addOperands({lhs, rhs});
+ state.addSuccessors(tb);
+ state.addSuccessors(fb);
+ state.addAttribute("condCode",
+ builder.getI32IntegerAttr(static_cast<int32_t>(condCode)));
+ if (trueLabel != nullptr) state.addOperands(trueLabel);
+ if (falseLabel != nullptr) state.addOperands(falseLabel);
+}
+
+void CondOp::build(OpBuilder &builder, OperationState &state,
+ IComparisonCode condCode, Value lhs, Value rhs)
+{
+ PluginAPI::PluginServerAPI pluginAPI;
+ PlaceholderOp lhsOp = lhs.getDefiningOp<PlaceholderOp>();
+ uint64_t lhsId = lhsOp.idAttr().getInt();
+ PlaceholderOp rhsOp = rhs.getDefiningOp<PlaceholderOp>();
+ uint64_t rhsId = rhsOp.idAttr().getInt();
+ uint64_t id = pluginAPI.CreateCondOp(condCode, lhsId, rhsId);
+ state.addAttribute("id", builder.getI64IntegerAttr(id));
+ state.addOperands({lhs, rhs});
+ state.addAttribute("condCode",
+ builder.getI32IntegerAttr(static_cast<int32_t>(condCode)));
+}
+
+//===----------------------------------------------------------------------===//
+// PhiOp
+
+void PhiOp::build(OpBuilder &builder, OperationState &state,
+ uint64_t id, uint32_t capacity, uint32_t nArgs,
+ ArrayRef<Value> operands, Type resultType)
+{
+ state.addAttribute("id", builder.getI64IntegerAttr(id));
+ state.addAttribute("capacity", builder.getI32IntegerAttr(capacity));
+ state.addAttribute("nArgs", builder.getI32IntegerAttr(nArgs));
+ state.addOperands(operands);
+ if (resultType) state.addTypes(resultType);
+}
+
+Value PhiOp::GetResult()
+{
+ PluginAPI::PluginServerAPI pluginAPI;
+ return pluginAPI.GetResultFromPhi(this->idAttr().getInt());
+}
+
+//===----------------------------------------------------------------------===//
+// AssignOp
+
+void AssignOp::build(OpBuilder &builder, OperationState &state,
+ uint64_t id, IExprCode exprCode,
+ ArrayRef<Value> operands, Type resultType)
+{
+ state.addAttribute("id", builder.getI64IntegerAttr(id));
+ state.addAttribute("exprCode",
+ builder.getI32IntegerAttr(static_cast<int32_t>(exprCode)));
+ state.addOperands(operands);
+ if (resultType) state.addTypes(resultType);
+}
+
+//===----------------------------------------------------------------------===//
+// BaseOp
+
+void BaseOp::build(OpBuilder &builder, OperationState &state,
+ uint64_t id, StringRef opCode)
+{
+ state.addAttribute("id", builder.getI64IntegerAttr(id));
+ state.addAttribute("opCode", builder.getStringAttr(opCode));
+}
+
+//===----------------------------------------------------------------------===//
+// FallThroughOp
+
+void FallThroughOp::build(OpBuilder &builder, OperationState &state,
+ uint64_t address, Block* dest, uint64_t destaddr)
+{
+ state.addAttribute("address", builder.getI64IntegerAttr(address));
+ state.addAttribute("destaddr", builder.getI64IntegerAttr(destaddr));
+ state.addSuccessors(dest);
+}
+
+//===----------------------------------------------------------------------===//
+// RetOp
+
+void RetOp::build(OpBuilder &builder, OperationState &state, uint64_t address)
+{
+ state.addAttribute("address", builder.getI64IntegerAttr(address));
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/lib/PluginAPI/PluginServerAPI.cpp b/lib/PluginAPI/PluginServerAPI.cpp
index 41626f9..65eafa7 100644
--- a/lib/PluginAPI/PluginServerAPI.cpp
+++ b/lib/PluginAPI/PluginServerAPI.cpp
@@ -54,7 +54,7 @@ void PluginServerAPI::WaitClientResult(const string& funName, const string& para
}
}
-vector<FunctionOp> PluginServerAPI::GetOperationResult(const string& funName, const string& params)
+vector<FunctionOp> PluginServerAPI::GetFunctionOpResult(const string& funName, const string& params)
{
WaitClientResult(funName, params);
vector<FunctionOp> retOps = PluginServer::GetInstance()->GetFunctionOpResult();
@@ -67,7 +67,63 @@ vector<FunctionOp> PluginServerAPI::GetAllFunc()
string funName = __func__;
string params = root.toStyledString();
- return GetOperationResult(funName, params);
+ return GetFunctionOpResult(funName, params);
+}
+
+PhiOp PluginServerAPI::GetPhiOp(uint64_t id)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["id"] = std::to_string(id);
+ string params = root.toStyledString();
+ WaitClientResult(funName, params);
+ vector<mlir::Operation*> opRet = PluginServer::GetInstance()->GetOpResult();
+ return llvm::dyn_cast<PhiOp>(opRet[0]);
+}
+
+CallOp PluginServerAPI::GetCallOp(uint64_t id)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["id"] = std::to_string(id);
+ string params = root.toStyledString();
+ WaitClientResult(funName, params);
+ vector<mlir::Operation*> opRet = PluginServer::GetInstance()->GetOpResult();
+ return llvm::dyn_cast<CallOp>(opRet[0]);
+}
+
+bool PluginServerAPI::SetLhsInCallOp(uint64_t callId, uint64_t lhsId)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["callId"] = std::to_string(callId);
+ root["lhsId"] = std::to_string(lhsId);
+ string params = root.toStyledString();
+ WaitClientResult(funName, params);
+ return PluginServer::GetInstance()->GetBoolResult();
+}
+
+uint64_t PluginServerAPI::CreateCondOp(IComparisonCode iCode,
+ uint64_t lhs, uint64_t rhs)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["condCode"] = std::to_string(static_cast<int32_t>(iCode));
+ root["lhsId"] = std::to_string(lhs);
+ root["rhsId"] = std::to_string(rhs);
+ string params = root.toStyledString();
+ WaitClientResult(funName, params);
+ return PluginServer::GetInstance()->GetIdResult();
+}
+
+mlir::Value PluginServerAPI::GetResultFromPhi(uint64_t phiId)
+{
+ Json::Value root;
+ string funName = __func__;
+ root["id"] = std::to_string(phiId);
+ string params = root.toStyledString();
+ WaitClientResult(funName, params);
+ return PluginServer::GetInstance()->GetValueResult();
}
PluginIR::PluginTypeID PluginServerAPI::GetTypeCodeFromString(string type)
diff --git a/lib/PluginServer/PluginServer.cpp b/lib/PluginServer/PluginServer.cpp
index 92159e8..2eac906 100644
--- a/lib/PluginServer/PluginServer.cpp
+++ b/lib/PluginServer/PluginServer.cpp
@@ -27,7 +27,6 @@
#include "Dialect/PluginDialect.h"
#include "PluginAPI/PluginServerAPI.h"
#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "user.h"
@@ -37,7 +36,6 @@
namespace PinServer {
using namespace mlir::Plugin;
-
using std::cout;
using std::endl;
using std::pair;
@@ -75,6 +73,13 @@ int PluginServer::RegisterPassManagerSetup(InjectPoint inject, const ManagerSetu
return 0;
}
+vector<mlir::Operation *> PluginServer::GetOpResult(void)
+{
+ vector<mlir::Operation *> retOps = opData;
+ opData.clear();
+ return retOps;
+}
+
vector<mlir::Plugin::FunctionOp> PluginServer::GetFunctionOpResult(void)
{
vector<mlir::Plugin::FunctionOp> retOps = funcOpData;
@@ -135,6 +140,21 @@ vector<pair<uint64_t, uint64_t> > PluginServer::EdgesResult()
return retEdges;
}
+bool PluginServer::GetBoolResult()
+{
+ return this->boolResult;
+}
+
+uint64_t PluginServer::GetIdResult()
+{
+ return this->idResult;
+}
+
+mlir::Value PluginServer::GetValueResult()
+{
+ return this->valueResult;
+}
+
void PluginServer::JsonGetAttributes(Json::Value node, map<string, string>& attributes)
{
Json::Value::Members attMember = node.getMemberNames();
@@ -151,6 +171,27 @@ static uintptr_t GetID(Json::Value node)
return atol(id.c_str());
}
+mlir::Value PluginServer::ValueJsonDeSerialize(Json::Value valueJson)
+{
+ uint64_t opId = GetID(valueJson["id"]);
+ IDefineCode defCode = IDefineCode(
+ atoi(valueJson["defCode"].asString().c_str()));
+ switch (defCode) {
+ case IDefineCode::MemRef : {
+ break;
+ }
+ case IDefineCode::IntCST : {
+ break;
+ }
+ default:
+ break;
+ }
+ mlir::Type retType = PluginIR::PluginUndefType::get(&context); // FIXME!
+ mlir::Value opValue = opBuilder.create<PlaceholderOp>(
+ opBuilder.getUnknownLoc(), opId, defCode, retType);
+ return opValue;
+}
+
void PluginServer::JsonDeSerialize(const string& key, const string& data)
{
if (key == "FuncOpResult") {
@@ -173,6 +214,12 @@ void PluginServer::JsonDeSerialize(const string& key, const string& data)
EdgesJsonDeSerialize(data);
} else if (key == "BlockIdsResult") {
BlocksJsonDeSerialize(data);
+ } else if (key == "IdResult") {
+ this->idResult = atol(data.c_str());
+ } else if (key == "ValueResult") {
+ context.getOrLoadDialect<PluginDialect>();
+ opBuilder = mlir::OpBuilder(&context);
+ this->valueResult = ValueJsonDeSerialize(data.c_str());
} else {
cout << "not Json,key:" << key << ",value:" << data << endl;
}
@@ -217,6 +264,46 @@ void PluginServer::TypeJsonDeSerialize(const string& data)
return;
}
+bool PluginServer::ProcessBlock(mlir::Block* block, mlir::Region& rg,
+ const Json::Value& blockJson)
+{
+ if (blockJson.isNull()) {
+ return false;
+ }
+ // todo process func return type
+ // todo isDeclaration
+
+ // process each stmt
+ opBuilder.setInsertionPointToStart(block);
+ Json::Value::Members opMember = blockJson.getMemberNames();
+ for (Json::Value::Members::iterator opIdx = opMember.begin();
+ opIdx != opMember.end(); opIdx++) {
+ string baseOpKey = *opIdx;
+ Json::Value opJson = blockJson[baseOpKey];
+ if (opJson.isNull()) continue;
+ // llvm::StringRef opCode(opJson["OperationName"].asString().c_str());
+ string opCode = opJson["OperationName"].asString();
+ if (opCode == PhiOp::getOperationName().str()) {
+ PhiOpJsonDeSerialize(opJson.toStyledString());
+ } else if (opCode == CallOp::getOperationName().str()) {
+ CallOpJsonDeSerialize(opJson.toStyledString());
+ } else if (opCode == AssignOp::getOperationName().str()) {
+ AssignOpJsonDeSerialize(opJson.toStyledString());
+ } else if (opCode == CondOp::getOperationName().str()) {
+ CondOpJsonDeSerialize(opJson.toStyledString());
+ } else if (opCode == RetOp::getOperationName().str()) {
+ RetOpJsonDeSerialize(opJson.toStyledString());
+ } else if (opCode == FallThroughOp::getOperationName().str()) {
+ FallThroughOpJsonDeSerialize(opJson.toStyledString());
+ } else if (opCode == BaseOp::getOperationName().str()) {
+ uint64_t opID = GetID(opJson["id"]);
+ opBuilder.create<BaseOp>(opBuilder.getUnknownLoc(), opID, opCode);
+ }
+ }
+ // fprintf(stderr, "[bb] op:%ld, succ: %d\n", block->getOperations().size(), block->getNumSuccessors());
+ return true;
+}
+
void PluginServer::FuncOpJsonDeSerialize(const string& data)
{
Json::Value root;
@@ -227,8 +314,9 @@ void PluginServer::FuncOpJsonDeSerialize(const string& data)
Json::Value::Members operation = root.getMemberNames();
context.getOrLoadDialect<PluginDialect>();
- mlir::OpBuilder builder(&context);
- for (Json::Value::Members::iterator iter = operation.begin(); iter != operation.end(); iter++) {
+ opBuilder = mlir::OpBuilder(&context);
+ for (Json::Value::Members::iterator iter = operation.begin();
+ iter != operation.end(); iter++) {
string operationKey = *iter;
node = root[operationKey];
int64_t id = GetID(node["id"]);
@@ -237,9 +325,30 @@ void PluginServer::FuncOpJsonDeSerialize(const string& data)
JsonGetAttributes(attributes, funcAttributes);
bool declaredInline = false;
if (funcAttributes["declaredInline"] == "1") declaredInline = true;
- auto location = builder.getUnknownLoc();
- FunctionOp op = builder.create<FunctionOp>(location, id, funcAttributes["funcName"], declaredInline);
- funcOpData.push_back(op);
+ auto location = opBuilder.getUnknownLoc();
+ FunctionOp fOp = opBuilder.create<FunctionOp>(
+ location, id, funcAttributes["funcName"], declaredInline);
+ mlir::Region &bodyRegion = fOp.bodyRegion();
+ Json::Value regionJson = node["region"];
+ Json::Value::Members bbMember = regionJson.getMemberNames();
+ // We must create Blocks before process ops
+ for (Json::Value::Members::iterator bbIdx = bbMember.begin();
+ bbIdx != bbMember.end(); bbIdx++) {
+ string blockKey = *bbIdx;
+ Json::Value blockJson = regionJson[blockKey];
+ mlir::Block* block = opBuilder.createBlock(&bodyRegion);
+ this->blockMaps.insert({GetID(blockJson["address"]), block});
+ }
+
+ for (Json::Value::Members::iterator bbIdx = bbMember.begin();
+ bbIdx != bbMember.end(); bbIdx++) {
+ string blockKey = *bbIdx;
+ Json::Value blockJson = regionJson[blockKey];
+ uint64_t bbAddress = GetID(blockJson["address"]);
+ ProcessBlock(this->blockMaps[bbAddress], bodyRegion, blockJson["ops"]);
+ }
+ funcOpData.push_back(fOp);
+ opBuilder.setInsertionPointAfter(fOp.getOperation());
}
}
@@ -253,7 +362,7 @@ void PluginServer::LocalDeclOpJsonDeSerialize(const string& data)
Json::Value::Members operation = root.getMemberNames();
context.getOrLoadDialect<PluginDialect>();
- mlir::OpBuilder builder(&context);
+ opBuilder = mlir::OpBuilder(&context);
for (Json::Value::Members::iterator iter = operation.begin(); iter != operation.end(); iter++) {
string operationKey = *iter;
node = root[operationKey];
@@ -261,11 +370,11 @@ void PluginServer::LocalDeclOpJsonDeSerialize(const string& data)
Json::Value attributes = node["attributes"];
map<string, string> declAttributes;
JsonGetAttributes(attributes, declAttributes);
- string symName = declAttributes["symName"];
- uint64_t typeID = atol(declAttributes["typeID"].c_str());
- uint64_t typeWidth = atol(declAttributes["typeWidth"].c_str());
- auto location = builder.getUnknownLoc();
- LocalDeclOp op = builder.create<LocalDeclOp>(location, id, symName, typeID, typeWidth);
+ string symName = declAttributes["symName"];
+ uint64_t typeID = atol(declAttributes["typeID"].c_str());
+ uint64_t typeWidth = atol(declAttributes["typeWidth"].c_str());
+ auto location = opBuilder.getUnknownLoc();
+ LocalDeclOp op = opBuilder.create<LocalDeclOp>(location, id, symName, typeID, typeWidth);
decls.push_back(op);
}
}
@@ -384,6 +493,121 @@ void PluginServer::BlockJsonDeSerialize(const string& data)
blockId = (uint64_t)atol(root["id"].asString().c_str());
}
+void PluginServer::CallOpJsonDeSerialize(const string& data)
+{
+ Json::Value node;
+ Json::Reader reader;
+ reader.parse(data, node);
+ Json::Value operandJson = node["operands"];
+ Json::Value::Members operandMember = operandJson.getMemberNames();
+ llvm::SmallVector<mlir::Value, 4> ops;
+ for (Json::Value::Members::iterator opIter = operandMember.begin();
+ opIter != operandMember.end(); opIter++) {
+ string key = *opIter;
+ mlir::Value opValue = ValueJsonDeSerialize(operandJson[key.c_str()]);
+ ops.push_back(opValue);
+ }
+ int64_t id = GetID(node["id"]);
+ mlir::StringRef callName(node["callee"].asString());
+ CallOp op = opBuilder.create<CallOp>(opBuilder.getUnknownLoc(),
+ id, callName, ops);
+ opData.push_back(op.getOperation());
+}
+
+void PluginServer::CondOpJsonDeSerialize(const string& data)
+{
+ Json::Value node;
+ Json::Reader reader;
+ reader.parse(data, node);
+ mlir::Value LHS = ValueJsonDeSerialize(node["lhs"]);
+ mlir::Value RHS = ValueJsonDeSerialize(node["rhs"]);
+ mlir::Value trueLabel = nullptr;
+ mlir::Value falseLabel = nullptr;
+ int64_t id = GetID(node["id"]);
+ int64_t address = GetID(node["address"]);
+ int64_t tbaddr = GetID(node["tbaddr"]);
+ int64_t fbaddr = GetID(node["fbaddr"]);
+ assert (this->blockMaps.find(tbaddr) != this->blockMaps.end());
+ assert (this->blockMaps.find(fbaddr) != this->blockMaps.end());
+ mlir::Block* tb = this->blockMaps[tbaddr];
+ mlir::Block* fb = this->blockMaps[fbaddr];
+ IComparisonCode iCode = IComparisonCode(
+ atoi(node["condCode"].asString().c_str()));
+ CondOp op = opBuilder.create<CondOp>(opBuilder.getUnknownLoc(), id,
+ address, iCode, LHS, RHS, tb, fb, tbaddr, fbaddr,
+ trueLabel, falseLabel);
+ opData.push_back(op.getOperation());
+}
+
+void PluginServer::RetOpJsonDeSerialize(const string& data)
+{
+ Json::Value node;
+ Json::Reader reader;
+ reader.parse(data, node);
+ int64_t address = GetID(node["address"]);
+ RetOp op = opBuilder.create<RetOp>(opBuilder.getUnknownLoc(), address);
+ opData.push_back(op.getOperation());
+}
+
+void PluginServer::FallThroughOpJsonDeSerialize(const string& data)
+{
+ Json::Value node;
+ Json::Reader reader;
+ reader.parse(data, node);
+ int64_t address = GetID(node["address"]);
+ int64_t destaddr = GetID(node["destaddr"]);
+ assert (this->blockMaps.find(destaddr) != this->blockMaps.end());
+ mlir::Block* succ = this->blockMaps[destaddr];
+ FallThroughOp op = opBuilder.create<FallThroughOp>(opBuilder.getUnknownLoc(),
+ address, succ, destaddr);
+ opData.push_back(op.getOperation());
+}
+
+void PluginServer::PhiOpJsonDeSerialize(const string& data)
+{
+ Json::Value node;
+ Json::Reader reader;
+ reader.parse(data, node);
+ Json::Value operandJson = node["operands"];
+ Json::Value::Members operandMember = operandJson.getMemberNames();
+ llvm::SmallVector<mlir::Value, 4> ops;
+ for (Json::Value::Members::iterator opIter = operandMember.begin();
+ opIter != operandMember.end(); opIter++) {
+ string key = *opIter;
+ mlir::Value opValue = ValueJsonDeSerialize(operandJson[key.c_str()]);
+ ops.push_back(opValue);
+ }
+ int64_t id = GetID(node["id"]);
+ uint32_t capacity = atoi(node["capacity"].asString().c_str());
+ uint32_t nArgs = atoi(node["nArgs"].asString().c_str());
+ mlir::Type retType = nullptr;
+ PhiOp op = opBuilder.create<PhiOp>(opBuilder.getUnknownLoc(),
+ id, capacity, nArgs, ops, retType);
+ opData.push_back(op.getOperation());
+}
+
+void PluginServer::AssignOpJsonDeSerialize(const string& data)
+{
+ Json::Value node;
+ Json::Reader reader;
+ reader.parse(data, node);
+ Json::Value operandJson = node["operands"];
+ Json::Value::Members operandMember = operandJson.getMemberNames();
+ llvm::SmallVector<mlir::Value, 4> ops;
+ for (Json::Value::Members::iterator opIter = operandMember.begin();
+ opIter != operandMember.end(); opIter++) {
+ string key = *opIter;
+ mlir::Value opValue = ValueJsonDeSerialize(operandJson[key.c_str()]);
+ ops.push_back(opValue);
+ }
+ int64_t id = GetID(node["id"]);
+ IExprCode iCode = IExprCode(atoi(node["exprCode"].asString().c_str()));
+ mlir::Type retType = nullptr;
+ AssignOp op = opBuilder.create<AssignOp>(opBuilder.getUnknownLoc(),
+ id, iCode, ops, retType);
+ opData.push_back(op.getOperation());
+}
+
/* 线程函数,执行用户注册函数,客户端返回数据后退出 */
static void ExecCallbacks(const string& name)
{
diff --git a/user.cpp b/user.cpp
index efc4158..605e6f8 100644
--- a/user.cpp
+++ b/user.cpp
@@ -117,7 +117,7 @@ ProcessArrayWiden(void)
void RegisterCallbacks(void)
{
- PluginServer::GetInstance()->RegisterUserFunc(HANDLE_BEFORE_IPA, UserOptimizeFunc);
+ // PluginServer::GetInstance()->RegisterUserFunc(HANDLE_BEFORE_IPA, UserOptimizeFunc);
PluginServer::GetInstance()->RegisterUserFunc(HANDLE_BEFORE_IPA, LocalVarSummery);
ManagerSetupData setupData;
setupData.refPassName = PASS_PHIOPT;
--
2.27.0.windows.1