From b721e3a4292503115329d6380b590b3bab1cd856 Mon Sep 17 00:00:00 2001 From: Mingchuan Wu Date: Thu, 8 Dec 2022 20:14:09 +0800 Subject: [PATCH 3/4] [Pin-server] Support build CFG, CondOp, CallOp, PhiOp and AssignOp. Add CondOp, CallOp, PhiOp and AssignOp for PluginDialect. Now we can support CFG. --- include/Dialect/CMakeLists.txt | 3 + include/Dialect/PluginDialect.td | 38 ++++ include/Dialect/PluginOps.h | 4 + include/Dialect/PluginOps.td | 142 ++++++++++++++- include/PluginAPI/BasicPluginOpsAPI.h | 5 + include/PluginAPI/PluginServerAPI.h | 10 +- include/PluginServer/PluginServer.h | 22 +++ lib/Dialect/PluginOps.cpp | 150 +++++++++++++++- lib/PluginAPI/PluginServerAPI.cpp | 60 ++++++- lib/PluginServer/PluginServer.cpp | 250 ++++++++++++++++++++++++-- user.cpp | 2 +- 11 files changed, 663 insertions(+), 23 deletions(-) diff --git a/include/Dialect/CMakeLists.txt b/include/Dialect/CMakeLists.txt index f0dfc58..f44e77c 100644 --- a/include/Dialect/CMakeLists.txt +++ b/include/Dialect/CMakeLists.txt @@ -1,5 +1,8 @@ # Add for the dialect operations. arg:(dialect dialect_namespace) file(COPY /usr/bin/mlir-tblgen DESTINATION ./) +set(LLVM_TARGET_DEFINITIONS PluginOps.td) +mlir_tablegen(PluginOpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(PluginOpsEnums.cpp.inc -gen-enum-defs) add_mlir_dialect(PluginOps Plugin) # Necessary to generate documentation. arg:(doc_filename command output_file output_directory) diff --git a/include/Dialect/PluginDialect.td b/include/Dialect/PluginDialect.td index 362a5c8..31b338d 100644 --- a/include/Dialect/PluginDialect.td +++ b/include/Dialect/PluginDialect.td @@ -46,4 +46,42 @@ def Plugin_Dialect : Dialect { class Plugin_Op traits = []> : Op; +//===----------------------------------------------------------------------===// +// PluginDialect enum definitions +//===----------------------------------------------------------------------===// + +def IComparisonLT : I32EnumAttrCase<"lt", 0>; +def IComparisonLE : I32EnumAttrCase<"le", 1>; +def IComparisonGT : I32EnumAttrCase<"gt", 2>; +def IComparisonGE : I32EnumAttrCase<"ge", 3>; +def IComparisonLTGT : I32EnumAttrCase<"ltgt", 4>; +def IComparisonEQ : I32EnumAttrCase<"eq", 5>; +def IComparisonNE : I32EnumAttrCase<"ne", 6>; +def IComparisonUNDEF : I32EnumAttrCase<"UNDEF", 7>; +def IComparisonAttr : I32EnumAttr< + "IComparisonCode", "plugin comparison code", + [IComparisonLT, IComparisonLE, IComparisonGT, IComparisonGE, + IComparisonLTGT, IComparisonEQ, IComparisonNE, IComparisonUNDEF]>{ + let cppNamespace = "::mlir::Plugin"; +} + +def IDefineCodeMemRef : I32EnumAttrCase<"MemRef", 0>; +def IDefineCodeIntCST : I32EnumAttrCase<"IntCST", 1>; +def IDefineCodeUNDEF : I32EnumAttrCase<"UNDEF", 2>; +def IDefineCodeAttr : I32EnumAttr< + "IDefineCode", "plugin define code", + [IDefineCodeMemRef, IDefineCodeIntCST, IDefineCodeUNDEF]>{ + let cppNamespace = "::mlir::Plugin"; +} + +def IExprCodePlus : I32EnumAttrCase<"Plus", 0>; +def IExprCodeMinus : I32EnumAttrCase<"Minus", 1>; +def IExprCodeNop : I32EnumAttrCase<"Nop", 2>; +def IExprCodeUNDEF : I32EnumAttrCase<"UNDEF", 3>; +def IExprCodeAttr : I32EnumAttr< + "IExprCode", "plugin expr code", + [IExprCodePlus, IExprCodeMinus, IExprCodeNop, IExprCodeUNDEF]>{ + let cppNamespace = "::mlir::Plugin"; +} + #endif // PLUGIN_DIALECT_TD \ No newline at end of file diff --git a/include/Dialect/PluginOps.h b/include/Dialect/PluginOps.h index be2a3ef..8bab877 100644 --- a/include/Dialect/PluginOps.h +++ b/include/Dialect/PluginOps.h @@ -27,6 +27,10 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/CallInterfaces.h" + +// Pull in all enum type definitions and utility function declarations. +#include "Dialect/PluginOpsEnums.h.inc" #define GET_OP_CLASSES #include "Dialect/PluginOps.h.inc" diff --git a/include/Dialect/PluginOps.td b/include/Dialect/PluginOps.td index 02d9bdd..e91cd84 100644 --- a/include/Dialect/PluginOps.td +++ b/include/Dialect/PluginOps.td @@ -28,8 +28,8 @@ def FunctionOp : Plugin_Op<"function", [NoSideEffect]> { TODO. }]; - let arguments = (ins OptionalAttr:$id, - OptionalAttr:$funcName, + let arguments = (ins UI64Attr:$id, + StrAttr:$funcName, OptionalAttr:$declaredInline); let regions = (region AnyRegion:$bodyRegion); @@ -89,4 +89,142 @@ def LoopOp : Plugin_Op<"loop", [NoSideEffect]> { }]; } +def CallOp : Plugin_Op<"call", [ + DeclareOpInterfaceMethods]> { + let summary = "call operation"; + let description = [{ + CallOp represent calls to a user defined function that needs to + be specialized for the shape of its arguments. + The callee name is attached as a symbol reference via an attribute. + The arguments list must match the arguments expected by the callee. + }]; + let arguments = (ins UI64Attr:$id, + FlatSymbolRefAttr:$callee, + Variadic:$inputs); + let results = (outs Optional:$result); + let builders = [ + OpBuilderDAG<(ins "uint64_t":$id, "StringRef":$callee, + "ArrayRef":$arguments)> + ]; + let extraClassDeclaration = [{ + bool SetLHS(Value lhs); + }]; + let assemblyFormat = [{ + $callee `(` $inputs `)` attr-dict `:` functional-type($inputs, results) + }]; +} + +def PhiOp : Plugin_Op<"phi", [NoSideEffect]> { + let summary = "phi op"; + let description = [{TODO}]; + let arguments = (ins UI64Attr:$id, + UI32Attr:$capacity, + UI32Attr:$nArgs, + Variadic:$operands); + let results = (outs AnyType:$result); + let builders = [ + OpBuilderDAG<(ins "uint64_t":$id, "uint32_t":$capacity, + "uint32_t":$nArgs, "ArrayRef":$operands, + "Type":$resultType)> + ]; + let extraClassDeclaration = [{ + Value GetResult(); + Value GetArgDef(int i) { return getOperand(i); } + }]; +} + +def AssignOp : Plugin_Op<"assign", [NoSideEffect]> { + let summary = "assign op"; + let description = [{TODO}]; + let arguments = (ins UI64Attr:$id, + IExprCodeAttr:$exprCode, + Variadic:$operands); + let results = (outs AnyType:$result); + let builders = [ + OpBuilderDAG<(ins "uint64_t":$id, "IExprCode":$exprCode, + "ArrayRef":$operands, "Type":$resultType)> + ]; + let extraClassDeclaration = [{ + Value GetLHS() { return getOperand(0); } + Value GetRHS1() { return getOperand(1); } + Value GetRHS2() { return getOperand(2); } + }]; +} + +def PlaceholderOp : Plugin_Op<"placeholder", [NoSideEffect]> { + let summary = "PlaceHolder."; + let description = [{TODO}]; + let arguments = (ins UI64Attr:$id, + OptionalAttr:$defCode); + let results = (outs AnyType); + let builders = [ + OpBuilderDAG<(ins "uint64_t":$id, "IDefineCode":$defCode, "Type":$retType)> + ]; +} + +def BaseOp : Plugin_Op<"statement_base", [NoSideEffect]> { + let summary = "Base operation, just like placeholder for statement."; + let description = [{TODO}]; + let arguments = (ins UI64Attr:$id, StrAttr:$opCode); + let results = (outs AnyType); + let builders = [ + OpBuilderDAG<(ins "uint64_t":$id, "StringRef":$opCode)> + ]; +} + +// Terminators +// Opaque builder used for terminator operations that contain successors. + +class Plugin_TerminatorOp traits = []> : + Plugin_Op; + +def FallThroughOp : Plugin_TerminatorOp<"fallthrough", [NoSideEffect]> { + let summary = "FallThroughOp"; + let description = [{TODO}]; + let successors = (successor AnySuccessor:$dest); + // for bb address + let arguments = (ins UI64Attr:$address, UI64Attr:$destaddr); + let results = (outs AnyType); + let builders = [ + OpBuilderDAG<(ins "uint64_t":$address, "Block*":$dest, "uint64_t":$destaddr)> + ]; +} + +def CondOp : Plugin_TerminatorOp<"condition", [NoSideEffect]> { + let summary = "condition op"; + let description = [{TODO}]; + let arguments = (ins UI64Attr:$id, UI64Attr:$address, + IComparisonAttr:$condCode, + AnyType:$LHS, AnyType:$RHS, + UI64Attr:$tbaddr, + UI64Attr:$fbaddr, + OptionalAttr:$trueLabel, + OptionalAttr:$falseLabel); + let successors = (successor AnySuccessor:$tb, AnySuccessor:$fb); + let builders = [ + OpBuilderDAG<(ins "uint64_t":$id, "uint64_t":$address, "IComparisonCode":$condCode, + "Value":$lhs, "Value":$rhs, "Block*":$tb, "Block*":$fb, + "uint64_t":$tbaddr, "uint64_t":$fbaddr, "Value":$trueLabel, + "Value":$falseLabel)>, + // Only for server. + OpBuilderDAG<(ins "IComparisonCode":$condCode, + "Value":$lhs, "Value":$rhs)> + ]; + let extraClassDeclaration = [{ + Value GetLHS() { return getOperand(0); } + Value GetRHS() { return getOperand(1); } + }]; +} + +// todo: currently RetOp do not have a correct assemblyFormat +def RetOp : Plugin_TerminatorOp<"ret", [NoSideEffect]> { + let summary = "RetOp"; + let description = [{TODO}]; + let arguments = (ins UI64Attr:$address); // for bb address + let results = (outs AnyType); + let builders = [ + OpBuilderDAG<(ins "uint64_t":$address)> + ]; +} + #endif // PLUGIN_OPS_TD \ No newline at end of file diff --git a/include/PluginAPI/BasicPluginOpsAPI.h b/include/PluginAPI/BasicPluginOpsAPI.h index fde5f63..6f5fbb0 100644 --- a/include/PluginAPI/BasicPluginOpsAPI.h +++ b/include/PluginAPI/BasicPluginOpsAPI.h @@ -48,6 +48,11 @@ public: virtual pair LoopSingleExit(uint64_t) = 0; virtual vector > GetLoopExitEdges(uint64_t) = 0; virtual LoopOp GetBlockLoopFather(uint64_t) = 0; + virtual PhiOp GetPhiOp(uint64_t) = 0; + virtual CallOp GetCallOp(uint64_t) = 0; + virtual bool SetLhsInCallOp(uint64_t, uint64_t) = 0; + virtual uint64_t CreateCondOp(IComparisonCode, uint64_t, uint64_t) = 0; + virtual mlir::Value GetResultFromPhi(uint64_t) = 0; }; // class BasicPluginOpsAPI } // namespace PluginAPI diff --git a/include/PluginAPI/PluginServerAPI.h b/include/PluginAPI/PluginServerAPI.h index 5b14f2e..425b946 100644 --- a/include/PluginAPI/PluginServerAPI.h +++ b/include/PluginAPI/PluginServerAPI.h @@ -38,6 +38,8 @@ public: vector GetAllFunc() override; vector GetDecls(uint64_t) override; + PhiOp GetPhiOp(uint64_t) override; + CallOp GetCallOp(uint64_t) override; PluginIR::PluginTypeID GetTypeCodeFromString(string type); LoopOp AllocateNewLoop(uint64_t funcID); LoopOp GetLoopById(uint64_t loopID); @@ -51,8 +53,14 @@ public: uint64_t GetLatch(uint64_t loopID); vector GetLoopBody(uint64_t loopID); LoopOp GetBlockLoopFather(uint64_t blockID); + /* Plugin API for CallOp. */ + bool SetLhsInCallOp(uint64_t, uint64_t); + /* Plugin API for CondOp. */ + uint64_t CreateCondOp(IComparisonCode, uint64_t, uint64_t) override; + mlir::Value GetResultFromPhi(uint64_t) override; + private: - vector GetOperationResult(const string& funName, const string& params); + vector GetFunctionOpResult(const string& funName, const string& params); vector GetDeclOperationResult(const string& funName, const string& params); LoopOp GetLoopResult(const string&funName, const string& params); vector GetLoopsResult(const string& funName, const string& params); diff --git a/include/PluginServer/PluginServer.h b/include/PluginServer/PluginServer.h index 610ecb9..b28c6d0 100644 --- a/include/PluginServer/PluginServer.h +++ b/include/PluginServer/PluginServer.h @@ -33,6 +33,7 @@ #include "Dialect/PluginOps.h" #include "plugin.grpc.pb.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Builders.h" #include "Dialect/PluginTypes.h" namespace PinServer { @@ -120,6 +121,7 @@ private: class PluginServer final : public PluginService::Service { public: + PluginServer() : opBuilder(&context){} /* 定义的grpc服务端和客户端通信的接口函数 */ Status ReceiveSendMsg(ServerContext* context, ServerReaderWriter* stream) override; /* 服务端发送数据给client接口 */ @@ -137,6 +139,10 @@ public: vector > EdgesResult(void); std::pair EdgeResult(void); bool BoolResult(void); + vector GetOpResult(void); + bool GetBoolResult(void); + uint64_t GetIdResult(void); + mlir::Value GetValueResult(void); /* 回调函数接口,用于向server注册用户需要执行的函数 */ int RegisterUserFunc(InjectPoint inject, UserFunc func); int RegisterPassManagerSetup(InjectPoint inject, const ManagerSetupData& passData, UserFunc func); @@ -190,6 +196,13 @@ public: void EdgeJsonDeSerialize(const string& data); void BlocksJsonDeSerialize(const string& data); void BlockJsonDeSerialize(const string& data); + void CallOpJsonDeSerialize(const string& data); + void CondOpJsonDeSerialize(const string& data); + void RetOpJsonDeSerialize(const string& data); + void FallThroughOpJsonDeSerialize(const string& data); + void PhiOpJsonDeSerialize(const string& data); + void AssignOpJsonDeSerialize(const string& data); + mlir::Value ValueJsonDeSerialize(Json::Value valueJson); /* json反序列化,根据key值分别调用Operation/Decl/Type反序列化接口函数 */ void JsonDeSerialize(const string& key, const string& data); /* 解析客户端发送过来的-fplugin-arg参数,并保存在私有变量args中 */ @@ -234,6 +247,7 @@ private: /* 用户函数执行状态,client返回结果后为STATE_RETURN,开始执行下一个函数 */ volatile UserFunStateEnum userFunState; mlir::MLIRContext context; + mlir::OpBuilder opBuilder; vector funcOpData; PluginIR::PluginTypeBase pluginType; vector decls; @@ -244,6 +258,10 @@ private: vector blockIds; uint64_t blockId; bool boolRes; + vector opData; + bool boolResult; + bool idResult; + mlir::Value valueResult; /* 保存用户注册的回调函数,它们将在注入点事件触发后调用 */ map> userFunc; string apiFuncName; // 保存用户调用PluginAPI的函数名 @@ -252,6 +270,10 @@ private: timer_t timerId; map args; // 保存gcc编译时用户传入参数 sem_t sem[2]; + + // process Block. + std::map blockMaps; + bool ProcessBlock(mlir::Block*, mlir::Region&, const Json::Value&); }; // class PluginServer void RunServer(int timeout, string& port); diff --git a/lib/Dialect/PluginOps.cpp b/lib/Dialect/PluginOps.cpp index 481d51a..d647fc4 100644 --- a/lib/Dialect/PluginOps.cpp +++ b/lib/Dialect/PluginOps.cpp @@ -33,8 +33,9 @@ using namespace mlir::Plugin; using std::vector; using std::pair; -void FunctionOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - uint64_t id, StringRef funcName, bool declaredInline) { +void FunctionOp::build(OpBuilder &builder, OperationState &state, + uint64_t id, StringRef funcName, bool declaredInline) +{ FunctionOp::build(builder, state, builder.getI64IntegerAttr(id), builder.getStringAttr(funcName), @@ -48,9 +49,10 @@ vector FunctionOp::GetAllLoops() return pluginAPI.GetLoopsFromFunc(funcId); } -void LocalDeclOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, +void LocalDeclOp::build(OpBuilder &builder, OperationState &state, uint64_t id, StringRef symName, - int64_t typeID, uint64_t typeWidth) { + int64_t typeID, uint64_t typeWidth) +{ LocalDeclOp::build(builder, state, builder.getI64IntegerAttr(id), builder.getStringAttr(symName), @@ -142,6 +144,146 @@ void LoopOp::AddLoop(uint64_t outerId, uint64_t funcId) return pluginAPI.AddLoop(loopId, outerId, funcId); } +//===----------------------------------------------------------------------===// +// PlaceholderOp + +void PlaceholderOp::build(OpBuilder &builder, OperationState &state, + uint64_t id, IDefineCode defCode, Type retType) { + state.addAttribute("id", builder.getI64IntegerAttr(id)); + state.addAttribute("defCode", + builder.getI32IntegerAttr(static_cast(defCode))); + if (retType) state.addTypes(retType); +} + +//===----------------------------------------------------------------------===// +// CallOp + +void CallOp::build(OpBuilder &builder, OperationState &state, + uint64_t id, StringRef callee, + ArrayRef arguments) +{ + state.addAttribute("id", builder.getI64IntegerAttr(id)); + state.addOperands(arguments); + state.addAttribute("callee", builder.getSymbolRefAttr(callee)); +} + +/// Return the callee of the generic call operation, this is required by the +/// call interface. +CallInterfaceCallable CallOp::getCallableForCallee() +{ + return (*this)->getAttrOfType("callee"); +} + +/// Get the argument operands to the called function, this is required by the +/// call interface. +Operation::operand_range CallOp::getArgOperands() { return inputs(); } + +bool CallOp::SetLHS(Value lhs) +{ + PlaceholderOp phOp = lhs.getDefiningOp(); + uint64_t lhsId = phOp.idAttr().getInt(); + PluginAPI::PluginServerAPI pluginAPI; + return pluginAPI.SetLhsInCallOp(this->idAttr().getInt(), lhsId); +} + +//===----------------------------------------------------------------------===// +// CondOp + +void CondOp::build(OpBuilder &builder, OperationState &state, + uint64_t id, uint64_t address, IComparisonCode condCode, + Value lhs, Value rhs, Block* tb, Block* fb, uint64_t tbaddr, + uint64_t fbaddr, Value trueLabel, Value falseLabel) { + state.addAttribute("id", builder.getI64IntegerAttr(id)); + state.addAttribute("address", builder.getI64IntegerAttr(address)); + state.addAttribute("tbaddr", builder.getI64IntegerAttr(tbaddr)); + state.addAttribute("fbaddr", builder.getI64IntegerAttr(fbaddr)); + state.addOperands({lhs, rhs}); + state.addSuccessors(tb); + state.addSuccessors(fb); + state.addAttribute("condCode", + builder.getI32IntegerAttr(static_cast(condCode))); + if (trueLabel != nullptr) state.addOperands(trueLabel); + if (falseLabel != nullptr) state.addOperands(falseLabel); +} + +void CondOp::build(OpBuilder &builder, OperationState &state, + IComparisonCode condCode, Value lhs, Value rhs) +{ + PluginAPI::PluginServerAPI pluginAPI; + PlaceholderOp lhsOp = lhs.getDefiningOp(); + uint64_t lhsId = lhsOp.idAttr().getInt(); + PlaceholderOp rhsOp = rhs.getDefiningOp(); + uint64_t rhsId = rhsOp.idAttr().getInt(); + uint64_t id = pluginAPI.CreateCondOp(condCode, lhsId, rhsId); + state.addAttribute("id", builder.getI64IntegerAttr(id)); + state.addOperands({lhs, rhs}); + state.addAttribute("condCode", + builder.getI32IntegerAttr(static_cast(condCode))); +} + +//===----------------------------------------------------------------------===// +// PhiOp + +void PhiOp::build(OpBuilder &builder, OperationState &state, + uint64_t id, uint32_t capacity, uint32_t nArgs, + ArrayRef operands, Type resultType) +{ + state.addAttribute("id", builder.getI64IntegerAttr(id)); + state.addAttribute("capacity", builder.getI32IntegerAttr(capacity)); + state.addAttribute("nArgs", builder.getI32IntegerAttr(nArgs)); + state.addOperands(operands); + if (resultType) state.addTypes(resultType); +} + +Value PhiOp::GetResult() +{ + PluginAPI::PluginServerAPI pluginAPI; + return pluginAPI.GetResultFromPhi(this->idAttr().getInt()); +} + +//===----------------------------------------------------------------------===// +// AssignOp + +void AssignOp::build(OpBuilder &builder, OperationState &state, + uint64_t id, IExprCode exprCode, + ArrayRef operands, Type resultType) +{ + state.addAttribute("id", builder.getI64IntegerAttr(id)); + state.addAttribute("exprCode", + builder.getI32IntegerAttr(static_cast(exprCode))); + state.addOperands(operands); + if (resultType) state.addTypes(resultType); +} + +//===----------------------------------------------------------------------===// +// BaseOp + +void BaseOp::build(OpBuilder &builder, OperationState &state, + uint64_t id, StringRef opCode) +{ + state.addAttribute("id", builder.getI64IntegerAttr(id)); + state.addAttribute("opCode", builder.getStringAttr(opCode)); +} + +//===----------------------------------------------------------------------===// +// FallThroughOp + +void FallThroughOp::build(OpBuilder &builder, OperationState &state, + uint64_t address, Block* dest, uint64_t destaddr) +{ + state.addAttribute("address", builder.getI64IntegerAttr(address)); + state.addAttribute("destaddr", builder.getI64IntegerAttr(destaddr)); + state.addSuccessors(dest); +} + +//===----------------------------------------------------------------------===// +// RetOp + +void RetOp::build(OpBuilder &builder, OperationState &state, uint64_t address) +{ + state.addAttribute("address", builder.getI64IntegerAttr(address)); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/lib/PluginAPI/PluginServerAPI.cpp b/lib/PluginAPI/PluginServerAPI.cpp index 41626f9..65eafa7 100644 --- a/lib/PluginAPI/PluginServerAPI.cpp +++ b/lib/PluginAPI/PluginServerAPI.cpp @@ -54,7 +54,7 @@ void PluginServerAPI::WaitClientResult(const string& funName, const string& para } } -vector PluginServerAPI::GetOperationResult(const string& funName, const string& params) +vector PluginServerAPI::GetFunctionOpResult(const string& funName, const string& params) { WaitClientResult(funName, params); vector retOps = PluginServer::GetInstance()->GetFunctionOpResult(); @@ -67,7 +67,63 @@ vector PluginServerAPI::GetAllFunc() string funName = __func__; string params = root.toStyledString(); - return GetOperationResult(funName, params); + return GetFunctionOpResult(funName, params); +} + +PhiOp PluginServerAPI::GetPhiOp(uint64_t id) +{ + Json::Value root; + string funName = __func__; + root["id"] = std::to_string(id); + string params = root.toStyledString(); + WaitClientResult(funName, params); + vector opRet = PluginServer::GetInstance()->GetOpResult(); + return llvm::dyn_cast(opRet[0]); +} + +CallOp PluginServerAPI::GetCallOp(uint64_t id) +{ + Json::Value root; + string funName = __func__; + root["id"] = std::to_string(id); + string params = root.toStyledString(); + WaitClientResult(funName, params); + vector opRet = PluginServer::GetInstance()->GetOpResult(); + return llvm::dyn_cast(opRet[0]); +} + +bool PluginServerAPI::SetLhsInCallOp(uint64_t callId, uint64_t lhsId) +{ + Json::Value root; + string funName = __func__; + root["callId"] = std::to_string(callId); + root["lhsId"] = std::to_string(lhsId); + string params = root.toStyledString(); + WaitClientResult(funName, params); + return PluginServer::GetInstance()->GetBoolResult(); +} + +uint64_t PluginServerAPI::CreateCondOp(IComparisonCode iCode, + uint64_t lhs, uint64_t rhs) +{ + Json::Value root; + string funName = __func__; + root["condCode"] = std::to_string(static_cast(iCode)); + root["lhsId"] = std::to_string(lhs); + root["rhsId"] = std::to_string(rhs); + string params = root.toStyledString(); + WaitClientResult(funName, params); + return PluginServer::GetInstance()->GetIdResult(); +} + +mlir::Value PluginServerAPI::GetResultFromPhi(uint64_t phiId) +{ + Json::Value root; + string funName = __func__; + root["id"] = std::to_string(phiId); + string params = root.toStyledString(); + WaitClientResult(funName, params); + return PluginServer::GetInstance()->GetValueResult(); } PluginIR::PluginTypeID PluginServerAPI::GetTypeCodeFromString(string type) diff --git a/lib/PluginServer/PluginServer.cpp b/lib/PluginServer/PluginServer.cpp index 92159e8..2eac906 100644 --- a/lib/PluginServer/PluginServer.cpp +++ b/lib/PluginServer/PluginServer.cpp @@ -27,7 +27,6 @@ #include "Dialect/PluginDialect.h" #include "PluginAPI/PluginServerAPI.h" #include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "user.h" @@ -37,7 +36,6 @@ namespace PinServer { using namespace mlir::Plugin; - using std::cout; using std::endl; using std::pair; @@ -75,6 +73,13 @@ int PluginServer::RegisterPassManagerSetup(InjectPoint inject, const ManagerSetu return 0; } +vector PluginServer::GetOpResult(void) +{ + vector retOps = opData; + opData.clear(); + return retOps; +} + vector PluginServer::GetFunctionOpResult(void) { vector retOps = funcOpData; @@ -135,6 +140,21 @@ vector > PluginServer::EdgesResult() return retEdges; } +bool PluginServer::GetBoolResult() +{ + return this->boolResult; +} + +uint64_t PluginServer::GetIdResult() +{ + return this->idResult; +} + +mlir::Value PluginServer::GetValueResult() +{ + return this->valueResult; +} + void PluginServer::JsonGetAttributes(Json::Value node, map& attributes) { Json::Value::Members attMember = node.getMemberNames(); @@ -151,6 +171,27 @@ static uintptr_t GetID(Json::Value node) return atol(id.c_str()); } +mlir::Value PluginServer::ValueJsonDeSerialize(Json::Value valueJson) +{ + uint64_t opId = GetID(valueJson["id"]); + IDefineCode defCode = IDefineCode( + atoi(valueJson["defCode"].asString().c_str())); + switch (defCode) { + case IDefineCode::MemRef : { + break; + } + case IDefineCode::IntCST : { + break; + } + default: + break; + } + mlir::Type retType = PluginIR::PluginUndefType::get(&context); // FIXME! + mlir::Value opValue = opBuilder.create( + opBuilder.getUnknownLoc(), opId, defCode, retType); + return opValue; +} + void PluginServer::JsonDeSerialize(const string& key, const string& data) { if (key == "FuncOpResult") { @@ -173,6 +214,12 @@ void PluginServer::JsonDeSerialize(const string& key, const string& data) EdgesJsonDeSerialize(data); } else if (key == "BlockIdsResult") { BlocksJsonDeSerialize(data); + } else if (key == "IdResult") { + this->idResult = atol(data.c_str()); + } else if (key == "ValueResult") { + context.getOrLoadDialect(); + opBuilder = mlir::OpBuilder(&context); + this->valueResult = ValueJsonDeSerialize(data.c_str()); } else { cout << "not Json,key:" << key << ",value:" << data << endl; } @@ -217,6 +264,46 @@ void PluginServer::TypeJsonDeSerialize(const string& data) return; } +bool PluginServer::ProcessBlock(mlir::Block* block, mlir::Region& rg, + const Json::Value& blockJson) +{ + if (blockJson.isNull()) { + return false; + } + // todo process func return type + // todo isDeclaration + + // process each stmt + opBuilder.setInsertionPointToStart(block); + Json::Value::Members opMember = blockJson.getMemberNames(); + for (Json::Value::Members::iterator opIdx = opMember.begin(); + opIdx != opMember.end(); opIdx++) { + string baseOpKey = *opIdx; + Json::Value opJson = blockJson[baseOpKey]; + if (opJson.isNull()) continue; + // llvm::StringRef opCode(opJson["OperationName"].asString().c_str()); + string opCode = opJson["OperationName"].asString(); + if (opCode == PhiOp::getOperationName().str()) { + PhiOpJsonDeSerialize(opJson.toStyledString()); + } else if (opCode == CallOp::getOperationName().str()) { + CallOpJsonDeSerialize(opJson.toStyledString()); + } else if (opCode == AssignOp::getOperationName().str()) { + AssignOpJsonDeSerialize(opJson.toStyledString()); + } else if (opCode == CondOp::getOperationName().str()) { + CondOpJsonDeSerialize(opJson.toStyledString()); + } else if (opCode == RetOp::getOperationName().str()) { + RetOpJsonDeSerialize(opJson.toStyledString()); + } else if (opCode == FallThroughOp::getOperationName().str()) { + FallThroughOpJsonDeSerialize(opJson.toStyledString()); + } else if (opCode == BaseOp::getOperationName().str()) { + uint64_t opID = GetID(opJson["id"]); + opBuilder.create(opBuilder.getUnknownLoc(), opID, opCode); + } + } + // fprintf(stderr, "[bb] op:%ld, succ: %d\n", block->getOperations().size(), block->getNumSuccessors()); + return true; +} + void PluginServer::FuncOpJsonDeSerialize(const string& data) { Json::Value root; @@ -227,8 +314,9 @@ void PluginServer::FuncOpJsonDeSerialize(const string& data) Json::Value::Members operation = root.getMemberNames(); context.getOrLoadDialect(); - mlir::OpBuilder builder(&context); - for (Json::Value::Members::iterator iter = operation.begin(); iter != operation.end(); iter++) { + opBuilder = mlir::OpBuilder(&context); + for (Json::Value::Members::iterator iter = operation.begin(); + iter != operation.end(); iter++) { string operationKey = *iter; node = root[operationKey]; int64_t id = GetID(node["id"]); @@ -237,9 +325,30 @@ void PluginServer::FuncOpJsonDeSerialize(const string& data) JsonGetAttributes(attributes, funcAttributes); bool declaredInline = false; if (funcAttributes["declaredInline"] == "1") declaredInline = true; - auto location = builder.getUnknownLoc(); - FunctionOp op = builder.create(location, id, funcAttributes["funcName"], declaredInline); - funcOpData.push_back(op); + auto location = opBuilder.getUnknownLoc(); + FunctionOp fOp = opBuilder.create( + location, id, funcAttributes["funcName"], declaredInline); + mlir::Region &bodyRegion = fOp.bodyRegion(); + Json::Value regionJson = node["region"]; + Json::Value::Members bbMember = regionJson.getMemberNames(); + // We must create Blocks before process ops + for (Json::Value::Members::iterator bbIdx = bbMember.begin(); + bbIdx != bbMember.end(); bbIdx++) { + string blockKey = *bbIdx; + Json::Value blockJson = regionJson[blockKey]; + mlir::Block* block = opBuilder.createBlock(&bodyRegion); + this->blockMaps.insert({GetID(blockJson["address"]), block}); + } + + for (Json::Value::Members::iterator bbIdx = bbMember.begin(); + bbIdx != bbMember.end(); bbIdx++) { + string blockKey = *bbIdx; + Json::Value blockJson = regionJson[blockKey]; + uint64_t bbAddress = GetID(blockJson["address"]); + ProcessBlock(this->blockMaps[bbAddress], bodyRegion, blockJson["ops"]); + } + funcOpData.push_back(fOp); + opBuilder.setInsertionPointAfter(fOp.getOperation()); } } @@ -253,7 +362,7 @@ void PluginServer::LocalDeclOpJsonDeSerialize(const string& data) Json::Value::Members operation = root.getMemberNames(); context.getOrLoadDialect(); - mlir::OpBuilder builder(&context); + opBuilder = mlir::OpBuilder(&context); for (Json::Value::Members::iterator iter = operation.begin(); iter != operation.end(); iter++) { string operationKey = *iter; node = root[operationKey]; @@ -261,11 +370,11 @@ void PluginServer::LocalDeclOpJsonDeSerialize(const string& data) Json::Value attributes = node["attributes"]; map declAttributes; JsonGetAttributes(attributes, declAttributes); - string symName = declAttributes["symName"]; - uint64_t typeID = atol(declAttributes["typeID"].c_str()); - uint64_t typeWidth = atol(declAttributes["typeWidth"].c_str()); - auto location = builder.getUnknownLoc(); - LocalDeclOp op = builder.create(location, id, symName, typeID, typeWidth); + string symName = declAttributes["symName"]; + uint64_t typeID = atol(declAttributes["typeID"].c_str()); + uint64_t typeWidth = atol(declAttributes["typeWidth"].c_str()); + auto location = opBuilder.getUnknownLoc(); + LocalDeclOp op = opBuilder.create(location, id, symName, typeID, typeWidth); decls.push_back(op); } } @@ -384,6 +493,121 @@ void PluginServer::BlockJsonDeSerialize(const string& data) blockId = (uint64_t)atol(root["id"].asString().c_str()); } +void PluginServer::CallOpJsonDeSerialize(const string& data) +{ + Json::Value node; + Json::Reader reader; + reader.parse(data, node); + Json::Value operandJson = node["operands"]; + Json::Value::Members operandMember = operandJson.getMemberNames(); + llvm::SmallVector ops; + for (Json::Value::Members::iterator opIter = operandMember.begin(); + opIter != operandMember.end(); opIter++) { + string key = *opIter; + mlir::Value opValue = ValueJsonDeSerialize(operandJson[key.c_str()]); + ops.push_back(opValue); + } + int64_t id = GetID(node["id"]); + mlir::StringRef callName(node["callee"].asString()); + CallOp op = opBuilder.create(opBuilder.getUnknownLoc(), + id, callName, ops); + opData.push_back(op.getOperation()); +} + +void PluginServer::CondOpJsonDeSerialize(const string& data) +{ + Json::Value node; + Json::Reader reader; + reader.parse(data, node); + mlir::Value LHS = ValueJsonDeSerialize(node["lhs"]); + mlir::Value RHS = ValueJsonDeSerialize(node["rhs"]); + mlir::Value trueLabel = nullptr; + mlir::Value falseLabel = nullptr; + int64_t id = GetID(node["id"]); + int64_t address = GetID(node["address"]); + int64_t tbaddr = GetID(node["tbaddr"]); + int64_t fbaddr = GetID(node["fbaddr"]); + assert (this->blockMaps.find(tbaddr) != this->blockMaps.end()); + assert (this->blockMaps.find(fbaddr) != this->blockMaps.end()); + mlir::Block* tb = this->blockMaps[tbaddr]; + mlir::Block* fb = this->blockMaps[fbaddr]; + IComparisonCode iCode = IComparisonCode( + atoi(node["condCode"].asString().c_str())); + CondOp op = opBuilder.create(opBuilder.getUnknownLoc(), id, + address, iCode, LHS, RHS, tb, fb, tbaddr, fbaddr, + trueLabel, falseLabel); + opData.push_back(op.getOperation()); +} + +void PluginServer::RetOpJsonDeSerialize(const string& data) +{ + Json::Value node; + Json::Reader reader; + reader.parse(data, node); + int64_t address = GetID(node["address"]); + RetOp op = opBuilder.create(opBuilder.getUnknownLoc(), address); + opData.push_back(op.getOperation()); +} + +void PluginServer::FallThroughOpJsonDeSerialize(const string& data) +{ + Json::Value node; + Json::Reader reader; + reader.parse(data, node); + int64_t address = GetID(node["address"]); + int64_t destaddr = GetID(node["destaddr"]); + assert (this->blockMaps.find(destaddr) != this->blockMaps.end()); + mlir::Block* succ = this->blockMaps[destaddr]; + FallThroughOp op = opBuilder.create(opBuilder.getUnknownLoc(), + address, succ, destaddr); + opData.push_back(op.getOperation()); +} + +void PluginServer::PhiOpJsonDeSerialize(const string& data) +{ + Json::Value node; + Json::Reader reader; + reader.parse(data, node); + Json::Value operandJson = node["operands"]; + Json::Value::Members operandMember = operandJson.getMemberNames(); + llvm::SmallVector ops; + for (Json::Value::Members::iterator opIter = operandMember.begin(); + opIter != operandMember.end(); opIter++) { + string key = *opIter; + mlir::Value opValue = ValueJsonDeSerialize(operandJson[key.c_str()]); + ops.push_back(opValue); + } + int64_t id = GetID(node["id"]); + uint32_t capacity = atoi(node["capacity"].asString().c_str()); + uint32_t nArgs = atoi(node["nArgs"].asString().c_str()); + mlir::Type retType = nullptr; + PhiOp op = opBuilder.create(opBuilder.getUnknownLoc(), + id, capacity, nArgs, ops, retType); + opData.push_back(op.getOperation()); +} + +void PluginServer::AssignOpJsonDeSerialize(const string& data) +{ + Json::Value node; + Json::Reader reader; + reader.parse(data, node); + Json::Value operandJson = node["operands"]; + Json::Value::Members operandMember = operandJson.getMemberNames(); + llvm::SmallVector ops; + for (Json::Value::Members::iterator opIter = operandMember.begin(); + opIter != operandMember.end(); opIter++) { + string key = *opIter; + mlir::Value opValue = ValueJsonDeSerialize(operandJson[key.c_str()]); + ops.push_back(opValue); + } + int64_t id = GetID(node["id"]); + IExprCode iCode = IExprCode(atoi(node["exprCode"].asString().c_str())); + mlir::Type retType = nullptr; + AssignOp op = opBuilder.create(opBuilder.getUnknownLoc(), + id, iCode, ops, retType); + opData.push_back(op.getOperation()); +} + /* 线程函数,执行用户注册函数,客户端返回数据后退出 */ static void ExecCallbacks(const string& name) { diff --git a/user.cpp b/user.cpp index efc4158..605e6f8 100644 --- a/user.cpp +++ b/user.cpp @@ -117,7 +117,7 @@ ProcessArrayWiden(void) void RegisterCallbacks(void) { - PluginServer::GetInstance()->RegisterUserFunc(HANDLE_BEFORE_IPA, UserOptimizeFunc); + // PluginServer::GetInstance()->RegisterUserFunc(HANDLE_BEFORE_IPA, UserOptimizeFunc); PluginServer::GetInstance()->RegisterUserFunc(HANDLE_BEFORE_IPA, LocalVarSummery); ManagerSetupData setupData; setupData.refPassName = PASS_PHIOPT; -- 2.27.0.windows.1