1//===- RPCUtils.h - Utilities for building RPC APIs -------------*- C++ -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8// 9// Utilities to support construction of simple RPC APIs. 10// 11// The RPC utilities aim for ease of use (minimal conceptual overhead) for C++ 12// programmers, high performance, low memory overhead, and efficient use of the 13// communications channel. 14// 15//===----------------------------------------------------------------------===// 16 17#ifndef LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H 18#define LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H 19 20#include <map> 21#include <thread> 22#include <vector> 23 24#include "llvm/ADT/STLExtras.h" 25#include "llvm/ExecutionEngine/Orc/OrcError.h" 26#include "llvm/ExecutionEngine/Orc/RPC/RPCSerialization.h" 27#include "llvm/Support/MSVCErrorWorkarounds.h" 28 29#include <future> 30 31namespace llvm { 32namespace orc { 33namespace rpc { 34 35/// Base class of all fatal RPC errors (those that necessarily result in the 36/// termination of the RPC session). 37class RPCFatalError : public ErrorInfo<RPCFatalError> { 38public: 39 static char ID; 40}; 41 42/// RPCConnectionClosed is returned from RPC operations if the RPC connection 43/// has already been closed due to either an error or graceful disconnection. 44class ConnectionClosed : public ErrorInfo<ConnectionClosed> { 45public: 46 static char ID; 47 std::error_code convertToErrorCode() const override; 48 void log(raw_ostream &OS) const override; 49}; 50 51/// BadFunctionCall is returned from handleOne when the remote makes a call with 52/// an unrecognized function id. 53/// 54/// This error is fatal because Orc RPC needs to know how to parse a function 55/// call to know where the next call starts, and if it doesn't recognize the 56/// function id it cannot parse the call. 57template <typename FnIdT, typename SeqNoT> 58class BadFunctionCall 59 : public ErrorInfo<BadFunctionCall<FnIdT, SeqNoT>, RPCFatalError> { 60public: 61 static char ID; 62 63 BadFunctionCall(FnIdT FnId, SeqNoT SeqNo) 64 : FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {} 65 66 std::error_code convertToErrorCode() const override { 67 return orcError(OrcErrorCode::UnexpectedRPCCall); 68 } 69 70 void log(raw_ostream &OS) const override { 71 OS << "Call to invalid RPC function id '" << FnId << "' with " 72 "sequence number " << SeqNo; 73 } 74 75private: 76 FnIdT FnId; 77 SeqNoT SeqNo; 78}; 79 80template <typename FnIdT, typename SeqNoT> 81char BadFunctionCall<FnIdT, SeqNoT>::ID = 0; 82 83/// InvalidSequenceNumberForResponse is returned from handleOne when a response 84/// call arrives with a sequence number that doesn't correspond to any in-flight 85/// function call. 86/// 87/// This error is fatal because Orc RPC needs to know how to parse the rest of 88/// the response call to know where the next call starts, and if it doesn't have 89/// a result parser for this sequence number it can't do that. 90template <typename SeqNoT> 91class InvalidSequenceNumberForResponse 92 : public ErrorInfo<InvalidSequenceNumberForResponse<SeqNoT>, RPCFatalError> { 93public: 94 static char ID; 95 96 InvalidSequenceNumberForResponse(SeqNoT SeqNo) 97 : SeqNo(std::move(SeqNo)) {} 98 99 std::error_code convertToErrorCode() const override { 100 return orcError(OrcErrorCode::UnexpectedRPCCall); 101 }; 102 103 void log(raw_ostream &OS) const override { 104 OS << "Response has unknown sequence number " << SeqNo; 105 } 106private: 107 SeqNoT SeqNo; 108}; 109 110template <typename SeqNoT> 111char InvalidSequenceNumberForResponse<SeqNoT>::ID = 0; 112 113/// This non-fatal error will be passed to asynchronous result handlers in place 114/// of a result if the connection goes down before a result returns, or if the 115/// function to be called cannot be negotiated with the remote. 116class ResponseAbandoned : public ErrorInfo<ResponseAbandoned> { 117public: 118 static char ID; 119 120 std::error_code convertToErrorCode() const override; 121 void log(raw_ostream &OS) const override; 122}; 123 124/// This error is returned if the remote does not have a handler installed for 125/// the given RPC function. 126class CouldNotNegotiate : public ErrorInfo<CouldNotNegotiate> { 127public: 128 static char ID; 129 130 CouldNotNegotiate(std::string Signature); 131 std::error_code convertToErrorCode() const override; 132 void log(raw_ostream &OS) const override; 133 const std::string &getSignature() const { return Signature; } 134private: 135 std::string Signature; 136}; 137 138template <typename DerivedFunc, typename FnT> class Function; 139 140// RPC Function class. 141// DerivedFunc should be a user defined class with a static 'getName()' method 142// returning a const char* representing the function's name. 143template <typename DerivedFunc, typename RetT, typename... ArgTs> 144class Function<DerivedFunc, RetT(ArgTs...)> { 145public: 146 /// User defined function type. 147 using Type = RetT(ArgTs...); 148 149 /// Return type. 150 using ReturnType = RetT; 151 152 /// Returns the full function prototype as a string. 153 static const char *getPrototype() { 154 static std::string Name = [] { 155 std::string Name; 156 raw_string_ostream(Name) 157 << RPCTypeName<RetT>::getName() << " " << DerivedFunc::getName() 158 << "(" << llvm::orc::rpc::RPCTypeNameSequence<ArgTs...>() << ")"; 159 return Name; 160 }(); 161 return Name.data(); 162 } 163}; 164 165/// Allocates RPC function ids during autonegotiation. 166/// Specializations of this class must provide four members: 167/// 168/// static T getInvalidId(): 169/// Should return a reserved id that will be used to represent missing 170/// functions during autonegotiation. 171/// 172/// static T getResponseId(): 173/// Should return a reserved id that will be used to send function responses 174/// (return values). 175/// 176/// static T getNegotiateId(): 177/// Should return a reserved id for the negotiate function, which will be used 178/// to negotiate ids for user defined functions. 179/// 180/// template <typename Func> T allocate(): 181/// Allocate a unique id for function Func. 182template <typename T, typename = void> class RPCFunctionIdAllocator; 183 184/// This specialization of RPCFunctionIdAllocator provides a default 185/// implementation for integral types. 186template <typename T> 187class RPCFunctionIdAllocator< 188 T, typename std::enable_if<std::is_integral<T>::value>::type> { 189public: 190 static T getInvalidId() { return T(0); } 191 static T getResponseId() { return T(1); } 192 static T getNegotiateId() { return T(2); } 193 194 template <typename Func> T allocate() { return NextId++; } 195 196private: 197 T NextId = 3; 198}; 199 200namespace detail { 201 202/// Provides a typedef for a tuple containing the decayed argument types. 203template <typename T> class FunctionArgsTuple; 204 205template <typename RetT, typename... ArgTs> 206class FunctionArgsTuple<RetT(ArgTs...)> { 207public: 208 using Type = std::tuple<typename std::decay< 209 typename std::remove_reference<ArgTs>::type>::type...>; 210}; 211 212// ResultTraits provides typedefs and utilities specific to the return type 213// of functions. 214template <typename RetT> class ResultTraits { 215public: 216 // The return type wrapped in llvm::Expected. 217 using ErrorReturnType = Expected<RetT>; 218 219#ifdef _MSC_VER 220 // The ErrorReturnType wrapped in a std::promise. 221 using ReturnPromiseType = std::promise<MSVCPExpected<RetT>>; 222 223 // The ErrorReturnType wrapped in a std::future. 224 using ReturnFutureType = std::future<MSVCPExpected<RetT>>; 225#else 226 // The ErrorReturnType wrapped in a std::promise. 227 using ReturnPromiseType = std::promise<ErrorReturnType>; 228 229 // The ErrorReturnType wrapped in a std::future. 230 using ReturnFutureType = std::future<ErrorReturnType>; 231#endif 232 233 // Create a 'blank' value of the ErrorReturnType, ready and safe to 234 // overwrite. 235 static ErrorReturnType createBlankErrorReturnValue() { 236 return ErrorReturnType(RetT()); 237 } 238 239 // Consume an abandoned ErrorReturnType. 240 static void consumeAbandoned(ErrorReturnType RetOrErr) { 241 consumeError(RetOrErr.takeError()); 242 } 243}; 244 245// ResultTraits specialization for void functions. 246template <> class ResultTraits<void> { 247public: 248 // For void functions, ErrorReturnType is llvm::Error. 249 using ErrorReturnType = Error; 250 251#ifdef _MSC_VER 252 // The ErrorReturnType wrapped in a std::promise. 253 using ReturnPromiseType = std::promise<MSVCPError>; 254 255 // The ErrorReturnType wrapped in a std::future. 256 using ReturnFutureType = std::future<MSVCPError>; 257#else 258 // The ErrorReturnType wrapped in a std::promise. 259 using ReturnPromiseType = std::promise<ErrorReturnType>; 260 261 // The ErrorReturnType wrapped in a std::future. 262 using ReturnFutureType = std::future<ErrorReturnType>; 263#endif 264 265 // Create a 'blank' value of the ErrorReturnType, ready and safe to 266 // overwrite. 267 static ErrorReturnType createBlankErrorReturnValue() { 268 return ErrorReturnType::success(); 269 } 270 271 // Consume an abandoned ErrorReturnType. 272 static void consumeAbandoned(ErrorReturnType Err) { 273 consumeError(std::move(Err)); 274 } 275}; 276 277// ResultTraits<Error> is equivalent to ResultTraits<void>. This allows 278// handlers for void RPC functions to return either void (in which case they 279// implicitly succeed) or Error (in which case their error return is 280// propagated). See usage in HandlerTraits::runHandlerHelper. 281template <> class ResultTraits<Error> : public ResultTraits<void> {}; 282 283// ResultTraits<Expected<T>> is equivalent to ResultTraits<T>. This allows 284// handlers for RPC functions returning a T to return either a T (in which 285// case they implicitly succeed) or Expected<T> (in which case their error 286// return is propagated). See usage in HandlerTraits::runHandlerHelper. 287template <typename RetT> 288class ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {}; 289 290// Determines whether an RPC function's defined error return type supports 291// error return value. 292template <typename T> 293class SupportsErrorReturn { 294public: 295 static const bool value = false; 296}; 297 298template <> 299class SupportsErrorReturn<Error> { 300public: 301 static const bool value = true; 302}; 303 304template <typename T> 305class SupportsErrorReturn<Expected<T>> { 306public: 307 static const bool value = true; 308}; 309 310// RespondHelper packages return values based on whether or not the declared 311// RPC function return type supports error returns. 312template <bool FuncSupportsErrorReturn> 313class RespondHelper; 314 315// RespondHelper specialization for functions that support error returns. 316template <> 317class RespondHelper<true> { 318public: 319 320 // Send Expected<T>. 321 template <typename WireRetT, typename HandlerRetT, typename ChannelT, 322 typename FunctionIdT, typename SequenceNumberT> 323 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, 324 SequenceNumberT SeqNo, 325 Expected<HandlerRetT> ResultOrErr) { 326 if (!ResultOrErr && ResultOrErr.template errorIsA<RPCFatalError>()) 327 return ResultOrErr.takeError(); 328 329 // Open the response message. 330 if (auto Err = C.startSendMessage(ResponseId, SeqNo)) 331 return Err; 332 333 // Serialize the result. 334 if (auto Err = 335 SerializationTraits<ChannelT, WireRetT, 336 Expected<HandlerRetT>>::serialize( 337 C, std::move(ResultOrErr))) 338 return Err; 339 340 // Close the response message. 341 if (auto Err = C.endSendMessage()) 342 return Err; 343 return C.send(); 344 } 345 346 template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT> 347 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, 348 SequenceNumberT SeqNo, Error Err) { 349 if (Err && Err.isA<RPCFatalError>()) 350 return Err; 351 if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) 352 return Err2; 353 if (auto Err2 = serializeSeq(C, std::move(Err))) 354 return Err2; 355 if (auto Err2 = C.endSendMessage()) 356 return Err2; 357 return C.send(); 358 } 359 360}; 361 362// RespondHelper specialization for functions that do not support error returns. 363template <> 364class RespondHelper<false> { 365public: 366 367 template <typename WireRetT, typename HandlerRetT, typename ChannelT, 368 typename FunctionIdT, typename SequenceNumberT> 369 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, 370 SequenceNumberT SeqNo, 371 Expected<HandlerRetT> ResultOrErr) { 372 if (auto Err = ResultOrErr.takeError()) 373 return Err; 374 375 // Open the response message. 376 if (auto Err = C.startSendMessage(ResponseId, SeqNo)) 377 return Err; 378 379 // Serialize the result. 380 if (auto Err = 381 SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize( 382 C, *ResultOrErr)) 383 return Err; 384 385 // End the response message. 386 if (auto Err = C.endSendMessage()) 387 return Err; 388 389 return C.send(); 390 } 391 392 template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT> 393 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, 394 SequenceNumberT SeqNo, Error Err) { 395 if (Err) 396 return Err; 397 if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) 398 return Err2; 399 if (auto Err2 = C.endSendMessage()) 400 return Err2; 401 return C.send(); 402 } 403 404}; 405 406 407// Send a response of the given wire return type (WireRetT) over the 408// channel, with the given sequence number. 409template <typename WireRetT, typename HandlerRetT, typename ChannelT, 410 typename FunctionIdT, typename SequenceNumberT> 411Error respond(ChannelT &C, const FunctionIdT &ResponseId, 412 SequenceNumberT SeqNo, Expected<HandlerRetT> ResultOrErr) { 413 return RespondHelper<SupportsErrorReturn<WireRetT>::value>:: 414 template sendResult<WireRetT>(C, ResponseId, SeqNo, std::move(ResultOrErr)); 415} 416 417// Send an empty response message on the given channel to indicate that 418// the handler ran. 419template <typename WireRetT, typename ChannelT, typename FunctionIdT, 420 typename SequenceNumberT> 421Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo, 422 Error Err) { 423 return RespondHelper<SupportsErrorReturn<WireRetT>::value>:: 424 sendResult(C, ResponseId, SeqNo, std::move(Err)); 425} 426 427// Converts a given type to the equivalent error return type. 428template <typename T> class WrappedHandlerReturn { 429public: 430 using Type = Expected<T>; 431}; 432 433template <typename T> class WrappedHandlerReturn<Expected<T>> { 434public: 435 using Type = Expected<T>; 436}; 437 438template <> class WrappedHandlerReturn<void> { 439public: 440 using Type = Error; 441}; 442 443template <> class WrappedHandlerReturn<Error> { 444public: 445 using Type = Error; 446}; 447 448template <> class WrappedHandlerReturn<ErrorSuccess> { 449public: 450 using Type = Error; 451}; 452 453// Traits class that strips the response function from the list of handler 454// arguments. 455template <typename FnT> class AsyncHandlerTraits; 456 457template <typename ResultT, typename... ArgTs> 458class AsyncHandlerTraits<Error(std::function<Error(Expected<ResultT>)>, ArgTs...)> { 459public: 460 using Type = Error(ArgTs...); 461 using ResultType = Expected<ResultT>; 462}; 463 464template <typename... ArgTs> 465class AsyncHandlerTraits<Error(std::function<Error(Error)>, ArgTs...)> { 466public: 467 using Type = Error(ArgTs...); 468 using ResultType = Error; 469}; 470 471template <typename... ArgTs> 472class AsyncHandlerTraits<ErrorSuccess(std::function<Error(Error)>, ArgTs...)> { 473public: 474 using Type = Error(ArgTs...); 475 using ResultType = Error; 476}; 477 478template <typename... ArgTs> 479class AsyncHandlerTraits<void(std::function<Error(Error)>, ArgTs...)> { 480public: 481 using Type = Error(ArgTs...); 482 using ResultType = Error; 483}; 484 485template <typename ResponseHandlerT, typename... ArgTs> 486class AsyncHandlerTraits<Error(ResponseHandlerT, ArgTs...)> : 487 public AsyncHandlerTraits<Error(typename std::decay<ResponseHandlerT>::type, 488 ArgTs...)> {}; 489 490// This template class provides utilities related to RPC function handlers. 491// The base case applies to non-function types (the template class is 492// specialized for function types) and inherits from the appropriate 493// speciilization for the given non-function type's call operator. 494template <typename HandlerT> 495class HandlerTraits : public HandlerTraits<decltype( 496 &std::remove_reference<HandlerT>::type::operator())> { 497}; 498 499// Traits for handlers with a given function type. 500template <typename RetT, typename... ArgTs> 501class HandlerTraits<RetT(ArgTs...)> { 502public: 503 // Function type of the handler. 504 using Type = RetT(ArgTs...); 505 506 // Return type of the handler. 507 using ReturnType = RetT; 508 509 // Call the given handler with the given arguments. 510 template <typename HandlerT, typename... TArgTs> 511 static typename WrappedHandlerReturn<RetT>::Type 512 unpackAndRun(HandlerT &Handler, std::tuple<TArgTs...> &Args) { 513 return unpackAndRunHelper(Handler, Args, 514 std::index_sequence_for<TArgTs...>()); 515 } 516 517 // Call the given handler with the given arguments. 518 template <typename HandlerT, typename ResponderT, typename... TArgTs> 519 static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder, 520 std::tuple<TArgTs...> &Args) { 521 return unpackAndRunAsyncHelper(Handler, Responder, Args, 522 std::index_sequence_for<TArgTs...>()); 523 } 524 525 // Call the given handler with the given arguments. 526 template <typename HandlerT> 527 static typename std::enable_if< 528 std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value, 529 Error>::type 530 run(HandlerT &Handler, ArgTs &&... Args) { 531 Handler(std::move(Args)...); 532 return Error::success(); 533 } 534 535 template <typename HandlerT, typename... TArgTs> 536 static typename std::enable_if< 537 !std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value, 538 typename HandlerTraits<HandlerT>::ReturnType>::type 539 run(HandlerT &Handler, TArgTs... Args) { 540 return Handler(std::move(Args)...); 541 } 542 543 // Serialize arguments to the channel. 544 template <typename ChannelT, typename... CArgTs> 545 static Error serializeArgs(ChannelT &C, const CArgTs... CArgs) { 546 return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...); 547 } 548 549 // Deserialize arguments from the channel. 550 template <typename ChannelT, typename... CArgTs> 551 static Error deserializeArgs(ChannelT &C, std::tuple<CArgTs...> &Args) { 552 return deserializeArgsHelper(C, Args, std::index_sequence_for<CArgTs...>()); 553 } 554 555private: 556 template <typename ChannelT, typename... CArgTs, size_t... Indexes> 557 static Error deserializeArgsHelper(ChannelT &C, std::tuple<CArgTs...> &Args, 558 std::index_sequence<Indexes...> _) { 559 return SequenceSerialization<ChannelT, ArgTs...>::deserialize( 560 C, std::get<Indexes>(Args)...); 561 } 562 563 template <typename HandlerT, typename ArgTuple, size_t... Indexes> 564 static typename WrappedHandlerReturn< 565 typename HandlerTraits<HandlerT>::ReturnType>::Type 566 unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args, 567 std::index_sequence<Indexes...>) { 568 return run(Handler, std::move(std::get<Indexes>(Args))...); 569 } 570 571 template <typename HandlerT, typename ResponderT, typename ArgTuple, 572 size_t... Indexes> 573 static typename WrappedHandlerReturn< 574 typename HandlerTraits<HandlerT>::ReturnType>::Type 575 unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder, 576 ArgTuple &Args, std::index_sequence<Indexes...>) { 577 return run(Handler, Responder, std::move(std::get<Indexes>(Args))...); 578 } 579}; 580 581// Handler traits for free functions. 582template <typename RetT, typename... ArgTs> 583class HandlerTraits<RetT(*)(ArgTs...)> 584 : public HandlerTraits<RetT(ArgTs...)> {}; 585 586// Handler traits for class methods (especially call operators for lambdas). 587template <typename Class, typename RetT, typename... ArgTs> 588class HandlerTraits<RetT (Class::*)(ArgTs...)> 589 : public HandlerTraits<RetT(ArgTs...)> {}; 590 591// Handler traits for const class methods (especially call operators for 592// lambdas). 593template <typename Class, typename RetT, typename... ArgTs> 594class HandlerTraits<RetT (Class::*)(ArgTs...) const> 595 : public HandlerTraits<RetT(ArgTs...)> {}; 596 597// Utility to peel the Expected wrapper off a response handler error type. 598template <typename HandlerT> class ResponseHandlerArg; 599 600template <typename ArgT> class ResponseHandlerArg<Error(Expected<ArgT>)> { 601public: 602 using ArgType = Expected<ArgT>; 603 using UnwrappedArgType = ArgT; 604}; 605 606template <typename ArgT> 607class ResponseHandlerArg<ErrorSuccess(Expected<ArgT>)> { 608public: 609 using ArgType = Expected<ArgT>; 610 using UnwrappedArgType = ArgT; 611}; 612 613template <> class ResponseHandlerArg<Error(Error)> { 614public: 615 using ArgType = Error; 616}; 617 618template <> class ResponseHandlerArg<ErrorSuccess(Error)> { 619public: 620 using ArgType = Error; 621}; 622 623// ResponseHandler represents a handler for a not-yet-received function call 624// result. 625template <typename ChannelT> class ResponseHandler { 626public: 627 virtual ~ResponseHandler() {} 628 629 // Reads the function result off the wire and acts on it. The meaning of 630 // "act" will depend on how this method is implemented in any given 631 // ResponseHandler subclass but could, for example, mean running a 632 // user-specified handler or setting a promise value. 633 virtual Error handleResponse(ChannelT &C) = 0; 634 635 // Abandons this outstanding result. 636 virtual void abandon() = 0; 637 638 // Create an error instance representing an abandoned response. 639 static Error createAbandonedResponseError() { 640 return make_error<ResponseAbandoned>(); 641 } 642}; 643 644// ResponseHandler subclass for RPC functions with non-void returns. 645template <typename ChannelT, typename FuncRetT, typename HandlerT> 646class ResponseHandlerImpl : public ResponseHandler<ChannelT> { 647public: 648 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} 649 650 // Handle the result by deserializing it from the channel then passing it 651 // to the user defined handler. 652 Error handleResponse(ChannelT &C) override { 653 using UnwrappedArgType = typename ResponseHandlerArg< 654 typename HandlerTraits<HandlerT>::Type>::UnwrappedArgType; 655 UnwrappedArgType Result; 656 if (auto Err = 657 SerializationTraits<ChannelT, FuncRetT, 658 UnwrappedArgType>::deserialize(C, Result)) 659 return Err; 660 if (auto Err = C.endReceiveMessage()) 661 return Err; 662 return Handler(std::move(Result)); 663 } 664 665 // Abandon this response by calling the handler with an 'abandoned response' 666 // error. 667 void abandon() override { 668 if (auto Err = Handler(this->createAbandonedResponseError())) { 669 // Handlers should not fail when passed an abandoned response error. 670 report_fatal_error(std::move(Err)); 671 } 672 } 673 674private: 675 HandlerT Handler; 676}; 677 678// ResponseHandler subclass for RPC functions with void returns. 679template <typename ChannelT, typename HandlerT> 680class ResponseHandlerImpl<ChannelT, void, HandlerT> 681 : public ResponseHandler<ChannelT> { 682public: 683 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} 684 685 // Handle the result (no actual value, just a notification that the function 686 // has completed on the remote end) by calling the user-defined handler with 687 // Error::success(). 688 Error handleResponse(ChannelT &C) override { 689 if (auto Err = C.endReceiveMessage()) 690 return Err; 691 return Handler(Error::success()); 692 } 693 694 // Abandon this response by calling the handler with an 'abandoned response' 695 // error. 696 void abandon() override { 697 if (auto Err = Handler(this->createAbandonedResponseError())) { 698 // Handlers should not fail when passed an abandoned response error. 699 report_fatal_error(std::move(Err)); 700 } 701 } 702 703private: 704 HandlerT Handler; 705}; 706 707template <typename ChannelT, typename FuncRetT, typename HandlerT> 708class ResponseHandlerImpl<ChannelT, Expected<FuncRetT>, HandlerT> 709 : public ResponseHandler<ChannelT> { 710public: 711 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} 712 713 // Handle the result by deserializing it from the channel then passing it 714 // to the user defined handler. 715 Error handleResponse(ChannelT &C) override { 716 using HandlerArgType = typename ResponseHandlerArg< 717 typename HandlerTraits<HandlerT>::Type>::ArgType; 718 HandlerArgType Result((typename HandlerArgType::value_type())); 719 720 if (auto Err = 721 SerializationTraits<ChannelT, Expected<FuncRetT>, 722 HandlerArgType>::deserialize(C, Result)) 723 return Err; 724 if (auto Err = C.endReceiveMessage()) 725 return Err; 726 return Handler(std::move(Result)); 727 } 728 729 // Abandon this response by calling the handler with an 'abandoned response' 730 // error. 731 void abandon() override { 732 if (auto Err = Handler(this->createAbandonedResponseError())) { 733 // Handlers should not fail when passed an abandoned response error. 734 report_fatal_error(std::move(Err)); 735 } 736 } 737 738private: 739 HandlerT Handler; 740}; 741 742template <typename ChannelT, typename HandlerT> 743class ResponseHandlerImpl<ChannelT, Error, HandlerT> 744 : public ResponseHandler<ChannelT> { 745public: 746 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} 747 748 // Handle the result by deserializing it from the channel then passing it 749 // to the user defined handler. 750 Error handleResponse(ChannelT &C) override { 751 Error Result = Error::success(); 752 if (auto Err = SerializationTraits<ChannelT, Error, Error>::deserialize( 753 C, Result)) { 754 consumeError(std::move(Result)); 755 return Err; 756 } 757 if (auto Err = C.endReceiveMessage()) { 758 consumeError(std::move(Result)); 759 return Err; 760 } 761 return Handler(std::move(Result)); 762 } 763 764 // Abandon this response by calling the handler with an 'abandoned response' 765 // error. 766 void abandon() override { 767 if (auto Err = Handler(this->createAbandonedResponseError())) { 768 // Handlers should not fail when passed an abandoned response error. 769 report_fatal_error(std::move(Err)); 770 } 771 } 772 773private: 774 HandlerT Handler; 775}; 776 777// Create a ResponseHandler from a given user handler. 778template <typename ChannelT, typename FuncRetT, typename HandlerT> 779std::unique_ptr<ResponseHandler<ChannelT>> createResponseHandler(HandlerT H) { 780 return std::make_unique<ResponseHandlerImpl<ChannelT, FuncRetT, HandlerT>>( 781 std::move(H)); 782} 783 784// Helper for wrapping member functions up as functors. This is useful for 785// installing methods as result handlers. 786template <typename ClassT, typename RetT, typename... ArgTs> 787class MemberFnWrapper { 788public: 789 using MethodT = RetT (ClassT::*)(ArgTs...); 790 MemberFnWrapper(ClassT &Instance, MethodT Method) 791 : Instance(Instance), Method(Method) {} 792 RetT operator()(ArgTs &&... Args) { 793 return (Instance.*Method)(std::move(Args)...); 794 } 795 796private: 797 ClassT &Instance; 798 MethodT Method; 799}; 800 801// Helper that provides a Functor for deserializing arguments. 802template <typename... ArgTs> class ReadArgs { 803public: 804 Error operator()() { return Error::success(); } 805}; 806 807template <typename ArgT, typename... ArgTs> 808class ReadArgs<ArgT, ArgTs...> : public ReadArgs<ArgTs...> { 809public: 810 ReadArgs(ArgT &Arg, ArgTs &... Args) 811 : ReadArgs<ArgTs...>(Args...), Arg(Arg) {} 812 813 Error operator()(ArgT &ArgVal, ArgTs &... ArgVals) { 814 this->Arg = std::move(ArgVal); 815 return ReadArgs<ArgTs...>::operator()(ArgVals...); 816 } 817 818private: 819 ArgT &Arg; 820}; 821 822// Manage sequence numbers. 823template <typename SequenceNumberT> class SequenceNumberManager { 824public: 825 // Reset, making all sequence numbers available. 826 void reset() { 827 std::lock_guard<std::mutex> Lock(SeqNoLock); 828 NextSequenceNumber = 0; 829 FreeSequenceNumbers.clear(); 830 } 831 832 // Get the next available sequence number. Will re-use numbers that have 833 // been released. 834 SequenceNumberT getSequenceNumber() { 835 std::lock_guard<std::mutex> Lock(SeqNoLock); 836 if (FreeSequenceNumbers.empty()) 837 return NextSequenceNumber++; 838 auto SequenceNumber = FreeSequenceNumbers.back(); 839 FreeSequenceNumbers.pop_back(); 840 return SequenceNumber; 841 } 842 843 // Release a sequence number, making it available for re-use. 844 void releaseSequenceNumber(SequenceNumberT SequenceNumber) { 845 std::lock_guard<std::mutex> Lock(SeqNoLock); 846 FreeSequenceNumbers.push_back(SequenceNumber); 847 } 848 849private: 850 std::mutex SeqNoLock; 851 SequenceNumberT NextSequenceNumber = 0; 852 std::vector<SequenceNumberT> FreeSequenceNumbers; 853}; 854 855// Checks that predicate P holds for each corresponding pair of type arguments 856// from T1 and T2 tuple. 857template <template <class, class> class P, typename T1Tuple, typename T2Tuple> 858class RPCArgTypeCheckHelper; 859 860template <template <class, class> class P> 861class RPCArgTypeCheckHelper<P, std::tuple<>, std::tuple<>> { 862public: 863 static const bool value = true; 864}; 865 866template <template <class, class> class P, typename T, typename... Ts, 867 typename U, typename... Us> 868class RPCArgTypeCheckHelper<P, std::tuple<T, Ts...>, std::tuple<U, Us...>> { 869public: 870 static const bool value = 871 P<T, U>::value && 872 RPCArgTypeCheckHelper<P, std::tuple<Ts...>, std::tuple<Us...>>::value; 873}; 874 875template <template <class, class> class P, typename T1Sig, typename T2Sig> 876class RPCArgTypeCheck { 877public: 878 using T1Tuple = typename FunctionArgsTuple<T1Sig>::Type; 879 using T2Tuple = typename FunctionArgsTuple<T2Sig>::Type; 880 881 static_assert(std::tuple_size<T1Tuple>::value >= 882 std::tuple_size<T2Tuple>::value, 883 "Too many arguments to RPC call"); 884 static_assert(std::tuple_size<T1Tuple>::value <= 885 std::tuple_size<T2Tuple>::value, 886 "Too few arguments to RPC call"); 887 888 static const bool value = RPCArgTypeCheckHelper<P, T1Tuple, T2Tuple>::value; 889}; 890 891template <typename ChannelT, typename WireT, typename ConcreteT> 892class CanSerialize { 893private: 894 using S = SerializationTraits<ChannelT, WireT, ConcreteT>; 895 896 template <typename T> 897 static std::true_type 898 check(typename std::enable_if< 899 std::is_same<decltype(T::serialize(std::declval<ChannelT &>(), 900 std::declval<const ConcreteT &>())), 901 Error>::value, 902 void *>::type); 903 904 template <typename> static std::false_type check(...); 905 906public: 907 static const bool value = decltype(check<S>(0))::value; 908}; 909 910template <typename ChannelT, typename WireT, typename ConcreteT> 911class CanDeserialize { 912private: 913 using S = SerializationTraits<ChannelT, WireT, ConcreteT>; 914 915 template <typename T> 916 static std::true_type 917 check(typename std::enable_if< 918 std::is_same<decltype(T::deserialize(std::declval<ChannelT &>(), 919 std::declval<ConcreteT &>())), 920 Error>::value, 921 void *>::type); 922 923 template <typename> static std::false_type check(...); 924 925public: 926 static const bool value = decltype(check<S>(0))::value; 927}; 928 929/// Contains primitive utilities for defining, calling and handling calls to 930/// remote procedures. ChannelT is a bidirectional stream conforming to the 931/// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure 932/// identifier type that must be serializable on ChannelT, and SequenceNumberT 933/// is an integral type that will be used to number in-flight function calls. 934/// 935/// These utilities support the construction of very primitive RPC utilities. 936/// Their intent is to ensure correct serialization and deserialization of 937/// procedure arguments, and to keep the client and server's view of the API in 938/// sync. 939template <typename ImplT, typename ChannelT, typename FunctionIdT, 940 typename SequenceNumberT> 941class RPCEndpointBase { 942protected: 943 class OrcRPCInvalid : public Function<OrcRPCInvalid, void()> { 944 public: 945 static const char *getName() { return "__orc_rpc$invalid"; } 946 }; 947 948 class OrcRPCResponse : public Function<OrcRPCResponse, void()> { 949 public: 950 static const char *getName() { return "__orc_rpc$response"; } 951 }; 952 953 class OrcRPCNegotiate 954 : public Function<OrcRPCNegotiate, FunctionIdT(std::string)> { 955 public: 956 static const char *getName() { return "__orc_rpc$negotiate"; } 957 }; 958 959 // Helper predicate for testing for the presence of SerializeTraits 960 // serializers. 961 template <typename WireT, typename ConcreteT> 962 class CanSerializeCheck : detail::CanSerialize<ChannelT, WireT, ConcreteT> { 963 public: 964 using detail::CanSerialize<ChannelT, WireT, ConcreteT>::value; 965 966 static_assert(value, "Missing serializer for argument (Can't serialize the " 967 "first template type argument of CanSerializeCheck " 968 "from the second)"); 969 }; 970 971 // Helper predicate for testing for the presence of SerializeTraits 972 // deserializers. 973 template <typename WireT, typename ConcreteT> 974 class CanDeserializeCheck 975 : detail::CanDeserialize<ChannelT, WireT, ConcreteT> { 976 public: 977 using detail::CanDeserialize<ChannelT, WireT, ConcreteT>::value; 978 979 static_assert(value, "Missing deserializer for argument (Can't deserialize " 980 "the second template type argument of " 981 "CanDeserializeCheck from the first)"); 982 }; 983 984public: 985 /// Construct an RPC instance on a channel. 986 RPCEndpointBase(ChannelT &C, bool LazyAutoNegotiation) 987 : C(C), LazyAutoNegotiation(LazyAutoNegotiation) { 988 // Hold ResponseId in a special variable, since we expect Response to be 989 // called relatively frequently, and want to avoid the map lookup. 990 ResponseId = FnIdAllocator.getResponseId(); 991 RemoteFunctionIds[OrcRPCResponse::getPrototype()] = ResponseId; 992 993 // Register the negotiate function id and handler. 994 auto NegotiateId = FnIdAllocator.getNegotiateId(); 995 RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId; 996 Handlers[NegotiateId] = wrapHandler<OrcRPCNegotiate>( 997 [this](const std::string &Name) { return handleNegotiate(Name); }); 998 } 999 1000 1001 /// Negotiate a function id for Func with the other end of the channel. 1002 template <typename Func> Error negotiateFunction(bool Retry = false) { 1003 return getRemoteFunctionId<Func>(true, Retry).takeError(); 1004 } 1005 1006 /// Append a call Func, does not call send on the channel. 1007 /// The first argument specifies a user-defined handler to be run when the 1008 /// function returns. The handler should take an Expected<Func::ReturnType>, 1009 /// or an Error (if Func::ReturnType is void). The handler will be called 1010 /// with an error if the return value is abandoned due to a channel error. 1011 template <typename Func, typename HandlerT, typename... ArgTs> 1012 Error appendCallAsync(HandlerT Handler, const ArgTs &... Args) { 1013 1014 static_assert( 1015 detail::RPCArgTypeCheck<CanSerializeCheck, typename Func::Type, 1016 void(ArgTs...)>::value, 1017 ""); 1018 1019 // Look up the function ID. 1020 FunctionIdT FnId; 1021 if (auto FnIdOrErr = getRemoteFunctionId<Func>(LazyAutoNegotiation, false)) 1022 FnId = *FnIdOrErr; 1023 else { 1024 // Negotiation failed. Notify the handler then return the negotiate-failed 1025 // error. 1026 cantFail(Handler(make_error<ResponseAbandoned>())); 1027 return FnIdOrErr.takeError(); 1028 } 1029 1030 SequenceNumberT SeqNo; // initialized in locked scope below. 1031 { 1032 // Lock the pending responses map and sequence number manager. 1033 std::lock_guard<std::mutex> Lock(ResponsesMutex); 1034 1035 // Allocate a sequence number. 1036 SeqNo = SequenceNumberMgr.getSequenceNumber(); 1037 assert(!PendingResponses.count(SeqNo) && 1038 "Sequence number already allocated"); 1039 1040 // Install the user handler. 1041 PendingResponses[SeqNo] = 1042 detail::createResponseHandler<ChannelT, typename Func::ReturnType>( 1043 std::move(Handler)); 1044 } 1045 1046 // Open the function call message. 1047 if (auto Err = C.startSendMessage(FnId, SeqNo)) { 1048 abandonPendingResponses(); 1049 return Err; 1050 } 1051 1052 // Serialize the call arguments. 1053 if (auto Err = detail::HandlerTraits<typename Func::Type>::serializeArgs( 1054 C, Args...)) { 1055 abandonPendingResponses(); 1056 return Err; 1057 } 1058 1059 // Close the function call messagee. 1060 if (auto Err = C.endSendMessage()) { 1061 abandonPendingResponses(); 1062 return Err; 1063 } 1064 1065 return Error::success(); 1066 } 1067 1068 Error sendAppendedCalls() { return C.send(); }; 1069 1070 template <typename Func, typename HandlerT, typename... ArgTs> 1071 Error callAsync(HandlerT Handler, const ArgTs &... Args) { 1072 if (auto Err = appendCallAsync<Func>(std::move(Handler), Args...)) 1073 return Err; 1074 return C.send(); 1075 } 1076 1077 /// Handle one incoming call. 1078 Error handleOne() { 1079 FunctionIdT FnId; 1080 SequenceNumberT SeqNo; 1081 if (auto Err = C.startReceiveMessage(FnId, SeqNo)) { 1082 abandonPendingResponses(); 1083 return Err; 1084 } 1085 if (FnId == ResponseId) 1086 return handleResponse(SeqNo); 1087 auto I = Handlers.find(FnId); 1088 if (I != Handlers.end()) 1089 return I->second(C, SeqNo); 1090 1091 // else: No handler found. Report error to client? 1092 return make_error<BadFunctionCall<FunctionIdT, SequenceNumberT>>(FnId, 1093 SeqNo); 1094 } 1095 1096 /// Helper for handling setter procedures - this method returns a functor that 1097 /// sets the variables referred to by Args... to values deserialized from the 1098 /// channel. 1099 /// E.g. 1100 /// 1101 /// typedef Function<0, bool, int> Func1; 1102 /// 1103 /// ... 1104 /// bool B; 1105 /// int I; 1106 /// if (auto Err = expect<Func1>(Channel, readArgs(B, I))) 1107 /// /* Handle Args */ ; 1108 /// 1109 template <typename... ArgTs> 1110 static detail::ReadArgs<ArgTs...> readArgs(ArgTs &... Args) { 1111 return detail::ReadArgs<ArgTs...>(Args...); 1112 } 1113 1114 /// Abandon all outstanding result handlers. 1115 /// 1116 /// This will call all currently registered result handlers to receive an 1117 /// "abandoned" error as their argument. This is used internally by the RPC 1118 /// in error situations, but can also be called directly by clients who are 1119 /// disconnecting from the remote and don't or can't expect responses to their 1120 /// outstanding calls. (Especially for outstanding blocking calls, calling 1121 /// this function may be necessary to avoid dead threads). 1122 void abandonPendingResponses() { 1123 // Lock the pending responses map and sequence number manager. 1124 std::lock_guard<std::mutex> Lock(ResponsesMutex); 1125 1126 for (auto &KV : PendingResponses) 1127 KV.second->abandon(); 1128 PendingResponses.clear(); 1129 SequenceNumberMgr.reset(); 1130 } 1131 1132 /// Remove the handler for the given function. 1133 /// A handler must currently be registered for this function. 1134 template <typename Func> 1135 void removeHandler() { 1136 auto IdItr = LocalFunctionIds.find(Func::getPrototype()); 1137 assert(IdItr != LocalFunctionIds.end() && 1138 "Function does not have a registered handler"); 1139 auto HandlerItr = Handlers.find(IdItr->second); 1140 assert(HandlerItr != Handlers.end() && 1141 "Function does not have a registered handler"); 1142 Handlers.erase(HandlerItr); 1143 } 1144 1145 /// Clear all handlers. 1146 void clearHandlers() { 1147 Handlers.clear(); 1148 } 1149 1150protected: 1151 1152 FunctionIdT getInvalidFunctionId() const { 1153 return FnIdAllocator.getInvalidId(); 1154 } 1155 1156 /// Add the given handler to the handler map and make it available for 1157 /// autonegotiation and execution. 1158 template <typename Func, typename HandlerT> 1159 void addHandlerImpl(HandlerT Handler) { 1160 1161 static_assert(detail::RPCArgTypeCheck< 1162 CanDeserializeCheck, typename Func::Type, 1163 typename detail::HandlerTraits<HandlerT>::Type>::value, 1164 ""); 1165 1166 FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>(); 1167 LocalFunctionIds[Func::getPrototype()] = NewFnId; 1168 Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler)); 1169 } 1170 1171 template <typename Func, typename HandlerT> 1172 void addAsyncHandlerImpl(HandlerT Handler) { 1173 1174 static_assert(detail::RPCArgTypeCheck< 1175 CanDeserializeCheck, typename Func::Type, 1176 typename detail::AsyncHandlerTraits< 1177 typename detail::HandlerTraits<HandlerT>::Type 1178 >::Type>::value, 1179 ""); 1180 1181 FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>(); 1182 LocalFunctionIds[Func::getPrototype()] = NewFnId; 1183 Handlers[NewFnId] = wrapAsyncHandler<Func>(std::move(Handler)); 1184 } 1185 1186 Error handleResponse(SequenceNumberT SeqNo) { 1187 using Handler = typename decltype(PendingResponses)::mapped_type; 1188 Handler PRHandler; 1189 1190 { 1191 // Lock the pending responses map and sequence number manager. 1192 std::unique_lock<std::mutex> Lock(ResponsesMutex); 1193 auto I = PendingResponses.find(SeqNo); 1194 1195 if (I != PendingResponses.end()) { 1196 PRHandler = std::move(I->second); 1197 PendingResponses.erase(I); 1198 SequenceNumberMgr.releaseSequenceNumber(SeqNo); 1199 } else { 1200 // Unlock the pending results map to prevent recursive lock. 1201 Lock.unlock(); 1202 abandonPendingResponses(); 1203 return make_error< 1204 InvalidSequenceNumberForResponse<SequenceNumberT>>(SeqNo); 1205 } 1206 } 1207 1208 assert(PRHandler && 1209 "If we didn't find a response handler we should have bailed out"); 1210 1211 if (auto Err = PRHandler->handleResponse(C)) { 1212 abandonPendingResponses(); 1213 return Err; 1214 } 1215 1216 return Error::success(); 1217 } 1218 1219 FunctionIdT handleNegotiate(const std::string &Name) { 1220 auto I = LocalFunctionIds.find(Name); 1221 if (I == LocalFunctionIds.end()) 1222 return getInvalidFunctionId(); 1223 return I->second; 1224 } 1225 1226 // Find the remote FunctionId for the given function. 1227 template <typename Func> 1228 Expected<FunctionIdT> getRemoteFunctionId(bool NegotiateIfNotInMap, 1229 bool NegotiateIfInvalid) { 1230 bool DoNegotiate; 1231 1232 // Check if we already have a function id... 1233 auto I = RemoteFunctionIds.find(Func::getPrototype()); 1234 if (I != RemoteFunctionIds.end()) { 1235 // If it's valid there's nothing left to do. 1236 if (I->second != getInvalidFunctionId()) 1237 return I->second; 1238 DoNegotiate = NegotiateIfInvalid; 1239 } else 1240 DoNegotiate = NegotiateIfNotInMap; 1241 1242 // We don't have a function id for Func yet, but we're allowed to try to 1243 // negotiate one. 1244 if (DoNegotiate) { 1245 auto &Impl = static_cast<ImplT &>(*this); 1246 if (auto RemoteIdOrErr = 1247 Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) { 1248 RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr; 1249 if (*RemoteIdOrErr == getInvalidFunctionId()) 1250 return make_error<CouldNotNegotiate>(Func::getPrototype()); 1251 return *RemoteIdOrErr; 1252 } else 1253 return RemoteIdOrErr.takeError(); 1254 } 1255 1256 // No key was available in the map and we weren't allowed to try to 1257 // negotiate one, so return an unknown function error. 1258 return make_error<CouldNotNegotiate>(Func::getPrototype()); 1259 } 1260 1261 using WrappedHandlerFn = std::function<Error(ChannelT &, SequenceNumberT)>; 1262 1263 // Wrap the given user handler in the necessary argument-deserialization code, 1264 // result-serialization code, and call to the launch policy (if present). 1265 template <typename Func, typename HandlerT> 1266 WrappedHandlerFn wrapHandler(HandlerT Handler) { 1267 return [this, Handler](ChannelT &Channel, 1268 SequenceNumberT SeqNo) mutable -> Error { 1269 // Start by deserializing the arguments. 1270 using ArgsTuple = 1271 typename detail::FunctionArgsTuple< 1272 typename detail::HandlerTraits<HandlerT>::Type>::Type; 1273 auto Args = std::make_shared<ArgsTuple>(); 1274 1275 if (auto Err = 1276 detail::HandlerTraits<typename Func::Type>::deserializeArgs( 1277 Channel, *Args)) 1278 return Err; 1279 1280 // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning 1281 // for RPCArgs. Void cast RPCArgs to work around this for now. 1282 // FIXME: Remove this workaround once we can assume a working GCC version. 1283 (void)Args; 1284 1285 // End receieve message, unlocking the channel for reading. 1286 if (auto Err = Channel.endReceiveMessage()) 1287 return Err; 1288 1289 using HTraits = detail::HandlerTraits<HandlerT>; 1290 using FuncReturn = typename Func::ReturnType; 1291 return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo, 1292 HTraits::unpackAndRun(Handler, *Args)); 1293 }; 1294 } 1295 1296 // Wrap the given user handler in the necessary argument-deserialization code, 1297 // result-serialization code, and call to the launch policy (if present). 1298 template <typename Func, typename HandlerT> 1299 WrappedHandlerFn wrapAsyncHandler(HandlerT Handler) { 1300 return [this, Handler](ChannelT &Channel, 1301 SequenceNumberT SeqNo) mutable -> Error { 1302 // Start by deserializing the arguments. 1303 using AHTraits = detail::AsyncHandlerTraits< 1304 typename detail::HandlerTraits<HandlerT>::Type>; 1305 using ArgsTuple = 1306 typename detail::FunctionArgsTuple<typename AHTraits::Type>::Type; 1307 auto Args = std::make_shared<ArgsTuple>(); 1308 1309 if (auto Err = 1310 detail::HandlerTraits<typename Func::Type>::deserializeArgs( 1311 Channel, *Args)) 1312 return Err; 1313 1314 // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning 1315 // for RPCArgs. Void cast RPCArgs to work around this for now. 1316 // FIXME: Remove this workaround once we can assume a working GCC version. 1317 (void)Args; 1318 1319 // End receieve message, unlocking the channel for reading. 1320 if (auto Err = Channel.endReceiveMessage()) 1321 return Err; 1322 1323 using HTraits = detail::HandlerTraits<HandlerT>; 1324 using FuncReturn = typename Func::ReturnType; 1325 auto Responder = 1326 [this, SeqNo](typename AHTraits::ResultType RetVal) -> Error { 1327 return detail::respond<FuncReturn>(C, ResponseId, SeqNo, 1328 std::move(RetVal)); 1329 }; 1330 1331 return HTraits::unpackAndRunAsync(Handler, Responder, *Args); 1332 }; 1333 } 1334 1335 ChannelT &C; 1336 1337 bool LazyAutoNegotiation; 1338 1339 RPCFunctionIdAllocator<FunctionIdT> FnIdAllocator; 1340 1341 FunctionIdT ResponseId; 1342 std::map<std::string, FunctionIdT> LocalFunctionIds; 1343 std::map<const char *, FunctionIdT> RemoteFunctionIds; 1344 1345 std::map<FunctionIdT, WrappedHandlerFn> Handlers; 1346 1347 std::mutex ResponsesMutex; 1348 detail::SequenceNumberManager<SequenceNumberT> SequenceNumberMgr; 1349 std::map<SequenceNumberT, std::unique_ptr<detail::ResponseHandler<ChannelT>>> 1350 PendingResponses; 1351}; 1352 1353} // end namespace detail 1354 1355template <typename ChannelT, typename FunctionIdT = uint32_t, 1356 typename SequenceNumberT = uint32_t> 1357class MultiThreadedRPCEndpoint 1358 : public detail::RPCEndpointBase< 1359 MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, 1360 ChannelT, FunctionIdT, SequenceNumberT> { 1361private: 1362 using BaseClass = 1363 detail::RPCEndpointBase< 1364 MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, 1365 ChannelT, FunctionIdT, SequenceNumberT>; 1366 1367public: 1368 MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation) 1369 : BaseClass(C, LazyAutoNegotiation) {} 1370 1371 /// Add a handler for the given RPC function. 1372 /// This installs the given handler functor for the given RPC Function, and 1373 /// makes the RPC function available for negotiation/calling from the remote. 1374 template <typename Func, typename HandlerT> 1375 void addHandler(HandlerT Handler) { 1376 return this->template addHandlerImpl<Func>(std::move(Handler)); 1377 } 1378 1379 /// Add a class-method as a handler. 1380 template <typename Func, typename ClassT, typename RetT, typename... ArgTs> 1381 void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { 1382 addHandler<Func>( 1383 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); 1384 } 1385 1386 template <typename Func, typename HandlerT> 1387 void addAsyncHandler(HandlerT Handler) { 1388 return this->template addAsyncHandlerImpl<Func>(std::move(Handler)); 1389 } 1390 1391 /// Add a class-method as a handler. 1392 template <typename Func, typename ClassT, typename RetT, typename... ArgTs> 1393 void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { 1394 addAsyncHandler<Func>( 1395 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); 1396 } 1397 1398 /// Return type for non-blocking call primitives. 1399 template <typename Func> 1400 using NonBlockingCallResult = typename detail::ResultTraits< 1401 typename Func::ReturnType>::ReturnFutureType; 1402 1403 /// Call Func on Channel C. Does not block, does not call send. Returns a pair 1404 /// of a future result and the sequence number assigned to the result. 1405 /// 1406 /// This utility function is primarily used for single-threaded mode support, 1407 /// where the sequence number can be used to wait for the corresponding 1408 /// result. In multi-threaded mode the appendCallNB method, which does not 1409 /// return the sequence numeber, should be preferred. 1410 template <typename Func, typename... ArgTs> 1411 Expected<NonBlockingCallResult<Func>> appendCallNB(const ArgTs &... Args) { 1412 using RTraits = detail::ResultTraits<typename Func::ReturnType>; 1413 using ErrorReturn = typename RTraits::ErrorReturnType; 1414 using ErrorReturnPromise = typename RTraits::ReturnPromiseType; 1415 1416 ErrorReturnPromise Promise; 1417 auto FutureResult = Promise.get_future(); 1418 1419 if (auto Err = this->template appendCallAsync<Func>( 1420 [Promise = std::move(Promise)](ErrorReturn RetOrErr) mutable { 1421 Promise.set_value(std::move(RetOrErr)); 1422 return Error::success(); 1423 }, 1424 Args...)) { 1425 RTraits::consumeAbandoned(FutureResult.get()); 1426 return std::move(Err); 1427 } 1428 return std::move(FutureResult); 1429 } 1430 1431 /// The same as appendCallNBWithSeq, except that it calls C.send() to 1432 /// flush the channel after serializing the call. 1433 template <typename Func, typename... ArgTs> 1434 Expected<NonBlockingCallResult<Func>> callNB(const ArgTs &... Args) { 1435 auto Result = appendCallNB<Func>(Args...); 1436 if (!Result) 1437 return Result; 1438 if (auto Err = this->C.send()) { 1439 this->abandonPendingResponses(); 1440 detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( 1441 std::move(Result->get())); 1442 return std::move(Err); 1443 } 1444 return Result; 1445 } 1446 1447 /// Call Func on Channel C. Blocks waiting for a result. Returns an Error 1448 /// for void functions or an Expected<T> for functions returning a T. 1449 /// 1450 /// This function is for use in threaded code where another thread is 1451 /// handling responses and incoming calls. 1452 template <typename Func, typename... ArgTs, 1453 typename AltRetT = typename Func::ReturnType> 1454 typename detail::ResultTraits<AltRetT>::ErrorReturnType 1455 callB(const ArgTs &... Args) { 1456 if (auto FutureResOrErr = callNB<Func>(Args...)) 1457 return FutureResOrErr->get(); 1458 else 1459 return FutureResOrErr.takeError(); 1460 } 1461 1462 /// Handle incoming RPC calls. 1463 Error handlerLoop() { 1464 while (true) 1465 if (auto Err = this->handleOne()) 1466 return Err; 1467 return Error::success(); 1468 } 1469}; 1470 1471template <typename ChannelT, typename FunctionIdT = uint32_t, 1472 typename SequenceNumberT = uint32_t> 1473class SingleThreadedRPCEndpoint 1474 : public detail::RPCEndpointBase< 1475 SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, 1476 ChannelT, FunctionIdT, SequenceNumberT> { 1477private: 1478 using BaseClass = 1479 detail::RPCEndpointBase< 1480 SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, 1481 ChannelT, FunctionIdT, SequenceNumberT>; 1482 1483public: 1484 SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation) 1485 : BaseClass(C, LazyAutoNegotiation) {} 1486 1487 template <typename Func, typename HandlerT> 1488 void addHandler(HandlerT Handler) { 1489 return this->template addHandlerImpl<Func>(std::move(Handler)); 1490 } 1491 1492 template <typename Func, typename ClassT, typename RetT, typename... ArgTs> 1493 void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { 1494 addHandler<Func>( 1495 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); 1496 } 1497 1498 template <typename Func, typename HandlerT> 1499 void addAsyncHandler(HandlerT Handler) { 1500 return this->template addAsyncHandlerImpl<Func>(std::move(Handler)); 1501 } 1502 1503 /// Add a class-method as a handler. 1504 template <typename Func, typename ClassT, typename RetT, typename... ArgTs> 1505 void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { 1506 addAsyncHandler<Func>( 1507 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); 1508 } 1509 1510 template <typename Func, typename... ArgTs, 1511 typename AltRetT = typename Func::ReturnType> 1512 typename detail::ResultTraits<AltRetT>::ErrorReturnType 1513 callB(const ArgTs &... Args) { 1514 bool ReceivedResponse = false; 1515 using ResultType = typename detail::ResultTraits<AltRetT>::ErrorReturnType; 1516 auto Result = detail::ResultTraits<AltRetT>::createBlankErrorReturnValue(); 1517 1518 // We have to 'Check' result (which we know is in a success state at this 1519 // point) so that it can be overwritten in the async handler. 1520 (void)!!Result; 1521 1522 if (auto Err = this->template appendCallAsync<Func>( 1523 [&](ResultType R) { 1524 Result = std::move(R); 1525 ReceivedResponse = true; 1526 return Error::success(); 1527 }, 1528 Args...)) { 1529 detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( 1530 std::move(Result)); 1531 return std::move(Err); 1532 } 1533 1534 if (auto Err = this->C.send()) { 1535 detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( 1536 std::move(Result)); 1537 return std::move(Err); 1538 } 1539 1540 while (!ReceivedResponse) { 1541 if (auto Err = this->handleOne()) { 1542 detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( 1543 std::move(Result)); 1544 return std::move(Err); 1545 } 1546 } 1547 1548 return Result; 1549 } 1550}; 1551 1552/// Asynchronous dispatch for a function on an RPC endpoint. 1553template <typename RPCClass, typename Func> 1554class RPCAsyncDispatch { 1555public: 1556 RPCAsyncDispatch(RPCClass &Endpoint) : Endpoint(Endpoint) {} 1557 1558 template <typename HandlerT, typename... ArgTs> 1559 Error operator()(HandlerT Handler, const ArgTs &... Args) const { 1560 return Endpoint.template appendCallAsync<Func>(std::move(Handler), Args...); 1561 } 1562 1563private: 1564 RPCClass &Endpoint; 1565}; 1566 1567/// Construct an asynchronous dispatcher from an RPC endpoint and a Func. 1568template <typename Func, typename RPCEndpointT> 1569RPCAsyncDispatch<RPCEndpointT, Func> rpcAsyncDispatch(RPCEndpointT &Endpoint) { 1570 return RPCAsyncDispatch<RPCEndpointT, Func>(Endpoint); 1571} 1572 1573/// Allows a set of asynchrounous calls to be dispatched, and then 1574/// waited on as a group. 1575class ParallelCallGroup { 1576public: 1577 1578 ParallelCallGroup() = default; 1579 ParallelCallGroup(const ParallelCallGroup &) = delete; 1580 ParallelCallGroup &operator=(const ParallelCallGroup &) = delete; 1581 1582 /// Make as asynchronous call. 1583 template <typename AsyncDispatcher, typename HandlerT, typename... ArgTs> 1584 Error call(const AsyncDispatcher &AsyncDispatch, HandlerT Handler, 1585 const ArgTs &... Args) { 1586 // Increment the count of outstanding calls. This has to happen before 1587 // we invoke the call, as the handler may (depending on scheduling) 1588 // be run immediately on another thread, and we don't want the decrement 1589 // in the wrapped handler below to run before the increment. 1590 { 1591 std::unique_lock<std::mutex> Lock(M); 1592 ++NumOutstandingCalls; 1593 } 1594 1595 // Wrap the user handler in a lambda that will decrement the 1596 // outstanding calls count, then poke the condition variable. 1597 using ArgType = typename detail::ResponseHandlerArg< 1598 typename detail::HandlerTraits<HandlerT>::Type>::ArgType; 1599 auto WrappedHandler = [this, Handler = std::move(Handler)](ArgType Arg) { 1600 auto Err = Handler(std::move(Arg)); 1601 std::unique_lock<std::mutex> Lock(M); 1602 --NumOutstandingCalls; 1603 CV.notify_all(); 1604 return Err; 1605 }; 1606 1607 return AsyncDispatch(std::move(WrappedHandler), Args...); 1608 } 1609 1610 /// Blocks until all calls have been completed and their return value 1611 /// handlers run. 1612 void wait() { 1613 std::unique_lock<std::mutex> Lock(M); 1614 while (NumOutstandingCalls > 0) 1615 CV.wait(Lock); 1616 } 1617 1618private: 1619 std::mutex M; 1620 std::condition_variable CV; 1621 uint32_t NumOutstandingCalls = 0; 1622}; 1623 1624/// Convenience class for grouping RPC Functions into APIs that can be 1625/// negotiated as a block. 1626/// 1627template <typename... Funcs> 1628class APICalls { 1629public: 1630 1631 /// Test whether this API contains Function F. 1632 template <typename F> 1633 class Contains { 1634 public: 1635 static const bool value = false; 1636 }; 1637 1638 /// Negotiate all functions in this API. 1639 template <typename RPCEndpoint> 1640 static Error negotiate(RPCEndpoint &R) { 1641 return Error::success(); 1642 } 1643}; 1644 1645template <typename Func, typename... Funcs> 1646class APICalls<Func, Funcs...> { 1647public: 1648 1649 template <typename F> 1650 class Contains { 1651 public: 1652 static const bool value = std::is_same<F, Func>::value | 1653 APICalls<Funcs...>::template Contains<F>::value; 1654 }; 1655 1656 template <typename RPCEndpoint> 1657 static Error negotiate(RPCEndpoint &R) { 1658 if (auto Err = R.template negotiateFunction<Func>()) 1659 return Err; 1660 return APICalls<Funcs...>::negotiate(R); 1661 } 1662 1663}; 1664 1665template <typename... InnerFuncs, typename... Funcs> 1666class APICalls<APICalls<InnerFuncs...>, Funcs...> { 1667public: 1668 1669 template <typename F> 1670 class Contains { 1671 public: 1672 static const bool value = 1673 APICalls<InnerFuncs...>::template Contains<F>::value | 1674 APICalls<Funcs...>::template Contains<F>::value; 1675 }; 1676 1677 template <typename RPCEndpoint> 1678 static Error negotiate(RPCEndpoint &R) { 1679 if (auto Err = APICalls<InnerFuncs...>::negotiate(R)) 1680 return Err; 1681 return APICalls<Funcs...>::negotiate(R); 1682 } 1683 1684}; 1685 1686} // end namespace rpc 1687} // end namespace orc 1688} // end namespace llvm 1689 1690#endif 1691