1356843Sdim//===- RPCUtils.h - Utilities for building RPC APIs -------------*- C++ -*-===//
2356843Sdim//
3356843Sdim// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4356843Sdim// See https://llvm.org/LICENSE.txt for license information.
5356843Sdim// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6356843Sdim//
7356843Sdim//===----------------------------------------------------------------------===//
8356843Sdim//
9356843Sdim// Utilities to support construction of simple RPC APIs.
10356843Sdim//
11356843Sdim// The RPC utilities aim for ease of use (minimal conceptual overhead) for C++
12356843Sdim// programmers, high performance, low memory overhead, and efficient use of the
13356843Sdim// communications channel.
14356843Sdim//
15356843Sdim//===----------------------------------------------------------------------===//
16356843Sdim
17356843Sdim#ifndef LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H
18356843Sdim#define LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H
19356843Sdim
20356843Sdim#include <map>
21356843Sdim#include <thread>
22356843Sdim#include <vector>
23356843Sdim
24356843Sdim#include "llvm/ADT/STLExtras.h"
25356843Sdim#include "llvm/ExecutionEngine/Orc/OrcError.h"
26356843Sdim#include "llvm/ExecutionEngine/Orc/RPC/RPCSerialization.h"
27356843Sdim#include "llvm/Support/MSVCErrorWorkarounds.h"
28356843Sdim
29356843Sdim#include <future>
30356843Sdim
31356843Sdimnamespace llvm {
32356843Sdimnamespace orc {
33356843Sdimnamespace rpc {
34356843Sdim
35356843Sdim/// Base class of all fatal RPC errors (those that necessarily result in the
36356843Sdim/// termination of the RPC session).
37356843Sdimclass RPCFatalError : public ErrorInfo<RPCFatalError> {
38356843Sdimpublic:
39356843Sdim  static char ID;
40356843Sdim};
41356843Sdim
42356843Sdim/// RPCConnectionClosed is returned from RPC operations if the RPC connection
43356843Sdim/// has already been closed due to either an error or graceful disconnection.
44356843Sdimclass ConnectionClosed : public ErrorInfo<ConnectionClosed> {
45356843Sdimpublic:
46356843Sdim  static char ID;
47356843Sdim  std::error_code convertToErrorCode() const override;
48356843Sdim  void log(raw_ostream &OS) const override;
49356843Sdim};
50356843Sdim
51356843Sdim/// BadFunctionCall is returned from handleOne when the remote makes a call with
52356843Sdim/// an unrecognized function id.
53356843Sdim///
54356843Sdim/// This error is fatal because Orc RPC needs to know how to parse a function
55356843Sdim/// call to know where the next call starts, and if it doesn't recognize the
56356843Sdim/// function id it cannot parse the call.
57356843Sdimtemplate <typename FnIdT, typename SeqNoT>
58356843Sdimclass BadFunctionCall
59356843Sdim  : public ErrorInfo<BadFunctionCall<FnIdT, SeqNoT>, RPCFatalError> {
60356843Sdimpublic:
61356843Sdim  static char ID;
62356843Sdim
63356843Sdim  BadFunctionCall(FnIdT FnId, SeqNoT SeqNo)
64356843Sdim      : FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {}
65356843Sdim
66356843Sdim  std::error_code convertToErrorCode() const override {
67356843Sdim    return orcError(OrcErrorCode::UnexpectedRPCCall);
68356843Sdim  }
69356843Sdim
70356843Sdim  void log(raw_ostream &OS) const override {
71356843Sdim    OS << "Call to invalid RPC function id '" << FnId << "' with "
72356843Sdim          "sequence number " << SeqNo;
73356843Sdim  }
74356843Sdim
75356843Sdimprivate:
76356843Sdim  FnIdT FnId;
77356843Sdim  SeqNoT SeqNo;
78356843Sdim};
79356843Sdim
80356843Sdimtemplate <typename FnIdT, typename SeqNoT>
81356843Sdimchar BadFunctionCall<FnIdT, SeqNoT>::ID = 0;
82356843Sdim
83356843Sdim/// InvalidSequenceNumberForResponse is returned from handleOne when a response
84356843Sdim/// call arrives with a sequence number that doesn't correspond to any in-flight
85356843Sdim/// function call.
86356843Sdim///
87356843Sdim/// This error is fatal because Orc RPC needs to know how to parse the rest of
88356843Sdim/// the response call to know where the next call starts, and if it doesn't have
89356843Sdim/// a result parser for this sequence number it can't do that.
90356843Sdimtemplate <typename SeqNoT>
91356843Sdimclass InvalidSequenceNumberForResponse
92356843Sdim    : public ErrorInfo<InvalidSequenceNumberForResponse<SeqNoT>, RPCFatalError> {
93356843Sdimpublic:
94356843Sdim  static char ID;
95356843Sdim
96356843Sdim  InvalidSequenceNumberForResponse(SeqNoT SeqNo)
97356843Sdim      : SeqNo(std::move(SeqNo)) {}
98356843Sdim
99356843Sdim  std::error_code convertToErrorCode() const override {
100356843Sdim    return orcError(OrcErrorCode::UnexpectedRPCCall);
101356843Sdim  };
102356843Sdim
103356843Sdim  void log(raw_ostream &OS) const override {
104356843Sdim    OS << "Response has unknown sequence number " << SeqNo;
105356843Sdim  }
106356843Sdimprivate:
107356843Sdim  SeqNoT SeqNo;
108356843Sdim};
109356843Sdim
110356843Sdimtemplate <typename SeqNoT>
111356843Sdimchar InvalidSequenceNumberForResponse<SeqNoT>::ID = 0;
112356843Sdim
113356843Sdim/// This non-fatal error will be passed to asynchronous result handlers in place
114356843Sdim/// of a result if the connection goes down before a result returns, or if the
115356843Sdim/// function to be called cannot be negotiated with the remote.
116356843Sdimclass ResponseAbandoned : public ErrorInfo<ResponseAbandoned> {
117356843Sdimpublic:
118356843Sdim  static char ID;
119356843Sdim
120356843Sdim  std::error_code convertToErrorCode() const override;
121356843Sdim  void log(raw_ostream &OS) const override;
122356843Sdim};
123356843Sdim
124356843Sdim/// This error is returned if the remote does not have a handler installed for
125356843Sdim/// the given RPC function.
126356843Sdimclass CouldNotNegotiate : public ErrorInfo<CouldNotNegotiate> {
127356843Sdimpublic:
128356843Sdim  static char ID;
129356843Sdim
130356843Sdim  CouldNotNegotiate(std::string Signature);
131356843Sdim  std::error_code convertToErrorCode() const override;
132356843Sdim  void log(raw_ostream &OS) const override;
133356843Sdim  const std::string &getSignature() const { return Signature; }
134356843Sdimprivate:
135356843Sdim  std::string Signature;
136356843Sdim};
137356843Sdim
138356843Sdimtemplate <typename DerivedFunc, typename FnT> class Function;
139356843Sdim
140356843Sdim// RPC Function class.
141356843Sdim// DerivedFunc should be a user defined class with a static 'getName()' method
142356843Sdim// returning a const char* representing the function's name.
143356843Sdimtemplate <typename DerivedFunc, typename RetT, typename... ArgTs>
144356843Sdimclass Function<DerivedFunc, RetT(ArgTs...)> {
145356843Sdimpublic:
146356843Sdim  /// User defined function type.
147356843Sdim  using Type = RetT(ArgTs...);
148356843Sdim
149356843Sdim  /// Return type.
150356843Sdim  using ReturnType = RetT;
151356843Sdim
152356843Sdim  /// Returns the full function prototype as a string.
153356843Sdim  static const char *getPrototype() {
154356843Sdim    static std::string Name = [] {
155356843Sdim      std::string Name;
156356843Sdim      raw_string_ostream(Name)
157356843Sdim          << RPCTypeName<RetT>::getName() << " " << DerivedFunc::getName()
158356843Sdim          << "(" << llvm::orc::rpc::RPCTypeNameSequence<ArgTs...>() << ")";
159356843Sdim      return Name;
160356843Sdim    }();
161356843Sdim    return Name.data();
162356843Sdim  }
163356843Sdim};
164356843Sdim
165356843Sdim/// Allocates RPC function ids during autonegotiation.
166356843Sdim/// Specializations of this class must provide four members:
167356843Sdim///
168356843Sdim/// static T getInvalidId():
169356843Sdim///   Should return a reserved id that will be used to represent missing
170356843Sdim/// functions during autonegotiation.
171356843Sdim///
172356843Sdim/// static T getResponseId():
173356843Sdim///   Should return a reserved id that will be used to send function responses
174356843Sdim/// (return values).
175356843Sdim///
176356843Sdim/// static T getNegotiateId():
177356843Sdim///   Should return a reserved id for the negotiate function, which will be used
178356843Sdim/// to negotiate ids for user defined functions.
179356843Sdim///
180356843Sdim/// template <typename Func> T allocate():
181356843Sdim///   Allocate a unique id for function Func.
182356843Sdimtemplate <typename T, typename = void> class RPCFunctionIdAllocator;
183356843Sdim
184356843Sdim/// This specialization of RPCFunctionIdAllocator provides a default
185356843Sdim/// implementation for integral types.
186356843Sdimtemplate <typename T>
187356843Sdimclass RPCFunctionIdAllocator<
188356843Sdim    T, typename std::enable_if<std::is_integral<T>::value>::type> {
189356843Sdimpublic:
190356843Sdim  static T getInvalidId() { return T(0); }
191356843Sdim  static T getResponseId() { return T(1); }
192356843Sdim  static T getNegotiateId() { return T(2); }
193356843Sdim
194356843Sdim  template <typename Func> T allocate() { return NextId++; }
195356843Sdim
196356843Sdimprivate:
197356843Sdim  T NextId = 3;
198356843Sdim};
199356843Sdim
200356843Sdimnamespace detail {
201356843Sdim
202356843Sdim/// Provides a typedef for a tuple containing the decayed argument types.
203356843Sdimtemplate <typename T> class FunctionArgsTuple;
204356843Sdim
205356843Sdimtemplate <typename RetT, typename... ArgTs>
206356843Sdimclass FunctionArgsTuple<RetT(ArgTs...)> {
207356843Sdimpublic:
208356843Sdim  using Type = std::tuple<typename std::decay<
209356843Sdim      typename std::remove_reference<ArgTs>::type>::type...>;
210356843Sdim};
211356843Sdim
212356843Sdim// ResultTraits provides typedefs and utilities specific to the return type
213356843Sdim// of functions.
214356843Sdimtemplate <typename RetT> class ResultTraits {
215356843Sdimpublic:
216356843Sdim  // The return type wrapped in llvm::Expected.
217356843Sdim  using ErrorReturnType = Expected<RetT>;
218356843Sdim
219356843Sdim#ifdef _MSC_VER
220356843Sdim  // The ErrorReturnType wrapped in a std::promise.
221356843Sdim  using ReturnPromiseType = std::promise<MSVCPExpected<RetT>>;
222356843Sdim
223356843Sdim  // The ErrorReturnType wrapped in a std::future.
224356843Sdim  using ReturnFutureType = std::future<MSVCPExpected<RetT>>;
225356843Sdim#else
226356843Sdim  // The ErrorReturnType wrapped in a std::promise.
227356843Sdim  using ReturnPromiseType = std::promise<ErrorReturnType>;
228356843Sdim
229356843Sdim  // The ErrorReturnType wrapped in a std::future.
230356843Sdim  using ReturnFutureType = std::future<ErrorReturnType>;
231356843Sdim#endif
232356843Sdim
233356843Sdim  // Create a 'blank' value of the ErrorReturnType, ready and safe to
234356843Sdim  // overwrite.
235356843Sdim  static ErrorReturnType createBlankErrorReturnValue() {
236356843Sdim    return ErrorReturnType(RetT());
237356843Sdim  }
238356843Sdim
239356843Sdim  // Consume an abandoned ErrorReturnType.
240356843Sdim  static void consumeAbandoned(ErrorReturnType RetOrErr) {
241356843Sdim    consumeError(RetOrErr.takeError());
242356843Sdim  }
243356843Sdim};
244356843Sdim
245356843Sdim// ResultTraits specialization for void functions.
246356843Sdimtemplate <> class ResultTraits<void> {
247356843Sdimpublic:
248356843Sdim  // For void functions, ErrorReturnType is llvm::Error.
249356843Sdim  using ErrorReturnType = Error;
250356843Sdim
251356843Sdim#ifdef _MSC_VER
252356843Sdim  // The ErrorReturnType wrapped in a std::promise.
253356843Sdim  using ReturnPromiseType = std::promise<MSVCPError>;
254356843Sdim
255356843Sdim  // The ErrorReturnType wrapped in a std::future.
256356843Sdim  using ReturnFutureType = std::future<MSVCPError>;
257356843Sdim#else
258356843Sdim  // The ErrorReturnType wrapped in a std::promise.
259356843Sdim  using ReturnPromiseType = std::promise<ErrorReturnType>;
260356843Sdim
261356843Sdim  // The ErrorReturnType wrapped in a std::future.
262356843Sdim  using ReturnFutureType = std::future<ErrorReturnType>;
263356843Sdim#endif
264356843Sdim
265356843Sdim  // Create a 'blank' value of the ErrorReturnType, ready and safe to
266356843Sdim  // overwrite.
267356843Sdim  static ErrorReturnType createBlankErrorReturnValue() {
268356843Sdim    return ErrorReturnType::success();
269356843Sdim  }
270356843Sdim
271356843Sdim  // Consume an abandoned ErrorReturnType.
272356843Sdim  static void consumeAbandoned(ErrorReturnType Err) {
273356843Sdim    consumeError(std::move(Err));
274356843Sdim  }
275356843Sdim};
276356843Sdim
277356843Sdim// ResultTraits<Error> is equivalent to ResultTraits<void>. This allows
278356843Sdim// handlers for void RPC functions to return either void (in which case they
279356843Sdim// implicitly succeed) or Error (in which case their error return is
280356843Sdim// propagated). See usage in HandlerTraits::runHandlerHelper.
281356843Sdimtemplate <> class ResultTraits<Error> : public ResultTraits<void> {};
282356843Sdim
283356843Sdim// ResultTraits<Expected<T>> is equivalent to ResultTraits<T>. This allows
284356843Sdim// handlers for RPC functions returning a T to return either a T (in which
285356843Sdim// case they implicitly succeed) or Expected<T> (in which case their error
286356843Sdim// return is propagated). See usage in HandlerTraits::runHandlerHelper.
287356843Sdimtemplate <typename RetT>
288356843Sdimclass ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {};
289356843Sdim
290356843Sdim// Determines whether an RPC function's defined error return type supports
291356843Sdim// error return value.
292356843Sdimtemplate <typename T>
293356843Sdimclass SupportsErrorReturn {
294356843Sdimpublic:
295356843Sdim  static const bool value = false;
296356843Sdim};
297356843Sdim
298356843Sdimtemplate <>
299356843Sdimclass SupportsErrorReturn<Error> {
300356843Sdimpublic:
301356843Sdim  static const bool value = true;
302356843Sdim};
303356843Sdim
304356843Sdimtemplate <typename T>
305356843Sdimclass SupportsErrorReturn<Expected<T>> {
306356843Sdimpublic:
307356843Sdim  static const bool value = true;
308356843Sdim};
309356843Sdim
310356843Sdim// RespondHelper packages return values based on whether or not the declared
311356843Sdim// RPC function return type supports error returns.
312356843Sdimtemplate <bool FuncSupportsErrorReturn>
313356843Sdimclass RespondHelper;
314356843Sdim
315356843Sdim// RespondHelper specialization for functions that support error returns.
316356843Sdimtemplate <>
317356843Sdimclass RespondHelper<true> {
318356843Sdimpublic:
319356843Sdim
320356843Sdim  // Send Expected<T>.
321356843Sdim  template <typename WireRetT, typename HandlerRetT, typename ChannelT,
322356843Sdim            typename FunctionIdT, typename SequenceNumberT>
323356843Sdim  static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
324356843Sdim                          SequenceNumberT SeqNo,
325356843Sdim                          Expected<HandlerRetT> ResultOrErr) {
326356843Sdim    if (!ResultOrErr && ResultOrErr.template errorIsA<RPCFatalError>())
327356843Sdim      return ResultOrErr.takeError();
328356843Sdim
329356843Sdim    // Open the response message.
330356843Sdim    if (auto Err = C.startSendMessage(ResponseId, SeqNo))
331356843Sdim      return Err;
332356843Sdim
333356843Sdim    // Serialize the result.
334356843Sdim    if (auto Err =
335356843Sdim        SerializationTraits<ChannelT, WireRetT,
336356843Sdim                            Expected<HandlerRetT>>::serialize(
337356843Sdim                                                     C, std::move(ResultOrErr)))
338356843Sdim      return Err;
339356843Sdim
340356843Sdim    // Close the response message.
341356843Sdim    if (auto Err = C.endSendMessage())
342356843Sdim      return Err;
343356843Sdim    return C.send();
344356843Sdim  }
345356843Sdim
346356843Sdim  template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
347356843Sdim  static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
348356843Sdim                          SequenceNumberT SeqNo, Error Err) {
349356843Sdim    if (Err && Err.isA<RPCFatalError>())
350356843Sdim      return Err;
351356843Sdim    if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
352356843Sdim      return Err2;
353356843Sdim    if (auto Err2 = serializeSeq(C, std::move(Err)))
354356843Sdim      return Err2;
355356843Sdim    if (auto Err2 = C.endSendMessage())
356356843Sdim      return Err2;
357356843Sdim    return C.send();
358356843Sdim  }
359356843Sdim
360356843Sdim};
361356843Sdim
362356843Sdim// RespondHelper specialization for functions that do not support error returns.
363356843Sdimtemplate <>
364356843Sdimclass RespondHelper<false> {
365356843Sdimpublic:
366356843Sdim
367356843Sdim  template <typename WireRetT, typename HandlerRetT, typename ChannelT,
368356843Sdim            typename FunctionIdT, typename SequenceNumberT>
369356843Sdim  static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
370356843Sdim                          SequenceNumberT SeqNo,
371356843Sdim                          Expected<HandlerRetT> ResultOrErr) {
372356843Sdim    if (auto Err = ResultOrErr.takeError())
373356843Sdim      return Err;
374356843Sdim
375356843Sdim    // Open the response message.
376356843Sdim    if (auto Err = C.startSendMessage(ResponseId, SeqNo))
377356843Sdim      return Err;
378356843Sdim
379356843Sdim    // Serialize the result.
380356843Sdim    if (auto Err =
381356843Sdim        SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize(
382356843Sdim                                                               C, *ResultOrErr))
383356843Sdim      return Err;
384356843Sdim
385356843Sdim    // End the response message.
386356843Sdim    if (auto Err = C.endSendMessage())
387356843Sdim      return Err;
388356843Sdim
389356843Sdim    return C.send();
390356843Sdim  }
391356843Sdim
392356843Sdim  template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
393356843Sdim  static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
394356843Sdim                          SequenceNumberT SeqNo, Error Err) {
395356843Sdim    if (Err)
396356843Sdim      return Err;
397356843Sdim    if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
398356843Sdim      return Err2;
399356843Sdim    if (auto Err2 = C.endSendMessage())
400356843Sdim      return Err2;
401356843Sdim    return C.send();
402356843Sdim  }
403356843Sdim
404356843Sdim};
405356843Sdim
406356843Sdim
407356843Sdim// Send a response of the given wire return type (WireRetT) over the
408356843Sdim// channel, with the given sequence number.
409356843Sdimtemplate <typename WireRetT, typename HandlerRetT, typename ChannelT,
410356843Sdim          typename FunctionIdT, typename SequenceNumberT>
411356843SdimError respond(ChannelT &C, const FunctionIdT &ResponseId,
412356843Sdim              SequenceNumberT SeqNo, Expected<HandlerRetT> ResultOrErr) {
413356843Sdim  return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
414356843Sdim    template sendResult<WireRetT>(C, ResponseId, SeqNo, std::move(ResultOrErr));
415356843Sdim}
416356843Sdim
417356843Sdim// Send an empty response message on the given channel to indicate that
418356843Sdim// the handler ran.
419356843Sdimtemplate <typename WireRetT, typename ChannelT, typename FunctionIdT,
420356843Sdim          typename SequenceNumberT>
421356843SdimError respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo,
422356843Sdim              Error Err) {
423356843Sdim  return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
424356843Sdim    sendResult(C, ResponseId, SeqNo, std::move(Err));
425356843Sdim}
426356843Sdim
427356843Sdim// Converts a given type to the equivalent error return type.
428356843Sdimtemplate <typename T> class WrappedHandlerReturn {
429356843Sdimpublic:
430356843Sdim  using Type = Expected<T>;
431356843Sdim};
432356843Sdim
433356843Sdimtemplate <typename T> class WrappedHandlerReturn<Expected<T>> {
434356843Sdimpublic:
435356843Sdim  using Type = Expected<T>;
436356843Sdim};
437356843Sdim
438356843Sdimtemplate <> class WrappedHandlerReturn<void> {
439356843Sdimpublic:
440356843Sdim  using Type = Error;
441356843Sdim};
442356843Sdim
443356843Sdimtemplate <> class WrappedHandlerReturn<Error> {
444356843Sdimpublic:
445356843Sdim  using Type = Error;
446356843Sdim};
447356843Sdim
448356843Sdimtemplate <> class WrappedHandlerReturn<ErrorSuccess> {
449356843Sdimpublic:
450356843Sdim  using Type = Error;
451356843Sdim};
452356843Sdim
453356843Sdim// Traits class that strips the response function from the list of handler
454356843Sdim// arguments.
455356843Sdimtemplate <typename FnT> class AsyncHandlerTraits;
456356843Sdim
457356843Sdimtemplate <typename ResultT, typename... ArgTs>
458356843Sdimclass AsyncHandlerTraits<Error(std::function<Error(Expected<ResultT>)>, ArgTs...)> {
459356843Sdimpublic:
460356843Sdim  using Type = Error(ArgTs...);
461356843Sdim  using ResultType = Expected<ResultT>;
462356843Sdim};
463356843Sdim
464356843Sdimtemplate <typename... ArgTs>
465356843Sdimclass AsyncHandlerTraits<Error(std::function<Error(Error)>, ArgTs...)> {
466356843Sdimpublic:
467356843Sdim  using Type = Error(ArgTs...);
468356843Sdim  using ResultType = Error;
469356843Sdim};
470356843Sdim
471356843Sdimtemplate <typename... ArgTs>
472356843Sdimclass AsyncHandlerTraits<ErrorSuccess(std::function<Error(Error)>, ArgTs...)> {
473356843Sdimpublic:
474356843Sdim  using Type = Error(ArgTs...);
475356843Sdim  using ResultType = Error;
476356843Sdim};
477356843Sdim
478356843Sdimtemplate <typename... ArgTs>
479356843Sdimclass AsyncHandlerTraits<void(std::function<Error(Error)>, ArgTs...)> {
480356843Sdimpublic:
481356843Sdim  using Type = Error(ArgTs...);
482356843Sdim  using ResultType = Error;
483356843Sdim};
484356843Sdim
485356843Sdimtemplate <typename ResponseHandlerT, typename... ArgTs>
486356843Sdimclass AsyncHandlerTraits<Error(ResponseHandlerT, ArgTs...)> :
487356843Sdim    public AsyncHandlerTraits<Error(typename std::decay<ResponseHandlerT>::type,
488356843Sdim                                    ArgTs...)> {};
489356843Sdim
490356843Sdim// This template class provides utilities related to RPC function handlers.
491356843Sdim// The base case applies to non-function types (the template class is
492356843Sdim// specialized for function types) and inherits from the appropriate
493356843Sdim// speciilization for the given non-function type's call operator.
494356843Sdimtemplate <typename HandlerT>
495356843Sdimclass HandlerTraits : public HandlerTraits<decltype(
496356843Sdim                          &std::remove_reference<HandlerT>::type::operator())> {
497356843Sdim};
498356843Sdim
499356843Sdim// Traits for handlers with a given function type.
500356843Sdimtemplate <typename RetT, typename... ArgTs>
501356843Sdimclass HandlerTraits<RetT(ArgTs...)> {
502356843Sdimpublic:
503356843Sdim  // Function type of the handler.
504356843Sdim  using Type = RetT(ArgTs...);
505356843Sdim
506356843Sdim  // Return type of the handler.
507356843Sdim  using ReturnType = RetT;
508356843Sdim
509356843Sdim  // Call the given handler with the given arguments.
510356843Sdim  template <typename HandlerT, typename... TArgTs>
511356843Sdim  static typename WrappedHandlerReturn<RetT>::Type
512356843Sdim  unpackAndRun(HandlerT &Handler, std::tuple<TArgTs...> &Args) {
513356843Sdim    return unpackAndRunHelper(Handler, Args,
514356843Sdim                              std::index_sequence_for<TArgTs...>());
515356843Sdim  }
516356843Sdim
517356843Sdim  // Call the given handler with the given arguments.
518356843Sdim  template <typename HandlerT, typename ResponderT, typename... TArgTs>
519356843Sdim  static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder,
520356843Sdim                                 std::tuple<TArgTs...> &Args) {
521356843Sdim    return unpackAndRunAsyncHelper(Handler, Responder, Args,
522356843Sdim                                   std::index_sequence_for<TArgTs...>());
523356843Sdim  }
524356843Sdim
525356843Sdim  // Call the given handler with the given arguments.
526356843Sdim  template <typename HandlerT>
527356843Sdim  static typename std::enable_if<
528356843Sdim      std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value,
529356843Sdim      Error>::type
530356843Sdim  run(HandlerT &Handler, ArgTs &&... Args) {
531356843Sdim    Handler(std::move(Args)...);
532356843Sdim    return Error::success();
533356843Sdim  }
534356843Sdim
535356843Sdim  template <typename HandlerT, typename... TArgTs>
536356843Sdim  static typename std::enable_if<
537356843Sdim      !std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value,
538356843Sdim      typename HandlerTraits<HandlerT>::ReturnType>::type
539356843Sdim  run(HandlerT &Handler, TArgTs... Args) {
540356843Sdim    return Handler(std::move(Args)...);
541356843Sdim  }
542356843Sdim
543356843Sdim  // Serialize arguments to the channel.
544356843Sdim  template <typename ChannelT, typename... CArgTs>
545356843Sdim  static Error serializeArgs(ChannelT &C, const CArgTs... CArgs) {
546356843Sdim    return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...);
547356843Sdim  }
548356843Sdim
549356843Sdim  // Deserialize arguments from the channel.
550356843Sdim  template <typename ChannelT, typename... CArgTs>
551356843Sdim  static Error deserializeArgs(ChannelT &C, std::tuple<CArgTs...> &Args) {
552356843Sdim    return deserializeArgsHelper(C, Args, std::index_sequence_for<CArgTs...>());
553356843Sdim  }
554356843Sdim
555356843Sdimprivate:
556356843Sdim  template <typename ChannelT, typename... CArgTs, size_t... Indexes>
557356843Sdim  static Error deserializeArgsHelper(ChannelT &C, std::tuple<CArgTs...> &Args,
558356843Sdim                                     std::index_sequence<Indexes...> _) {
559356843Sdim    return SequenceSerialization<ChannelT, ArgTs...>::deserialize(
560356843Sdim        C, std::get<Indexes>(Args)...);
561356843Sdim  }
562356843Sdim
563356843Sdim  template <typename HandlerT, typename ArgTuple, size_t... Indexes>
564356843Sdim  static typename WrappedHandlerReturn<
565356843Sdim      typename HandlerTraits<HandlerT>::ReturnType>::Type
566356843Sdim  unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args,
567356843Sdim                     std::index_sequence<Indexes...>) {
568356843Sdim    return run(Handler, std::move(std::get<Indexes>(Args))...);
569356843Sdim  }
570356843Sdim
571356843Sdim  template <typename HandlerT, typename ResponderT, typename ArgTuple,
572356843Sdim            size_t... Indexes>
573356843Sdim  static typename WrappedHandlerReturn<
574356843Sdim      typename HandlerTraits<HandlerT>::ReturnType>::Type
575356843Sdim  unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder,
576356843Sdim                          ArgTuple &Args, std::index_sequence<Indexes...>) {
577356843Sdim    return run(Handler, Responder, std::move(std::get<Indexes>(Args))...);
578356843Sdim  }
579356843Sdim};
580356843Sdim
581356843Sdim// Handler traits for free functions.
582356843Sdimtemplate <typename RetT, typename... ArgTs>
583356843Sdimclass HandlerTraits<RetT(*)(ArgTs...)>
584356843Sdim  : public HandlerTraits<RetT(ArgTs...)> {};
585356843Sdim
586356843Sdim// Handler traits for class methods (especially call operators for lambdas).
587356843Sdimtemplate <typename Class, typename RetT, typename... ArgTs>
588356843Sdimclass HandlerTraits<RetT (Class::*)(ArgTs...)>
589356843Sdim    : public HandlerTraits<RetT(ArgTs...)> {};
590356843Sdim
591356843Sdim// Handler traits for const class methods (especially call operators for
592356843Sdim// lambdas).
593356843Sdimtemplate <typename Class, typename RetT, typename... ArgTs>
594356843Sdimclass HandlerTraits<RetT (Class::*)(ArgTs...) const>
595356843Sdim    : public HandlerTraits<RetT(ArgTs...)> {};
596356843Sdim
597356843Sdim// Utility to peel the Expected wrapper off a response handler error type.
598356843Sdimtemplate <typename HandlerT> class ResponseHandlerArg;
599356843Sdim
600356843Sdimtemplate <typename ArgT> class ResponseHandlerArg<Error(Expected<ArgT>)> {
601356843Sdimpublic:
602356843Sdim  using ArgType = Expected<ArgT>;
603356843Sdim  using UnwrappedArgType = ArgT;
604356843Sdim};
605356843Sdim
606356843Sdimtemplate <typename ArgT>
607356843Sdimclass ResponseHandlerArg<ErrorSuccess(Expected<ArgT>)> {
608356843Sdimpublic:
609356843Sdim  using ArgType = Expected<ArgT>;
610356843Sdim  using UnwrappedArgType = ArgT;
611356843Sdim};
612356843Sdim
613356843Sdimtemplate <> class ResponseHandlerArg<Error(Error)> {
614356843Sdimpublic:
615356843Sdim  using ArgType = Error;
616356843Sdim};
617356843Sdim
618356843Sdimtemplate <> class ResponseHandlerArg<ErrorSuccess(Error)> {
619356843Sdimpublic:
620356843Sdim  using ArgType = Error;
621356843Sdim};
622356843Sdim
623356843Sdim// ResponseHandler represents a handler for a not-yet-received function call
624356843Sdim// result.
625356843Sdimtemplate <typename ChannelT> class ResponseHandler {
626356843Sdimpublic:
627356843Sdim  virtual ~ResponseHandler() {}
628356843Sdim
629356843Sdim  // Reads the function result off the wire and acts on it. The meaning of
630356843Sdim  // "act" will depend on how this method is implemented in any given
631356843Sdim  // ResponseHandler subclass but could, for example, mean running a
632356843Sdim  // user-specified handler or setting a promise value.
633356843Sdim  virtual Error handleResponse(ChannelT &C) = 0;
634356843Sdim
635356843Sdim  // Abandons this outstanding result.
636356843Sdim  virtual void abandon() = 0;
637356843Sdim
638356843Sdim  // Create an error instance representing an abandoned response.
639356843Sdim  static Error createAbandonedResponseError() {
640356843Sdim    return make_error<ResponseAbandoned>();
641356843Sdim  }
642356843Sdim};
643356843Sdim
644356843Sdim// ResponseHandler subclass for RPC functions with non-void returns.
645356843Sdimtemplate <typename ChannelT, typename FuncRetT, typename HandlerT>
646356843Sdimclass ResponseHandlerImpl : public ResponseHandler<ChannelT> {
647356843Sdimpublic:
648356843Sdim  ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
649356843Sdim
650356843Sdim  // Handle the result by deserializing it from the channel then passing it
651356843Sdim  // to the user defined handler.
652356843Sdim  Error handleResponse(ChannelT &C) override {
653356843Sdim    using UnwrappedArgType = typename ResponseHandlerArg<
654356843Sdim        typename HandlerTraits<HandlerT>::Type>::UnwrappedArgType;
655356843Sdim    UnwrappedArgType Result;
656356843Sdim    if (auto Err =
657356843Sdim            SerializationTraits<ChannelT, FuncRetT,
658356843Sdim                                UnwrappedArgType>::deserialize(C, Result))
659356843Sdim      return Err;
660356843Sdim    if (auto Err = C.endReceiveMessage())
661356843Sdim      return Err;
662356843Sdim    return Handler(std::move(Result));
663356843Sdim  }
664356843Sdim
665356843Sdim  // Abandon this response by calling the handler with an 'abandoned response'
666356843Sdim  // error.
667356843Sdim  void abandon() override {
668356843Sdim    if (auto Err = Handler(this->createAbandonedResponseError())) {
669356843Sdim      // Handlers should not fail when passed an abandoned response error.
670356843Sdim      report_fatal_error(std::move(Err));
671356843Sdim    }
672356843Sdim  }
673356843Sdim
674356843Sdimprivate:
675356843Sdim  HandlerT Handler;
676356843Sdim};
677356843Sdim
678356843Sdim// ResponseHandler subclass for RPC functions with void returns.
679356843Sdimtemplate <typename ChannelT, typename HandlerT>
680356843Sdimclass ResponseHandlerImpl<ChannelT, void, HandlerT>
681356843Sdim    : public ResponseHandler<ChannelT> {
682356843Sdimpublic:
683356843Sdim  ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
684356843Sdim
685356843Sdim  // Handle the result (no actual value, just a notification that the function
686356843Sdim  // has completed on the remote end) by calling the user-defined handler with
687356843Sdim  // Error::success().
688356843Sdim  Error handleResponse(ChannelT &C) override {
689356843Sdim    if (auto Err = C.endReceiveMessage())
690356843Sdim      return Err;
691356843Sdim    return Handler(Error::success());
692356843Sdim  }
693356843Sdim
694356843Sdim  // Abandon this response by calling the handler with an 'abandoned response'
695356843Sdim  // error.
696356843Sdim  void abandon() override {
697356843Sdim    if (auto Err = Handler(this->createAbandonedResponseError())) {
698356843Sdim      // Handlers should not fail when passed an abandoned response error.
699356843Sdim      report_fatal_error(std::move(Err));
700356843Sdim    }
701356843Sdim  }
702356843Sdim
703356843Sdimprivate:
704356843Sdim  HandlerT Handler;
705356843Sdim};
706356843Sdim
707356843Sdimtemplate <typename ChannelT, typename FuncRetT, typename HandlerT>
708356843Sdimclass ResponseHandlerImpl<ChannelT, Expected<FuncRetT>, HandlerT>
709356843Sdim    : public ResponseHandler<ChannelT> {
710356843Sdimpublic:
711356843Sdim  ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
712356843Sdim
713356843Sdim  // Handle the result by deserializing it from the channel then passing it
714356843Sdim  // to the user defined handler.
715356843Sdim  Error handleResponse(ChannelT &C) override {
716356843Sdim    using HandlerArgType = typename ResponseHandlerArg<
717356843Sdim        typename HandlerTraits<HandlerT>::Type>::ArgType;
718356843Sdim    HandlerArgType Result((typename HandlerArgType::value_type()));
719356843Sdim
720356843Sdim    if (auto Err =
721356843Sdim            SerializationTraits<ChannelT, Expected<FuncRetT>,
722356843Sdim                                HandlerArgType>::deserialize(C, Result))
723356843Sdim      return Err;
724356843Sdim    if (auto Err = C.endReceiveMessage())
725356843Sdim      return Err;
726356843Sdim    return Handler(std::move(Result));
727356843Sdim  }
728356843Sdim
729356843Sdim  // Abandon this response by calling the handler with an 'abandoned response'
730356843Sdim  // error.
731356843Sdim  void abandon() override {
732356843Sdim    if (auto Err = Handler(this->createAbandonedResponseError())) {
733356843Sdim      // Handlers should not fail when passed an abandoned response error.
734356843Sdim      report_fatal_error(std::move(Err));
735356843Sdim    }
736356843Sdim  }
737356843Sdim
738356843Sdimprivate:
739356843Sdim  HandlerT Handler;
740356843Sdim};
741356843Sdim
742356843Sdimtemplate <typename ChannelT, typename HandlerT>
743356843Sdimclass ResponseHandlerImpl<ChannelT, Error, HandlerT>
744356843Sdim    : public ResponseHandler<ChannelT> {
745356843Sdimpublic:
746356843Sdim  ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
747356843Sdim
748356843Sdim  // Handle the result by deserializing it from the channel then passing it
749356843Sdim  // to the user defined handler.
750356843Sdim  Error handleResponse(ChannelT &C) override {
751356843Sdim    Error Result = Error::success();
752356843Sdim    if (auto Err = SerializationTraits<ChannelT, Error, Error>::deserialize(
753356843Sdim            C, Result)) {
754356843Sdim      consumeError(std::move(Result));
755356843Sdim      return Err;
756356843Sdim    }
757356843Sdim    if (auto Err = C.endReceiveMessage()) {
758356843Sdim      consumeError(std::move(Result));
759356843Sdim      return Err;
760356843Sdim    }
761356843Sdim    return Handler(std::move(Result));
762356843Sdim  }
763356843Sdim
764356843Sdim  // Abandon this response by calling the handler with an 'abandoned response'
765356843Sdim  // error.
766356843Sdim  void abandon() override {
767356843Sdim    if (auto Err = Handler(this->createAbandonedResponseError())) {
768356843Sdim      // Handlers should not fail when passed an abandoned response error.
769356843Sdim      report_fatal_error(std::move(Err));
770356843Sdim    }
771356843Sdim  }
772356843Sdim
773356843Sdimprivate:
774356843Sdim  HandlerT Handler;
775356843Sdim};
776356843Sdim
777356843Sdim// Create a ResponseHandler from a given user handler.
778356843Sdimtemplate <typename ChannelT, typename FuncRetT, typename HandlerT>
779356843Sdimstd::unique_ptr<ResponseHandler<ChannelT>> createResponseHandler(HandlerT H) {
780356843Sdim  return std::make_unique<ResponseHandlerImpl<ChannelT, FuncRetT, HandlerT>>(
781356843Sdim      std::move(H));
782356843Sdim}
783356843Sdim
784356843Sdim// Helper for wrapping member functions up as functors. This is useful for
785356843Sdim// installing methods as result handlers.
786356843Sdimtemplate <typename ClassT, typename RetT, typename... ArgTs>
787356843Sdimclass MemberFnWrapper {
788356843Sdimpublic:
789356843Sdim  using MethodT = RetT (ClassT::*)(ArgTs...);
790356843Sdim  MemberFnWrapper(ClassT &Instance, MethodT Method)
791356843Sdim      : Instance(Instance), Method(Method) {}
792356843Sdim  RetT operator()(ArgTs &&... Args) {
793356843Sdim    return (Instance.*Method)(std::move(Args)...);
794356843Sdim  }
795356843Sdim
796356843Sdimprivate:
797356843Sdim  ClassT &Instance;
798356843Sdim  MethodT Method;
799356843Sdim};
800356843Sdim
801356843Sdim// Helper that provides a Functor for deserializing arguments.
802356843Sdimtemplate <typename... ArgTs> class ReadArgs {
803356843Sdimpublic:
804356843Sdim  Error operator()() { return Error::success(); }
805356843Sdim};
806356843Sdim
807356843Sdimtemplate <typename ArgT, typename... ArgTs>
808356843Sdimclass ReadArgs<ArgT, ArgTs...> : public ReadArgs<ArgTs...> {
809356843Sdimpublic:
810356843Sdim  ReadArgs(ArgT &Arg, ArgTs &... Args)
811356843Sdim      : ReadArgs<ArgTs...>(Args...), Arg(Arg) {}
812356843Sdim
813356843Sdim  Error operator()(ArgT &ArgVal, ArgTs &... ArgVals) {
814356843Sdim    this->Arg = std::move(ArgVal);
815356843Sdim    return ReadArgs<ArgTs...>::operator()(ArgVals...);
816356843Sdim  }
817356843Sdim
818356843Sdimprivate:
819356843Sdim  ArgT &Arg;
820356843Sdim};
821356843Sdim
822356843Sdim// Manage sequence numbers.
823356843Sdimtemplate <typename SequenceNumberT> class SequenceNumberManager {
824356843Sdimpublic:
825356843Sdim  // Reset, making all sequence numbers available.
826356843Sdim  void reset() {
827356843Sdim    std::lock_guard<std::mutex> Lock(SeqNoLock);
828356843Sdim    NextSequenceNumber = 0;
829356843Sdim    FreeSequenceNumbers.clear();
830356843Sdim  }
831356843Sdim
832356843Sdim  // Get the next available sequence number. Will re-use numbers that have
833356843Sdim  // been released.
834356843Sdim  SequenceNumberT getSequenceNumber() {
835356843Sdim    std::lock_guard<std::mutex> Lock(SeqNoLock);
836356843Sdim    if (FreeSequenceNumbers.empty())
837356843Sdim      return NextSequenceNumber++;
838356843Sdim    auto SequenceNumber = FreeSequenceNumbers.back();
839356843Sdim    FreeSequenceNumbers.pop_back();
840356843Sdim    return SequenceNumber;
841356843Sdim  }
842356843Sdim
843356843Sdim  // Release a sequence number, making it available for re-use.
844356843Sdim  void releaseSequenceNumber(SequenceNumberT SequenceNumber) {
845356843Sdim    std::lock_guard<std::mutex> Lock(SeqNoLock);
846356843Sdim    FreeSequenceNumbers.push_back(SequenceNumber);
847356843Sdim  }
848356843Sdim
849356843Sdimprivate:
850356843Sdim  std::mutex SeqNoLock;
851356843Sdim  SequenceNumberT NextSequenceNumber = 0;
852356843Sdim  std::vector<SequenceNumberT> FreeSequenceNumbers;
853356843Sdim};
854356843Sdim
855356843Sdim// Checks that predicate P holds for each corresponding pair of type arguments
856356843Sdim// from T1 and T2 tuple.
857356843Sdimtemplate <template <class, class> class P, typename T1Tuple, typename T2Tuple>
858356843Sdimclass RPCArgTypeCheckHelper;
859356843Sdim
860356843Sdimtemplate <template <class, class> class P>
861356843Sdimclass RPCArgTypeCheckHelper<P, std::tuple<>, std::tuple<>> {
862356843Sdimpublic:
863356843Sdim  static const bool value = true;
864356843Sdim};
865356843Sdim
866356843Sdimtemplate <template <class, class> class P, typename T, typename... Ts,
867356843Sdim          typename U, typename... Us>
868356843Sdimclass RPCArgTypeCheckHelper<P, std::tuple<T, Ts...>, std::tuple<U, Us...>> {
869356843Sdimpublic:
870356843Sdim  static const bool value =
871356843Sdim      P<T, U>::value &&
872356843Sdim      RPCArgTypeCheckHelper<P, std::tuple<Ts...>, std::tuple<Us...>>::value;
873356843Sdim};
874356843Sdim
875356843Sdimtemplate <template <class, class> class P, typename T1Sig, typename T2Sig>
876356843Sdimclass RPCArgTypeCheck {
877356843Sdimpublic:
878356843Sdim  using T1Tuple = typename FunctionArgsTuple<T1Sig>::Type;
879356843Sdim  using T2Tuple = typename FunctionArgsTuple<T2Sig>::Type;
880356843Sdim
881356843Sdim  static_assert(std::tuple_size<T1Tuple>::value >=
882356843Sdim                    std::tuple_size<T2Tuple>::value,
883356843Sdim                "Too many arguments to RPC call");
884356843Sdim  static_assert(std::tuple_size<T1Tuple>::value <=
885356843Sdim                    std::tuple_size<T2Tuple>::value,
886356843Sdim                "Too few arguments to RPC call");
887356843Sdim
888356843Sdim  static const bool value = RPCArgTypeCheckHelper<P, T1Tuple, T2Tuple>::value;
889356843Sdim};
890356843Sdim
891356843Sdimtemplate <typename ChannelT, typename WireT, typename ConcreteT>
892356843Sdimclass CanSerialize {
893356843Sdimprivate:
894356843Sdim  using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
895356843Sdim
896356843Sdim  template <typename T>
897356843Sdim  static std::true_type
898356843Sdim  check(typename std::enable_if<
899356843Sdim        std::is_same<decltype(T::serialize(std::declval<ChannelT &>(),
900356843Sdim                                           std::declval<const ConcreteT &>())),
901356843Sdim                     Error>::value,
902356843Sdim        void *>::type);
903356843Sdim
904356843Sdim  template <typename> static std::false_type check(...);
905356843Sdim
906356843Sdimpublic:
907356843Sdim  static const bool value = decltype(check<S>(0))::value;
908356843Sdim};
909356843Sdim
910356843Sdimtemplate <typename ChannelT, typename WireT, typename ConcreteT>
911356843Sdimclass CanDeserialize {
912356843Sdimprivate:
913356843Sdim  using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
914356843Sdim
915356843Sdim  template <typename T>
916356843Sdim  static std::true_type
917356843Sdim  check(typename std::enable_if<
918356843Sdim        std::is_same<decltype(T::deserialize(std::declval<ChannelT &>(),
919356843Sdim                                             std::declval<ConcreteT &>())),
920356843Sdim                     Error>::value,
921356843Sdim        void *>::type);
922356843Sdim
923356843Sdim  template <typename> static std::false_type check(...);
924356843Sdim
925356843Sdimpublic:
926356843Sdim  static const bool value = decltype(check<S>(0))::value;
927356843Sdim};
928356843Sdim
929356843Sdim/// Contains primitive utilities for defining, calling and handling calls to
930356843Sdim/// remote procedures. ChannelT is a bidirectional stream conforming to the
931356843Sdim/// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure
932356843Sdim/// identifier type that must be serializable on ChannelT, and SequenceNumberT
933356843Sdim/// is an integral type that will be used to number in-flight function calls.
934356843Sdim///
935356843Sdim/// These utilities support the construction of very primitive RPC utilities.
936356843Sdim/// Their intent is to ensure correct serialization and deserialization of
937356843Sdim/// procedure arguments, and to keep the client and server's view of the API in
938356843Sdim/// sync.
939356843Sdimtemplate <typename ImplT, typename ChannelT, typename FunctionIdT,
940356843Sdim          typename SequenceNumberT>
941356843Sdimclass RPCEndpointBase {
942356843Sdimprotected:
943356843Sdim  class OrcRPCInvalid : public Function<OrcRPCInvalid, void()> {
944356843Sdim  public:
945356843Sdim    static const char *getName() { return "__orc_rpc$invalid"; }
946356843Sdim  };
947356843Sdim
948356843Sdim  class OrcRPCResponse : public Function<OrcRPCResponse, void()> {
949356843Sdim  public:
950356843Sdim    static const char *getName() { return "__orc_rpc$response"; }
951356843Sdim  };
952356843Sdim
953356843Sdim  class OrcRPCNegotiate
954356843Sdim      : public Function<OrcRPCNegotiate, FunctionIdT(std::string)> {
955356843Sdim  public:
956356843Sdim    static const char *getName() { return "__orc_rpc$negotiate"; }
957356843Sdim  };
958356843Sdim
959356843Sdim  // Helper predicate for testing for the presence of SerializeTraits
960356843Sdim  // serializers.
961356843Sdim  template <typename WireT, typename ConcreteT>
962356843Sdim  class CanSerializeCheck : detail::CanSerialize<ChannelT, WireT, ConcreteT> {
963356843Sdim  public:
964356843Sdim    using detail::CanSerialize<ChannelT, WireT, ConcreteT>::value;
965356843Sdim
966356843Sdim    static_assert(value, "Missing serializer for argument (Can't serialize the "
967356843Sdim                         "first template type argument of CanSerializeCheck "
968356843Sdim                         "from the second)");
969356843Sdim  };
970356843Sdim
971356843Sdim  // Helper predicate for testing for the presence of SerializeTraits
972356843Sdim  // deserializers.
973356843Sdim  template <typename WireT, typename ConcreteT>
974356843Sdim  class CanDeserializeCheck
975356843Sdim      : detail::CanDeserialize<ChannelT, WireT, ConcreteT> {
976356843Sdim  public:
977356843Sdim    using detail::CanDeserialize<ChannelT, WireT, ConcreteT>::value;
978356843Sdim
979356843Sdim    static_assert(value, "Missing deserializer for argument (Can't deserialize "
980356843Sdim                         "the second template type argument of "
981356843Sdim                         "CanDeserializeCheck from the first)");
982356843Sdim  };
983356843Sdim
984356843Sdimpublic:
985356843Sdim  /// Construct an RPC instance on a channel.
986356843Sdim  RPCEndpointBase(ChannelT &C, bool LazyAutoNegotiation)
987356843Sdim      : C(C), LazyAutoNegotiation(LazyAutoNegotiation) {
988356843Sdim    // Hold ResponseId in a special variable, since we expect Response to be
989356843Sdim    // called relatively frequently, and want to avoid the map lookup.
990356843Sdim    ResponseId = FnIdAllocator.getResponseId();
991356843Sdim    RemoteFunctionIds[OrcRPCResponse::getPrototype()] = ResponseId;
992356843Sdim
993356843Sdim    // Register the negotiate function id and handler.
994356843Sdim    auto NegotiateId = FnIdAllocator.getNegotiateId();
995356843Sdim    RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId;
996356843Sdim    Handlers[NegotiateId] = wrapHandler<OrcRPCNegotiate>(
997356843Sdim        [this](const std::string &Name) { return handleNegotiate(Name); });
998356843Sdim  }
999356843Sdim
1000356843Sdim
1001356843Sdim  /// Negotiate a function id for Func with the other end of the channel.
1002356843Sdim  template <typename Func> Error negotiateFunction(bool Retry = false) {
1003356843Sdim    return getRemoteFunctionId<Func>(true, Retry).takeError();
1004356843Sdim  }
1005356843Sdim
1006356843Sdim  /// Append a call Func, does not call send on the channel.
1007356843Sdim  /// The first argument specifies a user-defined handler to be run when the
1008356843Sdim  /// function returns. The handler should take an Expected<Func::ReturnType>,
1009356843Sdim  /// or an Error (if Func::ReturnType is void). The handler will be called
1010356843Sdim  /// with an error if the return value is abandoned due to a channel error.
1011356843Sdim  template <typename Func, typename HandlerT, typename... ArgTs>
1012356843Sdim  Error appendCallAsync(HandlerT Handler, const ArgTs &... Args) {
1013356843Sdim
1014356843Sdim    static_assert(
1015356843Sdim        detail::RPCArgTypeCheck<CanSerializeCheck, typename Func::Type,
1016356843Sdim                                void(ArgTs...)>::value,
1017356843Sdim        "");
1018356843Sdim
1019356843Sdim    // Look up the function ID.
1020356843Sdim    FunctionIdT FnId;
1021356843Sdim    if (auto FnIdOrErr = getRemoteFunctionId<Func>(LazyAutoNegotiation, false))
1022356843Sdim      FnId = *FnIdOrErr;
1023356843Sdim    else {
1024356843Sdim      // Negotiation failed. Notify the handler then return the negotiate-failed
1025356843Sdim      // error.
1026356843Sdim      cantFail(Handler(make_error<ResponseAbandoned>()));
1027356843Sdim      return FnIdOrErr.takeError();
1028356843Sdim    }
1029356843Sdim
1030356843Sdim    SequenceNumberT SeqNo; // initialized in locked scope below.
1031356843Sdim    {
1032356843Sdim      // Lock the pending responses map and sequence number manager.
1033356843Sdim      std::lock_guard<std::mutex> Lock(ResponsesMutex);
1034356843Sdim
1035356843Sdim      // Allocate a sequence number.
1036356843Sdim      SeqNo = SequenceNumberMgr.getSequenceNumber();
1037356843Sdim      assert(!PendingResponses.count(SeqNo) &&
1038356843Sdim             "Sequence number already allocated");
1039356843Sdim
1040356843Sdim      // Install the user handler.
1041356843Sdim      PendingResponses[SeqNo] =
1042356843Sdim        detail::createResponseHandler<ChannelT, typename Func::ReturnType>(
1043356843Sdim            std::move(Handler));
1044356843Sdim    }
1045356843Sdim
1046356843Sdim    // Open the function call message.
1047356843Sdim    if (auto Err = C.startSendMessage(FnId, SeqNo)) {
1048356843Sdim      abandonPendingResponses();
1049356843Sdim      return Err;
1050356843Sdim    }
1051356843Sdim
1052356843Sdim    // Serialize the call arguments.
1053356843Sdim    if (auto Err = detail::HandlerTraits<typename Func::Type>::serializeArgs(
1054356843Sdim            C, Args...)) {
1055356843Sdim      abandonPendingResponses();
1056356843Sdim      return Err;
1057356843Sdim    }
1058356843Sdim
1059356843Sdim    // Close the function call messagee.
1060356843Sdim    if (auto Err = C.endSendMessage()) {
1061356843Sdim      abandonPendingResponses();
1062356843Sdim      return Err;
1063356843Sdim    }
1064356843Sdim
1065356843Sdim    return Error::success();
1066356843Sdim  }
1067356843Sdim
1068356843Sdim  Error sendAppendedCalls() { return C.send(); };
1069356843Sdim
1070356843Sdim  template <typename Func, typename HandlerT, typename... ArgTs>
1071356843Sdim  Error callAsync(HandlerT Handler, const ArgTs &... Args) {
1072356843Sdim    if (auto Err = appendCallAsync<Func>(std::move(Handler), Args...))
1073356843Sdim      return Err;
1074356843Sdim    return C.send();
1075356843Sdim  }
1076356843Sdim
1077356843Sdim  /// Handle one incoming call.
1078356843Sdim  Error handleOne() {
1079356843Sdim    FunctionIdT FnId;
1080356843Sdim    SequenceNumberT SeqNo;
1081356843Sdim    if (auto Err = C.startReceiveMessage(FnId, SeqNo)) {
1082356843Sdim      abandonPendingResponses();
1083356843Sdim      return Err;
1084356843Sdim    }
1085356843Sdim    if (FnId == ResponseId)
1086356843Sdim      return handleResponse(SeqNo);
1087356843Sdim    auto I = Handlers.find(FnId);
1088356843Sdim    if (I != Handlers.end())
1089356843Sdim      return I->second(C, SeqNo);
1090356843Sdim
1091356843Sdim    // else: No handler found. Report error to client?
1092356843Sdim    return make_error<BadFunctionCall<FunctionIdT, SequenceNumberT>>(FnId,
1093356843Sdim                                                                     SeqNo);
1094356843Sdim  }
1095356843Sdim
1096356843Sdim  /// Helper for handling setter procedures - this method returns a functor that
1097356843Sdim  /// sets the variables referred to by Args... to values deserialized from the
1098356843Sdim  /// channel.
1099356843Sdim  /// E.g.
1100356843Sdim  ///
1101356843Sdim  ///   typedef Function<0, bool, int> Func1;
1102356843Sdim  ///
1103356843Sdim  ///   ...
1104356843Sdim  ///   bool B;
1105356843Sdim  ///   int I;
1106356843Sdim  ///   if (auto Err = expect<Func1>(Channel, readArgs(B, I)))
1107356843Sdim  ///     /* Handle Args */ ;
1108356843Sdim  ///
1109356843Sdim  template <typename... ArgTs>
1110356843Sdim  static detail::ReadArgs<ArgTs...> readArgs(ArgTs &... Args) {
1111356843Sdim    return detail::ReadArgs<ArgTs...>(Args...);
1112356843Sdim  }
1113356843Sdim
1114356843Sdim  /// Abandon all outstanding result handlers.
1115356843Sdim  ///
1116356843Sdim  /// This will call all currently registered result handlers to receive an
1117356843Sdim  /// "abandoned" error as their argument. This is used internally by the RPC
1118356843Sdim  /// in error situations, but can also be called directly by clients who are
1119356843Sdim  /// disconnecting from the remote and don't or can't expect responses to their
1120356843Sdim  /// outstanding calls. (Especially for outstanding blocking calls, calling
1121356843Sdim  /// this function may be necessary to avoid dead threads).
1122356843Sdim  void abandonPendingResponses() {
1123356843Sdim    // Lock the pending responses map and sequence number manager.
1124356843Sdim    std::lock_guard<std::mutex> Lock(ResponsesMutex);
1125356843Sdim
1126356843Sdim    for (auto &KV : PendingResponses)
1127356843Sdim      KV.second->abandon();
1128356843Sdim    PendingResponses.clear();
1129356843Sdim    SequenceNumberMgr.reset();
1130356843Sdim  }
1131356843Sdim
1132356843Sdim  /// Remove the handler for the given function.
1133356843Sdim  /// A handler must currently be registered for this function.
1134356843Sdim  template <typename Func>
1135356843Sdim  void removeHandler() {
1136356843Sdim    auto IdItr = LocalFunctionIds.find(Func::getPrototype());
1137356843Sdim    assert(IdItr != LocalFunctionIds.end() &&
1138356843Sdim           "Function does not have a registered handler");
1139356843Sdim    auto HandlerItr = Handlers.find(IdItr->second);
1140356843Sdim    assert(HandlerItr != Handlers.end() &&
1141356843Sdim           "Function does not have a registered handler");
1142356843Sdim    Handlers.erase(HandlerItr);
1143356843Sdim  }
1144356843Sdim
1145356843Sdim  /// Clear all handlers.
1146356843Sdim  void clearHandlers() {
1147356843Sdim    Handlers.clear();
1148356843Sdim  }
1149356843Sdim
1150356843Sdimprotected:
1151356843Sdim
1152356843Sdim  FunctionIdT getInvalidFunctionId() const {
1153356843Sdim    return FnIdAllocator.getInvalidId();
1154356843Sdim  }
1155356843Sdim
1156356843Sdim  /// Add the given handler to the handler map and make it available for
1157356843Sdim  /// autonegotiation and execution.
1158356843Sdim  template <typename Func, typename HandlerT>
1159356843Sdim  void addHandlerImpl(HandlerT Handler) {
1160356843Sdim
1161356843Sdim    static_assert(detail::RPCArgTypeCheck<
1162356843Sdim                      CanDeserializeCheck, typename Func::Type,
1163356843Sdim                      typename detail::HandlerTraits<HandlerT>::Type>::value,
1164356843Sdim                  "");
1165356843Sdim
1166356843Sdim    FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
1167356843Sdim    LocalFunctionIds[Func::getPrototype()] = NewFnId;
1168356843Sdim    Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler));
1169356843Sdim  }
1170356843Sdim
1171356843Sdim  template <typename Func, typename HandlerT>
1172356843Sdim  void addAsyncHandlerImpl(HandlerT Handler) {
1173356843Sdim
1174356843Sdim    static_assert(detail::RPCArgTypeCheck<
1175356843Sdim                      CanDeserializeCheck, typename Func::Type,
1176356843Sdim                      typename detail::AsyncHandlerTraits<
1177356843Sdim                        typename detail::HandlerTraits<HandlerT>::Type
1178356843Sdim                      >::Type>::value,
1179356843Sdim                  "");
1180356843Sdim
1181356843Sdim    FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
1182356843Sdim    LocalFunctionIds[Func::getPrototype()] = NewFnId;
1183356843Sdim    Handlers[NewFnId] = wrapAsyncHandler<Func>(std::move(Handler));
1184356843Sdim  }
1185356843Sdim
1186356843Sdim  Error handleResponse(SequenceNumberT SeqNo) {
1187356843Sdim    using Handler = typename decltype(PendingResponses)::mapped_type;
1188356843Sdim    Handler PRHandler;
1189356843Sdim
1190356843Sdim    {
1191356843Sdim      // Lock the pending responses map and sequence number manager.
1192356843Sdim      std::unique_lock<std::mutex> Lock(ResponsesMutex);
1193356843Sdim      auto I = PendingResponses.find(SeqNo);
1194356843Sdim
1195356843Sdim      if (I != PendingResponses.end()) {
1196356843Sdim        PRHandler = std::move(I->second);
1197356843Sdim        PendingResponses.erase(I);
1198356843Sdim        SequenceNumberMgr.releaseSequenceNumber(SeqNo);
1199356843Sdim      } else {
1200356843Sdim        // Unlock the pending results map to prevent recursive lock.
1201356843Sdim        Lock.unlock();
1202356843Sdim        abandonPendingResponses();
1203356843Sdim        return make_error<
1204356843Sdim                 InvalidSequenceNumberForResponse<SequenceNumberT>>(SeqNo);
1205356843Sdim      }
1206356843Sdim    }
1207356843Sdim
1208356843Sdim    assert(PRHandler &&
1209356843Sdim           "If we didn't find a response handler we should have bailed out");
1210356843Sdim
1211356843Sdim    if (auto Err = PRHandler->handleResponse(C)) {
1212356843Sdim      abandonPendingResponses();
1213356843Sdim      return Err;
1214356843Sdim    }
1215356843Sdim
1216356843Sdim    return Error::success();
1217356843Sdim  }
1218356843Sdim
1219356843Sdim  FunctionIdT handleNegotiate(const std::string &Name) {
1220356843Sdim    auto I = LocalFunctionIds.find(Name);
1221356843Sdim    if (I == LocalFunctionIds.end())
1222356843Sdim      return getInvalidFunctionId();
1223356843Sdim    return I->second;
1224356843Sdim  }
1225356843Sdim
1226356843Sdim  // Find the remote FunctionId for the given function.
1227356843Sdim  template <typename Func>
1228356843Sdim  Expected<FunctionIdT> getRemoteFunctionId(bool NegotiateIfNotInMap,
1229356843Sdim                                            bool NegotiateIfInvalid) {
1230356843Sdim    bool DoNegotiate;
1231356843Sdim
1232356843Sdim    // Check if we already have a function id...
1233356843Sdim    auto I = RemoteFunctionIds.find(Func::getPrototype());
1234356843Sdim    if (I != RemoteFunctionIds.end()) {
1235356843Sdim      // If it's valid there's nothing left to do.
1236356843Sdim      if (I->second != getInvalidFunctionId())
1237356843Sdim        return I->second;
1238356843Sdim      DoNegotiate = NegotiateIfInvalid;
1239356843Sdim    } else
1240356843Sdim      DoNegotiate = NegotiateIfNotInMap;
1241356843Sdim
1242356843Sdim    // We don't have a function id for Func yet, but we're allowed to try to
1243356843Sdim    // negotiate one.
1244356843Sdim    if (DoNegotiate) {
1245356843Sdim      auto &Impl = static_cast<ImplT &>(*this);
1246356843Sdim      if (auto RemoteIdOrErr =
1247356843Sdim          Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) {
1248356843Sdim        RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr;
1249356843Sdim        if (*RemoteIdOrErr == getInvalidFunctionId())
1250356843Sdim          return make_error<CouldNotNegotiate>(Func::getPrototype());
1251356843Sdim        return *RemoteIdOrErr;
1252356843Sdim      } else
1253356843Sdim        return RemoteIdOrErr.takeError();
1254356843Sdim    }
1255356843Sdim
1256356843Sdim    // No key was available in the map and we weren't allowed to try to
1257356843Sdim    // negotiate one, so return an unknown function error.
1258356843Sdim    return make_error<CouldNotNegotiate>(Func::getPrototype());
1259356843Sdim  }
1260356843Sdim
1261356843Sdim  using WrappedHandlerFn = std::function<Error(ChannelT &, SequenceNumberT)>;
1262356843Sdim
1263356843Sdim  // Wrap the given user handler in the necessary argument-deserialization code,
1264356843Sdim  // result-serialization code, and call to the launch policy (if present).
1265356843Sdim  template <typename Func, typename HandlerT>
1266356843Sdim  WrappedHandlerFn wrapHandler(HandlerT Handler) {
1267356843Sdim    return [this, Handler](ChannelT &Channel,
1268356843Sdim                           SequenceNumberT SeqNo) mutable -> Error {
1269356843Sdim      // Start by deserializing the arguments.
1270356843Sdim      using ArgsTuple =
1271356843Sdim          typename detail::FunctionArgsTuple<
1272356843Sdim            typename detail::HandlerTraits<HandlerT>::Type>::Type;
1273356843Sdim      auto Args = std::make_shared<ArgsTuple>();
1274356843Sdim
1275356843Sdim      if (auto Err =
1276356843Sdim              detail::HandlerTraits<typename Func::Type>::deserializeArgs(
1277356843Sdim                  Channel, *Args))
1278356843Sdim        return Err;
1279356843Sdim
1280356843Sdim      // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
1281356843Sdim      // for RPCArgs. Void cast RPCArgs to work around this for now.
1282356843Sdim      // FIXME: Remove this workaround once we can assume a working GCC version.
1283356843Sdim      (void)Args;
1284356843Sdim
1285356843Sdim      // End receieve message, unlocking the channel for reading.
1286356843Sdim      if (auto Err = Channel.endReceiveMessage())
1287356843Sdim        return Err;
1288356843Sdim
1289356843Sdim      using HTraits = detail::HandlerTraits<HandlerT>;
1290356843Sdim      using FuncReturn = typename Func::ReturnType;
1291356843Sdim      return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo,
1292356843Sdim                                         HTraits::unpackAndRun(Handler, *Args));
1293356843Sdim    };
1294356843Sdim  }
1295356843Sdim
1296356843Sdim  // Wrap the given user handler in the necessary argument-deserialization code,
1297356843Sdim  // result-serialization code, and call to the launch policy (if present).
1298356843Sdim  template <typename Func, typename HandlerT>
1299356843Sdim  WrappedHandlerFn wrapAsyncHandler(HandlerT Handler) {
1300356843Sdim    return [this, Handler](ChannelT &Channel,
1301356843Sdim                           SequenceNumberT SeqNo) mutable -> Error {
1302356843Sdim      // Start by deserializing the arguments.
1303356843Sdim      using AHTraits = detail::AsyncHandlerTraits<
1304356843Sdim                         typename detail::HandlerTraits<HandlerT>::Type>;
1305356843Sdim      using ArgsTuple =
1306356843Sdim          typename detail::FunctionArgsTuple<typename AHTraits::Type>::Type;
1307356843Sdim      auto Args = std::make_shared<ArgsTuple>();
1308356843Sdim
1309356843Sdim      if (auto Err =
1310356843Sdim              detail::HandlerTraits<typename Func::Type>::deserializeArgs(
1311356843Sdim                  Channel, *Args))
1312356843Sdim        return Err;
1313356843Sdim
1314356843Sdim      // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
1315356843Sdim      // for RPCArgs. Void cast RPCArgs to work around this for now.
1316356843Sdim      // FIXME: Remove this workaround once we can assume a working GCC version.
1317356843Sdim      (void)Args;
1318356843Sdim
1319356843Sdim      // End receieve message, unlocking the channel for reading.
1320356843Sdim      if (auto Err = Channel.endReceiveMessage())
1321356843Sdim        return Err;
1322356843Sdim
1323356843Sdim      using HTraits = detail::HandlerTraits<HandlerT>;
1324356843Sdim      using FuncReturn = typename Func::ReturnType;
1325356843Sdim      auto Responder =
1326356843Sdim        [this, SeqNo](typename AHTraits::ResultType RetVal) -> Error {
1327356843Sdim          return detail::respond<FuncReturn>(C, ResponseId, SeqNo,
1328356843Sdim                                             std::move(RetVal));
1329356843Sdim        };
1330356843Sdim
1331356843Sdim      return HTraits::unpackAndRunAsync(Handler, Responder, *Args);
1332356843Sdim    };
1333356843Sdim  }
1334356843Sdim
1335356843Sdim  ChannelT &C;
1336356843Sdim
1337356843Sdim  bool LazyAutoNegotiation;
1338356843Sdim
1339356843Sdim  RPCFunctionIdAllocator<FunctionIdT> FnIdAllocator;
1340356843Sdim
1341356843Sdim  FunctionIdT ResponseId;
1342356843Sdim  std::map<std::string, FunctionIdT> LocalFunctionIds;
1343356843Sdim  std::map<const char *, FunctionIdT> RemoteFunctionIds;
1344356843Sdim
1345356843Sdim  std::map<FunctionIdT, WrappedHandlerFn> Handlers;
1346356843Sdim
1347356843Sdim  std::mutex ResponsesMutex;
1348356843Sdim  detail::SequenceNumberManager<SequenceNumberT> SequenceNumberMgr;
1349356843Sdim  std::map<SequenceNumberT, std::unique_ptr<detail::ResponseHandler<ChannelT>>>
1350356843Sdim      PendingResponses;
1351356843Sdim};
1352356843Sdim
1353356843Sdim} // end namespace detail
1354356843Sdim
1355356843Sdimtemplate <typename ChannelT, typename FunctionIdT = uint32_t,
1356356843Sdim          typename SequenceNumberT = uint32_t>
1357356843Sdimclass MultiThreadedRPCEndpoint
1358356843Sdim    : public detail::RPCEndpointBase<
1359356843Sdim          MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1360356843Sdim          ChannelT, FunctionIdT, SequenceNumberT> {
1361356843Sdimprivate:
1362356843Sdim  using BaseClass =
1363356843Sdim      detail::RPCEndpointBase<
1364356843Sdim        MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1365356843Sdim        ChannelT, FunctionIdT, SequenceNumberT>;
1366356843Sdim
1367356843Sdimpublic:
1368356843Sdim  MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
1369356843Sdim      : BaseClass(C, LazyAutoNegotiation) {}
1370356843Sdim
1371356843Sdim  /// Add a handler for the given RPC function.
1372356843Sdim  /// This installs the given handler functor for the given RPC Function, and
1373356843Sdim  /// makes the RPC function available for negotiation/calling from the remote.
1374356843Sdim  template <typename Func, typename HandlerT>
1375356843Sdim  void addHandler(HandlerT Handler) {
1376356843Sdim    return this->template addHandlerImpl<Func>(std::move(Handler));
1377356843Sdim  }
1378356843Sdim
1379356843Sdim  /// Add a class-method as a handler.
1380356843Sdim  template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
1381356843Sdim  void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1382356843Sdim    addHandler<Func>(
1383356843Sdim      detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1384356843Sdim  }
1385356843Sdim
1386356843Sdim  template <typename Func, typename HandlerT>
1387356843Sdim  void addAsyncHandler(HandlerT Handler) {
1388356843Sdim    return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
1389356843Sdim  }
1390356843Sdim
1391356843Sdim  /// Add a class-method as a handler.
1392356843Sdim  template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
1393356843Sdim  void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1394356843Sdim    addAsyncHandler<Func>(
1395356843Sdim      detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1396356843Sdim  }
1397356843Sdim
1398356843Sdim  /// Return type for non-blocking call primitives.
1399356843Sdim  template <typename Func>
1400356843Sdim  using NonBlockingCallResult = typename detail::ResultTraits<
1401356843Sdim      typename Func::ReturnType>::ReturnFutureType;
1402356843Sdim
1403356843Sdim  /// Call Func on Channel C. Does not block, does not call send. Returns a pair
1404356843Sdim  /// of a future result and the sequence number assigned to the result.
1405356843Sdim  ///
1406356843Sdim  /// This utility function is primarily used for single-threaded mode support,
1407356843Sdim  /// where the sequence number can be used to wait for the corresponding
1408356843Sdim  /// result. In multi-threaded mode the appendCallNB method, which does not
1409356843Sdim  /// return the sequence numeber, should be preferred.
1410356843Sdim  template <typename Func, typename... ArgTs>
1411356843Sdim  Expected<NonBlockingCallResult<Func>> appendCallNB(const ArgTs &... Args) {
1412356843Sdim    using RTraits = detail::ResultTraits<typename Func::ReturnType>;
1413356843Sdim    using ErrorReturn = typename RTraits::ErrorReturnType;
1414356843Sdim    using ErrorReturnPromise = typename RTraits::ReturnPromiseType;
1415356843Sdim
1416356843Sdim    ErrorReturnPromise Promise;
1417356843Sdim    auto FutureResult = Promise.get_future();
1418356843Sdim
1419356843Sdim    if (auto Err = this->template appendCallAsync<Func>(
1420356843Sdim            [Promise = std::move(Promise)](ErrorReturn RetOrErr) mutable {
1421356843Sdim              Promise.set_value(std::move(RetOrErr));
1422356843Sdim              return Error::success();
1423356843Sdim            },
1424356843Sdim            Args...)) {
1425356843Sdim      RTraits::consumeAbandoned(FutureResult.get());
1426356843Sdim      return std::move(Err);
1427356843Sdim    }
1428356843Sdim    return std::move(FutureResult);
1429356843Sdim  }
1430356843Sdim
1431356843Sdim  /// The same as appendCallNBWithSeq, except that it calls C.send() to
1432356843Sdim  /// flush the channel after serializing the call.
1433356843Sdim  template <typename Func, typename... ArgTs>
1434356843Sdim  Expected<NonBlockingCallResult<Func>> callNB(const ArgTs &... Args) {
1435356843Sdim    auto Result = appendCallNB<Func>(Args...);
1436356843Sdim    if (!Result)
1437356843Sdim      return Result;
1438356843Sdim    if (auto Err = this->C.send()) {
1439356843Sdim      this->abandonPendingResponses();
1440356843Sdim      detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1441356843Sdim          std::move(Result->get()));
1442356843Sdim      return std::move(Err);
1443356843Sdim    }
1444356843Sdim    return Result;
1445356843Sdim  }
1446356843Sdim
1447356843Sdim  /// Call Func on Channel C. Blocks waiting for a result. Returns an Error
1448356843Sdim  /// for void functions or an Expected<T> for functions returning a T.
1449356843Sdim  ///
1450356843Sdim  /// This function is for use in threaded code where another thread is
1451356843Sdim  /// handling responses and incoming calls.
1452356843Sdim  template <typename Func, typename... ArgTs,
1453356843Sdim            typename AltRetT = typename Func::ReturnType>
1454356843Sdim  typename detail::ResultTraits<AltRetT>::ErrorReturnType
1455356843Sdim  callB(const ArgTs &... Args) {
1456356843Sdim    if (auto FutureResOrErr = callNB<Func>(Args...))
1457356843Sdim      return FutureResOrErr->get();
1458356843Sdim    else
1459356843Sdim      return FutureResOrErr.takeError();
1460356843Sdim  }
1461356843Sdim
1462356843Sdim  /// Handle incoming RPC calls.
1463356843Sdim  Error handlerLoop() {
1464356843Sdim    while (true)
1465356843Sdim      if (auto Err = this->handleOne())
1466356843Sdim        return Err;
1467356843Sdim    return Error::success();
1468356843Sdim  }
1469356843Sdim};
1470356843Sdim
1471356843Sdimtemplate <typename ChannelT, typename FunctionIdT = uint32_t,
1472356843Sdim          typename SequenceNumberT = uint32_t>
1473356843Sdimclass SingleThreadedRPCEndpoint
1474356843Sdim    : public detail::RPCEndpointBase<
1475356843Sdim          SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1476356843Sdim          ChannelT, FunctionIdT, SequenceNumberT> {
1477356843Sdimprivate:
1478356843Sdim  using BaseClass =
1479356843Sdim      detail::RPCEndpointBase<
1480356843Sdim        SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1481356843Sdim        ChannelT, FunctionIdT, SequenceNumberT>;
1482356843Sdim
1483356843Sdimpublic:
1484356843Sdim  SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
1485356843Sdim      : BaseClass(C, LazyAutoNegotiation) {}
1486356843Sdim
1487356843Sdim  template <typename Func, typename HandlerT>
1488356843Sdim  void addHandler(HandlerT Handler) {
1489356843Sdim    return this->template addHandlerImpl<Func>(std::move(Handler));
1490356843Sdim  }
1491356843Sdim
1492356843Sdim  template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
1493356843Sdim  void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1494356843Sdim    addHandler<Func>(
1495356843Sdim        detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1496356843Sdim  }
1497356843Sdim
1498356843Sdim  template <typename Func, typename HandlerT>
1499356843Sdim  void addAsyncHandler(HandlerT Handler) {
1500356843Sdim    return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
1501356843Sdim  }
1502356843Sdim
1503356843Sdim  /// Add a class-method as a handler.
1504356843Sdim  template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
1505356843Sdim  void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1506356843Sdim    addAsyncHandler<Func>(
1507356843Sdim      detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1508356843Sdim  }
1509356843Sdim
1510356843Sdim  template <typename Func, typename... ArgTs,
1511356843Sdim            typename AltRetT = typename Func::ReturnType>
1512356843Sdim  typename detail::ResultTraits<AltRetT>::ErrorReturnType
1513356843Sdim  callB(const ArgTs &... Args) {
1514356843Sdim    bool ReceivedResponse = false;
1515356843Sdim    using ResultType = typename detail::ResultTraits<AltRetT>::ErrorReturnType;
1516356843Sdim    auto Result = detail::ResultTraits<AltRetT>::createBlankErrorReturnValue();
1517356843Sdim
1518356843Sdim    // We have to 'Check' result (which we know is in a success state at this
1519356843Sdim    // point) so that it can be overwritten in the async handler.
1520356843Sdim    (void)!!Result;
1521356843Sdim
1522356843Sdim    if (auto Err = this->template appendCallAsync<Func>(
1523356843Sdim            [&](ResultType R) {
1524356843Sdim              Result = std::move(R);
1525356843Sdim              ReceivedResponse = true;
1526356843Sdim              return Error::success();
1527356843Sdim            },
1528356843Sdim            Args...)) {
1529356843Sdim      detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1530356843Sdim          std::move(Result));
1531356843Sdim      return std::move(Err);
1532356843Sdim    }
1533356843Sdim
1534356843Sdim    if (auto Err = this->C.send()) {
1535356843Sdim      detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1536356843Sdim          std::move(Result));
1537356843Sdim      return std::move(Err);
1538356843Sdim    }
1539356843Sdim
1540356843Sdim    while (!ReceivedResponse) {
1541356843Sdim      if (auto Err = this->handleOne()) {
1542356843Sdim        detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1543356843Sdim            std::move(Result));
1544356843Sdim        return std::move(Err);
1545356843Sdim      }
1546356843Sdim    }
1547356843Sdim
1548356843Sdim    return Result;
1549356843Sdim  }
1550356843Sdim};
1551356843Sdim
1552356843Sdim/// Asynchronous dispatch for a function on an RPC endpoint.
1553356843Sdimtemplate <typename RPCClass, typename Func>
1554356843Sdimclass RPCAsyncDispatch {
1555356843Sdimpublic:
1556356843Sdim  RPCAsyncDispatch(RPCClass &Endpoint) : Endpoint(Endpoint) {}
1557356843Sdim
1558356843Sdim  template <typename HandlerT, typename... ArgTs>
1559356843Sdim  Error operator()(HandlerT Handler, const ArgTs &... Args) const {
1560356843Sdim    return Endpoint.template appendCallAsync<Func>(std::move(Handler), Args...);
1561356843Sdim  }
1562356843Sdim
1563356843Sdimprivate:
1564356843Sdim  RPCClass &Endpoint;
1565356843Sdim};
1566356843Sdim
1567356843Sdim/// Construct an asynchronous dispatcher from an RPC endpoint and a Func.
1568356843Sdimtemplate <typename Func, typename RPCEndpointT>
1569356843SdimRPCAsyncDispatch<RPCEndpointT, Func> rpcAsyncDispatch(RPCEndpointT &Endpoint) {
1570356843Sdim  return RPCAsyncDispatch<RPCEndpointT, Func>(Endpoint);
1571356843Sdim}
1572356843Sdim
1573356843Sdim/// Allows a set of asynchrounous calls to be dispatched, and then
1574356843Sdim///        waited on as a group.
1575356843Sdimclass ParallelCallGroup {
1576356843Sdimpublic:
1577356843Sdim
1578356843Sdim  ParallelCallGroup() = default;
1579356843Sdim  ParallelCallGroup(const ParallelCallGroup &) = delete;
1580356843Sdim  ParallelCallGroup &operator=(const ParallelCallGroup &) = delete;
1581356843Sdim
1582356843Sdim  /// Make as asynchronous call.
1583356843Sdim  template <typename AsyncDispatcher, typename HandlerT, typename... ArgTs>
1584356843Sdim  Error call(const AsyncDispatcher &AsyncDispatch, HandlerT Handler,
1585356843Sdim             const ArgTs &... Args) {
1586356843Sdim    // Increment the count of outstanding calls. This has to happen before
1587356843Sdim    // we invoke the call, as the handler may (depending on scheduling)
1588356843Sdim    // be run immediately on another thread, and we don't want the decrement
1589356843Sdim    // in the wrapped handler below to run before the increment.
1590356843Sdim    {
1591356843Sdim      std::unique_lock<std::mutex> Lock(M);
1592356843Sdim      ++NumOutstandingCalls;
1593356843Sdim    }
1594356843Sdim
1595356843Sdim    // Wrap the user handler in a lambda that will decrement the
1596356843Sdim    // outstanding calls count, then poke the condition variable.
1597356843Sdim    using ArgType = typename detail::ResponseHandlerArg<
1598356843Sdim        typename detail::HandlerTraits<HandlerT>::Type>::ArgType;
1599356843Sdim    auto WrappedHandler = [this, Handler = std::move(Handler)](ArgType Arg) {
1600356843Sdim      auto Err = Handler(std::move(Arg));
1601356843Sdim      std::unique_lock<std::mutex> Lock(M);
1602356843Sdim      --NumOutstandingCalls;
1603356843Sdim      CV.notify_all();
1604356843Sdim      return Err;
1605356843Sdim    };
1606356843Sdim
1607356843Sdim    return AsyncDispatch(std::move(WrappedHandler), Args...);
1608356843Sdim  }
1609356843Sdim
1610356843Sdim  /// Blocks until all calls have been completed and their return value
1611356843Sdim  ///        handlers run.
1612356843Sdim  void wait() {
1613356843Sdim    std::unique_lock<std::mutex> Lock(M);
1614356843Sdim    while (NumOutstandingCalls > 0)
1615356843Sdim      CV.wait(Lock);
1616356843Sdim  }
1617356843Sdim
1618356843Sdimprivate:
1619356843Sdim  std::mutex M;
1620356843Sdim  std::condition_variable CV;
1621356843Sdim  uint32_t NumOutstandingCalls = 0;
1622356843Sdim};
1623356843Sdim
1624356843Sdim/// Convenience class for grouping RPC Functions into APIs that can be
1625356843Sdim///        negotiated as a block.
1626356843Sdim///
1627356843Sdimtemplate <typename... Funcs>
1628356843Sdimclass APICalls {
1629356843Sdimpublic:
1630356843Sdim
1631356843Sdim  /// Test whether this API contains Function F.
1632356843Sdim  template <typename F>
1633356843Sdim  class Contains {
1634356843Sdim  public:
1635356843Sdim    static const bool value = false;
1636356843Sdim  };
1637356843Sdim
1638356843Sdim  /// Negotiate all functions in this API.
1639356843Sdim  template <typename RPCEndpoint>
1640356843Sdim  static Error negotiate(RPCEndpoint &R) {
1641356843Sdim    return Error::success();
1642356843Sdim  }
1643356843Sdim};
1644356843Sdim
1645356843Sdimtemplate <typename Func, typename... Funcs>
1646356843Sdimclass APICalls<Func, Funcs...> {
1647356843Sdimpublic:
1648356843Sdim
1649356843Sdim  template <typename F>
1650356843Sdim  class Contains {
1651356843Sdim  public:
1652356843Sdim    static const bool value = std::is_same<F, Func>::value |
1653356843Sdim                              APICalls<Funcs...>::template Contains<F>::value;
1654356843Sdim  };
1655356843Sdim
1656356843Sdim  template <typename RPCEndpoint>
1657356843Sdim  static Error negotiate(RPCEndpoint &R) {
1658356843Sdim    if (auto Err = R.template negotiateFunction<Func>())
1659356843Sdim      return Err;
1660356843Sdim    return APICalls<Funcs...>::negotiate(R);
1661356843Sdim  }
1662356843Sdim
1663356843Sdim};
1664356843Sdim
1665356843Sdimtemplate <typename... InnerFuncs, typename... Funcs>
1666356843Sdimclass APICalls<APICalls<InnerFuncs...>, Funcs...> {
1667356843Sdimpublic:
1668356843Sdim
1669356843Sdim  template <typename F>
1670356843Sdim  class Contains {
1671356843Sdim  public:
1672356843Sdim    static const bool value =
1673356843Sdim      APICalls<InnerFuncs...>::template Contains<F>::value |
1674356843Sdim      APICalls<Funcs...>::template Contains<F>::value;
1675356843Sdim  };
1676356843Sdim
1677356843Sdim  template <typename RPCEndpoint>
1678356843Sdim  static Error negotiate(RPCEndpoint &R) {
1679356843Sdim    if (auto Err = APICalls<InnerFuncs...>::negotiate(R))
1680356843Sdim      return Err;
1681356843Sdim    return APICalls<Funcs...>::negotiate(R);
1682356843Sdim  }
1683356843Sdim
1684356843Sdim};
1685356843Sdim
1686356843Sdim} // end namespace rpc
1687356843Sdim} // end namespace orc
1688356843Sdim} // end namespace llvm
1689356843Sdim
1690356843Sdim#endif
1691