1// Copyright 2016 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#define _POSIX_C_SOURCE 200809L
6
7#include "netprotocol.h"
8
9#include <arpa/inet.h>
10#include <netinet/in.h>
11#include <poll.h>
12#include <sched.h>
13#include <sys/param.h>
14#include <sys/socket.h>
15#include <sys/stat.h>
16#include <sys/time.h>
17#include <ifaddrs.h>
18
19#include <fcntl.h>
20#include <libgen.h>
21#include <stdbool.h>
22#include <stdio.h>
23#include <stdlib.h>
24#include <string.h>
25#include <unistd.h>
26
27#include <errno.h>
28#include <stdint.h>
29
30#include <zircon/boot/netboot.h>
31#include <tftp/tftp.h>
32
33#define TFTP_BUF_SZ 2048
34
35typedef struct {
36    int fd;
37    size_t size;
38} file_info_t;
39
40typedef struct {
41    int socket;
42    bool connected;
43    uint32_t previous_timeout_ms;
44    struct sockaddr_in6 target_addr;
45} transport_info_t;
46
47static const char* appname;
48
49static ssize_t file_open_read(const char* filename, void* file_cookie) {
50    int fd = open(filename, O_RDONLY);
51    if (fd < 0) {
52        return TFTP_ERR_IO;
53    }
54    file_info_t *file_info = file_cookie;
55    file_info->fd = fd;
56    struct stat st;
57    if (fstat(file_info->fd, &st) < 0) {
58        close(fd);
59        return TFTP_ERR_IO;
60    }
61    file_info->size = st.st_size;
62    return st.st_size;
63}
64
65static tftp_status file_open_write(const char* filename, size_t size, void* file_cookie) {
66    int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH);
67    if (fd < 0) {
68        return TFTP_ERR_IO;
69    }
70    file_info_t* file_info = file_cookie;
71    file_info->fd = fd;
72    file_info->size = size;
73    return TFTP_NO_ERROR;
74}
75
76static tftp_status file_read(void* data, size_t* length, off_t offset, void* file_cookie) {
77    int fd = ((file_info_t*)file_cookie)->fd;
78    ssize_t n = pread(fd, data, *length, offset);
79    if (n < 0) {
80        return TFTP_ERR_IO;
81    }
82    *length = n;
83    return TFTP_NO_ERROR;
84}
85
86static tftp_status file_write(const void* data, size_t* length, off_t offset, void* file_cookie) {
87    int fd = ((file_info_t*)file_cookie)->fd;
88    ssize_t n = pwrite(fd, data, *length, offset);
89    if (n < 0) {
90        return TFTP_ERR_IO;
91    }
92    *length = n;
93    return TFTP_NO_ERROR;
94}
95
96static void file_close(void* file_cookie) {
97    close(((file_info_t*)file_cookie)->fd);
98}
99
100// Longest time we will wait for a send operation to succeed
101#define MAX_SEND_TIME_MS 1000
102
103static tftp_status transport_send(void* data, size_t len, void* transport_cookie) {
104    transport_info_t* transport_info = transport_cookie;
105    ssize_t send_result;
106    struct pollfd poll_fds = {.fd = transport_info->socket,
107                              .events = POLLOUT};
108    do {
109        int poll_result = poll(&poll_fds, 1, MAX_SEND_TIME_MS);
110        if (poll_result <= 0) {
111            // We'll treat a timeout as an IO error and not a TFTP_ERR_TIMED_OUT,
112            // since the latter is a timeout waiting for a response from the server.
113            return TFTP_ERR_IO;
114        }
115        if (!transport_info->connected) {
116            transport_info->target_addr.sin6_port = htons(NB_TFTP_INCOMING_PORT);
117            send_result = sendto(transport_info->socket, data, len, 0,
118                                 (struct sockaddr*)&transport_info->target_addr,
119                                 sizeof(transport_info->target_addr));
120        } else {
121            send_result = send(transport_info->socket, data, len, 0);
122        }
123    } while ((send_result < 0) &&
124             ((errno == EAGAIN) || (errno == EWOULDBLOCK) ||
125              (errno == ENOBUFS && sched_yield() == 0)));
126
127    if (send_result < 0) {
128        return TFTP_ERR_IO;
129    }
130    return TFTP_NO_ERROR;
131}
132
133static int transport_recv(void* data, size_t len, bool block, void* transport_cookie) {
134    transport_info_t* transport_info = transport_cookie;
135    int flags = fcntl(transport_info->socket, F_GETFL, 0);
136    if (flags < 0) {
137        return TFTP_ERR_IO;
138    }
139    if (block) {
140        flags &= ~O_NONBLOCK;
141    } else {
142        flags |= O_NONBLOCK;
143    }
144    if (fcntl(transport_info->socket, F_SETFL, flags)) {
145        return TFTP_ERR_IO;
146    }
147    ssize_t recv_result;
148    struct sockaddr_in6 connection_addr;
149    socklen_t addr_len = sizeof(connection_addr);
150    if (!transport_info->connected) {
151        recv_result = recvfrom(transport_info->socket, data, len, 0,
152                               (struct sockaddr*)&connection_addr,
153                               &addr_len);
154    } else {
155        recv_result = recv(transport_info->socket, data, len, 0);
156    }
157    if (recv_result < 0) {
158        if ((errno == EAGAIN) || (errno == EWOULDBLOCK)) {
159            return TFTP_ERR_TIMED_OUT;
160        }
161        return TFTP_ERR_INTERNAL;
162    }
163    if (!transport_info->connected) {
164        if (connect(transport_info->socket, (struct sockaddr*)&connection_addr,
165                    sizeof(connection_addr)) < 0) {
166            return TFTP_ERR_IO;
167        }
168        memcpy(&transport_info->target_addr, &connection_addr,
169               sizeof(transport_info->target_addr));
170        transport_info->connected = true;
171    }
172    return recv_result;
173}
174
175static int transport_timeout_set(uint32_t timeout_ms, void* transport_cookie) {
176    transport_info_t* transport_info = transport_cookie;
177    if (transport_info->previous_timeout_ms != timeout_ms && timeout_ms > 0) {
178        transport_info->previous_timeout_ms = timeout_ms;
179        struct timeval tv;
180        tv.tv_sec = timeout_ms / 1000;
181        tv.tv_usec = 1000 * (timeout_ms - 1000 * tv.tv_sec);
182        return setsockopt(transport_info->socket, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
183    }
184    return 0;
185}
186
187static int transfer_file(bool push, int s, struct sockaddr_in6* addr, const char* dst,
188                         const char* src) {
189    // Initialize session
190    tftp_session* session = NULL;
191    size_t session_data_sz = tftp_sizeof_session();
192    void* session_data = calloc(session_data_sz, 1);
193    if (session_data == NULL) {
194        fprintf(stderr, "%s: unable to allocate tftp session memory\n", appname);
195        return 1;
196    }
197    if (tftp_init(&session, session_data, session_data_sz) != TFTP_NO_ERROR) {
198        fprintf(stderr, "%s: unable to initiate tftp session\n", appname);
199        free(session_data);
200        return 1;
201    }
202
203    // Initialize file interface
204    file_info_t file_info;
205    tftp_file_interface file_ifc = {file_open_read, file_open_write,
206                                    file_read, file_write, file_close};
207    tftp_session_set_file_interface(session, &file_ifc);
208
209    // Initialize transport interface
210    transport_info_t transport_info;
211    transport_info.previous_timeout_ms = 0;
212    transport_info.socket = s;
213    transport_info.connected = false;
214    memcpy(&transport_info.target_addr, addr, sizeof(transport_info.target_addr));
215    tftp_transport_interface transport_ifc = {transport_send, transport_recv,
216                                              transport_timeout_set};
217    tftp_session_set_transport_interface(session, &transport_ifc);
218
219    // Set our preferred transport options
220    tftp_set_options(session, &tftp_block_size, NULL, &tftp_window_size);
221
222    // Prepare buffers
223    char err_msg[128];
224    tftp_request_opts opts = { 0 };
225    opts.inbuf = malloc(TFTP_BUF_SZ);
226    opts.inbuf_sz = TFTP_BUF_SZ;
227    opts.outbuf = malloc(TFTP_BUF_SZ);
228    opts.outbuf_sz = TFTP_BUF_SZ;
229    opts.err_msg = err_msg;
230    opts.err_msg_sz = sizeof(err_msg);
231
232    tftp_status status;
233    if (push) {
234        status = tftp_push_file(session, &transport_info, &file_info, src, dst, &opts);
235    } else {
236        status = tftp_pull_file(session, &transport_info, &file_info, dst, src, &opts);
237    }
238
239    free(session_data);
240    free(opts.inbuf);
241    free(opts.outbuf);
242
243    if (status < 0) {
244        fprintf(stderr, "%s: %s (status = %d)\n", appname, opts.err_msg, (int)status);
245        return 1;
246    }
247
248    fprintf(stderr, "wrote %zu bytes\n", file_info.size);
249
250    return 0;
251}
252
253static void usage(void) {
254    fprintf(stderr, "usage: %s [options] [hostname:]src [hostname:]dst\n", appname);
255    netboot_usage(true);
256}
257
258int main(int argc, char** argv) {
259    appname = argv[0];
260
261    int index = netboot_handle_getopt(argc, argv);
262    if (index < 0) {
263        usage();
264        return -1;
265    }
266    argv += index;
267    argc -= index;
268
269    if (argc != 2) {
270        usage();
271        return -1;
272    }
273
274    const char* src = argv[0];
275    const char* dst = argv[1];
276
277    int push = -1;
278    char* pos;
279    const char* hostname;
280    if ((pos = strpbrk(src, ":")) != 0) {
281        push = 0;
282        hostname = src;
283        pos[0] = 0;
284        src = pos+1;
285    }
286    if ((pos = strpbrk(dst, ":")) != 0) {
287        if (push == 0) {
288            fprintf(stderr, "%s: only one of src or dst can have a hostname\n", appname);
289            return -1;
290        }
291        push = 1;
292        hostname = dst;
293        pos[0] = 0;
294        dst = pos+1;
295    }
296    if (push == -1) {
297        fprintf(stderr, "%s: either src or dst needs a hostname\n", appname);
298        return -1;
299    }
300
301    int s;
302    struct sockaddr_in6 server_addr;
303    if ((s = netboot_open(hostname, NULL, &server_addr, false)) < 0) {
304        if (errno == ETIMEDOUT) {
305            fprintf(stderr, "%s: lookup of %s timed out\n", appname, hostname);
306        } else {
307            fprintf(stderr, "%s: failed to connect to %s: %d\n", appname, hostname, errno);
308        }
309        return -1;
310    }
311
312
313    int ret = transfer_file(push, s, &server_addr, dst, src);
314    close(s);
315    return ret;
316}
317