1311818Sdim//===-- WebAssemblyFixFunctionBitcasts.cpp - Fix function bitcasts --------===//
2311818Sdim//
3353358Sdim// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4353358Sdim// See https://llvm.org/LICENSE.txt for license information.
5353358Sdim// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6311818Sdim//
7311818Sdim//===----------------------------------------------------------------------===//
8311818Sdim///
9311818Sdim/// \file
10341825Sdim/// Fix bitcasted functions.
11311818Sdim///
12311818Sdim/// WebAssembly requires caller and callee signatures to match, however in LLVM,
13311818Sdim/// some amount of slop is vaguely permitted. Detect mismatch by looking for
14311818Sdim/// bitcasts of functions and rewrite them to use wrapper functions instead.
15311818Sdim///
16311818Sdim/// This doesn't catch all cases, such as when a function's address is taken in
17311818Sdim/// one place and casted in another, but it works for many common cases.
18311818Sdim///
19311818Sdim/// Note that LLVM already optimizes away function bitcasts in common cases by
20311818Sdim/// dropping arguments as needed, so this pass only ends up getting used in less
21311818Sdim/// common cases.
22311818Sdim///
23311818Sdim//===----------------------------------------------------------------------===//
24311818Sdim
25311818Sdim#include "WebAssembly.h"
26327952Sdim#include "llvm/IR/CallSite.h"
27311818Sdim#include "llvm/IR/Constants.h"
28311818Sdim#include "llvm/IR/Instructions.h"
29311818Sdim#include "llvm/IR/Module.h"
30311818Sdim#include "llvm/IR/Operator.h"
31311818Sdim#include "llvm/Pass.h"
32311818Sdim#include "llvm/Support/Debug.h"
33311818Sdim#include "llvm/Support/raw_ostream.h"
34311818Sdimusing namespace llvm;
35311818Sdim
36311818Sdim#define DEBUG_TYPE "wasm-fix-function-bitcasts"
37311818Sdim
38311818Sdimnamespace {
39311818Sdimclass FixFunctionBitcasts final : public ModulePass {
40311818Sdim  StringRef getPassName() const override {
41311818Sdim    return "WebAssembly Fix Function Bitcasts";
42311818Sdim  }
43311818Sdim
44311818Sdim  void getAnalysisUsage(AnalysisUsage &AU) const override {
45311818Sdim    AU.setPreservesCFG();
46311818Sdim    ModulePass::getAnalysisUsage(AU);
47311818Sdim  }
48311818Sdim
49311818Sdim  bool runOnModule(Module &M) override;
50311818Sdim
51311818Sdimpublic:
52311818Sdim  static char ID;
53311818Sdim  FixFunctionBitcasts() : ModulePass(ID) {}
54311818Sdim};
55311818Sdim} // End anonymous namespace
56311818Sdim
57311818Sdimchar FixFunctionBitcasts::ID = 0;
58341825SdimINITIALIZE_PASS(FixFunctionBitcasts, DEBUG_TYPE,
59341825Sdim                "Fix mismatching bitcasts for WebAssembly", false, false)
60341825Sdim
61311818SdimModulePass *llvm::createWebAssemblyFixFunctionBitcasts() {
62311818Sdim  return new FixFunctionBitcasts();
63311818Sdim}
64311818Sdim
65311818Sdim// Recursively descend the def-use lists from V to find non-bitcast users of
66311818Sdim// bitcasts of V.
67353358Sdimstatic void findUses(Value *V, Function &F,
68312197Sdim                     SmallVectorImpl<std::pair<Use *, Function *>> &Uses,
69312197Sdim                     SmallPtrSetImpl<Constant *> &ConstantBCs) {
70311818Sdim  for (Use &U : V->uses()) {
71353358Sdim    if (auto *BC = dyn_cast<BitCastOperator>(U.getUser()))
72353358Sdim      findUses(BC, F, Uses, ConstantBCs);
73360784Sdim    else if (auto *A = dyn_cast<GlobalAlias>(U.getUser()))
74360784Sdim      findUses(A, F, Uses, ConstantBCs);
75312197Sdim    else if (U.get()->getType() != F.getType()) {
76327952Sdim      CallSite CS(U.getUser());
77327952Sdim      if (!CS)
78327952Sdim        // Skip uses that aren't immediately called
79327952Sdim        continue;
80327952Sdim      Value *Callee = CS.getCalledValue();
81327952Sdim      if (Callee != V)
82327952Sdim        // Skip calls where the function isn't the callee
83327952Sdim        continue;
84312197Sdim      if (isa<Constant>(U.get())) {
85312197Sdim        // Only add constant bitcasts to the list once; they get RAUW'd
86353358Sdim        auto C = ConstantBCs.insert(cast<Constant>(U.get()));
87353358Sdim        if (!C.second)
88327952Sdim          continue;
89312197Sdim      }
90311818Sdim      Uses.push_back(std::make_pair(&U, &F));
91312197Sdim    }
92311818Sdim  }
93311818Sdim}
94311818Sdim
95311818Sdim// Create a wrapper function with type Ty that calls F (which may have a
96311818Sdim// different type). Attempt to support common bitcasted function idioms:
97311818Sdim//  - Call with more arguments than needed: arguments are dropped
98311818Sdim//  - Call with fewer arguments than needed: arguments are filled in with undef
99311818Sdim//  - Return value is not needed: drop it
100311818Sdim//  - Return value needed but not present: supply an undef
101321369Sdim//
102344779Sdim// If the all the argument types of trivially castable to one another (i.e.
103344779Sdim// I32 vs pointer type) then we don't create a wrapper at all (return nullptr
104344779Sdim// instead).
105344779Sdim//
106344779Sdim// If there is a type mismatch that we know would result in an invalid wasm
107344779Sdim// module then generate wrapper that contains unreachable (i.e. abort at
108344779Sdim// runtime).  Such programs are deep into undefined behaviour territory,
109344779Sdim// but we choose to fail at runtime rather than generate and invalid module
110344779Sdim// or fail at compiler time.  The reason we delay the error is that we want
111344779Sdim// to support the CMake which expects to be able to compile and link programs
112344779Sdim// that refer to functions with entirely incorrect signatures (this is how
113344779Sdim// CMake detects the existence of a function in a toolchain).
114344779Sdim//
115344779Sdim// For bitcasts that involve struct types we don't know at this stage if they
116344779Sdim// would be equivalent at the wasm level and so we can't know if we need to
117344779Sdim// generate a wrapper.
118353358Sdimstatic Function *createWrapper(Function *F, FunctionType *Ty) {
119311818Sdim  Module *M = F->getParent();
120311818Sdim
121344779Sdim  Function *Wrapper = Function::Create(Ty, Function::PrivateLinkage,
122344779Sdim                                       F->getName() + "_bitcast", M);
123311818Sdim  BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
124344779Sdim  const DataLayout &DL = BB->getModule()->getDataLayout();
125311818Sdim
126311818Sdim  // Determine what arguments to pass.
127311818Sdim  SmallVector<Value *, 4> Args;
128311818Sdim  Function::arg_iterator AI = Wrapper->arg_begin();
129327952Sdim  Function::arg_iterator AE = Wrapper->arg_end();
130311818Sdim  FunctionType::param_iterator PI = F->getFunctionType()->param_begin();
131311818Sdim  FunctionType::param_iterator PE = F->getFunctionType()->param_end();
132344779Sdim  bool TypeMismatch = false;
133344779Sdim  bool WrapperNeeded = false;
134344779Sdim
135344779Sdim  Type *ExpectedRtnType = F->getFunctionType()->getReturnType();
136344779Sdim  Type *RtnType = Ty->getReturnType();
137344779Sdim
138344779Sdim  if ((F->getFunctionType()->getNumParams() != Ty->getNumParams()) ||
139344779Sdim      (F->getFunctionType()->isVarArg() != Ty->isVarArg()) ||
140344779Sdim      (ExpectedRtnType != RtnType))
141344779Sdim    WrapperNeeded = true;
142344779Sdim
143327952Sdim  for (; AI != AE && PI != PE; ++AI, ++PI) {
144344779Sdim    Type *ArgType = AI->getType();
145344779Sdim    Type *ParamType = *PI;
146344779Sdim
147344779Sdim    if (ArgType == ParamType) {
148344779Sdim      Args.push_back(&*AI);
149344779Sdim    } else {
150344779Sdim      if (CastInst::isBitOrNoopPointerCastable(ArgType, ParamType, DL)) {
151344779Sdim        Instruction *PtrCast =
152344779Sdim            CastInst::CreateBitOrPointerCast(AI, ParamType, "cast");
153344779Sdim        BB->getInstList().push_back(PtrCast);
154344779Sdim        Args.push_back(PtrCast);
155344779Sdim      } else if (ArgType->isStructTy() || ParamType->isStructTy()) {
156353358Sdim        LLVM_DEBUG(dbgs() << "createWrapper: struct param type in bitcast: "
157344779Sdim                          << F->getName() << "\n");
158344779Sdim        WrapperNeeded = false;
159344779Sdim      } else {
160353358Sdim        LLVM_DEBUG(dbgs() << "createWrapper: arg type mismatch calling: "
161344779Sdim                          << F->getName() << "\n");
162344779Sdim        LLVM_DEBUG(dbgs() << "Arg[" << Args.size() << "] Expected: "
163344779Sdim                          << *ParamType << " Got: " << *ArgType << "\n");
164344779Sdim        TypeMismatch = true;
165344779Sdim        break;
166344779Sdim      }
167311818Sdim    }
168311818Sdim  }
169311818Sdim
170344779Sdim  if (WrapperNeeded && !TypeMismatch) {
171344779Sdim    for (; PI != PE; ++PI)
172344779Sdim      Args.push_back(UndefValue::get(*PI));
173344779Sdim    if (F->isVarArg())
174344779Sdim      for (; AI != AE; ++AI)
175344779Sdim        Args.push_back(&*AI);
176311818Sdim
177344779Sdim    CallInst *Call = CallInst::Create(F, Args, "", BB);
178344779Sdim
179344779Sdim    Type *ExpectedRtnType = F->getFunctionType()->getReturnType();
180344779Sdim    Type *RtnType = Ty->getReturnType();
181344779Sdim    // Determine what value to return.
182344779Sdim    if (RtnType->isVoidTy()) {
183344779Sdim      ReturnInst::Create(M->getContext(), BB);
184344779Sdim    } else if (ExpectedRtnType->isVoidTy()) {
185344779Sdim      LLVM_DEBUG(dbgs() << "Creating dummy return: " << *RtnType << "\n");
186344779Sdim      ReturnInst::Create(M->getContext(), UndefValue::get(RtnType), BB);
187344779Sdim    } else if (RtnType == ExpectedRtnType) {
188344779Sdim      ReturnInst::Create(M->getContext(), Call, BB);
189344779Sdim    } else if (CastInst::isBitOrNoopPointerCastable(ExpectedRtnType, RtnType,
190344779Sdim                                                    DL)) {
191344779Sdim      Instruction *Cast =
192344779Sdim          CastInst::CreateBitOrPointerCast(Call, RtnType, "cast");
193344779Sdim      BB->getInstList().push_back(Cast);
194344779Sdim      ReturnInst::Create(M->getContext(), Cast, BB);
195344779Sdim    } else if (RtnType->isStructTy() || ExpectedRtnType->isStructTy()) {
196353358Sdim      LLVM_DEBUG(dbgs() << "createWrapper: struct return type in bitcast: "
197344779Sdim                        << F->getName() << "\n");
198344779Sdim      WrapperNeeded = false;
199344779Sdim    } else {
200353358Sdim      LLVM_DEBUG(dbgs() << "createWrapper: return type mismatch calling: "
201344779Sdim                        << F->getName() << "\n");
202344779Sdim      LLVM_DEBUG(dbgs() << "Expected: " << *ExpectedRtnType
203344779Sdim                        << " Got: " << *RtnType << "\n");
204344779Sdim      TypeMismatch = true;
205344779Sdim    }
206344779Sdim  }
207344779Sdim
208344779Sdim  if (TypeMismatch) {
209344779Sdim    // Create a new wrapper that simply contains `unreachable`.
210311818Sdim    Wrapper->eraseFromParent();
211344779Sdim    Wrapper = Function::Create(Ty, Function::PrivateLinkage,
212344779Sdim                               F->getName() + "_bitcast_invalid", M);
213344779Sdim    BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
214344779Sdim    new UnreachableInst(M->getContext(), BB);
215344779Sdim    Wrapper->setName(F->getName() + "_bitcast_invalid");
216344779Sdim  } else if (!WrapperNeeded) {
217353358Sdim    LLVM_DEBUG(dbgs() << "createWrapper: no wrapper needed: " << F->getName()
218344779Sdim                      << "\n");
219344779Sdim    Wrapper->eraseFromParent();
220311818Sdim    return nullptr;
221311818Sdim  }
222353358Sdim  LLVM_DEBUG(dbgs() << "createWrapper: " << F->getName() << "\n");
223311818Sdim  return Wrapper;
224311818Sdim}
225311818Sdim
226344779Sdim// Test whether a main function with type FuncTy should be rewritten to have
227344779Sdim// type MainTy.
228353358Sdimstatic bool shouldFixMainFunction(FunctionType *FuncTy, FunctionType *MainTy) {
229344779Sdim  // Only fix the main function if it's the standard zero-arg form. That way,
230344779Sdim  // the standard cases will work as expected, and users will see signature
231344779Sdim  // mismatches from the linker for non-standard cases.
232344779Sdim  return FuncTy->getReturnType() == MainTy->getReturnType() &&
233344779Sdim         FuncTy->getNumParams() == 0 &&
234344779Sdim         !FuncTy->isVarArg();
235344779Sdim}
236344779Sdim
237311818Sdimbool FixFunctionBitcasts::runOnModule(Module &M) {
238344779Sdim  LLVM_DEBUG(dbgs() << "********** Fix Function Bitcasts **********\n");
239344779Sdim
240327952Sdim  Function *Main = nullptr;
241327952Sdim  CallInst *CallMain = nullptr;
242311818Sdim  SmallVector<std::pair<Use *, Function *>, 0> Uses;
243312197Sdim  SmallPtrSet<Constant *, 2> ConstantBCs;
244311818Sdim
245311818Sdim  // Collect all the places that need wrappers.
246327952Sdim  for (Function &F : M) {
247353358Sdim    findUses(&F, F, Uses, ConstantBCs);
248311818Sdim
249327952Sdim    // If we have a "main" function, and its type isn't
250327952Sdim    // "int main(int argc, char *argv[])", create an artificial call with it
251327952Sdim    // bitcasted to that type so that we generate a wrapper for it, so that
252327952Sdim    // the C runtime can call it.
253344779Sdim    if (F.getName() == "main") {
254327952Sdim      Main = &F;
255327952Sdim      LLVMContext &C = M.getContext();
256344779Sdim      Type *MainArgTys[] = {Type::getInt32Ty(C),
257344779Sdim                            PointerType::get(Type::getInt8PtrTy(C), 0)};
258327952Sdim      FunctionType *MainTy = FunctionType::get(Type::getInt32Ty(C), MainArgTys,
259327952Sdim                                               /*isVarArg=*/false);
260344779Sdim      if (shouldFixMainFunction(F.getFunctionType(), MainTy)) {
261344779Sdim        LLVM_DEBUG(dbgs() << "Found `main` function with incorrect type: "
262344779Sdim                          << *F.getFunctionType() << "\n");
263344779Sdim        Value *Args[] = {UndefValue::get(MainArgTys[0]),
264344779Sdim                         UndefValue::get(MainArgTys[1])};
265344779Sdim        Value *Casted =
266344779Sdim            ConstantExpr::getBitCast(Main, PointerType::get(MainTy, 0));
267353358Sdim        CallMain = CallInst::Create(MainTy, Casted, Args, "call_main");
268327952Sdim        Use *UseMain = &CallMain->getOperandUse(2);
269327952Sdim        Uses.push_back(std::make_pair(UseMain, &F));
270327952Sdim      }
271327952Sdim    }
272327952Sdim  }
273327952Sdim
274311818Sdim  DenseMap<std::pair<Function *, FunctionType *>, Function *> Wrappers;
275311818Sdim
276311818Sdim  for (auto &UseFunc : Uses) {
277311818Sdim    Use *U = UseFunc.first;
278311818Sdim    Function *F = UseFunc.second;
279353358Sdim    auto *PTy = cast<PointerType>(U->get()->getType());
280353358Sdim    auto *Ty = dyn_cast<FunctionType>(PTy->getElementType());
281311818Sdim
282311818Sdim    // If the function is casted to something like i8* as a "generic pointer"
283311818Sdim    // to be later casted to something else, we can't generate a wrapper for it.
284311818Sdim    // Just ignore such casts for now.
285311818Sdim    if (!Ty)
286311818Sdim      continue;
287311818Sdim
288311818Sdim    auto Pair = Wrappers.insert(std::make_pair(std::make_pair(F, Ty), nullptr));
289311818Sdim    if (Pair.second)
290353358Sdim      Pair.first->second = createWrapper(F, Ty);
291311818Sdim
292311818Sdim    Function *Wrapper = Pair.first->second;
293311818Sdim    if (!Wrapper)
294311818Sdim      continue;
295311818Sdim
296311818Sdim    if (isa<Constant>(U->get()))
297311818Sdim      U->get()->replaceAllUsesWith(Wrapper);
298311818Sdim    else
299311818Sdim      U->set(Wrapper);
300311818Sdim  }
301311818Sdim
302327952Sdim  // If we created a wrapper for main, rename the wrapper so that it's the
303327952Sdim  // one that gets called from startup.
304327952Sdim  if (CallMain) {
305327952Sdim    Main->setName("__original_main");
306353358Sdim    auto *MainWrapper =
307327952Sdim        cast<Function>(CallMain->getCalledValue()->stripPointerCasts());
308327952Sdim    delete CallMain;
309344779Sdim    if (Main->isDeclaration()) {
310344779Sdim      // The wrapper is not needed in this case as we don't need to export
311344779Sdim      // it to anyone else.
312344779Sdim      MainWrapper->eraseFromParent();
313344779Sdim    } else {
314344779Sdim      // Otherwise give the wrapper the same linkage as the original main
315344779Sdim      // function, so that it can be called from the same places.
316344779Sdim      MainWrapper->setName("main");
317344779Sdim      MainWrapper->setLinkage(Main->getLinkage());
318344779Sdim      MainWrapper->setVisibility(Main->getVisibility());
319344779Sdim    }
320327952Sdim  }
321327952Sdim
322311818Sdim  return true;
323311818Sdim}
324