1//===- llvm/ExecutionEngine/Orc/RPC/RawByteChannel.h ----------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H
10#define LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H
11
12#include "llvm/ADT/StringRef.h"
13#include "llvm/ExecutionEngine/Orc/RPC/RPCSerialization.h"
14#include "llvm/Support/Endian.h"
15#include "llvm/Support/Error.h"
16#include <cstdint>
17#include <mutex>
18#include <string>
19#include <type_traits>
20
21namespace llvm {
22namespace orc {
23namespace rpc {
24
25/// Interface for byte-streams to be used with RPC.
26class RawByteChannel {
27public:
28  virtual ~RawByteChannel() = default;
29
30  /// Read Size bytes from the stream into *Dst.
31  virtual Error readBytes(char *Dst, unsigned Size) = 0;
32
33  /// Read size bytes from *Src and append them to the stream.
34  virtual Error appendBytes(const char *Src, unsigned Size) = 0;
35
36  /// Flush the stream if possible.
37  virtual Error send() = 0;
38
39  /// Notify the channel that we're starting a message send.
40  /// Locks the channel for writing.
41  template <typename FunctionIdT, typename SequenceIdT>
42  Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) {
43    writeLock.lock();
44    if (auto Err = serializeSeq(*this, FnId, SeqNo)) {
45      writeLock.unlock();
46      return Err;
47    }
48    return Error::success();
49  }
50
51  /// Notify the channel that we're ending a message send.
52  /// Unlocks the channel for writing.
53  Error endSendMessage() {
54    writeLock.unlock();
55    return Error::success();
56  }
57
58  /// Notify the channel that we're starting a message receive.
59  /// Locks the channel for reading.
60  template <typename FunctionIdT, typename SequenceNumberT>
61  Error startReceiveMessage(FunctionIdT &FnId, SequenceNumberT &SeqNo) {
62    readLock.lock();
63    if (auto Err = deserializeSeq(*this, FnId, SeqNo)) {
64      readLock.unlock();
65      return Err;
66    }
67    return Error::success();
68  }
69
70  /// Notify the channel that we're ending a message receive.
71  /// Unlocks the channel for reading.
72  Error endReceiveMessage() {
73    readLock.unlock();
74    return Error::success();
75  }
76
77  /// Get the lock for stream reading.
78  std::mutex &getReadLock() { return readLock; }
79
80  /// Get the lock for stream writing.
81  std::mutex &getWriteLock() { return writeLock; }
82
83private:
84  std::mutex readLock, writeLock;
85};
86
87template <typename ChannelT, typename T>
88class SerializationTraits<
89    ChannelT, T, T,
90    typename std::enable_if<
91        std::is_base_of<RawByteChannel, ChannelT>::value &&
92        (std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value ||
93         std::is_same<T, uint16_t>::value || std::is_same<T, int16_t>::value ||
94         std::is_same<T, uint32_t>::value || std::is_same<T, int32_t>::value ||
95         std::is_same<T, uint64_t>::value || std::is_same<T, int64_t>::value ||
96         std::is_same<T, char>::value)>::type> {
97public:
98  static Error serialize(ChannelT &C, T V) {
99    support::endian::byte_swap<T, support::big>(V);
100    return C.appendBytes(reinterpret_cast<const char *>(&V), sizeof(T));
101  };
102
103  static Error deserialize(ChannelT &C, T &V) {
104    if (auto Err = C.readBytes(reinterpret_cast<char *>(&V), sizeof(T)))
105      return Err;
106    support::endian::byte_swap<T, support::big>(V);
107    return Error::success();
108  };
109};
110
111template <typename ChannelT>
112class SerializationTraits<ChannelT, bool, bool,
113                          typename std::enable_if<std::is_base_of<
114                              RawByteChannel, ChannelT>::value>::type> {
115public:
116  static Error serialize(ChannelT &C, bool V) {
117    uint8_t Tmp = V ? 1 : 0;
118    if (auto Err =
119          C.appendBytes(reinterpret_cast<const char *>(&Tmp), 1))
120      return Err;
121    return Error::success();
122  }
123
124  static Error deserialize(ChannelT &C, bool &V) {
125    uint8_t Tmp = 0;
126    if (auto Err = C.readBytes(reinterpret_cast<char *>(&Tmp), 1))
127      return Err;
128    V = Tmp != 0;
129    return Error::success();
130  }
131};
132
133template <typename ChannelT>
134class SerializationTraits<ChannelT, std::string, StringRef,
135                          typename std::enable_if<std::is_base_of<
136                              RawByteChannel, ChannelT>::value>::type> {
137public:
138  /// RPC channel serialization for std::strings.
139  static Error serialize(RawByteChannel &C, StringRef S) {
140    if (auto Err = serializeSeq(C, static_cast<uint64_t>(S.size())))
141      return Err;
142    return C.appendBytes((const char *)S.data(), S.size());
143  }
144};
145
146template <typename ChannelT, typename T>
147class SerializationTraits<ChannelT, std::string, T,
148                          typename std::enable_if<
149                            std::is_base_of<RawByteChannel, ChannelT>::value &&
150                            (std::is_same<T, const char*>::value ||
151                             std::is_same<T, char*>::value)>::type> {
152public:
153  static Error serialize(RawByteChannel &C, const char *S) {
154    return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C,
155                                                                            S);
156  }
157};
158
159template <typename ChannelT>
160class SerializationTraits<ChannelT, std::string, std::string,
161                          typename std::enable_if<std::is_base_of<
162                              RawByteChannel, ChannelT>::value>::type> {
163public:
164  /// RPC channel serialization for std::strings.
165  static Error serialize(RawByteChannel &C, const std::string &S) {
166    return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C,
167                                                                            S);
168  }
169
170  /// RPC channel deserialization for std::strings.
171  static Error deserialize(RawByteChannel &C, std::string &S) {
172    uint64_t Count = 0;
173    if (auto Err = deserializeSeq(C, Count))
174      return Err;
175    S.resize(Count);
176    return C.readBytes(&S[0], Count);
177  }
178};
179
180} // end namespace rpc
181} // end namespace orc
182} // end namespace llvm
183
184#endif // LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H
185