1311818Sdim//===-- WebAssemblyFixFunctionBitcasts.cpp - Fix function bitcasts --------===// 2311818Sdim// 3353358Sdim// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4353358Sdim// See https://llvm.org/LICENSE.txt for license information. 5353358Sdim// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6311818Sdim// 7311818Sdim//===----------------------------------------------------------------------===// 8311818Sdim/// 9311818Sdim/// \file 10341825Sdim/// Fix bitcasted functions. 11311818Sdim/// 12311818Sdim/// WebAssembly requires caller and callee signatures to match, however in LLVM, 13311818Sdim/// some amount of slop is vaguely permitted. Detect mismatch by looking for 14311818Sdim/// bitcasts of functions and rewrite them to use wrapper functions instead. 15311818Sdim/// 16311818Sdim/// This doesn't catch all cases, such as when a function's address is taken in 17311818Sdim/// one place and casted in another, but it works for many common cases. 18311818Sdim/// 19311818Sdim/// Note that LLVM already optimizes away function bitcasts in common cases by 20311818Sdim/// dropping arguments as needed, so this pass only ends up getting used in less 21311818Sdim/// common cases. 22311818Sdim/// 23311818Sdim//===----------------------------------------------------------------------===// 24311818Sdim 25311818Sdim#include "WebAssembly.h" 26327952Sdim#include "llvm/IR/CallSite.h" 27311818Sdim#include "llvm/IR/Constants.h" 28311818Sdim#include "llvm/IR/Instructions.h" 29311818Sdim#include "llvm/IR/Module.h" 30311818Sdim#include "llvm/IR/Operator.h" 31311818Sdim#include "llvm/Pass.h" 32311818Sdim#include "llvm/Support/Debug.h" 33311818Sdim#include "llvm/Support/raw_ostream.h" 34311818Sdimusing namespace llvm; 35311818Sdim 36311818Sdim#define DEBUG_TYPE "wasm-fix-function-bitcasts" 37311818Sdim 38311818Sdimnamespace { 39311818Sdimclass FixFunctionBitcasts final : public ModulePass { 40311818Sdim StringRef getPassName() const override { 41311818Sdim return "WebAssembly Fix Function Bitcasts"; 42311818Sdim } 43311818Sdim 44311818Sdim void getAnalysisUsage(AnalysisUsage &AU) const override { 45311818Sdim AU.setPreservesCFG(); 46311818Sdim ModulePass::getAnalysisUsage(AU); 47311818Sdim } 48311818Sdim 49311818Sdim bool runOnModule(Module &M) override; 50311818Sdim 51311818Sdimpublic: 52311818Sdim static char ID; 53311818Sdim FixFunctionBitcasts() : ModulePass(ID) {} 54311818Sdim}; 55311818Sdim} // End anonymous namespace 56311818Sdim 57311818Sdimchar FixFunctionBitcasts::ID = 0; 58341825SdimINITIALIZE_PASS(FixFunctionBitcasts, DEBUG_TYPE, 59341825Sdim "Fix mismatching bitcasts for WebAssembly", false, false) 60341825Sdim 61311818SdimModulePass *llvm::createWebAssemblyFixFunctionBitcasts() { 62311818Sdim return new FixFunctionBitcasts(); 63311818Sdim} 64311818Sdim 65311818Sdim// Recursively descend the def-use lists from V to find non-bitcast users of 66311818Sdim// bitcasts of V. 67353358Sdimstatic void findUses(Value *V, Function &F, 68312197Sdim SmallVectorImpl<std::pair<Use *, Function *>> &Uses, 69312197Sdim SmallPtrSetImpl<Constant *> &ConstantBCs) { 70311818Sdim for (Use &U : V->uses()) { 71353358Sdim if (auto *BC = dyn_cast<BitCastOperator>(U.getUser())) 72353358Sdim findUses(BC, F, Uses, ConstantBCs); 73360784Sdim else if (auto *A = dyn_cast<GlobalAlias>(U.getUser())) 74360784Sdim findUses(A, F, Uses, ConstantBCs); 75312197Sdim else if (U.get()->getType() != F.getType()) { 76327952Sdim CallSite CS(U.getUser()); 77327952Sdim if (!CS) 78327952Sdim // Skip uses that aren't immediately called 79327952Sdim continue; 80327952Sdim Value *Callee = CS.getCalledValue(); 81327952Sdim if (Callee != V) 82327952Sdim // Skip calls where the function isn't the callee 83327952Sdim continue; 84312197Sdim if (isa<Constant>(U.get())) { 85312197Sdim // Only add constant bitcasts to the list once; they get RAUW'd 86353358Sdim auto C = ConstantBCs.insert(cast<Constant>(U.get())); 87353358Sdim if (!C.second) 88327952Sdim continue; 89312197Sdim } 90311818Sdim Uses.push_back(std::make_pair(&U, &F)); 91312197Sdim } 92311818Sdim } 93311818Sdim} 94311818Sdim 95311818Sdim// Create a wrapper function with type Ty that calls F (which may have a 96311818Sdim// different type). Attempt to support common bitcasted function idioms: 97311818Sdim// - Call with more arguments than needed: arguments are dropped 98311818Sdim// - Call with fewer arguments than needed: arguments are filled in with undef 99311818Sdim// - Return value is not needed: drop it 100311818Sdim// - Return value needed but not present: supply an undef 101321369Sdim// 102344779Sdim// If the all the argument types of trivially castable to one another (i.e. 103344779Sdim// I32 vs pointer type) then we don't create a wrapper at all (return nullptr 104344779Sdim// instead). 105344779Sdim// 106344779Sdim// If there is a type mismatch that we know would result in an invalid wasm 107344779Sdim// module then generate wrapper that contains unreachable (i.e. abort at 108344779Sdim// runtime). Such programs are deep into undefined behaviour territory, 109344779Sdim// but we choose to fail at runtime rather than generate and invalid module 110344779Sdim// or fail at compiler time. The reason we delay the error is that we want 111344779Sdim// to support the CMake which expects to be able to compile and link programs 112344779Sdim// that refer to functions with entirely incorrect signatures (this is how 113344779Sdim// CMake detects the existence of a function in a toolchain). 114344779Sdim// 115344779Sdim// For bitcasts that involve struct types we don't know at this stage if they 116344779Sdim// would be equivalent at the wasm level and so we can't know if we need to 117344779Sdim// generate a wrapper. 118353358Sdimstatic Function *createWrapper(Function *F, FunctionType *Ty) { 119311818Sdim Module *M = F->getParent(); 120311818Sdim 121344779Sdim Function *Wrapper = Function::Create(Ty, Function::PrivateLinkage, 122344779Sdim F->getName() + "_bitcast", M); 123311818Sdim BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper); 124344779Sdim const DataLayout &DL = BB->getModule()->getDataLayout(); 125311818Sdim 126311818Sdim // Determine what arguments to pass. 127311818Sdim SmallVector<Value *, 4> Args; 128311818Sdim Function::arg_iterator AI = Wrapper->arg_begin(); 129327952Sdim Function::arg_iterator AE = Wrapper->arg_end(); 130311818Sdim FunctionType::param_iterator PI = F->getFunctionType()->param_begin(); 131311818Sdim FunctionType::param_iterator PE = F->getFunctionType()->param_end(); 132344779Sdim bool TypeMismatch = false; 133344779Sdim bool WrapperNeeded = false; 134344779Sdim 135344779Sdim Type *ExpectedRtnType = F->getFunctionType()->getReturnType(); 136344779Sdim Type *RtnType = Ty->getReturnType(); 137344779Sdim 138344779Sdim if ((F->getFunctionType()->getNumParams() != Ty->getNumParams()) || 139344779Sdim (F->getFunctionType()->isVarArg() != Ty->isVarArg()) || 140344779Sdim (ExpectedRtnType != RtnType)) 141344779Sdim WrapperNeeded = true; 142344779Sdim 143327952Sdim for (; AI != AE && PI != PE; ++AI, ++PI) { 144344779Sdim Type *ArgType = AI->getType(); 145344779Sdim Type *ParamType = *PI; 146344779Sdim 147344779Sdim if (ArgType == ParamType) { 148344779Sdim Args.push_back(&*AI); 149344779Sdim } else { 150344779Sdim if (CastInst::isBitOrNoopPointerCastable(ArgType, ParamType, DL)) { 151344779Sdim Instruction *PtrCast = 152344779Sdim CastInst::CreateBitOrPointerCast(AI, ParamType, "cast"); 153344779Sdim BB->getInstList().push_back(PtrCast); 154344779Sdim Args.push_back(PtrCast); 155344779Sdim } else if (ArgType->isStructTy() || ParamType->isStructTy()) { 156353358Sdim LLVM_DEBUG(dbgs() << "createWrapper: struct param type in bitcast: " 157344779Sdim << F->getName() << "\n"); 158344779Sdim WrapperNeeded = false; 159344779Sdim } else { 160353358Sdim LLVM_DEBUG(dbgs() << "createWrapper: arg type mismatch calling: " 161344779Sdim << F->getName() << "\n"); 162344779Sdim LLVM_DEBUG(dbgs() << "Arg[" << Args.size() << "] Expected: " 163344779Sdim << *ParamType << " Got: " << *ArgType << "\n"); 164344779Sdim TypeMismatch = true; 165344779Sdim break; 166344779Sdim } 167311818Sdim } 168311818Sdim } 169311818Sdim 170344779Sdim if (WrapperNeeded && !TypeMismatch) { 171344779Sdim for (; PI != PE; ++PI) 172344779Sdim Args.push_back(UndefValue::get(*PI)); 173344779Sdim if (F->isVarArg()) 174344779Sdim for (; AI != AE; ++AI) 175344779Sdim Args.push_back(&*AI); 176311818Sdim 177344779Sdim CallInst *Call = CallInst::Create(F, Args, "", BB); 178344779Sdim 179344779Sdim Type *ExpectedRtnType = F->getFunctionType()->getReturnType(); 180344779Sdim Type *RtnType = Ty->getReturnType(); 181344779Sdim // Determine what value to return. 182344779Sdim if (RtnType->isVoidTy()) { 183344779Sdim ReturnInst::Create(M->getContext(), BB); 184344779Sdim } else if (ExpectedRtnType->isVoidTy()) { 185344779Sdim LLVM_DEBUG(dbgs() << "Creating dummy return: " << *RtnType << "\n"); 186344779Sdim ReturnInst::Create(M->getContext(), UndefValue::get(RtnType), BB); 187344779Sdim } else if (RtnType == ExpectedRtnType) { 188344779Sdim ReturnInst::Create(M->getContext(), Call, BB); 189344779Sdim } else if (CastInst::isBitOrNoopPointerCastable(ExpectedRtnType, RtnType, 190344779Sdim DL)) { 191344779Sdim Instruction *Cast = 192344779Sdim CastInst::CreateBitOrPointerCast(Call, RtnType, "cast"); 193344779Sdim BB->getInstList().push_back(Cast); 194344779Sdim ReturnInst::Create(M->getContext(), Cast, BB); 195344779Sdim } else if (RtnType->isStructTy() || ExpectedRtnType->isStructTy()) { 196353358Sdim LLVM_DEBUG(dbgs() << "createWrapper: struct return type in bitcast: " 197344779Sdim << F->getName() << "\n"); 198344779Sdim WrapperNeeded = false; 199344779Sdim } else { 200353358Sdim LLVM_DEBUG(dbgs() << "createWrapper: return type mismatch calling: " 201344779Sdim << F->getName() << "\n"); 202344779Sdim LLVM_DEBUG(dbgs() << "Expected: " << *ExpectedRtnType 203344779Sdim << " Got: " << *RtnType << "\n"); 204344779Sdim TypeMismatch = true; 205344779Sdim } 206344779Sdim } 207344779Sdim 208344779Sdim if (TypeMismatch) { 209344779Sdim // Create a new wrapper that simply contains `unreachable`. 210311818Sdim Wrapper->eraseFromParent(); 211344779Sdim Wrapper = Function::Create(Ty, Function::PrivateLinkage, 212344779Sdim F->getName() + "_bitcast_invalid", M); 213344779Sdim BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper); 214344779Sdim new UnreachableInst(M->getContext(), BB); 215344779Sdim Wrapper->setName(F->getName() + "_bitcast_invalid"); 216344779Sdim } else if (!WrapperNeeded) { 217353358Sdim LLVM_DEBUG(dbgs() << "createWrapper: no wrapper needed: " << F->getName() 218344779Sdim << "\n"); 219344779Sdim Wrapper->eraseFromParent(); 220311818Sdim return nullptr; 221311818Sdim } 222353358Sdim LLVM_DEBUG(dbgs() << "createWrapper: " << F->getName() << "\n"); 223311818Sdim return Wrapper; 224311818Sdim} 225311818Sdim 226344779Sdim// Test whether a main function with type FuncTy should be rewritten to have 227344779Sdim// type MainTy. 228353358Sdimstatic bool shouldFixMainFunction(FunctionType *FuncTy, FunctionType *MainTy) { 229344779Sdim // Only fix the main function if it's the standard zero-arg form. That way, 230344779Sdim // the standard cases will work as expected, and users will see signature 231344779Sdim // mismatches from the linker for non-standard cases. 232344779Sdim return FuncTy->getReturnType() == MainTy->getReturnType() && 233344779Sdim FuncTy->getNumParams() == 0 && 234344779Sdim !FuncTy->isVarArg(); 235344779Sdim} 236344779Sdim 237311818Sdimbool FixFunctionBitcasts::runOnModule(Module &M) { 238344779Sdim LLVM_DEBUG(dbgs() << "********** Fix Function Bitcasts **********\n"); 239344779Sdim 240327952Sdim Function *Main = nullptr; 241327952Sdim CallInst *CallMain = nullptr; 242311818Sdim SmallVector<std::pair<Use *, Function *>, 0> Uses; 243312197Sdim SmallPtrSet<Constant *, 2> ConstantBCs; 244311818Sdim 245311818Sdim // Collect all the places that need wrappers. 246327952Sdim for (Function &F : M) { 247353358Sdim findUses(&F, F, Uses, ConstantBCs); 248311818Sdim 249327952Sdim // If we have a "main" function, and its type isn't 250327952Sdim // "int main(int argc, char *argv[])", create an artificial call with it 251327952Sdim // bitcasted to that type so that we generate a wrapper for it, so that 252327952Sdim // the C runtime can call it. 253344779Sdim if (F.getName() == "main") { 254327952Sdim Main = &F; 255327952Sdim LLVMContext &C = M.getContext(); 256344779Sdim Type *MainArgTys[] = {Type::getInt32Ty(C), 257344779Sdim PointerType::get(Type::getInt8PtrTy(C), 0)}; 258327952Sdim FunctionType *MainTy = FunctionType::get(Type::getInt32Ty(C), MainArgTys, 259327952Sdim /*isVarArg=*/false); 260344779Sdim if (shouldFixMainFunction(F.getFunctionType(), MainTy)) { 261344779Sdim LLVM_DEBUG(dbgs() << "Found `main` function with incorrect type: " 262344779Sdim << *F.getFunctionType() << "\n"); 263344779Sdim Value *Args[] = {UndefValue::get(MainArgTys[0]), 264344779Sdim UndefValue::get(MainArgTys[1])}; 265344779Sdim Value *Casted = 266344779Sdim ConstantExpr::getBitCast(Main, PointerType::get(MainTy, 0)); 267353358Sdim CallMain = CallInst::Create(MainTy, Casted, Args, "call_main"); 268327952Sdim Use *UseMain = &CallMain->getOperandUse(2); 269327952Sdim Uses.push_back(std::make_pair(UseMain, &F)); 270327952Sdim } 271327952Sdim } 272327952Sdim } 273327952Sdim 274311818Sdim DenseMap<std::pair<Function *, FunctionType *>, Function *> Wrappers; 275311818Sdim 276311818Sdim for (auto &UseFunc : Uses) { 277311818Sdim Use *U = UseFunc.first; 278311818Sdim Function *F = UseFunc.second; 279353358Sdim auto *PTy = cast<PointerType>(U->get()->getType()); 280353358Sdim auto *Ty = dyn_cast<FunctionType>(PTy->getElementType()); 281311818Sdim 282311818Sdim // If the function is casted to something like i8* as a "generic pointer" 283311818Sdim // to be later casted to something else, we can't generate a wrapper for it. 284311818Sdim // Just ignore such casts for now. 285311818Sdim if (!Ty) 286311818Sdim continue; 287311818Sdim 288311818Sdim auto Pair = Wrappers.insert(std::make_pair(std::make_pair(F, Ty), nullptr)); 289311818Sdim if (Pair.second) 290353358Sdim Pair.first->second = createWrapper(F, Ty); 291311818Sdim 292311818Sdim Function *Wrapper = Pair.first->second; 293311818Sdim if (!Wrapper) 294311818Sdim continue; 295311818Sdim 296311818Sdim if (isa<Constant>(U->get())) 297311818Sdim U->get()->replaceAllUsesWith(Wrapper); 298311818Sdim else 299311818Sdim U->set(Wrapper); 300311818Sdim } 301311818Sdim 302327952Sdim // If we created a wrapper for main, rename the wrapper so that it's the 303327952Sdim // one that gets called from startup. 304327952Sdim if (CallMain) { 305327952Sdim Main->setName("__original_main"); 306353358Sdim auto *MainWrapper = 307327952Sdim cast<Function>(CallMain->getCalledValue()->stripPointerCasts()); 308327952Sdim delete CallMain; 309344779Sdim if (Main->isDeclaration()) { 310344779Sdim // The wrapper is not needed in this case as we don't need to export 311344779Sdim // it to anyone else. 312344779Sdim MainWrapper->eraseFromParent(); 313344779Sdim } else { 314344779Sdim // Otherwise give the wrapper the same linkage as the original main 315344779Sdim // function, so that it can be called from the same places. 316344779Sdim MainWrapper->setName("main"); 317344779Sdim MainWrapper->setLinkage(Main->getLinkage()); 318344779Sdim MainWrapper->setVisibility(Main->getVisibility()); 319344779Sdim } 320327952Sdim } 321327952Sdim 322311818Sdim return true; 323311818Sdim} 324