1356843Sdim//===- llvm/ExecutionEngine/Orc/RPC/RawByteChannel.h ----------------*- C++ -*-===//
2356843Sdim//
3356843Sdim// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4356843Sdim// See https://llvm.org/LICENSE.txt for license information.
5356843Sdim// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6356843Sdim//
7356843Sdim//===----------------------------------------------------------------------===//
8356843Sdim
9356843Sdim#ifndef LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H
10356843Sdim#define LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H
11356843Sdim
12356843Sdim#include "llvm/ADT/StringRef.h"
13356843Sdim#include "llvm/ExecutionEngine/Orc/RPC/RPCSerialization.h"
14356843Sdim#include "llvm/Support/Endian.h"
15356843Sdim#include "llvm/Support/Error.h"
16356843Sdim#include <cstdint>
17356843Sdim#include <mutex>
18356843Sdim#include <string>
19356843Sdim#include <type_traits>
20356843Sdim
21356843Sdimnamespace llvm {
22356843Sdimnamespace orc {
23356843Sdimnamespace rpc {
24356843Sdim
25356843Sdim/// Interface for byte-streams to be used with RPC.
26356843Sdimclass RawByteChannel {
27356843Sdimpublic:
28356843Sdim  virtual ~RawByteChannel() = default;
29356843Sdim
30356843Sdim  /// Read Size bytes from the stream into *Dst.
31356843Sdim  virtual Error readBytes(char *Dst, unsigned Size) = 0;
32356843Sdim
33356843Sdim  /// Read size bytes from *Src and append them to the stream.
34356843Sdim  virtual Error appendBytes(const char *Src, unsigned Size) = 0;
35356843Sdim
36356843Sdim  /// Flush the stream if possible.
37356843Sdim  virtual Error send() = 0;
38356843Sdim
39356843Sdim  /// Notify the channel that we're starting a message send.
40356843Sdim  /// Locks the channel for writing.
41356843Sdim  template <typename FunctionIdT, typename SequenceIdT>
42356843Sdim  Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) {
43356843Sdim    writeLock.lock();
44356843Sdim    if (auto Err = serializeSeq(*this, FnId, SeqNo)) {
45356843Sdim      writeLock.unlock();
46356843Sdim      return Err;
47356843Sdim    }
48356843Sdim    return Error::success();
49356843Sdim  }
50356843Sdim
51356843Sdim  /// Notify the channel that we're ending a message send.
52356843Sdim  /// Unlocks the channel for writing.
53356843Sdim  Error endSendMessage() {
54356843Sdim    writeLock.unlock();
55356843Sdim    return Error::success();
56356843Sdim  }
57356843Sdim
58356843Sdim  /// Notify the channel that we're starting a message receive.
59356843Sdim  /// Locks the channel for reading.
60356843Sdim  template <typename FunctionIdT, typename SequenceNumberT>
61356843Sdim  Error startReceiveMessage(FunctionIdT &FnId, SequenceNumberT &SeqNo) {
62356843Sdim    readLock.lock();
63356843Sdim    if (auto Err = deserializeSeq(*this, FnId, SeqNo)) {
64356843Sdim      readLock.unlock();
65356843Sdim      return Err;
66356843Sdim    }
67356843Sdim    return Error::success();
68356843Sdim  }
69356843Sdim
70356843Sdim  /// Notify the channel that we're ending a message receive.
71356843Sdim  /// Unlocks the channel for reading.
72356843Sdim  Error endReceiveMessage() {
73356843Sdim    readLock.unlock();
74356843Sdim    return Error::success();
75356843Sdim  }
76356843Sdim
77356843Sdim  /// Get the lock for stream reading.
78356843Sdim  std::mutex &getReadLock() { return readLock; }
79356843Sdim
80356843Sdim  /// Get the lock for stream writing.
81356843Sdim  std::mutex &getWriteLock() { return writeLock; }
82356843Sdim
83356843Sdimprivate:
84356843Sdim  std::mutex readLock, writeLock;
85356843Sdim};
86356843Sdim
87356843Sdimtemplate <typename ChannelT, typename T>
88356843Sdimclass SerializationTraits<
89356843Sdim    ChannelT, T, T,
90356843Sdim    typename std::enable_if<
91356843Sdim        std::is_base_of<RawByteChannel, ChannelT>::value &&
92356843Sdim        (std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value ||
93356843Sdim         std::is_same<T, uint16_t>::value || std::is_same<T, int16_t>::value ||
94356843Sdim         std::is_same<T, uint32_t>::value || std::is_same<T, int32_t>::value ||
95356843Sdim         std::is_same<T, uint64_t>::value || std::is_same<T, int64_t>::value ||
96356843Sdim         std::is_same<T, char>::value)>::type> {
97356843Sdimpublic:
98356843Sdim  static Error serialize(ChannelT &C, T V) {
99356843Sdim    support::endian::byte_swap<T, support::big>(V);
100356843Sdim    return C.appendBytes(reinterpret_cast<const char *>(&V), sizeof(T));
101356843Sdim  };
102356843Sdim
103356843Sdim  static Error deserialize(ChannelT &C, T &V) {
104356843Sdim    if (auto Err = C.readBytes(reinterpret_cast<char *>(&V), sizeof(T)))
105356843Sdim      return Err;
106356843Sdim    support::endian::byte_swap<T, support::big>(V);
107356843Sdim    return Error::success();
108356843Sdim  };
109356843Sdim};
110356843Sdim
111356843Sdimtemplate <typename ChannelT>
112356843Sdimclass SerializationTraits<ChannelT, bool, bool,
113356843Sdim                          typename std::enable_if<std::is_base_of<
114356843Sdim                              RawByteChannel, ChannelT>::value>::type> {
115356843Sdimpublic:
116356843Sdim  static Error serialize(ChannelT &C, bool V) {
117356843Sdim    uint8_t Tmp = V ? 1 : 0;
118356843Sdim    if (auto Err =
119356843Sdim          C.appendBytes(reinterpret_cast<const char *>(&Tmp), 1))
120356843Sdim      return Err;
121356843Sdim    return Error::success();
122356843Sdim  }
123356843Sdim
124356843Sdim  static Error deserialize(ChannelT &C, bool &V) {
125356843Sdim    uint8_t Tmp = 0;
126356843Sdim    if (auto Err = C.readBytes(reinterpret_cast<char *>(&Tmp), 1))
127356843Sdim      return Err;
128356843Sdim    V = Tmp != 0;
129356843Sdim    return Error::success();
130356843Sdim  }
131356843Sdim};
132356843Sdim
133356843Sdimtemplate <typename ChannelT>
134356843Sdimclass SerializationTraits<ChannelT, std::string, StringRef,
135356843Sdim                          typename std::enable_if<std::is_base_of<
136356843Sdim                              RawByteChannel, ChannelT>::value>::type> {
137356843Sdimpublic:
138356843Sdim  /// RPC channel serialization for std::strings.
139356843Sdim  static Error serialize(RawByteChannel &C, StringRef S) {
140356843Sdim    if (auto Err = serializeSeq(C, static_cast<uint64_t>(S.size())))
141356843Sdim      return Err;
142356843Sdim    return C.appendBytes((const char *)S.data(), S.size());
143356843Sdim  }
144356843Sdim};
145356843Sdim
146356843Sdimtemplate <typename ChannelT, typename T>
147356843Sdimclass SerializationTraits<ChannelT, std::string, T,
148356843Sdim                          typename std::enable_if<
149356843Sdim                            std::is_base_of<RawByteChannel, ChannelT>::value &&
150356843Sdim                            (std::is_same<T, const char*>::value ||
151356843Sdim                             std::is_same<T, char*>::value)>::type> {
152356843Sdimpublic:
153356843Sdim  static Error serialize(RawByteChannel &C, const char *S) {
154356843Sdim    return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C,
155356843Sdim                                                                            S);
156356843Sdim  }
157356843Sdim};
158356843Sdim
159356843Sdimtemplate <typename ChannelT>
160356843Sdimclass SerializationTraits<ChannelT, std::string, std::string,
161356843Sdim                          typename std::enable_if<std::is_base_of<
162356843Sdim                              RawByteChannel, ChannelT>::value>::type> {
163356843Sdimpublic:
164356843Sdim  /// RPC channel serialization for std::strings.
165356843Sdim  static Error serialize(RawByteChannel &C, const std::string &S) {
166356843Sdim    return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C,
167356843Sdim                                                                            S);
168356843Sdim  }
169356843Sdim
170356843Sdim  /// RPC channel deserialization for std::strings.
171356843Sdim  static Error deserialize(RawByteChannel &C, std::string &S) {
172356843Sdim    uint64_t Count = 0;
173356843Sdim    if (auto Err = deserializeSeq(C, Count))
174356843Sdim      return Err;
175356843Sdim    S.resize(Count);
176356843Sdim    return C.readBytes(&S[0], Count);
177356843Sdim  }
178356843Sdim};
179356843Sdim
180356843Sdim} // end namespace rpc
181356843Sdim} // end namespace orc
182356843Sdim} // end namespace llvm
183356843Sdim
184356843Sdim#endif // LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H
185