1/**
2 * \file
3 * \brief TFTP library
4 */
5
6/*
7 * Copyright (c) 2015 ETH Zurich.
8 * All rights reserved.
9 *
10 * This file is distributed under the terms in the attached LICENSE file.
11 * If you do not find this file, copies can be found by writing to:
12 * ETH Zurich D-INFK, Universitaetsstrasse 6, CH-8092 Zurich. Attn: Systems Group.
13 */
14
15#include <stdlib.h>
16#include <stdio.h>
17
18#include <barrelfish/barrelfish.h>
19#include <barrelfish/waitset.h>
20#include <barrelfish/nameservice_client.h>
21
22#include <net_sockets/net_sockets.h>
23#include <arpa/inet.h>
24
25#include <tftp/tftp.h>
26
27#include "tftp_internal.h"
28
29
30// error definitions
31#define TFTP_ERR_BUSY 1
32#define TFTP_ERR_DISCONNECTED 1
33#define TFTP_ERR_NOT_FOUND 1
34#define TFTP_ERR_ACCESS_DENIED 1
35#define TFTP_ERR_FILE_EXISTS 1
36
37
38
39///< the TFTP client
40struct tftp_client
41{
42    /* client state */
43    tftp_st_t state;
44
45    /* connection information */
46    struct in_addr server_ip;
47    uint16_t server_port;
48    tftp_mode_t mode;
49
50    /* request information */
51    uint32_t block;
52    size_t bytes;
53    void *buf;
54    size_t buflen;
55
56    /* connection information */
57    struct net_socket *pcb;
58    void *ppayload;
59};
60
61
62struct tftp_client tftp_client;
63
64
65static errval_t tftp_client_send_data(struct net_socket *socket, uint32_t blockno, void *buf,
66                                      uint32_t length, struct in_addr addr, uint16_t port)
67{
68    void *payload = tftp_client.ppayload;
69    errval_t err;
70
71    size_t offset = set_opcode(payload, TFTP_OP_DATA);
72    offset += set_block_no(payload + offset, blockno);
73    if (length > TFTP_BLOCKSIZE) {
74        length = TFTP_BLOCKSIZE;
75    }
76
77    memcpy(payload + offset, buf, length);
78    err = net_send_to(socket, payload, length + offset, addr, port);
79    assert(err_is_ok(err));
80    return SYS_ERR_OK;
81}
82
83
84/*
85 * ------------------------------------------------------------------------------
86 * Recv Handlers
87 * ------------------------------------------------------------------------------
88 */
89
90static void tftp_client_handle_write(struct net_socket *socket, void *data,
91    size_t size, struct in_addr ip_address, uint16_t port)
92{
93    USER_PANIC("NYI");
94    tpft_op_t op = get_opcode(data);
95    uint32_t blockno;
96    switch(op) {
97        case TFTP_OP_ACK :
98            blockno = get_block_no(data, size);
99            if (blockno == TFTP_ERR_INVALID_BUFFER) {
100                TFTP_DEBUG("failed to decode block number in data packet\n");
101                break;
102            }
103
104            if (blockno == tftp_client.block) {
105                if (tftp_client.state == TFTP_ST_LAST_DATA_SENT) {
106                    tftp_client.state = TFTP_ST_CLOSED;
107                    break;
108                }
109
110                uint32_t offset = TFTP_BLOCKSIZE * blockno;
111                uint32_t length = TFTP_BLOCKSIZE;
112                if (tftp_client.buflen - offset < TFTP_BLOCKSIZE) {
113                    length = tftp_client.buflen - offset;
114                    tftp_client.state = TFTP_ST_LAST_DATA_SENT;
115                }
116
117                tftp_client.block++;
118
119                tftp_client_send_data(socket, tftp_client.block, tftp_client.buf + offset, length,
120                                      ip_address, port);
121                tftp_client.state = TFTP_ST_DATA_SENT;
122            } else  {
123                TFTP_DEBUG("got double packet: %u\n", blockno);
124            }
125
126            break;
127        case TFTP_OP_ERROR :
128            TFTP_DEBUG("got a error packet\n");
129            break;
130        default:
131            tftp_client.state = TFTP_ST_ERROR;
132            break;
133    }
134}
135
136static void tftp_client_handle_read(struct net_socket *socket, void *data,
137    size_t size, struct in_addr ip_address, uint16_t port)
138{
139    tpft_op_t op = get_opcode(data);
140    uint32_t blockno;
141    switch(op) {
142        case TFTP_OP_DATA :
143            blockno = get_block_no(data, size);
144            if (blockno == TFTP_ERR_INVALID_BUFFER) {
145                TFTP_DEBUG("failed to decode block number in data packet\n");
146                break;
147            }
148
149            if (blockno == tftp_client.block) {
150                if (size < 5) {
151                    TFTP_DEBUG("too small pbuf lenth\n");
152                }
153
154                void *buf = data + 4;
155                size_t length = size - 4;
156                TFTP_DEBUG_PACKETS("received block %u of size %lu bytes\n", blockno, length);
157
158                if (tftp_client.buflen < tftp_client.bytes + length) {
159                    TFTP_DEBUG("too less bufferspace available\n");
160                    length = tftp_client.buflen - tftp_client.bytes;
161                }
162                memcpy(tftp_client.buf + tftp_client.bytes, buf, length);
163
164                int r = tftp_send_ack(socket, blockno, ip_address, port,
165                                      tftp_client.ppayload);
166                if (r != SYS_ERR_OK) {
167                    tftp_client.state = TFTP_ST_ERROR;
168                    break;
169                }
170                tftp_client.state = TFTP_ST_ACK_SENT;
171                tftp_client.block++;
172                tftp_client.bytes += length;
173                if (length < TFTP_BLOCKSIZE) {
174                    TFTP_DEBUG("setting the last ack state\n");
175                    tftp_client.state = TFTP_ST_LAST_ACK_SENT;
176                }
177            } else  {
178                TFTP_DEBUG("got double packet: %u\n", blockno);
179                int r = tftp_send_ack(socket, blockno, ip_address, port,
180                                      tftp_client.ppayload);
181                if (r != SYS_ERR_OK) {
182                    tftp_client.state = TFTP_ST_ERROR;
183                    break;
184                }
185                tftp_client.state = TFTP_ST_ACK_SENT;
186            }
187
188            break;
189        case TFTP_OP_ERROR :
190            TFTP_DEBUG("got a error packet\n");
191            get_error(data, size);
192            tftp_client.state = TFTP_ST_ERROR;
193            break;
194        default:
195            tftp_client.state = TFTP_ST_ERROR;
196            TFTP_DEBUG("unexpected packet\n");
197            break;
198    }
199}
200
201
202static void tftp_client_recv_handler(void *user_state, struct net_socket *socket,
203    void *data, size_t size, struct in_addr ip_address, uint16_t port)
204{
205    switch(tftp_client.state) {
206        case TFTP_ST_WRITE_REQ_SENT:
207        case TFTP_ST_DATA_SENT :
208        case TFTP_ST_LAST_DATA_SENT :
209            tftp_client_handle_write(socket, data, size, ip_address, port);
210            break;
211        case TFTP_ST_READ_REQ_SENT :
212        case TFTP_ST_ACK_SENT :
213            tftp_client_handle_read(socket, data, size, ip_address, port);
214            break;
215        default:
216            TFTP_DEBUG("unexpected state: %u\n", tftp_client.state);
217            break;
218    }
219}
220
221static void new_request(char *path, tpft_op_t opcode)
222{
223    size_t path_length = strlen(path);
224    assert(strlen(path) + 14 < TFTP_MAX_MSGSIZE);
225
226    void *payload = tftp_client.ppayload;
227
228    memset(payload, 0, path_length + 16);
229
230    size_t length = set_opcode(payload, opcode);
231
232    length += snprintf(payload + length, path_length + 1, "%s", path) + 1;
233    length += set_mode(payload + length, tftp_client.mode);
234
235    TFTP_DEBUG("sending udp payload of %lu bytes\n", length);
236
237    errval_t err;
238    err = net_send_to(tftp_client.pcb, payload, length, tftp_client.server_ip, tftp_client.server_port);
239    if (err != SYS_ERR_OK) {
240        TFTP_DEBUG("send failed\n");
241    }
242}
243
244
245errval_t tftp_client_write_file(char *name, void *buf, size_t buflen)
246{
247    if (tftp_client.state < TFTP_ST_IDLE) {
248        TFTP_DEBUG("attempt to read file with no connection");
249        return TFTP_ERR_DISCONNECTED;
250    }
251
252    if (tftp_client.state > TFTP_ST_IDLE) {
253        return TFTP_ERR_BUSY;
254    }
255
256    tftp_client.buf = buf;
257    tftp_client.buflen = buflen;
258    tftp_client.block = 1;
259    tftp_client.state = TFTP_ST_WRITE_REQ_SENT;
260    tftp_client.bytes = 0;
261
262    return SYS_ERR_OK;
263}
264
265errval_t tftp_client_read_file(char *path, void *buf, size_t buflen, size_t *ret_size)
266{
267    if (tftp_client.state < TFTP_ST_IDLE) {
268        TFTP_DEBUG("attempt to read file with no connection");
269        return TFTP_ERR_DISCONNECTED;
270    }
271
272    if (tftp_client.state > TFTP_ST_IDLE) {
273        return TFTP_ERR_BUSY;
274    }
275
276    tftp_client.buf = buf;
277    tftp_client.buflen = buflen;
278    tftp_client.block = 1;
279    tftp_client.state = TFTP_ST_READ_REQ_SENT;
280    tftp_client.bytes = 0;
281
282    assert(tftp_client.pcb);
283
284    TFTP_DEBUG("read request of file %s\n", path);
285
286    new_request(path, TFTP_OP_READ_REQ);
287
288    while(tftp_client.state > TFTP_ST_ERROR) {
289        event_dispatch(get_default_waitset());
290    }
291
292    TFTP_DEBUG("tftp read file done.\n");
293
294    if (ret_size) {
295        *ret_size = tftp_client.bytes;
296    }
297
298    if (tftp_client.state == TFTP_ST_ERROR) {
299        tftp_client.state = TFTP_ST_IDLE;
300        return -1;
301    }
302
303    tftp_client.state = TFTP_ST_IDLE;
304
305    return SYS_ERR_OK;
306}
307
308
309
310/**
311 * \brief attempts to initialize a new TFTP connection to a server
312 *
313 * \returns SYS_ERR_OK on success
314 *          TFTP_ERR_* on failure
315 */
316errval_t tftp_client_connect(char *ip, uint16_t port)
317{
318    switch(tftp_client.state) {
319        case TFTP_ST_INVALID :
320            net_sockets_init();
321            tftp_client.pcb = net_udp_socket();
322            TFTP_DEBUG("new connection from uninitialized state\n");
323            break;
324        case TFTP_ST_CLOSED :
325            TFTP_DEBUG("new connection from closed state\n");
326            tftp_client.pcb = net_udp_socket();
327            break;
328        default:
329            TFTP_DEBUG("connection already established, cannot connect\n");
330            return TFTP_ERR_BUSY;
331    }
332
333    if (tftp_client.pcb == NULL) {
334        return LIB_ERR_MALLOC_FAIL;
335    }
336
337    tftp_client.server_port = port;
338
339    int ret = inet_aton(ip, &tftp_client.server_ip);
340    if (ret == 0) {
341        TFTP_DEBUG("Invalid IP addr: %s\n", ip);
342        return 1;
343    }
344
345    TFTP_DEBUG("connecting to %s:%" PRIu16 "\n", ip, port);
346
347    errval_t r;
348    r = net_bind(tftp_client.pcb, (struct in_addr){(INADDR_ANY)}, 0);
349    if (r != SYS_ERR_OK) {
350        USER_PANIC("UDP bind failed");
351    }
352    debug_printf("bound to %d\n", tftp_client.pcb->bound_port);
353
354    // r = net_connect(tftp_client.pcb, tftp_client.server_ip, tftp_client.server_port, NULL);
355    // if (r != SYS_ERR_OK) {
356    //     USER_PANIC("UDP connect failed");
357    // }
358
359    TFTP_DEBUG("registering recv handler\n");
360    net_set_on_received(tftp_client.pcb, tftp_client_recv_handler);
361
362    tftp_client.state = TFTP_ST_IDLE;
363    tftp_client.mode = TFTP_MODE_OCTET;
364    tftp_client.ppayload = net_alloc(TFTP_MAX_MSGSIZE);
365    TFTP_DEBUG("all set up. connection idle\n");
366    return SYS_ERR_OK;
367}
368
369errval_t tftp_client_disconnect(void)
370{
371    net_free(tftp_client.ppayload);
372    net_close(tftp_client.pcb);
373    tftp_client.state = TFTP_ST_CLOSED;
374    return SYS_ERR_OK;
375}
376