1//===-- Socket.cpp ----------------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "lldb/Host/Socket.h"
10
11#include "lldb/Host/Config.h"
12#include "lldb/Host/Host.h"
13#include "lldb/Host/SocketAddress.h"
14#include "lldb/Host/StringConvert.h"
15#include "lldb/Host/common/TCPSocket.h"
16#include "lldb/Host/common/UDPSocket.h"
17#include "lldb/Utility/Log.h"
18#include "lldb/Utility/RegularExpression.h"
19
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/Support/Errno.h"
22#include "llvm/Support/Error.h"
23#include "llvm/Support/WindowsError.h"
24
25#if LLDB_ENABLE_POSIX
26#include "lldb/Host/posix/DomainSocket.h"
27
28#include <arpa/inet.h>
29#include <netdb.h>
30#include <netinet/in.h>
31#include <netinet/tcp.h>
32#include <sys/socket.h>
33#include <sys/un.h>
34#include <unistd.h>
35#endif
36
37#ifdef __linux__
38#include "lldb/Host/linux/AbstractSocket.h"
39#endif
40
41#ifdef __ANDROID__
42#include <arpa/inet.h>
43#include <asm-generic/errno-base.h>
44#include <errno.h>
45#include <linux/tcp.h>
46#include <fcntl.h>
47#include <sys/syscall.h>
48#include <unistd.h>
49#endif // __ANDROID__
50
51using namespace lldb;
52using namespace lldb_private;
53
54#if defined(_WIN32)
55typedef const char *set_socket_option_arg_type;
56typedef char *get_socket_option_arg_type;
57const NativeSocket Socket::kInvalidSocketValue = INVALID_SOCKET;
58#else  // #if defined(_WIN32)
59typedef const void *set_socket_option_arg_type;
60typedef void *get_socket_option_arg_type;
61const NativeSocket Socket::kInvalidSocketValue = -1;
62#endif // #if defined(_WIN32)
63
64namespace {
65
66bool IsInterrupted() {
67#if defined(_WIN32)
68  return ::WSAGetLastError() == WSAEINTR;
69#else
70  return errno == EINTR;
71#endif
72}
73}
74
75Socket::Socket(SocketProtocol protocol, bool should_close,
76               bool child_processes_inherit)
77    : IOObject(eFDTypeSocket), m_protocol(protocol),
78      m_socket(kInvalidSocketValue),
79      m_child_processes_inherit(child_processes_inherit),
80      m_should_close_fd(should_close) {}
81
82Socket::~Socket() { Close(); }
83
84llvm::Error Socket::Initialize() {
85#if defined(_WIN32)
86  auto wVersion = WINSOCK_VERSION;
87  WSADATA wsaData;
88  int err = ::WSAStartup(wVersion, &wsaData);
89  if (err == 0) {
90    if (wsaData.wVersion < wVersion) {
91      WSACleanup();
92      return llvm::make_error<llvm::StringError>(
93          "WSASock version is not expected.", llvm::inconvertibleErrorCode());
94    }
95  } else {
96    return llvm::errorCodeToError(llvm::mapWindowsError(::WSAGetLastError()));
97  }
98#endif
99
100  return llvm::Error::success();
101}
102
103void Socket::Terminate() {
104#if defined(_WIN32)
105  ::WSACleanup();
106#endif
107}
108
109std::unique_ptr<Socket> Socket::Create(const SocketProtocol protocol,
110                                       bool child_processes_inherit,
111                                       Status &error) {
112  error.Clear();
113
114  std::unique_ptr<Socket> socket_up;
115  switch (protocol) {
116  case ProtocolTcp:
117    socket_up =
118        std::make_unique<TCPSocket>(true, child_processes_inherit);
119    break;
120  case ProtocolUdp:
121    socket_up =
122        std::make_unique<UDPSocket>(true, child_processes_inherit);
123    break;
124  case ProtocolUnixDomain:
125#if LLDB_ENABLE_POSIX
126    socket_up =
127        std::make_unique<DomainSocket>(true, child_processes_inherit);
128#else
129    error.SetErrorString(
130        "Unix domain sockets are not supported on this platform.");
131#endif
132    break;
133  case ProtocolUnixAbstract:
134#ifdef __linux__
135    socket_up =
136        std::make_unique<AbstractSocket>(child_processes_inherit);
137#else
138    error.SetErrorString(
139        "Abstract domain sockets are not supported on this platform.");
140#endif
141    break;
142  }
143
144  if (error.Fail())
145    socket_up.reset();
146
147  return socket_up;
148}
149
150Status Socket::TcpConnect(llvm::StringRef host_and_port,
151                          bool child_processes_inherit, Socket *&socket) {
152  Log *log(lldb_private::GetLogIfAnyCategoriesSet(LIBLLDB_LOG_COMMUNICATION));
153  LLDB_LOGF(log, "Socket::%s (host/port = %s)", __FUNCTION__,
154            host_and_port.str().c_str());
155
156  Status error;
157  std::unique_ptr<Socket> connect_socket(
158      Create(ProtocolTcp, child_processes_inherit, error));
159  if (error.Fail())
160    return error;
161
162  error = connect_socket->Connect(host_and_port);
163  if (error.Success())
164    socket = connect_socket.release();
165
166  return error;
167}
168
169Status Socket::TcpListen(llvm::StringRef host_and_port,
170                         bool child_processes_inherit, Socket *&socket,
171                         Predicate<uint16_t> *predicate, int backlog) {
172  Log *log(lldb_private::GetLogIfAnyCategoriesSet(LIBLLDB_LOG_CONNECTION));
173  LLDB_LOGF(log, "Socket::%s (%s)", __FUNCTION__, host_and_port.str().c_str());
174
175  Status error;
176  std::string host_str;
177  std::string port_str;
178  int32_t port = INT32_MIN;
179  if (!DecodeHostAndPort(host_and_port, host_str, port_str, port, &error))
180    return error;
181
182  std::unique_ptr<TCPSocket> listen_socket(
183      new TCPSocket(true, child_processes_inherit));
184  if (error.Fail())
185    return error;
186
187  error = listen_socket->Listen(host_and_port, backlog);
188  if (error.Success()) {
189    // We were asked to listen on port zero which means we must now read the
190    // actual port that was given to us as port zero is a special code for
191    // "find an open port for me".
192    if (port == 0)
193      port = listen_socket->GetLocalPortNumber();
194
195    // Set the port predicate since when doing a listen://<host>:<port> it
196    // often needs to accept the incoming connection which is a blocking system
197    // call. Allowing access to the bound port using a predicate allows us to
198    // wait for the port predicate to be set to a non-zero value from another
199    // thread in an efficient manor.
200    if (predicate)
201      predicate->SetValue(port, eBroadcastAlways);
202    socket = listen_socket.release();
203  }
204
205  return error;
206}
207
208Status Socket::UdpConnect(llvm::StringRef host_and_port,
209                          bool child_processes_inherit, Socket *&socket) {
210  Log *log(lldb_private::GetLogIfAnyCategoriesSet(LIBLLDB_LOG_CONNECTION));
211  LLDB_LOGF(log, "Socket::%s (host/port = %s)", __FUNCTION__,
212            host_and_port.str().c_str());
213
214  return UDPSocket::Connect(host_and_port, child_processes_inherit, socket);
215}
216
217Status Socket::UnixDomainConnect(llvm::StringRef name,
218                                 bool child_processes_inherit,
219                                 Socket *&socket) {
220  Status error;
221  std::unique_ptr<Socket> connect_socket(
222      Create(ProtocolUnixDomain, child_processes_inherit, error));
223  if (error.Fail())
224    return error;
225
226  error = connect_socket->Connect(name);
227  if (error.Success())
228    socket = connect_socket.release();
229
230  return error;
231}
232
233Status Socket::UnixDomainAccept(llvm::StringRef name,
234                                bool child_processes_inherit, Socket *&socket) {
235  Status error;
236  std::unique_ptr<Socket> listen_socket(
237      Create(ProtocolUnixDomain, child_processes_inherit, error));
238  if (error.Fail())
239    return error;
240
241  error = listen_socket->Listen(name, 5);
242  if (error.Fail())
243    return error;
244
245  error = listen_socket->Accept(socket);
246  return error;
247}
248
249Status Socket::UnixAbstractConnect(llvm::StringRef name,
250                                   bool child_processes_inherit,
251                                   Socket *&socket) {
252  Status error;
253  std::unique_ptr<Socket> connect_socket(
254      Create(ProtocolUnixAbstract, child_processes_inherit, error));
255  if (error.Fail())
256    return error;
257
258  error = connect_socket->Connect(name);
259  if (error.Success())
260    socket = connect_socket.release();
261  return error;
262}
263
264Status Socket::UnixAbstractAccept(llvm::StringRef name,
265                                  bool child_processes_inherit,
266                                  Socket *&socket) {
267  Status error;
268  std::unique_ptr<Socket> listen_socket(
269      Create(ProtocolUnixAbstract, child_processes_inherit, error));
270  if (error.Fail())
271    return error;
272
273  error = listen_socket->Listen(name, 5);
274  if (error.Fail())
275    return error;
276
277  error = listen_socket->Accept(socket);
278  return error;
279}
280
281bool Socket::DecodeHostAndPort(llvm::StringRef host_and_port,
282                               std::string &host_str, std::string &port_str,
283                               int32_t &port, Status *error_ptr) {
284  static RegularExpression g_regex(
285      llvm::StringRef("([^:]+|\\[[0-9a-fA-F:]+.*\\]):([0-9]+)"));
286  llvm::SmallVector<llvm::StringRef, 3> matches;
287  if (g_regex.Execute(host_and_port, &matches)) {
288    host_str = matches[1].str();
289    port_str = matches[2].str();
290    // IPv6 addresses are wrapped in [] when specified with ports
291    if (host_str.front() == '[' && host_str.back() == ']')
292      host_str = host_str.substr(1, host_str.size() - 2);
293    bool ok = false;
294    port = StringConvert::ToUInt32(port_str.c_str(), UINT32_MAX, 10, &ok);
295    if (ok && port <= UINT16_MAX) {
296      if (error_ptr)
297        error_ptr->Clear();
298      return true;
299    }
300    // port is too large
301    if (error_ptr)
302      error_ptr->SetErrorStringWithFormat(
303          "invalid host:port specification: '%s'", host_and_port.str().c_str());
304    return false;
305  }
306
307  // If this was unsuccessful, then check if it's simply a signed 32-bit
308  // integer, representing a port with an empty host.
309  host_str.clear();
310  port_str.clear();
311  if (to_integer(host_and_port, port, 10) && port < UINT16_MAX) {
312    port_str = host_and_port;
313    if (error_ptr)
314      error_ptr->Clear();
315    return true;
316  }
317
318  if (error_ptr)
319    error_ptr->SetErrorStringWithFormat("invalid host:port specification: '%s'",
320                                        host_and_port.str().c_str());
321  return false;
322}
323
324IOObject::WaitableHandle Socket::GetWaitableHandle() {
325  // TODO: On Windows, use WSAEventSelect
326  return m_socket;
327}
328
329Status Socket::Read(void *buf, size_t &num_bytes) {
330  Status error;
331  int bytes_received = 0;
332  do {
333    bytes_received = ::recv(m_socket, static_cast<char *>(buf), num_bytes, 0);
334  } while (bytes_received < 0 && IsInterrupted());
335
336  if (bytes_received < 0) {
337    SetLastError(error);
338    num_bytes = 0;
339  } else
340    num_bytes = bytes_received;
341
342  Log *log(lldb_private::GetLogIfAnyCategoriesSet(LIBLLDB_LOG_COMMUNICATION));
343  if (log) {
344    LLDB_LOGF(log,
345              "%p Socket::Read() (socket = %" PRIu64
346              ", src = %p, src_len = %" PRIu64 ", flags = 0) => %" PRIi64
347              " (error = %s)",
348              static_cast<void *>(this), static_cast<uint64_t>(m_socket), buf,
349              static_cast<uint64_t>(num_bytes),
350              static_cast<int64_t>(bytes_received), error.AsCString());
351  }
352
353  return error;
354}
355
356Status Socket::Write(const void *buf, size_t &num_bytes) {
357  const size_t src_len = num_bytes;
358  Status error;
359  int bytes_sent = 0;
360  do {
361    bytes_sent = Send(buf, num_bytes);
362  } while (bytes_sent < 0 && IsInterrupted());
363
364  if (bytes_sent < 0) {
365    SetLastError(error);
366    num_bytes = 0;
367  } else
368    num_bytes = bytes_sent;
369
370  Log *log(lldb_private::GetLogIfAnyCategoriesSet(LIBLLDB_LOG_COMMUNICATION));
371  if (log) {
372    LLDB_LOGF(log,
373              "%p Socket::Write() (socket = %" PRIu64
374              ", src = %p, src_len = %" PRIu64 ", flags = 0) => %" PRIi64
375              " (error = %s)",
376              static_cast<void *>(this), static_cast<uint64_t>(m_socket), buf,
377              static_cast<uint64_t>(src_len),
378              static_cast<int64_t>(bytes_sent), error.AsCString());
379  }
380
381  return error;
382}
383
384Status Socket::PreDisconnect() {
385  Status error;
386  return error;
387}
388
389Status Socket::Close() {
390  Status error;
391  if (!IsValid() || !m_should_close_fd)
392    return error;
393
394  Log *log(lldb_private::GetLogIfAnyCategoriesSet(LIBLLDB_LOG_CONNECTION));
395  LLDB_LOGF(log, "%p Socket::Close (fd = %" PRIu64 ")",
396            static_cast<void *>(this), static_cast<uint64_t>(m_socket));
397
398#if defined(_WIN32)
399  bool success = !!closesocket(m_socket);
400#else
401  bool success = !!::close(m_socket);
402#endif
403  // A reference to a FD was passed in, set it to an invalid value
404  m_socket = kInvalidSocketValue;
405  if (!success) {
406    SetLastError(error);
407  }
408
409  return error;
410}
411
412int Socket::GetOption(int level, int option_name, int &option_value) {
413  get_socket_option_arg_type option_value_p =
414      reinterpret_cast<get_socket_option_arg_type>(&option_value);
415  socklen_t option_value_size = sizeof(int);
416  return ::getsockopt(m_socket, level, option_name, option_value_p,
417                      &option_value_size);
418}
419
420int Socket::SetOption(int level, int option_name, int option_value) {
421  set_socket_option_arg_type option_value_p =
422      reinterpret_cast<get_socket_option_arg_type>(&option_value);
423  return ::setsockopt(m_socket, level, option_name, option_value_p,
424                      sizeof(option_value));
425}
426
427size_t Socket::Send(const void *buf, const size_t num_bytes) {
428  return ::send(m_socket, static_cast<const char *>(buf), num_bytes, 0);
429}
430
431void Socket::SetLastError(Status &error) {
432#if defined(_WIN32)
433  error.SetError(::WSAGetLastError(), lldb::eErrorTypeWin32);
434#else
435  error.SetErrorToErrno();
436#endif
437}
438
439NativeSocket Socket::CreateSocket(const int domain, const int type,
440                                  const int protocol,
441                                  bool child_processes_inherit, Status &error) {
442  error.Clear();
443  auto socket_type = type;
444#ifdef SOCK_CLOEXEC
445  if (!child_processes_inherit)
446    socket_type |= SOCK_CLOEXEC;
447#endif
448  auto sock = ::socket(domain, socket_type, protocol);
449  if (sock == kInvalidSocketValue)
450    SetLastError(error);
451
452  return sock;
453}
454
455NativeSocket Socket::AcceptSocket(NativeSocket sockfd, struct sockaddr *addr,
456                                  socklen_t *addrlen,
457                                  bool child_processes_inherit, Status &error) {
458  error.Clear();
459#if defined(ANDROID_USE_ACCEPT_WORKAROUND)
460  // Hack:
461  // This enables static linking lldb-server to an API 21 libc, but still
462  // having it run on older devices. It is necessary because API 21 libc's
463  // implementation of accept() uses the accept4 syscall(), which is not
464  // available in older kernels. Using an older libc would fix this issue, but
465  // introduce other ones, as the old libraries were quite buggy.
466  int fd = syscall(__NR_accept, sockfd, addr, addrlen);
467  if (fd >= 0 && !child_processes_inherit) {
468    int flags = ::fcntl(fd, F_GETFD);
469    if (flags != -1 && ::fcntl(fd, F_SETFD, flags | FD_CLOEXEC) != -1)
470      return fd;
471    SetLastError(error);
472    close(fd);
473  }
474  return fd;
475#elif defined(SOCK_CLOEXEC) && defined(HAVE_ACCEPT4)
476  int flags = 0;
477  if (!child_processes_inherit) {
478    flags |= SOCK_CLOEXEC;
479  }
480  NativeSocket fd = llvm::sys::RetryAfterSignal(
481      static_cast<NativeSocket>(-1), ::accept4, sockfd, addr, addrlen, flags);
482#else
483  NativeSocket fd = llvm::sys::RetryAfterSignal(
484      static_cast<NativeSocket>(-1), ::accept, sockfd, addr, addrlen);
485#endif
486  if (fd == kInvalidSocketValue)
487    SetLastError(error);
488  return fd;
489}
490