TCPSocket.cpp revision 296417
1//===-- TcpSocket.cpp -------------------------------------------*- C++ -*-===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9
10#include "lldb/Host/common/TCPSocket.h"
11
12#include "lldb/Core/Log.h"
13#include "lldb/Host/Config.h"
14
15#ifndef LLDB_DISABLE_POSIX
16#include <arpa/inet.h>
17#include <netinet/tcp.h>
18#include <sys/socket.h>
19#endif
20
21using namespace lldb;
22using namespace lldb_private;
23
24namespace {
25
26const int kDomain = AF_INET;
27const int kType   = SOCK_STREAM;
28
29}
30
31TCPSocket::TCPSocket(NativeSocket socket, bool should_close)
32    : Socket(socket, ProtocolTcp, should_close)
33{
34
35}
36
37TCPSocket::TCPSocket(bool child_processes_inherit, Error &error)
38    : TCPSocket(CreateSocket(kDomain, kType, IPPROTO_TCP, child_processes_inherit, error), true)
39{
40}
41
42
43// Return the port number that is being used by the socket.
44uint16_t
45TCPSocket::GetLocalPortNumber() const
46{
47    if (m_socket != kInvalidSocketValue)
48    {
49        SocketAddress sock_addr;
50        socklen_t sock_addr_len = sock_addr.GetMaxLength ();
51        if (::getsockname (m_socket, sock_addr, &sock_addr_len) == 0)
52            return sock_addr.GetPort ();
53    }
54    return 0;
55}
56
57std::string
58TCPSocket::GetLocalIPAddress() const
59{
60    // We bound to port zero, so we need to figure out which port we actually bound to
61    if (m_socket != kInvalidSocketValue)
62    {
63        SocketAddress sock_addr;
64        socklen_t sock_addr_len = sock_addr.GetMaxLength ();
65        if (::getsockname (m_socket, sock_addr, &sock_addr_len) == 0)
66            return sock_addr.GetIPAddress ();
67    }
68    return "";
69}
70
71uint16_t
72TCPSocket::GetRemotePortNumber() const
73{
74    if (m_socket != kInvalidSocketValue)
75    {
76        SocketAddress sock_addr;
77        socklen_t sock_addr_len = sock_addr.GetMaxLength ();
78        if (::getpeername (m_socket, sock_addr, &sock_addr_len) == 0)
79            return sock_addr.GetPort ();
80    }
81    return 0;
82}
83
84std::string
85TCPSocket::GetRemoteIPAddress () const
86{
87    // We bound to port zero, so we need to figure out which port we actually bound to
88    if (m_socket != kInvalidSocketValue)
89    {
90        SocketAddress sock_addr;
91        socklen_t sock_addr_len = sock_addr.GetMaxLength ();
92        if (::getpeername (m_socket, sock_addr, &sock_addr_len) == 0)
93            return sock_addr.GetIPAddress ();
94    }
95    return "";
96}
97
98Error
99TCPSocket::Connect(llvm::StringRef name)
100{
101    if (m_socket == kInvalidSocketValue)
102        return Error("Invalid socket");
103
104    Log *log(lldb_private::GetLogIfAnyCategoriesSet (LIBLLDB_LOG_COMMUNICATION));
105    if (log)
106        log->Printf ("TCPSocket::%s (host/port = %s)", __FUNCTION__, name.data());
107
108    Error error;
109    std::string host_str;
110    std::string port_str;
111    int32_t port = INT32_MIN;
112    if (!DecodeHostAndPort (name, host_str, port_str, port, &error))
113        return error;
114
115    // Enable local address reuse
116    SetOptionReuseAddress();
117
118    struct sockaddr_in sa;
119    ::memset (&sa, 0, sizeof (sa));
120    sa.sin_family = kDomain;
121    sa.sin_port = htons (port);
122
123    int inet_pton_result = ::inet_pton (kDomain, host_str.c_str(), &sa.sin_addr);
124
125    if (inet_pton_result <= 0)
126    {
127        struct hostent *host_entry = gethostbyname (host_str.c_str());
128        if (host_entry)
129            host_str = ::inet_ntoa (*(struct in_addr *)*host_entry->h_addr_list);
130        inet_pton_result = ::inet_pton (kDomain, host_str.c_str(), &sa.sin_addr);
131        if (inet_pton_result <= 0)
132        {
133            if (inet_pton_result == -1)
134                SetLastError(error);
135            else
136                error.SetErrorStringWithFormat("invalid host string: '%s'", host_str.c_str());
137
138            return error;
139        }
140    }
141
142    if (-1 == ::connect (GetNativeSocket(), (const struct sockaddr *)&sa, sizeof(sa)))
143    {
144        SetLastError (error);
145        return error;
146    }
147
148    // Keep our TCP packets coming without any delays.
149    SetOptionNoDelay();
150    error.Clear();
151    return error;
152}
153
154Error
155TCPSocket::Listen(llvm::StringRef name, int backlog)
156{
157    Error error;
158
159    // enable local address reuse
160    SetOptionReuseAddress();
161
162    Log *log(lldb_private::GetLogIfAnyCategoriesSet (LIBLLDB_LOG_CONNECTION));
163    if (log)
164        log->Printf ("TCPSocket::%s (%s)", __FUNCTION__, name.data());
165
166    std::string host_str;
167    std::string port_str;
168    int32_t port = INT32_MIN;
169    if (!DecodeHostAndPort (name, host_str, port_str, port, &error))
170        return error;
171
172    SocketAddress bind_addr;
173
174    // Only bind to the loopback address if we are expecting a connection from
175    // localhost to avoid any firewall issues.
176    const bool bind_addr_success = (host_str == "127.0.0.1") ?
177                                    bind_addr.SetToLocalhost (kDomain, port) :
178                                    bind_addr.SetToAnyAddress (kDomain, port);
179
180    if (!bind_addr_success)
181    {
182        error.SetErrorString("Failed to bind port");
183        return error;
184    }
185
186    int err = ::bind (GetNativeSocket(), bind_addr, bind_addr.GetLength());
187    if (err != -1)
188        err = ::listen (GetNativeSocket(), backlog);
189
190    if (err == -1)
191        SetLastError (error);
192
193    return error;
194}
195
196Error
197TCPSocket::Accept(llvm::StringRef name, bool child_processes_inherit, Socket *&conn_socket)
198{
199    Error error;
200    std::string host_str;
201    std::string port_str;
202    int32_t port;
203    if (!DecodeHostAndPort(name, host_str, port_str, port, &error))
204        return error;
205
206    const sa_family_t family = kDomain;
207    const int socktype = kType;
208    const int protocol = IPPROTO_TCP;
209    SocketAddress listen_addr;
210    if (host_str.empty())
211        listen_addr.SetToLocalhost(family, port);
212    else if (host_str.compare("*") == 0)
213        listen_addr.SetToAnyAddress(family, port);
214    else
215    {
216        if (!listen_addr.getaddrinfo(host_str.c_str(), port_str.c_str(), family, socktype, protocol))
217        {
218            error.SetErrorStringWithFormat("unable to resolve hostname '%s'", host_str.c_str());
219            return error;
220        }
221    }
222
223    bool accept_connection = false;
224    std::unique_ptr<TCPSocket> accepted_socket;
225
226    // Loop until we are happy with our connection
227    while (!accept_connection)
228    {
229        struct sockaddr_in accept_addr;
230        ::memset (&accept_addr, 0, sizeof accept_addr);
231#if !(defined (__linux__) || defined(_WIN32))
232        accept_addr.sin_len = sizeof accept_addr;
233#endif
234        socklen_t accept_addr_len = sizeof accept_addr;
235
236        int sock = AcceptSocket (GetNativeSocket(),
237                                 (struct sockaddr *)&accept_addr,
238                                 &accept_addr_len,
239                                 child_processes_inherit,
240                                 error);
241
242        if (error.Fail())
243            break;
244
245        bool is_same_addr = true;
246#if !(defined(__linux__) || (defined(_WIN32)))
247        is_same_addr = (accept_addr_len == listen_addr.sockaddr_in().sin_len);
248#endif
249        if (is_same_addr)
250            is_same_addr = (accept_addr.sin_addr.s_addr == listen_addr.sockaddr_in().sin_addr.s_addr);
251
252        if (is_same_addr || (listen_addr.sockaddr_in().sin_addr.s_addr == INADDR_ANY))
253        {
254            accept_connection = true;
255            accepted_socket.reset(new TCPSocket(sock, true));
256        }
257        else
258        {
259            const uint8_t *accept_ip = (const uint8_t *)&accept_addr.sin_addr.s_addr;
260            const uint8_t *listen_ip = (const uint8_t *)&listen_addr.sockaddr_in().sin_addr.s_addr;
261            ::fprintf (stderr, "error: rejecting incoming connection from %u.%u.%u.%u (expecting %u.%u.%u.%u)\n",
262                        accept_ip[0], accept_ip[1], accept_ip[2], accept_ip[3],
263                        listen_ip[0], listen_ip[1], listen_ip[2], listen_ip[3]);
264            accepted_socket.reset();
265        }
266    }
267
268    if (!accepted_socket)
269        return error;
270
271    // Keep our TCP packets coming without any delays.
272    accepted_socket->SetOptionNoDelay();
273    error.Clear();
274    conn_socket = accepted_socket.release();
275    return error;
276}
277
278int
279TCPSocket::SetOptionNoDelay()
280{
281    return SetOption (IPPROTO_TCP, TCP_NODELAY, 1);
282}
283
284int
285TCPSocket::SetOptionReuseAddress()
286{
287    return SetOption(SOL_SOCKET, SO_REUSEADDR, 1);
288}
289