1// Copyright 2016 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <assert.h>
6#include <dirent.h>
7#include <errno.h>
8#include <fcntl.h>
9#include <stdio.h>
10#include <stdlib.h>
11#include <string.h>
12#include <threads.h>
13#include <unistd.h>
14#include <limits.h>
15
16#include <zircon/assert.h>
17#include <zircon/listnode.h>
18#include <zircon/threads.h>
19#include <zircon/types.h>
20#include <zircon/device/input.h>
21
22#include <fbl/algorithm.h>
23#include <fbl/unique_ptr.h>
24
25#include <lib/fdio/watcher.h>
26
27// defined in report.cpp
28void print_report_descriptor(const uint8_t* rpt_desc, size_t desc_len);
29
30#define DEV_INPUT "/dev/class/input"
31
32static bool verbose = false;
33#define xprintf(fmt...) do { if (verbose) printf(fmt); } while (0)
34
35void usage(void) {
36    printf("usage: hid [-v] <command> [<args>]\n\n");
37    printf("  commands:\n");
38    printf("    read [<devpath> [num reads]]\n");
39    printf("    get <devpath> <in|out|feature> <id>\n");
40    printf("    set <devpath> <in|out|feature> <id> [0xXX *]\n");
41    printf("    parse <devpath>\n");
42}
43
44typedef struct input_args {
45    int fd;
46    char name[128];
47    unsigned long int num_reads;
48} input_args_t;
49
50static thrd_t input_poll_thread;
51
52static mtx_t print_lock = MTX_INIT;
53#define lprintf(fmt...) \
54    do { \
55        mtx_lock(&print_lock); \
56        printf(fmt); \
57        mtx_unlock(&print_lock); \
58    } while (0)
59
60
61static void print_hex(uint8_t* buf, size_t len) {
62    for (size_t i = 0; i < len; i++) {
63        printf("%02x ", buf[i]);
64        if (i % 16 == 15) printf("\n");
65    }
66    printf("\n");
67}
68
69static zx_status_t parse_uint_arg(const char* arg, uint32_t min, uint32_t max, uint32_t* out_val) {
70    if ((arg == NULL) || (out_val == NULL)) {
71        return ZX_ERR_INVALID_ARGS;
72    }
73
74    bool is_hex = (arg[0] == '0') && (arg[1] == 'x');
75    if (sscanf(arg, is_hex ? "%x" : "%u", out_val) != 1) {
76        return ZX_ERR_INVALID_ARGS;
77    }
78
79    if ((*out_val < min) || (*out_val > max)) {
80        return ZX_ERR_OUT_OF_RANGE;
81    }
82
83    return ZX_OK;
84}
85
86static zx_status_t parse_input_report_type(const char* arg, input_report_type_t* out_type) {
87    if ((arg == NULL) || (out_type == NULL)) {
88        return ZX_ERR_INVALID_ARGS;
89    }
90
91    static const struct {
92        const char* name;
93        input_report_type_t type;
94    } LUT[] = {
95        { .name = "in",      .type = INPUT_REPORT_INPUT },
96        { .name = "out",     .type = INPUT_REPORT_OUTPUT },
97        { .name = "feature", .type = INPUT_REPORT_FEATURE },
98    };
99
100    for (size_t i = 0; i < fbl::count_of(LUT); ++i) {
101        if (!strcasecmp(arg, LUT[i].name)) {
102            *out_type = LUT[i].type;
103            return ZX_OK;
104        }
105    }
106
107    return ZX_ERR_INVALID_ARGS;
108}
109
110static zx_status_t parse_set_get_report_args(int argc,
111                                             const char** argv,
112                                             input_report_id_t* out_id,
113                                             input_report_type_t* out_type) {
114    if ((argc < 3) || (argv == NULL) || (out_id == NULL) || (out_type == NULL)) {
115        return ZX_ERR_INVALID_ARGS;
116    }
117
118    zx_status_t res;
119    uint32_t tmp;
120    res = parse_uint_arg(argv[2], 0, 255, &tmp);
121    if (res != ZX_OK) {
122        return res;
123    }
124
125    *out_id = static_cast<input_report_id_t>(tmp);
126
127    return parse_input_report_type(argv[1], out_type);
128}
129
130
131static ssize_t get_hid_protocol(int fd, const char* name) {
132    int proto;
133    ssize_t rc = ioctl_input_get_protocol(fd, &proto);
134    if (rc < 0) {
135        lprintf("hid: could not get protocol from %s (status=%zd)\n", name, rc);
136    } else {
137        lprintf("hid: %s proto=%d\n", name, proto);
138    }
139    return rc;
140}
141
142static ssize_t get_report_desc_len(int fd, const char* name, size_t* report_desc_len) {
143    ssize_t rc = ioctl_input_get_report_desc_size(fd, report_desc_len);
144    if (rc < 0) {
145        lprintf("hid: could not get report descriptor length from %s (status=%zd)\n", name, rc);
146    } else {
147        lprintf("hid: %s report descriptor len=%zu\n", name, *report_desc_len);
148    }
149    return rc;
150}
151
152static ssize_t get_report_desc(int fd, const char* name, size_t report_desc_len) {
153    fbl::unique_ptr<uint8_t[]> buf(new uint8_t[report_desc_len]);
154
155    ssize_t rc = ioctl_input_get_report_desc(fd, buf.get(), report_desc_len);
156    if (rc < 0) {
157        lprintf("hid: could not get report descriptor from %s (status=%zd)\n", name, rc);
158        return rc;
159    }
160    mtx_lock(&print_lock);
161    printf("hid: %s report descriptor:\n", name);
162    if (verbose) {
163        print_hex(buf.get(), report_desc_len);
164    }
165    print_report_descriptor(buf.get(), report_desc_len);
166    mtx_unlock(&print_lock);
167    return rc;
168}
169
170static ssize_t get_num_reports(int fd, const char* name, size_t* num_reports) {
171    ssize_t rc = ioctl_input_get_num_reports(fd, num_reports);
172    if (rc < 0) {
173        lprintf("hid: could not get number of reports from %s (status=%zd)\n", name, rc);
174    } else {
175        lprintf("hid: %s num reports: %zu\n", name, *num_reports);
176    }
177    return rc;
178}
179
180static ssize_t get_report_ids(int fd, const char* name, size_t num_reports) {
181    size_t out_len = num_reports * sizeof(input_report_id_t);
182    fbl::unique_ptr<input_report_id_t[]> ids(new input_report_id_t[num_reports]);
183
184    ssize_t rc = ioctl_input_get_report_ids(fd, ids.get(), out_len);
185    if (rc < 0) {
186        lprintf("hid: could not get report ids from %s (status=%zd)\n", name, rc);
187        return rc;
188    }
189
190    mtx_lock(&print_lock);
191    printf("hid: %s report ids...\n", name);
192    for (size_t i = 0; i < num_reports; i++) {
193        static const struct {
194            input_report_type_t type;
195            const char* tag;
196        } TYPES[] = {
197            { .type = INPUT_REPORT_INPUT,   .tag = "Input" },
198            { .type = INPUT_REPORT_OUTPUT,  .tag = "Output" },
199            { .type = INPUT_REPORT_FEATURE, .tag = "Feature" },
200        };
201
202        bool found = false;
203        for (size_t j = 0; j < fbl::count_of(TYPES); ++j) {
204            input_get_report_size_t arg = { .id = ids[i], .type = TYPES[j].type };
205            input_report_size_t size;
206            ssize_t size_rc;
207
208            size_rc = ioctl_input_get_report_size(fd, &arg, &size);
209            if (size_rc >= 0) {
210                printf("  ID 0x%02x : TYPE %7s : SIZE %u bytes\n",
211                        ids[i], TYPES[j].tag, size);
212                found = true;
213            }
214        }
215
216        if (!found) {
217            printf("  hid: failed to find any report sizes for report id 0x%02x's (dev %s)\n",
218                    ids[i], name);
219        }
220    }
221
222    mtx_unlock(&print_lock);
223    return rc;
224}
225
226static ssize_t get_max_report_len(int fd, const char* name, input_report_size_t* max_report_len) {
227    input_report_size_t tmp;
228    if (max_report_len == NULL) {
229        max_report_len = &tmp;
230    }
231    ssize_t rc = ioctl_input_get_max_reportsize(fd, max_report_len);
232    if (rc < 0) {
233        lprintf("hid: could not get max report size from %s (status=%zd)\n", name, rc);
234    } else {
235        lprintf("hid: %s maxreport=%u\n", name, *max_report_len);
236    }
237    return rc;
238}
239
240#define TRY(fn)           \
241    do {                  \
242        ssize_t rc = fn;  \
243        if (rc < 0)       \
244            return rc;    \
245    } while (0)
246
247static ssize_t hid_status(int fd, const char* name, input_report_size_t* max_report_len) {
248    size_t num_reports;
249
250    TRY(get_hid_protocol(fd, name));
251    TRY(get_num_reports(fd, name, &num_reports));
252    TRY(get_report_ids(fd, name, num_reports));
253    TRY(get_max_report_len(fd, name, max_report_len));
254    return ZX_OK;
255}
256
257static ssize_t parse_rpt_descriptor(int fd, const char* name) {
258    size_t report_desc_len;
259    TRY(get_report_desc_len(fd, "", &report_desc_len));
260    TRY(get_report_desc(fd, "", report_desc_len));
261    return ZX_OK;
262}
263
264#undef TRY
265
266static int hid_input_thread(void* arg) {
267    input_args_t* args = (input_args_t*)arg;
268    lprintf("hid: input thread started for %s\n", args->name);
269
270    input_report_size_t max_report_len = 0;
271    ssize_t rc = hid_status(args->fd, args->name, &max_report_len);
272    if (rc < 0) {
273        return static_cast<int>(rc);
274    }
275
276    // Add 1 to the max report length to make room for a Report ID.
277    max_report_len++;
278    fbl::unique_ptr<uint8_t[]> report(new uint8_t[max_report_len]);
279
280    for (uint32_t i = 0; i < args->num_reads; i++) {
281        ssize_t r = read(args->fd, report.get(), max_report_len);
282        mtx_lock(&print_lock);
283        printf("read returned %ld\n", r);
284        if (r < 0) {
285            printf("read errno=%d (%s)\n", errno, strerror(errno));
286            mtx_unlock(&print_lock);
287            break;
288        }
289        printf("hid: input from %s\n", args->name);
290        print_hex(report.get(), r);
291        mtx_unlock(&print_lock);
292    }
293
294    lprintf("hid: closing %s\n", args->name);
295    close(args->fd);
296    delete args;
297    return ZX_OK;
298}
299
300static zx_status_t hid_input_device_added(int dirfd, int event, const char* fn, void* cookie) {
301    if (event != WATCH_EVENT_ADD_FILE) {
302        return ZX_OK;
303    }
304
305    int fd = openat(dirfd, fn, O_RDONLY);
306    if (fd < 0) {
307        return ZX_OK;
308    }
309
310    input_args_t* args = new input_args {};
311    args->fd = fd;
312    // TODO: support setting num_reads across all devices. requires a way to
313    // signal shutdown to all input threads.
314    args->num_reads = ULONG_MAX;
315    thrd_t t;
316    snprintf(args->name, sizeof(args->name), "hid-input-%s", fn);
317    int ret = thrd_create_with_name(&t, hid_input_thread, (void*)args, args->name);
318    if (ret != thrd_success) {
319        printf("hid: input thread %s did not start (error=%d)\n", args->name, ret);
320        close(fd);
321        return thrd_status_to_zx_status(ret);
322    }
323    thrd_detach(t);
324    return ZX_OK;
325}
326
327static int hid_input_devices_poll_thread(void* arg) {
328    int dirfd = open(DEV_INPUT, O_DIRECTORY|O_RDONLY);
329    if (dirfd < 0) {
330        printf("hid: error opening %s\n", DEV_INPUT);
331        return ZX_ERR_INTERNAL;
332    }
333    fdio_watch_directory(dirfd, hid_input_device_added, ZX_TIME_INFINITE, NULL);
334    close(dirfd);
335    return -1;
336}
337
338int read_reports(int argc, const char** argv) {
339    argc--;
340    argv++;
341    if (argc < 1) {
342        usage();
343        return 0;
344    }
345
346    uint32_t tmp = 0xffffffff;
347    if (argc > 1) {
348        zx_status_t res = parse_uint_arg(argv[1], 0, 0xffffffff, &tmp);
349        if (res != ZX_OK) {
350            printf("Failed to parse <num reads> (res %d)\n", res);
351            usage();
352            return 0;
353        }
354    }
355
356    int fd = open(argv[0], O_RDWR);
357    if (fd < 0) {
358        printf("could not open %s: %d\n", argv[0], errno);
359        return -1;
360    }
361
362    input_args_t* args = new input_args_t {};
363    args->fd = fd;
364    args->num_reads = tmp;
365
366    strlcpy(args->name, argv[0], sizeof(args->name));
367    thrd_t t;
368    int ret = thrd_create_with_name(&t, hid_input_thread, (void*)args, args->name);
369    if (ret != thrd_success) {
370        printf("hid: input thread %s did not start (error=%d)\n", args->name, ret);
371        delete args;
372        close(fd);
373        return -1;
374    }
375    thrd_join(t, NULL);
376    return 0;
377}
378
379int readall_reports(int argc, const char** argv) {
380    int ret = thrd_create_with_name(&input_poll_thread,
381                                    hid_input_devices_poll_thread,
382                                    NULL,
383                                    "hid-inputdev-poll");
384    if (ret != thrd_success) {
385        return -1;
386    }
387
388    thrd_join(input_poll_thread, NULL);
389    return 0;
390}
391
392int get_report(int argc, const char** argv) {
393    argc--;
394    argv++;
395    if (argc < 3) {
396        usage();
397        return 0;
398    }
399
400    input_get_report_size_t size_arg;
401    zx_status_t res = parse_set_get_report_args(argc, argv, &size_arg.id, &size_arg.type);
402    if (res != ZX_OK) {
403        printf("Failed to parse type/id for get report operation (res %d)\n", res);
404        usage();
405        return 0;
406    }
407
408    int fd = open(argv[0], O_RDWR);
409    if (fd < 0) {
410        printf("could not open %s: %d\n", argv[0], errno);
411        return -1;
412    }
413
414    xprintf("hid: getting report size for id=0x%02x type=%u\n", size_arg.id, size_arg.type);
415
416    input_report_size_t size;
417    ssize_t rc = ioctl_input_get_report_size(fd, &size_arg, &size);
418    if (rc < 0) {
419        printf("hid: could not get report (id 0x%02x type %u) size from %s (status=%zd)\n",
420                size_arg.id, size_arg.type, argv[0], rc);
421        return static_cast<int>(rc);
422    }
423    xprintf("hid: report size=%u\n", size);
424
425    input_get_report_t rpt_arg;
426    rpt_arg.id = size_arg.id;
427    rpt_arg.type = size_arg.type;
428
429    // TODO(johngro) : Come up with a better policy than this...  While devices
430    // are *supposed* to only deliver a report descriptor's computed size, in
431    // practice they frequently seem to deliver number of bytes either greater
432    // or fewer than the number of bytes originally requested.  For example...
433    //
434    // ++ Sometimes a device is expected to deliver a Report ID byte along with
435    //    the payload contents, but does not do so.
436    // ++ Sometimes it is unclear whether or not a device needs to deliver a
437    //    Report ID byte at all since there is only one report listed (and,
438    //    sometimes the device delivers that ID, and sometimes it chooses not
439    //    to).
440    // ++ Sometimes no bytes at all are returned for a report (this seems to
441    //    be relatively common for input reports)
442    // ++ Sometimes the number of bytes returned has basically nothing to do
443    //    with the expected size of the report (this seems to be relatively
444    //    common for vendor feature reports).
445    //
446    // Because of this uncertainty, we currently just provide a worst-case 4KB
447    // buffer to read into, and report the number of bytes which came back along
448    // with the expected size of the raw report.
449    size_t bufsz = 4u << 10;
450    fbl::unique_ptr<uint8_t[]> buf(new uint8_t[bufsz]);
451    rc = ioctl_input_get_report(fd, &rpt_arg, buf.get(), bufsz);
452    if (rc < 0) {
453        printf("hid: could not get report: %zd\n", rc);
454    } else {
455        printf("hid: got %zu bytes (raw report size %u)\n", rc, size);
456        print_hex(buf.get(), rc);
457    }
458    return static_cast<int>(rc);
459}
460
461int set_report(int argc, const char** argv) {
462    argc--;
463    argv++;
464    if (argc < 4) {
465        usage();
466        return 0;
467    }
468
469    input_get_report_size_t size_arg;
470    zx_status_t res = parse_set_get_report_args(argc, argv, &size_arg.id, &size_arg.type);
471    if (res != ZX_OK) {
472        printf("Failed to parse type/id for get report operation (res %d)\n", res);
473        usage();
474        return 0;
475    }
476
477    xprintf("hid: getting report size for id=0x%02x type=%u\n", size_arg.id, size_arg.type);
478
479    input_set_report_t* arg = NULL;
480    int fd = open(argv[0], O_RDWR);
481    if (fd < 0) {
482        printf("could not open %s: %d\n", argv[0], errno);
483        return -1;
484    }
485
486    // If the set/get report args parsed, then we must have at least 3 arguments.
487    ZX_DEBUG_ASSERT(argc >= 3);
488    input_report_size_t payload_size = static_cast<input_report_size_t>(argc - 3);
489    size_t in_len = sizeof(input_set_report_t) + payload_size;
490
491    input_report_size_t size;
492    ssize_t rc = ioctl_input_get_report_size(fd, &size_arg, &size);
493    if (rc < 0) {
494        printf("hid: could not get report (id 0x%02x type %u) size from %s (status=%zd)\n",
495                size_arg.id, size_arg.type, argv[0], rc);
496        goto finished;
497    }
498
499    xprintf("hid: report size=%u, tx payload size=%u\n", size, payload_size);
500
501    arg = reinterpret_cast<input_set_report_t*>(new char[in_len]);
502    arg->id = size_arg.id;
503    arg->type = size_arg.type;
504    for (int i = 0; i < payload_size; i++) {
505        uint32_t tmp;
506        zx_status_t res = parse_uint_arg(argv[i+3], 0, 255, &tmp);
507        if (res != ZX_OK) {
508            printf("Failed to parse payload byte \"%s\" (res = %d)\n", argv[i+3], res);
509            rc = res;
510            goto finished;
511        }
512
513        arg->data[i] = static_cast<uint8_t>(tmp);
514    }
515
516    rc = ioctl_input_set_report(fd, arg, in_len);
517    if (rc < 0) {
518        printf("hid: could not set report: %zd\n", rc);
519    } else {
520        printf("hid: success\n");
521    }
522
523finished:
524    delete [] reinterpret_cast<char*>(arg);
525    close(fd);
526    return static_cast<int>(rc);
527}
528
529int parse(int argc, const char** argv) {
530    argc--;
531    argv++;
532    if (argc < 1) {
533        usage();
534        return 0;
535    }
536
537    int fd = open(argv[0], O_RDWR);
538    if (fd < 0) {
539        printf("could not open %s: %d\n", argv[0], errno);
540        return -1;
541    }
542
543    ssize_t rc = parse_rpt_descriptor(fd, argv[0]);
544    close(fd);
545
546    return static_cast<int>(rc);
547}
548
549int main(int argc, const char** argv) {
550    if (argc < 2) {
551        usage();
552        return 0;
553    }
554    argc--;
555    argv++;
556    if (!strcmp("-v", argv[0])) {
557        verbose = true;
558        argc--;
559        argv++;
560    }
561    if (!strcmp("read", argv[0])) {
562        if (argc > 1) {
563            return read_reports(argc, argv);
564        } else {
565            return readall_reports(argc, argv);
566        }
567    }
568
569    if (!strcmp("get", argv[0])) {
570        return get_report(argc, argv);
571    }
572
573    if (!strcmp("set", argv[0])) {
574        return set_report(argc, argv);
575    }
576
577    if (!strcmp("parse", argv[0])) {
578        return parse(argc, argv);
579    }
580
581    usage();
582    return 0;
583}
584