1/*
2 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3 *
4 * SPDX-License-Identifier: BSD-2-Clause
5 */
6
7#include <arpa/inet.h>
8#include <string.h>
9
10#include "client.h"
11
12#define PRINT_RECEIVE_PACKET_IDS 0
13#if PRINT_RECEIVE_PACKET_IDS
14static uint64_t received_len;
15#endif
16
17static int num_rx_bufs = 0;
18static tx_msg_t *rx_buf_pool[NUM_TCP_BUFS];
19
20
21int socket_in;
22int tcp_echo_client;
23
24int tcp_socket_handle_async_received(tx_msg_t *msg)
25{
26    virtqueue_ring_object_t handle;
27    if (msg->socket_fd != tcp_echo_client) {
28        // socket has been closed
29        if (tcp_echo_client == -1) {
30            rx_buf_pool[num_rx_bufs] = msg;
31            num_rx_bufs++;
32            return 0;
33        } else {
34            msg->socket_fd = tcp_echo_client;
35            msg->done_len = -1;
36        }
37
38    }
39    if (msg->done_len == -1 || msg->done_len == 0) {
40        msg->total_len = TCP_READ_SIZE;
41        msg->done_len = 0;
42        virtqueue_init_ring_object(&handle);
43        if (!virtqueue_add_available_buf(&rx_virtqueue, &handle, ENCODE_DMA_ADDRESS(msg), BUF_SIZE, VQ_RW)) {
44            ZF_LOGF("tcp_handle_received: Error while enqueuing available buffer, queue full");
45        }
46
47    } else {
48        msg->total_len = msg->done_len;
49#if PRINT_RECEIVE_PACKET_IDS
50        /* Because we could potentially get back less data than we asked for
51         * the packet ID isn't always going to be at the start of the buffer.
52         */
53        if (received_len > 0) {
54            if ((msg->done_len + received_len) >= TCP_READ_SIZE) {
55                msg->done_len -= (TCP_READ_SIZE - received_len);
56                received_len = 0;
57            } else {
58                received_len += msg->done_len;
59                msg->done_len = 0;
60            }
61        }
62        while (msg->done_len > 0 && received_len == 0) {
63            if (msg->done_len < sizeof(uint64_t)) {
64                ZF_LOGE("SPLIT ID received");
65            } else {
66                uint64_t sent_id = *(uint64_t *)&msg->buf[msg->total_len - msg->done_len];
67                printf("ID: %ld\n", sent_id);
68            }
69            if (msg->done_len >= TCP_READ_SIZE) {
70                msg->done_len -= TCP_READ_SIZE;
71            } else {
72                received_len += msg->done_len;
73                msg->done_len = 0;
74            }
75        }
76#endif
77        msg->done_len = 0;
78        /* copy the packet over */
79
80        virtqueue_init_ring_object(&handle);
81        if (!virtqueue_add_available_buf(&tx_virtqueue, &handle, ENCODE_DMA_ADDRESS(msg), sizeof(*msg), VQ_RW)) {
82            ZF_LOGF("tcp_handle_received: Error while enqueuing available buffer, queue full");
83        }
84
85    }
86    return 0;
87
88}
89
90
91int tcp_socket_handle_async_sent(tx_msg_t *msg)
92{
93    virtqueue_ring_object_t handle;
94    msg->total_len = TCP_READ_SIZE;
95    msg->done_len = 0;
96    if (msg->socket_fd != tcp_echo_client) {
97        if (tcp_echo_client == -1) {
98            rx_buf_pool[num_rx_bufs] = msg;
99            num_rx_bufs++;
100            return 0;
101        } else {
102            msg->socket_fd = tcp_echo_client;
103        }
104
105    }
106    virtqueue_init_ring_object(&handle);
107    if (!virtqueue_add_available_buf(&rx_virtqueue, &handle, ENCODE_DMA_ADDRESS(msg), BUF_SIZE, VQ_RW)) {
108        ZF_LOGF("tcp_handle_sent: Error while enqueuing available buffer, queue full");
109    }
110    return 0;
111
112}
113
114void handle_tcp_echo_notification(uint16_t events, int socket)
115{
116    int ret = 0;
117    char ip_string[16] = {0};
118
119    if (events & PICOSERVER_CONN) {
120        picoserver_peer_t peer = echo_control_accept(socket);
121        if (peer.result == -1) {
122            ZF_LOGF("Failed to accept a peer");
123        }
124        tcp_echo_client = peer.socket;
125        ret = echo_control_set_async(tcp_echo_client, true);
126        if (ret) {
127            ZF_LOGF("Failed to set a socket to async: %d!", ret);
128        }
129        while (num_rx_bufs > 0) {
130            virtqueue_ring_object_t handle;
131
132            virtqueue_init_ring_object(&handle);
133            num_rx_bufs--;
134            tx_msg_t *buf = rx_buf_pool[num_rx_bufs];
135            buf->total_len = TCP_READ_SIZE;
136            buf->done_len = 0;
137            buf->socket_fd = tcp_echo_client;
138
139            if (!virtqueue_add_available_buf(&rx_virtqueue, &handle, ENCODE_DMA_ADDRESS(buf), sizeof(*buf), VQ_RW)) {
140                ZF_LOGF("Error while enqueuing available buffer, queue full");
141            }
142        }
143
144
145        inet_ntop(AF_INET, &peer.peer_addr, ip_string, 16);
146        printf("%s: Connection established with %s on socket %d from socket %d\n", get_instance_name(), ip_string,
147               tcp_echo_client, socket);
148    }
149    if (events & PICOSERVER_CLOSE) {
150        ret = echo_control_shutdown(socket, PICOSERVER_SHUT_RDWR);
151        printf("%s: Connection closing on socket %d\n", get_instance_name(), socket);
152    }
153    if (events & PICOSERVER_FIN) {
154        echo_control_close(tcp_echo_client);
155        tcp_echo_client = -1;
156        printf("%s: Connection closed on socket %d\n", get_instance_name(), socket);
157    }
158    if (events & PICOSERVER_ERR) {
159        ZF_LOGE("%s: Error with socket %d\n", get_instance_name(), socket);
160    }
161}
162
163static int setup_tcp_echo_socket(ps_io_ops_t *io_ops)
164{
165    socket_in = echo_control_open(false);
166    if (socket_in == -1) {
167        ZF_LOGF("Failed to open a socket for listening!");
168    }
169    int ret = echo_control_set_async(socket_in, true);
170    if (ret) {
171        ZF_LOGF("Failed to set a socket to async: %d!", ret);
172    }
173
174    ret = echo_control_bind(socket_in, PICOSERVER_ANY_ADDR_IPV4, TCP_ECHO_PORT);
175    if (ret) {
176        ZF_LOGF("Failed to bind a socket for listening!");
177    }
178
179    ret = echo_control_listen(socket_in, 1);
180    if (ret) {
181        ZF_LOGF("Failed to listen for incoming connections!");
182    }
183
184    for (int i = 0; i < NUM_TCP_BUFS; i++) {
185        tx_msg_t *buf = ps_dma_alloc(&io_ops->dma_manager, BUF_SIZE, 4, 1, PS_MEM_NORMAL);
186        ZF_LOGF_IF(buf == NULL, "Failed to alloc");
187        memset(buf, 0, BUF_SIZE);
188        buf->socket_fd = -1;
189        buf->client_cookie = (void *)TCP_SOCKETS_ASYNC_ID;
190        rx_buf_pool[num_rx_bufs] = buf;
191        num_rx_bufs++;
192
193    }
194
195    char ip_string[16] = {0};
196    uint32_t ip_raw = PICOSERVER_ANY_ADDR_IPV4;
197    inet_ntop(AF_INET, &ip_raw, ip_string, 16);
198    printf("%s instance starting up, going to be listening on %s:%d\n",
199           get_instance_name(), ip_string, TCP_ECHO_PORT);
200    return 0;
201
202}
203CAMKES_POST_INIT_MODULE_DEFINE(setup_tcp, setup_tcp_echo_socket);
204