From 35db41f1fc006aa06fb012ec942d17c93bf0f8d5 Mon Sep 17 00:00:00 2001 From: Amir Ashouri Date: Thu, 22 Aug 2024 23:58:33 -0400 Subject: [PATCH] [ACPO] Introduce MLInliner using ACPO infrastructure This change adds ML model to the inliner for performance optimization. --- .../llvm/Analysis/ACPOCollectFeatures.h | 3 + llvm/include/llvm/Analysis/ACPOFIModel.h | 144 +++++++++ llvm/include/llvm/Analysis/ACPOMLInterface.h | 4 +- llvm/include/llvm/Analysis/ACPOModel.h | 2 + llvm/include/llvm/Analysis/ACPOModelRunner.h | 2 + llvm/include/llvm/Analysis/AOTModelRunner.h | 2 + llvm/include/llvm/Analysis/CallHeight.h | 2 + llvm/include/llvm/Analysis/DumpCallsite.h | 2 + llvm/include/llvm/Analysis/DumpFeature.h | 2 + llvm/include/llvm/Analysis/FIModelRunner.h | 277 ++++++++++++++++++ llvm/include/llvm/Analysis/InlineAdvisor.h | 30 ++ .../llvm/Analysis/InlineModelFeatureMaps.h | 30 ++ llvm/include/llvm/Analysis/MLInlineAdvisor.h | 5 +- llvm/include/llvm/InitializePasses.h | 6 + llvm/include/llvm/Transforms/IPO.h | 4 + llvm/include/llvm/Transforms/IPO/Inliner.h | 9 + llvm/lib/Analysis/ACPOCollectFeatures.cpp | 2 + llvm/lib/Analysis/ACPOFIModel.cpp | 243 +++++++++++++++ llvm/lib/Analysis/ACPOMLInterface.cpp | 2 + llvm/lib/Analysis/ACPOModel.cpp | 2 + llvm/lib/Analysis/CallHeight.cpp | 3 + llvm/lib/Analysis/DumpCallsite.cpp | 2 + llvm/lib/Analysis/DumpFeature.cpp | 3 + llvm/lib/Analysis/InlineAdvisor.cpp | 115 ++++++++ llvm/lib/Analysis/MLInlineAdvisor.cpp | 5 + llvm/lib/IR/AsmWriter.cpp | 35 ++- llvm/lib/Transforms/IPO/Inliner.cpp | 219 +++++++++++++- llvm/tools/opt/opt.cpp | 3 + 28 files changed, 1151 insertions(+), 7 deletions(-) create mode 100644 llvm/include/llvm/Analysis/ACPOFIModel.h create mode 100644 llvm/include/llvm/Analysis/FIModelRunner.h create mode 100644 llvm/lib/Analysis/ACPOFIModel.cpp diff --git a/llvm/include/llvm/Analysis/ACPOCollectFeatures.h b/llvm/include/llvm/Analysis/ACPOCollectFeatures.h index ec62b559542d..8b266b3bc756 100644 --- a/llvm/include/llvm/Analysis/ACPOCollectFeatures.h +++ b/llvm/include/llvm/Analysis/ACPOCollectFeatures.h @@ -10,6 +10,8 @@ // collected on a given ACPOModel class from all available features. // //===----------------------------------------------------------------------===// + +#if defined(ENABLE_ACPO) #ifndef LLVM_ANALYSIS_ACPOCOLLECTFEATURES_H #define LLVM_ANALYSIS_ACPOCOLLECTFEATURES_H #include "llvm/Analysis/InlineAdvisor.h" @@ -294,3 +296,4 @@ operator++(ACPOCollectFeatures::FeatureIndex &, int); } // namespace llvm #endif // LLVM_ANALYSIS_ACPOCOLLECTFEATURES_H +#endif // ENABLE_ACPO diff --git a/llvm/include/llvm/Analysis/ACPOFIModel.h b/llvm/include/llvm/Analysis/ACPOFIModel.h new file mode 100644 index 000000000000..8753dd3d7c63 --- /dev/null +++ b/llvm/include/llvm/Analysis/ACPOFIModel.h @@ -0,0 +1,144 @@ +//===- ACPOFIModel.h - AI-Enabled Continuous Program Optimization ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#if defined(ENABLE_ACPO) +#ifndef LLVM_ANALYSIS_ACPOFIMODEL_H +#define LLVM_ANALYSIS_ACPOFIMODEL_H + +#include "llvm/Analysis/ACPOModel.h" +#include "llvm/Analysis/DumpFeature.h" +#include "llvm/Analysis/FunctionPropertiesAnalysis.h" +#include "llvm/Analysis/InlineAdvisor.h" +#include "llvm/Analysis/InlineSizeEstimatorAnalysis.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" + +#include + +namespace llvm { + +//class ACPOmodel; + +class ACPOFIModel : public ACPOModel { +public: + ACPOFIModel(CallBase *CB, InlineAdvisor *IA, OptimizationRemarkEmitter *ORE, + bool OnlyMandatory, bool UseML = true); + + ~ACPOFIModel(); + + void setMLCustomFeatures( + std::vector> FeatureValues); + + void sendCustomFeatures() override; + + InlineAdvisor *getNotACPOAdvisor(); + + // Recorder's to micmic the behavior for default InlineAdvice. + // If the model is turned off or was decided to fall back to + // default inline advisor then we need to make sure the advice returned + // is properly recorded. Or else there would be an error. + void recordUnattemptedInlining(); + + void recordInlining(); + + void recordUnsuccessfulInlining(InlineResult &IR); + + void recordInliningWithCalleeDeleted(); + + // Interface for IRToPerf Cache system. + struct FunctionFeaturesCache { + using FunctionSizeMap = DenseMap; + using FunctionFloatMap = DenseMap; + + std::array( + ACPOFIExtendedFeatures::NamedFeatureIndex::NumNamedFeatures)> + NamedFeatures; + std::array( + ACPOFIExtendedFeatures::NamedFloatFeatureIndex:: + NumNamedFloatFeatures)> + NamedFloatFeatures; + + FunctionSizeMap &operator[](ACPOFIExtendedFeatures::NamedFeatureIndex Pos) { + return NamedFeatures[static_cast(Pos)]; + } + FunctionFloatMap & + operator[](ACPOFIExtendedFeatures::NamedFloatFeatureIndex Pos) { + return NamedFloatFeatures[static_cast(Pos)]; + } + }; + + struct FunctionAnalysisMap { + DenseMap DomCache; + DenseMap LICache; + DenseMap TTICache; + }; + + // Invalidation mechanisms + static void invalidateCache(CallBase *CB); + + static void invalidateCache(const Function *F); + + static void clearCache(); + + // Getters/setters for the cache system. + static std::optional + getCachedSize(const Function *F, + ACPOFIExtendedFeatures::NamedFeatureIndex idx); + + static std::optional + getCachedFloat(const Function *F, + ACPOFIExtendedFeatures::NamedFloatFeatureIndex idx); + + static void insertSizeCache(const Function *F, + ACPOFIExtendedFeatures::NamedFeatureIndex idx, + size_t val); + + static void + insertFloatCache(const Function *F, + ACPOFIExtendedFeatures::NamedFloatFeatureIndex idx, + float val); + + static const DominatorTree *getDomCachedAnalysis(const Function *F); + + static const LoopInfo *getLICachedAnalysis(const Function *F); + + static const TargetTransformInfo *getTTICachedAnalysis(const Function *F); + + static void insertAnalysisCache(const Function *F, const DominatorTree *Tree); + + static void insertAnalysisCache(const Function *F, const LoopInfo *LI); + + static void insertAnalysisCache(const Function *F, + const TargetTransformInfo *TTI); + +protected: + // Interface to run the MLInference/default advisor and get advice from the + // model/default advisor + virtual std::unique_ptr getAdviceML() override; + + virtual std::unique_ptr getAdviceNoML() override; + +private: + static FunctionFeaturesCache FeatureCache; + static FunctionAnalysisMap FunctionAnalysisCache; + CallBase *CurrentCB = nullptr; + InlineAdvisor *NotACPOAdvisor = nullptr; + bool ShouldInline = false; + bool OnlyMandatory = false; + std::unique_ptr NotACPOAdvice = nullptr; + std::vector> CustomFeatureValues; +}; + +} // end namespace llvm + +#endif // LLVM_ANALYSIS_ACPOFIMODEL_H + +#endif // ENABLE_ACPO diff --git a/llvm/include/llvm/Analysis/ACPOMLInterface.h b/llvm/include/llvm/Analysis/ACPOMLInterface.h index 996f27ee32ba..fbc8a46b3d9a 100644 --- a/llvm/include/llvm/Analysis/ACPOMLInterface.h +++ b/llvm/include/llvm/Analysis/ACPOMLInterface.h @@ -4,10 +4,9 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Copyright (C) 2021-2022. Huawei Technologies Co., Ltd. All rights reserved. -// //===----------------------------------------------------------------------===// +#if defined(ENABLE_ACPO) #ifndef LLVM_ANALYSIS_ACPOML_INTERFACE_H #define LLVM_ANALYSIS_ACPOML_INTERFACE_H @@ -480,3 +479,4 @@ std::shared_ptr createPersistentCompiledMLIF(); } // namespace llvm #endif // LLVM_ANALYSIS_ACPOML_INTERFACE_H +#endif // ENABLE_ACPO diff --git a/llvm/include/llvm/Analysis/ACPOModel.h b/llvm/include/llvm/Analysis/ACPOModel.h index 34dbc0fdb8bf..d61ac00efaaf 100644 --- a/llvm/include/llvm/Analysis/ACPOModel.h +++ b/llvm/include/llvm/Analysis/ACPOModel.h @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#if defined(ENABLE_ACPO) #ifndef LLVM_ANALYSIS_ACPOMODEL_H #define LLVM_ANALYSIS_ACPOMODEL_H @@ -120,3 +121,4 @@ private: } // namespace llvm #endif // LLVM_ANALYSIS_ACPOMODEL_H +#endif // ENABLE_ACPO diff --git a/llvm/include/llvm/Analysis/ACPOModelRunner.h b/llvm/include/llvm/Analysis/ACPOModelRunner.h index 819e17f71103..044f3af15bbe 100644 --- a/llvm/include/llvm/Analysis/ACPOModelRunner.h +++ b/llvm/include/llvm/Analysis/ACPOModelRunner.h @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#if defined(ENABLE_ACPO) #ifndef LLVM_ANALYSIS_ACPOMODEL_H #define LLVM_ANALYSIS_ACPOMODEL_H @@ -37,3 +38,4 @@ protected: } // namespace llvm #endif // LLVM_ANALYSIS_ACPOMODEL_H +#endif // ENABLE_ACPO diff --git a/llvm/include/llvm/Analysis/AOTModelRunner.h b/llvm/include/llvm/Analysis/AOTModelRunner.h index abc6258c4f09..fe19a33a5a08 100644 --- a/llvm/include/llvm/Analysis/AOTModelRunner.h +++ b/llvm/include/llvm/Analysis/AOTModelRunner.h @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#if defined(ENABLE_ACPO) #ifndef LLVM_ANALYSIS_AOTMODEL_H #define LLVM_ANALYSIS_AOTMODEL_H @@ -201,3 +202,4 @@ private: } // namespace llvm #endif // LLVM_ANALYSIS_AOTMODEL_H +#endif // ENABLE_ACPO diff --git a/llvm/include/llvm/Analysis/CallHeight.h b/llvm/include/llvm/Analysis/CallHeight.h index c1251081f525..84e94075ea39 100644 --- a/llvm/include/llvm/Analysis/CallHeight.h +++ b/llvm/include/llvm/Analysis/CallHeight.h @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#if defined(ENABLE_ACPO) #ifndef LLVM_ANALYSIS_CALLHEIGHT #define LLVM_ANALYSIS_CALLHEIGHT @@ -70,3 +71,4 @@ Pass *createCallHeightAnalysisWrapper(); } // namespace llvm #endif +#endif // ENABLE_ACPO diff --git a/llvm/include/llvm/Analysis/DumpCallsite.h b/llvm/include/llvm/Analysis/DumpCallsite.h index 9f80fe1cb985..02238521580b 100644 --- a/llvm/include/llvm/Analysis/DumpCallsite.h +++ b/llvm/include/llvm/Analysis/DumpCallsite.h @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#if defined(ENABLE_ACPO) #ifndef LLVM_ANALYSIS_DUMPCALLSITE #define LLVM_ANALYSIS_DUMPCALLSITE @@ -25,3 +26,4 @@ public: } // namespace llvm #endif +#endif // ENABLE_ACPO diff --git a/llvm/include/llvm/Analysis/DumpFeature.h b/llvm/include/llvm/Analysis/DumpFeature.h index 226e06cf5600..67ca36b106cb 100644 --- a/llvm/include/llvm/Analysis/DumpFeature.h +++ b/llvm/include/llvm/Analysis/DumpFeature.h @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#if defined(ENABLE_ACPO) #ifndef LLVM_ANALYSIS_DUMPFEATURE #define LLVM_ANALYSIS_DUMPFEATURE @@ -192,3 +193,4 @@ operator++(ACPOFIExtendedFeatures::NamedFloatFeatureIndex &n, int); } // namespace llvm #endif +#endif // ENABLE_ACPO diff --git a/llvm/include/llvm/Analysis/FIModelRunner.h b/llvm/include/llvm/Analysis/FIModelRunner.h new file mode 100644 index 000000000000..3685220aa074 --- /dev/null +++ b/llvm/include/llvm/Analysis/FIModelRunner.h @@ -0,0 +1,277 @@ +//===- FIModelRunner.h - AI-Enabled Continuous Program Optimization -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#if defined(ENABLE_ACPO) +#ifdef LLVM_HAVE_TF_AOT_FICOMPILEDMODEL + +#ifndef LLVM_ANALYSIS_FIMODELRUNNER_H +#define LLVM_ANALYSIS_FIMODELRUNNER_H + +#include "llvm/Analysis/AOTModelRunner.h" +#include "llvm/Analysis/FICompiledModel.h" + +namespace llvm { + +class FIModelRunner : public AOTModelRunner { + std::vector Means = { + 0.40009943697174110699421589742996729910373687744141, + 0.0, + 47.2218788212687599070704891346395015716552734375, + 0.0, + 0.07675459224122871404460966004990041255950927734375, + 5816.8243862454482950852252542972564697265625, + 1333.68016232413765465025790035724639892578125, + 321.9700210967629345759632997214794158935546875, + 0.94076781467098458122677584469784051179885864257812, + 0.0, + 0.0, + 24.57427538666200916850357316434383392333984375, + 0.72785175828753412297089653293369337916374206542969, + 22.362582136282401990001744707114994525909423828125, + 2.3236404681600126842511144786840304732322692871094, + 219.476437468925951179699040949344635009765625, + 123.872156304169635632206336595118045806884765625, + 759.6211988873809559663641266524791717529296875, + 3.5118047810371009198604497214546427130699157714844, + 0.0, + 14.689125089022963877027905255090445280075073242188, + 0.2720138674263292699606608948670327663421630859375, + 97.33707789677367827607668004930019378662109375, + 5.4576519437240493815011177503038197755813598632812, + 222416123463299168.0, + 697004967939498496.0, + 6.2712796684314486839184610289521515369415283203125, + 1.4856427516360068974421437815180979669094085693359, + 0.0041427067953076499376430241738944459939375519752502, + 0.72785175828753412297089653293369337916374206542969, + 552.7808652140716958456323482096195220947265625, + 62.5524652090595196796130039729177951812744140625, + 385.68509386043888298445381224155426025390625, + 92.9494483935554143272383953444659709930419921875, + 24.2728066757145342080548289231956005096435546875, + 0.90531987798814816947867711860453709959983825683594, + 0.0, + 0.0, + 2.9322753597871509256833633116912096738815307617188, + 0.49584111584407208894731411419343203306198120117188, + 7.9963853317029256473347231803927570581436157226562, + 1.4571144465795025091381376114441081881523132324219, + 15.557169540036818844441768305841833353042602539062, + 9.6481678066085265754736610688269138336181640625, + 50.98738225453177363988288561813533306121826171875, + 1.3425469302194332765765238946187309920787811279297, + 0.0, + 839.271140434566405019722878932952880859375, + 0.16440693908813608370422798543586395680904388427734, + 2.8829196844891762374629706755513325333595275878906, + 132.0555906421747067724936641752719879150390625, + 92791372484119440.0, + 166968642875823456.0, + 5.5557876796248262252220229129306972026824951171875, + 1.1750766644405326033506753446999937295913696289062, + 0.0042161570432282073628282859090177225880324840545654, + 0.49584111584407208894731411419343203306198120117188, + 41.15953665944181949498670292086899280548095703125, + 5.14903426051142787400749512016773223876953125, + 2.0527687821658449074391228350577875971794128417969, + 0.52614251736787642776960183255141600966453552246094, + 0.74523979091361081117383946548216044902801513671875, + 222.345100041656024814074044115841388702392578125, + 7.4997648449992606600744693423621356487274169921875, + 0.0, + 78.5584998454695693226312869228422641754150390625, + 0.0, + 10.409640011287439875786731136031448841094970703125, + 8.4653112780338357623577394406311213970184326171875, + 1.3630927585697201198655648113344796001911163330078, + 566.7381985783200661899172700941562652587890625, + 0.0, + 1.2066945269353257508271326514659449458122253417969, + 55.41075531786237462483768467791378498077392578125, + 0.51243634018194272883306439325679093599319458007812, + 1.1147556403606606600931172579294070601463317871094, + -31.471868743197301654390685143880546092987060546875, + 0.0, + 0.030368588666872708276001091576290491502732038497925, + 0.58478345583789081985059965518303215503692626953125, + 0.00034937314395517275094141251834400918596656993031502, + -0.23764092503258577027125397762574721127748489379883, + -62.20223330063559075142620713450014591217041015625, + 5.8952014942420616350204909394960850477218627929688, + 3339.09353794057960840291343629360198974609375, + 0.71960117711874660439974604742019437253475189208984, + -49.2720273048549444183663581497967243194580078125, + 27818.32155766672440222464501857757568359375, + 91.64824843118020680776680819690227508544921875, + 106.3296335613216996307528461329638957977294921875, + 469.83727273948858282892615534365177154541015625, + 0.30689743210739195422576131022651679813861846923828, + 1071.964175815315911677316762506961822509765625, + 1363.988766309679022015188820660114288330078125, + 14.079536139964256236112305487040430307388305664062, + 63165365211952664.0, + 0.38502264206721403816402471420587971806526184082031, + 0.015573979763232508391479491649533883901312947273254, + 0.13859363872129429329227434664062457159161567687988, + 0.0}; + + std::vector Scales = { + 0.48991823553184549178141082848014775663614273071289, + 1.0, + 19.2517211876445770712962257675826549530029296875, + 1.0, + 0.26620166192402217042456413764739409089088439941406, + 13580.447773648038491955958306789398193359375, + 3192.7079136089387247920967638492584228515625, + 633.0586155859824657454737462103366851806640625, + 0.23605875020885080939336830851971171796321868896484, + 1.0, + 1.0, + 101.565906032925312274528550915420055389404296875, + 0.44506581113952026207414292002795264124870300292969, + 25.4451961539476627649492002092301845550537109375, + 1.8819488669919737233726664271671324968338012695312, + 399.4446922340151786556816659867763519287109375, + 253.61174866766344848656444810330867767333984375, + 1934.51814232197148157865740358829498291015625, + 9.2671206485376131922748754732310771942138671875, + 1.0, + 101.7363052307218964642743230797350406646728515625, + 0.44499699252253444026194983962341211736202239990234, + 241.819662633324895750774885527789592742919921875, + 41.0624051346520815286567085422575473785400390625, + 1810657384453411584.0, + 2590019375355715584.0, + 18.6007475145233769353581010363996028900146484375, + 0.30589376767499054654564361044322140514850616455078, + 0.021661308027730186848147653222440567333251237869263, + 0.44506581113952026207414292002795264124870300292969, + 2210.9835111177717408281750977039337158203125, + 252.28469071093292086516157723963260650634765625, + 1479.28580699818076027440838515758514404296875, + 358.2883493183543350824038498103618621826171875, + 86.4399992258764626740230596624314785003662109375, + 0.29277260204409949473358665272826328873634338378906, + 1.0, + 1.0, + 11.300678128510535103146139590535312891006469726562, + 0.49998270338340455865022704529110342264175415039062, + 9.4889928089799600030573856201954185962677001953125, + 1.0885854822898506366612991769216023385524749755859, + 53.20529981175358358314042561687529087066650390625, + 36.65171139901388386306280153803527355194091796875, + 214.68561782216193023486994206905364990234375, + 2.8728217196022858281878598063485696911811828613281, + 1.0, + 1653.1016242378727838513441383838653564453125, + 0.37064443536603375317639574859640561044216156005859, + 20.0905336391907667348277755081653594970703125, + 288.66579115116110187955200672149658203125, + 967784087203564544.0, + 986920622098821248.0, + 17.499765511468584833210115903057157993316650390625, + 0.57797196338014200645005757905892096459865570068359, + 0.028955889395889600895772630906321865040808916091919, + 0.49998270338340455865022704529110342264175415039062, + 319.19585661999855119574931450188159942626953125, + 38.6813101625874224964718450792133808135986328125, + 39.62777871280881214488545083440840244293212890625, + 5.0871202966110988796799574629403650760650634765625, + 0.69504605038799238680979897253564558923244476318359, + 673.3477042973012203219695948064327239990234375, + 56.94168682747444876213194220326840877532958984375, + 1.0, + 261.01902251155337353338836692273616790771484375, + 1.0, + 85.0611943221388884239786420948803424835205078125, + 53.12927927294536090130350203253328800201416015625, + 21.829518414441992035790462978184223175048828125, + 1898.72146183866834689979441463947296142578125, + 1.0, + 9.7285926829767870316345579340122640132904052734375, + 174.40267892003106453557847999036312103271484375, + 0.98364895900708060327843895720434375107288360595703, + 1.1152676652901183373955973365809768438339233398438, + 18.12268289087599981712628505192697048187255859375, + 1.0, + 0.1715993516574435828747624555035145021975040435791, + 0.49275933843630442821037718204024713486433029174805, + 0.031531692879025040310292382628176710568368434906006, + 23.13033056510358420609918539412319660186767578125, + 210.58233961820729973624111153185367584228515625, + 5.1604155410259560099461850768420845270156860351562, + 2053.87275307550726211047731339931488037109375, + 1.0834136602451556186110792623367160558700561523438, + 3840.080091990574146620929241180419921875, + 13192.047960544839952490292489528656005859375, + 348.088713237990532434196211397647857666015625, + 439.96013885313283253708505071699619293212890625, + 897.3433304220051240918110124766826629638671875, + 0.69288480487588777201324319321429356932640075683594, + 2894.596744865002619917504489421844482421875, + 3788.94162413956064483500085771083831787109375, + 94.549943427633166947998688556253910064697265625, + 649339661894085888.0, + 0.48660066498392295919472871901234611868858337402344, + 0.12382015553845396316212656984134810045361518859863, + 0.50791641118256847242662388453027233481407165527344, + 1.0}; + +public: + FIModelRunner(LLVMContext &Ctx, + std::vector> Features, + StringRef DecisionName) + : AOTModelRunner( + Ctx, + {{"input_1", "float32[" + std::to_string(Features.size()) + "]"}}, + DecisionName) {} + + // Features for this model are only floats so we only need to override the + // float method to handle feature scaling and the input type + bool setCustomFeature(int FeatureIndex, float FeatureValue) override { + // Scale the feature according to the constant mean and scale value + // Feature scaling is done to create a standard normal distribution: + // subtract mean, then divide by standard deviation ("scale") + float ScaledValue = + (FeatureValue - Means[FeatureIndex]) / Scales[FeatureIndex]; + // Assuming the Buffer at index 0 is for feature input of shape: + // (Feature.size()) + float *Location = getTensor(0) + FeatureIndex; + *Location = ScaledValue; + return true; + } + + // Outputs for this model are only int so we only need to override this + // method + int getModelResultI(std::string OutputName) override { + if (OutputName == "FI-ShouldInline") { + int Classes[] = {0, 1}; + void *ResultUntyped = CompiledModel->result_data(0); + float *Result = reinterpret_cast(ResultUntyped); + float Max = Result[0]; + int MaxClass = 0; + for (size_t I = 0; I < sizeof(Classes) / sizeof(int); ++I) { + if (Result[I] > Max) { + Max = Result[I]; + MaxClass = I; + } + } + + return Classes[MaxClass]; + } + assert(false && "ModelRunner received invalid result name"); + } +}; + +} // namespace llvm + +#endif // LLVM_ANALYSIS_FIMODELRUNNER_H + +#endif // LLVM_HAVE_TF_AOT_FICOMPILEDMODEL + +#endif // ENABLE_ACPO diff --git a/llvm/include/llvm/Analysis/InlineAdvisor.h b/llvm/include/llvm/Analysis/InlineAdvisor.h index 53c018d15cd7..adf36a385725 100644 --- a/llvm/include/llvm/Analysis/InlineAdvisor.h +++ b/llvm/include/llvm/Analysis/InlineAdvisor.h @@ -200,6 +200,22 @@ public: return AnnotatedInlinePassName.c_str(); } +#if defined(ENABLE_ACPO) + /// Helper functions used by getFeatures to retrieve certain information + ///{ + CallBase *getInlinableCS(Instruction &I); + int64_t getLocalCalls(Function &F); + unsigned getCallLoopLevel(CallBase &CB) const; + uint64_t getCalleeBlockFreq(CallBase &CB) const; + unsigned getCallSiteHeight(CallBase *CB); + ///} + + // Allow ACPO infrastructure to replicate Advisor behaviour + virtual bool isForcedToStop() const { return false; } + bool neverInline(CallBase &CB) const; + bool isCSInlinable(CallBase &CB) const; +#endif + protected: InlineAdvisor(Module &M, FunctionAnalysisManager &FAM, std::optional IC = std::nullopt); @@ -213,6 +229,15 @@ protected: const std::string AnnotatedInlinePassName; std::unique_ptr ImportedFunctionsStats; +#if defined(ENABLE_ACPO) + /// Map a function to its callheight + std::map FunctionLevels; + // used by getORE() for legacy PM + static std::unique_ptr ORE; + + friend class ACPOCollectFeatures; +#endif + enum class MandatoryInliningKind { NotMandatory, Always, Never }; static MandatoryInliningKind getMandatoryKind(CallBase &CB, @@ -389,6 +414,11 @@ void emitInlinedIntoBasedOnCost(OptimizationRemarkEmitter &ORE, DebugLoc DLoc, bool ForProfileContext = false, const char *PassName = nullptr); +#if defined(ENABLE_ACPO) +/// get call site location as string. +std::string getCallSiteLocation(DebugLoc DLoc); +#endif + /// Add location info to ORE message. void addLocationToRemarks(OptimizationRemark &Remark, DebugLoc DLoc); diff --git a/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h b/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h index 77ae60059ce9..e7ece46342fd 100644 --- a/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h +++ b/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h @@ -72,6 +72,36 @@ enum class InlineCostFeatureIndex : size_t { NumberOfFeatures }; + +#if defined(ENABLE_ACPO) +const std::map InlineCostFeatureIndexToName = { + { InlineCostFeatureIndex::sroa_savings, "sroa_savings" }, + { InlineCostFeatureIndex::sroa_losses, "sroa_losses" }, + { InlineCostFeatureIndex::load_elimination, "load_elimination" }, + { InlineCostFeatureIndex::call_penalty, "call_penalty" }, + { InlineCostFeatureIndex::call_argument_setup, "call_argument_setup" }, + { InlineCostFeatureIndex::load_relative_intrinsic, "load_relative_intrinsic" }, + { InlineCostFeatureIndex::lowered_call_arg_setup, "lowered_call_arg_setup" }, + { InlineCostFeatureIndex::indirect_call_penalty, "indirect_call_penalty" }, + { InlineCostFeatureIndex::jump_table_penalty, "jump_table_penalty" }, + { InlineCostFeatureIndex::case_cluster_penalty, "case_cluster_penalty" }, + { InlineCostFeatureIndex::switch_penalty, "switch_penalty" }, + { InlineCostFeatureIndex::unsimplified_common_instructions, "unsimplified_common_instructions" }, + { InlineCostFeatureIndex::num_loops, "num_loops" }, + { InlineCostFeatureIndex::dead_blocks, "dead_blocks" }, + { InlineCostFeatureIndex::simplified_instructions, "simplified_instructions" }, + { InlineCostFeatureIndex::constant_args, "constant_args" }, + { InlineCostFeatureIndex::constant_offset_ptr_args, "constant_offset_ptr_args" }, + { InlineCostFeatureIndex::callsite_cost, "callsite_cost" }, + { InlineCostFeatureIndex::cold_cc_penalty, "cold_cc_penalty" }, + { InlineCostFeatureIndex::last_call_to_static_bonus, "last_call_to_static_bonus" }, + { InlineCostFeatureIndex::is_multiple_blocks, "is_multiple_blocks" }, + { InlineCostFeatureIndex::nested_inlines, "nested_inlines" }, + { InlineCostFeatureIndex::nested_inline_cost_estimate, "nested_inline_cost_estimate" }, + { InlineCostFeatureIndex::threshold, "threshold" } +}; +#endif + // clang-format on using InlineCostFeatures = diff --git a/llvm/include/llvm/Analysis/MLInlineAdvisor.h b/llvm/include/llvm/Analysis/MLInlineAdvisor.h index f58862e53352..e302b0a979a5 100644 --- a/llvm/include/llvm/Analysis/MLInlineAdvisor.h +++ b/llvm/include/llvm/Analysis/MLInlineAdvisor.h @@ -41,8 +41,11 @@ public: } void onSuccessfulInlining(const MLInlineAdvice &Advice, bool CalleeWasDeleted); - +#if defined(ENABLE_ACPO) + bool isForcedToStop() const override { return ForceStop; } +#else bool isForcedToStop() const { return ForceStop; } +#endif int64_t getLocalCalls(Function &F); const MLModelRunner &getModelRunner() const { return *ModelRunner.get(); } FunctionPropertiesInfo &getCachedFPI(Function &) const; diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h index 7fdb5db67c16..89160cfd17d1 100644 --- a/llvm/include/llvm/InitializePasses.h +++ b/llvm/include/llvm/InitializePasses.h @@ -100,7 +100,9 @@ void initializeDomPrinterWrapperPassPass(PassRegistry &); void initializeDomViewerWrapperPassPass(PassRegistry &); void initializeDominanceFrontierWrapperPassPass(PassRegistry&); void initializeDominatorTreeWrapperPassPass(PassRegistry&); +#if defined(ENABLE_ACPO) void initializeDumpCallsiteLegacyPass(PassRegistry &); +#endif void initializeDwarfEHPrepareLegacyPassPass(PassRegistry &); void initializeEarlyCSELegacyPassPass(PassRegistry&); void initializeEarlyCSEMemSSALegacyPassPass(PassRegistry&); @@ -134,7 +136,9 @@ void initializeGlobalsAAWrapperPassPass(PassRegistry&); void initializeGuardWideningLegacyPassPass(PassRegistry&); void initializeHardwareLoopsLegacyPass(PassRegistry&); void initializeMIRProfileLoaderPassPass(PassRegistry &); +#if defined(ENABLE_ACPO) void initializeInlineAdvisorAnalysisWrapperPass(PassRegistry &); +#endif void initializeIRSimilarityIdentifierWrapperPassPass(PassRegistry&); void initializeIRTranslatorPass(PassRegistry&); void initializeIVUsersWrapperPassPass(PassRegistry&); @@ -152,11 +156,13 @@ void initializeInterleavedLoadCombinePass(PassRegistry &); void initializeIntervalPartitionPass(PassRegistry&); void initializeJMCInstrumenterPass(PassRegistry&); void initializeKCFIPass(PassRegistry &); +#if defined(ENABLE_ACPO) void initializeLegacyFAMPass(PassRegistry &); void initializeLegacyFunctionPropertiesAnalysisPass(PassRegistry &); void initializeLegacyInlinerPassPass(PassRegistry &); void initializeLegacyInlineSizeEstimatorAnalysisPass(PassRegistry &); void initializeLegacyModuleInlinerWrapperPassPass(PassRegistry &); +#endif void initializeLCSSAVerificationPassPass(PassRegistry&); void initializeLCSSAWrapperPassPass(PassRegistry&); void initializeLazyBlockFrequencyInfoPassPass(PassRegistry&); diff --git a/llvm/include/llvm/Transforms/IPO.h b/llvm/include/llvm/Transforms/IPO.h index 4995b000c454..6905acb261fe 100644 --- a/llvm/include/llvm/Transforms/IPO.h +++ b/llvm/include/llvm/Transforms/IPO.h @@ -18,6 +18,10 @@ #include #include +#if defined(ENABLE_ACPO) +#include "llvm/Analysis/InlineAdvisor.h" +#endif + namespace llvm { class ModulePass; diff --git a/llvm/include/llvm/Transforms/IPO/Inliner.h b/llvm/include/llvm/Transforms/IPO/Inliner.h index 401aa2d3a0cc..46a3468c927c 100644 --- a/llvm/include/llvm/Transforms/IPO/Inliner.h +++ b/llvm/include/llvm/Transforms/IPO/Inliner.h @@ -16,6 +16,15 @@ #include "llvm/Analysis/Utils/ImportedFunctionsInliningStatistics.h" #include "llvm/IR/PassManager.h" +#if defined(ENABLE_ACPO) +#include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/Pass.h" +#include +#endif + namespace llvm { /// The inliner pass for the new pass manager. diff --git a/llvm/lib/Analysis/ACPOCollectFeatures.cpp b/llvm/lib/Analysis/ACPOCollectFeatures.cpp index f9de26483c76..daa924f2cb3b 100644 --- a/llvm/lib/Analysis/ACPOCollectFeatures.cpp +++ b/llvm/lib/Analysis/ACPOCollectFeatures.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#if defined(ENABLE_ACPO) #include "llvm/Analysis/ACPOCollectFeatures.h" #include "llvm/ADT/SCCIterator.h" // The ACPOFIModel.h currently contains only the cache system for @@ -1256,3 +1257,4 @@ operator++(ACPOCollectFeatures::FeatureIndex &N, int) { } } // namespace llvm +#endif // ENABLE_ACPO diff --git a/llvm/lib/Analysis/ACPOFIModel.cpp b/llvm/lib/Analysis/ACPOFIModel.cpp new file mode 100644 index 000000000000..d9a647ec1012 --- /dev/null +++ b/llvm/lib/Analysis/ACPOFIModel.cpp @@ -0,0 +1,243 @@ +//===- ACPOFIModel.cpp - AI-Enabled Continuous Program Optimization -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the interface between ACPO and ML-guided optimizations. +// It delegates decision making to inference with a pre-trained model. +// +//===----------------------------------------------------------------------===// + +#if defined(ENABLE_ACPO) +#include "llvm/Analysis/ACPOFIModel.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Process.h" + +using namespace llvm; + +#define DEBUG_TYPE "acpo" +#define ACPO_ENV_VAR_DIR "ACPO_DIR" + +cl::opt + EnableACPOFI("enable-acpo-fi", cl::init(false), cl::Hidden, + cl::desc("Leverage ACPO ML model to decide inlining.")); + +cl::opt + EnableAOTFI("enable-acpo-fi-aot", cl::init(false), cl::Hidden, + cl::desc("Leverage AOT ML model to decide inlining.")); + +ACPOFIModel::ACPOFIModel(CallBase *CB, InlineAdvisor *IA, + OptimizationRemarkEmitter *ORE, bool OnlyMandatory, + bool UseML) + : ACPOModel(ORE, UseML), CurrentCB(CB), NotACPOAdvisor(IA), + OnlyMandatory(OnlyMandatory) { + Function *Caller = CB->getCaller(); + LLVMContext *Context = &(Caller->getContext()); + setContextPtr(Context); + if (EnableACPOFI) + // ACPO Python support + setMLIF(createPersistentPythonMLIF()); + else if (EnableAOTFI) + // ACPO AOT support + setMLIF(createPersistentCompiledMLIF()); +} + +ACPOFIModel::~ACPOFIModel() {} + +void ACPOFIModel::setMLCustomFeatures( + std::vector> FeatureValues) { + CustomFeatureValues = FeatureValues; +} + +void ACPOFIModel::sendCustomFeatures() { + // Get an ACPOMLInterface to communicate with the Python side + std::shared_ptr MLIF = getMLIF(); + MLIF->initializeFeatures("FI", CustomFeatureValues); +} + +void ACPOFIModel::recordUnattemptedInlining() { + if (NotACPOAdvice) + NotACPOAdvice->recordUnattemptedInlining(); +} + +void ACPOFIModel::recordInlining() { + if (NotACPOAdvice) + NotACPOAdvice->recordInlining(); +} + +void ACPOFIModel::recordUnsuccessfulInlining(InlineResult &IR) { + if (NotACPOAdvice) + NotACPOAdvice->recordUnsuccessfulInlining(IR); +} + +void ACPOFIModel::recordInliningWithCalleeDeleted() { + if (NotACPOAdvice) + NotACPOAdvice->recordInliningWithCalleeDeleted(); +} + +void ACPOFIModel::invalidateCache(CallBase *CB) { + if (CB) { + invalidateCache(CB->getCaller()); + } +} + +InlineAdvisor *ACPOFIModel::getNotACPOAdvisor() { return NotACPOAdvisor; } + +void ACPOFIModel::invalidateCache(const Function *F) { + for (ACPOFIExtendedFeatures::NamedFeatureIndex feature = + ACPOFIExtendedFeatures::NamedFeatureIndex(0); + feature != ACPOFIExtendedFeatures::NamedFeatureIndex::NumNamedFeatures; + ++feature) { + FeatureCache[feature].erase(F); + } + for (ACPOFIExtendedFeatures::NamedFloatFeatureIndex feature = + ACPOFIExtendedFeatures::NamedFloatFeatureIndex(0); + feature != + ACPOFIExtendedFeatures::NamedFloatFeatureIndex::NumNamedFloatFeatures; + ++feature) { + FeatureCache[feature].erase(F); + } + FunctionAnalysisCache.DomCache.erase(F); + FunctionAnalysisCache.LICache.erase(F); + FunctionAnalysisCache.TTICache.erase(F); +} + +void ACPOFIModel::clearCache() { + for (ACPOFIExtendedFeatures::NamedFeatureIndex feature = + ACPOFIExtendedFeatures::NamedFeatureIndex(0); + feature != ACPOFIExtendedFeatures::NamedFeatureIndex::NumNamedFeatures; + ++feature) { + FeatureCache[feature].clear(); + } + for (ACPOFIExtendedFeatures::NamedFloatFeatureIndex feature = + ACPOFIExtendedFeatures::NamedFloatFeatureIndex(0); + feature != + ACPOFIExtendedFeatures::NamedFloatFeatureIndex::NumNamedFloatFeatures; + ++feature) { + FeatureCache[feature].clear(); + } + FunctionAnalysisCache.DomCache.clear(); + FunctionAnalysisCache.LICache.clear(); + FunctionAnalysisCache.TTICache.clear(); +} + +std::optional +ACPOFIModel::getCachedSize(const Function *F, + ACPOFIExtendedFeatures::NamedFeatureIndex idx) { + auto it = FeatureCache[idx].find(F); + return it != FeatureCache[idx].end() ? std::optional(it->second) + : std::nullopt; +} + +std::optional ACPOFIModel::getCachedFloat( + const Function *F, ACPOFIExtendedFeatures::NamedFloatFeatureIndex idx) { + auto it = FeatureCache[idx].find(F); + return it != FeatureCache[idx].end() ? std::optional(it->second) + : std::nullopt; +} + +void ACPOFIModel::insertSizeCache(const Function *F, + ACPOFIExtendedFeatures::NamedFeatureIndex idx, + size_t val) { + FeatureCache[idx].insert(std::make_pair(F, val)); +} + +void ACPOFIModel::insertFloatCache( + const Function *F, ACPOFIExtendedFeatures::NamedFloatFeatureIndex idx, + float val) { + FeatureCache[idx].insert(std::make_pair(F, val)); +} + +const DominatorTree *ACPOFIModel::getDomCachedAnalysis(const Function *F) { + auto it = FunctionAnalysisCache.DomCache.find(F); + return it != FunctionAnalysisCache.DomCache.end() ? it->second : nullptr; +} + +const LoopInfo *ACPOFIModel::getLICachedAnalysis(const Function *F) { + auto it = FunctionAnalysisCache.LICache.find(F); + return it != FunctionAnalysisCache.LICache.end() ? it->second : nullptr; +} + +const TargetTransformInfo * +ACPOFIModel::getTTICachedAnalysis(const Function *F) { + auto it = FunctionAnalysisCache.TTICache.find(F); + return it != FunctionAnalysisCache.TTICache.end() ? it->second : nullptr; +} + +void ACPOFIModel::insertAnalysisCache(const Function *F, + const DominatorTree *Tree) { + FunctionAnalysisCache.DomCache.insert(std::make_pair(F, Tree)); +} + +void ACPOFIModel::insertAnalysisCache(const Function *F, const LoopInfo *LI) { + FunctionAnalysisCache.LICache.insert(std::make_pair(F, LI)); +} + +void ACPOFIModel::insertAnalysisCache(const Function *F, + const TargetTransformInfo *TTI) { + FunctionAnalysisCache.TTICache.insert(std::make_pair(F, TTI)); +} + +std::unique_ptr ACPOFIModel::getAdviceML() { + std::shared_ptr MLIF = getMLIF(); + // Generate result. + std::unique_ptr Advice = std::make_unique(); + // handle mandatory case, forcestop, never inline or not inlinable cases + if (OnlyMandatory) + return getAdviceNoML(); + if (NotACPOAdvisor->neverInline(*CurrentCB) || + !NotACPOAdvisor->isCSInlinable(*CurrentCB)) { + Advice->addField("FI-ShouldInline", + ConstantInt::get(Type::getInt64Ty(*(getContextPtr())), + (int64_t) false)); + NotACPOAdvice = nullptr; + return Advice; + } + std::optional Env = llvm::sys::Process::GetEnv(ACPO_ENV_VAR_DIR); + if (!Env || *Env == "") { + std::optional LLVMDIROpt = + llvm::sys::Process::GetEnv("LLVM_DIR"); + if (!LLVMDIROpt) { + outs() << "ACPO_DIR not found. " + << "Did you export ACPO_DIR to $LLVM_DIR/acpo ?\n" + << "Falling back to default advisor. \n"; + return getAdviceNoML(); + } + } + assert(MLIF != nullptr); + if (!MLIF->loadModel("model-fi.acpo")) { + outs() << "Model not loaded correctly. \n"; + return getAdviceNoML(); + } + if (!MLIF->initializeFeatures("FI", CustomFeatureValues)) { + outs() << "Features not initialized correctly. \n"; + return getAdviceNoML(); + } + bool ModelRunOK = MLIF->runModel("FI"); + assert(ModelRunOK); + ShouldInline = MLIF->getModelResultI("FI-ShouldInline"); + assert(getContextPtr() != nullptr); + Advice->addField("FI-ShouldInline", + ConstantInt::get(Type::getInt64Ty(*(getContextPtr())), + (int64_t)ShouldInline)); + return Advice; +} + +std::unique_ptr ACPOFIModel::getAdviceNoML() { + // Use the advisor used by default inlining + std::unique_ptr Advice = std::make_unique(); + assert(getContextPtr() != nullptr); + NotACPOAdvice = NotACPOAdvisor->getAdvice(*CurrentCB, OnlyMandatory); + bool ShouldInline = NotACPOAdvice->isInliningRecommended(); + Advice->addField("FI-ShouldInline", + ConstantInt::get(Type::getInt64Ty(*(getContextPtr())), + (int64_t)ShouldInline)); + return Advice; +} + +ACPOFIModel::FunctionFeaturesCache ACPOFIModel::FeatureCache; +ACPOFIModel::FunctionAnalysisMap ACPOFIModel::FunctionAnalysisCache; +#endif // ENABLE_ACPO diff --git a/llvm/lib/Analysis/ACPOMLInterface.cpp b/llvm/lib/Analysis/ACPOMLInterface.cpp index 271dcfe7d851..f48eb46638e3 100644 --- a/llvm/lib/Analysis/ACPOMLInterface.cpp +++ b/llvm/lib/Analysis/ACPOMLInterface.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#if defined(ENABLE_ACPO) #include "llvm/Analysis/ACPOMLInterface.h" #include "llvm/Analysis/ACPOModelRunner.h" #include "llvm/Analysis/FIModelRunner.h" @@ -1403,3 +1404,4 @@ const std::unordered_map + +#include "llvm/ADT/SCCIterator.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/ACPOFIModel.h" +#include "llvm/Analysis/DumpFeature.h" +#include "llvm/Analysis/FunctionPropertiesAnalysis.h" +#include "llvm/Analysis/InlineModelFeatureMaps.h" +#include "llvm/Analysis/InlineSizeEstimatorAnalysis.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/InitializePasses.h" +#include "llvm/Transforms/IPO/Inliner.h" +#include +#endif + using namespace llvm; #define DEBUG_TYPE "inline" #ifdef LLVM_HAVE_TF_AOT_INLINERSIZEMODEL @@ -537,6 +556,46 @@ void llvm::emitInlinedIntoBasedOnCost( PassName); } +#if defined(ENABLE_ACPO) +CallBase *InlineAdvisor::getInlinableCS(Instruction &I) { + if (auto *CS = dyn_cast(&I)) + if (Function *Callee = CS->getCalledFunction()) { + if (!Callee->isDeclaration()) { + return CS; + } + } + return nullptr; +} + +// TODO: We can make this faster on large programs by applying +// this patch from MLGO f46dd19b480496d2ba0a57d12935882e530f2b93. +// This patch incrementally computes FunctionPropertiesInfo +// instead of recomputing. +int64_t InlineAdvisor::getLocalCalls(Function &F) { + return FAM.getResult(F) + .DirectCallsToDefinedFunctions; +} + +unsigned InlineAdvisor::getCallLoopLevel(CallBase &CB) const { + Function *F = CB.getCaller(); + BasicBlock *BB = CB.getParent(); + LoopInfo &LI = FAM.getResult(*F); + return LI.getLoopDepth(BB); +} + +uint64_t InlineAdvisor::getCalleeBlockFreq(CallBase &CB) const { + Function *F = CB.getCaller(); + BasicBlock *BB = CB.getParent(); + BlockFrequencyInfo &BFI = FAM.getResult(*F); + return BFI.getBlockFreq(BB).getFrequency(); +} + +unsigned InlineAdvisor::getCallSiteHeight(CallBase *CB) { + Function *Caller = CB->getCaller(); + return FunctionLevels[Caller]; +} +#endif + InlineAdvisor::InlineAdvisor(Module &M, FunctionAnalysisManager &FAM, std::optional IC) : M(M), FAM(FAM), IC(IC), @@ -548,6 +607,35 @@ InlineAdvisor::InlineAdvisor(Module &M, FunctionAnalysisManager &FAM, std::make_unique(); ImportedFunctionsStats->setModuleInfo(M); } +#if defined(ENABLE_ACPO) + std::unique_ptr CG(std::make_unique(M)); + for (auto I = scc_begin(CG.get()); !I.isAtEnd(); ++I) { + const std::vector &CGNodes = *I; + unsigned Level = 0; + for (auto *CGNode : CGNodes) { + Function *F = CGNode->getFunction(); + if (!F || F->isDeclaration()) + continue; + for (auto &I : instructions(F)) { + if (auto *CS = getInlinableCS(I)) { + auto *Called = CS->getCalledFunction(); + auto Pos = FunctionLevels.find(Called); + // In bottom up traversal, an inlinable callee is either in the + // same SCC, or to a function in a visited SCC. So not finding its + // level means we haven't visited it yet, meaning it's in this SCC. + if (Pos == FunctionLevels.end()) + continue; + Level = std::max(Level, Pos->second + 1); + } + } + } + for (auto *CGNode : CGNodes) { + Function *F = CGNode->getFunction(); + if (F && !F->isDeclaration()) + FunctionLevels[F] = Level; + } + } +#endif } InlineAdvisor::~InlineAdvisor() { @@ -639,6 +727,33 @@ std::unique_ptr InlineAdvisor::getAdvice(CallBase &CB, return getMandatoryAdvice(CB, Advice); } +#if defined(ENABLE_ACPO) +bool InlineAdvisor::neverInline(CallBase &CB) const { + auto &Caller = *CB.getCaller(); + auto &Callee = *CB.getCalledFunction(); + auto &ORE = FAM.getResult(Caller); + auto MandatoryKind = InlineAdvisor::getMandatoryKind(CB, FAM, ORE); + // If this is a "never inline" case, there won't be any changes to internal + // state we need to track, so we can just return the base InlineAdvice, + // which will do nothing interesting. Same thing if this is a recursive + // case. + return MandatoryKind == InlineAdvisor::MandatoryInliningKind::Never || + &Caller == &Callee; +} + +bool InlineAdvisor::isCSInlinable(CallBase &CB) const { + auto &Callee = *CB.getCalledFunction(); + + auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & { + return FAM.getResult(F); + }; + auto &TIR = FAM.getResult(Callee); + auto IsCallSiteInlinable = + llvm::getInliningCostEstimate(CB, TIR, GetAssumptionCache); + return !!IsCallSiteInlinable; +} +#endif + OptimizationRemarkEmitter &InlineAdvisor::getCallerORE(CallBase &CB) { return FAM.getResult(*CB.getCaller()); } diff --git a/llvm/lib/Analysis/MLInlineAdvisor.cpp b/llvm/lib/Analysis/MLInlineAdvisor.cpp index 0660a9993b6d..c7ea2eb8ffe9 100644 --- a/llvm/lib/Analysis/MLInlineAdvisor.cpp +++ b/llvm/lib/Analysis/MLInlineAdvisor.cpp @@ -323,6 +323,11 @@ std::unique_ptr MLInlineAdvisor::getAdviceImpl(CallBase &CB) { auto &Caller = *CB.getCaller(); auto &Callee = *CB.getCalledFunction(); +#if defined(ENABLE_ACPO) + LLVM_DEBUG(dbgs() << "Advice on call: " << Caller.getName() << " to " + << Callee.getName() << "\n"); +#endif + auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & { return FAM.getResult(F); }; diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp index a02c603a14a5..370b248c2b85 100644 --- a/llvm/lib/IR/AsmWriter.cpp +++ b/llvm/lib/IR/AsmWriter.cpp @@ -86,15 +86,19 @@ #include #include +#if defined(ENABLE_ACPO) #include "llvm/ADT/StringSet.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Support/CommandLine.h" +#endif using namespace llvm; +#if defined(ENABLE_ACPO) cl::opt UnnamedVariablePrefix( "unnamed-var-prefix", cl::Hidden, cl::desc("Specify the prefix added to unnamed variables"), cl::init("")); +#endif // Make virtual table appear in this compilation unit. AssemblyAnnotationWriter::~AssemblyAnnotationWriter() = default; @@ -2494,12 +2498,17 @@ static void WriteAsOperandInternal(raw_ostream &Out, const Value *V, } else { Slot = -1; } - +#if defined(ENABLE_ACPO) if (Slot != -1) { // By default, UnnamedVariablePrefix is empty so it matches original behaviour // unless specified. Out << Prefix << UnnamedVariablePrefix << Slot; } else +#else + if (Slot != -1) + Out << Prefix << Slot; + else +#endif Out << ""; } @@ -2635,7 +2644,11 @@ public: SmallVector ExitBlocks); #endif void printArgument(const Argument *FA, AttributeSet Attrs); +#if defined(ENABLE_ACPO) void printBasicBlock(const BasicBlock *BB, bool PrintLabelOnly = false); +#else + void printBasicBlock(const BasicBlock *BB); +#endif void printInstructionLine(const Instruction &I); void printInstruction(const Instruction &I); @@ -4214,17 +4227,27 @@ void AssemblyWriter::printArgument(const Argument *Arg, AttributeSet Attrs) { } else { int Slot = Machine.getLocalSlot(Arg); assert(Slot != -1 && "expect argument in function here"); +#if defined(ENABLE_ACPO) // By default, UnnamedVariablePrefix is empty so it matches original behaviour // unless specified. Out << " %" << UnnamedVariablePrefix << Slot; +#else + Out << " %" << Slot; +#endif } } + /// printBasicBlock - This member is called for each basic block in a method. +#if defined(ENABLE_ACPO) void AssemblyWriter::printBasicBlock(const BasicBlock *BB, bool PrintLabelOnly) { assert(BB && BB->getParent() && "block without parent!"); bool IsEntryBlock = BB == &BB->getParent()->getEntryBlock(); +#else +void AssemblyWriter::printBasicBlock(const BasicBlock *BB) { + bool IsEntryBlock = BB == &BB->getParent()->getEntryBlock(); +#endif if (BB->hasName()) { // Print out the label if it exists... Out << "\n"; PrintLLVMName(Out, BB->getName(), LabelPrefix); @@ -4233,17 +4256,23 @@ void AssemblyWriter::printBasicBlock(const BasicBlock *BB, Out << "\n"; int Slot = Machine.getLocalSlot(BB); if (Slot != -1) { +#if defined(ENABLE_ACPO) // By default, UnnamedVariablePrefix is empty so it matches original behaviour // unless specified. Out << UnnamedVariablePrefix << Slot << ":"; +#else + Out << Slot << ":"; +#endif } else Out << ":"; } +#if defined(ENABLE_ACPO) if (PrintLabelOnly) { Out << "\n"; return; } +#endif if (!IsEntryBlock) { // Output predecessors for the block. @@ -4339,9 +4368,13 @@ void AssemblyWriter::printInstruction(const Instruction &I) { if (SlotNum == -1) Out << " = "; else { +#if defined(ENABLE_ACPO) // By default, UnnamedVariablePrefix is empty so it matches original behaviour // unless specified. Out << '%' << UnnamedVariablePrefix << SlotNum << " = "; +#else + Out << '%' << SlotNum << " = "; +#endif } } diff --git a/llvm/lib/Transforms/IPO/Inliner.cpp b/llvm/lib/Transforms/IPO/Inliner.cpp index 802667819c44..3a0a2494c36a 100644 --- a/llvm/lib/Transforms/IPO/Inliner.cpp +++ b/llvm/lib/Transforms/IPO/Inliner.cpp @@ -64,10 +64,21 @@ #include #include #include -#if defined(ENABLE_AUTOTUNER) +#if defined(ENABLE_AUTOTUNER) || defined(ENABLE_ACPO) #include "llvm/AutoTuner/AutoTuning.h" #endif +#if defined(ENABLE_ACPO) +#include "llvm/Analysis/ACPOFIModel.h" +#include "llvm/Analysis/ModelDataCollector.h" +#include "llvm/IR/IRPrintingPasses.h" +#include "llvm/Support/FormattedStream.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/Utils/CallGraphUpdater.h" +#include +#endif + using namespace llvm; #define DEBUG_TYPE "inline" @@ -149,6 +160,124 @@ static cl::opt CGSCCInlineReplayFormat( ":. (default)")), cl::desc("How cgscc inline replay file is formatted"), cl::Hidden); +#if defined(ENABLE_ACPO) +static cl::opt + ACPOVerboseFI("acpo-verbose-fi", cl::init(false), cl::Hidden, + cl::desc("Print ACPO invocation messages for FI.")); + +static cl::opt FeatureDump("enable-fi-feature-dump", cl::init(false)); + +// Defined in 'lib/Analysis/ACPOFIModel.cpp' +extern cl::opt EnableACPOFI; +extern cl::opt EnableAOTFI; +// In "llvm/lib/Analysis/ModelDataCollector.cpp" +extern cl::opt ACPOModelFile; + +namespace { +/// Class for collecting inlining model data +class ModelDataFICollector : public ModelDataCollector { +public: + ModelDataFICollector(formatted_raw_ostream &OS, bool OnlyMandatory, + std::string OutputFileName) + : ModelDataCollector(OS, OutputFileName), OnlyMandatory(OnlyMandatory) {} + + void collectFeatures(CallBase *CB, InlineAdvisor *IA, + FunctionAnalysisManager *FAM) { + bool MandatoryOnly = getOnlyMandatory(); + resetRegisteredFeatures(); + ACPOCollectFeatures::FeaturesInfo CallerFeatures{ + {ACPOCollectFeatures::FeatureIndex::BasicBlockCount, + /* ACPOCollectFeatures::Scope::Function, */ + /* ACPOCollectFeatures::GroupID::FPIRelated, */ + {FAM, nullptr}, + {CB->getCaller(), nullptr, nullptr, nullptr, nullptr}, + {MandatoryOnly, IA}}}; + ACPOCollectFeatures::FeaturesInfo CalleeFeatures{ + {ACPOCollectFeatures::FeatureIndex::BasicBlockCount, + /* ACPOCollectFeatures::Scope::Function, */ + /* ACPOCollectFeatures::GroupID::FPIRelated, */ + {FAM, nullptr}, + {CB->getCalledFunction(), nullptr, nullptr, nullptr, nullptr}, + {MandatoryOnly, IA}}}; + BasicBlock *GlobalBB = CB->getParent(); + Function *GlobalF = GlobalBB->getParent(); + Module *GlobalM = GlobalF->getParent(); + ACPOCollectFeatures::FeatureInfo GlobalFeatureInfo{ + ACPOCollectFeatures::FeatureIndex::NumOfFeatures, + {FAM, nullptr}, + {GlobalF, CB, GlobalBB, GlobalM, nullptr}, + {MandatoryOnly, IA}}; + ACPOCollectFeatures::FeatureInfo CallerInfo{ + ACPOCollectFeatures::FeatureIndex::NumOfFeatures, + {FAM, nullptr}, + {CB->getCaller(), CB, GlobalBB, GlobalM, nullptr}, + {MandatoryOnly, IA}}; + ACPOCollectFeatures::FeatureInfo CalleeInfo{ + ACPOCollectFeatures::FeatureIndex::NumOfFeatures, + {FAM, nullptr}, + {CB->getCalledFunction(), CB, GlobalBB, GlobalM, nullptr}, + {MandatoryOnly, IA}}; + + registerFeature({ACPOCollectFeatures::Scope::Function}, CalleeInfo, + "callee"); + registerFeature({ACPOCollectFeatures::Scope::Function}, CallerInfo, + "caller"); + registerFeature({ACPOCollectFeatures::Scope::CallSite}, GlobalFeatureInfo); + registerFeature({ACPOCollectFeatures::Scope::Module}, GlobalFeatureInfo); + ModelDataCollector::collectFeatures(); + } + bool getOnlyMandatory() { return OnlyMandatory; } + +private: + bool OnlyMandatory = false; +}; + +llvm::SmallDenseSet, 4> + InlinedInternalEdges = + llvm::SmallDenseSet, 4>(); +} // end anonymous namespace + +/// helper function for getting advice with acpo infrastructure +bool getACPOAdvice(CallBase *CB, std::unique_ptr &FI, + ModelDataFICollector *MDC, InlineAdvisor *Advisor, + FunctionAnalysisManager *FAM) { + bool ShouldInline = false; + // ------------------------------------------------------------------------ + // Begin ACPO invocation + if ((EnableACPOFI || EnableAOTFI) && !MDC->getOnlyMandatory() && + !Advisor->neverInline(*CB) && Advisor->isCSInlinable(*CB)) { + if (ACPOVerboseFI) { + errs() << "--- ACPOModel is activated\n"; + } + MDC->collectFeatures(CB, Advisor, FAM); + std::vector> Features = + MDC->getFeatures(); + FI->setMLCustomFeatures(Features); + } + std::unique_ptr Advice = FI->getAdvice(); + Constant *Val = Advice->getField("FI-ShouldInline"); + assert(Val != nullptr); + assert(isa(Val)); + ConstantInt *ACPOInline = dyn_cast(Val); + ShouldInline = ACPOInline->getSExtValue(); + if ((EnableACPOFI || EnableAOTFI) && ACPOVerboseFI) { + errs() << "ACPOModel's inline prediction: " << ShouldInline << "\n"; + } + if (FeatureDump) { + MDC->collectFeatures(CB, Advisor, FAM); + std::vector> Features = + MDC->getFeatures(); + if (MDC->isEmptyOutputFile()) { + MDC->printRow(true); + } + MDC->printRow(); + } + return ShouldInline; + // End ACPO Invocation + // --------------------------------------------------------------------- +} +#endif + /// Return true if the specified inline history ID /// indicates an inline history that includes the specified function. static bool inlineHistoryIncludes( @@ -205,6 +334,14 @@ InlinerPass::getAdvisor(const ModuleAnalysisManagerCGSCCProxy::Result &MAM, PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR) { +#if defined(ENABLE_ACPO) + if (EnableACPOFI || EnableAOTFI) { + // Need to clear the cache at the beggining of the inliner pass, since during + // optimization we may have transofrmed the code which invalidated the cache. + ACPOFIModel::clearCache(); + } +#endif + const auto &MAMProxy = AM.getResult(InitialC, CG); bool Changed = false; @@ -221,6 +358,10 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, Advisor.onPassEntry(&InitialC); auto AdvisorOnExit = make_scope_exit([&] { Advisor.onPassExit(&InitialC); }); +#if defined(ENABLE_ACPO) + if (EnableACPOFI || EnableAOTFI) + ACPOCollectFeatures::clearFunctionLevel(); +#endif // We use a single common worklist for calls across the entire SCC. We // process these in-order and append new calls introduced during inlining to @@ -377,8 +518,51 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, continue; } - std::unique_ptr Advice = - Advisor.getAdvice(*CB, OnlyMandatory); + std::unique_ptr Advice = nullptr; + #if defined(ENABLE_ACPO) + std::unique_ptr FI = nullptr; + if (EnableACPOFI || EnableAOTFI) { + auto &ORE = + FAM.getResult(*CB->getCaller()); + FI = std::make_unique( + CB, &Advisor, &ORE, OnlyMandatory, EnableACPOFI || EnableAOTFI); + std::error_code EC; + raw_fd_ostream RawOS(ACPOModelFile.getValue(), EC, sys::fs::CD_OpenAlways, + sys::fs::FA_Write, sys::fs::OF_Append); + if (EC) + errs() << "Could not create/open training data file (Falling back to " + "debug mode): " + << EC.message() << "\n"; + + formatted_raw_ostream OS(RawOS); + ModelDataFICollector MDC(OS, OnlyMandatory, ACPOModelFile); + if (EnableACPOFI) + LLVM_DEBUG(dbgs() << "ACPO Python ML infra is activated" << "\n"); + else if (EnableAOTFI) + LLVM_DEBUG(dbgs() << "ACPO AOT C++ ML infra is activated" << "\n"); + bool ShouldInline = getACPOAdvice(CB, FI, &MDC, &Advisor, &FAM); + + // Check whether we want to inline this callsite. + if (!ShouldInline) { + FI->recordUnattemptedInlining(); + continue; + } else { + ACPOFIModel::invalidateCache(CB); + } + } else { + Advice = Advisor.getAdvice(*CB, OnlyMandatory); + + // Check whether we want to inline this callsite. + if (!Advice) + continue; + + if (!Advice->isInliningRecommended()) { + Advice->recordUnattemptedInlining(); + continue; + } + } +#else + Advice = Advisor.getAdvice(*CB, OnlyMandatory); // Check whether we want to inline this callsite. if (!Advice) @@ -388,6 +572,7 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, Advice->recordUnattemptedInlining(); continue; } +#endif int CBCostMult = getStringFnAttrAsInt( @@ -396,6 +581,13 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, // Setup the data structure used to plumb customization into the // `InlineFunction` routine. +#if defined(ENABLE_ACPO) + if ((EnableACPOFI || EnableAOTFI) && ACPOVerboseFI) { + Function &F2 = *CB->getCaller(); + LLVM_DEBUG(dbgs() << "check: " << F2.getName() << ", " + << Callee.getName() << "\n"); + } +#endif InlineFunctionInfo IFI( GetAssumptionCache, PSI, &FAM.getResult(*(CB->getCaller())), @@ -405,7 +597,14 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, InlineFunction(*CB, IFI, /*MergeAttributes=*/true, &FAM.getResult(*CB->getCaller())); if (!IR.isSuccess()) { +#if defined(ENABLE_ACPO) + if (EnableACPOFI || EnableAOTFI) + FI->recordUnsuccessfulInlining(IR); + else + Advice->recordUnsuccessfulInlining(IR); +#else Advice->recordUnsuccessfulInlining(IR); +#endif continue; } @@ -494,10 +693,24 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, DeadFunctionsInComdats.push_back(&Callee); } } +#if defined(ENABLE_ACPO) + if (EnableACPOFI || EnableAOTFI) { + if (CalleeWasDeleted) + FI->recordInliningWithCalleeDeleted(); + else + FI->recordInlining(); + } else { + if (CalleeWasDeleted) + Advice->recordInliningWithCalleeDeleted(); + else + Advice->recordInlining(); + } +#else if (CalleeWasDeleted) Advice->recordInliningWithCalleeDeleted(); else Advice->recordInlining(); +#endif } // Back the call index up by one to put us in a good position to go around diff --git a/llvm/tools/opt/opt.cpp b/llvm/tools/opt/opt.cpp index 1401352647cd..671a33309a1b 100644 --- a/llvm/tools/opt/opt.cpp +++ b/llvm/tools/opt/opt.cpp @@ -430,6 +430,9 @@ int main(int argc, char **argv) { initializeTransformUtils(Registry); initializeInstCombine(Registry); initializeTarget(Registry); +#if defined(ENABLE_ACPO) + initializeDumpCallsiteLegacyPass(Registry); +#endif // For codegen passes, only passes that do IR to IR transformation are // supported. initializeExpandLargeDivRemLegacyPassPass(Registry); -- 2.38.1.windows.1