1/*
2 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3 *
4 * SPDX-License-Identifier: BSD-2-Clause
5 */
6
7#include <autoconf.h>
8#include <stdbool.h>
9
10#include <camkes.h>
11
12#undef PACKED
13#include <pico_stack.h>
14#include <pico_socket.h>
15#include <pico_ipv4.h>
16
17#include "ports.h"
18#include "tuning_params.h"
19
20
21/* This file implements a echo server that listens on a TCP port and returns
22 * every byte it receives in order. At most MAX_TCP_CLIENTS can be connected at
23 * a time. If clients disconnect, more can then connect.
24 *
25 * The server tries to read data in chunks of TCP_READ_SIZE and then immediately
26 * tries to send what was received. If sending becomes blocked then the server
27 * won't try and receive new data until it has successfully sent the current
28 * data.
29 */
30
31/* Socket global data */
32
33/* TCP socket listening for clients */
34static struct pico_socket *socket_in;
35
36/* Connected client sockets */
37static struct pico_socket *connected[MAX_TCP_CLIENTS];
38
39/* Per-client buffer for receiving packets */
40static char data_packet[MAX_TCP_CLIENTS][0x1000] ALIGN(0x1000);
41
42/* Per-client state for saving a pending send */
43/* Whether a write is queued */
44static bool write_pending[MAX_TCP_CLIENTS];
45/* Amount of data left to send in the write */
46static int remaining_payload[MAX_TCP_CLIENTS];
47/* Amount of data from payload sent already */
48static int sent_payload[MAX_TCP_CLIENTS];
49
50
51/**
52 * @brief      picotcp socket callback function
53 *
54 * This function gets called when there are any socket events on the TCP
55 * socket or any sockets created from it. It isn't possible to register
56 * different handlers for the client sockets so this handler has to handle
57 * both connections and sends/receives.
58 *
59 * @param[in]  events  The picotcp events
60 * @param      s       Socket the event applies to.
61 */
62void handle_tcp_picoserver_notification(uint16_t events, struct pico_socket *s)
63{
64    int ret = 0;
65
66    /* Detect the client based on the reference to the socket given to the callback. */
67    int client_id = -1;
68    if (s != socket_in) {
69        for (int i = 0; i < MAX_TCP_CLIENTS; i++) {
70            if (connected[i] == s) {
71                client_id = i;
72                break;
73            }
74        }
75    }
76
77    /* New client connected event */
78    if (events & PICO_SOCK_EV_CONN) {
79        uint32_t peer_addr;
80        uint16_t remote_port;
81        int connect_client_id = -1;
82        assert(client_id == -1);
83        /* Find a free client ID */
84        for (int i = 0; i < MAX_TCP_CLIENTS; i++) {
85            if (connected[i] == NULL) {
86                connect_client_id = i;
87                break;
88            }
89        }
90        if (connect_client_id == -1 || connect_client_id == MAX_TCP_CLIENTS) {
91            printf("Cannot connect new client\n");
92        } else {
93            /* Accept client connection */
94            connected[connect_client_id] = pico_socket_accept(socket_in, &peer_addr, &remote_port);
95            if (connected[connect_client_id] == NULL) {
96                ZF_LOGE("pico_socket_accept: error received: %d", pico_err);
97            }
98            write_pending[connect_client_id] = false;
99            char ip_string[16] = {0};
100            pico_ipv4_to_string(ip_string, peer_addr);
101            printf("%s: Connection established with %s on socket %p\n", get_instance_name(), ip_string, connected);
102        }
103
104
105    }
106    /* Write successful event. If we have blocked writes try and resend */
107    if (events & PICO_SOCK_EV_WR && write_pending[client_id]) {
108        while (remaining_payload[client_id] > 0) {
109            int inner_ret = pico_socket_send(s, data_packet[client_id] + sent_payload[client_id], remaining_payload[client_id]);
110            if (inner_ret == -1) {
111                /* Received socket error. report and keep going. */
112                ZF_LOGE("pico_socket_send: error received: %d", pico_err);
113                break;
114            }
115            if (inner_ret == 0) {
116                write_pending[client_id] = true;
117                remaining_payload[client_id] = remaining_payload[client_id];
118                sent_payload[client_id] = sent_payload[client_id];
119
120                break;
121            } else {
122                remaining_payload[client_id] -= inner_ret;
123                sent_payload[client_id] += inner_ret;
124            }
125        }
126        /* If we successfully sent everything, clear the write pending flag */
127        if (remaining_payload[client_id] == 0) {
128            write_pending[client_id] = false;
129            /* Set the Read event bit in order to clear any reads that we skipped
130             * while blocking for send.
131             */
132            events |= PICO_SOCK_EV_RD;
133        }
134
135    }
136
137    /* Read event on client socket. Receive the data and try resend immediately. */
138    if (events & PICO_SOCK_EV_RD) {
139        while (!write_pending[client_id]) {
140            ret = pico_socket_recv(s, data_packet[client_id], TCP_READ_SIZE);
141            if (ret == -1) {
142                /* Received socket error. Report and keep going. */
143                ZF_LOGE("pico_socket_recv: error received: %d", pico_err);
144                break;
145            } else if (ret == 0) {
146                /* No data available */
147                break;
148            }
149            int done = 0;
150            while (ret > 0) {
151                int inner_ret = pico_socket_send(s, data_packet[client_id] + done, ret);
152                if (inner_ret == -1) {
153                    /* Received socket error, report and keep going. */
154                    ZF_LOGE("pico_socket_send: error received: %d", pico_err);
155                    break;
156                }
157                if (inner_ret == 0) {
158                    /* Cannot send more data. Save for write-completed event to retry */
159                    write_pending[client_id] = true;
160                    remaining_payload[client_id] = ret;
161                    sent_payload[client_id] = done;
162                    break;
163                } else {
164                    ret -= inner_ret;
165                    done += inner_ret;
166                }
167            }
168        }
169
170    }
171
172    if (events & PICO_SOCK_EV_CLOSE) {
173        ret = pico_socket_shutdown(s, PICO_SHUT_RDWR);
174        printf("%s: Connection closing on socket %p\n", get_instance_name(), s);
175    }
176    if (events & PICO_SOCK_EV_FIN) {
177        assert(client_id != -1);
178        connected[client_id] = NULL;
179        printf("%s: Connection closed on socket %p\n", get_instance_name(), s);
180    }
181    if (events & PICO_SOCK_EV_ERR) {
182        printf("%s: Error with socket %p, going to die\n", get_instance_name(), s);
183        assert(0);
184    }
185}
186
187
188int setup_tcp_socket(UNUSED ps_io_ops_t *io_ops)
189{
190    socket_in = pico_socket_open(PICO_PROTO_IPV4, PICO_PROTO_TCP, handle_tcp_picoserver_notification);
191    if (socket_in == NULL) {
192        ZF_LOGE("Failed to open a socket for listening!");
193        return -1;
194    }
195    uint32_t local_addr = PICO_IPV4_INADDR_ANY;
196    uint16_t port = short_be(TCP_ECHO_PORT);
197    int ret = pico_socket_bind(socket_in, &local_addr, &port);
198    if (ret) {
199        ZF_LOGE("Failed to bind a socket for listening: %d!", pico_err);
200        return -1;
201    } else {
202        printf("Bound to addr: %d, port %d\n", local_addr, port);
203    }
204
205    ret = pico_socket_listen(socket_in, MAX_TCP_CLIENTS);
206    if (ret) {
207        ZF_LOGE("Failed to listen for incoming connections!");
208        return -1;
209    }
210    return 0;
211}
212
213CAMKES_POST_INIT_MODULE_DEFINE(setup_tcp_socket_, setup_tcp_socket);
214