1//==- WebAssemblyAsmTypeCheck.cpp - Assembler for WebAssembly -*- C++ -*-==//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file is part of the WebAssembly Assembler.
11///
12/// It contains code to translate a parsed .s file into MCInsts.
13///
14//===----------------------------------------------------------------------===//
15
16#include "AsmParser/WebAssemblyAsmTypeCheck.h"
17#include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
18#include "MCTargetDesc/WebAssemblyMCTypeUtilities.h"
19#include "MCTargetDesc/WebAssemblyTargetStreamer.h"
20#include "TargetInfo/WebAssemblyTargetInfo.h"
21#include "WebAssembly.h"
22#include "llvm/MC/MCContext.h"
23#include "llvm/MC/MCExpr.h"
24#include "llvm/MC/MCInst.h"
25#include "llvm/MC/MCInstrInfo.h"
26#include "llvm/MC/MCParser/MCParsedAsmOperand.h"
27#include "llvm/MC/MCParser/MCTargetAsmParser.h"
28#include "llvm/MC/MCSectionWasm.h"
29#include "llvm/MC/MCStreamer.h"
30#include "llvm/MC/MCSubtargetInfo.h"
31#include "llvm/MC/MCSymbol.h"
32#include "llvm/MC/MCSymbolWasm.h"
33#include "llvm/MC/TargetRegistry.h"
34#include "llvm/Support/Compiler.h"
35#include "llvm/Support/SourceMgr.h"
36
37using namespace llvm;
38
39#define DEBUG_TYPE "wasm-asm-parser"
40
41extern StringRef GetMnemonic(unsigned Opc);
42
43namespace llvm {
44
45WebAssemblyAsmTypeCheck::WebAssemblyAsmTypeCheck(MCAsmParser &Parser,
46                                                 const MCInstrInfo &MII,
47                                                 bool is64)
48    : Parser(Parser), MII(MII), is64(is64) {}
49
50void WebAssemblyAsmTypeCheck::funcDecl(const wasm::WasmSignature &Sig) {
51  LocalTypes.assign(Sig.Params.begin(), Sig.Params.end());
52  ReturnTypes.assign(Sig.Returns.begin(), Sig.Returns.end());
53  BrStack.emplace_back(Sig.Returns.begin(), Sig.Returns.end());
54}
55
56void WebAssemblyAsmTypeCheck::localDecl(
57    const SmallVectorImpl<wasm::ValType> &Locals) {
58  LocalTypes.insert(LocalTypes.end(), Locals.begin(), Locals.end());
59}
60
61void WebAssemblyAsmTypeCheck::dumpTypeStack(Twine Msg) {
62  LLVM_DEBUG({
63    std::string s;
64    for (auto VT : Stack) {
65      s += WebAssembly::typeToString(VT);
66      s += " ";
67    }
68    dbgs() << Msg << s << '\n';
69  });
70}
71
72bool WebAssemblyAsmTypeCheck::typeError(SMLoc ErrorLoc, const Twine &Msg) {
73  // Once you get one type error in a function, it will likely trigger more
74  // which are mostly not helpful.
75  if (TypeErrorThisFunction)
76    return true;
77  // If we're currently in unreachable code, we suppress errors completely.
78  if (Unreachable)
79    return false;
80  TypeErrorThisFunction = true;
81  dumpTypeStack("current stack: ");
82  return Parser.Error(ErrorLoc, Msg);
83}
84
85bool WebAssemblyAsmTypeCheck::popType(SMLoc ErrorLoc,
86                                      std::optional<wasm::ValType> EVT) {
87  if (Stack.empty()) {
88    return typeError(ErrorLoc,
89                     EVT ? StringRef("empty stack while popping ") +
90                               WebAssembly::typeToString(*EVT)
91                         : StringRef("empty stack while popping value"));
92  }
93  auto PVT = Stack.pop_back_val();
94  if (EVT && *EVT != PVT) {
95    return typeError(ErrorLoc,
96                     StringRef("popped ") + WebAssembly::typeToString(PVT) +
97                         ", expected " + WebAssembly::typeToString(*EVT));
98  }
99  return false;
100}
101
102bool WebAssemblyAsmTypeCheck::popRefType(SMLoc ErrorLoc) {
103  if (Stack.empty()) {
104    return typeError(ErrorLoc, StringRef("empty stack while popping reftype"));
105  }
106  auto PVT = Stack.pop_back_val();
107  if (!WebAssembly::isRefType(PVT)) {
108    return typeError(ErrorLoc, StringRef("popped ") +
109                                   WebAssembly::typeToString(PVT) +
110                                   ", expected reftype");
111  }
112  return false;
113}
114
115bool WebAssemblyAsmTypeCheck::getLocal(SMLoc ErrorLoc, const MCInst &Inst,
116                                       wasm::ValType &Type) {
117  auto Local = static_cast<size_t>(Inst.getOperand(0).getImm());
118  if (Local >= LocalTypes.size())
119    return typeError(ErrorLoc, StringRef("no local type specified for index ") +
120                                   std::to_string(Local));
121  Type = LocalTypes[Local];
122  return false;
123}
124
125static std::optional<std::string>
126checkStackTop(const SmallVectorImpl<wasm::ValType> &ExpectedStackTop,
127              const SmallVectorImpl<wasm::ValType> &Got) {
128  for (size_t I = 0; I < ExpectedStackTop.size(); I++) {
129    auto EVT = ExpectedStackTop[I];
130    auto PVT = Got[Got.size() - ExpectedStackTop.size() + I];
131    if (PVT != EVT)
132      return std::string{"got "} + WebAssembly::typeToString(PVT) +
133             ", expected " + WebAssembly::typeToString(EVT);
134  }
135  return std::nullopt;
136}
137
138bool WebAssemblyAsmTypeCheck::checkBr(SMLoc ErrorLoc, size_t Level) {
139  if (Level >= BrStack.size())
140    return typeError(ErrorLoc,
141                     StringRef("br: invalid depth ") + std::to_string(Level));
142  const SmallVector<wasm::ValType, 4> &Expected =
143      BrStack[BrStack.size() - Level - 1];
144  if (Expected.size() > Stack.size())
145    return typeError(ErrorLoc, "br: insufficient values on the type stack");
146  auto IsStackTopInvalid = checkStackTop(Expected, Stack);
147  if (IsStackTopInvalid)
148    return typeError(ErrorLoc, "br " + IsStackTopInvalid.value());
149  return false;
150}
151
152bool WebAssemblyAsmTypeCheck::checkEnd(SMLoc ErrorLoc, bool PopVals) {
153  if (!PopVals)
154    BrStack.pop_back();
155  if (LastSig.Returns.size() > Stack.size())
156    return typeError(ErrorLoc, "end: insufficient values on the type stack");
157
158  if (PopVals) {
159    for (auto VT : llvm::reverse(LastSig.Returns)) {
160      if (popType(ErrorLoc, VT))
161        return true;
162    }
163    return false;
164  }
165
166  auto IsStackTopInvalid = checkStackTop(LastSig.Returns, Stack);
167  if (IsStackTopInvalid)
168    return typeError(ErrorLoc, "end " + IsStackTopInvalid.value());
169  return false;
170}
171
172bool WebAssemblyAsmTypeCheck::checkSig(SMLoc ErrorLoc,
173                                       const wasm::WasmSignature &Sig) {
174  for (auto VT : llvm::reverse(Sig.Params))
175    if (popType(ErrorLoc, VT))
176      return true;
177  Stack.insert(Stack.end(), Sig.Returns.begin(), Sig.Returns.end());
178  return false;
179}
180
181bool WebAssemblyAsmTypeCheck::getSymRef(SMLoc ErrorLoc, const MCInst &Inst,
182                                        const MCSymbolRefExpr *&SymRef) {
183  auto Op = Inst.getOperand(0);
184  if (!Op.isExpr())
185    return typeError(ErrorLoc, StringRef("expected expression operand"));
186  SymRef = dyn_cast<MCSymbolRefExpr>(Op.getExpr());
187  if (!SymRef)
188    return typeError(ErrorLoc, StringRef("expected symbol operand"));
189  return false;
190}
191
192bool WebAssemblyAsmTypeCheck::getGlobal(SMLoc ErrorLoc, const MCInst &Inst,
193                                        wasm::ValType &Type) {
194  const MCSymbolRefExpr *SymRef;
195  if (getSymRef(ErrorLoc, Inst, SymRef))
196    return true;
197  auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
198  switch (WasmSym->getType().value_or(wasm::WASM_SYMBOL_TYPE_DATA)) {
199  case wasm::WASM_SYMBOL_TYPE_GLOBAL:
200    Type = static_cast<wasm::ValType>(WasmSym->getGlobalType().Type);
201    break;
202  case wasm::WASM_SYMBOL_TYPE_FUNCTION:
203  case wasm::WASM_SYMBOL_TYPE_DATA:
204    switch (SymRef->getKind()) {
205    case MCSymbolRefExpr::VK_GOT:
206    case MCSymbolRefExpr::VK_WASM_GOT_TLS:
207      Type = is64 ? wasm::ValType::I64 : wasm::ValType::I32;
208      return false;
209    default:
210      break;
211    }
212    [[fallthrough]];
213  default:
214    return typeError(ErrorLoc, StringRef("symbol ") + WasmSym->getName() +
215                                   " missing .globaltype");
216  }
217  return false;
218}
219
220bool WebAssemblyAsmTypeCheck::getTable(SMLoc ErrorLoc, const MCInst &Inst,
221                                       wasm::ValType &Type) {
222  const MCSymbolRefExpr *SymRef;
223  if (getSymRef(ErrorLoc, Inst, SymRef))
224    return true;
225  auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
226  if (WasmSym->getType().value_or(wasm::WASM_SYMBOL_TYPE_DATA) !=
227      wasm::WASM_SYMBOL_TYPE_TABLE)
228    return typeError(ErrorLoc, StringRef("symbol ") + WasmSym->getName() +
229                                   " missing .tabletype");
230  Type = static_cast<wasm::ValType>(WasmSym->getTableType().ElemType);
231  return false;
232}
233
234bool WebAssemblyAsmTypeCheck::endOfFunction(SMLoc ErrorLoc) {
235  // Check the return types.
236  for (auto RVT : llvm::reverse(ReturnTypes)) {
237    if (popType(ErrorLoc, RVT))
238      return true;
239  }
240  if (!Stack.empty()) {
241    return typeError(ErrorLoc, std::to_string(Stack.size()) +
242                                   " superfluous return values");
243  }
244  Unreachable = true;
245  return false;
246}
247
248bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
249                                        OperandVector &Operands) {
250  auto Opc = Inst.getOpcode();
251  auto Name = GetMnemonic(Opc);
252  dumpTypeStack("typechecking " + Name + ": ");
253  wasm::ValType Type;
254  if (Name == "local.get") {
255    if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
256      return true;
257    Stack.push_back(Type);
258  } else if (Name == "local.set") {
259    if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
260      return true;
261    if (popType(ErrorLoc, Type))
262      return true;
263  } else if (Name == "local.tee") {
264    if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
265      return true;
266    if (popType(ErrorLoc, Type))
267      return true;
268    Stack.push_back(Type);
269  } else if (Name == "global.get") {
270    if (getGlobal(Operands[1]->getStartLoc(), Inst, Type))
271      return true;
272    Stack.push_back(Type);
273  } else if (Name == "global.set") {
274    if (getGlobal(Operands[1]->getStartLoc(), Inst, Type))
275      return true;
276    if (popType(ErrorLoc, Type))
277      return true;
278  } else if (Name == "table.get") {
279    if (getTable(Operands[1]->getStartLoc(), Inst, Type))
280      return true;
281    if (popType(ErrorLoc, wasm::ValType::I32))
282      return true;
283    Stack.push_back(Type);
284  } else if (Name == "table.set") {
285    if (getTable(Operands[1]->getStartLoc(), Inst, Type))
286      return true;
287    if (popType(ErrorLoc, Type))
288      return true;
289    if (popType(ErrorLoc, wasm::ValType::I32))
290      return true;
291  } else if (Name == "table.fill") {
292    if (getTable(Operands[1]->getStartLoc(), Inst, Type))
293      return true;
294    if (popType(ErrorLoc, wasm::ValType::I32))
295      return true;
296    if (popType(ErrorLoc, Type))
297      return true;
298    if (popType(ErrorLoc, wasm::ValType::I32))
299      return true;
300  } else if (Name == "memory.fill") {
301    Type = is64 ? wasm::ValType::I64 : wasm::ValType::I32;
302    if (popType(ErrorLoc, Type))
303      return true;
304    if (popType(ErrorLoc, wasm::ValType::I32))
305      return true;
306    if (popType(ErrorLoc, Type))
307      return true;
308  } else if (Name == "memory.copy") {
309    Type = is64 ? wasm::ValType::I64 : wasm::ValType::I32;
310    if (popType(ErrorLoc, Type))
311      return true;
312    if (popType(ErrorLoc, Type))
313      return true;
314    if (popType(ErrorLoc, Type))
315      return true;
316  } else if (Name == "memory.init") {
317    Type = is64 ? wasm::ValType::I64 : wasm::ValType::I32;
318    if (popType(ErrorLoc, wasm::ValType::I32))
319      return true;
320    if (popType(ErrorLoc, wasm::ValType::I32))
321      return true;
322    if (popType(ErrorLoc, Type))
323      return true;
324  } else if (Name == "drop") {
325    if (popType(ErrorLoc, {}))
326      return true;
327  } else if (Name == "try" || Name == "block" || Name == "loop" ||
328             Name == "if") {
329    if (Name == "if" && popType(ErrorLoc, wasm::ValType::I32))
330      return true;
331    if (Name == "loop")
332      BrStack.emplace_back(LastSig.Params.begin(), LastSig.Params.end());
333    else
334      BrStack.emplace_back(LastSig.Returns.begin(), LastSig.Returns.end());
335  } else if (Name == "end_block" || Name == "end_loop" || Name == "end_if" ||
336             Name == "else" || Name == "end_try" || Name == "catch" ||
337             Name == "catch_all" || Name == "delegate") {
338    if (checkEnd(ErrorLoc,
339                 Name == "else" || Name == "catch" || Name == "catch_all"))
340      return true;
341    Unreachable = false;
342    if (Name == "catch") {
343      const MCSymbolRefExpr *SymRef;
344      if (getSymRef(Operands[1]->getStartLoc(), Inst, SymRef))
345        return true;
346      const auto *WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
347      const auto *Sig = WasmSym->getSignature();
348      if (!Sig || WasmSym->getType() != wasm::WASM_SYMBOL_TYPE_TAG)
349        return typeError(Operands[1]->getStartLoc(), StringRef("symbol ") +
350                                                         WasmSym->getName() +
351                                                         " missing .tagtype");
352      // catch instruction pushes values whose types are specified in the tag's
353      // "params" part
354      Stack.insert(Stack.end(), Sig->Params.begin(), Sig->Params.end());
355    }
356  } else if (Name == "br") {
357    const MCOperand &Operand = Inst.getOperand(0);
358    if (!Operand.isImm())
359      return false;
360    if (checkBr(ErrorLoc, static_cast<size_t>(Operand.getImm())))
361      return true;
362  } else if (Name == "return") {
363    if (endOfFunction(ErrorLoc))
364      return true;
365  } else if (Name == "call_indirect" || Name == "return_call_indirect") {
366    // Function value.
367    if (popType(ErrorLoc, wasm::ValType::I32))
368      return true;
369    if (checkSig(ErrorLoc, LastSig))
370      return true;
371    if (Name == "return_call_indirect" && endOfFunction(ErrorLoc))
372      return true;
373  } else if (Name == "call" || Name == "return_call") {
374    const MCSymbolRefExpr *SymRef;
375    if (getSymRef(Operands[1]->getStartLoc(), Inst, SymRef))
376      return true;
377    auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
378    auto Sig = WasmSym->getSignature();
379    if (!Sig || WasmSym->getType() != wasm::WASM_SYMBOL_TYPE_FUNCTION)
380      return typeError(Operands[1]->getStartLoc(), StringRef("symbol ") +
381                                                       WasmSym->getName() +
382                                                       " missing .functype");
383    if (checkSig(ErrorLoc, *Sig))
384      return true;
385    if (Name == "return_call" && endOfFunction(ErrorLoc))
386      return true;
387  } else if (Name == "unreachable") {
388    Unreachable = true;
389  } else if (Name == "ref.is_null") {
390    if (popRefType(ErrorLoc))
391      return true;
392    Stack.push_back(wasm::ValType::I32);
393  } else {
394    // The current instruction is a stack instruction which doesn't have
395    // explicit operands that indicate push/pop types, so we get those from
396    // the register version of the same instruction.
397    auto RegOpc = WebAssembly::getRegisterOpcode(Opc);
398    assert(RegOpc != -1 && "Failed to get register version of MC instruction");
399    const auto &II = MII.get(RegOpc);
400    // First pop all the uses off the stack and check them.
401    for (unsigned I = II.getNumOperands(); I > II.getNumDefs(); I--) {
402      const auto &Op = II.operands()[I - 1];
403      if (Op.OperandType == MCOI::OPERAND_REGISTER) {
404        auto VT = WebAssembly::regClassToValType(Op.RegClass);
405        if (popType(ErrorLoc, VT))
406          return true;
407      }
408    }
409    // Now push all the defs onto the stack.
410    for (unsigned I = 0; I < II.getNumDefs(); I++) {
411      const auto &Op = II.operands()[I];
412      assert(Op.OperandType == MCOI::OPERAND_REGISTER && "Register expected");
413      auto VT = WebAssembly::regClassToValType(Op.RegClass);
414      Stack.push_back(VT);
415    }
416  }
417  return false;
418}
419
420} // end namespace llvm
421