1//===-- llvm/Support/raw_socket_stream.cpp - Socket streams --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains raw_ostream implementations for streams to communicate
10// via UNIX sockets
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/Support/raw_socket_stream.h"
15#include "llvm/Config/config.h"
16#include "llvm/Support/Error.h"
17
18#ifndef _WIN32
19#include <sys/socket.h>
20#include <sys/un.h>
21#else
22#include "llvm/Support/Windows/WindowsSupport.h"
23// winsock2.h must be included before afunix.h. Briefly turn off clang-format to
24// avoid error.
25// clang-format off
26#include <winsock2.h>
27#include <afunix.h>
28// clang-format on
29#include <io.h>
30#endif // _WIN32
31
32#if defined(HAVE_UNISTD_H)
33#include <unistd.h>
34#endif
35
36using namespace llvm;
37
38#ifdef _WIN32
39WSABalancer::WSABalancer() {
40  WSADATA WsaData;
41  ::memset(&WsaData, 0, sizeof(WsaData));
42  if (WSAStartup(MAKEWORD(2, 2), &WsaData) != 0) {
43    llvm::report_fatal_error("WSAStartup failed");
44  }
45}
46
47WSABalancer::~WSABalancer() { WSACleanup(); }
48
49#endif // _WIN32
50
51static std::error_code getLastSocketErrorCode() {
52#ifdef _WIN32
53  return std::error_code(::WSAGetLastError(), std::system_category());
54#else
55  return std::error_code(errno, std::system_category());
56#endif
57}
58
59ListeningSocket::ListeningSocket(int SocketFD, StringRef SocketPath)
60    : FD(SocketFD), SocketPath(SocketPath) {}
61
62ListeningSocket::ListeningSocket(ListeningSocket &&LS)
63    : FD(LS.FD), SocketPath(LS.SocketPath) {
64  LS.FD = -1;
65}
66
67Expected<ListeningSocket> ListeningSocket::createUnix(StringRef SocketPath,
68                                                      int MaxBacklog) {
69
70#ifdef _WIN32
71  WSABalancer _;
72  SOCKET MaybeWinsocket = socket(AF_UNIX, SOCK_STREAM, 0);
73  if (MaybeWinsocket == INVALID_SOCKET) {
74#else
75  int MaybeWinsocket = socket(AF_UNIX, SOCK_STREAM, 0);
76  if (MaybeWinsocket == -1) {
77#endif
78    return llvm::make_error<StringError>(getLastSocketErrorCode(),
79                                         "socket create failed");
80  }
81
82  struct sockaddr_un Addr;
83  memset(&Addr, 0, sizeof(Addr));
84  Addr.sun_family = AF_UNIX;
85  strncpy(Addr.sun_path, SocketPath.str().c_str(), sizeof(Addr.sun_path) - 1);
86
87  if (bind(MaybeWinsocket, (struct sockaddr *)&Addr, sizeof(Addr)) == -1) {
88    std::error_code Err = getLastSocketErrorCode();
89    if (Err == std::errc::address_in_use)
90      ::close(MaybeWinsocket);
91    return llvm::make_error<StringError>(Err, "Bind error");
92  }
93  if (listen(MaybeWinsocket, MaxBacklog) == -1) {
94    return llvm::make_error<StringError>(getLastSocketErrorCode(),
95                                         "Listen error");
96  }
97  int UnixSocket;
98#ifdef _WIN32
99  UnixSocket = _open_osfhandle(MaybeWinsocket, 0);
100#else
101  UnixSocket = MaybeWinsocket;
102#endif // _WIN32
103  return ListeningSocket{UnixSocket, SocketPath};
104}
105
106Expected<std::unique_ptr<raw_socket_stream>> ListeningSocket::accept() {
107  int AcceptFD;
108#ifdef _WIN32
109  SOCKET WinServerSock = _get_osfhandle(FD);
110  SOCKET WinAcceptSock = ::accept(WinServerSock, NULL, NULL);
111  AcceptFD = _open_osfhandle(WinAcceptSock, 0);
112#else
113  AcceptFD = ::accept(FD, NULL, NULL);
114#endif //_WIN32
115  if (AcceptFD == -1)
116    return llvm::make_error<StringError>(getLastSocketErrorCode(),
117                                         "Accept failed");
118  return std::make_unique<raw_socket_stream>(AcceptFD);
119}
120
121ListeningSocket::~ListeningSocket() {
122  if (FD == -1)
123    return;
124  ::close(FD);
125  unlink(SocketPath.c_str());
126}
127
128static Expected<int> GetSocketFD(StringRef SocketPath) {
129#ifdef _WIN32
130  SOCKET MaybeWinsocket = socket(AF_UNIX, SOCK_STREAM, 0);
131  if (MaybeWinsocket == INVALID_SOCKET) {
132#else
133  int MaybeWinsocket = socket(AF_UNIX, SOCK_STREAM, 0);
134  if (MaybeWinsocket == -1) {
135#endif // _WIN32
136    return llvm::make_error<StringError>(getLastSocketErrorCode(),
137                                         "Create socket failed");
138  }
139
140  struct sockaddr_un Addr;
141  memset(&Addr, 0, sizeof(Addr));
142  Addr.sun_family = AF_UNIX;
143  strncpy(Addr.sun_path, SocketPath.str().c_str(), sizeof(Addr.sun_path) - 1);
144
145  int status = connect(MaybeWinsocket, (struct sockaddr *)&Addr, sizeof(Addr));
146  if (status == -1) {
147    return llvm::make_error<StringError>(getLastSocketErrorCode(),
148                                         "Connect socket failed");
149  }
150#ifdef _WIN32
151  return _open_osfhandle(MaybeWinsocket, 0);
152#else
153  return MaybeWinsocket;
154#endif // _WIN32
155}
156
157raw_socket_stream::raw_socket_stream(int SocketFD)
158    : raw_fd_stream(SocketFD, true) {}
159
160Expected<std::unique_ptr<raw_socket_stream>>
161raw_socket_stream::createConnectedUnix(StringRef SocketPath) {
162#ifdef _WIN32
163  WSABalancer _;
164#endif // _WIN32
165  Expected<int> FD = GetSocketFD(SocketPath);
166  if (!FD)
167    return FD.takeError();
168  return std::make_unique<raw_socket_stream>(*FD);
169}
170
171raw_socket_stream::~raw_socket_stream() {}
172
173//===----------------------------------------------------------------------===//
174//  raw_string_ostream
175//===----------------------------------------------------------------------===//
176
177void raw_string_ostream::write_impl(const char *Ptr, size_t Size) {
178  OS.append(Ptr, Size);
179}
180