1292932Sdim//===-- Acceptor.cpp --------------------------------------------*- C++ -*-===//
2292932Sdim//
3353358Sdim// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4353358Sdim// See https://llvm.org/LICENSE.txt for license information.
5353358Sdim// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6292932Sdim//
7292932Sdim//===----------------------------------------------------------------------===//
8292932Sdim
9292932Sdim#include "Acceptor.h"
10292932Sdim
11292932Sdim#include "llvm/ADT/StringRef.h"
12314564Sdim#include "llvm/Support/ScopedPrinter.h"
13292932Sdim
14292932Sdim#include "lldb/Host/ConnectionFileDescriptor.h"
15292932Sdim#include "lldb/Host/common/TCPSocket.h"
16321369Sdim#include "lldb/Utility/StreamString.h"
17321369Sdim#include "lldb/Utility/UriParser.h"
18292932Sdim
19292932Sdimusing namespace lldb;
20292932Sdimusing namespace lldb_private;
21292932Sdimusing namespace lldb_private::lldb_server;
22292932Sdimusing namespace llvm;
23292932Sdim
24292932Sdimnamespace {
25292932Sdim
26314564Sdimstruct SocketScheme {
27314564Sdim  const char *m_scheme;
28314564Sdim  const Socket::SocketProtocol m_protocol;
29292932Sdim};
30292932Sdim
31292932SdimSocketScheme socket_schemes[] = {
32292932Sdim    {"tcp", Socket::ProtocolTcp},
33292932Sdim    {"udp", Socket::ProtocolUdp},
34292932Sdim    {"unix", Socket::ProtocolUnixDomain},
35292932Sdim    {"unix-abstract", Socket::ProtocolUnixAbstract},
36292932Sdim};
37292932Sdim
38314564Sdimbool FindProtocolByScheme(const char *scheme,
39314564Sdim                          Socket::SocketProtocol &protocol) {
40314564Sdim  for (auto s : socket_schemes) {
41314564Sdim    if (!strcmp(s.m_scheme, scheme)) {
42314564Sdim      protocol = s.m_protocol;
43314564Sdim      return true;
44292932Sdim    }
45314564Sdim  }
46314564Sdim  return false;
47292932Sdim}
48292932Sdim
49314564Sdimconst char *FindSchemeByProtocol(const Socket::SocketProtocol protocol) {
50314564Sdim  for (auto s : socket_schemes) {
51314564Sdim    if (s.m_protocol == protocol)
52314564Sdim      return s.m_scheme;
53314564Sdim  }
54314564Sdim  return nullptr;
55292932Sdim}
56292932Sdim}
57292932Sdim
58321369SdimStatus Acceptor::Listen(int backlog) {
59314564Sdim  return m_listener_socket_up->Listen(StringRef(m_name), backlog);
60292932Sdim}
61292932Sdim
62321369SdimStatus Acceptor::Accept(const bool child_processes_inherit, Connection *&conn) {
63314564Sdim  Socket *conn_socket = nullptr;
64321369Sdim  auto error = m_listener_socket_up->Accept(conn_socket);
65314564Sdim  if (error.Success())
66314564Sdim    conn = new ConnectionFileDescriptor(conn_socket);
67292932Sdim
68314564Sdim  return error;
69292932Sdim}
70292932Sdim
71314564SdimSocket::SocketProtocol Acceptor::GetSocketProtocol() const {
72314564Sdim  return m_listener_socket_up->GetSocketProtocol();
73292932Sdim}
74292932Sdim
75314564Sdimconst char *Acceptor::GetSocketScheme() const {
76314564Sdim  return FindSchemeByProtocol(GetSocketProtocol());
77292932Sdim}
78292932Sdim
79314564Sdimstd::string Acceptor::GetLocalSocketId() const { return m_local_socket_id(); }
80292932Sdim
81314564Sdimstd::unique_ptr<Acceptor> Acceptor::Create(StringRef name,
82314564Sdim                                           const bool child_processes_inherit,
83321369Sdim                                           Status &error) {
84314564Sdim  error.Clear();
85292932Sdim
86314564Sdim  Socket::SocketProtocol socket_protocol = Socket::ProtocolUnixDomain;
87314564Sdim  int port;
88314564Sdim  StringRef scheme, host, path;
89314564Sdim  // Try to match socket name as URL - e.g., tcp://localhost:5555
90314564Sdim  if (UriParser::Parse(name, scheme, host, port, path)) {
91314564Sdim    if (!FindProtocolByScheme(scheme.str().c_str(), socket_protocol))
92314564Sdim      error.SetErrorStringWithFormat("Unknown protocol scheme \"%s\"",
93314564Sdim                                     scheme.str().c_str());
94292932Sdim    else
95314564Sdim      name = name.drop_front(scheme.size() + strlen("://"));
96314564Sdim  } else {
97314564Sdim    std::string host_str;
98314564Sdim    std::string port_str;
99314564Sdim    int32_t port = INT32_MIN;
100314564Sdim    // Try to match socket name as $host:port - e.g., localhost:5555
101314564Sdim    if (Socket::DecodeHostAndPort(name, host_str, port_str, port, nullptr))
102314564Sdim      socket_protocol = Socket::ProtocolTcp;
103314564Sdim  }
104292932Sdim
105314564Sdim  if (error.Fail())
106314564Sdim    return std::unique_ptr<Acceptor>();
107292932Sdim
108314564Sdim  std::unique_ptr<Socket> listener_socket_up =
109314564Sdim      Socket::Create(socket_protocol, child_processes_inherit, error);
110292932Sdim
111314564Sdim  LocalSocketIdFunc local_socket_id;
112314564Sdim  if (error.Success()) {
113314564Sdim    if (listener_socket_up->GetSocketProtocol() == Socket::ProtocolTcp) {
114314564Sdim      TCPSocket *tcp_socket =
115314564Sdim          static_cast<TCPSocket *>(listener_socket_up.get());
116314564Sdim      local_socket_id = [tcp_socket]() {
117314564Sdim        auto local_port = tcp_socket->GetLocalPortNumber();
118314564Sdim        return (local_port != 0) ? llvm::to_string(local_port) : "";
119314564Sdim      };
120314564Sdim    } else {
121314564Sdim      const std::string socket_name = name;
122314564Sdim      local_socket_id = [socket_name]() { return socket_name; };
123292932Sdim    }
124292932Sdim
125314564Sdim    return std::unique_ptr<Acceptor>(
126314564Sdim        new Acceptor(std::move(listener_socket_up), name, local_socket_id));
127314564Sdim  }
128314564Sdim
129314564Sdim  return std::unique_ptr<Acceptor>();
130292932Sdim}
131292932Sdim
132314564SdimAcceptor::Acceptor(std::unique_ptr<Socket> &&listener_socket, StringRef name,
133292932Sdim                   const LocalSocketIdFunc &local_socket_id)
134314564Sdim    : m_listener_socket_up(std::move(listener_socket)), m_name(name.str()),
135314564Sdim      m_local_socket_id(local_socket_id) {}
136