1//===-- TCPSocket.cpp -----------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#if defined(_MSC_VER)
10#define _WINSOCK_DEPRECATED_NO_WARNINGS
11#endif
12
13#include "lldb/Host/common/TCPSocket.h"
14
15#include "lldb/Host/Config.h"
16#include "lldb/Host/MainLoop.h"
17#include "lldb/Utility/LLDBLog.h"
18#include "lldb/Utility/Log.h"
19
20#include "llvm/Config/llvm-config.h"
21#include "llvm/Support/Errno.h"
22#include "llvm/Support/WindowsError.h"
23#include "llvm/Support/raw_ostream.h"
24
25#if LLDB_ENABLE_POSIX
26#include <arpa/inet.h>
27#include <netinet/tcp.h>
28#include <sys/socket.h>
29#endif
30
31#if defined(_WIN32)
32#include <winsock2.h>
33#endif
34
35#ifdef _WIN32
36#define CLOSE_SOCKET closesocket
37typedef const char *set_socket_option_arg_type;
38#else
39#include <unistd.h>
40#define CLOSE_SOCKET ::close
41typedef const void *set_socket_option_arg_type;
42#endif
43
44using namespace lldb;
45using namespace lldb_private;
46
47static Status GetLastSocketError() {
48  std::error_code EC;
49#ifdef _WIN32
50  EC = llvm::mapWindowsError(WSAGetLastError());
51#else
52  EC = std::error_code(errno, std::generic_category());
53#endif
54  return EC;
55}
56
57static const int kType = SOCK_STREAM;
58
59TCPSocket::TCPSocket(bool should_close, bool child_processes_inherit)
60    : Socket(ProtocolTcp, should_close, child_processes_inherit) {}
61
62TCPSocket::TCPSocket(NativeSocket socket, const TCPSocket &listen_socket)
63    : Socket(ProtocolTcp, listen_socket.m_should_close_fd,
64             listen_socket.m_child_processes_inherit) {
65  m_socket = socket;
66}
67
68TCPSocket::TCPSocket(NativeSocket socket, bool should_close,
69                     bool child_processes_inherit)
70    : Socket(ProtocolTcp, should_close, child_processes_inherit) {
71  m_socket = socket;
72}
73
74TCPSocket::~TCPSocket() { CloseListenSockets(); }
75
76bool TCPSocket::IsValid() const {
77  return m_socket != kInvalidSocketValue || m_listen_sockets.size() != 0;
78}
79
80// Return the port number that is being used by the socket.
81uint16_t TCPSocket::GetLocalPortNumber() const {
82  if (m_socket != kInvalidSocketValue) {
83    SocketAddress sock_addr;
84    socklen_t sock_addr_len = sock_addr.GetMaxLength();
85    if (::getsockname(m_socket, sock_addr, &sock_addr_len) == 0)
86      return sock_addr.GetPort();
87  } else if (!m_listen_sockets.empty()) {
88    SocketAddress sock_addr;
89    socklen_t sock_addr_len = sock_addr.GetMaxLength();
90    if (::getsockname(m_listen_sockets.begin()->first, sock_addr,
91                      &sock_addr_len) == 0)
92      return sock_addr.GetPort();
93  }
94  return 0;
95}
96
97std::string TCPSocket::GetLocalIPAddress() const {
98  // We bound to port zero, so we need to figure out which port we actually
99  // bound to
100  if (m_socket != kInvalidSocketValue) {
101    SocketAddress sock_addr;
102    socklen_t sock_addr_len = sock_addr.GetMaxLength();
103    if (::getsockname(m_socket, sock_addr, &sock_addr_len) == 0)
104      return sock_addr.GetIPAddress();
105  }
106  return "";
107}
108
109uint16_t TCPSocket::GetRemotePortNumber() const {
110  if (m_socket != kInvalidSocketValue) {
111    SocketAddress sock_addr;
112    socklen_t sock_addr_len = sock_addr.GetMaxLength();
113    if (::getpeername(m_socket, sock_addr, &sock_addr_len) == 0)
114      return sock_addr.GetPort();
115  }
116  return 0;
117}
118
119std::string TCPSocket::GetRemoteIPAddress() const {
120  // We bound to port zero, so we need to figure out which port we actually
121  // bound to
122  if (m_socket != kInvalidSocketValue) {
123    SocketAddress sock_addr;
124    socklen_t sock_addr_len = sock_addr.GetMaxLength();
125    if (::getpeername(m_socket, sock_addr, &sock_addr_len) == 0)
126      return sock_addr.GetIPAddress();
127  }
128  return "";
129}
130
131std::string TCPSocket::GetRemoteConnectionURI() const {
132  if (m_socket != kInvalidSocketValue) {
133    return std::string(llvm::formatv(
134        "connect://[{0}]:{1}", GetRemoteIPAddress(), GetRemotePortNumber()));
135  }
136  return "";
137}
138
139Status TCPSocket::CreateSocket(int domain) {
140  Status error;
141  if (IsValid())
142    error = Close();
143  if (error.Fail())
144    return error;
145  m_socket = Socket::CreateSocket(domain, kType, IPPROTO_TCP,
146                                  m_child_processes_inherit, error);
147  return error;
148}
149
150Status TCPSocket::Connect(llvm::StringRef name) {
151
152  Log *log = GetLog(LLDBLog::Communication);
153  LLDB_LOGF(log, "TCPSocket::%s (host/port = %s)", __FUNCTION__, name.data());
154
155  Status error;
156  llvm::Expected<HostAndPort> host_port = DecodeHostAndPort(name);
157  if (!host_port)
158    return Status(host_port.takeError());
159
160  std::vector<SocketAddress> addresses =
161      SocketAddress::GetAddressInfo(host_port->hostname.c_str(), nullptr,
162                                    AF_UNSPEC, SOCK_STREAM, IPPROTO_TCP);
163  for (SocketAddress &address : addresses) {
164    error = CreateSocket(address.GetFamily());
165    if (error.Fail())
166      continue;
167
168    address.SetPort(host_port->port);
169
170    if (llvm::sys::RetryAfterSignal(-1, ::connect, GetNativeSocket(),
171                                    &address.sockaddr(),
172                                    address.GetLength()) == -1) {
173      Close();
174      continue;
175    }
176
177    if (SetOptionNoDelay() == -1) {
178      Close();
179      continue;
180    }
181
182    error.Clear();
183    return error;
184  }
185
186  error.SetErrorString("Failed to connect port");
187  return error;
188}
189
190Status TCPSocket::Listen(llvm::StringRef name, int backlog) {
191  Log *log = GetLog(LLDBLog::Connection);
192  LLDB_LOGF(log, "TCPSocket::%s (%s)", __FUNCTION__, name.data());
193
194  Status error;
195  llvm::Expected<HostAndPort> host_port = DecodeHostAndPort(name);
196  if (!host_port)
197    return Status(host_port.takeError());
198
199  if (host_port->hostname == "*")
200    host_port->hostname = "0.0.0.0";
201  std::vector<SocketAddress> addresses = SocketAddress::GetAddressInfo(
202      host_port->hostname.c_str(), nullptr, AF_UNSPEC, SOCK_STREAM, IPPROTO_TCP);
203  for (SocketAddress &address : addresses) {
204    int fd = Socket::CreateSocket(address.GetFamily(), kType, IPPROTO_TCP,
205                                  m_child_processes_inherit, error);
206    if (error.Fail() || fd < 0)
207      continue;
208
209    // enable local address reuse
210    int option_value = 1;
211    set_socket_option_arg_type option_value_p =
212        reinterpret_cast<set_socket_option_arg_type>(&option_value);
213    if (::setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, option_value_p,
214                     sizeof(option_value)) == -1) {
215      CLOSE_SOCKET(fd);
216      continue;
217    }
218
219    SocketAddress listen_address = address;
220    if(!listen_address.IsLocalhost())
221      listen_address.SetToAnyAddress(address.GetFamily(), host_port->port);
222    else
223      listen_address.SetPort(host_port->port);
224
225    int err =
226        ::bind(fd, &listen_address.sockaddr(), listen_address.GetLength());
227    if (err != -1)
228      err = ::listen(fd, backlog);
229
230    if (err == -1) {
231      error = GetLastSocketError();
232      CLOSE_SOCKET(fd);
233      continue;
234    }
235
236    if (host_port->port == 0) {
237      socklen_t sa_len = address.GetLength();
238      if (getsockname(fd, &address.sockaddr(), &sa_len) == 0)
239        host_port->port = address.GetPort();
240    }
241    m_listen_sockets[fd] = address;
242  }
243
244  if (m_listen_sockets.empty()) {
245    assert(error.Fail());
246    return error;
247  }
248  return Status();
249}
250
251void TCPSocket::CloseListenSockets() {
252  for (auto socket : m_listen_sockets)
253    CLOSE_SOCKET(socket.first);
254  m_listen_sockets.clear();
255}
256
257Status TCPSocket::Accept(Socket *&conn_socket) {
258  Status error;
259  if (m_listen_sockets.size() == 0) {
260    error.SetErrorString("No open listening sockets!");
261    return error;
262  }
263
264  NativeSocket sock = kInvalidSocketValue;
265  NativeSocket listen_sock = kInvalidSocketValue;
266  lldb_private::SocketAddress AcceptAddr;
267  MainLoop accept_loop;
268  std::vector<MainLoopBase::ReadHandleUP> handles;
269  for (auto socket : m_listen_sockets) {
270    auto fd = socket.first;
271    auto inherit = this->m_child_processes_inherit;
272    auto io_sp = IOObjectSP(new TCPSocket(socket.first, false, inherit));
273    handles.emplace_back(accept_loop.RegisterReadObject(
274        io_sp, [fd, inherit, &sock, &AcceptAddr, &error,
275                        &listen_sock](MainLoopBase &loop) {
276          socklen_t sa_len = AcceptAddr.GetMaxLength();
277          sock = AcceptSocket(fd, &AcceptAddr.sockaddr(), &sa_len, inherit,
278                              error);
279          listen_sock = fd;
280          loop.RequestTermination();
281        }, error));
282    if (error.Fail())
283      return error;
284  }
285
286  bool accept_connection = false;
287  std::unique_ptr<TCPSocket> accepted_socket;
288  // Loop until we are happy with our connection
289  while (!accept_connection) {
290    accept_loop.Run();
291
292    if (error.Fail())
293        return error;
294
295    lldb_private::SocketAddress &AddrIn = m_listen_sockets[listen_sock];
296    if (!AddrIn.IsAnyAddr() && AcceptAddr != AddrIn) {
297      if (sock != kInvalidSocketValue) {
298        CLOSE_SOCKET(sock);
299        sock = kInvalidSocketValue;
300      }
301      llvm::errs() << llvm::formatv(
302          "error: rejecting incoming connection from {0} (expecting {1})",
303          AcceptAddr.GetIPAddress(), AddrIn.GetIPAddress());
304      continue;
305    }
306    accept_connection = true;
307    accepted_socket.reset(new TCPSocket(sock, *this));
308  }
309
310  if (!accepted_socket)
311    return error;
312
313  // Keep our TCP packets coming without any delays.
314  accepted_socket->SetOptionNoDelay();
315  error.Clear();
316  conn_socket = accepted_socket.release();
317  return error;
318}
319
320int TCPSocket::SetOptionNoDelay() {
321  return SetOption(IPPROTO_TCP, TCP_NODELAY, 1);
322}
323
324int TCPSocket::SetOptionReuseAddress() {
325  return SetOption(SOL_SOCKET, SO_REUSEADDR, 1);
326}
327