1//===---- OrcRemoteTargetServer.h - Orc Remote-target Server ----*- C++ -*-===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This file defines the OrcRemoteTargetServer class. It can be used to build a
11// JIT server that can execute code sent from an OrcRemoteTargetClient.
12//
13//===----------------------------------------------------------------------===//
14
15#ifndef LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETSERVER_H
16#define LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETSERVER_H
17
18#include "OrcRemoteTargetRPCAPI.h"
19#include "llvm/ExecutionEngine/RTDyldMemoryManager.h"
20#include "llvm/Support/Debug.h"
21#include "llvm/Support/Format.h"
22#include "llvm/Support/Process.h"
23#include "llvm/Support/raw_ostream.h"
24#include <map>
25
26#define DEBUG_TYPE "orc-remote"
27
28namespace llvm {
29namespace orc {
30namespace remote {
31
32template <typename ChannelT, typename TargetT>
33class OrcRemoteTargetServer : public OrcRemoteTargetRPCAPI {
34public:
35  typedef std::function<TargetAddress(const std::string &Name)>
36      SymbolLookupFtor;
37
38  OrcRemoteTargetServer(ChannelT &Channel, SymbolLookupFtor SymbolLookup)
39      : Channel(Channel), SymbolLookup(std::move(SymbolLookup)) {}
40
41  std::error_code getNextProcId(JITProcId &Id) {
42    return deserialize(Channel, Id);
43  }
44
45  std::error_code handleKnownProcedure(JITProcId Id) {
46    typedef OrcRemoteTargetServer ThisT;
47
48    DEBUG(dbgs() << "Handling known proc: " << getJITProcIdName(Id) << "\n");
49
50    switch (Id) {
51    case CallIntVoidId:
52      return handle<CallIntVoid>(Channel, *this, &ThisT::handleCallIntVoid);
53    case CallMainId:
54      return handle<CallMain>(Channel, *this, &ThisT::handleCallMain);
55    case CallVoidVoidId:
56      return handle<CallVoidVoid>(Channel, *this, &ThisT::handleCallVoidVoid);
57    case CreateRemoteAllocatorId:
58      return handle<CreateRemoteAllocator>(Channel, *this,
59                                           &ThisT::handleCreateRemoteAllocator);
60    case CreateIndirectStubsOwnerId:
61      return handle<CreateIndirectStubsOwner>(
62          Channel, *this, &ThisT::handleCreateIndirectStubsOwner);
63    case DestroyRemoteAllocatorId:
64      return handle<DestroyRemoteAllocator>(
65          Channel, *this, &ThisT::handleDestroyRemoteAllocator);
66    case DestroyIndirectStubsOwnerId:
67      return handle<DestroyIndirectStubsOwner>(
68          Channel, *this, &ThisT::handleDestroyIndirectStubsOwner);
69    case EmitIndirectStubsId:
70      return handle<EmitIndirectStubs>(Channel, *this,
71                                       &ThisT::handleEmitIndirectStubs);
72    case EmitResolverBlockId:
73      return handle<EmitResolverBlock>(Channel, *this,
74                                       &ThisT::handleEmitResolverBlock);
75    case EmitTrampolineBlockId:
76      return handle<EmitTrampolineBlock>(Channel, *this,
77                                         &ThisT::handleEmitTrampolineBlock);
78    case GetSymbolAddressId:
79      return handle<GetSymbolAddress>(Channel, *this,
80                                      &ThisT::handleGetSymbolAddress);
81    case GetRemoteInfoId:
82      return handle<GetRemoteInfo>(Channel, *this, &ThisT::handleGetRemoteInfo);
83    case ReadMemId:
84      return handle<ReadMem>(Channel, *this, &ThisT::handleReadMem);
85    case ReserveMemId:
86      return handle<ReserveMem>(Channel, *this, &ThisT::handleReserveMem);
87    case SetProtectionsId:
88      return handle<SetProtections>(Channel, *this,
89                                    &ThisT::handleSetProtections);
90    case WriteMemId:
91      return handle<WriteMem>(Channel, *this, &ThisT::handleWriteMem);
92    case WritePtrId:
93      return handle<WritePtr>(Channel, *this, &ThisT::handleWritePtr);
94    default:
95      return orcError(OrcErrorCode::UnexpectedRPCCall);
96    }
97
98    llvm_unreachable("Unhandled JIT RPC procedure Id.");
99  }
100
101  std::error_code requestCompile(TargetAddress &CompiledFnAddr,
102                                 TargetAddress TrampolineAddr) {
103    if (auto EC = call<RequestCompile>(Channel, TrampolineAddr))
104      return EC;
105
106    while (1) {
107      JITProcId Id = InvalidId;
108      if (auto EC = getNextProcId(Id))
109        return EC;
110
111      switch (Id) {
112      case RequestCompileResponseId:
113        return handle<RequestCompileResponse>(Channel,
114                                              readArgs(CompiledFnAddr));
115      default:
116        if (auto EC = handleKnownProcedure(Id))
117          return EC;
118      }
119    }
120
121    llvm_unreachable("Fell through request-compile command loop.");
122  }
123
124private:
125  struct Allocator {
126    Allocator() = default;
127    Allocator(Allocator &&Other) : Allocs(std::move(Other.Allocs)) {}
128    Allocator &operator=(Allocator &&Other) {
129      Allocs = std::move(Other.Allocs);
130      return *this;
131    }
132
133    ~Allocator() {
134      for (auto &Alloc : Allocs)
135        sys::Memory::releaseMappedMemory(Alloc.second);
136    }
137
138    std::error_code allocate(void *&Addr, size_t Size, uint32_t Align) {
139      std::error_code EC;
140      sys::MemoryBlock MB = sys::Memory::allocateMappedMemory(
141          Size, nullptr, sys::Memory::MF_READ | sys::Memory::MF_WRITE, EC);
142      if (EC)
143        return EC;
144
145      Addr = MB.base();
146      assert(Allocs.find(MB.base()) == Allocs.end() && "Duplicate alloc");
147      Allocs[MB.base()] = std::move(MB);
148      return std::error_code();
149    }
150
151    std::error_code setProtections(void *block, unsigned Flags) {
152      auto I = Allocs.find(block);
153      if (I == Allocs.end())
154        return orcError(OrcErrorCode::RemoteMProtectAddrUnrecognized);
155      return sys::Memory::protectMappedMemory(I->second, Flags);
156    }
157
158  private:
159    std::map<void *, sys::MemoryBlock> Allocs;
160  };
161
162  static std::error_code doNothing() { return std::error_code(); }
163
164  static TargetAddress reenter(void *JITTargetAddr, void *TrampolineAddr) {
165    TargetAddress CompiledFnAddr = 0;
166
167    auto T = static_cast<OrcRemoteTargetServer *>(JITTargetAddr);
168    auto EC = T->requestCompile(
169        CompiledFnAddr, static_cast<TargetAddress>(
170                            reinterpret_cast<uintptr_t>(TrampolineAddr)));
171    assert(!EC && "Compile request failed");
172    (void)EC;
173    return CompiledFnAddr;
174  }
175
176  std::error_code handleCallIntVoid(TargetAddress Addr) {
177    typedef int (*IntVoidFnTy)();
178    IntVoidFnTy Fn =
179        reinterpret_cast<IntVoidFnTy>(static_cast<uintptr_t>(Addr));
180
181    DEBUG(dbgs() << "  Calling "
182                 << reinterpret_cast<void *>(reinterpret_cast<intptr_t>(Fn))
183                 << "\n");
184    int Result = Fn();
185    DEBUG(dbgs() << "  Result = " << Result << "\n");
186
187    return call<CallIntVoidResponse>(Channel, Result);
188  }
189
190  std::error_code handleCallMain(TargetAddress Addr,
191                                 std::vector<std::string> Args) {
192    typedef int (*MainFnTy)(int, const char *[]);
193
194    MainFnTy Fn = reinterpret_cast<MainFnTy>(static_cast<uintptr_t>(Addr));
195    int ArgC = Args.size() + 1;
196    int Idx = 1;
197    std::unique_ptr<const char *[]> ArgV(new const char *[ArgC + 1]);
198    ArgV[0] = "<jit process>";
199    for (auto &Arg : Args)
200      ArgV[Idx++] = Arg.c_str();
201
202    DEBUG(dbgs() << "  Calling " << reinterpret_cast<void *>(Fn) << "\n");
203    int Result = Fn(ArgC, ArgV.get());
204    DEBUG(dbgs() << "  Result = " << Result << "\n");
205
206    return call<CallMainResponse>(Channel, Result);
207  }
208
209  std::error_code handleCallVoidVoid(TargetAddress Addr) {
210    typedef void (*VoidVoidFnTy)();
211    VoidVoidFnTy Fn =
212        reinterpret_cast<VoidVoidFnTy>(static_cast<uintptr_t>(Addr));
213
214    DEBUG(dbgs() << "  Calling " << reinterpret_cast<void *>(Fn) << "\n");
215    Fn();
216    DEBUG(dbgs() << "  Complete.\n");
217
218    return call<CallVoidVoidResponse>(Channel);
219  }
220
221  std::error_code handleCreateRemoteAllocator(ResourceIdMgr::ResourceId Id) {
222    auto I = Allocators.find(Id);
223    if (I != Allocators.end())
224      return orcError(OrcErrorCode::RemoteAllocatorIdAlreadyInUse);
225    DEBUG(dbgs() << "  Created allocator " << Id << "\n");
226    Allocators[Id] = Allocator();
227    return std::error_code();
228  }
229
230  std::error_code handleCreateIndirectStubsOwner(ResourceIdMgr::ResourceId Id) {
231    auto I = IndirectStubsOwners.find(Id);
232    if (I != IndirectStubsOwners.end())
233      return orcError(OrcErrorCode::RemoteIndirectStubsOwnerIdAlreadyInUse);
234    DEBUG(dbgs() << "  Create indirect stubs owner " << Id << "\n");
235    IndirectStubsOwners[Id] = ISBlockOwnerList();
236    return std::error_code();
237  }
238
239  std::error_code handleDestroyRemoteAllocator(ResourceIdMgr::ResourceId Id) {
240    auto I = Allocators.find(Id);
241    if (I == Allocators.end())
242      return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
243    Allocators.erase(I);
244    DEBUG(dbgs() << "  Destroyed allocator " << Id << "\n");
245    return std::error_code();
246  }
247
248  std::error_code
249  handleDestroyIndirectStubsOwner(ResourceIdMgr::ResourceId Id) {
250    auto I = IndirectStubsOwners.find(Id);
251    if (I == IndirectStubsOwners.end())
252      return orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist);
253    IndirectStubsOwners.erase(I);
254    return std::error_code();
255  }
256
257  std::error_code handleEmitIndirectStubs(ResourceIdMgr::ResourceId Id,
258                                          uint32_t NumStubsRequired) {
259    DEBUG(dbgs() << "  ISMgr " << Id << " request " << NumStubsRequired
260                 << " stubs.\n");
261
262    auto StubOwnerItr = IndirectStubsOwners.find(Id);
263    if (StubOwnerItr == IndirectStubsOwners.end())
264      return orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist);
265
266    typename TargetT::IndirectStubsInfo IS;
267    if (auto EC =
268            TargetT::emitIndirectStubsBlock(IS, NumStubsRequired, nullptr))
269      return EC;
270
271    TargetAddress StubsBase =
272        static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(IS.getStub(0)));
273    TargetAddress PtrsBase =
274        static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(IS.getPtr(0)));
275    uint32_t NumStubsEmitted = IS.getNumStubs();
276
277    auto &BlockList = StubOwnerItr->second;
278    BlockList.push_back(std::move(IS));
279
280    return call<EmitIndirectStubsResponse>(Channel, StubsBase, PtrsBase,
281                                           NumStubsEmitted);
282  }
283
284  std::error_code handleEmitResolverBlock() {
285    std::error_code EC;
286    ResolverBlock = sys::OwningMemoryBlock(sys::Memory::allocateMappedMemory(
287        TargetT::ResolverCodeSize, nullptr,
288        sys::Memory::MF_READ | sys::Memory::MF_WRITE, EC));
289    if (EC)
290      return EC;
291
292    TargetT::writeResolverCode(static_cast<uint8_t *>(ResolverBlock.base()),
293                               &reenter, this);
294
295    return sys::Memory::protectMappedMemory(ResolverBlock.getMemoryBlock(),
296                                            sys::Memory::MF_READ |
297                                                sys::Memory::MF_EXEC);
298  }
299
300  std::error_code handleEmitTrampolineBlock() {
301    std::error_code EC;
302    auto TrampolineBlock =
303        sys::OwningMemoryBlock(sys::Memory::allocateMappedMemory(
304            sys::Process::getPageSize(), nullptr,
305            sys::Memory::MF_READ | sys::Memory::MF_WRITE, EC));
306    if (EC)
307      return EC;
308
309    unsigned NumTrampolines =
310        (sys::Process::getPageSize() - TargetT::PointerSize) /
311        TargetT::TrampolineSize;
312
313    uint8_t *TrampolineMem = static_cast<uint8_t *>(TrampolineBlock.base());
314    TargetT::writeTrampolines(TrampolineMem, ResolverBlock.base(),
315                              NumTrampolines);
316
317    EC = sys::Memory::protectMappedMemory(TrampolineBlock.getMemoryBlock(),
318                                          sys::Memory::MF_READ |
319                                              sys::Memory::MF_EXEC);
320
321    TrampolineBlocks.push_back(std::move(TrampolineBlock));
322
323    return call<EmitTrampolineBlockResponse>(
324        Channel,
325        static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(TrampolineMem)),
326        NumTrampolines);
327  }
328
329  std::error_code handleGetSymbolAddress(const std::string &Name) {
330    TargetAddress Addr = SymbolLookup(Name);
331    DEBUG(dbgs() << "  Symbol '" << Name << "' =  " << format("0x%016x", Addr)
332                 << "\n");
333    return call<GetSymbolAddressResponse>(Channel, Addr);
334  }
335
336  std::error_code handleGetRemoteInfo() {
337    std::string ProcessTriple = sys::getProcessTriple();
338    uint32_t PointerSize = TargetT::PointerSize;
339    uint32_t PageSize = sys::Process::getPageSize();
340    uint32_t TrampolineSize = TargetT::TrampolineSize;
341    uint32_t IndirectStubSize = TargetT::IndirectStubsInfo::StubSize;
342    DEBUG(dbgs() << "  Remote info:\n"
343                 << "    triple             = '" << ProcessTriple << "'\n"
344                 << "    pointer size       = " << PointerSize << "\n"
345                 << "    page size          = " << PageSize << "\n"
346                 << "    trampoline size    = " << TrampolineSize << "\n"
347                 << "    indirect stub size = " << IndirectStubSize << "\n");
348    return call<GetRemoteInfoResponse>(Channel, ProcessTriple, PointerSize,
349                                       PageSize, TrampolineSize,
350                                       IndirectStubSize);
351  }
352
353  std::error_code handleReadMem(TargetAddress RSrc, uint64_t Size) {
354    char *Src = reinterpret_cast<char *>(static_cast<uintptr_t>(RSrc));
355
356    DEBUG(dbgs() << "  Reading " << Size << " bytes from "
357                 << static_cast<void *>(Src) << "\n");
358
359    if (auto EC = call<ReadMemResponse>(Channel))
360      return EC;
361
362    if (auto EC = Channel.appendBytes(Src, Size))
363      return EC;
364
365    return Channel.send();
366  }
367
368  std::error_code handleReserveMem(ResourceIdMgr::ResourceId Id, uint64_t Size,
369                                   uint32_t Align) {
370    auto I = Allocators.find(Id);
371    if (I == Allocators.end())
372      return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
373    auto &Allocator = I->second;
374    void *LocalAllocAddr = nullptr;
375    if (auto EC = Allocator.allocate(LocalAllocAddr, Size, Align))
376      return EC;
377
378    DEBUG(dbgs() << "  Allocator " << Id << " reserved " << LocalAllocAddr
379                 << " (" << Size << " bytes, alignment " << Align << ")\n");
380
381    TargetAddress AllocAddr =
382        static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(LocalAllocAddr));
383
384    return call<ReserveMemResponse>(Channel, AllocAddr);
385  }
386
387  std::error_code handleSetProtections(ResourceIdMgr::ResourceId Id,
388                                       TargetAddress Addr, uint32_t Flags) {
389    auto I = Allocators.find(Id);
390    if (I == Allocators.end())
391      return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
392    auto &Allocator = I->second;
393    void *LocalAddr = reinterpret_cast<void *>(static_cast<uintptr_t>(Addr));
394    DEBUG(dbgs() << "  Allocator " << Id << " set permissions on " << LocalAddr
395                 << " to " << (Flags & sys::Memory::MF_READ ? 'R' : '-')
396                 << (Flags & sys::Memory::MF_WRITE ? 'W' : '-')
397                 << (Flags & sys::Memory::MF_EXEC ? 'X' : '-') << "\n");
398    return Allocator.setProtections(LocalAddr, Flags);
399  }
400
401  std::error_code handleWriteMem(TargetAddress RDst, uint64_t Size) {
402    char *Dst = reinterpret_cast<char *>(static_cast<uintptr_t>(RDst));
403    DEBUG(dbgs() << "  Writing " << Size << " bytes to "
404                 << format("0x%016x", RDst) << "\n");
405    return Channel.readBytes(Dst, Size);
406  }
407
408  std::error_code handleWritePtr(TargetAddress Addr, TargetAddress PtrVal) {
409    DEBUG(dbgs() << "  Writing pointer *" << format("0x%016x", Addr) << " = "
410                 << format("0x%016x", PtrVal) << "\n");
411    uintptr_t *Ptr =
412        reinterpret_cast<uintptr_t *>(static_cast<uintptr_t>(Addr));
413    *Ptr = static_cast<uintptr_t>(PtrVal);
414    return std::error_code();
415  }
416
417  ChannelT &Channel;
418  SymbolLookupFtor SymbolLookup;
419  std::map<ResourceIdMgr::ResourceId, Allocator> Allocators;
420  typedef std::vector<typename TargetT::IndirectStubsInfo> ISBlockOwnerList;
421  std::map<ResourceIdMgr::ResourceId, ISBlockOwnerList> IndirectStubsOwners;
422  sys::OwningMemoryBlock ResolverBlock;
423  std::vector<sys::OwningMemoryBlock> TrampolineBlocks;
424};
425
426} // end namespace remote
427} // end namespace orc
428} // end namespace llvm
429
430#undef DEBUG_TYPE
431
432#endif
433