1// Copyright 2017 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <digest/merkle-tree.h>
6
7#include <stdint.h>
8#include <string.h>
9
10#include <digest/digest.h>
11#include <fbl/algorithm.h>
12#include <fbl/alloc_checker.h>
13#include <fbl/unique_ptr.h>
14#include <zircon/assert.h>
15#include <zircon/errors.h>
16
17namespace digest {
18
19// Size of a node in bytes.  Defined in tree.h.
20constexpr size_t MerkleTree::kNodeSize;
21
22// The number of digests that fit in a node.  Importantly, if L is a
23// node-aligned length in one level of the Merkle tree, |L / kDigestsPerNode| is
24// the corresponding digest-aligned length in the next level up.
25const size_t kDigestsPerNode = MerkleTree::kNodeSize / Digest::kLength;
26
27namespace {
28
29// Digest wrapper functions.  These functions implement how a node in the Merkle
30// tree is hashed:
31//    digest = Hash((offset | level) + length + node_data + padding)
32// where:
33//  * offset is from the start of the VMO.
34//  * level is the height of the node in the tree (data nodes have level == 0).
35//  * length is the node size, e.g kNodeSize except possibly for the last node.
36//  * node_data is the actual bytes from the node.
37//  * padding is |kNodeSize - length| zeros.
38
39// Wrapper for Digest::Init.  This primes the working |digest| initializing it
40// and hashing the |locality| and |length|.
41zx_status_t DigestInit(Digest* digest, uint64_t locality, size_t length) {
42    zx_status_t rc;
43    ZX_DEBUG_ASSERT(digest);
44    ZX_DEBUG_ASSERT(length < UINT32_MAX);
45    if ((rc = digest->Init()) != ZX_OK) {
46        return rc;
47    }
48    digest->Update(&locality, sizeof(locality));
49    uint32_t len32 = static_cast<uint32_t>(fbl::min(length, MerkleTree::kNodeSize));
50    digest->Update(&len32, sizeof(len32));
51    return ZX_OK;
52}
53
54// Wrapper for Digest::Update.  This will hash data from |in|, either |length|
55// bytes or up to the next node boundary, as determined from |offset|.  Returns
56// the number of bytes hashed.
57size_t DigestUpdate(Digest* digest, const uint8_t* in, size_t offset, size_t length) {
58    ZX_DEBUG_ASSERT(digest);
59    // Check if length crosses a node boundary
60    length = fbl::min(length, MerkleTree::kNodeSize - (offset % MerkleTree::kNodeSize));
61    digest->Update(in, length);
62    return length;
63}
64
65// Wrapper for Digest::Final.  This pads the hashed data with zeros up to a
66// node boundary before finalizing the digest.
67void DigestFinal(Digest* digest, size_t offset) {
68    offset = offset % MerkleTree::kNodeSize;
69    if (offset != 0) {
70        size_t pad_len = MerkleTree::kNodeSize - offset;
71        uint8_t pad[pad_len];
72        memset(pad, 0, pad_len);
73        digest->Update(pad, pad_len);
74    }
75    digest->Final();
76}
77
78////////
79// Helper functions for working between levels of the tree.
80
81// Helper function to transform a length in the current level to a length in the
82// next level up.
83size_t NextLength(size_t length) {
84    if (length > MerkleTree::kNodeSize) {
85        return fbl::round_up(length, MerkleTree::kNodeSize) / kDigestsPerNode;
86    } else {
87        return 0;
88    }
89}
90
91// Helper function to transform a length in the current level to a node-aligned
92// length in the next level up.
93size_t NextAligned(size_t length) {
94    return fbl::round_up(NextLength(length), MerkleTree::kNodeSize);
95}
96
97} // namespace
98
99////////
100// Creation methods
101
102size_t MerkleTree::GetTreeLength(size_t data_len) {
103    size_t next_len = NextAligned(data_len);
104    return (next_len == 0 ? 0 : next_len + GetTreeLength(next_len));
105}
106
107zx_status_t MerkleTree::Create(const void* data, size_t data_len, void* tree, size_t tree_len,
108                               Digest* digest) {
109    zx_status_t rc;
110    MerkleTree mt;
111    if ((rc = mt.CreateInit(data_len, tree_len)) != ZX_OK ||
112        (rc = mt.CreateUpdate(data, data_len, tree)) != ZX_OK ||
113        (rc = mt.CreateFinal(tree, digest)) != ZX_OK) {
114        return rc;
115    }
116    return ZX_OK;
117}
118
119MerkleTree::MerkleTree() : initialized_(false), next_(nullptr), level_(0), offset_(0), length_(0) {}
120
121MerkleTree::~MerkleTree() {}
122
123zx_status_t MerkleTree::CreateInit(size_t data_len, size_t tree_len) {
124    initialized_ = true;
125    offset_ = 0;
126    length_ = data_len;
127    // Data fits in a single node, making this the top level of the tree.
128    if (data_len <= kNodeSize) {
129        return ZX_OK;
130    }
131    fbl::AllocChecker ac;
132    next_.reset(new (&ac) MerkleTree());
133    if (!ac.check()) {
134        return ZX_ERR_NO_MEMORY;
135    }
136    next_->level_ = level_ + 1;
137    // Ascend the tree.
138    data_len = NextAligned(data_len);
139    if (tree_len < data_len) {
140        return ZX_ERR_BUFFER_TOO_SMALL;
141    }
142    tree_len -= data_len;
143    return next_->CreateInit(data_len, tree_len);
144}
145
146zx_status_t MerkleTree::CreateUpdate(const void* data, size_t length, void* tree) {
147    ZX_DEBUG_ASSERT(offset_ + length >= offset_);
148    // Must call CreateInit first.
149    if (!initialized_) {
150        return ZX_ERR_BAD_STATE;
151    }
152    // Early exit if no work to do.
153    if (length == 0) {
154        return ZX_OK;
155    }
156    // Must not overrun expected length.
157    if (offset_ + length > length_) {
158        return ZX_ERR_OUT_OF_RANGE;
159    }
160    // Must have data to read and a tree to fill if expecting more than one
161    // digest.
162    if (!data || (!tree && length_ > kNodeSize)) {
163        return ZX_ERR_INVALID_ARGS;
164    }
165    // Save pointers to the data, digest, and the next level tree.
166    const uint8_t* in = static_cast<const uint8_t*>(data);
167    size_t tree_off = (offset_ - (offset_ % kNodeSize)) / kDigestsPerNode;
168    uint8_t* out = static_cast<uint8_t*>(tree) + tree_off;
169    void* next = static_cast<uint8_t*>(tree) + NextAligned(length_);
170    // Consume the data.
171    zx_status_t rc = ZX_OK;
172    while (length > 0 && rc == ZX_OK) {
173        // Check if this is the start of a node.
174        if (offset_ % kNodeSize == 0 &&
175            (rc = DigestInit(&digest_, offset_ | level_, length_ - offset_)) != ZX_OK) {
176            break;
177        }
178        // Hash the node data.
179        size_t chunk = DigestUpdate(&digest_, in, offset_, length);
180        in += chunk;
181        offset_ += chunk;
182        length -= chunk;
183        // Done if not at the end of a node.
184        if (offset_ % kNodeSize != 0 && offset_ != length_) {
185            break;
186        }
187        DigestFinal(&digest_, offset_);
188        // Done if at the top of the tree.
189        if (length_ <= kNodeSize) {
190            break;
191        }
192        // If this is the first digest in a new node, first initialize it.
193        if (tree_off % kNodeSize == 0) {
194            memset(out, 0, kNodeSize);
195        }
196        // Add the digest and ascend the tree.
197        digest_.CopyTo(out, Digest::kLength);
198        rc = next_->CreateUpdate(out, Digest::kLength, next);
199        out += Digest::kLength;
200        tree_off += Digest::kLength;
201    }
202    return rc;
203}
204
205zx_status_t MerkleTree::CreateFinal(void* tree, Digest* root) {
206    return CreateFinalInternal(nullptr, tree, root);
207}
208
209zx_status_t MerkleTree::CreateFinalInternal(const void* data, void* tree, Digest* root) {
210    zx_status_t rc;
211    // Must call CreateInit first.  Must call CreateUpdate with all data first.
212    if (!initialized_ || (level_ == 0 && offset_ != length_)) {
213        return ZX_ERR_BAD_STATE;
214    }
215    // Must have root to write and a tree to fill if expecting more than one
216    // digest.
217    if (!root || (!tree && length_ > kNodeSize)) {
218        return ZX_ERR_INVALID_ARGS;
219    }
220    // Special case: the level is empty.
221    if (length_ == 0) {
222        if ((rc = DigestInit(&digest_, 0, 0)) != ZX_OK) {
223            return rc;
224        }
225        DigestFinal(&digest_, 0);
226    }
227    // Consume padding if needed.
228    const uint8_t* tail = static_cast<const uint8_t*>(data) + offset_;
229    if ((rc = CreateUpdate(tail, length_ - offset_, tree)) != ZX_OK) {
230        return rc;
231    }
232    initialized_ = false;
233    // If the top, save the digest as the Merkle tree root and return.
234    if (length_ <= kNodeSize) {
235        *root = digest_.AcquireBytes();
236        digest_.ReleaseBytes();
237        return ZX_OK;
238    }
239    // Finalize the next level up.
240    uint8_t* next = static_cast<uint8_t*>(tree) + NextAligned(length_);
241    return next_->CreateFinalInternal(tree, next, root);
242}
243
244////////
245// Verification methods
246
247zx_status_t MerkleTree::Verify(const void* data, size_t data_len, const void* tree, size_t tree_len,
248                               size_t offset, size_t length, const Digest& root) {
249    uint64_t level = 0;
250    size_t root_len = data_len;
251    while (data_len > kNodeSize) {
252        zx_status_t rc;
253        // Verify the data in this level.
254        if ((rc = VerifyLevel(data, data_len, tree, offset, length, level)) != ZX_OK) {
255            return rc;
256        }
257        // Ascend to the next level up.
258        data = tree;
259        root_len = NextLength(data_len);
260        data_len = NextAligned(data_len);
261        tree = static_cast<const uint8_t*>(tree) + data_len;
262        if (tree_len < data_len) {
263            return ZX_ERR_BUFFER_TOO_SMALL;
264        }
265        tree_len -= data_len;
266        offset /= kDigestsPerNode;
267        length /= kDigestsPerNode;
268        ++level;
269    }
270    return VerifyRoot(data, root_len, level, root);
271}
272
273zx_status_t MerkleTree::VerifyRoot(const void* data, size_t root_len, uint64_t level,
274                                   const Digest& expected) {
275    zx_status_t rc;
276    // Must have data if length isn't 0.  Must have either zero or one node.
277    if ((!data && root_len != 0) || root_len > kNodeSize) {
278        return ZX_ERR_INVALID_ARGS;
279    }
280    const uint8_t* in = static_cast<const uint8_t*>(data);
281    Digest actual;
282    // We have up to one node if at tree bottom, exactly one node otherwise.
283    if ((rc = DigestInit(&actual, level, (level == 0 ? root_len : kNodeSize))) != ZX_OK) {
284        return rc;
285    }
286    DigestUpdate(&actual, in, 0, root_len);
287    DigestFinal(&actual, root_len);
288    return (actual == expected ? ZX_OK : ZX_ERR_IO_DATA_INTEGRITY);
289}
290
291zx_status_t MerkleTree::VerifyLevel(const void* data, size_t data_len, const void* tree,
292                                    size_t offset, size_t length, uint64_t level) {
293    zx_status_t rc;
294    ZX_DEBUG_ASSERT(offset + length >= offset);
295    // Must have more than one node of data and digests to check against.
296    if (!data || data_len <= kNodeSize || !tree) {
297        return ZX_ERR_INVALID_ARGS;
298    }
299    // Must not overrun expected length.
300    if (offset + length > data_len) {
301        return ZX_ERR_OUT_OF_RANGE;
302    }
303    // Align parameters to node boundaries, but don't exceed data_len
304    offset -= offset % kNodeSize;
305    size_t finish = fbl::round_up(offset + length, kNodeSize);
306    length = fbl::min(finish, data_len) - offset;
307    const uint8_t* in = static_cast<const uint8_t*>(data) + offset;
308    // The digests are in the next level up.
309    Digest actual;
310    const uint8_t* expected = static_cast<const uint8_t*>(tree) + (offset / kDigestsPerNode);
311    // Check the data of this level against the digests.
312    while (length > 0) {
313        if ((rc = DigestInit(&actual, offset | level, data_len - offset)) != ZX_OK) {
314            return rc;
315        }
316        size_t chunk = DigestUpdate(&actual, in, offset, length);
317        in += chunk;
318        offset += chunk;
319        length -= chunk;
320        DigestFinal(&actual, offset);
321        if (actual != expected) {
322            return ZX_ERR_IO_DATA_INTEGRITY;
323        }
324        expected += Digest::kLength;
325    }
326    return ZX_OK;
327}
328
329} // namespace digest
330
331////////
332// C-style wrapper functions
333
334using digest::Digest;
335using digest::MerkleTree;
336
337struct merkle_tree_t {
338    MerkleTree obj;
339};
340
341size_t merkle_tree_get_tree_length(size_t data_len) {
342    return MerkleTree::GetTreeLength(data_len);
343}
344
345zx_status_t merkle_tree_create_init(size_t data_len, size_t tree_len, merkle_tree_t** out) {
346    zx_status_t rc;
347    // Must have some where to store the wrapper.
348    if (!out) {
349        return ZX_ERR_INVALID_ARGS;
350    }
351    // Allocate the wrapper object using a unique_ptr.  That way, if we hit an
352    // error we'll clean up automatically.
353    fbl::AllocChecker ac;
354    fbl::unique_ptr<merkle_tree_t> mt_uniq(new (&ac) merkle_tree_t());
355    if (!ac.check()) {
356        return ZX_ERR_NO_MEMORY;
357    }
358    // Call the C++ function.
359    if ((rc = mt_uniq->obj.CreateInit(data_len, tree_len)) != ZX_OK) {
360        return rc;
361    }
362    // Release the wrapper object.
363    *out = mt_uniq.release();
364    return ZX_OK;
365}
366
367zx_status_t merkle_tree_create_update(merkle_tree_t* mt, const void* data, size_t length,
368                                      void* tree) {
369    // Must have a wrapper object.
370    if (!mt) {
371        return ZX_ERR_INVALID_ARGS;
372    }
373    // Call the C++ function.
374    zx_status_t rc;
375    if ((rc = mt->obj.CreateUpdate(data, length, tree)) != ZX_OK) {
376        return rc;
377    }
378    return ZX_OK;
379}
380
381zx_status_t merkle_tree_create_final(merkle_tree_t* mt, void* tree, void* out, size_t out_len) {
382    // Must have a wrapper object.
383    if (!mt) {
384        return ZX_ERR_INVALID_ARGS;
385    }
386    // Take possession of the wrapper object. That way, we'll clean up
387    // automatically.
388    fbl::unique_ptr<merkle_tree_t> mt_uniq(mt);
389    // Call the C++ function.
390    zx_status_t rc;
391    Digest digest;
392    if ((rc = mt_uniq->obj.CreateFinal(tree, &digest)) != ZX_OK) {
393        return rc;
394    }
395    return digest.CopyTo(static_cast<uint8_t*>(out), out_len);
396}
397
398zx_status_t merkle_tree_create(const void* data, size_t data_len, void* tree, size_t tree_len,
399                               void* out, size_t out_len) {
400    zx_status_t rc;
401    Digest digest;
402    if ((rc = MerkleTree::Create(data, data_len, tree, tree_len, &digest)) != ZX_OK) {
403        return rc;
404    }
405    return digest.CopyTo(static_cast<uint8_t*>(out), out_len);
406}
407
408zx_status_t merkle_tree_verify(const void* data, size_t data_len, void* tree, size_t tree_len,
409                               size_t offset, size_t length, const void* root, size_t root_len) {
410    // Must have a complete root digest.
411    if (root_len < Digest::kLength) {
412        return ZX_ERR_INVALID_ARGS;
413    }
414    Digest digest(static_cast<const uint8_t*>(root));
415    return MerkleTree::Verify(data, data_len, tree, tree_len, offset, length, digest);
416}
417