1// Copyright 2017 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "hid-fifo.h"
6
7#include <ddk/binding.h>
8#include <ddk/debug.h>
9#include <ddk/device.h>
10#include <ddk/driver.h>
11#include <ddk/protocol/hidbus.h>
12#include <zircon/device/input.h>
13
14#include <zircon/assert.h>
15#include <zircon/listnode.h>
16
17#include <assert.h>
18#include <stdlib.h>
19#include <string.h>
20
21#define HID_FLAGS_DEAD         (1 << 0)
22#define HID_FLAGS_WRITE_FAILED (1 << 1)
23
24// TODO(johngro) : Get this from a standard header instead of defining our own.
25#ifndef MIN
26#define MIN(a, b) ((a) < (b) ? (a) : (b))
27#endif
28
29#define foreach_instance(base, instance) \
30    list_for_every_entry(&base->instance_list, instance, hid_instance_t, node)
31#define bits_to_bytes(n) (((n) + 7) / 8)
32
33// Until we do full HID parsing, we put mouse and keyboard devices into boot
34// protocol mode. In particular, a mouse will always send 3 byte reports (see
35// ddk/protocol/input.h for the format). This macro sets ioctl return values for
36// boot mouse devices to reflect the boot protocol, rather than what the device
37// itself reports.
38// TODO: update this to include keyboards if we find a keyboard in the wild that
39// needs a hack as well.
40#define BOOT_MOUSE_HACK 1
41
42typedef struct hid_report_size {
43    int16_t id;
44    input_report_size_t in_size;
45    input_report_size_t out_size;
46    input_report_size_t feat_size;
47} hid_report_size_t;
48
49typedef struct hid_device {
50    zx_device_t* zxdev;
51
52    hid_info_t info;
53    hidbus_protocol_t hid;
54
55    // Reassembly buffer for input events too large to fit in a single interrupt
56    // transaction.
57    uint8_t* rbuf;
58    size_t rbuf_size;
59    size_t rbuf_filled;
60    size_t rbuf_needed;
61
62    size_t hid_report_desc_len;
63    uint8_t* hid_report_desc;
64
65    // TODO(johngro, tkilbourn: Do not hardcode this limit!)
66#define HID_MAX_REPORT_IDS 32
67    size_t num_reports;
68    hid_report_size_t sizes[HID_MAX_REPORT_IDS];
69
70    struct list_node instance_list;
71    mtx_t instance_lock;
72
73    char name[ZX_DEVICE_NAME_MAX + 1];
74} hid_device_t;
75
76typedef struct hid_instance {
77    zx_device_t* zxdev;
78    hid_device_t* base;
79
80    uint32_t flags;
81
82    zx_hid_fifo_t fifo;
83
84    struct list_node node;
85} hid_instance_t;
86
87// Convenience functions for calling hidbus protocol functions
88
89static inline zx_status_t hid_op_query(hid_device_t* hid, uint32_t options, hid_info_t* info) {
90    return hid->hid.ops->query(hid->hid.ctx, options, info);
91}
92
93static inline zx_status_t hid_op_start(hid_device_t* hid, hidbus_ifc_t* ifc, void* cookie) {
94    return hid->hid.ops->start(hid->hid.ctx, ifc, cookie);
95}
96
97static inline void hid_op_stop(hid_device_t* hid) {
98    hid->hid.ops->stop(hid->hid.ctx);
99}
100
101static inline zx_status_t hid_op_get_descriptor(hid_device_t* hid, uint8_t desc_type,
102                                                void** data, size_t* len) {
103    return hid->hid.ops->get_descriptor(hid->hid.ctx, desc_type, data, len);
104}
105
106static inline zx_status_t hid_op_get_report(hid_device_t* hid, uint8_t rpt_type, uint8_t rpt_id,
107                                            void* data, size_t len, size_t* out_len) {
108    return hid->hid.ops->get_report(hid->hid.ctx, rpt_type, rpt_id, data, len, out_len);
109}
110
111static inline zx_status_t hid_op_set_report(hid_device_t* hid, uint8_t rpt_type, uint8_t rpt_id,
112                                            void* data, size_t len) {
113    return hid->hid.ops->set_report(hid->hid.ctx, rpt_type, rpt_id, data, len);
114}
115
116static inline zx_status_t hid_op_get_idle(hid_device_t* hid, uint8_t rpt_id, uint8_t* duration) {
117    return hid->hid.ops->get_idle(hid->hid.ctx, rpt_id, duration);
118}
119
120static inline zx_status_t hid_op_set_idle(hid_device_t* hid, uint8_t rpt_id, uint8_t duration) {
121    return hid->hid.ops->set_idle(hid->hid.ctx, rpt_id, duration);
122}
123
124static inline zx_status_t hid_op_get_protocol(hid_device_t* hid, uint8_t* protocol) {
125    return hid->hid.ops->get_protocol(hid->hid.ctx, protocol);
126}
127
128static inline zx_status_t hid_op_set_protocol(hid_device_t* hid, uint8_t protocol) {
129    return hid->hid.ops->set_protocol(hid->hid.ctx, protocol);
130}
131
132
133static input_report_size_t hid_get_report_size_by_id(hid_device_t* hid,
134                                                     input_report_id_t id,
135                                                     input_report_type_t type) {
136    for (size_t i = 0; i < hid->num_reports; i++) {
137        if ((hid->sizes[i].id == id) || (hid->num_reports == 1)) {
138            switch (type) {
139            case INPUT_REPORT_INPUT:
140                return bits_to_bytes(hid->sizes[i].in_size);
141            case INPUT_REPORT_OUTPUT:
142                return bits_to_bytes(hid->sizes[i].out_size);
143            case INPUT_REPORT_FEATURE:
144                return bits_to_bytes(hid->sizes[i].feat_size);
145            }
146        }
147    }
148
149    return 0;
150}
151
152static zx_status_t hid_get_protocol(hid_device_t* hid, void* out_buf, size_t out_len,
153                                    size_t* out_actual) {
154    if (out_len < sizeof(int)) return ZX_ERR_INVALID_ARGS;
155
156    int* reply = out_buf;
157    *reply = INPUT_PROTO_NONE;
158    if (hid->info.dev_class == HID_DEV_CLASS_KBD || hid->info.dev_class == HID_DEV_CLASS_KBD_POINTER) {
159        *reply = INPUT_PROTO_KBD;
160    } else if (hid->info.dev_class == HID_DEV_CLASS_POINTER) {
161        *reply = INPUT_PROTO_MOUSE;
162    }
163    *out_actual = sizeof(*reply);
164    return ZX_OK;
165}
166
167static zx_status_t hid_get_hid_desc_size(hid_device_t* hid, void* out_buf, size_t out_len,
168                                         size_t* out_actual) {
169    if (out_len < sizeof(size_t)) return ZX_ERR_INVALID_ARGS;
170
171    size_t* reply = out_buf;
172    *reply = hid->hid_report_desc_len;
173    *out_actual = sizeof(*reply);
174    return ZX_OK;
175}
176
177static zx_status_t hid_get_hid_desc(hid_device_t* hid, void* out_buf, size_t out_len,
178                                    size_t* out_actual) {
179    if (out_len < hid->hid_report_desc_len) return ZX_ERR_INVALID_ARGS;
180
181    memcpy(out_buf, hid->hid_report_desc, hid->hid_report_desc_len);
182    *out_actual = hid->hid_report_desc_len;
183    return ZX_OK;
184}
185
186static zx_status_t hid_get_num_reports(hid_device_t* hid, void* out_buf, size_t out_len,
187                                       size_t* out_actual) {
188    if (out_len < sizeof(size_t)) return ZX_ERR_INVALID_ARGS;
189
190    size_t* reply = out_buf;
191    *reply = hid->num_reports;
192    *out_actual = sizeof(*reply);
193    return ZX_OK;
194}
195
196static zx_status_t hid_get_report_ids(hid_device_t* hid, void* out_buf, size_t out_len,
197                                      size_t* out_actual) {
198    if (out_len < hid->num_reports * sizeof(input_report_id_t))
199        return ZX_ERR_INVALID_ARGS;
200
201    input_report_id_t* reply = out_buf;
202    for (size_t i = 0; i < hid->num_reports; i++) {
203        *reply++ = (input_report_id_t)hid->sizes[i].id;
204    }
205    *out_actual =  hid->num_reports * sizeof(input_report_id_t);
206    return ZX_OK;
207}
208
209static zx_status_t hid_get_report_size(hid_device_t* hid, const void* in_buf, size_t in_len,
210                                       void* out_buf, size_t out_len, size_t* out_actual) {
211    if (in_len < sizeof(input_get_report_size_t)) return ZX_ERR_INVALID_ARGS;
212    if (out_len < sizeof(input_report_size_t)) return ZX_ERR_INVALID_ARGS;
213
214    const input_get_report_size_t* inp = in_buf;
215
216    input_report_size_t* reply = out_buf;
217    *reply = hid_get_report_size_by_id(hid, inp->id, inp->type);
218    if (*reply == 0) {
219        return ZX_ERR_INVALID_ARGS;
220    }
221
222    *out_actual = sizeof(*reply);
223    return ZX_OK;
224}
225
226static ssize_t hid_get_max_input_reportsize(hid_device_t* hid, void* out_buf, size_t out_len,
227                                            size_t* out_actual) {
228    if (out_len < sizeof(input_report_size_t)) return ZX_ERR_INVALID_ARGS;
229
230    input_report_size_t* reply = out_buf;
231
232    *reply = 0;
233    for (size_t i = 0; i < hid->num_reports; i++) {
234        if (hid->sizes[i].in_size > *reply)
235            *reply = hid->sizes[i].in_size;
236    }
237
238    *reply = bits_to_bytes(*reply);
239    *out_actual = sizeof(*reply);
240    return ZX_OK;
241}
242
243static zx_status_t hid_get_report(hid_device_t* hid, const void* in_buf, size_t in_len,
244                                  void* out_buf, size_t out_len, size_t* out_actual) {
245    if (in_len < sizeof(input_get_report_t)) return ZX_ERR_INVALID_ARGS;
246    const input_get_report_t* inp = in_buf;
247
248    input_report_size_t needed = hid_get_report_size_by_id(hid, inp->id, inp->type);
249    if (needed == 0) return ZX_ERR_INVALID_ARGS;
250    if (out_len < (size_t)needed) return ZX_ERR_BUFFER_TOO_SMALL;
251
252    return hid_op_get_report(hid, inp->type, inp->id, out_buf, out_len, out_actual);
253}
254
255static zx_status_t hid_set_report(hid_device_t* hid, const void* in_buf, size_t in_len) {
256
257    if (in_len < sizeof(input_set_report_t)) return ZX_ERR_INVALID_ARGS;
258    const input_set_report_t* inp = in_buf;
259
260    input_report_size_t needed = hid_get_report_size_by_id(hid, inp->id, inp->type);
261    if (needed == 0) return ZX_ERR_INVALID_ARGS;
262    if (in_len - sizeof(input_set_report_t) < (size_t)needed) return ZX_ERR_INVALID_ARGS;
263
264    return hid_op_set_report(hid, inp->type, inp->id, (void*)inp->data,
265                             in_len - sizeof(input_set_report_t));
266}
267
268
269static zx_status_t hid_read_instance(void* ctx, void* buf, size_t count, zx_off_t off,
270                                     size_t* actual) {
271    hid_instance_t* hid = ctx;
272
273    if (hid->flags & HID_FLAGS_DEAD) {
274        return ZX_ERR_PEER_CLOSED;
275    }
276
277    size_t left;
278    mtx_lock(&hid->fifo.lock);
279    size_t xfer;
280    uint8_t rpt_id;
281    ssize_t r = zx_hid_fifo_peek(&hid->fifo, &rpt_id);
282    if (r < 1) {
283        // fifo is empty
284        mtx_unlock(&hid->fifo.lock);
285        return ZX_ERR_SHOULD_WAIT;
286    }
287
288    xfer = hid_get_report_size_by_id(hid->base, rpt_id, INPUT_REPORT_INPUT);
289    if (xfer == 0) {
290        zxlogf(ERROR, "error reading hid device: unknown report id (%u)!\n", rpt_id);
291        mtx_unlock(&hid->fifo.lock);
292        return ZX_ERR_BAD_STATE;
293    }
294
295    if (xfer > count) {
296        zxlogf(SPEW, "next report: %zd, read count: %zd\n", xfer, count);
297        mtx_unlock(&hid->fifo.lock);
298        return ZX_ERR_BUFFER_TOO_SMALL;
299    }
300
301    r = zx_hid_fifo_read(&hid->fifo, buf, xfer);
302    left = zx_hid_fifo_size(&hid->fifo);
303    if (left == 0) {
304        device_state_clr(hid->zxdev, DEV_STATE_READABLE);
305    }
306    mtx_unlock(&hid->fifo.lock);
307    if (r > 0) {
308        *actual = r;
309        r = ZX_OK;
310    } else if (r == 0) {
311        r = ZX_ERR_SHOULD_WAIT;
312    }
313    return r;
314}
315
316static zx_status_t hid_ioctl_instance(void* ctx, uint32_t op,
317        const void* in_buf, size_t in_len, void* out_buf, size_t out_len, size_t* out_actual) {
318    hid_instance_t* hid = ctx;
319    if (hid->flags & HID_FLAGS_DEAD) return ZX_ERR_PEER_CLOSED;
320
321    switch (op) {
322    case IOCTL_INPUT_GET_PROTOCOL:
323        return hid_get_protocol(hid->base, out_buf, out_len, out_actual);
324    case IOCTL_INPUT_GET_REPORT_DESC_SIZE:
325        return hid_get_hid_desc_size(hid->base, out_buf, out_len, out_actual);
326    case IOCTL_INPUT_GET_REPORT_DESC:
327        return hid_get_hid_desc(hid->base, out_buf, out_len, out_actual);
328    case IOCTL_INPUT_GET_NUM_REPORTS:
329        return hid_get_num_reports(hid->base, out_buf, out_len, out_actual);
330    case IOCTL_INPUT_GET_REPORT_IDS:
331        return hid_get_report_ids(hid->base, out_buf, out_len, out_actual);
332    case IOCTL_INPUT_GET_REPORT_SIZE:
333        return hid_get_report_size(hid->base, in_buf, in_len, out_buf, out_len, out_actual);
334    case IOCTL_INPUT_GET_MAX_REPORTSIZE:
335        return hid_get_max_input_reportsize(hid->base, out_buf, out_len, out_actual);
336    case IOCTL_INPUT_GET_REPORT:
337        return hid_get_report(hid->base, in_buf, in_len, out_buf, out_len, out_actual);
338    case IOCTL_INPUT_SET_REPORT:
339        return hid_set_report(hid->base, in_buf, in_len);
340    }
341    return ZX_ERR_NOT_SUPPORTED;
342}
343
344static zx_status_t hid_close_instance(void* ctx, uint32_t flags) {
345    hid_instance_t* hid = ctx;
346    hid->flags |= HID_FLAGS_DEAD;
347    mtx_lock(&hid->base->instance_lock);
348    // TODO: refcount the base device and call stop if no instances are open
349    list_delete(&hid->node);
350    mtx_unlock(&hid->base->instance_lock);
351    return ZX_OK;
352}
353
354static void hid_release_reassembly_buffer(hid_device_t* dev);
355
356static void hid_release_instance(void* ctx) {
357    hid_instance_t* hid = ctx;
358    free(hid);
359}
360
361zx_protocol_device_t hid_instance_proto = {
362    .version = DEVICE_OPS_VERSION,
363    .read = hid_read_instance,
364    .ioctl = hid_ioctl_instance,
365    .close = hid_close_instance,
366    .release = hid_release_instance,
367};
368
369enum {
370    HID_ITEM_TYPE_MAIN = 0,
371    HID_ITEM_TYPE_GLOBAL = 1,
372    HID_ITEM_TYPE_LOCAL = 2,
373};
374
375enum {
376    HID_ITEM_MAIN_TAG_INPUT = 8,
377    HID_ITEM_MAIN_TAG_OUTPUT = 9,
378    HID_ITEM_MAIN_TAG_FEATURE = 11,
379};
380
381enum {
382    HID_ITEM_GLOBAL_TAG_REPORT_SIZE = 7,
383    HID_ITEM_GLOBAL_TAG_REPORT_ID = 8,
384    HID_ITEM_GLOBAL_TAG_REPORT_COUNT = 9,
385    HID_ITEM_GLOBAL_TAG_PUSH = 10,
386    HID_ITEM_GLOBAL_TAG_POP = 11,
387};
388
389static void hid_dump_hid_report_desc(hid_device_t* dev) {
390    zxlogf(TRACE, "hid: dev %p HID report descriptor\n", dev);
391    for (size_t c = 0; c < dev->hid_report_desc_len; c++) {
392        zxlogf(TRACE, "%02x ", dev->hid_report_desc[c]);
393        if (c % 16 == 15) zxlogf(ERROR, "\n");
394    }
395    zxlogf(TRACE, "\n");
396    zxlogf(TRACE, "hid: num reports: %zd\n", dev->num_reports);
397    for (size_t i = 0; i < dev->num_reports; i++) {
398        zxlogf(TRACE, "  report id: %u  sizes: in %u out %u feat %u\n",
399                dev->sizes[i].id, dev->sizes[i].in_size, dev->sizes[i].out_size,
400                dev->sizes[i].feat_size);
401    }
402}
403
404typedef struct hid_item {
405    uint8_t bSize;
406    uint8_t bType;
407    uint8_t bTag;
408    int64_t data;
409} hid_item_t;
410
411static const uint8_t* hid_parse_short_item(const uint8_t* buf, const uint8_t* end, hid_item_t* item) {
412    switch (*buf & 0x3) {
413    case 0:
414        item->bSize = 0;
415        break;
416    case 1:
417        item->bSize = 1;
418        break;
419    case 2:
420        item->bSize = 2;
421        break;
422    case 3:
423        item->bSize = 4;
424        break;
425    }
426    item->bType = (*buf >> 2) & 0x3;
427    item->bTag = (*buf >> 4) & 0x0f;
428    if (buf + item->bSize >= end) {
429        // Return a RESERVED item type, and point past the end of the buffer to
430        // prevent further parsing.
431        item->bType = 0x03;
432        return end;
433    }
434    buf++;
435
436    item->data = 0;
437    for (uint8_t i = 0; i < item->bSize; i++) {
438        item->data |= *buf << (8*i);
439        buf++;
440    }
441    return buf;
442}
443
444static int hid_fetch_or_alloc_report_ndx(input_report_id_t report_id, hid_device_t* dev) {
445    ZX_DEBUG_ASSERT(dev->num_reports <= countof(dev->sizes));
446    for (size_t i = 0; i < dev->num_reports; i++) {
447        if (dev->sizes[i].id == report_id)
448            return i;
449    }
450
451    if (dev->num_reports < countof(dev->sizes)) {
452        dev->sizes[dev->num_reports].id = report_id;
453        ZX_DEBUG_ASSERT(dev->sizes[dev->num_reports].in_size == 0);
454        ZX_DEBUG_ASSERT(dev->sizes[dev->num_reports].out_size == 0);
455        ZX_DEBUG_ASSERT(dev->sizes[dev->num_reports].feat_size == 0);
456        return dev->num_reports++;
457    } else {
458        return -1;
459    }
460
461}
462
463typedef struct hid_global_state {
464    uint32_t rpt_size;
465    uint32_t rpt_count;
466    input_report_id_t rpt_id;
467    list_node_t node;
468} hid_global_state_t;
469
470static zx_status_t hid_push_global_state(list_node_t* stack, hid_global_state_t* state) {
471    hid_global_state_t* entry = malloc(sizeof(*entry));
472    if (entry == NULL) {
473        return ZX_ERR_NO_MEMORY;
474    }
475    entry->rpt_size = state->rpt_size;
476    entry->rpt_count = state->rpt_count;
477    entry->rpt_id = state->rpt_id;
478    list_add_tail(stack, &entry->node);
479    return ZX_OK;
480}
481
482static zx_status_t hid_pop_global_state(list_node_t* stack, hid_global_state_t* state) {
483    hid_global_state_t* entry = list_remove_tail_type(stack, hid_global_state_t, node);
484    if (entry == NULL) {
485        return ZX_ERR_BAD_STATE;
486    }
487    state->rpt_size = entry->rpt_size;
488    state->rpt_count = entry->rpt_count;
489    state->rpt_id = entry->rpt_id;
490    free(entry);
491    return ZX_OK;
492}
493
494static void hid_clear_global_state(list_node_t* stack) {
495    hid_global_state_t* state, *tmp;
496    list_for_every_entry_safe(stack, state, tmp, hid_global_state_t, node) {
497        list_delete(&state->node);
498        free(state);
499    }
500}
501
502static zx_status_t hid_process_hid_report_desc(hid_device_t* dev) {
503    const uint8_t* buf = dev->hid_report_desc;
504    const uint8_t* end = buf + dev->hid_report_desc_len;
505    zx_status_t status = ZX_OK;
506    hid_item_t item;
507
508    bool has_rpt_id = false;
509    hid_global_state_t state;
510    memset(&state, 0, sizeof(state));
511    list_node_t global_stack;
512    list_initialize(&global_stack);
513    while (buf < end) {
514        buf = hid_parse_short_item(buf, end, &item);
515        switch (item.bType) {
516        case HID_ITEM_TYPE_MAIN: {
517            input_report_size_t inc = state.rpt_size * state.rpt_count;
518            int idx;
519            switch (item.bTag) {
520            case HID_ITEM_MAIN_TAG_INPUT:
521                idx = hid_fetch_or_alloc_report_ndx(state.rpt_id, dev);
522                if (idx < 0) {
523                    status = ZX_ERR_NOT_SUPPORTED;
524                    goto done;
525                }
526                dev->sizes[idx].in_size += inc;
527                break;
528            case HID_ITEM_MAIN_TAG_OUTPUT:
529                idx = hid_fetch_or_alloc_report_ndx(state.rpt_id, dev);
530                if (idx < 0) {
531                    status = ZX_ERR_NOT_SUPPORTED;
532                    goto done;
533                }
534                dev->sizes[idx].out_size += inc;
535                break;
536            case HID_ITEM_MAIN_TAG_FEATURE:
537                idx = hid_fetch_or_alloc_report_ndx(state.rpt_id, dev);
538                if (idx < 0) {
539                    status = ZX_ERR_NOT_SUPPORTED;
540                    goto done;
541                }
542                dev->sizes[idx].feat_size += inc;
543                break;
544            default:
545                break;
546            }
547            break;  // case HID_ITEM_TYPE_MAIN
548        }
549        case HID_ITEM_TYPE_GLOBAL: {
550            switch (item.bTag) {
551            case HID_ITEM_GLOBAL_TAG_REPORT_SIZE:
552                state.rpt_size = (uint32_t)item.data;
553                break;
554            case HID_ITEM_GLOBAL_TAG_REPORT_ID:
555                state.rpt_id = (input_report_id_t)item.data;
556                has_rpt_id = true;
557                break;
558            case HID_ITEM_GLOBAL_TAG_REPORT_COUNT:
559                state.rpt_count = (uint32_t)item.data;
560                break;
561            case HID_ITEM_GLOBAL_TAG_PUSH:
562                status = hid_push_global_state(&global_stack, &state);
563                if (status != ZX_OK) {
564                    goto done;
565                }
566                break;
567            case HID_ITEM_GLOBAL_TAG_POP:
568                status = hid_pop_global_state(&global_stack, &state);
569                if (status != ZX_OK) {
570                    goto done;
571                }
572                break;
573            default:
574                break;
575            }
576            break;  // case HID_ITEM_TYPE_GLOBAL
577        }
578        default:
579            break;
580        }
581    }
582done:
583    hid_clear_global_state(&global_stack);
584
585    if (status == ZX_OK) {
586#if BOOT_MOUSE_HACK
587        // Ignore the HID report descriptor from the device, since we're putting
588        // the device into boot protocol mode.
589        if (dev->info.dev_class == HID_DEV_CLASS_POINTER) {
590            if (dev->info.boot_device) {
591                zxlogf(INFO, "hid: boot mouse hack for \"%s\":  "
592                       "report count (%zu->1), "
593                       "inp sz (%d->24), "
594                       "out sz (%d->0), "
595                       "feat sz (%d->0)\n",
596                       dev->name, dev->num_reports, dev->sizes[0].in_size,
597                       dev->sizes[0].out_size, dev->sizes[0].feat_size);
598                dev->num_reports = 1;
599                dev->sizes[0].id = 0;
600                dev->sizes[0].in_size = 24;
601                dev->sizes[0].out_size = 0;
602                dev->sizes[0].feat_size = 0;
603                has_rpt_id = false;
604            } else {
605                zxlogf(INFO,
606                    "hid: boot mouse hack skipped for \"%s\": does not support protocol.\n",
607                    dev->name);
608            }
609        }
610#endif
611        // If we saw a report ID, adjust the expected report sizes to reflect
612        // the fact that we expect a report ID to be prepended to each report.
613        ZX_DEBUG_ASSERT(dev->num_reports <= countof(dev->sizes));
614        if (has_rpt_id) {
615            for (size_t i = 0; i < dev->num_reports; ++i) {
616                if (dev->sizes[i].in_size)   dev->sizes[i].in_size   += 8;
617                if (dev->sizes[i].out_size)  dev->sizes[i].out_size  += 8;
618                if (dev->sizes[i].feat_size) dev->sizes[i].feat_size += 8;
619            }
620        }
621    }
622
623    return status;
624}
625
626static void hid_release_reassembly_buffer(hid_device_t* dev) {
627    if (dev->rbuf != NULL) {
628        free(dev->rbuf);
629    }
630
631    dev->rbuf = NULL;
632    dev->rbuf_size =  0;
633    dev->rbuf_filled =  0;
634    dev->rbuf_needed =  0;
635}
636
637static zx_status_t hid_init_reassembly_buffer(hid_device_t* dev) {
638    ZX_DEBUG_ASSERT(dev->rbuf == NULL);
639    ZX_DEBUG_ASSERT(dev->rbuf_size == 0);
640    ZX_DEBUG_ASSERT(dev->rbuf_filled == 0);
641    ZX_DEBUG_ASSERT(dev->rbuf_needed == 0);
642
643    // TODO(johngro) : Take into account the underlying transport's ability to
644    // deliver payloads.  For example, if this is a USB HID device operating at
645    // full speed, we can expect it to deliver up to 64 bytes at a time.  If the
646    // maximum HID input report size is only 60 bytes, we should not need a
647    // reassembly buffer.
648    input_report_size_t max_report_size = 0;
649    size_t actual = 0;
650    zx_status_t res = hid_get_max_input_reportsize(dev, &max_report_size, sizeof(max_report_size),
651                                                   &actual);
652    if (res < 0) {
653        return res;
654    } else if (!max_report_size || actual != sizeof(max_report_size)) {
655        return ZX_ERR_INTERNAL;
656    }
657
658    dev->rbuf = malloc(max_report_size);
659    if (dev->rbuf == NULL) {
660        return ZX_ERR_NO_MEMORY;
661    }
662
663    dev->rbuf_size = max_report_size;
664    return ZX_OK;
665}
666
667static void hid_release_device(void* ctx) {
668    hid_device_t* hid = ctx;
669
670    if (hid->hid_report_desc) {
671        free(hid->hid_report_desc);
672        hid->hid_report_desc = NULL;
673        hid->hid_report_desc_len = 0;
674    }
675    hid_release_reassembly_buffer(hid);
676    free(hid);
677}
678
679static zx_status_t hid_open_device(void* ctx, zx_device_t** dev_out, uint32_t flags) {
680    hid_device_t* hid = ctx;
681
682    hid_instance_t* inst = calloc(1, sizeof(hid_instance_t));
683    if (inst == NULL) {
684        return ZX_ERR_NO_MEMORY;
685    }
686    zx_hid_fifo_init(&inst->fifo);
687
688    device_add_args_t args = {
689        .version = DEVICE_ADD_ARGS_VERSION,
690        .name = "hid",
691        .ctx = inst,
692        .ops = &hid_instance_proto,
693        .proto_id = ZX_PROTOCOL_INPUT,
694        .flags = DEVICE_ADD_INSTANCE,
695    };
696
697    zx_status_t status = status = device_add(hid->zxdev, &args, &inst->zxdev);
698    if (status != ZX_OK) {
699        zxlogf(ERROR, "hid: error creating instance %d\n", status);
700        free(inst);
701        return status;
702    }
703    inst->base = hid;
704
705    mtx_lock(&hid->instance_lock);
706    list_add_tail(&hid->instance_list, &inst->node);
707    mtx_unlock(&hid->instance_lock);
708
709    *dev_out = inst->zxdev;
710    return ZX_OK;
711}
712
713static void hid_unbind_device(void* ctx) {
714    hid_device_t* hid = ctx;
715    mtx_lock(&hid->instance_lock);
716    hid_instance_t* instance;
717    foreach_instance(hid, instance) {
718        instance->flags |= HID_FLAGS_DEAD;
719        device_state_set(instance->zxdev, DEV_STATE_READABLE);
720    }
721    mtx_unlock(&hid->instance_lock);
722    device_remove(hid->zxdev);
723}
724
725zx_protocol_device_t hid_device_proto = {
726    .version = DEVICE_OPS_VERSION,
727    .open = hid_open_device,
728    .unbind = hid_unbind_device,
729    .release = hid_release_device,
730};
731
732void hid_io_queue(void* cookie, const uint8_t* buf, size_t len) {
733    hid_device_t* hid = cookie;
734
735    mtx_lock(&hid->instance_lock);
736
737    while (len) {
738        // Start by figuring out if this payload either completes a partially
739        // assembled input report or represents an entire input buffer report on
740        // its own.
741        const uint8_t* rbuf;
742        size_t rlen;
743        size_t consumed;
744
745        if (hid->rbuf_needed) {
746            // Reassembly is in progress, just continue the process.
747            consumed = MIN(len, hid->rbuf_needed);
748            ZX_DEBUG_ASSERT (hid->rbuf_size >= hid->rbuf_filled);
749            ZX_DEBUG_ASSERT((hid->rbuf_size - hid->rbuf_filled) >= consumed);
750
751            memcpy(hid->rbuf + hid->rbuf_filled, buf, consumed);
752
753            if (consumed == hid->rbuf_needed) {
754                // reassembly finished.  Reset the bookkeeping and deliver the
755                // payload.
756                rbuf = hid->rbuf;
757                rlen = hid->rbuf_filled + consumed;
758                hid->rbuf_filled = 0;
759                hid->rbuf_needed = 0;
760            } else {
761                // We have not finished the process yet.  Update the bookkeeping
762                // and get out.
763                hid->rbuf_filled += consumed;
764                hid->rbuf_needed -= consumed;
765                break;
766            }
767        } else {
768            // No reassembly is in progress.  Start by identifying this report's
769            // size.
770            size_t  rpt_sz = hid_get_report_size_by_id(hid, buf[0], INPUT_REPORT_INPUT);
771
772            // If we don't recognize this report ID, we are in trouble.  Drop
773            // the rest of this payload and hope that the next one gets us back
774            // on track.
775            if (!rpt_sz) {
776                zxlogf(ERROR, "%s: failed to find input report size (report id %u)\n",
777                        hid->name, buf[0]);
778                break;
779            }
780
781            // Is the entire report present in this payload?  If so, just go
782            // ahead an deliver it directly from the input buffer.
783            if (len >= rpt_sz) {
784                rbuf = buf;
785                consumed = rlen = rpt_sz;
786            } else {
787                // Looks likes our report is fragmented over multiple buffers.
788                // Start the process of reassembly and get out.
789                ZX_DEBUG_ASSERT(hid->rbuf != NULL);
790                ZX_DEBUG_ASSERT(hid->rbuf_size >= rpt_sz);
791                memcpy(hid->rbuf, buf, len);
792                hid->rbuf_filled = len;
793                hid->rbuf_needed = rpt_sz - len;
794                break;
795            }
796        }
797
798        ZX_DEBUG_ASSERT(rbuf != NULL);
799        ZX_DEBUG_ASSERT(consumed <= len);
800        buf += consumed;
801        len -= consumed;
802
803        hid_instance_t* instance;
804        foreach_instance(hid, instance) {
805            mtx_lock(&instance->fifo.lock);
806            bool was_empty = zx_hid_fifo_size(&instance->fifo) == 0;
807            ssize_t wrote = zx_hid_fifo_write(&instance->fifo, rbuf, rlen);
808
809            if (wrote <= 0) {
810                if (!(instance->flags & HID_FLAGS_WRITE_FAILED)) {
811                    zxlogf(ERROR, "%s: could not write to hid fifo (ret=%zd)\n",
812                            hid->name, wrote);
813                    instance->flags |= HID_FLAGS_WRITE_FAILED;
814                }
815            } else {
816                instance->flags &= ~HID_FLAGS_WRITE_FAILED;
817                if (was_empty) {
818                    device_state_set(instance->zxdev, DEV_STATE_READABLE);
819                }
820            }
821            mtx_unlock(&instance->fifo.lock);
822        }
823    }
824
825    mtx_unlock(&hid->instance_lock);
826}
827
828hidbus_ifc_t hid_ifc = {
829    .io_queue = hid_io_queue,
830};
831
832static zx_status_t hid_bind(void* ctx, zx_device_t* parent) {
833    hid_device_t* hiddev;
834    if ((hiddev = calloc(1, sizeof(hid_device_t))) == NULL) {
835        return ZX_ERR_NO_MEMORY;
836    }
837
838    zx_status_t status;
839    if (device_get_protocol(parent, ZX_PROTOCOL_HIDBUS, &hiddev->hid)) {
840        zxlogf(ERROR, "hid: bind: no hidbus protocol\n");
841        status = ZX_ERR_INTERNAL;
842        goto fail;
843    }
844
845    if ((status = hid_op_query(hiddev, 0, &hiddev->info)) < 0) {
846        zxlogf(ERROR, "hid: bind: hidbus query failed: %d\n", status);
847        goto fail;
848    }
849
850    mtx_init(&hiddev->instance_lock, mtx_plain);
851    list_initialize(&hiddev->instance_list);
852
853    snprintf(hiddev->name, sizeof(hiddev->name), "hid-device-%03d", hiddev->info.dev_num);
854    hiddev->name[ZX_DEVICE_NAME_MAX] = 0;
855
856    if (hiddev->info.boot_device) {
857        status = hid_op_set_protocol(hiddev, HID_PROTOCOL_BOOT);
858        if (status != ZX_OK) {
859            zxlogf(ERROR, "hid: could not put HID device into boot protocol: %d\n", status);
860            goto fail;
861        }
862
863        // Disable numlock
864        if (hiddev->info.dev_class == HID_DEV_CLASS_KBD) {
865            uint8_t zero = 0;
866            hid_op_set_report(hiddev, HID_REPORT_TYPE_OUTPUT, 0, &zero, sizeof(zero));
867            // ignore failure for now
868        }
869    }
870
871    status = hid_op_get_descriptor(hiddev, HID_DESC_TYPE_REPORT,
872            (void**)&hiddev->hid_report_desc, &hiddev->hid_report_desc_len);
873    if (status != ZX_OK) {
874        zxlogf(ERROR, "hid: could not retrieve HID report descriptor: %d\n", status);
875        goto fail;
876    }
877
878    status = hid_process_hid_report_desc(hiddev);
879    if (status != ZX_OK) {
880        zxlogf(ERROR, "hid: could not parse hid report descriptor: %d\n", status);
881        goto fail;
882    }
883    hid_dump_hid_report_desc(hiddev);
884
885    status = hid_init_reassembly_buffer(hiddev);
886    if (status != ZX_OK) {
887        zxlogf(ERROR, "hid: failed to initialize reassembly buffer: %d\n", status);
888        goto fail;
889    }
890
891    device_add_args_t args = {
892        .version = DEVICE_ADD_ARGS_VERSION,
893        .name = hiddev->name,
894        .ctx = hiddev,
895        .ops = &hid_device_proto,
896        .proto_id = ZX_PROTOCOL_INPUT,
897    };
898
899    status = device_add(parent, &args, &hiddev->zxdev);
900    if (status != ZX_OK) {
901        zxlogf(ERROR, "hid: device_add failed for HID device: %d\n", status);
902        goto fail;
903    }
904
905    // TODO: delay calling start until we've been opened by someone
906    status = hid_op_start(hiddev, &hid_ifc, hiddev);
907    if (status != ZX_OK) {
908        zxlogf(ERROR, "hid: could not start hid device: %d\n", status);
909        device_remove(hiddev->zxdev);
910        // Don't fail, since we've been added. Need to let devmgr clean us up.
911        return status;
912    }
913
914    status = hid_op_set_idle(hiddev, 0, 0);
915    if (status != ZX_OK) {
916        zxlogf(TRACE, "hid: [W] set_idle failed for %s: %d\n", hiddev->name, status);
917        // continue anyway
918    }
919    return ZX_OK;
920
921fail:
922    hid_release_reassembly_buffer(hiddev);
923    free(hiddev);
924    return status;
925}
926
927static zx_driver_ops_t hid_driver_ops = {
928    .version = DRIVER_OPS_VERSION,
929    .bind = hid_bind,
930};
931
932ZIRCON_DRIVER_BEGIN(hid, hid_driver_ops, "zircon", "0.1", 1)
933    BI_MATCH_IF(EQ, BIND_PROTOCOL, ZX_PROTOCOL_HIDBUS),
934ZIRCON_DRIVER_END(hid)
935