1/*
2 * Copyright (c) 2017, ETH Zurich.
3 * All rights reserved.
4 *
5 * This file is distributed under the terms in the attached LICENSE file.
6 * If you do not find this file, copies can be found by writing to:
7 * ETH Zurich D-INFK, Universitaetstrasse 4, CH-8092 Zurich. Attn: Systems Group.
8 */
9
10#include <stdlib.h>
11#include <stdio.h>
12#include <stdbool.h>
13#include <barrelfish/barrelfish.h>
14#include <barrelfish/waitset.h>
15#include <devif/queue_interface.h>
16#include <devif/queue_interface_backend.h>
17#include <devif/backends/net/udp.h>
18#include <lwip/inet_chksum.h>
19#include <lwip/lwip/inet.h>
20#include <net_interfaces/flags.h>
21#include <net/net.h>
22#include <net/net_filter.h>
23#include <net/dhcp.h>
24#include "../headers.h"
25#include <bench/bench.h>
26
27
28#define MAX_NUM_REGIONS 64
29
30//#define DEBUG_ENABLED
31
32#if defined(DEBUG_ENABLED)
33#define DEBUG(x...) do { printf("UDP_QUEUE: %s.%d:%s:%d: ", \
34            disp_name(), disp_get_core_id(), __func__, __LINE__); \
35                printf(x);\
36        } while (0)
37
38#else
39#define DEBUG(x...) ((void)0)
40#endif
41
42struct region_vaddr {
43    void* va;
44    regionid_t rid;
45};
46
47struct udp_q {
48    struct devq my_q;
49    struct devq* q;
50    struct udp_hdr header; // can fill in this header and reuse it by copying
51    struct region_vaddr regions[MAX_NUM_REGIONS];
52    struct capref filter_ep;
53    struct net_filter_state* filter;
54};
55
56
57#ifdef DEBUG_ENABLED
58static void print_buffer(struct udp_q* q, void* start, uint64_t len)
59{
60    uint8_t* buf = (uint8_t*) start;
61    printf("Packet in region at address %p len %zu \n",
62           buf, len);
63    for (int i = 0; i < len; i+=2) {
64        if (((i % 16) == 0) && i > 0) {
65            printf("\n");
66        }
67        printf("%2X", buf[i]);
68        printf("%2X ", buf[i+1]);
69    }
70    printf("\n");
71}
72#endif
73
74static errval_t udp_register(struct devq* q, struct capref cap,
75                            regionid_t rid)
76{
77
78    errval_t err;
79    struct frame_identity frameid = { .base = 0, .bytes = 0 };
80
81    struct udp_q* que = (struct udp_q*) q;
82
83    // Map device registers
84    err = frame_identify(cap, &frameid);
85    assert(err_is_ok(err));
86
87    err = vspace_map_one_frame_attr(&que->regions[rid % MAX_NUM_REGIONS].va,
88                                    frameid.bytes, cap, VREGION_FLAGS_READ_WRITE,
89                                    NULL, NULL);
90    if (err_is_fail(err)) {
91        DEBUG_ERR(err, "vspace_map_one_frame failed");
92        return err;
93    }
94    que->regions[rid % MAX_NUM_REGIONS].rid = rid;
95    DEBUG("id-%d va-%p \n", que->regions[rid % MAX_NUM_REGIONS].rid,
96          que->regions[rid % MAX_NUM_REGIONS].va);
97
98    return que->q->f.reg(que->q, cap, rid);
99}
100
101static errval_t udp_deregister(struct devq* q, regionid_t rid)
102{
103
104    struct udp_q* que = (struct udp_q*) q;
105    que->regions[rid % MAX_NUM_REGIONS].va = NULL;
106    que->regions[rid % MAX_NUM_REGIONS].rid = 0;
107    return que->q->f.dereg(que->q, rid);
108}
109
110
111static errval_t udp_control(struct devq* q, uint64_t cmd, uint64_t value,
112                           uint64_t* result)
113{
114    struct udp_q* que = (struct udp_q*) q;
115    return que->q->f.ctrl(que->q, cmd, value, result);
116}
117
118
119static errval_t udp_notify(struct devq* q)
120{
121    struct udp_q* que = (struct udp_q*) q;
122    return que->q->f.notify(que->q);
123}
124
125static errval_t udp_enqueue(struct devq* q, regionid_t rid,
126                           genoffset_t offset, genoffset_t length,
127                           genoffset_t valid_data, genoffset_t valid_length,
128                           uint64_t flags)
129{
130
131    // for now limit length
132    //  TODO fragmentation
133
134    struct udp_q* que = (struct udp_q*) q;
135    if (flags & NETIF_TXFLAG) {
136
137        DEBUG("TX rid: %d offset %ld length %ld valid_length %ld \n", rid, offset,
138              length, valid_length);
139        assert(valid_length <= 1500);
140        que->header.len = htons(valid_length + UDP_HLEN);
141
142        assert(que->regions[rid % MAX_NUM_REGIONS].va != NULL);
143
144        uint8_t* start = (uint8_t*) que->regions[rid % MAX_NUM_REGIONS].va +
145                         offset + valid_data + ETH_HLEN + IP_HLEN;
146
147        memcpy(start, &que->header, sizeof(que->header));
148
149        return que->q->f.enq(que->q, rid, offset, length, valid_data,
150                             valid_length + UDP_HLEN, flags);
151    }
152
153    if (flags & NETIF_RXFLAG) {
154        assert(valid_length <= 2048);
155        DEBUG("RX rid: %d offset %ld length %ld valid_length %ld \n", rid, offset,
156              length, valid_length);
157        return que->q->f.enq(que->q, rid, offset, length, valid_data,
158                             valid_length, flags);
159    }
160
161    return NET_QUEUE_ERR_UNKNOWN_BUF_TYPE;
162}
163
164static errval_t udp_dequeue(struct devq* q, regionid_t* rid, genoffset_t* offset,
165                           genoffset_t* length, genoffset_t* valid_data,
166                           genoffset_t* valid_length, uint64_t* flags)
167{
168    errval_t err;
169    struct udp_q* que = (struct udp_q*) q;
170
171    err = que->q->f.deq(que->q, rid, offset, length, valid_data, valid_length, flags);
172    if (err_is_fail(err)) {
173        return err;
174    }
175
176    if (*flags & NETIF_RXFLAG) {
177        DEBUG("RX rid: %d offset %ld valid_data %ld length %ld va %p \n", *rid,
178              *offset, *valid_data,
179              *valid_length, que->regions[*rid % MAX_NUM_REGIONS].va + *offset + *valid_data);
180
181        struct udp_hdr* header = (struct udp_hdr*)
182                                 (que->regions[*rid % MAX_NUM_REGIONS].va +
183                                 *offset + *valid_data);
184
185        // Correct port for this queue?
186        if (header->dest != que->header.dest) {
187            printf("UDP queue: dropping packet, wrong port %d %d \n",
188                   header->dest, que->header.dest);
189            err = que->q->f.enq(que->q, *rid, *offset, *length, 0, 0, NETIF_RXFLAG);
190            return err_push(err, NET_QUEUE_ERR_WRONG_PORT);
191        }
192
193#ifdef DEBUG_ENABLED
194        print_buffer(que, que->regions[*rid % MAX_NUM_REGIONS].va + *offset, *valid_length);
195#endif
196
197        *valid_length = ntohs(header->len) - UDP_HLEN;
198        *valid_data += UDP_HLEN;
199        //print_buffer(que, que->regions[*rid % MAX_NUM_REGIONS].va + *offset+ *valid_data, *valid_length);
200        return SYS_ERR_OK;
201    }
202
203#ifdef DEBUG_ENABLED
204    DEBUG("TX rid: %d offset %ld length %ld \n", *rid, *offset,
205          *valid_length);
206#endif
207
208    return SYS_ERR_OK;
209}
210
211/*
212 * Public functions
213 *
214 */
215errval_t udp_create(struct udp_q** q, const char* card_name,
216                    uint16_t src_port, uint16_t dst_port,
217                    uint32_t dst_ip, void(*interrupt)(void*), bool poll)
218{
219    errval_t err;
220    struct udp_q* que;
221    que = calloc(1, sizeof(struct udp_q));
222    assert(que);
223
224    uint32_t src_ip;
225    err = net_config_current_ip_query(NET_FLAGS_BLOCKING_INIT, &src_ip);
226    if (err_is_fail(err)) {
227        return err;
228    }
229
230    // init other queue
231    uint64_t qid;
232    err = ip_create((struct ip_q**) &que->q, card_name, &qid, UDP_PROT, dst_ip,
233                    interrupt, poll);
234    if (err_is_fail(err)) {
235        return err;
236    }
237
238    ip_get_netfilter_ep((struct ip_q*) que->q, &que->filter_ep);
239    if (capref_is_null(que->filter_ep)) {
240        err = net_filter_init(&que->filter, card_name);
241        if (err_is_fail(err)) {
242            return err;
243        }
244    } else {
245        err = net_filter_init_with_ep(&que->filter, que->filter_ep);
246        if (err_is_fail(err)) {
247            return err;
248        }
249    }
250
251    src_ip = htonl(src_ip);
252    struct net_filter_ip ip = {
253        .qid = qid,
254        .ip_src = dst_ip,
255        .ip_dst = src_ip,
256        .port_dst = dst_port,
257        .type = NET_FILTER_UDP,
258    };
259
260    err = net_filter_ip_install(que->filter, &ip);
261    if (err_is_fail(err)) {
262        return err;
263    }
264
265    err = devq_init(&que->my_q, false);
266    if (err_is_fail(err)) {
267        errval_t err2;
268        err2 = net_filter_ip_remove(que->filter, &ip);
269        if (err_is_fail(err)) {
270            return err_push(err2, err);
271        }
272        return err;
273    }
274
275    // UDP fields
276    que->header.src = htons(src_port);
277    que->header.dest = htons(dst_port);
278    que->header.chksum = 0x0;
279
280    que->my_q.f.reg = udp_register;
281    que->my_q.f.dereg = udp_deregister;
282    que->my_q.f.ctrl = udp_control;
283    que->my_q.f.notify = udp_notify;
284    que->my_q.f.enq = udp_enqueue;
285    que->my_q.f.deq = udp_dequeue;
286    *q = que;
287
288    return SYS_ERR_OK;
289}
290
291errval_t udp_destroy(struct udp_q* q)
292{
293    // TODO destroy q->q;
294    free(q);
295
296    return SYS_ERR_OK;
297}
298
299errval_t udp_write_buffer(struct udp_q* q, regionid_t rid, genoffset_t offset,
300                          void* data, uint16_t len)
301{
302    assert(len <= 1500);
303    if (q->regions[rid % MAX_NUM_REGIONS].va != NULL) {
304        uint8_t* start = q->regions[rid % MAX_NUM_REGIONS].va + offset
305                         + sizeof (struct udp_hdr)
306                         + sizeof (struct ip_hdr)
307                         + sizeof (struct eth_hdr);
308        memcpy(start, data, len);
309        return SYS_ERR_OK;
310    } else {
311        return DEVQ_ERR_INVALID_REGION_ARGS;
312    }
313}
314
315struct bench_ctl* udp_get_benchmark_data(struct udp_q* q, bench_data_type_t type)
316{
317    return ip_get_benchmark_data((struct ip_q*) q->q, type);
318}
319