1//===- RPCUtils.h - Utilities for building RPC APIs -------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Utilities to support construction of simple RPC APIs.
10//
11// The RPC utilities aim for ease of use (minimal conceptual overhead) for C++
12// programmers, high performance, low memory overhead, and efficient use of the
13// communications channel.
14//
15//===----------------------------------------------------------------------===//
16
17#ifndef LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H
18#define LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H
19
20#include <map>
21#include <thread>
22#include <vector>
23
24#include "llvm/ADT/STLExtras.h"
25#include "llvm/ExecutionEngine/Orc/OrcError.h"
26#include "llvm/ExecutionEngine/Orc/RPC/RPCSerialization.h"
27#include "llvm/Support/MSVCErrorWorkarounds.h"
28
29#include <future>
30
31namespace llvm {
32namespace orc {
33namespace rpc {
34
35/// Base class of all fatal RPC errors (those that necessarily result in the
36/// termination of the RPC session).
37class RPCFatalError : public ErrorInfo<RPCFatalError> {
38public:
39  static char ID;
40};
41
42/// RPCConnectionClosed is returned from RPC operations if the RPC connection
43/// has already been closed due to either an error or graceful disconnection.
44class ConnectionClosed : public ErrorInfo<ConnectionClosed> {
45public:
46  static char ID;
47  std::error_code convertToErrorCode() const override;
48  void log(raw_ostream &OS) const override;
49};
50
51/// BadFunctionCall is returned from handleOne when the remote makes a call with
52/// an unrecognized function id.
53///
54/// This error is fatal because Orc RPC needs to know how to parse a function
55/// call to know where the next call starts, and if it doesn't recognize the
56/// function id it cannot parse the call.
57template <typename FnIdT, typename SeqNoT>
58class BadFunctionCall
59  : public ErrorInfo<BadFunctionCall<FnIdT, SeqNoT>, RPCFatalError> {
60public:
61  static char ID;
62
63  BadFunctionCall(FnIdT FnId, SeqNoT SeqNo)
64      : FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {}
65
66  std::error_code convertToErrorCode() const override {
67    return orcError(OrcErrorCode::UnexpectedRPCCall);
68  }
69
70  void log(raw_ostream &OS) const override {
71    OS << "Call to invalid RPC function id '" << FnId << "' with "
72          "sequence number " << SeqNo;
73  }
74
75private:
76  FnIdT FnId;
77  SeqNoT SeqNo;
78};
79
80template <typename FnIdT, typename SeqNoT>
81char BadFunctionCall<FnIdT, SeqNoT>::ID = 0;
82
83/// InvalidSequenceNumberForResponse is returned from handleOne when a response
84/// call arrives with a sequence number that doesn't correspond to any in-flight
85/// function call.
86///
87/// This error is fatal because Orc RPC needs to know how to parse the rest of
88/// the response call to know where the next call starts, and if it doesn't have
89/// a result parser for this sequence number it can't do that.
90template <typename SeqNoT>
91class InvalidSequenceNumberForResponse
92    : public ErrorInfo<InvalidSequenceNumberForResponse<SeqNoT>, RPCFatalError> {
93public:
94  static char ID;
95
96  InvalidSequenceNumberForResponse(SeqNoT SeqNo)
97      : SeqNo(std::move(SeqNo)) {}
98
99  std::error_code convertToErrorCode() const override {
100    return orcError(OrcErrorCode::UnexpectedRPCCall);
101  };
102
103  void log(raw_ostream &OS) const override {
104    OS << "Response has unknown sequence number " << SeqNo;
105  }
106private:
107  SeqNoT SeqNo;
108};
109
110template <typename SeqNoT>
111char InvalidSequenceNumberForResponse<SeqNoT>::ID = 0;
112
113/// This non-fatal error will be passed to asynchronous result handlers in place
114/// of a result if the connection goes down before a result returns, or if the
115/// function to be called cannot be negotiated with the remote.
116class ResponseAbandoned : public ErrorInfo<ResponseAbandoned> {
117public:
118  static char ID;
119
120  std::error_code convertToErrorCode() const override;
121  void log(raw_ostream &OS) const override;
122};
123
124/// This error is returned if the remote does not have a handler installed for
125/// the given RPC function.
126class CouldNotNegotiate : public ErrorInfo<CouldNotNegotiate> {
127public:
128  static char ID;
129
130  CouldNotNegotiate(std::string Signature);
131  std::error_code convertToErrorCode() const override;
132  void log(raw_ostream &OS) const override;
133  const std::string &getSignature() const { return Signature; }
134private:
135  std::string Signature;
136};
137
138template <typename DerivedFunc, typename FnT> class Function;
139
140// RPC Function class.
141// DerivedFunc should be a user defined class with a static 'getName()' method
142// returning a const char* representing the function's name.
143template <typename DerivedFunc, typename RetT, typename... ArgTs>
144class Function<DerivedFunc, RetT(ArgTs...)> {
145public:
146  /// User defined function type.
147  using Type = RetT(ArgTs...);
148
149  /// Return type.
150  using ReturnType = RetT;
151
152  /// Returns the full function prototype as a string.
153  static const char *getPrototype() {
154    static std::string Name = [] {
155      std::string Name;
156      raw_string_ostream(Name)
157          << RPCTypeName<RetT>::getName() << " " << DerivedFunc::getName()
158          << "(" << llvm::orc::rpc::RPCTypeNameSequence<ArgTs...>() << ")";
159      return Name;
160    }();
161    return Name.data();
162  }
163};
164
165/// Allocates RPC function ids during autonegotiation.
166/// Specializations of this class must provide four members:
167///
168/// static T getInvalidId():
169///   Should return a reserved id that will be used to represent missing
170/// functions during autonegotiation.
171///
172/// static T getResponseId():
173///   Should return a reserved id that will be used to send function responses
174/// (return values).
175///
176/// static T getNegotiateId():
177///   Should return a reserved id for the negotiate function, which will be used
178/// to negotiate ids for user defined functions.
179///
180/// template <typename Func> T allocate():
181///   Allocate a unique id for function Func.
182template <typename T, typename = void> class RPCFunctionIdAllocator;
183
184/// This specialization of RPCFunctionIdAllocator provides a default
185/// implementation for integral types.
186template <typename T>
187class RPCFunctionIdAllocator<
188    T, typename std::enable_if<std::is_integral<T>::value>::type> {
189public:
190  static T getInvalidId() { return T(0); }
191  static T getResponseId() { return T(1); }
192  static T getNegotiateId() { return T(2); }
193
194  template <typename Func> T allocate() { return NextId++; }
195
196private:
197  T NextId = 3;
198};
199
200namespace detail {
201
202/// Provides a typedef for a tuple containing the decayed argument types.
203template <typename T> class FunctionArgsTuple;
204
205template <typename RetT, typename... ArgTs>
206class FunctionArgsTuple<RetT(ArgTs...)> {
207public:
208  using Type = std::tuple<typename std::decay<
209      typename std::remove_reference<ArgTs>::type>::type...>;
210};
211
212// ResultTraits provides typedefs and utilities specific to the return type
213// of functions.
214template <typename RetT> class ResultTraits {
215public:
216  // The return type wrapped in llvm::Expected.
217  using ErrorReturnType = Expected<RetT>;
218
219#ifdef _MSC_VER
220  // The ErrorReturnType wrapped in a std::promise.
221  using ReturnPromiseType = std::promise<MSVCPExpected<RetT>>;
222
223  // The ErrorReturnType wrapped in a std::future.
224  using ReturnFutureType = std::future<MSVCPExpected<RetT>>;
225#else
226  // The ErrorReturnType wrapped in a std::promise.
227  using ReturnPromiseType = std::promise<ErrorReturnType>;
228
229  // The ErrorReturnType wrapped in a std::future.
230  using ReturnFutureType = std::future<ErrorReturnType>;
231#endif
232
233  // Create a 'blank' value of the ErrorReturnType, ready and safe to
234  // overwrite.
235  static ErrorReturnType createBlankErrorReturnValue() {
236    return ErrorReturnType(RetT());
237  }
238
239  // Consume an abandoned ErrorReturnType.
240  static void consumeAbandoned(ErrorReturnType RetOrErr) {
241    consumeError(RetOrErr.takeError());
242  }
243};
244
245// ResultTraits specialization for void functions.
246template <> class ResultTraits<void> {
247public:
248  // For void functions, ErrorReturnType is llvm::Error.
249  using ErrorReturnType = Error;
250
251#ifdef _MSC_VER
252  // The ErrorReturnType wrapped in a std::promise.
253  using ReturnPromiseType = std::promise<MSVCPError>;
254
255  // The ErrorReturnType wrapped in a std::future.
256  using ReturnFutureType = std::future<MSVCPError>;
257#else
258  // The ErrorReturnType wrapped in a std::promise.
259  using ReturnPromiseType = std::promise<ErrorReturnType>;
260
261  // The ErrorReturnType wrapped in a std::future.
262  using ReturnFutureType = std::future<ErrorReturnType>;
263#endif
264
265  // Create a 'blank' value of the ErrorReturnType, ready and safe to
266  // overwrite.
267  static ErrorReturnType createBlankErrorReturnValue() {
268    return ErrorReturnType::success();
269  }
270
271  // Consume an abandoned ErrorReturnType.
272  static void consumeAbandoned(ErrorReturnType Err) {
273    consumeError(std::move(Err));
274  }
275};
276
277// ResultTraits<Error> is equivalent to ResultTraits<void>. This allows
278// handlers for void RPC functions to return either void (in which case they
279// implicitly succeed) or Error (in which case their error return is
280// propagated). See usage in HandlerTraits::runHandlerHelper.
281template <> class ResultTraits<Error> : public ResultTraits<void> {};
282
283// ResultTraits<Expected<T>> is equivalent to ResultTraits<T>. This allows
284// handlers for RPC functions returning a T to return either a T (in which
285// case they implicitly succeed) or Expected<T> (in which case their error
286// return is propagated). See usage in HandlerTraits::runHandlerHelper.
287template <typename RetT>
288class ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {};
289
290// Determines whether an RPC function's defined error return type supports
291// error return value.
292template <typename T>
293class SupportsErrorReturn {
294public:
295  static const bool value = false;
296};
297
298template <>
299class SupportsErrorReturn<Error> {
300public:
301  static const bool value = true;
302};
303
304template <typename T>
305class SupportsErrorReturn<Expected<T>> {
306public:
307  static const bool value = true;
308};
309
310// RespondHelper packages return values based on whether or not the declared
311// RPC function return type supports error returns.
312template <bool FuncSupportsErrorReturn>
313class RespondHelper;
314
315// RespondHelper specialization for functions that support error returns.
316template <>
317class RespondHelper<true> {
318public:
319
320  // Send Expected<T>.
321  template <typename WireRetT, typename HandlerRetT, typename ChannelT,
322            typename FunctionIdT, typename SequenceNumberT>
323  static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
324                          SequenceNumberT SeqNo,
325                          Expected<HandlerRetT> ResultOrErr) {
326    if (!ResultOrErr && ResultOrErr.template errorIsA<RPCFatalError>())
327      return ResultOrErr.takeError();
328
329    // Open the response message.
330    if (auto Err = C.startSendMessage(ResponseId, SeqNo))
331      return Err;
332
333    // Serialize the result.
334    if (auto Err =
335        SerializationTraits<ChannelT, WireRetT,
336                            Expected<HandlerRetT>>::serialize(
337                                                     C, std::move(ResultOrErr)))
338      return Err;
339
340    // Close the response message.
341    if (auto Err = C.endSendMessage())
342      return Err;
343    return C.send();
344  }
345
346  template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
347  static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
348                          SequenceNumberT SeqNo, Error Err) {
349    if (Err && Err.isA<RPCFatalError>())
350      return Err;
351    if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
352      return Err2;
353    if (auto Err2 = serializeSeq(C, std::move(Err)))
354      return Err2;
355    if (auto Err2 = C.endSendMessage())
356      return Err2;
357    return C.send();
358  }
359
360};
361
362// RespondHelper specialization for functions that do not support error returns.
363template <>
364class RespondHelper<false> {
365public:
366
367  template <typename WireRetT, typename HandlerRetT, typename ChannelT,
368            typename FunctionIdT, typename SequenceNumberT>
369  static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
370                          SequenceNumberT SeqNo,
371                          Expected<HandlerRetT> ResultOrErr) {
372    if (auto Err = ResultOrErr.takeError())
373      return Err;
374
375    // Open the response message.
376    if (auto Err = C.startSendMessage(ResponseId, SeqNo))
377      return Err;
378
379    // Serialize the result.
380    if (auto Err =
381        SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize(
382                                                               C, *ResultOrErr))
383      return Err;
384
385    // End the response message.
386    if (auto Err = C.endSendMessage())
387      return Err;
388
389    return C.send();
390  }
391
392  template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
393  static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
394                          SequenceNumberT SeqNo, Error Err) {
395    if (Err)
396      return Err;
397    if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
398      return Err2;
399    if (auto Err2 = C.endSendMessage())
400      return Err2;
401    return C.send();
402  }
403
404};
405
406
407// Send a response of the given wire return type (WireRetT) over the
408// channel, with the given sequence number.
409template <typename WireRetT, typename HandlerRetT, typename ChannelT,
410          typename FunctionIdT, typename SequenceNumberT>
411Error respond(ChannelT &C, const FunctionIdT &ResponseId,
412              SequenceNumberT SeqNo, Expected<HandlerRetT> ResultOrErr) {
413  return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
414    template sendResult<WireRetT>(C, ResponseId, SeqNo, std::move(ResultOrErr));
415}
416
417// Send an empty response message on the given channel to indicate that
418// the handler ran.
419template <typename WireRetT, typename ChannelT, typename FunctionIdT,
420          typename SequenceNumberT>
421Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo,
422              Error Err) {
423  return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
424    sendResult(C, ResponseId, SeqNo, std::move(Err));
425}
426
427// Converts a given type to the equivalent error return type.
428template <typename T> class WrappedHandlerReturn {
429public:
430  using Type = Expected<T>;
431};
432
433template <typename T> class WrappedHandlerReturn<Expected<T>> {
434public:
435  using Type = Expected<T>;
436};
437
438template <> class WrappedHandlerReturn<void> {
439public:
440  using Type = Error;
441};
442
443template <> class WrappedHandlerReturn<Error> {
444public:
445  using Type = Error;
446};
447
448template <> class WrappedHandlerReturn<ErrorSuccess> {
449public:
450  using Type = Error;
451};
452
453// Traits class that strips the response function from the list of handler
454// arguments.
455template <typename FnT> class AsyncHandlerTraits;
456
457template <typename ResultT, typename... ArgTs>
458class AsyncHandlerTraits<Error(std::function<Error(Expected<ResultT>)>, ArgTs...)> {
459public:
460  using Type = Error(ArgTs...);
461  using ResultType = Expected<ResultT>;
462};
463
464template <typename... ArgTs>
465class AsyncHandlerTraits<Error(std::function<Error(Error)>, ArgTs...)> {
466public:
467  using Type = Error(ArgTs...);
468  using ResultType = Error;
469};
470
471template <typename... ArgTs>
472class AsyncHandlerTraits<ErrorSuccess(std::function<Error(Error)>, ArgTs...)> {
473public:
474  using Type = Error(ArgTs...);
475  using ResultType = Error;
476};
477
478template <typename... ArgTs>
479class AsyncHandlerTraits<void(std::function<Error(Error)>, ArgTs...)> {
480public:
481  using Type = Error(ArgTs...);
482  using ResultType = Error;
483};
484
485template <typename ResponseHandlerT, typename... ArgTs>
486class AsyncHandlerTraits<Error(ResponseHandlerT, ArgTs...)> :
487    public AsyncHandlerTraits<Error(typename std::decay<ResponseHandlerT>::type,
488                                    ArgTs...)> {};
489
490// This template class provides utilities related to RPC function handlers.
491// The base case applies to non-function types (the template class is
492// specialized for function types) and inherits from the appropriate
493// speciilization for the given non-function type's call operator.
494template <typename HandlerT>
495class HandlerTraits : public HandlerTraits<decltype(
496                          &std::remove_reference<HandlerT>::type::operator())> {
497};
498
499// Traits for handlers with a given function type.
500template <typename RetT, typename... ArgTs>
501class HandlerTraits<RetT(ArgTs...)> {
502public:
503  // Function type of the handler.
504  using Type = RetT(ArgTs...);
505
506  // Return type of the handler.
507  using ReturnType = RetT;
508
509  // Call the given handler with the given arguments.
510  template <typename HandlerT, typename... TArgTs>
511  static typename WrappedHandlerReturn<RetT>::Type
512  unpackAndRun(HandlerT &Handler, std::tuple<TArgTs...> &Args) {
513    return unpackAndRunHelper(Handler, Args,
514                              std::index_sequence_for<TArgTs...>());
515  }
516
517  // Call the given handler with the given arguments.
518  template <typename HandlerT, typename ResponderT, typename... TArgTs>
519  static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder,
520                                 std::tuple<TArgTs...> &Args) {
521    return unpackAndRunAsyncHelper(Handler, Responder, Args,
522                                   std::index_sequence_for<TArgTs...>());
523  }
524
525  // Call the given handler with the given arguments.
526  template <typename HandlerT>
527  static typename std::enable_if<
528      std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value,
529      Error>::type
530  run(HandlerT &Handler, ArgTs &&... Args) {
531    Handler(std::move(Args)...);
532    return Error::success();
533  }
534
535  template <typename HandlerT, typename... TArgTs>
536  static typename std::enable_if<
537      !std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value,
538      typename HandlerTraits<HandlerT>::ReturnType>::type
539  run(HandlerT &Handler, TArgTs... Args) {
540    return Handler(std::move(Args)...);
541  }
542
543  // Serialize arguments to the channel.
544  template <typename ChannelT, typename... CArgTs>
545  static Error serializeArgs(ChannelT &C, const CArgTs... CArgs) {
546    return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...);
547  }
548
549  // Deserialize arguments from the channel.
550  template <typename ChannelT, typename... CArgTs>
551  static Error deserializeArgs(ChannelT &C, std::tuple<CArgTs...> &Args) {
552    return deserializeArgsHelper(C, Args, std::index_sequence_for<CArgTs...>());
553  }
554
555private:
556  template <typename ChannelT, typename... CArgTs, size_t... Indexes>
557  static Error deserializeArgsHelper(ChannelT &C, std::tuple<CArgTs...> &Args,
558                                     std::index_sequence<Indexes...> _) {
559    return SequenceSerialization<ChannelT, ArgTs...>::deserialize(
560        C, std::get<Indexes>(Args)...);
561  }
562
563  template <typename HandlerT, typename ArgTuple, size_t... Indexes>
564  static typename WrappedHandlerReturn<
565      typename HandlerTraits<HandlerT>::ReturnType>::Type
566  unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args,
567                     std::index_sequence<Indexes...>) {
568    return run(Handler, std::move(std::get<Indexes>(Args))...);
569  }
570
571  template <typename HandlerT, typename ResponderT, typename ArgTuple,
572            size_t... Indexes>
573  static typename WrappedHandlerReturn<
574      typename HandlerTraits<HandlerT>::ReturnType>::Type
575  unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder,
576                          ArgTuple &Args, std::index_sequence<Indexes...>) {
577    return run(Handler, Responder, std::move(std::get<Indexes>(Args))...);
578  }
579};
580
581// Handler traits for free functions.
582template <typename RetT, typename... ArgTs>
583class HandlerTraits<RetT(*)(ArgTs...)>
584  : public HandlerTraits<RetT(ArgTs...)> {};
585
586// Handler traits for class methods (especially call operators for lambdas).
587template <typename Class, typename RetT, typename... ArgTs>
588class HandlerTraits<RetT (Class::*)(ArgTs...)>
589    : public HandlerTraits<RetT(ArgTs...)> {};
590
591// Handler traits for const class methods (especially call operators for
592// lambdas).
593template <typename Class, typename RetT, typename... ArgTs>
594class HandlerTraits<RetT (Class::*)(ArgTs...) const>
595    : public HandlerTraits<RetT(ArgTs...)> {};
596
597// Utility to peel the Expected wrapper off a response handler error type.
598template <typename HandlerT> class ResponseHandlerArg;
599
600template <typename ArgT> class ResponseHandlerArg<Error(Expected<ArgT>)> {
601public:
602  using ArgType = Expected<ArgT>;
603  using UnwrappedArgType = ArgT;
604};
605
606template <typename ArgT>
607class ResponseHandlerArg<ErrorSuccess(Expected<ArgT>)> {
608public:
609  using ArgType = Expected<ArgT>;
610  using UnwrappedArgType = ArgT;
611};
612
613template <> class ResponseHandlerArg<Error(Error)> {
614public:
615  using ArgType = Error;
616};
617
618template <> class ResponseHandlerArg<ErrorSuccess(Error)> {
619public:
620  using ArgType = Error;
621};
622
623// ResponseHandler represents a handler for a not-yet-received function call
624// result.
625template <typename ChannelT> class ResponseHandler {
626public:
627  virtual ~ResponseHandler() {}
628
629  // Reads the function result off the wire and acts on it. The meaning of
630  // "act" will depend on how this method is implemented in any given
631  // ResponseHandler subclass but could, for example, mean running a
632  // user-specified handler or setting a promise value.
633  virtual Error handleResponse(ChannelT &C) = 0;
634
635  // Abandons this outstanding result.
636  virtual void abandon() = 0;
637
638  // Create an error instance representing an abandoned response.
639  static Error createAbandonedResponseError() {
640    return make_error<ResponseAbandoned>();
641  }
642};
643
644// ResponseHandler subclass for RPC functions with non-void returns.
645template <typename ChannelT, typename FuncRetT, typename HandlerT>
646class ResponseHandlerImpl : public ResponseHandler<ChannelT> {
647public:
648  ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
649
650  // Handle the result by deserializing it from the channel then passing it
651  // to the user defined handler.
652  Error handleResponse(ChannelT &C) override {
653    using UnwrappedArgType = typename ResponseHandlerArg<
654        typename HandlerTraits<HandlerT>::Type>::UnwrappedArgType;
655    UnwrappedArgType Result;
656    if (auto Err =
657            SerializationTraits<ChannelT, FuncRetT,
658                                UnwrappedArgType>::deserialize(C, Result))
659      return Err;
660    if (auto Err = C.endReceiveMessage())
661      return Err;
662    return Handler(std::move(Result));
663  }
664
665  // Abandon this response by calling the handler with an 'abandoned response'
666  // error.
667  void abandon() override {
668    if (auto Err = Handler(this->createAbandonedResponseError())) {
669      // Handlers should not fail when passed an abandoned response error.
670      report_fatal_error(std::move(Err));
671    }
672  }
673
674private:
675  HandlerT Handler;
676};
677
678// ResponseHandler subclass for RPC functions with void returns.
679template <typename ChannelT, typename HandlerT>
680class ResponseHandlerImpl<ChannelT, void, HandlerT>
681    : public ResponseHandler<ChannelT> {
682public:
683  ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
684
685  // Handle the result (no actual value, just a notification that the function
686  // has completed on the remote end) by calling the user-defined handler with
687  // Error::success().
688  Error handleResponse(ChannelT &C) override {
689    if (auto Err = C.endReceiveMessage())
690      return Err;
691    return Handler(Error::success());
692  }
693
694  // Abandon this response by calling the handler with an 'abandoned response'
695  // error.
696  void abandon() override {
697    if (auto Err = Handler(this->createAbandonedResponseError())) {
698      // Handlers should not fail when passed an abandoned response error.
699      report_fatal_error(std::move(Err));
700    }
701  }
702
703private:
704  HandlerT Handler;
705};
706
707template <typename ChannelT, typename FuncRetT, typename HandlerT>
708class ResponseHandlerImpl<ChannelT, Expected<FuncRetT>, HandlerT>
709    : public ResponseHandler<ChannelT> {
710public:
711  ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
712
713  // Handle the result by deserializing it from the channel then passing it
714  // to the user defined handler.
715  Error handleResponse(ChannelT &C) override {
716    using HandlerArgType = typename ResponseHandlerArg<
717        typename HandlerTraits<HandlerT>::Type>::ArgType;
718    HandlerArgType Result((typename HandlerArgType::value_type()));
719
720    if (auto Err =
721            SerializationTraits<ChannelT, Expected<FuncRetT>,
722                                HandlerArgType>::deserialize(C, Result))
723      return Err;
724    if (auto Err = C.endReceiveMessage())
725      return Err;
726    return Handler(std::move(Result));
727  }
728
729  // Abandon this response by calling the handler with an 'abandoned response'
730  // error.
731  void abandon() override {
732    if (auto Err = Handler(this->createAbandonedResponseError())) {
733      // Handlers should not fail when passed an abandoned response error.
734      report_fatal_error(std::move(Err));
735    }
736  }
737
738private:
739  HandlerT Handler;
740};
741
742template <typename ChannelT, typename HandlerT>
743class ResponseHandlerImpl<ChannelT, Error, HandlerT>
744    : public ResponseHandler<ChannelT> {
745public:
746  ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
747
748  // Handle the result by deserializing it from the channel then passing it
749  // to the user defined handler.
750  Error handleResponse(ChannelT &C) override {
751    Error Result = Error::success();
752    if (auto Err = SerializationTraits<ChannelT, Error, Error>::deserialize(
753            C, Result)) {
754      consumeError(std::move(Result));
755      return Err;
756    }
757    if (auto Err = C.endReceiveMessage()) {
758      consumeError(std::move(Result));
759      return Err;
760    }
761    return Handler(std::move(Result));
762  }
763
764  // Abandon this response by calling the handler with an 'abandoned response'
765  // error.
766  void abandon() override {
767    if (auto Err = Handler(this->createAbandonedResponseError())) {
768      // Handlers should not fail when passed an abandoned response error.
769      report_fatal_error(std::move(Err));
770    }
771  }
772
773private:
774  HandlerT Handler;
775};
776
777// Create a ResponseHandler from a given user handler.
778template <typename ChannelT, typename FuncRetT, typename HandlerT>
779std::unique_ptr<ResponseHandler<ChannelT>> createResponseHandler(HandlerT H) {
780  return std::make_unique<ResponseHandlerImpl<ChannelT, FuncRetT, HandlerT>>(
781      std::move(H));
782}
783
784// Helper for wrapping member functions up as functors. This is useful for
785// installing methods as result handlers.
786template <typename ClassT, typename RetT, typename... ArgTs>
787class MemberFnWrapper {
788public:
789  using MethodT = RetT (ClassT::*)(ArgTs...);
790  MemberFnWrapper(ClassT &Instance, MethodT Method)
791      : Instance(Instance), Method(Method) {}
792  RetT operator()(ArgTs &&... Args) {
793    return (Instance.*Method)(std::move(Args)...);
794  }
795
796private:
797  ClassT &Instance;
798  MethodT Method;
799};
800
801// Helper that provides a Functor for deserializing arguments.
802template <typename... ArgTs> class ReadArgs {
803public:
804  Error operator()() { return Error::success(); }
805};
806
807template <typename ArgT, typename... ArgTs>
808class ReadArgs<ArgT, ArgTs...> : public ReadArgs<ArgTs...> {
809public:
810  ReadArgs(ArgT &Arg, ArgTs &... Args)
811      : ReadArgs<ArgTs...>(Args...), Arg(Arg) {}
812
813  Error operator()(ArgT &ArgVal, ArgTs &... ArgVals) {
814    this->Arg = std::move(ArgVal);
815    return ReadArgs<ArgTs...>::operator()(ArgVals...);
816  }
817
818private:
819  ArgT &Arg;
820};
821
822// Manage sequence numbers.
823template <typename SequenceNumberT> class SequenceNumberManager {
824public:
825  // Reset, making all sequence numbers available.
826  void reset() {
827    std::lock_guard<std::mutex> Lock(SeqNoLock);
828    NextSequenceNumber = 0;
829    FreeSequenceNumbers.clear();
830  }
831
832  // Get the next available sequence number. Will re-use numbers that have
833  // been released.
834  SequenceNumberT getSequenceNumber() {
835    std::lock_guard<std::mutex> Lock(SeqNoLock);
836    if (FreeSequenceNumbers.empty())
837      return NextSequenceNumber++;
838    auto SequenceNumber = FreeSequenceNumbers.back();
839    FreeSequenceNumbers.pop_back();
840    return SequenceNumber;
841  }
842
843  // Release a sequence number, making it available for re-use.
844  void releaseSequenceNumber(SequenceNumberT SequenceNumber) {
845    std::lock_guard<std::mutex> Lock(SeqNoLock);
846    FreeSequenceNumbers.push_back(SequenceNumber);
847  }
848
849private:
850  std::mutex SeqNoLock;
851  SequenceNumberT NextSequenceNumber = 0;
852  std::vector<SequenceNumberT> FreeSequenceNumbers;
853};
854
855// Checks that predicate P holds for each corresponding pair of type arguments
856// from T1 and T2 tuple.
857template <template <class, class> class P, typename T1Tuple, typename T2Tuple>
858class RPCArgTypeCheckHelper;
859
860template <template <class, class> class P>
861class RPCArgTypeCheckHelper<P, std::tuple<>, std::tuple<>> {
862public:
863  static const bool value = true;
864};
865
866template <template <class, class> class P, typename T, typename... Ts,
867          typename U, typename... Us>
868class RPCArgTypeCheckHelper<P, std::tuple<T, Ts...>, std::tuple<U, Us...>> {
869public:
870  static const bool value =
871      P<T, U>::value &&
872      RPCArgTypeCheckHelper<P, std::tuple<Ts...>, std::tuple<Us...>>::value;
873};
874
875template <template <class, class> class P, typename T1Sig, typename T2Sig>
876class RPCArgTypeCheck {
877public:
878  using T1Tuple = typename FunctionArgsTuple<T1Sig>::Type;
879  using T2Tuple = typename FunctionArgsTuple<T2Sig>::Type;
880
881  static_assert(std::tuple_size<T1Tuple>::value >=
882                    std::tuple_size<T2Tuple>::value,
883                "Too many arguments to RPC call");
884  static_assert(std::tuple_size<T1Tuple>::value <=
885                    std::tuple_size<T2Tuple>::value,
886                "Too few arguments to RPC call");
887
888  static const bool value = RPCArgTypeCheckHelper<P, T1Tuple, T2Tuple>::value;
889};
890
891template <typename ChannelT, typename WireT, typename ConcreteT>
892class CanSerialize {
893private:
894  using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
895
896  template <typename T>
897  static std::true_type
898  check(typename std::enable_if<
899        std::is_same<decltype(T::serialize(std::declval<ChannelT &>(),
900                                           std::declval<const ConcreteT &>())),
901                     Error>::value,
902        void *>::type);
903
904  template <typename> static std::false_type check(...);
905
906public:
907  static const bool value = decltype(check<S>(0))::value;
908};
909
910template <typename ChannelT, typename WireT, typename ConcreteT>
911class CanDeserialize {
912private:
913  using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
914
915  template <typename T>
916  static std::true_type
917  check(typename std::enable_if<
918        std::is_same<decltype(T::deserialize(std::declval<ChannelT &>(),
919                                             std::declval<ConcreteT &>())),
920                     Error>::value,
921        void *>::type);
922
923  template <typename> static std::false_type check(...);
924
925public:
926  static const bool value = decltype(check<S>(0))::value;
927};
928
929/// Contains primitive utilities for defining, calling and handling calls to
930/// remote procedures. ChannelT is a bidirectional stream conforming to the
931/// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure
932/// identifier type that must be serializable on ChannelT, and SequenceNumberT
933/// is an integral type that will be used to number in-flight function calls.
934///
935/// These utilities support the construction of very primitive RPC utilities.
936/// Their intent is to ensure correct serialization and deserialization of
937/// procedure arguments, and to keep the client and server's view of the API in
938/// sync.
939template <typename ImplT, typename ChannelT, typename FunctionIdT,
940          typename SequenceNumberT>
941class RPCEndpointBase {
942protected:
943  class OrcRPCInvalid : public Function<OrcRPCInvalid, void()> {
944  public:
945    static const char *getName() { return "__orc_rpc$invalid"; }
946  };
947
948  class OrcRPCResponse : public Function<OrcRPCResponse, void()> {
949  public:
950    static const char *getName() { return "__orc_rpc$response"; }
951  };
952
953  class OrcRPCNegotiate
954      : public Function<OrcRPCNegotiate, FunctionIdT(std::string)> {
955  public:
956    static const char *getName() { return "__orc_rpc$negotiate"; }
957  };
958
959  // Helper predicate for testing for the presence of SerializeTraits
960  // serializers.
961  template <typename WireT, typename ConcreteT>
962  class CanSerializeCheck : detail::CanSerialize<ChannelT, WireT, ConcreteT> {
963  public:
964    using detail::CanSerialize<ChannelT, WireT, ConcreteT>::value;
965
966    static_assert(value, "Missing serializer for argument (Can't serialize the "
967                         "first template type argument of CanSerializeCheck "
968                         "from the second)");
969  };
970
971  // Helper predicate for testing for the presence of SerializeTraits
972  // deserializers.
973  template <typename WireT, typename ConcreteT>
974  class CanDeserializeCheck
975      : detail::CanDeserialize<ChannelT, WireT, ConcreteT> {
976  public:
977    using detail::CanDeserialize<ChannelT, WireT, ConcreteT>::value;
978
979    static_assert(value, "Missing deserializer for argument (Can't deserialize "
980                         "the second template type argument of "
981                         "CanDeserializeCheck from the first)");
982  };
983
984public:
985  /// Construct an RPC instance on a channel.
986  RPCEndpointBase(ChannelT &C, bool LazyAutoNegotiation)
987      : C(C), LazyAutoNegotiation(LazyAutoNegotiation) {
988    // Hold ResponseId in a special variable, since we expect Response to be
989    // called relatively frequently, and want to avoid the map lookup.
990    ResponseId = FnIdAllocator.getResponseId();
991    RemoteFunctionIds[OrcRPCResponse::getPrototype()] = ResponseId;
992
993    // Register the negotiate function id and handler.
994    auto NegotiateId = FnIdAllocator.getNegotiateId();
995    RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId;
996    Handlers[NegotiateId] = wrapHandler<OrcRPCNegotiate>(
997        [this](const std::string &Name) { return handleNegotiate(Name); });
998  }
999
1000
1001  /// Negotiate a function id for Func with the other end of the channel.
1002  template <typename Func> Error negotiateFunction(bool Retry = false) {
1003    return getRemoteFunctionId<Func>(true, Retry).takeError();
1004  }
1005
1006  /// Append a call Func, does not call send on the channel.
1007  /// The first argument specifies a user-defined handler to be run when the
1008  /// function returns. The handler should take an Expected<Func::ReturnType>,
1009  /// or an Error (if Func::ReturnType is void). The handler will be called
1010  /// with an error if the return value is abandoned due to a channel error.
1011  template <typename Func, typename HandlerT, typename... ArgTs>
1012  Error appendCallAsync(HandlerT Handler, const ArgTs &... Args) {
1013
1014    static_assert(
1015        detail::RPCArgTypeCheck<CanSerializeCheck, typename Func::Type,
1016                                void(ArgTs...)>::value,
1017        "");
1018
1019    // Look up the function ID.
1020    FunctionIdT FnId;
1021    if (auto FnIdOrErr = getRemoteFunctionId<Func>(LazyAutoNegotiation, false))
1022      FnId = *FnIdOrErr;
1023    else {
1024      // Negotiation failed. Notify the handler then return the negotiate-failed
1025      // error.
1026      cantFail(Handler(make_error<ResponseAbandoned>()));
1027      return FnIdOrErr.takeError();
1028    }
1029
1030    SequenceNumberT SeqNo; // initialized in locked scope below.
1031    {
1032      // Lock the pending responses map and sequence number manager.
1033      std::lock_guard<std::mutex> Lock(ResponsesMutex);
1034
1035      // Allocate a sequence number.
1036      SeqNo = SequenceNumberMgr.getSequenceNumber();
1037      assert(!PendingResponses.count(SeqNo) &&
1038             "Sequence number already allocated");
1039
1040      // Install the user handler.
1041      PendingResponses[SeqNo] =
1042        detail::createResponseHandler<ChannelT, typename Func::ReturnType>(
1043            std::move(Handler));
1044    }
1045
1046    // Open the function call message.
1047    if (auto Err = C.startSendMessage(FnId, SeqNo)) {
1048      abandonPendingResponses();
1049      return Err;
1050    }
1051
1052    // Serialize the call arguments.
1053    if (auto Err = detail::HandlerTraits<typename Func::Type>::serializeArgs(
1054            C, Args...)) {
1055      abandonPendingResponses();
1056      return Err;
1057    }
1058
1059    // Close the function call messagee.
1060    if (auto Err = C.endSendMessage()) {
1061      abandonPendingResponses();
1062      return Err;
1063    }
1064
1065    return Error::success();
1066  }
1067
1068  Error sendAppendedCalls() { return C.send(); };
1069
1070  template <typename Func, typename HandlerT, typename... ArgTs>
1071  Error callAsync(HandlerT Handler, const ArgTs &... Args) {
1072    if (auto Err = appendCallAsync<Func>(std::move(Handler), Args...))
1073      return Err;
1074    return C.send();
1075  }
1076
1077  /// Handle one incoming call.
1078  Error handleOne() {
1079    FunctionIdT FnId;
1080    SequenceNumberT SeqNo;
1081    if (auto Err = C.startReceiveMessage(FnId, SeqNo)) {
1082      abandonPendingResponses();
1083      return Err;
1084    }
1085    if (FnId == ResponseId)
1086      return handleResponse(SeqNo);
1087    auto I = Handlers.find(FnId);
1088    if (I != Handlers.end())
1089      return I->second(C, SeqNo);
1090
1091    // else: No handler found. Report error to client?
1092    return make_error<BadFunctionCall<FunctionIdT, SequenceNumberT>>(FnId,
1093                                                                     SeqNo);
1094  }
1095
1096  /// Helper for handling setter procedures - this method returns a functor that
1097  /// sets the variables referred to by Args... to values deserialized from the
1098  /// channel.
1099  /// E.g.
1100  ///
1101  ///   typedef Function<0, bool, int> Func1;
1102  ///
1103  ///   ...
1104  ///   bool B;
1105  ///   int I;
1106  ///   if (auto Err = expect<Func1>(Channel, readArgs(B, I)))
1107  ///     /* Handle Args */ ;
1108  ///
1109  template <typename... ArgTs>
1110  static detail::ReadArgs<ArgTs...> readArgs(ArgTs &... Args) {
1111    return detail::ReadArgs<ArgTs...>(Args...);
1112  }
1113
1114  /// Abandon all outstanding result handlers.
1115  ///
1116  /// This will call all currently registered result handlers to receive an
1117  /// "abandoned" error as their argument. This is used internally by the RPC
1118  /// in error situations, but can also be called directly by clients who are
1119  /// disconnecting from the remote and don't or can't expect responses to their
1120  /// outstanding calls. (Especially for outstanding blocking calls, calling
1121  /// this function may be necessary to avoid dead threads).
1122  void abandonPendingResponses() {
1123    // Lock the pending responses map and sequence number manager.
1124    std::lock_guard<std::mutex> Lock(ResponsesMutex);
1125
1126    for (auto &KV : PendingResponses)
1127      KV.second->abandon();
1128    PendingResponses.clear();
1129    SequenceNumberMgr.reset();
1130  }
1131
1132  /// Remove the handler for the given function.
1133  /// A handler must currently be registered for this function.
1134  template <typename Func>
1135  void removeHandler() {
1136    auto IdItr = LocalFunctionIds.find(Func::getPrototype());
1137    assert(IdItr != LocalFunctionIds.end() &&
1138           "Function does not have a registered handler");
1139    auto HandlerItr = Handlers.find(IdItr->second);
1140    assert(HandlerItr != Handlers.end() &&
1141           "Function does not have a registered handler");
1142    Handlers.erase(HandlerItr);
1143  }
1144
1145  /// Clear all handlers.
1146  void clearHandlers() {
1147    Handlers.clear();
1148  }
1149
1150protected:
1151
1152  FunctionIdT getInvalidFunctionId() const {
1153    return FnIdAllocator.getInvalidId();
1154  }
1155
1156  /// Add the given handler to the handler map and make it available for
1157  /// autonegotiation and execution.
1158  template <typename Func, typename HandlerT>
1159  void addHandlerImpl(HandlerT Handler) {
1160
1161    static_assert(detail::RPCArgTypeCheck<
1162                      CanDeserializeCheck, typename Func::Type,
1163                      typename detail::HandlerTraits<HandlerT>::Type>::value,
1164                  "");
1165
1166    FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
1167    LocalFunctionIds[Func::getPrototype()] = NewFnId;
1168    Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler));
1169  }
1170
1171  template <typename Func, typename HandlerT>
1172  void addAsyncHandlerImpl(HandlerT Handler) {
1173
1174    static_assert(detail::RPCArgTypeCheck<
1175                      CanDeserializeCheck, typename Func::Type,
1176                      typename detail::AsyncHandlerTraits<
1177                        typename detail::HandlerTraits<HandlerT>::Type
1178                      >::Type>::value,
1179                  "");
1180
1181    FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
1182    LocalFunctionIds[Func::getPrototype()] = NewFnId;
1183    Handlers[NewFnId] = wrapAsyncHandler<Func>(std::move(Handler));
1184  }
1185
1186  Error handleResponse(SequenceNumberT SeqNo) {
1187    using Handler = typename decltype(PendingResponses)::mapped_type;
1188    Handler PRHandler;
1189
1190    {
1191      // Lock the pending responses map and sequence number manager.
1192      std::unique_lock<std::mutex> Lock(ResponsesMutex);
1193      auto I = PendingResponses.find(SeqNo);
1194
1195      if (I != PendingResponses.end()) {
1196        PRHandler = std::move(I->second);
1197        PendingResponses.erase(I);
1198        SequenceNumberMgr.releaseSequenceNumber(SeqNo);
1199      } else {
1200        // Unlock the pending results map to prevent recursive lock.
1201        Lock.unlock();
1202        abandonPendingResponses();
1203        return make_error<
1204                 InvalidSequenceNumberForResponse<SequenceNumberT>>(SeqNo);
1205      }
1206    }
1207
1208    assert(PRHandler &&
1209           "If we didn't find a response handler we should have bailed out");
1210
1211    if (auto Err = PRHandler->handleResponse(C)) {
1212      abandonPendingResponses();
1213      return Err;
1214    }
1215
1216    return Error::success();
1217  }
1218
1219  FunctionIdT handleNegotiate(const std::string &Name) {
1220    auto I = LocalFunctionIds.find(Name);
1221    if (I == LocalFunctionIds.end())
1222      return getInvalidFunctionId();
1223    return I->second;
1224  }
1225
1226  // Find the remote FunctionId for the given function.
1227  template <typename Func>
1228  Expected<FunctionIdT> getRemoteFunctionId(bool NegotiateIfNotInMap,
1229                                            bool NegotiateIfInvalid) {
1230    bool DoNegotiate;
1231
1232    // Check if we already have a function id...
1233    auto I = RemoteFunctionIds.find(Func::getPrototype());
1234    if (I != RemoteFunctionIds.end()) {
1235      // If it's valid there's nothing left to do.
1236      if (I->second != getInvalidFunctionId())
1237        return I->second;
1238      DoNegotiate = NegotiateIfInvalid;
1239    } else
1240      DoNegotiate = NegotiateIfNotInMap;
1241
1242    // We don't have a function id for Func yet, but we're allowed to try to
1243    // negotiate one.
1244    if (DoNegotiate) {
1245      auto &Impl = static_cast<ImplT &>(*this);
1246      if (auto RemoteIdOrErr =
1247          Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) {
1248        RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr;
1249        if (*RemoteIdOrErr == getInvalidFunctionId())
1250          return make_error<CouldNotNegotiate>(Func::getPrototype());
1251        return *RemoteIdOrErr;
1252      } else
1253        return RemoteIdOrErr.takeError();
1254    }
1255
1256    // No key was available in the map and we weren't allowed to try to
1257    // negotiate one, so return an unknown function error.
1258    return make_error<CouldNotNegotiate>(Func::getPrototype());
1259  }
1260
1261  using WrappedHandlerFn = std::function<Error(ChannelT &, SequenceNumberT)>;
1262
1263  // Wrap the given user handler in the necessary argument-deserialization code,
1264  // result-serialization code, and call to the launch policy (if present).
1265  template <typename Func, typename HandlerT>
1266  WrappedHandlerFn wrapHandler(HandlerT Handler) {
1267    return [this, Handler](ChannelT &Channel,
1268                           SequenceNumberT SeqNo) mutable -> Error {
1269      // Start by deserializing the arguments.
1270      using ArgsTuple =
1271          typename detail::FunctionArgsTuple<
1272            typename detail::HandlerTraits<HandlerT>::Type>::Type;
1273      auto Args = std::make_shared<ArgsTuple>();
1274
1275      if (auto Err =
1276              detail::HandlerTraits<typename Func::Type>::deserializeArgs(
1277                  Channel, *Args))
1278        return Err;
1279
1280      // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
1281      // for RPCArgs. Void cast RPCArgs to work around this for now.
1282      // FIXME: Remove this workaround once we can assume a working GCC version.
1283      (void)Args;
1284
1285      // End receieve message, unlocking the channel for reading.
1286      if (auto Err = Channel.endReceiveMessage())
1287        return Err;
1288
1289      using HTraits = detail::HandlerTraits<HandlerT>;
1290      using FuncReturn = typename Func::ReturnType;
1291      return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo,
1292                                         HTraits::unpackAndRun(Handler, *Args));
1293    };
1294  }
1295
1296  // Wrap the given user handler in the necessary argument-deserialization code,
1297  // result-serialization code, and call to the launch policy (if present).
1298  template <typename Func, typename HandlerT>
1299  WrappedHandlerFn wrapAsyncHandler(HandlerT Handler) {
1300    return [this, Handler](ChannelT &Channel,
1301                           SequenceNumberT SeqNo) mutable -> Error {
1302      // Start by deserializing the arguments.
1303      using AHTraits = detail::AsyncHandlerTraits<
1304                         typename detail::HandlerTraits<HandlerT>::Type>;
1305      using ArgsTuple =
1306          typename detail::FunctionArgsTuple<typename AHTraits::Type>::Type;
1307      auto Args = std::make_shared<ArgsTuple>();
1308
1309      if (auto Err =
1310              detail::HandlerTraits<typename Func::Type>::deserializeArgs(
1311                  Channel, *Args))
1312        return Err;
1313
1314      // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
1315      // for RPCArgs. Void cast RPCArgs to work around this for now.
1316      // FIXME: Remove this workaround once we can assume a working GCC version.
1317      (void)Args;
1318
1319      // End receieve message, unlocking the channel for reading.
1320      if (auto Err = Channel.endReceiveMessage())
1321        return Err;
1322
1323      using HTraits = detail::HandlerTraits<HandlerT>;
1324      using FuncReturn = typename Func::ReturnType;
1325      auto Responder =
1326        [this, SeqNo](typename AHTraits::ResultType RetVal) -> Error {
1327          return detail::respond<FuncReturn>(C, ResponseId, SeqNo,
1328                                             std::move(RetVal));
1329        };
1330
1331      return HTraits::unpackAndRunAsync(Handler, Responder, *Args);
1332    };
1333  }
1334
1335  ChannelT &C;
1336
1337  bool LazyAutoNegotiation;
1338
1339  RPCFunctionIdAllocator<FunctionIdT> FnIdAllocator;
1340
1341  FunctionIdT ResponseId;
1342  std::map<std::string, FunctionIdT> LocalFunctionIds;
1343  std::map<const char *, FunctionIdT> RemoteFunctionIds;
1344
1345  std::map<FunctionIdT, WrappedHandlerFn> Handlers;
1346
1347  std::mutex ResponsesMutex;
1348  detail::SequenceNumberManager<SequenceNumberT> SequenceNumberMgr;
1349  std::map<SequenceNumberT, std::unique_ptr<detail::ResponseHandler<ChannelT>>>
1350      PendingResponses;
1351};
1352
1353} // end namespace detail
1354
1355template <typename ChannelT, typename FunctionIdT = uint32_t,
1356          typename SequenceNumberT = uint32_t>
1357class MultiThreadedRPCEndpoint
1358    : public detail::RPCEndpointBase<
1359          MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1360          ChannelT, FunctionIdT, SequenceNumberT> {
1361private:
1362  using BaseClass =
1363      detail::RPCEndpointBase<
1364        MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1365        ChannelT, FunctionIdT, SequenceNumberT>;
1366
1367public:
1368  MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
1369      : BaseClass(C, LazyAutoNegotiation) {}
1370
1371  /// Add a handler for the given RPC function.
1372  /// This installs the given handler functor for the given RPC Function, and
1373  /// makes the RPC function available for negotiation/calling from the remote.
1374  template <typename Func, typename HandlerT>
1375  void addHandler(HandlerT Handler) {
1376    return this->template addHandlerImpl<Func>(std::move(Handler));
1377  }
1378
1379  /// Add a class-method as a handler.
1380  template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
1381  void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1382    addHandler<Func>(
1383      detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1384  }
1385
1386  template <typename Func, typename HandlerT>
1387  void addAsyncHandler(HandlerT Handler) {
1388    return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
1389  }
1390
1391  /// Add a class-method as a handler.
1392  template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
1393  void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1394    addAsyncHandler<Func>(
1395      detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1396  }
1397
1398  /// Return type for non-blocking call primitives.
1399  template <typename Func>
1400  using NonBlockingCallResult = typename detail::ResultTraits<
1401      typename Func::ReturnType>::ReturnFutureType;
1402
1403  /// Call Func on Channel C. Does not block, does not call send. Returns a pair
1404  /// of a future result and the sequence number assigned to the result.
1405  ///
1406  /// This utility function is primarily used for single-threaded mode support,
1407  /// where the sequence number can be used to wait for the corresponding
1408  /// result. In multi-threaded mode the appendCallNB method, which does not
1409  /// return the sequence numeber, should be preferred.
1410  template <typename Func, typename... ArgTs>
1411  Expected<NonBlockingCallResult<Func>> appendCallNB(const ArgTs &... Args) {
1412    using RTraits = detail::ResultTraits<typename Func::ReturnType>;
1413    using ErrorReturn = typename RTraits::ErrorReturnType;
1414    using ErrorReturnPromise = typename RTraits::ReturnPromiseType;
1415
1416    ErrorReturnPromise Promise;
1417    auto FutureResult = Promise.get_future();
1418
1419    if (auto Err = this->template appendCallAsync<Func>(
1420            [Promise = std::move(Promise)](ErrorReturn RetOrErr) mutable {
1421              Promise.set_value(std::move(RetOrErr));
1422              return Error::success();
1423            },
1424            Args...)) {
1425      RTraits::consumeAbandoned(FutureResult.get());
1426      return std::move(Err);
1427    }
1428    return std::move(FutureResult);
1429  }
1430
1431  /// The same as appendCallNBWithSeq, except that it calls C.send() to
1432  /// flush the channel after serializing the call.
1433  template <typename Func, typename... ArgTs>
1434  Expected<NonBlockingCallResult<Func>> callNB(const ArgTs &... Args) {
1435    auto Result = appendCallNB<Func>(Args...);
1436    if (!Result)
1437      return Result;
1438    if (auto Err = this->C.send()) {
1439      this->abandonPendingResponses();
1440      detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1441          std::move(Result->get()));
1442      return std::move(Err);
1443    }
1444    return Result;
1445  }
1446
1447  /// Call Func on Channel C. Blocks waiting for a result. Returns an Error
1448  /// for void functions or an Expected<T> for functions returning a T.
1449  ///
1450  /// This function is for use in threaded code where another thread is
1451  /// handling responses and incoming calls.
1452  template <typename Func, typename... ArgTs,
1453            typename AltRetT = typename Func::ReturnType>
1454  typename detail::ResultTraits<AltRetT>::ErrorReturnType
1455  callB(const ArgTs &... Args) {
1456    if (auto FutureResOrErr = callNB<Func>(Args...))
1457      return FutureResOrErr->get();
1458    else
1459      return FutureResOrErr.takeError();
1460  }
1461
1462  /// Handle incoming RPC calls.
1463  Error handlerLoop() {
1464    while (true)
1465      if (auto Err = this->handleOne())
1466        return Err;
1467    return Error::success();
1468  }
1469};
1470
1471template <typename ChannelT, typename FunctionIdT = uint32_t,
1472          typename SequenceNumberT = uint32_t>
1473class SingleThreadedRPCEndpoint
1474    : public detail::RPCEndpointBase<
1475          SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1476          ChannelT, FunctionIdT, SequenceNumberT> {
1477private:
1478  using BaseClass =
1479      detail::RPCEndpointBase<
1480        SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1481        ChannelT, FunctionIdT, SequenceNumberT>;
1482
1483public:
1484  SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
1485      : BaseClass(C, LazyAutoNegotiation) {}
1486
1487  template <typename Func, typename HandlerT>
1488  void addHandler(HandlerT Handler) {
1489    return this->template addHandlerImpl<Func>(std::move(Handler));
1490  }
1491
1492  template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
1493  void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1494    addHandler<Func>(
1495        detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1496  }
1497
1498  template <typename Func, typename HandlerT>
1499  void addAsyncHandler(HandlerT Handler) {
1500    return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
1501  }
1502
1503  /// Add a class-method as a handler.
1504  template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
1505  void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1506    addAsyncHandler<Func>(
1507      detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1508  }
1509
1510  template <typename Func, typename... ArgTs,
1511            typename AltRetT = typename Func::ReturnType>
1512  typename detail::ResultTraits<AltRetT>::ErrorReturnType
1513  callB(const ArgTs &... Args) {
1514    bool ReceivedResponse = false;
1515    using ResultType = typename detail::ResultTraits<AltRetT>::ErrorReturnType;
1516    auto Result = detail::ResultTraits<AltRetT>::createBlankErrorReturnValue();
1517
1518    // We have to 'Check' result (which we know is in a success state at this
1519    // point) so that it can be overwritten in the async handler.
1520    (void)!!Result;
1521
1522    if (auto Err = this->template appendCallAsync<Func>(
1523            [&](ResultType R) {
1524              Result = std::move(R);
1525              ReceivedResponse = true;
1526              return Error::success();
1527            },
1528            Args...)) {
1529      detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1530          std::move(Result));
1531      return std::move(Err);
1532    }
1533
1534    if (auto Err = this->C.send()) {
1535      detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1536          std::move(Result));
1537      return std::move(Err);
1538    }
1539
1540    while (!ReceivedResponse) {
1541      if (auto Err = this->handleOne()) {
1542        detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1543            std::move(Result));
1544        return std::move(Err);
1545      }
1546    }
1547
1548    return Result;
1549  }
1550};
1551
1552/// Asynchronous dispatch for a function on an RPC endpoint.
1553template <typename RPCClass, typename Func>
1554class RPCAsyncDispatch {
1555public:
1556  RPCAsyncDispatch(RPCClass &Endpoint) : Endpoint(Endpoint) {}
1557
1558  template <typename HandlerT, typename... ArgTs>
1559  Error operator()(HandlerT Handler, const ArgTs &... Args) const {
1560    return Endpoint.template appendCallAsync<Func>(std::move(Handler), Args...);
1561  }
1562
1563private:
1564  RPCClass &Endpoint;
1565};
1566
1567/// Construct an asynchronous dispatcher from an RPC endpoint and a Func.
1568template <typename Func, typename RPCEndpointT>
1569RPCAsyncDispatch<RPCEndpointT, Func> rpcAsyncDispatch(RPCEndpointT &Endpoint) {
1570  return RPCAsyncDispatch<RPCEndpointT, Func>(Endpoint);
1571}
1572
1573/// Allows a set of asynchrounous calls to be dispatched, and then
1574///        waited on as a group.
1575class ParallelCallGroup {
1576public:
1577
1578  ParallelCallGroup() = default;
1579  ParallelCallGroup(const ParallelCallGroup &) = delete;
1580  ParallelCallGroup &operator=(const ParallelCallGroup &) = delete;
1581
1582  /// Make as asynchronous call.
1583  template <typename AsyncDispatcher, typename HandlerT, typename... ArgTs>
1584  Error call(const AsyncDispatcher &AsyncDispatch, HandlerT Handler,
1585             const ArgTs &... Args) {
1586    // Increment the count of outstanding calls. This has to happen before
1587    // we invoke the call, as the handler may (depending on scheduling)
1588    // be run immediately on another thread, and we don't want the decrement
1589    // in the wrapped handler below to run before the increment.
1590    {
1591      std::unique_lock<std::mutex> Lock(M);
1592      ++NumOutstandingCalls;
1593    }
1594
1595    // Wrap the user handler in a lambda that will decrement the
1596    // outstanding calls count, then poke the condition variable.
1597    using ArgType = typename detail::ResponseHandlerArg<
1598        typename detail::HandlerTraits<HandlerT>::Type>::ArgType;
1599    auto WrappedHandler = [this, Handler = std::move(Handler)](ArgType Arg) {
1600      auto Err = Handler(std::move(Arg));
1601      std::unique_lock<std::mutex> Lock(M);
1602      --NumOutstandingCalls;
1603      CV.notify_all();
1604      return Err;
1605    };
1606
1607    return AsyncDispatch(std::move(WrappedHandler), Args...);
1608  }
1609
1610  /// Blocks until all calls have been completed and their return value
1611  ///        handlers run.
1612  void wait() {
1613    std::unique_lock<std::mutex> Lock(M);
1614    while (NumOutstandingCalls > 0)
1615      CV.wait(Lock);
1616  }
1617
1618private:
1619  std::mutex M;
1620  std::condition_variable CV;
1621  uint32_t NumOutstandingCalls = 0;
1622};
1623
1624/// Convenience class for grouping RPC Functions into APIs that can be
1625///        negotiated as a block.
1626///
1627template <typename... Funcs>
1628class APICalls {
1629public:
1630
1631  /// Test whether this API contains Function F.
1632  template <typename F>
1633  class Contains {
1634  public:
1635    static const bool value = false;
1636  };
1637
1638  /// Negotiate all functions in this API.
1639  template <typename RPCEndpoint>
1640  static Error negotiate(RPCEndpoint &R) {
1641    return Error::success();
1642  }
1643};
1644
1645template <typename Func, typename... Funcs>
1646class APICalls<Func, Funcs...> {
1647public:
1648
1649  template <typename F>
1650  class Contains {
1651  public:
1652    static const bool value = std::is_same<F, Func>::value |
1653                              APICalls<Funcs...>::template Contains<F>::value;
1654  };
1655
1656  template <typename RPCEndpoint>
1657  static Error negotiate(RPCEndpoint &R) {
1658    if (auto Err = R.template negotiateFunction<Func>())
1659      return Err;
1660    return APICalls<Funcs...>::negotiate(R);
1661  }
1662
1663};
1664
1665template <typename... InnerFuncs, typename... Funcs>
1666class APICalls<APICalls<InnerFuncs...>, Funcs...> {
1667public:
1668
1669  template <typename F>
1670  class Contains {
1671  public:
1672    static const bool value =
1673      APICalls<InnerFuncs...>::template Contains<F>::value |
1674      APICalls<Funcs...>::template Contains<F>::value;
1675  };
1676
1677  template <typename RPCEndpoint>
1678  static Error negotiate(RPCEndpoint &R) {
1679    if (auto Err = APICalls<InnerFuncs...>::negotiate(R))
1680      return Err;
1681    return APICalls<Funcs...>::negotiate(R);
1682  }
1683
1684};
1685
1686} // end namespace rpc
1687} // end namespace orc
1688} // end namespace llvm
1689
1690#endif
1691