1//==- WebAssemblyAsmTypeCheck.cpp - Assembler for WebAssembly -*- C++ -*-==// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8/// 9/// \file 10/// This file is part of the WebAssembly Assembler. 11/// 12/// It contains code to translate a parsed .s file into MCInsts. 13/// 14//===----------------------------------------------------------------------===// 15 16#include "AsmParser/WebAssemblyAsmTypeCheck.h" 17#include "MCTargetDesc/WebAssemblyMCTargetDesc.h" 18#include "MCTargetDesc/WebAssemblyMCTypeUtilities.h" 19#include "MCTargetDesc/WebAssemblyTargetStreamer.h" 20#include "TargetInfo/WebAssemblyTargetInfo.h" 21#include "WebAssembly.h" 22#include "llvm/MC/MCContext.h" 23#include "llvm/MC/MCExpr.h" 24#include "llvm/MC/MCInst.h" 25#include "llvm/MC/MCInstrInfo.h" 26#include "llvm/MC/MCParser/MCParsedAsmOperand.h" 27#include "llvm/MC/MCParser/MCTargetAsmParser.h" 28#include "llvm/MC/MCSectionWasm.h" 29#include "llvm/MC/MCStreamer.h" 30#include "llvm/MC/MCSubtargetInfo.h" 31#include "llvm/MC/MCSymbol.h" 32#include "llvm/MC/MCSymbolWasm.h" 33#include "llvm/MC/TargetRegistry.h" 34#include "llvm/Support/Compiler.h" 35#include "llvm/Support/SourceMgr.h" 36 37using namespace llvm; 38 39#define DEBUG_TYPE "wasm-asm-parser" 40 41extern StringRef GetMnemonic(unsigned Opc); 42 43namespace llvm { 44 45WebAssemblyAsmTypeCheck::WebAssemblyAsmTypeCheck(MCAsmParser &Parser, 46 const MCInstrInfo &MII, 47 bool is64) 48 : Parser(Parser), MII(MII), is64(is64) {} 49 50void WebAssemblyAsmTypeCheck::funcDecl(const wasm::WasmSignature &Sig) { 51 LocalTypes.assign(Sig.Params.begin(), Sig.Params.end()); 52 ReturnTypes.assign(Sig.Returns.begin(), Sig.Returns.end()); 53 BrStack.emplace_back(Sig.Returns.begin(), Sig.Returns.end()); 54} 55 56void WebAssemblyAsmTypeCheck::localDecl( 57 const SmallVectorImpl<wasm::ValType> &Locals) { 58 LocalTypes.insert(LocalTypes.end(), Locals.begin(), Locals.end()); 59} 60 61void WebAssemblyAsmTypeCheck::dumpTypeStack(Twine Msg) { 62 LLVM_DEBUG({ 63 std::string s; 64 for (auto VT : Stack) { 65 s += WebAssembly::typeToString(VT); 66 s += " "; 67 } 68 dbgs() << Msg << s << '\n'; 69 }); 70} 71 72bool WebAssemblyAsmTypeCheck::typeError(SMLoc ErrorLoc, const Twine &Msg) { 73 // Once you get one type error in a function, it will likely trigger more 74 // which are mostly not helpful. 75 if (TypeErrorThisFunction) 76 return true; 77 // If we're currently in unreachable code, we suppress errors completely. 78 if (Unreachable) 79 return false; 80 TypeErrorThisFunction = true; 81 dumpTypeStack("current stack: "); 82 return Parser.Error(ErrorLoc, Msg); 83} 84 85bool WebAssemblyAsmTypeCheck::popType(SMLoc ErrorLoc, 86 std::optional<wasm::ValType> EVT) { 87 if (Stack.empty()) { 88 return typeError(ErrorLoc, 89 EVT ? StringRef("empty stack while popping ") + 90 WebAssembly::typeToString(*EVT) 91 : StringRef("empty stack while popping value")); 92 } 93 auto PVT = Stack.pop_back_val(); 94 if (EVT && *EVT != PVT) { 95 return typeError(ErrorLoc, 96 StringRef("popped ") + WebAssembly::typeToString(PVT) + 97 ", expected " + WebAssembly::typeToString(*EVT)); 98 } 99 return false; 100} 101 102bool WebAssemblyAsmTypeCheck::popRefType(SMLoc ErrorLoc) { 103 if (Stack.empty()) { 104 return typeError(ErrorLoc, StringRef("empty stack while popping reftype")); 105 } 106 auto PVT = Stack.pop_back_val(); 107 if (!WebAssembly::isRefType(PVT)) { 108 return typeError(ErrorLoc, StringRef("popped ") + 109 WebAssembly::typeToString(PVT) + 110 ", expected reftype"); 111 } 112 return false; 113} 114 115bool WebAssemblyAsmTypeCheck::getLocal(SMLoc ErrorLoc, const MCInst &Inst, 116 wasm::ValType &Type) { 117 auto Local = static_cast<size_t>(Inst.getOperand(0).getImm()); 118 if (Local >= LocalTypes.size()) 119 return typeError(ErrorLoc, StringRef("no local type specified for index ") + 120 std::to_string(Local)); 121 Type = LocalTypes[Local]; 122 return false; 123} 124 125static std::optional<std::string> 126checkStackTop(const SmallVectorImpl<wasm::ValType> &ExpectedStackTop, 127 const SmallVectorImpl<wasm::ValType> &Got) { 128 for (size_t I = 0; I < ExpectedStackTop.size(); I++) { 129 auto EVT = ExpectedStackTop[I]; 130 auto PVT = Got[Got.size() - ExpectedStackTop.size() + I]; 131 if (PVT != EVT) 132 return std::string{"got "} + WebAssembly::typeToString(PVT) + 133 ", expected " + WebAssembly::typeToString(EVT); 134 } 135 return std::nullopt; 136} 137 138bool WebAssemblyAsmTypeCheck::checkBr(SMLoc ErrorLoc, size_t Level) { 139 if (Level >= BrStack.size()) 140 return typeError(ErrorLoc, 141 StringRef("br: invalid depth ") + std::to_string(Level)); 142 const SmallVector<wasm::ValType, 4> &Expected = 143 BrStack[BrStack.size() - Level - 1]; 144 if (Expected.size() > Stack.size()) 145 return typeError(ErrorLoc, "br: insufficient values on the type stack"); 146 auto IsStackTopInvalid = checkStackTop(Expected, Stack); 147 if (IsStackTopInvalid) 148 return typeError(ErrorLoc, "br " + IsStackTopInvalid.value()); 149 return false; 150} 151 152bool WebAssemblyAsmTypeCheck::checkEnd(SMLoc ErrorLoc, bool PopVals) { 153 if (!PopVals) 154 BrStack.pop_back(); 155 if (LastSig.Returns.size() > Stack.size()) 156 return typeError(ErrorLoc, "end: insufficient values on the type stack"); 157 158 if (PopVals) { 159 for (auto VT : llvm::reverse(LastSig.Returns)) { 160 if (popType(ErrorLoc, VT)) 161 return true; 162 } 163 return false; 164 } 165 166 auto IsStackTopInvalid = checkStackTop(LastSig.Returns, Stack); 167 if (IsStackTopInvalid) 168 return typeError(ErrorLoc, "end " + IsStackTopInvalid.value()); 169 return false; 170} 171 172bool WebAssemblyAsmTypeCheck::checkSig(SMLoc ErrorLoc, 173 const wasm::WasmSignature &Sig) { 174 for (auto VT : llvm::reverse(Sig.Params)) 175 if (popType(ErrorLoc, VT)) 176 return true; 177 Stack.insert(Stack.end(), Sig.Returns.begin(), Sig.Returns.end()); 178 return false; 179} 180 181bool WebAssemblyAsmTypeCheck::getSymRef(SMLoc ErrorLoc, const MCInst &Inst, 182 const MCSymbolRefExpr *&SymRef) { 183 auto Op = Inst.getOperand(0); 184 if (!Op.isExpr()) 185 return typeError(ErrorLoc, StringRef("expected expression operand")); 186 SymRef = dyn_cast<MCSymbolRefExpr>(Op.getExpr()); 187 if (!SymRef) 188 return typeError(ErrorLoc, StringRef("expected symbol operand")); 189 return false; 190} 191 192bool WebAssemblyAsmTypeCheck::getGlobal(SMLoc ErrorLoc, const MCInst &Inst, 193 wasm::ValType &Type) { 194 const MCSymbolRefExpr *SymRef; 195 if (getSymRef(ErrorLoc, Inst, SymRef)) 196 return true; 197 auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol()); 198 switch (WasmSym->getType().value_or(wasm::WASM_SYMBOL_TYPE_DATA)) { 199 case wasm::WASM_SYMBOL_TYPE_GLOBAL: 200 Type = static_cast<wasm::ValType>(WasmSym->getGlobalType().Type); 201 break; 202 case wasm::WASM_SYMBOL_TYPE_FUNCTION: 203 case wasm::WASM_SYMBOL_TYPE_DATA: 204 switch (SymRef->getKind()) { 205 case MCSymbolRefExpr::VK_GOT: 206 case MCSymbolRefExpr::VK_WASM_GOT_TLS: 207 Type = is64 ? wasm::ValType::I64 : wasm::ValType::I32; 208 return false; 209 default: 210 break; 211 } 212 [[fallthrough]]; 213 default: 214 return typeError(ErrorLoc, StringRef("symbol ") + WasmSym->getName() + 215 " missing .globaltype"); 216 } 217 return false; 218} 219 220bool WebAssemblyAsmTypeCheck::getTable(SMLoc ErrorLoc, const MCInst &Inst, 221 wasm::ValType &Type) { 222 const MCSymbolRefExpr *SymRef; 223 if (getSymRef(ErrorLoc, Inst, SymRef)) 224 return true; 225 auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol()); 226 if (WasmSym->getType().value_or(wasm::WASM_SYMBOL_TYPE_DATA) != 227 wasm::WASM_SYMBOL_TYPE_TABLE) 228 return typeError(ErrorLoc, StringRef("symbol ") + WasmSym->getName() + 229 " missing .tabletype"); 230 Type = static_cast<wasm::ValType>(WasmSym->getTableType().ElemType); 231 return false; 232} 233 234bool WebAssemblyAsmTypeCheck::endOfFunction(SMLoc ErrorLoc) { 235 // Check the return types. 236 for (auto RVT : llvm::reverse(ReturnTypes)) { 237 if (popType(ErrorLoc, RVT)) 238 return true; 239 } 240 if (!Stack.empty()) { 241 return typeError(ErrorLoc, std::to_string(Stack.size()) + 242 " superfluous return values"); 243 } 244 Unreachable = true; 245 return false; 246} 247 248bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst, 249 OperandVector &Operands) { 250 auto Opc = Inst.getOpcode(); 251 auto Name = GetMnemonic(Opc); 252 dumpTypeStack("typechecking " + Name + ": "); 253 wasm::ValType Type; 254 if (Name == "local.get") { 255 if (getLocal(Operands[1]->getStartLoc(), Inst, Type)) 256 return true; 257 Stack.push_back(Type); 258 } else if (Name == "local.set") { 259 if (getLocal(Operands[1]->getStartLoc(), Inst, Type)) 260 return true; 261 if (popType(ErrorLoc, Type)) 262 return true; 263 } else if (Name == "local.tee") { 264 if (getLocal(Operands[1]->getStartLoc(), Inst, Type)) 265 return true; 266 if (popType(ErrorLoc, Type)) 267 return true; 268 Stack.push_back(Type); 269 } else if (Name == "global.get") { 270 if (getGlobal(Operands[1]->getStartLoc(), Inst, Type)) 271 return true; 272 Stack.push_back(Type); 273 } else if (Name == "global.set") { 274 if (getGlobal(Operands[1]->getStartLoc(), Inst, Type)) 275 return true; 276 if (popType(ErrorLoc, Type)) 277 return true; 278 } else if (Name == "table.get") { 279 if (getTable(Operands[1]->getStartLoc(), Inst, Type)) 280 return true; 281 if (popType(ErrorLoc, wasm::ValType::I32)) 282 return true; 283 Stack.push_back(Type); 284 } else if (Name == "table.set") { 285 if (getTable(Operands[1]->getStartLoc(), Inst, Type)) 286 return true; 287 if (popType(ErrorLoc, Type)) 288 return true; 289 if (popType(ErrorLoc, wasm::ValType::I32)) 290 return true; 291 } else if (Name == "table.fill") { 292 if (getTable(Operands[1]->getStartLoc(), Inst, Type)) 293 return true; 294 if (popType(ErrorLoc, wasm::ValType::I32)) 295 return true; 296 if (popType(ErrorLoc, Type)) 297 return true; 298 if (popType(ErrorLoc, wasm::ValType::I32)) 299 return true; 300 } else if (Name == "memory.fill") { 301 Type = is64 ? wasm::ValType::I64 : wasm::ValType::I32; 302 if (popType(ErrorLoc, Type)) 303 return true; 304 if (popType(ErrorLoc, wasm::ValType::I32)) 305 return true; 306 if (popType(ErrorLoc, Type)) 307 return true; 308 } else if (Name == "memory.copy") { 309 Type = is64 ? wasm::ValType::I64 : wasm::ValType::I32; 310 if (popType(ErrorLoc, Type)) 311 return true; 312 if (popType(ErrorLoc, Type)) 313 return true; 314 if (popType(ErrorLoc, Type)) 315 return true; 316 } else if (Name == "memory.init") { 317 Type = is64 ? wasm::ValType::I64 : wasm::ValType::I32; 318 if (popType(ErrorLoc, wasm::ValType::I32)) 319 return true; 320 if (popType(ErrorLoc, wasm::ValType::I32)) 321 return true; 322 if (popType(ErrorLoc, Type)) 323 return true; 324 } else if (Name == "drop") { 325 if (popType(ErrorLoc, {})) 326 return true; 327 } else if (Name == "try" || Name == "block" || Name == "loop" || 328 Name == "if") { 329 if (Name == "if" && popType(ErrorLoc, wasm::ValType::I32)) 330 return true; 331 if (Name == "loop") 332 BrStack.emplace_back(LastSig.Params.begin(), LastSig.Params.end()); 333 else 334 BrStack.emplace_back(LastSig.Returns.begin(), LastSig.Returns.end()); 335 } else if (Name == "end_block" || Name == "end_loop" || Name == "end_if" || 336 Name == "else" || Name == "end_try" || Name == "catch" || 337 Name == "catch_all" || Name == "delegate") { 338 if (checkEnd(ErrorLoc, 339 Name == "else" || Name == "catch" || Name == "catch_all")) 340 return true; 341 Unreachable = false; 342 if (Name == "catch") { 343 const MCSymbolRefExpr *SymRef; 344 if (getSymRef(Operands[1]->getStartLoc(), Inst, SymRef)) 345 return true; 346 const auto *WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol()); 347 const auto *Sig = WasmSym->getSignature(); 348 if (!Sig || WasmSym->getType() != wasm::WASM_SYMBOL_TYPE_TAG) 349 return typeError(Operands[1]->getStartLoc(), StringRef("symbol ") + 350 WasmSym->getName() + 351 " missing .tagtype"); 352 // catch instruction pushes values whose types are specified in the tag's 353 // "params" part 354 Stack.insert(Stack.end(), Sig->Params.begin(), Sig->Params.end()); 355 } 356 } else if (Name == "br") { 357 const MCOperand &Operand = Inst.getOperand(0); 358 if (!Operand.isImm()) 359 return false; 360 if (checkBr(ErrorLoc, static_cast<size_t>(Operand.getImm()))) 361 return true; 362 } else if (Name == "return") { 363 if (endOfFunction(ErrorLoc)) 364 return true; 365 } else if (Name == "call_indirect" || Name == "return_call_indirect") { 366 // Function value. 367 if (popType(ErrorLoc, wasm::ValType::I32)) 368 return true; 369 if (checkSig(ErrorLoc, LastSig)) 370 return true; 371 if (Name == "return_call_indirect" && endOfFunction(ErrorLoc)) 372 return true; 373 } else if (Name == "call" || Name == "return_call") { 374 const MCSymbolRefExpr *SymRef; 375 if (getSymRef(Operands[1]->getStartLoc(), Inst, SymRef)) 376 return true; 377 auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol()); 378 auto Sig = WasmSym->getSignature(); 379 if (!Sig || WasmSym->getType() != wasm::WASM_SYMBOL_TYPE_FUNCTION) 380 return typeError(Operands[1]->getStartLoc(), StringRef("symbol ") + 381 WasmSym->getName() + 382 " missing .functype"); 383 if (checkSig(ErrorLoc, *Sig)) 384 return true; 385 if (Name == "return_call" && endOfFunction(ErrorLoc)) 386 return true; 387 } else if (Name == "unreachable") { 388 Unreachable = true; 389 } else if (Name == "ref.is_null") { 390 if (popRefType(ErrorLoc)) 391 return true; 392 Stack.push_back(wasm::ValType::I32); 393 } else { 394 // The current instruction is a stack instruction which doesn't have 395 // explicit operands that indicate push/pop types, so we get those from 396 // the register version of the same instruction. 397 auto RegOpc = WebAssembly::getRegisterOpcode(Opc); 398 assert(RegOpc != -1 && "Failed to get register version of MC instruction"); 399 const auto &II = MII.get(RegOpc); 400 // First pop all the uses off the stack and check them. 401 for (unsigned I = II.getNumOperands(); I > II.getNumDefs(); I--) { 402 const auto &Op = II.operands()[I - 1]; 403 if (Op.OperandType == MCOI::OPERAND_REGISTER) { 404 auto VT = WebAssembly::regClassToValType(Op.RegClass); 405 if (popType(ErrorLoc, VT)) 406 return true; 407 } 408 } 409 // Now push all the defs onto the stack. 410 for (unsigned I = 0; I < II.getNumDefs(); I++) { 411 const auto &Op = II.operands()[I]; 412 assert(Op.OperandType == MCOI::OPERAND_REGISTER && "Register expected"); 413 auto VT = WebAssembly::regClassToValType(Op.RegClass); 414 Stack.push_back(VT); 415 } 416 } 417 return false; 418} 419 420} // end namespace llvm 421