1// Copyright 2017 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <lib/fidl/coding.h>
6
7#include <stdalign.h>
8#include <stdint.h>
9#include <stdlib.h>
10
11#include <lib/fidl/internal.h>
12#include <zircon/assert.h>
13#include <zircon/compiler.h>
14#include <zircon/syscalls.h>
15
16// TODO(kulakowski) Design zx_status_t error values.
17
18namespace {
19
20// Some assumptions about data type layout.
21static_assert(offsetof(fidl_string_t, size) == 0u, "");
22static_assert(offsetof(fidl_string_t, data) == 8u, "");
23
24static_assert(offsetof(fidl_vector_t, count) == 0u, "");
25static_assert(offsetof(fidl_vector_t, data) == 8u, "");
26
27class FidlDecoder {
28public:
29    FidlDecoder(const fidl_type_t* type, void* bytes, uint32_t num_bytes,
30                const zx_handle_t* handles, uint32_t num_handles, const char** out_error_msg)
31        : type_(type), bytes_(static_cast<uint8_t*>(bytes)), num_bytes_(num_bytes),
32          handles_(handles), num_handles_(num_handles), out_error_msg_(out_error_msg) {}
33
34    zx_status_t DecodeMessage();
35
36private:
37    zx_status_t WithError(const char* error_msg) {
38        if (out_error_msg_ != nullptr) {
39            *out_error_msg_ = error_msg;
40        }
41        if (handles_) {
42            // Return value intentionally ignored: this is best-effort cleanup.
43            zx_handle_close_many(handles_, num_handles_);
44        }
45        return ZX_ERR_INVALID_ARGS;
46    }
47
48    template <typename T> T* TypedAt(uint32_t offset) const {
49        return reinterpret_cast<T*>(bytes_ + offset);
50    }
51
52    // Returns true when a handle was claimed, and false when the
53    // handles are exhausted.
54    bool ClaimHandle(zx_handle_t* out_handle) {
55        if (handle_idx_ == num_handles_) {
56            return false;
57        }
58        *out_handle = handles_[handle_idx_];
59        ++handle_idx_;
60        return true;
61    }
62
63    // Returns true when the buffer space is claimed, and false when
64    // the requested claim is too large for bytes_.
65    bool ClaimOutOfLineStorage(uint32_t size, uint32_t* out_offset) {
66        static constexpr uint32_t mask = FIDL_ALIGNMENT - 1;
67
68        // We have to manually maintain alignment here. For example, a pointer
69        // to a struct that is 4 bytes still needs to advance the next
70        // out-of-line offset by 8 to maintain the aligned-to-FIDL_ALIGNMENT
71        // property.
72        uint32_t offset = out_of_line_offset_;
73        if (add_overflow(offset, size, &offset) ||
74            add_overflow(offset, mask, &offset)) {
75            return false;
76        }
77        offset &= ~mask;
78
79        if (offset > num_bytes_) {
80            return false;
81        }
82        *out_offset = out_of_line_offset_;
83        out_of_line_offset_ = offset;
84        return true;
85    }
86
87    // Functions that manipulate the decoding stack frames.
88    struct Frame {
89        Frame(const fidl_type_t* fidl_type, uint32_t offset) : offset(offset) {
90            switch (fidl_type->type_tag) {
91            case fidl::kFidlTypeStruct:
92                state = kStateStruct;
93                struct_state.fields = fidl_type->coded_struct.fields;
94                struct_state.field_count = fidl_type->coded_struct.field_count;
95                break;
96            case fidl::kFidlTypeStructPointer:
97                state = kStateStructPointer;
98                struct_pointer_state.struct_type = fidl_type->coded_struct_pointer.struct_type;
99                break;
100            case fidl::kFidlTypeUnion:
101                state = kStateUnion;
102                union_state.types = fidl_type->coded_union.types;
103                union_state.type_count = fidl_type->coded_union.type_count;
104                union_state.data_offset = fidl_type->coded_union.data_offset;
105                break;
106            case fidl::kFidlTypeUnionPointer:
107                state = kStateUnionPointer;
108                union_pointer_state.union_type = fidl_type->coded_union_pointer.union_type;
109                break;
110            case fidl::kFidlTypeArray:
111                state = kStateArray;
112                array_state.element = fidl_type->coded_array.element;
113                array_state.array_size = fidl_type->coded_array.array_size;
114                array_state.element_size = fidl_type->coded_array.element_size;
115                break;
116            case fidl::kFidlTypeString:
117                state = kStateString;
118                string_state.max_size = fidl_type->coded_string.max_size;
119                string_state.nullable = fidl_type->coded_string.nullable;
120                break;
121            case fidl::kFidlTypeHandle:
122                state = kStateHandle;
123                handle_state.nullable = fidl_type->coded_handle.nullable;
124                break;
125            case fidl::kFidlTypeVector:
126                state = kStateVector;
127                vector_state.element = fidl_type->coded_vector.element;
128                vector_state.max_count = fidl_type->coded_vector.max_count;
129                vector_state.element_size = fidl_type->coded_vector.element_size;
130                vector_state.nullable = fidl_type->coded_vector.nullable;
131                break;
132            }
133        }
134
135        Frame(const fidl::FidlCodedStruct* coded_struct, uint32_t offset) : offset(offset) {
136            state = kStateStruct;
137            struct_state.fields = coded_struct->fields;
138            struct_state.field_count = coded_struct->field_count;
139        }
140
141        Frame(const fidl::FidlCodedUnion* coded_union, uint32_t offset) : offset(offset) {
142            state = kStateUnion;
143            union_state.types = coded_union->types;
144            union_state.type_count = coded_union->type_count;
145            union_state.data_offset = coded_union->data_offset;
146        }
147
148        Frame(const fidl_type_t* element, uint32_t array_size, uint32_t element_size,
149              uint32_t offset)
150            : offset(offset) {
151            state = kStateArray;
152            array_state.element = element;
153            array_state.array_size = array_size;
154            array_state.element_size = element_size;
155        }
156
157        // The default constructor does nothing when initializing the stack of frames.
158        Frame() {}
159
160        static Frame DoneSentinel() {
161            Frame frame;
162            frame.state = kStateDone;
163            return frame;
164        }
165
166        uint32_t NextStructField() {
167            ZX_DEBUG_ASSERT(state == kStateStruct);
168
169            uint32_t current = field;
170            field += 1;
171            return current;
172        }
173
174        uint32_t NextArrayOffset() {
175            ZX_DEBUG_ASSERT(state == kStateArray);
176
177            uint32_t current = field;
178            field += array_state.element_size;
179            return current;
180        }
181
182        enum : int {
183            kStateStruct,
184            kStateStructPointer,
185            kStateUnion,
186            kStateUnionPointer,
187            kStateArray,
188            kStateString,
189            kStateHandle,
190            kStateVector,
191
192            kStateDone,
193        } state;
194        // A byte offset into bytes_;
195        uint32_t offset;
196
197        // This is a subset of the information recorded in the
198        // fidl_type structures needed for decoding state. For
199        // example, struct sizes do not need to be present here.
200        union {
201            struct {
202                const fidl::FidlField* fields;
203                uint32_t field_count;
204            } struct_state;
205            struct {
206                const fidl::FidlCodedStruct* struct_type;
207            } struct_pointer_state;
208            struct {
209                const fidl_type_t* const* types;
210                uint32_t type_count;
211                uint32_t data_offset;
212            } union_state;
213            struct {
214                const fidl::FidlCodedUnion* union_type;
215            } union_pointer_state;
216            struct {
217                const fidl_type_t* element;
218                uint32_t array_size;
219                uint32_t element_size;
220            } array_state;
221            struct {
222                uint32_t max_size;
223                bool nullable;
224            } string_state;
225            struct {
226                bool nullable;
227            } handle_state;
228            struct {
229                const fidl_type* element;
230                uint32_t max_count;
231                uint32_t element_size;
232                bool nullable;
233            } vector_state;
234        };
235
236        uint32_t field = 0u;
237    };
238
239    // Returns true on success and false on recursion overflow.
240    bool Push(Frame frame) {
241        if (depth_ == FIDL_RECURSION_DEPTH) {
242            return false;
243        }
244        decoding_frames_[depth_] = frame;
245        ++depth_;
246        return true;
247    }
248
249    void Pop() {
250        ZX_DEBUG_ASSERT(depth_ != 0u);
251        --depth_;
252    }
253
254    Frame* Peek() {
255        ZX_DEBUG_ASSERT(depth_ != 0u);
256        return &decoding_frames_[depth_ - 1];
257    }
258
259    // Message state passed in to the constructor.
260    const fidl_type_t* const type_;
261    uint8_t* const bytes_;
262    const uint32_t num_bytes_;
263    const zx_handle_t* const handles_;
264    const uint32_t num_handles_;
265    const char** out_error_msg_;
266
267    // Internal state.
268    uint32_t handle_idx_ = 0u;
269    uint32_t out_of_line_offset_ = 0u;
270
271    // Decoding stack state.
272    uint32_t depth_ = 0u;
273    Frame decoding_frames_[FIDL_RECURSION_DEPTH];
274};
275
276zx_status_t FidlDecoder::DecodeMessage() {
277    // The first decode is special. It must be a struct. We need to
278    // know the size of the struct to compute the start of the
279    // out-of-line allocations.
280
281    if (type_ == nullptr) {
282        return WithError("Cannot decode a null fidl type");
283    }
284
285    if (bytes_ == nullptr) {
286        return WithError("Cannot decode null bytes");
287    }
288
289    if (handles_ == nullptr && num_handles_ != 0u) {
290        return WithError("Cannot provide non-zero handle count and null handle pointer");
291    }
292
293    if (type_->type_tag != fidl::kFidlTypeStruct) {
294        return WithError("Message must be a struct");
295    }
296
297    if (type_->coded_struct.size > num_bytes_) {
298        return WithError("Message size is smaller than expected");
299    }
300
301    out_of_line_offset_ = static_cast<uint32_t>(fidl::FidlAlign(type_->coded_struct.size));
302
303    Push(Frame::DoneSentinel());
304    Push(Frame(type_, 0u));
305
306    for (;;) {
307        Frame* frame = Peek();
308
309        switch (frame->state) {
310        case Frame::kStateStruct: {
311            uint32_t field_index = frame->NextStructField();
312            if (field_index == frame->struct_state.field_count) {
313                Pop();
314                continue;
315            }
316            const fidl::FidlField& field = frame->struct_state.fields[field_index];
317            const fidl_type_t* field_type = field.type;
318            uint32_t field_offset = frame->offset + field.offset;
319            if (!Push(Frame(field_type, field_offset))) {
320                return WithError("recursion depth exceeded decoding struct");
321            }
322            continue;
323        }
324        case Frame::kStateStructPointer: {
325            switch (*TypedAt<uintptr_t>(frame->offset)) {
326            case FIDL_ALLOC_PRESENT:
327                break;
328            case FIDL_ALLOC_ABSENT:
329                Pop();
330                continue;
331            default:
332                return WithError("Tried to decode a bad struct pointer");
333            }
334            void** struct_ptr_ptr = TypedAt<void*>(frame->offset);
335            if (!ClaimOutOfLineStorage(frame->struct_pointer_state.struct_type->size,
336                                       &frame->offset)) {
337                return WithError("message wanted to store too large of a nullable struct");
338            }
339            *struct_ptr_ptr = TypedAt<void>(frame->offset);
340            const fidl::FidlCodedStruct* coded_struct = frame->struct_pointer_state.struct_type;
341            *frame = Frame(coded_struct, frame->offset);
342            continue;
343        }
344        case Frame::kStateUnion: {
345            fidl_union_tag_t union_tag = *TypedAt<fidl_union_tag_t>(frame->offset);
346            if (union_tag >= frame->union_state.type_count) {
347                return WithError("Tried to decode a bad union discriminant");
348            }
349            const fidl_type_t* member = frame->union_state.types[union_tag];
350            if (!member) {
351                Pop();
352                continue;
353            }
354            frame->offset += frame->union_state.data_offset;
355            *frame = Frame(member, frame->offset);
356            continue;
357        }
358        case Frame::kStateUnionPointer: {
359            fidl_union_tag_t** union_ptr_ptr = TypedAt<fidl_union_tag_t*>(frame->offset);
360            switch (*TypedAt<uintptr_t>(frame->offset)) {
361            case FIDL_ALLOC_PRESENT:
362                break;
363            case FIDL_ALLOC_ABSENT:
364                Pop();
365                continue;
366            default:
367                return WithError("Tried to decode a bad union pointer");
368            }
369            if (!ClaimOutOfLineStorage(frame->union_pointer_state.union_type->size,
370                                       &frame->offset)) {
371                return WithError("message wanted to store too large of a nullable union");
372            }
373            *union_ptr_ptr = TypedAt<fidl_union_tag_t>(frame->offset);
374            const fidl::FidlCodedUnion* coded_union = frame->union_pointer_state.union_type;
375            *frame = Frame(coded_union, frame->offset);
376            continue;
377        }
378        case Frame::kStateArray: {
379            uint32_t element_offset = frame->NextArrayOffset();
380            if (element_offset == frame->array_state.array_size) {
381                Pop();
382                continue;
383            }
384            const fidl_type_t* element_type = frame->array_state.element;
385            uint32_t offset = frame->offset + element_offset;
386            if (!Push(Frame(element_type, offset))) {
387                return WithError("recursion depth exceeded decoding array");
388            }
389            continue;
390        }
391        case Frame::kStateString: {
392            fidl_string_t* string_ptr = TypedAt<fidl_string_t>(frame->offset);
393            // The string storage may be Absent for nullable strings and must
394            // otherwise be Present. No other values are allowed.
395            switch (reinterpret_cast<uintptr_t>(string_ptr->data)) {
396            case FIDL_ALLOC_PRESENT:
397                break;
398            case FIDL_ALLOC_ABSENT:
399                if (!frame->string_state.nullable) {
400                    return WithError("message tried to decode an absent non-nullable string");
401                }
402                if (string_ptr->size != 0u) {
403                    return WithError("message tried to decode an absent string of non-zero length");
404                }
405                Pop();
406                continue;
407            default:
408                return WithError(
409                    "message tried to decode a string that is neither present nor absent");
410            }
411            uint64_t bound = frame->string_state.max_size;
412            uint64_t size = string_ptr->size;
413            if (size > bound) {
414                return WithError("message tried to decode too large of a bounded string");
415            }
416            uint32_t string_data_offset = 0u;
417            if (!ClaimOutOfLineStorage(static_cast<uint32_t>(size), &string_data_offset)) {
418                return WithError("decoding a  string overflowed buffer");
419            }
420            string_ptr->data = TypedAt<char>(string_data_offset);
421            Pop();
422            continue;
423        }
424        case Frame::kStateHandle: {
425            zx_handle_t* handle_ptr = TypedAt<zx_handle_t>(frame->offset);
426            // The handle storage may be Absent for nullable handles and must
427            // otherwise be Present. No other values are allowed.
428            switch (*handle_ptr) {
429            case FIDL_HANDLE_ABSENT:
430                if (frame->handle_state.nullable) {
431                    Pop();
432                    continue;
433                }
434                break;
435            case FIDL_HANDLE_PRESENT:
436                if (!ClaimHandle(handle_ptr)) {
437                    return WithError("message decoded too many handles");
438                }
439                Pop();
440                continue;
441            }
442            // Either the value at the handle was garbage, or was
443            // ABSENT for a nonnullable handle.
444            return WithError("message tried to decode a non-present handle");
445        }
446        case Frame::kStateVector: {
447            fidl_vector_t* vector_ptr = TypedAt<fidl_vector_t>(frame->offset);
448            // The vector storage may be Absent for nullable vectors and must
449            // otherwise be Present. No other values are allowed.
450            switch (reinterpret_cast<uintptr_t>(vector_ptr->data)) {
451            case FIDL_ALLOC_PRESENT:
452                break;
453            case FIDL_ALLOC_ABSENT:
454                if (!frame->vector_state.nullable) {
455                    return WithError("message tried to decode an absent non-nullable vector");
456                }
457                if (vector_ptr->count != 0u) {
458                    return WithError("message tried to decode an absent vector of non-zero elements");
459                }
460                Pop();
461                continue;
462            default:
463                return WithError("message tried to decode a non-present vector");
464            }
465            if (vector_ptr->count > frame->vector_state.max_count) {
466                return WithError("message tried to decode too large of a bounded vector");
467            }
468            uint32_t size;
469            if (mul_overflow(vector_ptr->count, frame->vector_state.element_size, &size)) {
470                return WithError("integer overflow calculating vector size");
471            }
472            if (!ClaimOutOfLineStorage(size, &frame->offset)) {
473                return WithError("message wanted to store too large of a vector");
474            }
475            vector_ptr->data = TypedAt<void>(frame->offset);
476            if (frame->vector_state.element) {
477                // Continue by decoding the vector elements as an array.
478                *frame = Frame(frame->vector_state.element, size,
479                               frame->vector_state.element_size, frame->offset);
480            } else {
481                // If there is no element type pointer, there is
482                // nothing to decode in the vector secondary
483                // payload. So just continue.
484                Pop();
485            }
486            continue;
487        }
488        case Frame::kStateDone: {
489            if (out_of_line_offset_ != num_bytes_) {
490                return WithError("message did not decode all provided bytes");
491            }
492            if (handle_idx_ != num_handles_) {
493                return WithError("message did not contain the specified number of handles");
494            }
495            return ZX_OK;
496        }
497        }
498    }
499}
500
501} // namespace
502
503zx_status_t fidl_decode(const fidl_type_t* type, void* bytes, uint32_t num_bytes,
504                        const zx_handle_t* handles, uint32_t num_handles,
505                        const char** out_error_msg) {
506    FidlDecoder decoder(type, bytes, num_bytes, handles, num_handles, out_error_msg);
507    return decoder.DecodeMessage();
508}
509
510zx_status_t fidl_decode_msg(const fidl_type_t* type, fidl_msg_t* msg,
511                            const char** out_error_msg) {
512    return fidl_decode(type, msg->bytes, msg->num_bytes, msg->handles,
513                       msg->num_handles, out_error_msg);
514}
515