diff --git a/src/lib/Annotation.cc b/src/lib/Annotation.cc index ae7a42d..ae2da48 100644 --- a/src/lib/Annotation.cc +++ b/src/lib/Annotation.cc @@ -229,18 +229,8 @@ bool isAllocFn(StringRef name, int *size, int *flag) { bool isEntryFn(StringRef name) { if (name.equals("main") || - name.startswith("do_syscall_") || - name.endswith("do_softirq") || - name.equals("start_kernel") || - name.equals("init") || - name.equals("module_init") || - name.equals("module_exit") || - name.equals("init_module") || - name.equals("cleanup_module") || - name.equals("do_init_module") || - name.equals("do_cleanup_module") || - name.equals("do_one_initcall") || - name.equals("do_one_initcall_sync")) + name.startswith("LLVMFuzzerTestOneInput") || + name.startswith("FuzzerTestOneInput")) return true; else return false; } @@ -249,8 +239,12 @@ bool isExitFn(StringRef name) { if (name.equals("exit") || name.equals("_exit") || name.equals("_Exit") || + name.equals("quick_exit") || name.equals("exit_group") || + name.equals("terminate") || + name.equals("abort") || name.equals("panic") || + name.equals("png_error") || name.equals("BUG") || name.equals("BUG_ON")) return true; diff --git a/src/lib/CallGraph.cc b/src/lib/CallGraph.cc index 19f2155..bebef08 100644 --- a/src/lib/CallGraph.cc +++ b/src/lib/CallGraph.cc @@ -426,7 +426,6 @@ bool CallGraphPass::runOnFunction(Function *F) { #pragma clang diagnostic pop #endif // normal handling - bool isNull = false; Value *ptr = I->getOperand(0); NodeIndex ptrNode = NF.getValueNodeFor(ptr); auto itr = funcPtsGraph.find(ptrNode); @@ -438,7 +437,6 @@ bool CallGraphPass::runOnFunction(Function *F) { CG_LOG("Load: source obj: " << idx << "\n"); if (idx == NF.getNullObjectNode() && itr->second.find_next(idx) == end) { CG_LOG("Loading from null obj, ptr = " << ptrNode << "\n"); - isNull = true; // XXX funcPtsGraph[valNode].insert(idx); break; diff --git a/src/lib/KAMain.cc b/src/lib/KAMain.cc index b504ee2..c24eb4a 100644 --- a/src/lib/KAMain.cc +++ b/src/lib/KAMain.cc @@ -5,6 +5,7 @@ * Copyright (C) 2015 Byoungyoung Lee * Copyright (C) 2016 Kangjie Lu * Copyright (C) 2015 - 2024 Chengyu Song + * Copyrigth (C) 2024 - 2025 Haochen Zeng * * For licensing details see LICENSE */ @@ -42,6 +43,9 @@ cl::list InputFilenames( cl::opt VerboseLevel( "verbose", cl::desc("Verbose level"), cl::init(0)); +cl::opt CallStackLen( + "call-stack-len", cl::desc("The maximum call stack length from entry to the targets"), cl::init(10)); + cl::opt UseTypeBasedCallGraph( "type-based-callgraph", cl::desc("Use type-based call graph"), cl::init(false)); @@ -52,17 +56,30 @@ cl::opt EntryList( "entry-list", cl::desc("Entry list"), cl::init("")); cl::opt DumpPolicy( - "dump-policy", cl::desc("Dump static policy"), cl::init("")); + "dump-policy", cl::desc("Dump policy, format: bid,true_distance,false_distance,false_bid,true_bid"), cl::init("")); cl::opt DumpDistance( - "dump-distance", cl::desc("Dump distance"), cl::init("")); + "dump-distance", cl::desc("Dump distances, format: bid,bb_hash,loc,distance"), cl::init("")); + +cl::opt DumpCriticalBBs( + "dump-critical-branch", cl::desc("Dump critical basic blocks, format: critical_bid, exit_bid_1, exit_bid_2, ..."), cl::init("")); cl::opt DumpBidMapping( - "dump-bid-mapping", cl::desc("Dump basic block ID mapping, format: bid,fun_GUID,filepath:linenum"), cl::init("")); + "dump-bid-mapping", cl::desc("Dump basic block ID mapping, format: bid,bb_hash,fun_GUID,filepath:linenum"), cl::init("")); cl::opt DumpFuncInfo( "dump-func-info", cl::desc("Dump function info, format: fun_GUID,fun_name,filepath,start_linenum,end_linenum"), cl::init("")); +cl::opt DumpCallerCallee( + "dump-caller-callee", + cl::desc("Dump caller → callee mapping, format: caller_GUID,callee_GUID,..."), + cl::init("")); + +cl::opt DumpCalleeCaller( + "dump-callee-caller", + cl::desc("Dump callee → caller mapping, format: callee_GUID,caller_GUID,..."), + cl::init("")); + cl::opt DumpAnnotatedIR( "dump-annotated-ir", cl::desc("Dump annotated IR"), cl::init("")); @@ -222,7 +239,8 @@ int main(int argc, char **argv) { TyCG.run(GlobalCtx.Modules); } - ReachableCallGraphPass RCGPass(&GlobalCtx, TargetList, EntryList, UseTypeBasedCallGraph); + ReachableCallGraphPass RCGPass(&GlobalCtx, TargetList, EntryList, + UseTypeBasedCallGraph, CallStackLen); RCGPass.run(GlobalCtx.Modules); if (!DumpBidMapping.empty() && !DumpFuncInfo.empty()){ @@ -230,17 +248,27 @@ int main(int argc, char **argv) { std::ofstream funcInfo(DumpFuncInfo); RCGPass.dumpIDMapping(GlobalCtx.Modules, bbLocs, funcInfo); } + if (!DumpCallerCallee.empty() && !DumpCalleeCaller.empty()){ + std::ofstream callercallee(DumpCallerCallee); + std::ofstream calleecaller(DumpCalleeCaller); + RCGPass.dumpCallees(callercallee); + RCGPass.dumpCallers(calleecaller); + } if (!DumpPolicy.empty()) { std::ofstream policy(DumpPolicy); RCGPass.dumpPolicy(policy); } if (!DumpDistance.empty()) { std::ofstream distance(DumpDistance); - RCGPass.dumpDistance(distance, true, false); + RCGPass.dumpDistance(distance, true); } if (!DumpAnnotatedIR.empty()) { RCGPass.annotateModules(GlobalCtx.Modules, DumpAnnotatedIR); } + if (!DumpCriticalBBs.empty()) { + std::ofstream criticalBBs(DumpCriticalBBs); + RCGPass.dumpCriticalBBs(criticalBBs); + } return 0; } diff --git a/src/lib/Reachable.cc b/src/lib/Reachable.cc index 16dab8a..f3403a9 100644 --- a/src/lib/Reachable.cc +++ b/src/lib/Reachable.cc @@ -2,6 +2,7 @@ * Reachability-based Call Graph Analysis * * Copyrigth (C) 2024 - 2025 Chengyu Song + * Copyrigth (C) 2024 - 2025 Haochen Zeng * * For licensing details see LICENSE */ @@ -55,6 +56,191 @@ using namespace llvm; +static std::string getSourceLocation(const BasicBlock *BB) { + for (const auto &I : *BB) { + auto loc = I.getDebugLoc(); + if (loc && loc.getLine() != 0) { + // Get the filename from the debug location + std::string f = loc->getFilename().str(); + // If filename is empty, get it from the parent function + if (f.empty()) { + f = BB->getParent()->getParent()->getSourceFileName(); + } + // Remove leading "./" if present + if (f.find("./") == 0) { + f = f.substr(2); + } + // Extract the base filename by finding the last '/' or '\\' + size_t pos = f.find_last_of("/\\"); + if (pos != std::string::npos) { + f = f.substr(pos + 1); + } + return f + ":" + std::to_string(loc.getLine()); + } + } + return "NoLoc:0"; +} + +/// \brief Retrieve the first available debug location in \p BB that is not +/// inside /usr/ and store the **absolute, normalized path** in \p Filename. +/// Sets \p Line and \p Col accordingly. +/// +/// This version does: +/// 1) Loops over instructions in \p BB +/// 2) Checks the debug location (and possibly inlined-at location) +/// 3) Builds an absolute, normalized path (resolving "." and "..") +/// 4) Skips if the path is empty, line=0, or the path starts with "/usr/" +/// 5) Returns the first valid debug info found +static void getDebugLocationFullPath(const BasicBlock &BB, + std::string &Filename, + unsigned &Line, + unsigned &Col) { + Filename.clear(); + Line = 0; + Col = 0; + + // We don't want paths that point to system libraries + static const std::string Xlibs("/usr/"); + auto isSystemLikePath = [](StringRef P) -> bool { + if (P.empty()) return false; + // Consider any path that is exactly /usr/... or contains /usr/ segment + // as system-like (covers sysroot cases like /toolchain/sysroot/usr/...) + if (P.startswith("/usr/")) return true; + return P.contains("/usr/"); + }; + + // Iterate over instructions in the basic block + for (auto &Inst : BB) { + if (DILocation *Loc = Inst.getDebugLoc()) { + // Fallback: remember the first valid system-lib location if no user code is found + std::string systemFallbackPath; + unsigned systemFallbackLine = 0; + unsigned systemFallbackCol = 0; + + // Walk inlined-at chain from inner to outer to prefer user code call sites + for (DILocation *Cur = Loc; Cur != nullptr; Cur = Cur->getInlinedAt()) { + std::string Dir = Cur->getDirectory().str(); + std::string File = Cur->getFilename().str(); + unsigned L = Cur->getLine(); + unsigned C = Cur->getColumn(); + + // Skip if missing filename or invalid line + if (File.empty() || L == 0) + continue; + + // Normalize suspicious relative system paths like "usr/..." to "/usr/..." + if (!Dir.empty() && !llvm::sys::path::is_absolute(Dir) && llvm::StringRef(Dir).startswith("usr/")) { + Dir = "/" + Dir; + } + if (!File.empty() && !llvm::sys::path::is_absolute(File) && llvm::StringRef(File).startswith("usr/")) { + File = "/" + File; + } + + // Build an absolute path in a SmallString + llvm::SmallString<256> FullPath; + + // If File itself is absolute, prefer it directly + if (!File.empty() && llvm::sys::path::is_absolute(File)) { + FullPath = File; + } else { + // If Dir is already absolute, start with that. Otherwise base on CWD. + if (!Dir.empty() && llvm::sys::path::is_absolute(Dir)) { + FullPath = Dir; + } else { + llvm::sys::fs::current_path(FullPath); + if (!Dir.empty()) { + llvm::sys::path::append(FullPath, Dir); + } + } + // Append the filename (relative) + llvm::sys::path::append(FullPath, File); + } + + // Normalize dots + llvm::sys::path::remove_dots(FullPath, /*remove_dot_dot=*/true); + + // Skip if system-like, but record the first one as a fallback + StringRef FullRef(FullPath); + if (isSystemLikePath(FullRef)) { + if (systemFallbackPath.empty()) { + systemFallbackPath = FullPath.str().str(); + systemFallbackLine = L; + systemFallbackCol = C; + } + continue; + } + + // Found a valid location => set output vars + Filename = FullPath.str().str(); + Line = L; + Col = C; + break; + } + + // If we selected a valid non-system frame, stop scanning instructions + if (!Filename.empty()) + break; + + // If not found in this instruction's inlined chain, but we have a + // system fallback recorded, use it and stop. + if (Filename.empty() && !systemFallbackPath.empty()) { + Filename = systemFallbackPath; + Line = systemFallbackLine; + Col = systemFallbackCol; + break; + } + } + } +} + +// === Helpers to distinguish developer-introduced EH from compiler cleanups === +static bool isPureCleanupLP(const llvm::BasicBlock *BB) { + // Look for a landingpad as the first non-PHI instruction; treat a pure + // `cleanup` landingpad with zero clauses as compiler-generated cleanup. + for (const llvm::Instruction &I : *BB) { + if (I.getOpcode() == llvm::Instruction::PHI) continue; // skip PHIs + if (auto *LPI = llvm::dyn_cast(&I)) { + return LPI->isCleanup() && LPI->getNumClauses() == 0; + } + break; // first non-PHI wasn't a landingpad + } + return false; +} + +static bool hasUserDebugLocation(const llvm::BasicBlock *BB, std::string &OutPath) { + OutPath.clear(); + unsigned L = 0, C = 0; + getDebugLocationFullPath(*BB, OutPath, L, C); + if (OutPath.empty()) return false; + llvm::StringRef P(OutPath); + // Be conservative: treat anything under /usr/ as non-user code + if (P.contains("/usr/")) return false; + return true; +} + +static bool isDeveloperExceptionBB(const llvm::BasicBlock *BB) { + // Only consider blocks that actually resume unwinding + if (!llvm::isa(BB->getTerminator())) + return false; + + // If this is a pure cleanup landing pad, it's almost certainly compiler-gen + if (isPureCleanupLP(BB)) + return false; + + // Require a non-system debug location + std::string P; + if (!hasUserDebugLocation(BB, P)) + return false; + +#if LLVM_VERSION_MAJOR >= 15 + if (auto DL = BB->getTerminator()->getDebugLoc()) { + if (DL->isImplicitCode()) + return false; // compiler-synthesized + } +#endif + return true; +} + Function* ReachableCallGraphPass::getFuncDef(Function *F) { FuncMap::iterator it = Ctx->Funcs.find(F->getGUID()); if (it != Ctx->Funcs.end()) @@ -169,7 +355,7 @@ bool ReachableCallGraphPass::isCompatibleType(Type *T1, Type *T2) { bool ReachableCallGraphPass::findCalleesByType(CallBase *CB, FuncSet &FS) { bool Changed = false; - RA_LOG("Handle indirect call: " << *CB << "\n"); + RA_DEBUG("Handle indirect call: " << *CB << "\n"); for (const Function *F : Ctx->AddressTakenFuncs) { // just compare known args if (F->getFunctionType()->isVarArg()) { @@ -219,18 +405,75 @@ bool ReachableCallGraphPass::runOnFunction(Function *F) { bool Changed = false; RA_LOG("### Run on function: " << F->getName() << "\n"); - for (auto &BB : *F) { - for (auto &i : BB) { - Instruction *I = &i; - // assign a BB ID - if (BBIDs.find(&BB) == BBIDs.end()) { - BBIDs[&BB] = nextBBID++; - if (auto *SI = dyn_cast(BB.getTerminator())) { - // assign a unique ID to the switch case - nextBBID += SI->getNumCases(); + // if no entry specified, use the common one + // collect the exit block of the entry function too + bool isEntry = false; + if (entryList.empty()) { + isEntry = isEntryFn(F->getName()); + } else { + auto itr = std::find(entryList.begin(), entryList.end(), F->getName().str()); + isEntry = (itr != entryList.end()); + } + if (isEntry) { + // Record entry block + entryBBs.insert(&F->getEntryBlock()); + RA_LOG("[init] Entry function detected: " << F->getName() << "\n"); + // Compute the maximum source line number for this function (first pass) + unsigned maxLine = 0; + for (const auto &BB : *F) { + for (const auto &I : BB) { + if (auto DL = I.getDebugLoc()) { + maxLine = std::max(maxLine, DL.getLine()); + } + } + } + // Seed exitBBs (second pass) + for (const auto &BB : *F) { + // Never treat the entry block as an exit block + if (&BB == &F->getEntryBlock()) { + continue; + } + if (maxLine > 0) { + // Also include any BB whose debug line equals the function's last line + for (const auto &I : BB) { + if (auto DL = I.getDebugLoc()) { + if (DL.getLine() == maxLine) { + exitBBs.insert(&BB); + RA_LOG("[init] ExitByMaxLine added: " << F->getName() << " BB @ " << getSourceLocation(&BB) << " (maxLine=" << maxLine << ")\n"); + break; + } + } } } + } + } + + for (auto &BB : *F) { + // assign a BB ID + if (BBIDs.find(&BB) == BBIDs.end() || BBIDs[&BB] == 0) { + BBIDs[&BB] = nextBBID++; + if (auto *SI = dyn_cast(BB.getTerminator())) { + // assign a unique ID to the switch case + nextBBID += SI->getNumCases(); + } + } + auto* TI = BB.getTerminator(); + // Treat unreachable as exit; treat resume (EH) as exit only when it's + // likely developer-introduced (not compiler cleanup). + bool isDevEH = isa(TI) && isDeveloperExceptionBB(&BB); + if (isa(TI) || isDevEH) { + RA_DEBUG((isDevEH ? "Developer EH BB: " : "Unreachable Inst BB: ") << BBIDs[&BB] << "\n"); + exitBBs.insert(&BB); + RA_LOG("[add-exit] by terminator: BB " << BBIDs[&BB] + << " @ " << getSourceLocation(&BB) + << " func " << F->getName() + << " term=" << TI->getOpcodeName() + << (isDevEH ? ", reason=developer-exception" : "UnreachableInst") + << "\n"); + } + for (auto &i : BB) { + Instruction *I = &i; if (UseTypeBasedCallGraph) { if (CallBase *CI = dyn_cast(I)) { @@ -240,9 +483,19 @@ bool ReachableCallGraphPass::runOnFunction(Function *F) { Changed |= Ctx->Callees[CI].insert(RCF).second; Changed |= Ctx->Callers[RCF].insert(CI).second; // check for call to exit functions - if (isExitFn(RCF->getName())) { - RA_LOG("Exit Call: " << *CI << "\n"); + bool __isExitFn = isExitFn(RCF->getName()); + bool __doesNotReturn = CF->doesNotReturn(); + if (__isExitFn || __doesNotReturn) { + RA_DEBUG("Exit Call: " << *CI << "\n"); exitBBs.insert(CI->getParent()); + RA_LOG("[add-exit] by call: BB " << BBIDs[CI->getParent()] + << " @ " << getSourceLocation(CI->getParent()) + << " func " << F->getName() + << " callee=" << RCF->getName() + << " reason=" << (__isExitFn ? "isExitFn" : "") + << ((__isExitFn && __doesNotReturn) ? "+" : "") + << (__doesNotReturn ? "doesNotReturn" : "") + << "\n"); } } else if (!CI->isInlineAsm()) { // indirect call @@ -271,7 +524,9 @@ bool ReachableCallGraphPass::runOnFunction(Function *F) { if (f.find(target.first) != std::string::npos && loc.getLine() == target.second) { RA_LOG("Target I: " << *I << "\n"); distances[I->getParent()] = 0.0; + targetBBs.insert(I->getParent()); reachableBBs.insert(I->getParent()); + reachableFuns.insert(F); } } } @@ -328,24 +583,6 @@ bool ReachableCallGraphPass::doInitialization(Module *M) { } } } - - // if no entry specified, use the common one - // collect the exit block of the entry function too - bool isEntry = false; - if (entryList.empty()) { - isEntry = isEntryFn(F.getName()); - } else { - auto itr = std::find(entryList.begin(), entryList.end(), F.getName().str()); - isEntry = (itr != entryList.end()); - } - if (isEntry) { - entryBBs.insert(&F.getEntryBlock()); - for (auto &BB : F) { - if (isa(BB.getTerminator())) { - exitBBs.insert(&BB); - } - } - } } return false; @@ -355,15 +592,132 @@ bool ReachableCallGraphPass::doFinalization(Module *M) { return false; } +void ReachableCallGraphPass::propagateThroughReturnEdgees( + std::unordered_set &retReachable, + const BasicBlock* startBB) { + // Only collect BBs via return-edges. Do not touch the main worklist or callers. + if (startBB == nullptr) { + return; + } + + std::deque local; + local.push_back(startBB); + + while (!local.empty()) { + const BasicBlock *BB = local.front(); + local.pop_front(); + retReachable.insert(BB); + + unsigned currDepth = 0; + if (auto it = retDepth.find(BB); it != retDepth.end()) { + currDepth = it->second; + } + if (currDepth >= maxCallStackDepth) { + RA_LOG("Max depth reached (" << maxCallStackDepth + << ") for BB " << BBIDs[BB] << ", skipping ret-edge propagation\n"); + continue; + } + + // Add CFG predecessors to continue backward propagation + for (auto PI = pred_begin(BB), PE = pred_end(BB); PI != PE; ++PI) { + const BasicBlock *Pred = *PI; + if (retReachable.count(Pred)) { + continue; // already processed + } + // keep same ret-depth across normal CFG edges + if (currDepth != 0) { + retDepth[Pred] = currDepth; + } + local.push_back(Pred); + } + + // If this BB has interesting callsites, push callee return blocks + auto hasCalls = BBswithCalls.find(BB); + if (hasCalls == BBswithCalls.end()) { + continue; + } + const CallSequence &calls = hasCalls->second; + for (size_t i = calls.size(); i-- > 0; ) { + const llvm::CallBase* CI = calls[i]; + // Unified lookup of direct or type-based callees + const FuncSet *callees = nullptr; + if (auto it = Ctx->Callees.find(CI); it != Ctx->Callees.end()) { + callees = &it->second; + } else if (UseTypeBasedCallGraph) { + if (auto it2 = calleeByType.find(CI); it2 != calleeByType.end()) { + callees = &it2->second; + } + } + if (!callees) { + RA_DEBUG("No callee for " << *CI << "\n"); + continue; + } + + for (const auto *F : *callees) { + if (isExitFn(F->getName()) || F->doesNotReturn()) { + RA_DEBUG("DoesNotReturn: " << F->getName() << "\n"); + break; // stop on no-return functions + } + static std::unordered_set Seen; + if (Seen.insert(F).second) { + reachableFuns.insert(F); + RA_LOG(F->getName() << " is reachable through ret edge to the targets\n"); + } + for (const auto &TBB : *F) { + if (isa(TBB.getTerminator())) { + if (retReachable.count(&TBB)) { + continue; // already processed + } + retDepth[&TBB] = currDepth + 1; + // Keep exploring ret-edges from new return blocks as well + local.push_back(&TBB); + RA_DEBUG("[ret] add callee ret-BB: " << F->getName() + << " -> " << BBIDs[&TBB] << "\n"); + } + } // end of propagate through return BBs + } // end of propagate through potential callees + } // end of propagate through all call sites + } // end of local worklist +} + void ReachableCallGraphPass::collectReachable(std::deque &worklist, - std::unordered_set &reachable) { + std::unordered_set &reachable, + const std::unordered_set &others) { + bool isComputingReachable = others.empty(); + // Accumulator for ret-edge-only BBs across the whole BFS + std::unordered_set retEdgeAccum; while (!worklist.empty()) { auto *BB = worklist.front(); worklist.pop_front(); + // add callee when computing reachable BBs + if (isComputingReachable) { + // collect ret-edge-only BBs into accumulator; do not mutate 'reachable' here + propagateThroughReturnEdgees(retEdgeAccum, BB); + RA_DEBUG("[collectReachable] ret-edge accum size=" << retEdgeAccum.size() << ", from BB=" << BBIDs[BB] << " @ " << getSourceLocation(BB) << "\n"); + }else if (others.find(BB) != others.end()) { + continue; + } // add predecessors for (auto PI = pred_begin(BB), PE = pred_end(BB); PI != PE; ++PI) { - auto *Pred = *PI; - if (reachable.insert(Pred).second) { + const BasicBlock *Pred = *PI; + // if the predecessor is reachable to the target + // stop propagating unreachable BB through it + if (others.find(Pred) != others.end()) { + criticalBBs[Pred].push_back(BB); + continue; + } + if (reachable.find(Pred) != reachable.end()) { + continue; // already added + } else if(reachable.insert(Pred).second) { + RA_DEBUG("Adding " << BBIDs[BB] << "'s Pred: " << BBIDs[Pred] << "\n"); + // When computing exit BBs (others is not empty), log propagation reason + if (!isComputingReachable) { + RA_LOG("[add-exit] by pred-edge: add BB " << BBIDs[Pred] + << " @ " << getSourceLocation(Pred) + << " func " << Pred->getParent()->getName() + << " from Succ " << BBIDs[BB] + << " @ " << getSourceLocation(BB) << "\n"); + } worklist.push_back(Pred); } } @@ -371,7 +725,6 @@ void ReachableCallGraphPass::collectReachable(std::deque &wor auto *F = BB->getParent(); if (BB == &F->getEntryBlock()) { if (entryBBs.find(BB) != entryBBs.end()) { - RA_LOG("Entry func " << F->getName() << " is reachable\n"); continue; } auto itr = Ctx->Callers.find(F); @@ -382,64 +735,61 @@ void ReachableCallGraphPass::collectReachable(std::deque &wor found = (itr != callerByType.end()); } if (!found) { - WARNING("No caller for " << F->getName() << "\n"); + static std::unordered_set WarnedNoCaller1; + if (WarnedNoCaller1.insert(F).second) { + std::string context_str = isComputingReachable ? "Reachable Analysis: " : "Unreachable Analysis: "; + WARNING(context_str << "No caller for " << F->getName() << "\n"); + } continue; } } - RA_DEBUG(F->getName() << " is reachable\n"); - for (auto CI : itr->second) { - auto CBB = CI->getParent(); - // go through instructions, handle additional callees - bool willReturn = true; - bool added = false; - if (true /*PropagateThroughReturnEdgees*/) { - // always propagate reachability through return edges - auto hasCalls = BBswithCalls.find(CBB); - assert(hasCalls != BBswithCalls.end()); - auto calls = hasCalls->second; - for (auto i = calls.size() - 1; i > 0; --i) { - if (calls[i] == CI) { - // find the current callsite and there are additional callees before it - auto PCI = calls[i - 1]; - auto fitr = Ctx->Callees.find(PCI); - if (fitr == Ctx->Callees.end()) { - if (UseTypeBasedCallGraph) { - fitr = calleeByType.find(PCI); - } - } - // any callsite here is guaranteed to have a callee - for (auto F : fitr->second) { - if (F->doesNotReturn()) { - RA_DEBUG("DoesNotReturn: " << F->getName() << "\n"); - willReturn = false; - break; // not need to continue - } - // add exit block(s) as reachable - for (auto &TBB : *F) { - if (isa(TBB.getTerminator())) { - RA_LOG("Adding callee: " << F->getName() << "\n"); - if (reachable.insert(&TBB).second) { - worklist.push_back(&TBB); - added = true; - } - } - } - } - if (added) break; // one callsite at a time - } - } + if (isComputingReachable) { + reachableFuns.insert(F); + RA_LOG(F->getName() << " is reachable through call edge to the targets\n"); + }else { + RA_LOG(F->getName() << " is reachable to the exit\n"); + } + unsigned currDepth = callDepth[BB]; + for (auto *CI : itr->second) { + auto *CBB = CI->getParent(); + unsigned newDepth = currDepth + 1; + if (newDepth > maxCallStackDepth) { + RA_LOG("Max depth reached (" << maxCallStackDepth + << ") for function " << F->getName() << ", skipping caller\n"); + continue; // do not propagate beyond threshold + } + // If this caller basic block is already known reachable-to-target, + // mark critical and skip adding to exit set. + if (others.find(CBB) != others.end()) { + criticalBBs[CBB].push_back(BB); + continue; + } + if (reachable.find(CBB) != reachable.end()) { + continue; // already added } - if (willReturn && !added) { - // if all callsites have been processed, add the CBB - RA_DEBUG("\tadding caller: " << CI->getFunction()->getName() << "\n"); - if (reachable.insert(CBB).second) { - worklist.push_back(CBB); + // if all callsites have been processed, add the CBB + RA_DEBUG("\tadding caller: " << CI->getFunction()->getName() << "\n"); + if (reachable.insert(CBB).second) { + callDepth[CBB] = newDepth; // record depth before enqueue + worklist.push_back(CBB); + // When computing exit BBs (others is not empty), log propagation via caller edge + if (!isComputingReachable) { + RA_LOG("[add-exit] by caller-edge: add BB " << BBIDs[CBB] + << " @ " << getSourceLocation(CBB) + << " func " << CBB->getParent()->getName() + << " via call into callee " << F->getName() << "\n"); } } } // end of callers } // end of entry block } + // Merge ret-edge-only BBs after BFS completes + if (isComputingReachable) { + for (const BasicBlock *RBB : retEdgeAccum) { + reachable.insert(RBB); + } + } } void ReachableCallGraphPass::run(ModuleList &modules) { @@ -467,19 +817,48 @@ void ReachableCallGraphPass::run(ModuleList &modules) { WARNING("No entry BBs found\n"); return; } + RA_LOG("[run] Num entry BBs: " << entryBBs.size() << "\n"); + for (auto *EBB : entryBBs) { + RA_LOG("[run] Entry BB: " << BBIDs[EBB] << " @ " << getSourceLocation(EBB) << " of function " << EBB->getParent()->getName() << "\n"); + } - // do a BFS search on the call graph to find BB that can reach exits + // do a BFS search on the target list, find all reachable BBs first std::deque worklist; - RA_DEBUG("\n\n=== Collecting exit BBs ===\n\n"); - worklist.insert(worklist.end(), exitBBs.begin(), exitBBs.end()); - collectReachable(worklist, exitBBs); - - // now do a BFS search on the target list, find all reachable BBs first RA_LOG("\n\n=== Collecting reachable BBs ===\n\n"); + callDepth.clear(); + retDepth.clear(); for (const auto &kv : distances) { worklist.push_back(kv.first); } collectReachable(worklist, reachableBBs); + RA_LOG("[run] reachableBBs after target-backward: " << reachableBBs.size() << "\n"); + + // clean exitBBs + { + std::vector toErase; + toErase.reserve(exitBBs.size()); + for (const auto *BB : exitBBs) { + if (reachableBBs.find(BB) != reachableBBs.end()) { + toErase.push_back(BB); + } + } + for (const auto *BB : toErase) { + RA_LOG("[run] Removing BB from exitBBs " << BBIDs[BB] << " @ " << getSourceLocation(BB) << "\n"); + exitBBs.erase(BB); + } + } + + // do a BFS search on the call graph to find BB that can reach exits + RA_LOG("\n\n=== Collecting exit BBs ===\n\n"); + worklist.clear(); + callDepth.clear(); + retDepth.clear(); + for (auto *BB : exitBBs) { + RA_LOG("[run] Seed exit BB: " << BBIDs[BB] << " @ " << getSourceLocation(BB) << "\n"); + worklist.push_back(BB); + } + collectReachable(worklist, exitBBs, reachableBBs); + RA_LOG("[run] exitBBs reachable to target size: " << exitBBs.size() << "\n"); // check if target is reachable bool reached = false; @@ -488,6 +867,9 @@ void ReachableCallGraphPass::run(ModuleList &modules) { RA_LOG("\n\n=== Target is reachable from entry ===\n\n"); reached = true; } + else { + RA_LOG("[run] Entry not in reachableBBs: " << BBIDs[entry] << " @ " << getSourceLocation(entry) << " func " << entry->getParent()->getName() << "\n"); + } } if (!reached) { @@ -498,7 +880,9 @@ void ReachableCallGraphPass::run(ModuleList &modules) { // now calculate distances in a bottom-up manner std::unordered_set queued; std::unordered_set queuedCalls; + callDepth.clear(); for (const auto &kv : distances) { + callDepth[kv.first] = 0; worklist.push_back(kv.first); queued.insert(kv.first); } @@ -507,101 +891,20 @@ void ReachableCallGraphPass::run(ModuleList &modules) { auto *BB = worklist.front(); worklist.pop_front(); queued.erase(BB); - if (PropagateThroughReturnEdgees) { - // go through instructions, looking for calls - auto hasCalls = BBswithCalls.find(BB); - if (hasCalls != BBswithCalls.end()) { - auto &calls = hasCalls->second; - bool finished = false; - const CallBase *propagate = nullptr; - double dist = NAN; - for (auto i = calls.size() - 1;; --i) { - // iterate through all callsites, in reverse order - auto *CI = calls[i]; - if (queuedCalls.find(CI) != queuedCalls.end()) { - // if the reachability comes from the callee - RA_DEBUG("Find current callsite: " << *CI << "\n"); - queuedCalls.erase(CI); - if (i > 0) { - // there are additional callees before the current callsite - // we need to propagate the reachability to them - RA_DEBUG("Propagate to additional callees\n"); - // get the distance of current callsite - auto itr = callDistances.find(CI); - assert(itr != callDistances.end()); - dist = itr->second; - // record the callsite to be propagated to - propagate = calls[i - 1]; - } else { - // all callees have been processed - finished = true; - } - break; // always break if coming from callee - } - if (i == 0) break; - } - if (!finished) { - // if not finished, we either have more callsite(s) to process, - // or the reachability is coming from the successor, - if (propagate == nullptr) { - // in the later case, we want to propagate BB distance to the last callsite - RA_DEBUG("Propagate BB distance to last callsite\n"); - propagate = calls.back(); - dist = distances[BB]; - } - // propagate to return sites in the callee - auto fitr = Ctx->Callees.find(propagate); - if (fitr == Ctx->Callees.end()) { - if (UseTypeBasedCallGraph) { - fitr = calleeByType.find(propagate); - } - } - bool added = false; - // any callsite here is guaranteed to have a callee - for (auto F : fitr->second) { - // add exit block(s) as reachable - for (auto &TBB : *F) { - if (isa(TBB.getTerminator())) { - auto itr = distances.find(&TBB); - if (itr == distances.end() || itr->second > dist) { - RA_DEBUG("Propagate distance " << dist << " to callee: " << F->getName() << "\n"); - distances[&TBB] = dist; - if (queued.insert(&TBB).second) { - worklist.push_back(&TBB); - } - added = true; - } - } - } - } - if (added) { - // if we have propagated reachability to a callee, it will come back, - // so we don't need to propagate to predecessors for now - continue; - } else { - // there is another callsite but no propagation is needed - // simulate the propagation by adding the callsite to the queue - queuedCalls.insert(propagate); - if (queued.insert(BB).second) - worklist.push_back(BB); - continue; - } - } else { - // if all callsites have been processed, use the distance of the first - // callsite as the distance of the BB - RA_DEBUG("All callees processed\n"); - auto itr = callDistances.find(calls.front()); - assert(itr != callDistances.end()); - dist = itr->second; - distances[BB] = dist; - } - } + RA_DEBUG("[distance] Pop BB: " << BBIDs[BB] << " @ " << getSourceLocation(BB) << ", depth=" << callDepth[BB] << "\n"); + unsigned currDepth = callDepth[BB]; + if (currDepth >= maxCallStackDepth) { + continue; // do not propagate beyond threshold } // check predecessors for (auto PI = pred_begin(BB), PE = pred_end(BB); PI != PE; ++PI) { auto *Pred = *PI; double numSucc = 0.0; double prob = 0.0; + if (reachableBBs.find(Pred) == reachableBBs.end()) { + RA_DEBUG("Skip unreachable Pred: " << *Pred << "\n"); + continue; + } for (auto SI = succ_begin(Pred), SE = succ_end(Pred); SI != SE; ++SI) { auto *Succ = *SI; numSucc += 1.0; @@ -617,7 +920,7 @@ void ReachableCallGraphPass::run(ModuleList &modules) { } prob /= numSucc; if (prob == 0.0) { - WARNING("prob dropped to 0 for basic block in " << BB->getParent()->getName() << "\n"); + WARNING("prob dropped to 0 for BB "<< getSourceLocation(BB) << " in " << BB->getParent()->getName() << "\n"); RA_DEBUG("\t " << *BB << "\n"); continue; } @@ -630,17 +933,19 @@ void ReachableCallGraphPass::run(ModuleList &modules) { if (itr == distances.end() || itr->second > dist) { // RA_DEBUG("Adding Pred: " << *Pred << " with prob " << prob << "\n"); distances[Pred] = dist; - if (queued.insert(Pred).second) + if (queued.insert(Pred).second){ + callDepth[Pred] = currDepth; worklist.push_back(Pred); + RA_DEBUG("[distance] Enqueue Pred: " << BBIDs[Pred] << " @ " << getSourceLocation(Pred) << ", dist=" << dist*1000 << "\n"); + } } } // entry block has no predecessor, add caller auto *F = BB->getParent(); if (BB == &F->getEntryBlock()) { if (entryBBs.find(BB) != entryBBs.end()) { - RA_LOG("Entry func " << F->getName() << " is reachable\n"); - break; - // continue; + // break; + continue; } auto itr = Ctx->Callers.find(F); if (itr == Ctx->Callers.end()) { @@ -650,16 +955,18 @@ void ReachableCallGraphPass::run(ModuleList &modules) { found = (itr != callerByType.end()); } if (!found) { - if (!F->getName().equals("main")) { + static std::unordered_set WarnedNoCaller2; + if (WarnedNoCaller2.insert(F).second) { WARNING("No caller for " << F->getName() << "\n"); - } else { - RA_LOG("main is reached\n"); } continue; } } - - RA_LOG(F->getName() << " is reachable from " << itr->second.size() << " callers\n"); + // check callers + static std::unordered_set Seen; + if (Seen.insert(F).second) { + RA_DEBUG(F->getName() << " is reachable from " << itr->second.size() << " callers\n"); + } auto dist = distances[BB]; for (auto CI : itr->second) { auto CBB = CI->getParent(); @@ -673,21 +980,19 @@ void ReachableCallGraphPass::run(ModuleList &modules) { auto itr2 = callDistances.find(CI); if (itr2 == callDistances.end() || itr2->second > dist) { RA_DEBUG("Adding direct caller: " << CI->getFunction()->getName() << "\n"); - if (PropagateThroughReturnEdgees) { - callDistances[CI] = dist; - queuedCalls.insert(CI); - } else { - distances[CBB] = dist; - } - if (queued.insert(CBB).second) + distances[CBB] = dist; + if (queued.insert(CBB).second){ + callDepth[CBB] = currDepth + 1; worklist.push_back(CBB); + RA_DEBUG("[distance] Enqueue Caller CBB: " << BBIDs[CBB] << " @ " << getSourceLocation(CBB) << ", from F=" << F->getName() << "\n"); + } } } else { // indirect call is tricky, treat like predecessors // for each call site, check if all its callees have been processed double prob = 0.0; FuncSet &Callees = UseTypeBasedCallGraph ? calleeByType[CI] : Ctx->Callees[CI]; - RA_LOG("\tfrom indirect call @" << CF->getName() << ", callee size = " << Callees.size() << "\n"); + RA_DEBUG("\tfrom indirect call @" << CF->getName() << ", callee size = " << Callees.size() << "\n"); // XXX: skip potentially imprecise callsites? if (Callees.size() > 50) { RA_DEBUG("Skip indirect call with too many callees\n"); @@ -720,14 +1025,11 @@ void ReachableCallGraphPass::run(ModuleList &modules) { auto itr2 = callDistances.find(CI); if (itr2 == callDistances.end() || itr2->second > dist) { RA_DEBUG("Adding indirect caller: " << CI->getFunction()->getName() << "\n"); - if (PropagateThroughReturnEdgees) { - callDistances[CI] = dist; - queuedCalls.insert(CI); - } else { distances[CBB] = dist; - } - if (queued.insert(CBB).second) + if (queued.insert(CBB).second){ + callDepth[CBB] = currDepth + 1; worklist.push_back(CBB); + } } } } @@ -739,12 +1041,16 @@ void ReachableCallGraphPass::run(ModuleList &modules) { } } -ReachableCallGraphPass::ReachableCallGraphPass(GlobalContext *Ctx_, - std::string &TargetList, std::string &EntryList, bool typeBased, - bool propagateRet) - : Ctx(Ctx_), UseTypeBasedCallGraph(typeBased), - PropagateThroughReturnEdgees(propagateRet), - nextBBID(1000) { +ReachableCallGraphPass::ReachableCallGraphPass( + GlobalContext *Ctx_, + std::string &TargetList, + std::string &EntryList, + bool typeBased, + unsigned CallStackLen) + : Ctx(Ctx_), + UseTypeBasedCallGraph(typeBased), + nextBBID(1000), + maxCallStackDepth(CallStackLen) { // parse target list // format: filename:line_number if (!TargetList.empty()) { @@ -784,190 +1090,49 @@ ReachableCallGraphPass::ReachableCallGraphPass(GlobalContext *Ctx_, } } -std::string ReachableCallGraphPass::getSourceLocation(const BasicBlock *BB) { - for (const auto &I : *BB) { - auto loc = I.getDebugLoc(); - if (loc && loc.getLine() != 0) { - // Get the filename from the debug location - std::string f = loc->getFilename().str(); - // If filename is empty, get it from the parent function - if (f.empty()) { - f = BB->getParent()->getParent()->getSourceFileName(); - } - // Remove leading "./" if present - if (f.find("./") == 0) { - f = f.substr(2); - } - // Extract the base filename by finding the last '/' or '\\' - size_t pos = f.find_last_of("/\\"); - if (pos != std::string::npos) { - f = f.substr(pos + 1); - } - return f + ":" + std::to_string(loc.getLine()); - } - } - return "NoLoc:0"; -} - -/// \brief Retrieve the first available debug location in \p BB that is not -/// inside /usr/ and store the **absolute, normalized path** in \p Filename. -/// Sets \p Line and \p Col accordingly. -/// -/// This version does: -/// 1) Loops over instructions in \p BB -/// 2) Checks the debug location (and possibly inlined-at location) -/// 3) Builds an absolute, normalized path (resolving "." and "..") -/// 4) Skips if the path is empty, line=0, or the path starts with "/usr/" -/// 5) Returns the first valid debug info found -void getDebugLocationFullPath(const BasicBlock &BB, - std::string &Filename, - unsigned &Line, - unsigned &Col) { - Filename.clear(); - Line = 0; - Col = 0; - - // We don't want paths that point to system libraries in /usr/ - static const std::string Xlibs("/usr/"); - - // Iterate over instructions in the basic block - for (auto &Inst : BB) { - if (DILocation *Loc = Inst.getDebugLoc()) { - // Extract directory & filename - std::string Dir = Loc->getDirectory().str(); - std::string File = Loc->getFilename().str(); - unsigned L = Loc->getLine(); - unsigned C = Loc->getColumn(); - - // If there's no filename, check the inlined location - if (File.empty()) { - if (DILocation *inlinedAt = Loc->getInlinedAt()) { - Dir = inlinedAt->getDirectory().str(); - File = inlinedAt->getFilename().str(); - L = inlinedAt->getLine(); - C = inlinedAt->getColumn(); - } - } - - // Skip if still no filename or line==0 - if (File.empty() || L == 0) - continue; - - // Build an absolute path in a SmallString - llvm::SmallString<256> FullPath; - - // 1) If Dir is already absolute, just start with that. - // Otherwise, use the current working directory as a base. - if (!Dir.empty() && llvm::sys::path::is_absolute(Dir)) { - FullPath = Dir; - } else { - llvm::sys::fs::current_path(FullPath); // get the current working dir - if (!Dir.empty()) { - llvm::sys::path::append(FullPath, Dir); - } - } - - // 2) Append the filename - llvm::sys::path::append(FullPath, File); - - // 3) Remove dot segments (both "." and "..") - llvm::sys::path::remove_dots(FullPath, /*remove_dot_dot=*/true); - - // Now FullPath is absolute & normalized - // Check if it's in /usr/ - if (StringRef(FullPath).startswith(Xlibs)) - continue; // skip system-libs - - // Found a valid location => set output vars - Filename = FullPath.str().str(); // convert to std::string - Line = L; - Col = C; - break; // stop after the first valid location - } - } -} - -void ReachableCallGraphPass::dumpDistance(std::ostream &OS, bool dumpSolution, bool dumpUnreachable) { - std::deque worklist; - std::unordered_set visited; - double currentDist = std::numeric_limits::max();; - for (auto BB : entryBBs) { - if (distances.find(BB) != distances.end()) { - RA_LOG("Entry BB of " << BB->getParent()->getName() << " is reachable\n"); - worklist.push_back(BB); - visited.insert(BB); - } +void ReachableCallGraphPass::dumpDistance(std::ostream &OS, bool dumpUnreachable) { + // Set precision for output + OS << std::fixed << std::setprecision(6); + // Copy and sort distances by ascending value + std::vector> sorted; + sorted.reserve(distances.size()); + for (const auto &entry : distances) { + sorted.emplace_back(entry.first, entry.second); } - if (worklist.empty()) { - WARNING("Target not reachable from entry BBs\n"); - return; + std::sort(sorted.begin(), sorted.end(), + [](const auto &a, const auto &b) { return a.second < b.second; }); + // Output sorted distance entries + for (const auto &pair : sorted) { + const BasicBlock *BB = pair.first; + double dist = pair.second; + OS << BBIDs[BB] << "," + << getBasicBlockId(BB) << "," + << getSourceLocation(BB) << "," + << (dist * 1000) << "\n"; } - // set precision - OS << std::fixed << std::setprecision(6); - - // dump reachable bb - while (!worklist.empty()) { - auto *BB = worklist.front(); - worklist.pop_front(); - auto dist = distances[BB]; - if (dumpSolution && (dist < currentDist)) { - currentDist = dist; - RA_LOG("Best option: " << BB->getParent()->getName() << " at " << currentDist << "\n"); - } - OS << BBIDs[BB] << "," << getBasicBlockId(BB) << "," << getSourceLocation(BB) << "," - << distances[BB] * 1000 << "\n"; - - for (auto &I : *BB) { - // check for callees - if (const CallBase *CI = dyn_cast(&I)) { - if (CI->isInlineAsm() || CI->isIndirectCall()) { - continue; // skip inline asm - } - auto itr = Ctx->Callees.find(CI); - if (itr == Ctx->Callees.end() && UseTypeBasedCallGraph) { - itr = calleeByType.find(CI); - if (itr == calleeByType.end()) { - WARNING("No callees for " << *CI << "\n"); - continue; // no callees - } - } - for (auto F: itr->second) { - auto *FBB = &F->getEntryBlock(); - if (distances.find(FBB) != distances.end() && visited.insert(FBB).second) { - worklist.push_back(FBB); - } - } - } - } - for (auto SI = succ_begin(BB), SE = succ_end(BB); SI != SE; ++SI) { - auto *Succ = *SI; - if (distances.find(Succ) != distances.end() && visited.insert(Succ).second) { - worklist.push_back(Succ); - } - } - } - // dump unreachable bb + // If dumpUnreachable is enabled, output unreachable basic blocks if (dumpUnreachable) { - for (auto BB : exitBBs) { - if (distances.find(BB) == distances.end()) { - OS << BBIDs[BB] << "," << getBasicBlockId(BB) << "," << getSourceLocation(BB) << ",-1\n"; - } + for (const auto *BB : exitBBs) { + OS << BBIDs[BB] << "," + << getBasicBlockId(BB) << "," + << getSourceLocation(BB) + << ",-1\n"; } } - // dump the covered functions - std::unordered_set reachedFunctions; - for (auto BB : reachableBBs) { - reachedFunctions.insert(BB->getParent()); + + // Dump the covered functions + std::unordered_set reachedFunctions; + for (const auto &entry : distances) { + reachedFunctions.insert(entry.first->getParent()); } OS << "##########\n"; - for (auto *F : reachedFunctions) { + for (const auto *F : reachedFunctions) { OS << "fun:" << F->getName().str() << "\n"; } } void ReachableCallGraphPass::dumpPolicy(std::ostream &OS) { - // set precision OS << std::fixed << std::setprecision(6); @@ -983,6 +1148,7 @@ void ReachableCallGraphPass::dumpPolicy(std::ostream &OS) { continue; auto TT = branch->getSuccessor(0); auto FT = branch->getSuccessor(1); + bool reached = false; std::string tdist; auto itr = distances.find(FT); @@ -1015,7 +1181,7 @@ void ReachableCallGraphPass::dumpPolicy(std::ostream &OS) { << "\nAnd no call in the BB\n"); } } else { - OS << BBIDs[BB] << "," << getBasicBlockId(BB) << "," << tdist << "," << fdist << "\n"; + OS << BBIDs[BB] << "," << tdist << "," << fdist << "," << BBIDs[FT] << "," << BBIDs[TT] << "\n"; } } @@ -1055,6 +1221,9 @@ void ReachableCallGraphPass::dumpIDMapping(ModuleList &modules, std::ostream &bb unsigned minLine = std::numeric_limits::max(); unsigned maxLine = 0; std::string filepath; + if (F.isDeclaration() || F.empty() || F.isIntrinsic()) { + continue; // skip declaration and intrinsic + } for (auto &BB : F) { unsigned line = 0; @@ -1078,13 +1247,31 @@ void ReachableCallGraphPass::dumpIDMapping(ModuleList &modules, std::ostream &bb } } +void ReachableCallGraphPass::dumpCriticalBBs(std::ostream &OS) { + for (auto const &[BB, exits] : criticalBBs) { + OS << BBIDs[BB]; + for (auto *exitBB : exits) + OS << "," << BBIDs[exitBB]; + OS << "\n"; + } +} + bool ReachableCallGraphPass::annotateModules(ModuleList &modules, std::string suffix) { + std::unordered_set inverseCriticalBBs; + for (const auto &[k,v] : criticalBBs) { + for (const auto *exitBB : v) { + inverseCriticalBBs.insert(exitBB); + } + } ModuleList::iterator i, e; - double max_dist = std::max_element(distances.begin(), distances.end(), - [](const std::pair &a, - const std::pair &b) { - return a.second < b.second; - })->second; + // double max_dist = INFINITY; + // if (!distances.empty()) { + // max_dist = std::max_element(distances.begin(), distances.end(), + // [](const std::pair &a, + // const std::pair &b) { + // return a.second < b.second; + // })->second; + // } for (i = modules.begin(), e = modules.end(); i != e; ++i) { Module *M = i->first; @@ -1092,35 +1279,63 @@ bool ReachableCallGraphPass::annotateModules(ModuleList &modules, std::string su auto NewName = ModName + suffix; auto VoidTy = Type::getVoidTy(M->getContext()); auto Int64Ty = Type::getInt64Ty(M->getContext()); - FunctionCallee TraceDistanceFunc = M->getOrInsertFunction( - "__taint_trace_distance", VoidTy, Int64Ty, Int64Ty); + auto *BoolTy = Type::getInt1Ty(M->getContext()); + auto *TrueVal = ConstantInt::getTrue(BoolTy); + auto *FalseVal = ConstantInt::getFalse(BoolTy); + GlobalVariable *HasReachedTarget = cast( + M->getOrInsertGlobal("has_reached_target", BoolTy)); + HasReachedTarget->setLinkage(GlobalValue::LinkOnceODRLinkage); + HasReachedTarget->setComdat(M->getOrInsertComdat(HasReachedTarget->getName())); + if (!HasReachedTarget->hasInitializer()) + HasReachedTarget->setInitializer(FalseVal); + + // FunctionCallee TraceDistanceFunc = M->getOrInsertFunction( + // "__taint_trace_distance", VoidTy, Int64Ty, Int64Ty); + FunctionCallee TraceFunc = M->getOrInsertFunction( + "__taint_trace_divergence", VoidTy, Int64Ty); for (auto &F : *M) { if (F.isDeclaration() || F.empty() || F.isIntrinsic()) { continue; // skip declaration and intrinsic } for (auto &BB : F) { - if (isa(BB.getFirstNonPHIOrDbgOrLifetime())) + if (isa(BB.getTerminator())) continue; // skip unreachable BBs if (BB.getFirstInsertionPt() == BB.end()) continue; // skip empty BBs - // annotate reachable basic block with ID and distance - if (reachableBBs.count(&BB)) { - // check if we have a distance - auto itr = distances.find(&BB); - double dist = (itr != distances.end()) ? itr->second : max_dist; - dist *= 1000.0; - // instrument a call to trace distance + + // add an annotation for other instrumentation + auto *BBID = ConstantInt::get(Int64Ty, BBIDs[&BB]); + auto term = BB.getTerminator(); + MDNode *MD = MDNode::get(M->getContext(), + {ConstantAsMetadata::get(BBID)}); + term->setMetadata("bbid", MD); + + // instrument __taint_trace_divergence callback + if (inverseCriticalBBs.count(&BB)) { IRBuilder<> IRB(&*BB.getFirstInsertionPt()); - auto *BBID = ConstantInt::get(Int64Ty, BBIDs[&BB]); - auto *Dist = ConstantInt::get(Int64Ty, (uint64_t)dist); - IRB.CreateCall(TraceDistanceFunc, {BBID, Dist})->setCannotMerge(); - - // add an annotation for other instrumentation - auto term = BB.getTerminator(); - MDNode *MD = MDNode::get(M->getContext(), - {ConstantAsMetadata::get(BBID)}); - term->setMetadata("bbid", MD); + auto *CI = IRB.CreateCall(TraceFunc, {BBID}); + CI->setCannotMerge(); + } + // Instrument code to set has_reached_target to true + for (const llvm::BasicBlock* tb : targetBBs) { + if (tb == &BB) { + IRBuilder<> IRB(BB.getTerminator()); + IRB.CreateStore(TrueVal, HasReachedTarget)->setMetadata( + M->getMDKindID("nosanitize"), MDNode::get(M->getContext(), None)); + break; + } } + // annotate reachable basic block with ID and distance + // if (reachableBBs.count(&BB)) { + // // check if we have a distance + // auto itr = distances.find(&BB); + // double dist = (itr != distances.end()) ? itr->second : max_dist; + // dist *= 1000.0; + // // instrument a call to trace distance + // IRBuilder<> IRB(&*BB.getFirstInsertionPt()); + // auto *Dist = ConstantInt::get(Int64Ty, (uint64_t)dist); + // IRB.CreateCall(TraceDistanceFunc, {BBID, Dist})->setCannotMerge(); + // } } } // verify @@ -1141,74 +1356,91 @@ bool ReachableCallGraphPass::annotateModules(ModuleList &modules, std::string su return true; } -void ReachableCallGraphPass::dumpCallees() { - RES_REPORT("\n[dumpCallees]\n"); - raw_ostream &OS = outs(); - OS << "Num of Callees: " << calleeByType.size() << "\n"; - for (CalleeMap::iterator i = calleeByType.begin(), - e = calleeByType.end(); i != e; ++i) { - - auto CI = i->first; - FuncSet &v = i->second; - // only dump indirect call? - if (CI->isInlineAsm() || CI->getCalledFunction() /*|| v.empty()*/) - continue; - - // OS << "CS:" << *CI << "\n"; - // const DebugLoc &LOC = CI->getDebugLoc(); - // OS << "LOC: "; - // LOC.print(OS); - // OS << "^@^"; - std::string prefix = "<" + CI->getParent()->getParent()->getParent()->getName().str() + ">" - + CI->getParent()->getParent()->getName().str() + "::"; -#if 1 - for (FuncSet::iterator j = v.begin(), ej = v.end(); - j != ej; ++j) { - //OS << "\t" << ((*j)->hasInternalLinkage() ? "f" : "F") - // << " " << (*j)->getName() << "\n"; - OS << prefix << *CI << "\t"; - OS << (*j)->getName() << "\n"; - } -#endif - // OS << "\n"; - - // v = Ctx->Callees[CI]; - // OS << "Callees: "; - // for (FuncSet::iterator j = v.begin(), ej = v.end(); - // j != ej; ++j) { - // OS << (*j)->getName() << "::"; - // } - // OS << "\n"; - if (v.empty()) { -#if LLVM_VERSION_MAJOR > 10 - OS << "!!EMPTY =>" << *CI->getCalledOperand()<<"\n"; -#else - OS << "!!EMPTY =>" << *CI->getCalledValue()<<"\n"; -#endif - OS<< "Uninitialized function pointer is dereferenced!\n"; - } +void ReachableCallGraphPass::dumpCallees(std::ostream &calleeInfo) { + RA_LOG("\n\n=== Dumping caller->callees ===\n\n"); + // Build caller -> set{callees} using direct edges first. + // If a caller has no direct callees recorded, fall back to type-based. + std::unordered_map> caller2callees; + for (const auto &kv : Ctx->Callees) { + const CallBase *CI = kv.first; + const FuncSet &FS = kv.second; + const Function *CallerF = CI->getFunction(); + auto &CalSet = caller2callees[CallerF]; + for (const Function *CalleeF : FS) { + CalSet.insert(CalleeF); + } + } + if (UseTypeBasedCallGraph) { + for (const auto &kv : calleeByType) { // calleeByType: CallBase* -> FuncSet + const CallBase *CI = kv.first; + const Function *CallerF = CI->getFunction(); + auto findIt = caller2callees.find(CallerF); + bool hasDirect = (findIt != caller2callees.end() && !findIt->second.empty()); + if (hasDirect) { + continue; // already have direct callees for this caller + } + const FuncSet &FS = kv.second; + auto &CalSet = caller2callees[CallerF]; + for (const Function *CalleeF : FS) { + CalSet.insert(CalleeF); + } } - RES_REPORT("\n[End of dumpCallees]\n"); + } + // Emit lines: callerGUID,calleeGUID,calleeGUID,... for callers that have any callees + for (const auto &kv : caller2callees) { + const Function *CallerF = kv.first; + const auto &Callees = kv.second; + if (Callees.empty()) { + continue; + } + calleeInfo << CallerF->getGUID(); + for (const Function *CF : Callees) { + calleeInfo << ',' << CF->getGUID(); + } + calleeInfo << '\n'; + } } -void ReachableCallGraphPass::dumpCallers() { - RES_REPORT("\n[dumpCallers]\n"); - for (auto M : Ctx->Callers) { - const Function *F = M.first; - CallInstSet &CIS = M.second; - RES_REPORT("F : " << getScopeName(F) << "\n"); - - for (auto *CI : CIS) { - auto CallerF = CI->getParent()->getParent(); - RES_REPORT("\t"); - if (CallerF && CallerF->hasName()) { - RES_REPORT("(" << getScopeName(CallerF) << ") "); - } else { - RES_REPORT("(anonymous) "); - } - - RES_REPORT(*CI << "\n"); +void ReachableCallGraphPass::dumpCallers(std::ostream &callerInfo) { + RA_LOG("\n\n=== Dumping callee->callers ===\n\n"); + // Collect all callees that have recorded callers (direct or type-based) + std::unordered_set allCallees; + for (const auto &kv : Ctx->Callers) { + allCallees.insert(kv.first); + } + if (UseTypeBasedCallGraph) { + for (const auto &kv : callerByType) { + allCallees.insert(kv.first); + } + } + // For each callee, emit one line: calleeGUID,callerGUID,callerGUID,... + for (const Function *Callee : allCallees) { + std::unordered_set callerFns; + // Direct callers + bool has_direct_callers = false; + if (auto it = Ctx->Callers.find(Callee); it != Ctx->Callers.end()) { + for (const CallBase *CI : it->second) { + callerFns.insert(CI->getFunction()); + has_direct_callers = true; + } + } + // Fallback to Type-based (indirect) callers + if (!has_direct_callers && UseTypeBasedCallGraph) { + if (auto it2 = callerByType.find(Callee); it2 != callerByType.end()) { + for (const CallBase *CI : it2->second) { + callerFns.insert(CI->getFunction()); } + } } - RES_REPORT("\n[End of dumpCallers]\n"); + if (callerFns.empty()) { + // No callers recorded for this callee + continue; + } + + callerInfo << Callee->getGUID(); + for (const Function *CallerF : callerFns) { + callerInfo << ',' << CallerF->getGUID(); + } + callerInfo << '\n'; + } } diff --git a/src/lib/Reachable.h b/src/lib/Reachable.h index 3356c93..50dd914 100644 --- a/src/lib/Reachable.h +++ b/src/lib/Reachable.h @@ -17,7 +17,6 @@ class ReachableCallGraphPass { bool runOnFunction(llvm::Function*); bool isCompatibleType(llvm::Type *T1, llvm::Type *T2); bool findCalleesByType(llvm::CallBase*, FuncSet&); - std::string getSourceLocation(const llvm::BasicBlock *BB); GlobalContext *Ctx; @@ -25,40 +24,55 @@ class ReachableCallGraphPass { CallerMap callerByType; const bool UseTypeBasedCallGraph; - const bool PropagateThroughReturnEdgees; - + std::unordered_map BBIDs; uint64_t nextBBID; + // Maximum call stack depth to propagate across callers + const unsigned maxCallStackDepth; + std::unordered_map callDepth; + std::unordered_map retDepth; std::vector > targetList; std::vector entryList; + std::unordered_set targetBBs; std::unordered_set reachableBBs; + std::unordered_set reachableFuns; std::unordered_map distances; std::unordered_set exitBBs; std::unordered_set entryBBs; using CallSequence = std::vector; std::unordered_map BBswithCalls; std::unordered_map callDistances; + std::unordered_map> criticalBBs; std::unordered_set reachableIndirectCalls; public: ReachableCallGraphPass(GlobalContext *Ctx_, std::string &TargetList, - std::string &EntryList, bool typeBased = true, bool propagateRet = false); + std::string &EntryList, bool typeBased = true, + unsigned CallStackLen = 10); virtual bool doInitialization(llvm::Module *); virtual bool doFinalization(llvm::Module *); virtual void run(ModuleList &modules); - // simple bfs pass - void collectReachable(std::deque &worklist, - std::unordered_set &reachable); + // BFS pass + void collectReachable(std::deque &worklist, + std::unordered_set &reachable, + const std::unordered_set &others = {}); + void propagateThroughReturnEdgees(std::unordered_set &retReachable, + const BasicBlock *startBB); // debug - void dumpDistance(std::ostream &OS, bool dumpSolution = false, bool dumpUnreachable = false); void dumpPolicy(std::ostream &OS); + void dumpCriticalBBs(std::ostream &OS); + void dumpDistance(std::ostream &OS, bool dumpUnreachable = false); void dumpIDMapping(ModuleList &modules, std::ostream &bbLocs, std::ostream &funcInfo); bool annotateModules(ModuleList &modules, std::string suffix=".annotated.bc"); - void dumpCallees(); - void dumpCallers(); + // dumpCallees CSV format: + // one line per *caller* function: callerGUID,calleeGUID[,calleeGUID...] + void dumpCallees(std::ostream &calleeInfo); + // dumpCallers CSV format: + // one line per *callee* function: calleeGUID,callerGUID[,callerGUID...] + void dumpCallers(std::ostream &callerInfo); }; #endif diff --git a/src/tests/8.c b/src/tests/8.c new file mode 100644 index 0000000..8e094a1 --- /dev/null +++ b/src/tests/8.c @@ -0,0 +1,34 @@ +// RUN: %clang -O0 -g -emit-llvm -c %s -o %t.bc +// RUN: %KAMain %t.bc --dump-distance=%t.distance.txt --dump-policy=%t.policy.txt --target-list=%S/BBtargets8.txt --entry-list=%S/entry.txt +// RUN: diff %t.distance.txt %S/ground_truth_distance8.txt +// RUN: diff %t.policy.txt %S/ground_truth_policy8.txt + +/* + Simple C program for reachability analysis testing with multiple conditional + branches and return edges. + + The expected outcome is that KAMain, when run over the generated LLVM bitcode, + will produce a distance file and a policy file that match the provided ground truth. +*/ + +int target() { + return 0; + } + +int somethingelse() { + return 0; +} + +void foo(int i) { + if (i) + target(); + else + somethingelse(); +} + +int main() { + int i = 0; + foo(i); + foo(i+1); + return 0; +} diff --git a/src/tests/BBtargets5.txt b/src/tests/BBtargets5.txt index aeed917..ef40f3a 100644 --- a/src/tests/BBtargets5.txt +++ b/src/tests/BBtargets5.txt @@ -1 +1 @@ -5.c:16 \ No newline at end of file +5.c:14 \ No newline at end of file diff --git a/src/tests/BBtargets8.txt b/src/tests/BBtargets8.txt new file mode 100644 index 0000000..7aacca9 --- /dev/null +++ b/src/tests/BBtargets8.txt @@ -0,0 +1 @@ +8.c:15 \ No newline at end of file diff --git a/src/tests/ground_truth_distance5.txt b/src/tests/ground_truth_distance5.txt index 781abdc..cd9ea95 100644 --- a/src/tests/ground_truth_distance5.txt +++ b/src/tests/ground_truth_distance5.txt @@ -1,7 +1,7 @@ -628124478,5.c:35,0.000000 -3546980649,5.c:23,1000.000000 -628120118,5.c:31,0.000000 -628088539,5.c:25,0.000000 -628053689,5.c:16,0.000000 -628083092,5.c:20,0.000000 -628090717,5.c:27,0.000000 +628122300,5.c:33,0.000000 +3546908775,5.c:21,-0.000000 +628092893,5.c:29,0.000000 +628086361,5.c:23,0.000000 +628088539,5.c:25,-0.000000 +628051511,5.c:14,0.000000 +628089624,5.c:26,0.000000 diff --git a/src/tests/ground_truth_distance8.txt b/src/tests/ground_truth_distance8.txt new file mode 100644 index 0000000..7f02cdc --- /dev/null +++ b/src/tests/ground_truth_distance8.txt @@ -0,0 +1,4 @@ +3929396380,8.c:30,1000.000000 +819947467,8.c:22,1000.000000 +3929364797,8.c:24,0.000000 +3929329949,8.c:15,0.000000 diff --git a/src/tests/ground_truth_policy5.txt b/src/tests/ground_truth_policy5.txt index 12321e3..a151eb8 100644 --- a/src/tests/ground_truth_policy5.txt +++ b/src/tests/ground_truth_policy5.txt @@ -1,2 +1,2 @@ -3546980649,inf,0.000000 +3546908775,-0.000000,0.000000 ########## diff --git a/src/tests/ground_truth_policy8.txt b/src/tests/ground_truth_policy8.txt new file mode 100644 index 0000000..db369f6 --- /dev/null +++ b/src/tests/ground_truth_policy8.txt @@ -0,0 +1,2 @@ +819947467,1000.000000,0.000000 +########## diff --git a/src/tools/verify_critical_BBs.py b/src/tools/verify_critical_BBs.py new file mode 100644 index 0000000..178d989 --- /dev/null +++ b/src/tools/verify_critical_BBs.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +""" + Reachability-based Call Graph Analysis + + Copyrigth (C) 2024 - 2025 Haochen Zeng + + For licensing details see LICENSE +""" + +import os +import sys + +def get_bids(): + bids = [] + with open(bid_file, 'r') as fd: + for raw in fd: + line = raw.strip() + if not line or line.startswith('#'): + continue + first_field = line.split(',', 1)[0].strip() + try: + bids.append(int(first_field)) + except ValueError: + print(f"[WARN] Skip invalid bid line: {line}", file=sys.stderr) + return bids + +def load_mappings(path): + """Read mapping_file into a dict: bid → (filepath, line_no).""" + mappings = {} + with open(path, 'r') as f: + for raw in f: + line = raw.strip() + if not line or line.startswith('#'): + continue + parts = line.split(',') + if len(parts) < 2: + continue + bid_str = parts[0].strip() + loc = parts[-1].strip() + if ':' not in loc: + continue + file_path, lineno_str = loc.rsplit(':', 1) + try: + bid = int(bid_str) + lineno = int(lineno_str) + except ValueError: + continue + mappings[bid] = (file_path, lineno) + return mappings + +def show_context(file_path, line_no, bid): + """Print the line_no ± ctx lines from file_path.""" + if not os.path.exists(file_path): + print(f"[ERROR] File not found: {file_path}", file=sys.stderr) + return + with open(file_path, 'r') as f: + lines = f.readlines() + # zero-based indices + idx = line_no - 1 + start = max(0, idx - 2) + end = min(len(lines), idx + 10 + 1) + + print(f"\n--- BID({bid}) context: {os.path.basename(file_path)}:{line_no}---") + print("```") + for i in range(start, end): + prefix = "=> " if i == idx else " " + # Pad line numbers for readability + print(f"{prefix}{i+1:4d}: {lines[i].rstrip()}") + print("```") + +def main(): + # Path to your bid -> location mapping + global mapping_file, bid_file + target_program = "" + if len(sys.argv) > 1: + target_program = sys.argv[1].strip() + "_" + mapping_file = f'{target_program}bid_loc_mapping.txt' + bid_file = f"{target_program}critical_BBs.txt" + + if not os.path.exists(mapping_file): + print(f"[ERROR] Mapping file not found: {mapping_file}", file=sys.stderr) + sys.exit(1) + if not os.path.exists(bid_file): + print(f"[ERROR] Critical BIDs file not found: {bid_file}", file=sys.stderr) + sys.exit(1) + + # load all mappings at once + mappings = load_mappings(mapping_file) + critical_bids = get_bids() + for bid in critical_bids: + if bid not in mappings: + print(f"[WARN] No mapping found for bid {bid}", file=sys.stderr) + continue + filepath, lineno = mappings[bid] + show_context(filepath, lineno, bid) + +if __name__ == '__main__': + main() \ No newline at end of file