1#!/usr/bin/env python3
2#
3# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
4#
5# SPDX-License-Identifier: BSD-2-Clause
6#
7
8##
9# A tool for generating bifield structures with get/set/new methods
10# including Isabelle/HOL specifications and correctness proofs.
11##
12
13from __future__ import print_function, division
14import sys
15import os.path
16import optparse
17import re
18import itertools
19import tempfile
20
21from six.moves import range
22from functools import reduce
23
24import lex
25from ply import yacc
26
27import umm
28
29# Whether debugging is enabled (turn on with command line option --debug).
30DEBUG = False
31
32# name of locale the bitfield proofs should be in
33loc_name = 'kernel_all_substitute'
34
35# The C parser emits variables with a suffix to indicate their types; this lets
36# the C parser distinguish `char x` from `int x`.
37#
38# The suffix map should match the C parser types for `uint{8,16,32,64}_t` in
39# `include/stdint.h`.
40var_size_suffix_map = {k: 'unsigned' + v for (k, v) in
41                       {8: '_char', 16: '_short', 32: '',
42                        64: '_longlong'}.items()}
43
44
45def return_name(base):
46    # name of return value for standard word sizes
47    return 'ret__' + var_size_suffix_map[base]
48
49
50def var_name(name, base):
51    # converts C variable name to Isabelle variable name via type mangling.
52    return name + '___' + var_size_suffix_map[base]
53
54
55# Headers to include depending on which environment we are generating code for.
56INCLUDES = {
57    'sel4': ['assert.h', 'config.h', 'stdint.h', 'util.h'],
58    'libsel4': ['autoconf.h', 'sel4/simple_types.h', 'sel4/debug_assert.h'],
59}
60
61ASSERTS = {
62    'sel4': 'assert',
63    'libsel4': 'seL4_DebugAssert'
64}
65
66INLINE = {
67    'sel4': 'static inline',
68    'libsel4': 'LIBSEL4_INLINE_FUNC'
69}
70
71TYPES = {
72    "sel4": {
73        8:  "uint8_t",
74        16: "uint16_t",
75        32: "uint32_t",
76        64: "uint64_t"
77    },
78
79    "libsel4": {
80        8:  "seL4_Uint8",
81        16: "seL4_Uint16",
82        32: "seL4_Uint32",
83        64: "seL4_Uint64"
84    }
85}
86
87# Parser
88
89reserved = ('BLOCK', 'BASE', 'FIELD', 'FIELD_HIGH', 'MASK', 'PADDING',
90            'TAGGED_UNION', 'TAG')
91
92tokens = reserved + ('IDENTIFIER', 'INTLIT', 'LBRACE', 'RBRACE',
93                     'LPAREN', 'RPAREN', 'COMMA')
94
95t_LBRACE = r'{'
96t_RBRACE = r'}'
97t_LPAREN = r'\('
98t_RPAREN = r'\)'
99t_COMMA = r','
100
101reserved_map = dict((r.lower(), r) for r in reserved)
102
103
104def t_IDENTIFIER(t):
105    r'[A-Za-z_]\w+|[A-Za-z]'
106    t.type = reserved_map.get(t.value, 'IDENTIFIER')
107    return t
108
109
110def t_INTLIT(t):
111    r'([1-9][0-9]*|0[oO]?[0-7]+|0[xX][0-9a-fA-F]+|0[bB][01]+|0)[lL]?'
112    t.value = int(t.value, 0)
113    return t
114
115
116def t_NEWLINE(t):
117    r'\n+'
118    t.lexer.lineno += len(t.value)
119
120
121def t_comment(t):
122    r'--.*|\#.*'
123
124
125t_ignore = ' \t'
126
127
128def t_error(t):
129    print("%s: Unexpected character '%s'" % (sys.argv[0], t.value[0]),
130          file=sys.stderr)
131    if DEBUG:
132        print('Token: %s' % str(t), file=sys.stderr)
133    sys.exit(1)
134
135
136def p_start(t):
137    """start : entity_list"""
138    t[0] = t[1]
139
140
141def p_entity_list_empty(t):
142    """entity_list : """
143    t[0] = (None, {}, {})
144
145
146def p_entity_list_base(t):
147    """entity_list : entity_list base"""
148    current_base, block_map, union_map = t[1]
149    block_map.setdefault(t[2], {})
150    union_map.setdefault(t[2], {})
151    t[0] = (t[2], block_map, union_map)
152
153
154def p_entity_list_block(t):
155    """entity_list : entity_list block"""
156    current_base, block_map, union_map = t[1]
157    block_map[current_base][t[2].name] = t[2]
158    t[0] = (current_base, block_map, union_map)
159
160
161def p_entity_list_union(t):
162    """entity_list : entity_list tagged_union"""
163    current_base, block_map, union_map = t[1]
164    union_map[current_base][t[2].name] = t[2]
165    t[0] = (current_base, block_map, union_map)
166
167
168def p_base_simple(t):
169    """base : BASE INTLIT"""
170    t[0] = (t[2], t[2], 0)
171
172
173def p_base_mask(t):
174    """base : BASE INTLIT LPAREN INTLIT COMMA INTLIT RPAREN"""
175    t[0] = (t[2], t[4], t[6])
176
177
178def p_block(t):
179    """block : BLOCK IDENTIFIER opt_visible_order_spec""" \
180        """ LBRACE fields RBRACE"""
181    t[0] = Block(name=t[2], fields=t[5], visible_order=t[3])
182
183
184def p_opt_visible_order_spec_empty(t):
185    """opt_visible_order_spec : """
186    t[0] = None
187
188
189def p_opt_visible_order_spec(t):
190    """opt_visible_order_spec : LPAREN visible_order_spec RPAREN"""
191    t[0] = t[2]
192
193
194def p_visible_order_spec_empty(t):
195    """visible_order_spec : """
196    t[0] = []
197
198
199def p_visible_order_spec_single(t):
200    """visible_order_spec : IDENTIFIER"""
201    t[0] = [t[1]]
202
203
204def p_visible_order_spec(t):
205    """visible_order_spec : visible_order_spec COMMA IDENTIFIER"""
206    t[0] = t[1] + [t[3]]
207
208
209def p_fields_empty(t):
210    """fields : """
211    t[0] = []
212
213
214def p_fields_field(t):
215    """fields : fields FIELD IDENTIFIER INTLIT"""
216    t[0] = t[1] + [(t[3], t[4], False)]
217
218
219def p_fields_field_high(t):
220    """fields : fields FIELD_HIGH IDENTIFIER INTLIT"""
221    t[0] = t[1] + [(t[3], t[4], True)]
222
223
224def p_fields_padding(t):
225    """fields : fields PADDING INTLIT"""
226    t[0] = t[1] + [(None, t[3], False)]
227
228
229def p_tagged_union(t):
230    """tagged_union : TAGGED_UNION IDENTIFIER IDENTIFIER""" \
231        """ LBRACE masks tags RBRACE"""
232    t[0] = TaggedUnion(name=t[2], tagname=t[3], classes=t[5], tags=t[6])
233
234
235def p_tags_empty(t):
236    """tags :"""
237    t[0] = []
238
239
240def p_tags(t):
241    """tags : tags TAG IDENTIFIER INTLIT"""
242    t[0] = t[1] + [(t[3], t[4])]
243
244
245def p_masks_empty(t):
246    """masks :"""
247    t[0] = []
248
249
250def p_masks(t):
251    """masks : masks MASK INTLIT INTLIT"""
252    t[0] = t[1] + [(t[3], t[4])]
253
254
255def p_error(t):
256    print("Syntax error at token '%s'" % t.value, file=sys.stderr)
257    sys.exit(1)
258
259# Templates
260
261# C templates
262
263
264typedef_template = \
265    """struct %(name)s {
266    %(type)s words[%(multiple)d];
267};
268typedef struct %(name)s %(name)s_t;"""
269
270generator_template = \
271    """%(inline)s %(block)s_t CONST
272%(block)s_new(%(gen_params)s) {
273    %(block)s_t %(block)s;
274
275%(asserts)s
276
277%(gen_inits)s
278
279    return %(block)s;
280}"""
281
282ptr_generator_template = \
283    """%(inline)s void
284%(block)s_ptr_new(%(ptr_params)s) {
285%(asserts)s
286
287%(ptr_inits)s
288}"""
289
290reader_template = \
291    """%(inline)s %(type)s CONST
292%(block)s_get_%(field)s(%(block)s_t %(block)s) {
293    %(type)s ret;
294    ret = (%(block)s.words[%(index)d] & 0x%(mask)x%(suf)s) %(r_shift_op)s %(shift)d;
295    /* Possibly sign extend */
296    if (__builtin_expect(!!(%(sign_extend)d && (ret & (1%(suf)s << (%(extend_bit)d)))), %(sign_extend)d)) {
297        ret |= 0x%(high_bits)x;
298    }
299    return ret;
300}"""
301
302ptr_reader_template = \
303    """%(inline)s %(type)s PURE
304%(block)s_ptr_get_%(field)s(%(block)s_t *%(block)s_ptr) {
305    %(type)s ret;
306    ret = (%(block)s_ptr->words[%(index)d] & 0x%(mask)x%(suf)s) """ \
307    """%(r_shift_op)s %(shift)d;
308    /* Possibly sign extend */
309    if (__builtin_expect(!!(%(sign_extend)d && (ret & (1%(suf)s << (%(extend_bit)d)))), %(sign_extend)d)) {
310        ret |= 0x%(high_bits)x;
311    }
312    return ret;
313}"""
314
315writer_template = \
316    """%(inline)s %(block)s_t CONST
317%(block)s_set_%(field)s(%(block)s_t %(block)s, %(type)s v%(base)d) {
318    /* fail if user has passed bits that we will override */
319    %(assert)s((((~0x%(mask)x%(suf)s %(r_shift_op)s %(shift)d ) | 0x%(high_bits)x) & v%(base)d) == ((%(sign_extend)d && (v%(base)d & (1%(suf)s << (%(extend_bit)d)))) ? 0x%(high_bits)x : 0));
320    %(block)s.words[%(index)d] &= ~0x%(mask)x%(suf)s;
321    %(block)s.words[%(index)d] |= (v%(base)d %(w_shift_op)s %(shift)d) & 0x%(mask)x%(suf)s;
322    return %(block)s;
323}"""
324
325ptr_writer_template = \
326    """%(inline)s void
327%(block)s_ptr_set_%(field)s(%(block)s_t *%(block)s_ptr, %(type)s v%(base)d) {
328    /* fail if user has passed bits that we will override */
329    %(assert)s((((~0x%(mask)x%(suf)s %(r_shift_op)s %(shift)d) | 0x%(high_bits)x) & v%(base)d) == ((%(sign_extend)d && (v%(base)d & (1%(suf)s << (%(extend_bit)d)))) ? 0x%(high_bits)x : 0));
330    %(block)s_ptr->words[%(index)d] &= ~0x%(mask)x%(suf)s;
331    %(block)s_ptr->words[%(index)d] |= (v%(base)d %(w_shift_op)s """ \
332    """%(shift)d) & 0x%(mask)x;
333}"""
334
335union_generator_template = \
336    """%(inline)s %(union)s_t CONST
337%(union)s_%(block)s_new(%(gen_params)s) {
338    %(union)s_t %(union)s;
339
340%(asserts)s
341
342%(gen_inits)s
343
344    return %(union)s;
345}"""
346
347ptr_union_generator_template = \
348    """%(inline)s void
349%(union)s_%(block)s_ptr_new(%(ptr_params)s) {
350%(asserts)s
351
352%(ptr_inits)s
353}"""
354
355union_reader_template = \
356    """%(inline)s %(type)s CONST
357%(union)s_%(block)s_get_%(field)s(%(union)s_t %(union)s) {
358    %(type)s ret;
359    %(assert)s(((%(union)s.words[%(tagindex)d] >> %(tagshift)d) & 0x%(tagmask)x) ==
360           %(union)s_%(block)s);
361
362    ret = (%(union)s.words[%(index)d] & 0x%(mask)x%(suf)s) %(r_shift_op)s %(shift)d;
363    /* Possibly sign extend */
364    if (__builtin_expect(!!(%(sign_extend)d && (ret & (1%(suf)s << (%(extend_bit)d)))), %(sign_extend)d)) {
365        ret |= 0x%(high_bits)x;
366    }
367    return ret;
368}"""
369
370ptr_union_reader_template = \
371    """%(inline)s %(type)s PURE
372%(union)s_%(block)s_ptr_get_%(field)s(%(union)s_t *%(union)s_ptr) {
373    %(type)s ret;
374    %(assert)s(((%(union)s_ptr->words[%(tagindex)d] >> """ \
375    """%(tagshift)d) & 0x%(tagmask)x) ==
376           %(union)s_%(block)s);
377
378    ret = (%(union)s_ptr->words[%(index)d] & 0x%(mask)x%(suf)s) """ \
379    """%(r_shift_op)s %(shift)d;
380    /* Possibly sign extend */
381    if (__builtin_expect(!!(%(sign_extend)d && (ret & (1%(suf)s << (%(extend_bit)d)))), %(sign_extend)d)) {
382        ret |= 0x%(high_bits)x;
383    }
384    return ret;
385}"""
386
387union_writer_template = \
388    """%(inline)s %(union)s_t CONST
389%(union)s_%(block)s_set_%(field)s(%(union)s_t %(union)s, %(type)s v%(base)d) {
390    %(assert)s(((%(union)s.words[%(tagindex)d] >> %(tagshift)d) & 0x%(tagmask)x) ==
391           %(union)s_%(block)s);
392    /* fail if user has passed bits that we will override */
393    %(assert)s((((~0x%(mask)x%(suf)s %(r_shift_op)s %(shift)d ) | 0x%(high_bits)x) & v%(base)d) == ((%(sign_extend)d && (v%(base)d & (1%(suf)s << (%(extend_bit)d)))) ? 0x%(high_bits)x : 0));
394
395    %(union)s.words[%(index)d] &= ~0x%(mask)x%(suf)s;
396    %(union)s.words[%(index)d] |= (v%(base)d %(w_shift_op)s %(shift)d) & 0x%(mask)x%(suf)s;
397    return %(union)s;
398}"""
399
400ptr_union_writer_template = \
401    """%(inline)s void
402%(union)s_%(block)s_ptr_set_%(field)s(%(union)s_t *%(union)s_ptr,
403                                      %(type)s v%(base)d) {
404    %(assert)s(((%(union)s_ptr->words[%(tagindex)d] >> """ \
405    """%(tagshift)d) & 0x%(tagmask)x) ==
406           %(union)s_%(block)s);
407
408    /* fail if user has passed bits that we will override */
409    %(assert)s((((~0x%(mask)x%(suf)s %(r_shift_op)s %(shift)d) | 0x%(high_bits)x) & v%(base)d) == ((%(sign_extend)d && (v%(base)d & (1%(suf)s << (%(extend_bit)d)))) ? 0x%(high_bits)x : 0));
410
411    %(union)s_ptr->words[%(index)d] &= ~0x%(mask)x%(suf)s;
412    %(union)s_ptr->words[%(index)d] |= """ \
413    """(v%(base)d %(w_shift_op)s %(shift)d) & 0x%(mask)x%(suf)s;
414}"""
415
416tag_reader_header_template = \
417    """%(inline)s %(type)s CONST
418%(union)s_get_%(tagname)s(%(union)s_t %(union)s) {
419"""
420
421tag_reader_entry_template = \
422    """    if ((%(union)s.words[%(index)d] & 0x%(classmask)x) != 0x%(classmask)x)
423        return (%(union)s.words[%(index)d] >> %(shift)d) & 0x%(mask)x%(suf)s;
424"""
425
426tag_reader_final_template = \
427    """    return (%(union)s.words[%(index)d] >> %(shift)d) & 0x%(mask)x%(suf)s;"""
428
429tag_reader_footer_template = \
430    """
431}"""
432
433tag_eq_reader_header_template = \
434    """%(inline)s int CONST
435%(union)s_%(tagname)s_equals(%(union)s_t %(union)s, %(type)s %(union)s_type_tag) {
436"""
437
438tag_eq_reader_entry_template = \
439    """    if ((%(union)s_type_tag & 0x%(classmask)x) != 0x%(classmask)x)
440        return ((%(union)s.words[%(index)d] >> %(shift)d) & 0x%(mask)x%(suf)s) == %(union)s_type_tag;
441"""
442
443tag_eq_reader_final_template = \
444    """    return ((%(union)s.words[%(index)d] >> %(shift)d) & 0x%(mask)x%(suf)s) == %(union)s_type_tag;"""
445
446tag_eq_reader_footer_template = \
447    """
448}"""
449
450ptr_tag_reader_header_template = \
451    """%(inline)s %(type)s PURE
452%(union)s_ptr_get_%(tagname)s(%(union)s_t *%(union)s_ptr) {
453"""
454
455ptr_tag_reader_entry_template = \
456    """    if ((%(union)s_ptr->words[%(index)d] & 0x%(classmask)x) != 0x%(classmask)x)
457        return (%(union)s_ptr->words[%(index)d] >> %(shift)d) & 0x%(mask)x%(suf)s;
458"""
459
460ptr_tag_reader_final_template = \
461    """    return (%(union)s_ptr->words[%(index)d] >> %(shift)d) & 0x%(mask)x%(suf)s;"""
462
463ptr_tag_reader_footer_template = \
464    """
465}"""
466
467tag_writer_template = \
468    """%(inline)s %(union)s_t CONST
469%(union)s_set_%(tagname)s(%(union)s_t %(union)s, %(type)s v%(base)d) {
470    /* fail if user has passed bits that we will override */
471    %(assert)s((((~0x%(mask)x%(suf)s %(r_shift_op)s %(shift)d) | 0x%(high_bits)x) & v%(base)d) == ((%(sign_extend)d && (v%(base)d & (1%(suf)s << (%(extend_bit)d)))) ? 0x%(high_bits)x : 0));
472
473    %(union)s.words[%(index)d] &= ~0x%(mask)x%(suf)s;
474    %(union)s.words[%(index)d] |= (v%(base)d << %(shift)d) & 0x%(mask)x%(suf)s;
475    return %(union)s;
476}"""
477
478ptr_tag_writer_template = \
479    """%(inline)s void
480%(union)s_ptr_set_%(tagname)s(%(union)s_t *%(union)s_ptr, %(type)s v%(base)d) {
481    /* fail if user has passed bits that we will override */
482    %(assert)s((((~0x%(mask)x%(suf)s %(r_shift_op)s %(shift)d) | 0x%(high_bits)x) & v%(base)d) == ((%(sign_extend)d && (v%(base)d & (1%(suf)s << (%(extend_bit)d)))) ? 0x%(high_bits)x : 0));
483
484    %(union)s_ptr->words[%(index)d] &= ~0x%(mask)x%(suf)s;
485    %(union)s_ptr->words[%(index)d] |= (v%(base)d << %(shift)d) & 0x%(mask)x%(suf)s;
486}"""
487
488# HOL definition templates
489
490lift_def_template = \
491    '''definition
492  %(name)s_lift :: "%(name)s_C \<Rightarrow> %(name)s_CL"
493where
494  "%(name)s_lift %(name)s \<equiv> \<lparr>
495       %(fields)s \<rparr>"'''
496
497block_lift_def_template = \
498    '''definition %(union)s_%(block)s_lift :: ''' \
499    '''"%(union)s_C \<Rightarrow> %(union)s_%(block)s_CL"
500where
501  "%(union)s_%(block)s_lift %(union)s \<equiv>
502    case (%(union)s_lift %(union)s) of ''' \
503    '''Some (%(generator)s rec) \<Rightarrow> rec"'''
504
505block_lift_lemma_template = \
506    '''lemma %(union)s_%(block)s_lift:
507  "(%(union)s_get_tag c = scast %(union)s_%(block)s) = ''' \
508 '''(%(union)s_lift c = Some (%(generator)s (%(union)s_%(block)s_lift c)))"
509  unfolding %(union)s_lift_def %(union)s_%(block)s_lift_def
510  by (clarsimp simp: %(union)s_tag_defs Let_def)'''
511
512union_get_tag_def_header_template = \
513    '''definition
514  %(name)s_get_tag :: "%(name)s_C \<Rightarrow> word%(base)d"
515where
516  "%(name)s_get_tag %(name)s \<equiv>
517     '''
518
519union_get_tag_def_entry_template = \
520    '''if ((index (%(name)s_C.words_C %(name)s) %(tag_index)d)''' \
521    ''' AND 0x%(classmask)x \<noteq> 0x%(classmask)x)
522      then ((index (%(name)s_C.words_C %(name)s) %(tag_index)d)'''\
523''' >> %(tag_shift)d) AND mask %(tag_size)d
524      else '''
525
526union_get_tag_def_final_template = \
527    '''((index (%(name)s_C.words_C %(name)s) %(tag_index)d)'''\
528    ''' >> %(tag_shift)d) AND mask %(tag_size)d'''
529
530union_get_tag_def_footer_template = '''"'''
531
532union_get_tag_eq_x_def_header_template = \
533    '''lemma %(name)s_get_tag_eq_x:
534  "(%(name)s_get_tag c = x) = (('''
535
536union_get_tag_eq_x_def_entry_template = \
537    '''if ((x << %(tag_shift)d) AND 0x%(classmask)x \<noteq> 0x%(classmask)x)
538      then ((index (%(name)s_C.words_C c) %(tag_index)d)''' \
539''' >> %(tag_shift)d) AND mask %(tag_size)d
540      else '''
541
542union_get_tag_eq_x_def_final_template = \
543    '''((index (%(name)s_C.words_C c) %(tag_index)d)''' \
544    ''' >> %(tag_shift)d) AND mask %(tag_size)d'''
545
546union_get_tag_eq_x_def_footer_template = ''') = x)"
547  by (auto simp add: %(name)s_get_tag_def mask_def word_bw_assocs)'''
548
549union_tag_mask_helpers_header_template = \
550    '''lemma %(name)s_%(block)s_tag_mask_helpers:'''
551
552union_tag_mask_helpers_entry_template = '''
553  "w && %(full_mask)s = %(full_value)s \<Longrightarrow> w'''\
554''' && %(part_mask)s = %(part_value)s"
555'''
556
557union_tag_mask_helpers_footer_template = \
558    '''  by (auto elim: word_sub_mask simp: mask_def)'''
559
560union_lift_def_template = \
561    '''definition
562  %(name)s_lift :: "%(name)s_C \<Rightarrow> %(name)s_CL option"
563where
564  "%(name)s_lift %(name)s \<equiv>
565    (let tag = %(name)s_get_tag %(name)s in
566     %(tag_cases)s
567     else None)"'''
568
569union_access_def_template = \
570    '''definition
571  %(union)s_%(block)s_access :: "(%(union)s_%(block)s_CL \<Rightarrow> 'a)
572                                 \<Rightarrow> %(union)s_CL \<Rightarrow> 'a"
573where
574  "%(union)s_%(block)s_access f %(union)s \<equiv>
575     (case %(union)s of %(generator)s rec \<Rightarrow> f rec)"'''
576
577union_update_def_template = \
578    '''definition
579  %(union)s_%(block)s_update :: "(%(union)s_%(block)s_CL \<Rightarrow>''' \
580                                ''' %(union)s_%(block)s_CL) \<Rightarrow>
581                                 %(union)s_CL \<Rightarrow> %(union)s_CL"
582where
583  "%(union)s_%(block)s_update f %(union)s \<equiv>
584     (case %(union)s of %(generator)s rec \<Rightarrow>
585        %(generator)s (f rec))"'''
586
587# HOL proof templates
588
589# FIXME: avoid [simp]
590struct_lemmas_template = \
591    '''
592lemmas %(name)s_ptr_guards[simp] =
593  %(name)s_ptr_words_NULL
594  %(name)s_ptr_words_aligned
595  %(name)s_ptr_words_ptr_safe'''
596
597# FIXME: move to global theory
598defs_global_lemmas = '''
599lemma word_sub_mask:
600  "\<lbrakk> w && m1 = v1; m1 && m2 = m2; v1 && m2 = v2 \<rbrakk>
601     \<Longrightarrow> w && m2 = v2"
602  by (clarsimp simp: word_bw_assocs)
603'''
604
605# Proof templates are stored as a list of
606# (header, body, stuff to go afterwards).
607# This makes it easy to replace the proof body with a sorry.
608
609# ptrname should be a function of s
610
611
612def ptr_basic_template(name, ptrname, retval, args, post):
613    return ('''lemma (in ''' + loc_name + ''') %(name)s_ptr_''' + name + '''_spec:
614           defines "ptrval s \<equiv> cslift s ''' + ptrname + '''"
615           shows "\<forall>s. \<Gamma> \<turnstile> \<lbrace>s. s \<Turnstile>\<^sub>c ''' + ptrname + '''\<rbrace>
616            ''' + retval + '''PROC %(name)s_ptr_''' + name + '''(\<acute>%(name)s_ptr''' + args + ''')
617            ''' + post + ''' " ''')
618
619
620def ptr_union_basic_template(name, ptrname, retval, args, pre, post):
621    return ('''lemma (in ''' + loc_name + ''') %(name)s_%(block)s_ptr_''' + name + '''_spec:
622    defines "ptrval s \<equiv> cslift s ''' + ptrname + '''"
623    shows "\<forall>s. \<Gamma> \<turnstile> \<lbrace>s. s \<Turnstile>\<^sub>c ''' + ptrname + " " + pre + '''\<rbrace>
624            ''' + retval + '''PROC %(name)s_%(block)s_ptr_''' + name + '''(\<acute>%(name)s_ptr''' + args + ''')
625            ''' + post + ''' " ''')
626
627
628direct_ptr_name = '\<^bsup>s\<^esup>%(name)s_ptr'
629path_ptr_name = '(cparent \<^bsup>s\<^esup>%(name)s_ptr [%(path)s] :: %(toptp)s ptr)'
630
631
632def ptr_get_template(ptrname):
633    return ptr_basic_template('get_%(field)s', ptrname, '\<acute>%(ret_name)s :== ', '',
634                              '''\<lbrace>\<acute>%(ret_name)s = '''
635                              '''%(name)s_CL.%(field)s_CL '''
636                              '''(%(name)s_lift (%(access_path)s))\<rbrace>''')
637
638
639def ptr_set_template(name, ptrname):
640    return ptr_basic_template(name, ptrname, '', ', \<acute>v%(base)d',
641                              '''{t. \<exists>%(name)s.
642                              %(name)s_lift %(name)s =
643                              %(name)s_lift (%(access_path)s) \<lparr> %(name)s_CL.%(field)s_CL '''
644                              ''':= %(sign_extend)s(\<^bsup>s\<^esup>v%(base)d AND %(mask)s) \<rparr> \<and>
645                              t_hrs_' (globals t) = hrs_mem_update (heap_update
646                                      (''' + ptrname + ''')
647                                      %(update_path)s)
648                                  (t_hrs_' (globals s))
649                              }''')
650
651
652def ptr_new_template(ptrname):
653    return ptr_basic_template('new', ptrname, '', ', %(args)s',
654                              '''{t. \<exists>%(name)s. %(name)s_lift %(name)s = \<lparr>
655                              %(field_eqs)s \<rparr> \<and>
656                              t_hrs_' (globals t) = hrs_mem_update (heap_update
657                                      (''' + ptrname + ''')
658                                      %(update_path)s)
659                                  (t_hrs_' (globals s))
660                              }''')
661
662
663def ptr_get_tag_template(ptrname):
664    return ptr_basic_template('get_%(tagname)s', ptrname, '\<acute>%(ret_name)s :== ', '',
665                              '''\<lbrace>\<acute>%(ret_name)s = %(name)s_get_tag (%(access_path)s)\<rbrace>''')
666
667
668def ptr_empty_union_new_template(ptrname):
669    return ptr_union_basic_template('new', ptrname, '', '', '',
670                                    '''{t. \<exists>%(name)s. '''
671                                    '''%(name)s_get_tag %(name)s = scast %(name)s_%(block)s \<and>
672                                    t_hrs_' (globals t) = hrs_mem_update (heap_update
673                                            (''' + ptrname + ''')
674                                            %(update_path)s)
675                                        (t_hrs_' (globals s))
676                                    }''')
677
678
679def ptr_union_new_template(ptrname):
680    return ptr_union_basic_template('new', ptrname, '', ', %(args)s', '',
681                                    '''{t. \<exists>%(name)s. '''
682                                    '''%(name)s_%(block)s_lift %(name)s = \<lparr>
683                                    %(field_eqs)s \<rparr> \<and>
684                                    %(name)s_get_tag %(name)s = scast %(name)s_%(block)s \<and>
685                                    t_hrs_' (globals t) = hrs_mem_update (heap_update
686                                            (''' + ptrname + ''')
687                                            %(update_path)s)
688                                        (t_hrs_' (globals s))
689                                    }''')
690
691
692def ptr_union_get_template(ptrname):
693    return ptr_union_basic_template('get_%(field)s', ptrname,
694                                    '\<acute>%(ret_name)s :== ', '',
695                                    '\<and> %(name)s_get_tag %(access_path)s = scast %(name)s_%(block)s',
696                                    '''\<lbrace>\<acute>%(ret_name)s = '''
697                                    '''%(name)s_%(block)s_CL.%(field)s_CL '''
698                                    '''(%(name)s_%(block)s_lift %(access_path)s)\<rbrace>''')
699
700
701def ptr_union_set_template(ptrname):
702    return ptr_union_basic_template('set_%(field)s', ptrname, '', ', \<acute>v%(base)d',
703                                    '\<and> %(name)s_get_tag %(access_path)s = scast %(name)s_%(block)s',
704                                    '''{t. \<exists>%(name)s. '''
705                                    '''%(name)s_%(block)s_lift %(name)s =
706                                    %(name)s_%(block)s_lift %(access_path)s '''
707                                    '''\<lparr> %(name)s_%(block)s_CL.%(field)s_CL '''
708                                    ''':= %(sign_extend)s(\<^bsup>s\<^esup>v%(base)d AND %(mask)s) \<rparr> \<and>
709                                    %(name)s_get_tag %(name)s = scast %(name)s_%(block)s \<and>
710                                    t_hrs_' (globals t) = hrs_mem_update (heap_update
711                                            (''' + ptrname + ''')
712                                            %(update_path)s)
713                                        (t_hrs_' (globals s))
714                                    }''')
715
716
717proof_templates = {
718
719    'lift_collapse_proof': [
720        '''lemma %(name)s_lift_%(block)s:
721  "%(name)s_get_tag %(name)s = scast %(name)s_%(block)s \<Longrightarrow>
722  %(name)s_lift %(name)s =
723  Some (%(value)s)"''',
724        ''' apply(simp add:%(name)s_lift_def %(name)s_tag_defs)
725done'''],
726
727    'words_NULL_proof': [
728        '''lemma %(name)s_ptr_words_NULL:
729  "c_guard (p::%(name)s_C ptr) \<Longrightarrow>
730   0 < &(p\<rightarrow>[''words_C''])"''',
731        ''' apply(fastforce intro:c_guard_NULL_fl simp:typ_uinfo_t_def)
732done'''],
733
734    'words_aligned_proof': [
735        '''lemma %(name)s_ptr_words_aligned:
736  "c_guard (p::%(name)s_C ptr) \<Longrightarrow>
737   ptr_aligned ((Ptr &(p\<rightarrow>[''words_C'']))::'''
738        '''((word%(base)d[%(words)d]) ptr))"''',
739        ''' apply(fastforce intro:c_guard_ptr_aligned_fl simp:typ_uinfo_t_def)
740done'''],
741
742    'words_ptr_safe_proof': [
743        '''lemma %(name)s_ptr_words_ptr_safe:
744  "ptr_safe (p::%(name)s_C ptr) d \<Longrightarrow>
745   ptr_safe (Ptr &(p\<rightarrow>[''words_C''])::'''
746        '''((word%(base)d[%(words)d]) ptr)) d"''',
747        ''' apply(fastforce intro:ptr_safe_mono simp:typ_uinfo_t_def)
748done'''],
749
750    'get_tag_fun_spec_proof': [
751        '''lemma (in ''' + loc_name + ''') fun_spec:
752  "\<Gamma> \<turnstile> {\<sigma>}
753       \<acute>ret__%(rtype)s :== PROC %(name)s_get_%(tag_name)s('''
754        ''' \<acute>%(name))
755       \<lbrace>\<acute>ret__%(rtype)s = %(name)s_get_tag'''
756        '''\<^bsup>\<sigma>\<^esup>\<rbrace>"''',
757        ''' apply(rule allI, rule conseqPre, vcg)
758 apply(clarsimp)
759 apply(simp add:$(name)s_get_tag_def word_sle_def
760                mask_def ucast_def)
761done'''],
762
763    'const_modifies_proof': [
764        '''lemma (in ''' + loc_name + ''') %(fun_name)s_modifies:
765  "\<forall> s. \<Gamma> \<turnstile>\<^bsub>/UNIV\<^esub> {s}
766       PROC %(fun_name)s(%(args)s)
767       {t. t may_not_modify_globals s}"''',
768        ''' by (vcg spec=modifies strip_guards=true)'''],
769
770    'ptr_set_modifies_proof': [
771        '''lemma (in ''' + loc_name + ''') %(fun_name)s_modifies:
772  "\<forall>s. \<Gamma> \<turnstile>\<^bsub>/UNIV\<^esub> {s}
773       PROC %(fun_name)s(%(args)s)
774       {t. t may_only_modify_globals s in [t_hrs]}"''',
775        ''' by (vcg spec=modifies strip_guards=true)'''],
776
777
778    'new_spec': [
779        '''lemma (in ''' + loc_name + ''') %(name)s_new_spec:
780  "\<forall> s. \<Gamma> \<turnstile> {s}
781       \<acute>ret__struct_%(name)s_C :== PROC %(name)s_new(%(args)s)
782       \<lbrace> %(name)s_lift \<acute>ret__struct_%(name)s_C = \<lparr>
783          %(field_eqs)s \<rparr> \<rbrace>"''',
784        '''  apply (rule allI, rule conseqPre, vcg)
785  apply (clarsimp simp: guard_simps)
786  apply (simp add: %(name)s_lift_def)
787  apply ((intro conjI sign_extend_eq)?;
788         (simp add: mask_def shift_over_ao_dists multi_shift_simps word_size
789                    word_ao_dist word_bw_assocs word_and_max_simps)?)
790  done'''],
791
792    'ptr_new_spec_direct': [
793        ptr_new_template(direct_ptr_name),
794        '''sorry (* ptr_new_spec_direct *)'''],
795
796    'ptr_new_spec_path': [
797        ptr_new_template(path_ptr_name),
798        '''sorry (* ptr_new_spec_path *)'''],
799
800
801    'get_spec': [
802        '''lemma (in ''' + loc_name + ''') %(name)s_get_%(field)s_spec:
803  "\<forall>s. \<Gamma> \<turnstile> {s}
804       \<acute>%(ret_name)s :== '''
805        '''PROC %(name)s_get_%(field)s(\<acute>%(name)s)
806       \<lbrace>\<acute>%(ret_name)s = '''
807        '''%(name)s_CL.%(field)s_CL '''
808        '''(%(name)s_lift \<^bsup>s\<^esup>%(name)s)\<rbrace>"''',
809        '''  apply (rule allI, rule conseqPre, vcg)
810  apply clarsimp
811  apply (simp add: %(name)s_lift_def mask_shift_simps guard_simps)
812  apply (simp add: sign_extend_def' mask_def nth_is_and_neq_0 word_bw_assocs
813                   shift_over_ao_dists word_oa_dist word_and_max_simps)?
814  done'''],
815
816    'set_spec': [
817        '''lemma (in ''' + loc_name + ''') %(name)s_set_%(field)s_spec:
818  "\<forall>s. \<Gamma> \<turnstile> {s}
819       \<acute>ret__struct_%(name)s_C :== '''
820        '''PROC %(name)s_set_%(field)s(\<acute>%(name)s, \<acute>v%(base)d)
821       \<lbrace>%(name)s_lift \<acute>ret__struct_%(name)s_C = '''
822        '''%(name)s_lift \<^bsup>s\<^esup>%(name)s \<lparr> '''
823        '''%(name)s_CL.%(field)s_CL '''
824        ''':= %(sign_extend)s (\<^bsup>s\<^esup>v%(base)d AND %(mask)s) \<rparr>\<rbrace>"''',
825        '''  apply(rule allI, rule conseqPre, vcg)
826  apply(clarsimp simp: guard_simps ucast_id
827                       %(name)s_lift_def
828                       mask_def shift_over_ao_dists
829                       multi_shift_simps word_size
830                       word_ao_dist word_bw_assocs
831                       NOT_eq)
832  apply (simp add: sign_extend_def' mask_def nth_is_and_neq_0 word_bw_assocs
833                   shift_over_ao_dists word_and_max_simps)?
834  done'''],
835
836    # where the top level type is the bitfield type --- these are split because they have different proofs
837    'ptr_get_spec_direct': [
838        ptr_get_template(direct_ptr_name),
839        '''   unfolding ptrval_def
840  apply (rule allI, rule conseqPre, vcg)
841  apply (clarsimp simp: h_t_valid_clift_Some_iff)
842  apply (simp add: %(name)s_lift_def guard_simps mask_def typ_heap_simps ucast_def)
843  apply (simp add: sign_extend_def' mask_def nth_is_and_neq_0 word_bw_assocs
844                   shift_over_ao_dists word_oa_dist word_and_max_simps)?
845  done'''],
846
847    'ptr_get_spec_path': [
848        ptr_get_template(path_ptr_name),
849        '''  unfolding ptrval_def
850  apply (rule allI, rule conseqPre, vcg)
851  apply (clarsimp simp: guard_simps)
852  apply (frule iffD1[OF h_t_valid_clift_Some_iff], rule exE, assumption, simp)
853  apply (frule clift_subtype, simp, simp, simp)
854  apply (simp add: h_val_field_clift' typ_heap_simps)
855  apply (simp add: thread_state_lift_def)
856  apply (simp add: sign_extend_def' mask_def nth_is_and_neq_0 word_bw_assocs
857                   shift_over_ao_dists word_oa_dist word_and_max_simps)?
858  apply (simp add: mask_shift_simps)?
859  done'''],
860
861    'ptr_set_spec_direct': [
862        ptr_set_template('set_%(field)s', direct_ptr_name),
863        '''  unfolding ptrval_def
864  apply (rule allI, rule conseqPre, vcg)
865  apply (clarsimp simp: guard_simps)
866  apply (clarsimp simp add: packed_heap_update_collapse_hrs typ_heap_simps)?
867  apply (rule exI, rule conjI[rotated], rule refl)
868  apply (clarsimp simp: h_t_valid_clift_Some_iff %(name)s_lift_def typ_heap_simps)
869  apply ((intro conjI sign_extend_eq)?;
870         (simp add: mask_def shift_over_ao_dists multi_shift_simps word_size
871                    word_ao_dist word_bw_assocs word_and_max_simps))?
872  done'''],
873
874    'ptr_set_spec_path': [
875        ptr_set_template('set_%(field)s', path_ptr_name),
876        '''  (* Invoke vcg *)
877  unfolding ptrval_def
878  apply (rule allI, rule conseqPre, vcg)
879  apply (clarsimp)
880
881  (* Infer h_t_valid for all three levels of indirection *)
882  apply (frule h_t_valid_c_guard_cparent, simp, simp add: typ_uinfo_t_def)
883  apply (frule h_t_valid_c_guard_field[where f="[''words_C'']"],
884                                       simp, simp add: typ_uinfo_t_def)
885
886  (* Discharge guards, including c_guard for pointers *)
887  apply (simp add: h_t_valid_c_guard guard_simps)
888
889  (* Lift field updates to bitfield struct updates *)
890  apply (simp add: heap_update_field_hrs h_t_valid_c_guard typ_heap_simps)
891
892  (* Collapse multiple updates *)
893  apply(simp add: packed_heap_update_collapse_hrs)
894
895  (* Instantiate the toplevel object *)
896  apply(frule iffD1[OF h_t_valid_clift_Some_iff], rule exE, assumption, simp)
897
898  (* Instantiate the next-level object in terms of the last *)
899  apply(frule clift_subtype, simp+)
900
901  (* Resolve pointer accesses *)
902  apply(simp add: h_val_field_clift')
903
904  (* Rewrite bitfield struct updates as enclosing struct updates *)
905  apply(frule h_t_valid_c_guard)
906  apply(simp add: parent_update_child)
907
908  (* Equate the updated values *)
909  apply(rule exI, rule conjI[rotated], simp add: h_val_clift')
910
911  (* Rewrite struct updates *)
912  apply(simp add: o_def %(name)s_lift_def)
913
914  (* Solve bitwise arithmetic *)
915  apply ((intro conjI sign_extend_eq)?;
916         (simp add: mask_def shift_over_ao_dists multi_shift_simps word_size
917                    word_ao_dist word_bw_assocs word_and_max_simps))?
918  done'''],
919
920
921    'get_tag_spec': [
922        '''lemma (in ''' + loc_name + ''') %(name)s_get_%(tagname)s_spec:
923  "\<forall>s. \<Gamma> \<turnstile> {s}
924       \<acute>%(ret_name)s :== ''' \
925    '''PROC %(name)s_get_%(tagname)s(\<acute>%(name)s)
926       \<lbrace>\<acute>%(ret_name)s = ''' \
927    '''%(name)s_get_tag \<^bsup>s\<^esup>%(name)s\<rbrace>"''',
928        ''' apply(rule allI, rule conseqPre, vcg)
929 apply(clarsimp)
930 apply(simp add:%(name)s_get_tag_def
931                mask_shift_simps
932                guard_simps)
933done'''],
934
935    'get_tag_equals_spec': [
936        '''lemma (in ''' + loc_name + ''') %(name)s_%(tagname)s_equals_spec:
937  "\<forall>s. \<Gamma> \<turnstile> {s}
938       \<acute>ret__int :==
939       PROC %(name)s_%(tagname)s_equals(\<acute>%(name)s, \<acute>%(name)s_type_tag)
940       \<lbrace>\<acute>ret__int = of_bl [%(name)s_get_tag \<^bsup>s\<^esup>%(name)s = \<^bsup>s\<^esup>%(name)s_type_tag]\<rbrace>"''',
941        ''' apply(rule allI, rule conseqPre, vcg)
942 apply(clarsimp)
943 apply(simp add:%(name)s_get_tag_eq_x
944                mask_shift_simps
945                guard_simps)
946done'''],
947
948    'ptr_get_tag_spec_direct': [
949        ptr_get_tag_template(direct_ptr_name),
950        ''' unfolding ptrval_def
951 apply(rule allI, rule conseqPre, vcg)
952 apply(clarsimp simp:guard_simps)
953 apply(frule h_t_valid_field[where f="[''words_C'']"], simp+)
954 apply(frule iffD1[OF h_t_valid_clift_Some_iff], rule exE, assumption, simp)
955 apply(simp add:h_val_clift' clift_field)
956 apply(simp add:%(name)s_get_tag_def)
957 apply(simp add:mask_shift_simps)
958done'''],
959
960    'ptr_get_tag_spec_path': [
961        ptr_get_tag_template(path_ptr_name),
962        ''' unfolding ptrval_def
963 apply(rule allI, rule conseqPre, vcg)
964 apply(clarsimp)
965 apply(frule h_t_valid_c_guard_cparent, simp, simp add: typ_uinfo_t_def)
966 apply(clarsimp simp: typ_heap_simps h_t_valid_clift_Some_iff)
967 apply(frule clift_subtype, simp+)
968 apply(simp add: %(name)s_get_tag_def mask_shift_simps guard_simps)
969done'''],
970
971
972    'empty_union_new_spec': [
973        '''lemma (in ''' + loc_name + ''') ''' \
974        '''%(name)s_%(block)s_new_spec:
975  "\<forall>s. \<Gamma> \<turnstile> {s}
976       \<acute>ret__struct_%(name)s_C :== ''' \
977    '''PROC %(name)s_%(block)s_new()
978       \<lbrace>%(name)s_get_tag \<acute>ret__struct_%(name)s_C = ''' \
979     '''scast %(name)s_%(block)s\<rbrace>"''',
980        ''' apply(rule allI, rule conseqPre, vcg)
981 apply(clarsimp simp:guard_simps
982                     %(name)s_lift_def
983                     Let_def
984                     %(name)s_get_tag_def
985                     mask_shift_simps
986                     %(name)s_tag_defs
987                     word_of_int_hom_syms)
988done'''],
989
990    'union_new_spec': [
991        '''lemma (in ''' + loc_name + ''') ''' \
992        '''%(name)s_%(block)s_new_spec:
993  "\<forall>s. \<Gamma> \<turnstile> {s}
994       \<acute>ret__struct_%(name)s_C :== ''' \
995    '''PROC %(name)s_%(block)s_new(%(args)s)
996       \<lbrace>%(name)s_%(block)s_lift ''' \
997    '''\<acute>ret__struct_%(name)s_C = \<lparr>
998          %(field_eqs)s \<rparr> \<and>
999        %(name)s_get_tag \<acute>ret__struct_%(name)s_C = ''' \
1000     '''scast %(name)s_%(block)s\<rbrace>"''',
1001        '''  apply (rule allI, rule conseqPre, vcg)
1002  apply (clarsimp simp: guard_simps o_def mask_def shift_over_ao_dists)
1003  apply (rule context_conjI[THEN iffD1[OF conj_commute]],
1004         fastforce simp: %(name)s_get_tag_eq_x %(name)s_%(block)s_def
1005                         mask_def shift_over_ao_dists word_bw_assocs word_ao_dist)
1006  apply (simp add: %(name)s_%(block)s_lift_def)
1007  apply (erule %(name)s_lift_%(block)s[THEN subst[OF sym]]; simp?)
1008  apply ((intro conjI sign_extend_eq)?;
1009         (simp add: mask_def shift_over_ao_dists multi_shift_simps word_size
1010                    word_ao_dist word_bw_assocs word_and_max_simps))?
1011  done'''],
1012
1013    'ptr_empty_union_new_spec_direct': [
1014        ptr_empty_union_new_template(direct_ptr_name),
1015        '''sorry (* ptr_empty_union_new_spec_direct *)'''],
1016
1017    'ptr_empty_union_new_spec_path': [
1018        ptr_empty_union_new_template(path_ptr_name),
1019        ''' unfolding ptrval_def
1020 apply(rule allI, rule conseqPre, vcg)
1021 apply(clarsimp)
1022 apply(frule h_t_valid_c_guard_cparent, simp, simp add: typ_uinfo_t_def)
1023 apply(clarsimp simp: h_t_valid_clift_Some_iff)
1024 apply(frule clift_subtype, simp+)
1025 apply(clarsimp simp: typ_heap_simps c_guard_clift
1026                      packed_heap_update_collapse_hrs)
1027
1028 apply(simp add: parent_update_child[OF c_guard_clift]
1029                 typ_heap_simps c_guard_clift)
1030
1031 apply((simp add: o_def)?, rule exI, rule conjI[OF _ refl])
1032
1033 apply(simp add: %(name)s_get_tag_def %(name)s_tag_defs
1034                 guard_simps mask_shift_simps)
1035done
1036'''],
1037
1038    'ptr_union_new_spec_direct': [
1039        ptr_union_new_template(direct_ptr_name),
1040        '''sorry (* ptr_union_new_spec_direct *)'''],
1041
1042    'ptr_union_new_spec_path': [
1043        ptr_union_new_template(path_ptr_name),
1044        '''  unfolding ptrval_def
1045  apply (rule allI, rule conseqPre, vcg)
1046  apply (clarsimp)
1047  apply (frule h_t_valid_c_guard_cparent, simp, simp add: typ_uinfo_t_def)
1048  apply (drule h_t_valid_clift_Some_iff[THEN iffD1], erule exE)
1049  apply (frule clift_subtype, simp, simp)
1050  apply (clarsimp simp: typ_heap_simps c_guard_clift
1051                        packed_heap_update_collapse_hrs)
1052  apply (simp add: guard_simps mask_shift_simps
1053                   %(name)s_tag_defs[THEN tag_eq_to_tag_masked_eq])?
1054  apply (simp add: parent_update_child[OF c_guard_clift]
1055                   typ_heap_simps c_guard_clift)
1056  apply (simp add: o_def %(name)s_%(block)s_lift_def)
1057  apply (simp only: %(name)s_lift_%(block)s cong: rev_conj_cong)
1058  apply (rule exI, rule conjI[rotated], rule conjI[OF _ refl])
1059   apply (simp_all add: %(name)s_get_tag_eq_x %(name)s_tag_defs mask_shift_simps)
1060  apply (intro conjI sign_extend_eq; simp add: mask_def word_ao_dist word_bw_assocs)?
1061  done'''],
1062
1063    'union_get_spec': [
1064        '''lemma (in ''' + loc_name + ''') ''' \
1065        '''%(name)s_%(block)s_get_%(field)s_spec:
1066  "\<forall>s. \<Gamma> \<turnstile> ''' \
1067'''\<lbrace>s. %(name)s_get_tag \<acute>%(name)s = ''' \
1068        '''scast %(name)s_%(block)s\<rbrace>
1069       \<acute>%(ret_name)s :== ''' \
1070       '''PROC %(name)s_%(block)s_get_%(field)s(\<acute>%(name)s)
1071       \<lbrace>\<acute>%(ret_name)s = ''' \
1072       '''%(name)s_%(block)s_CL.%(field)s_CL ''' \
1073       '''(%(name)s_%(block)s_lift \<^bsup>s\<^esup>%(name)s)''' \
1074       '''\<rbrace>"''',
1075        ''' apply(rule allI, rule conseqPre, vcg)
1076 apply(clarsimp simp:guard_simps)
1077 apply(simp add:%(name)s_%(block)s_lift_def)
1078 apply(subst %(name)s_lift_%(block)s)
1079  apply(simp add:o_def
1080                 %(name)s_get_tag_def
1081                 %(name)s_%(block)s_def
1082                 mask_def
1083                 word_size
1084                 shift_over_ao_dists)
1085 apply(subst %(name)s_lift_%(block)s, simp)?
1086 apply(simp add:o_def
1087                %(name)s_get_tag_def
1088                %(name)s_%(block)s_def
1089                mask_def
1090                word_size
1091                shift_over_ao_dists
1092                multi_shift_simps
1093                word_bw_assocs
1094                word_oa_dist
1095                word_and_max_simps
1096                ucast_def
1097                sign_extend_def'
1098                nth_is_and_neq_0)
1099done'''],
1100
1101    'union_set_spec': [
1102        '''lemma (in ''' + loc_name + ''') ''' \
1103        '''%(name)s_%(block)s_set_%(field)s_spec:
1104  "\<forall>s. \<Gamma> \<turnstile> ''' \
1105'''\<lbrace>s. %(name)s_get_tag \<acute>%(name)s = ''' \
1106        '''scast %(name)s_%(block)s\<rbrace>
1107       \<acute>ret__struct_%(name)s_C :== ''' \
1108    '''PROC %(name)s_%(block)s_set_%(field)s(\<acute>%(name)s, \<acute>v%(base)d)
1109       \<lbrace>%(name)s_%(block)s_lift \<acute>ret__struct_%(name)s_C = ''' \
1110    '''%(name)s_%(block)s_lift \<^bsup>s\<^esup>%(name)s \<lparr> ''' \
1111        '''%(name)s_%(block)s_CL.%(field)s_CL ''' \
1112        ''':= %(sign_extend)s (\<^bsup>s\<^esup>v%(base)d AND %(mask)s)\<rparr> \<and>
1113        %(name)s_get_tag \<acute>ret__struct_%(name)s_C = ''' \
1114     '''scast %(name)s_%(block)s\<rbrace>"''',
1115        '''  apply (rule allI, rule conseqPre, vcg)
1116  apply clarsimp
1117  apply (rule context_conjI[THEN iffD1[OF conj_commute]],
1118         fastforce simp: %(name)s_get_tag_eq_x %(name)s_lift_def %(name)s_tag_defs
1119                         mask_def shift_over_ao_dists multi_shift_simps word_size
1120                         word_ao_dist word_bw_assocs)
1121  apply (simp add: %(name)s_%(block)s_lift_def %(name)s_lift_def %(name)s_tag_defs)
1122  apply ((intro conjI sign_extend_eq)?;
1123         (simp add: mask_def shift_over_ao_dists multi_shift_simps word_size
1124                    word_ao_dist word_bw_assocs word_and_max_simps))?
1125  done'''],
1126
1127    'ptr_union_get_spec_direct': [
1128        ptr_union_get_template(direct_ptr_name),
1129        ''' unfolding ptrval_def
1130 apply(rule allI, rule conseqPre, vcg)
1131 apply(clarsimp simp: typ_heap_simps h_t_valid_clift_Some_iff guard_simps
1132                      mask_shift_simps sign_extend_def' nth_is_and_neq_0
1133                      %(name)s_lift_%(block)s %(name)s_%(block)s_lift_def)
1134done
1135'''],
1136
1137    'ptr_union_get_spec_path': [
1138        ptr_union_get_template(path_ptr_name),
1139        '''unfolding ptrval_def
1140  apply(rule allI, rule conseqPre, vcg)
1141  apply(clarsimp)
1142  apply(frule h_t_valid_c_guard_cparent, simp, simp add: typ_uinfo_t_def)
1143  apply(drule h_t_valid_clift_Some_iff[THEN iffD1], erule exE)
1144  apply(frule clift_subtype, simp, simp)
1145  apply(clarsimp simp: typ_heap_simps c_guard_clift)
1146  apply(simp add: guard_simps mask_shift_simps)
1147  apply(simp add:%(name)s_%(block)s_lift_def)
1148  apply(subst %(name)s_lift_%(block)s)
1149  apply(simp add: mask_def)+
1150  done
1151  (* ptr_union_get_spec_path *)'''],
1152
1153    'ptr_union_set_spec_direct': [
1154        ptr_union_set_template(direct_ptr_name),
1155        '''sorry (* ptr_union_set_spec_direct *)'''],
1156
1157
1158    'ptr_union_set_spec_path': [
1159        ptr_union_set_template(path_ptr_name),
1160        '''  unfolding ptrval_def
1161  apply (rule allI, rule conseqPre, vcg)
1162  apply (clarsimp)
1163  apply (frule h_t_valid_c_guard_cparent, simp, simp add: typ_uinfo_t_def)
1164  apply (drule h_t_valid_clift_Some_iff[THEN iffD1], erule exE)
1165  apply (frule clift_subtype, simp, simp)
1166  apply (clarsimp simp: typ_heap_simps c_guard_clift
1167                        packed_heap_update_collapse_hrs)
1168  apply (simp add: guard_simps mask_shift_simps
1169                   %(name)s_tag_defs[THEN tag_eq_to_tag_masked_eq])?
1170  apply (simp add: parent_update_child[OF c_guard_clift]
1171                   typ_heap_simps c_guard_clift)
1172  apply (simp add: o_def %(name)s_%(block)s_lift_def)
1173  apply (simp only: %(name)s_lift_%(block)s cong: rev_conj_cong)
1174  apply (rule exI, rule conjI[rotated], rule conjI[OF _ refl])
1175   apply (simp_all add: %(name)s_get_tag_eq_x %(name)s_tag_defs mask_shift_simps)
1176  apply (intro conjI sign_extend_eq; simp add: mask_def word_ao_dist word_bw_assocs)?
1177  done'''],
1178
1179}
1180
1181
1182def make_proof(name, substs, sorry=False):
1183    result = proof_templates[name][0] % substs + '\n'
1184
1185    if sorry:
1186        result += '\nsorry'
1187    else:
1188        result += proof_templates[name][1] % substs
1189
1190    if len(proof_templates[name]) > 2:
1191        result += '\n' + '\n'.join(proof_templates[name][2:]) % substs
1192
1193    return result
1194
1195# AST objects
1196
1197
1198def emit_named(name, params, string):
1199    # Emit a named definition/proof, only when the given name is in
1200    # params.names
1201
1202    if(name in params.names):
1203        print(string, file=params.output)
1204        print(file=params.output)
1205
1206# This calculates substs for each proof, which is probably inefficient.  Meh
1207
1208
1209def emit_named_ptr_proof(fn_name, params, name, type_map, toptps, prf_prefix, substs):
1210    name_C = name + '_C'
1211
1212    if name_C in type_map:
1213        toptp, path = type_map[name_C]
1214
1215        substs['access_path'] = '(' + reduce(lambda x, y: y +
1216                                             ' (' + x + ')', ['the (ptrval s)'] + path) + ')'
1217
1218        if len(path) == 0:
1219            substs['update_path'] = name
1220            emit_named(fn_name, params, make_proof(prf_prefix + '_direct', substs, params.sorry))
1221        else:
1222            substs['toptp'] = toptp
1223            # the split here gives us the field name (less any qualifiers) as the typ_heap
1224            # stuff doesn't use the qualifier
1225            substs['path'] = ', '.join(map(lambda x: "''%s''" % x.split('.')[-1], path))
1226
1227            # The self.name here is the variable name (so not _C)
1228            path.reverse()
1229            substs['update_path'] = '(' + reduce(lambda x, y: y + '_update (' + x + ')',
1230                                                 ['\\<lambda>_. ' + name] + path) + '(the (ptrval s))' + ')'
1231            emit_named(fn_name, params, make_proof(prf_prefix + '_path', substs, params.sorry))
1232
1233
1234def field_mask_proof(base, base_bits, sign_extend, high, size):
1235    if high:
1236        if base_bits == base or sign_extend:
1237            # equivalent to below, but nicer in proofs
1238            return "NOT (mask %d)" % (base_bits - size)
1239        else:
1240            return "(mask %d << %d)" % (size, base_bits - size)
1241    else:
1242        return "mask %d" % size
1243
1244
1245def sign_extend_proof(high, base_bits, base_sign_extend):
1246    if high and base_sign_extend:
1247        return "sign_extend %d " % (base_bits - 1)
1248    else:
1249        return ""
1250
1251
1252def det_values(*dicts):
1253    """Deterministically iterate over the values of each dict in `dicts`."""
1254    def values(d):
1255        return (d[key] for key in sorted(d.keys()))
1256    return itertools.chain(*(values(d) for d in dicts))
1257
1258
1259class TaggedUnion:
1260    def __init__(self, name, tagname, classes, tags):
1261        self.name = name
1262        self.tagname = tagname
1263        self.constant_suffix = ''
1264
1265        # Check for duplicate tags
1266        used_names = set()
1267        used_values = set()
1268        for name, value in tags:
1269            if name in used_names:
1270                raise ValueError("Duplicate tag name %s" % name)
1271            if value in used_values:
1272                raise ValueError("Duplicate tag value %d" % value)
1273
1274            used_names.add(name)
1275            used_values.add(value)
1276        self.classes = dict(classes)
1277        self.tags = tags
1278
1279    def resolve(self, params, symtab):
1280        # Grab block references for tags
1281        self.tags = [(name, value, symtab[name]) for name, value in self.tags]
1282        self.make_classes(params)
1283
1284        # Ensure that block sizes and tag size & position match for
1285        # all tags in the union
1286        union_base = None
1287        union_size = None
1288        for name, value, ref in self.tags:
1289            _tag_offset, _tag_size, _tag_high = ref.field_map[self.tagname]
1290
1291            if union_base is None:
1292                union_base = ref.base
1293            elif union_base != ref.base:
1294                raise ValueError("Base mismatch for element %s"
1295                                 " of tagged union %s" % (name, self.name))
1296
1297            if union_size is None:
1298                union_size = ref.size
1299            elif union_size != ref.size:
1300                raise ValueError("Size mismatch for element %s"
1301                                 " of tagged union %s" % (name, self.name))
1302
1303            if _tag_offset != self.tag_offset[_tag_size]:
1304                raise ValueError("Tag offset mismatch for element %s"
1305                                 " of tagged union %s" % (name, self.name))
1306
1307            self.assert_value_in_class(name, value, _tag_size)
1308
1309            if _tag_high:
1310                raise ValueError("Tag field is high-aligned for element %s"
1311                                 " of tagged union %s" % (name, self.name))
1312
1313            # Flag block as belonging to a tagged union
1314            ref.tagged = True
1315
1316        self.union_base = union_base
1317        self.union_size = union_size
1318
1319    def set_base(self, base, base_bits, base_sign_extend, suffix):
1320        self.base = base
1321        self.multiple = self.union_size // base
1322        self.constant_suffix = suffix
1323        self.base_bits = base_bits
1324        self.base_sign_extend = base_sign_extend
1325
1326        tag_index = None
1327        for w in self.tag_offset:
1328            tag_offset = self.tag_offset[w]
1329
1330            if tag_index is None:
1331                tag_index = tag_offset // base
1332
1333            if (tag_offset // base) != tag_index:
1334                raise ValueError(
1335                    "The tag field of tagged union %s"
1336                    " is in a different word (%s) to the others (%s)."
1337                    % (self.name, hex(tag_offset // base), hex(tag_index)))
1338
1339    def generate_hol_proofs(self, params, type_map):
1340        output = params.output
1341
1342        # Add fixed simp rule for struct
1343        print("lemmas %(name)s_C_words_C_fl_simp[simp] = "
1344              "%(name)s_C_words_C_fl[simplified]" %
1345              {"name": self.name}, file=output)
1346        print(file=output)
1347
1348        # Generate struct field pointer proofs
1349        substs = {"name": self.name,
1350                  "words": self.multiple,
1351                  "base": self.base}
1352
1353        print(make_proof('words_NULL_proof',
1354                         substs, params.sorry), file=output)
1355        print(file=output)
1356
1357        print(make_proof('words_aligned_proof',
1358                         substs, params.sorry), file=output)
1359        print(file=output)
1360
1361        print(make_proof('words_ptr_safe_proof',
1362                         substs, params.sorry), file=output)
1363        print(file=output)
1364
1365        # Generate struct lemmas
1366        print(struct_lemmas_template % {"name": self.name},
1367              file=output)
1368        print(file=output)
1369
1370        # Generate get_tag specs
1371        substs = {"name": self.name,
1372                  "tagname": self.tagname,
1373                  "ret_name": return_name(self.base)}
1374
1375        if not params.skip_modifies:
1376            emit_named("%(name)s_get_%(tagname)s" % substs, params,
1377                       make_proof('const_modifies_proof',
1378                                  {"fun_name": "%(name)s_get_%(tagname)s" % substs,
1379                                   "args": ', '.join(["\<acute>ret__unsigned_long",
1380                                                      "\<acute>%(name)s" % substs])},
1381                                  params.sorry))
1382            emit_named("%(name)s_ptr_get_%(tagname)s" % substs, params,
1383                       make_proof('const_modifies_proof',
1384                                  {"fun_name": "%(name)s_ptr_get_%(tagname)s" % substs,
1385                                   "args": ', '.join(["\<acute>ret__unsigned_long",
1386                                                      "\<acute>%(name)s_ptr" % substs])},
1387                                  params.sorry))
1388
1389        emit_named("%s_get_%s" % (self.name, self.tagname), params,
1390                   make_proof('get_tag_spec', substs, params.sorry))
1391
1392        emit_named("%s_%s_equals" % (self.name, self.tagname), params,
1393                   make_proof('get_tag_equals_spec', substs, params.sorry))
1394
1395        # Only generate ptr lemmas for those types reachable from top level types
1396        emit_named_ptr_proof("%s_ptr_get_%s" % (self.name, self.tagname), params, self.name,
1397                             type_map, params.toplevel_types,
1398                             'ptr_get_tag_spec', substs)
1399
1400        for name, value, ref in self.tags:
1401            # Generate struct_new specs
1402            arg_list = ["\<acute>" + field
1403                        for field in ref.visible_order
1404                        if field != self.tagname]
1405
1406            # Generate modifies proof
1407            if not params.skip_modifies:
1408                emit_named("%s_%s_new" % (self.name, ref.name), params,
1409                           make_proof('const_modifies_proof',
1410                                      {"fun_name": "%s_%s_new" %
1411                                       (self.name, ref.name),
1412                                       "args": ', '.join([
1413                                           "\<acute>ret__struct_%(name)s_C" % substs] +
1414                                           arg_list)},
1415                                      params.sorry))
1416
1417                emit_named("%s_%s_ptr_new" % (self.name, ref.name), params,
1418                           make_proof('ptr_set_modifies_proof',
1419                                      {"fun_name": "%s_%s_ptr_new" %
1420                                       (self.name, ref.name),
1421                                       "args": ', '.join([
1422                                           "\<acute>ret__struct_%(name)s_C" % substs] +
1423                                           arg_list)},
1424                                      params.sorry))
1425
1426            if len(arg_list) == 0:
1427                # For an empty block:
1428                emit_named("%s_%s_new" % (self.name, ref.name), params,
1429                           make_proof('empty_union_new_spec',
1430                                      {"name": self.name,
1431                                       "block": ref.name},
1432                                      params.sorry))
1433
1434                emit_named_ptr_proof("%s_%s_ptr_new" % (self.name, ref.name), params, self.name,
1435                                     type_map, params.toplevel_types,
1436                                     'ptr_empty_union_new_spec',
1437                                     {"name": self.name,
1438                                      "block": ref.name})
1439            else:
1440                field_eq_list = []
1441                for field in ref.visible_order:
1442                    offset, size, high = ref.field_map[field]
1443
1444                    if field == self.tagname:
1445                        continue
1446
1447                    mask = field_mask_proof(self.base, self.base_bits,
1448                                            self.base_sign_extend, high, size)
1449                    sign_extend = sign_extend_proof(high, self.base_bits, self.base_sign_extend)
1450                    field_eq_list.append(
1451                        "%s_%s_CL.%s_CL = %s(\<^bsup>s\<^esup>%s AND %s)" %
1452                        (self.name, ref.name, field, sign_extend,
1453                            var_name(field, self.base), mask))
1454                field_eqs = ',\n          '.join(field_eq_list)
1455
1456                emit_named("%s_%s_new" % (self.name, ref.name), params,
1457                           make_proof('union_new_spec',
1458                                      {"name": self.name,
1459                                       "block": ref.name,
1460                                       "args": ', '.join(arg_list),
1461                                       "field_eqs": field_eqs},
1462                                      params.sorry))
1463
1464                emit_named_ptr_proof("%s_%s_ptr_new" % (self.name, ref.name), params, self.name,
1465                                     type_map, params.toplevel_types,
1466                                     'ptr_union_new_spec',
1467                                     {"name": self.name,
1468                                      "block": ref.name,
1469                                      "args": ', '.join(arg_list),
1470                                      "field_eqs": field_eqs})
1471
1472            _, size, _ = ref.field_map[self.tagname]
1473            if any([w for w in self.widths if w < size]):
1474                tag_mask_helpers = ("%s_%s_tag_mask_helpers"
1475                                    % (self.name, ref.name))
1476            else:
1477                tag_mask_helpers = ""
1478
1479            # Generate get/set specs
1480            for (field, offset, size, high) in ref.fields:
1481                if field == self.tagname:
1482                    continue
1483
1484                mask = field_mask_proof(self.base, self.base_bits,
1485                                        self.base_sign_extend, high, size)
1486                sign_extend = sign_extend_proof(high, self.base_bits, self.base_sign_extend)
1487
1488                substs = {"name":  self.name,
1489                          "block": ref.name,
1490                          "field": field,
1491                          "mask":  mask,
1492                          "sign_extend": sign_extend,
1493                          "tag_mask_helpers": tag_mask_helpers,
1494                          "ret_name": return_name(self.base),
1495                          "base": self.base}
1496
1497                # Get modifies spec
1498                if not params.skip_modifies:
1499                    emit_named("%s_%s_get_%s" % (self.name, ref.name, field),
1500                               params,
1501                               make_proof('const_modifies_proof',
1502                                          {"fun_name": "%s_%s_get_%s" %
1503                                           (self.name, ref.name, field),
1504                                              "args": ', '.join([
1505                                                  "\<acute>ret__unsigned_long",
1506                                                  "\<acute>%s" % self.name])},
1507                                          params.sorry))
1508
1509                    emit_named("%s_%s_ptr_get_%s" % (self.name, ref.name, field),
1510                               params,
1511                               make_proof('const_modifies_proof',
1512                                          {"fun_name": "%s_%s_ptr_get_%s" %
1513                                           (self.name, ref.name, field),
1514                                              "args": ', '.join([
1515                                                  "\<acute>ret__unsigned_long",
1516                                                  "\<acute>%s_ptr" % self.name])},
1517                                          params.sorry))
1518
1519                # Get spec
1520                emit_named("%s_%s_get_%s" % (self.name, ref.name, field),
1521                           params,
1522                           make_proof('union_get_spec', substs, params.sorry))
1523
1524                # Set modifies spec
1525                if not params.skip_modifies:
1526                    emit_named("%s_%s_set_%s" % (self.name, ref.name, field),
1527                               params,
1528                               make_proof('const_modifies_proof',
1529                                          {"fun_name": "%s_%s_set_%s" %
1530                                           (self.name, ref.name, field),
1531                                              "args": ', '.join([
1532                                                  "\<acute>ret__struct_%s_C" % self.name,
1533                                                  "\<acute>%s" % self.name,
1534                                                  "\<acute>v%(base)d"])},
1535                                          params.sorry))
1536
1537                    emit_named("%s_%s_ptr_set_%s" % (self.name, ref.name, field),
1538                               params,
1539                               make_proof('ptr_set_modifies_proof',
1540                                          {"fun_name": "%s_%s_ptr_set_%s" %
1541                                           (self.name, ref.name, field),
1542                                              "args": ', '.join([
1543                                                  "\<acute>%s_ptr" % self.name,
1544                                                  "\<acute>v%(base)d"])},
1545                                          params.sorry))
1546
1547                # Set spec
1548                emit_named("%s_%s_set_%s" % (self.name, ref.name, field),
1549                           params,
1550                           make_proof('union_set_spec', substs, params.sorry))
1551
1552                # Pointer get spec
1553                emit_named_ptr_proof("%s_%s_ptr_get_%s" % (self.name, ref.name, field),
1554                                     params, self.name, type_map, params.toplevel_types,
1555                                     'ptr_union_get_spec', substs)
1556
1557                # Pointer set spec
1558                emit_named_ptr_proof("%s_%s_ptr_set_%s" % (self.name, ref.name, field),
1559                                     params, self.name, type_map, params.toplevel_types,
1560                                     'ptr_union_set_spec', substs)
1561
1562    def generate_hol_defs(self, params):
1563        output = params.output
1564
1565        empty_blocks = {}
1566
1567        def gen_name(ref_name, capitalise=False):
1568            # Create datatype generator/type name for a block
1569            if capitalise:
1570                return "%s_%s" % \
1571                       (self.name[0].upper() + self.name[1:], ref_name)
1572            else:
1573                return "%s_%s" % (self.name, ref_name)
1574
1575        # Generate block records with tag field removed
1576        for name, value, ref in self.tags:
1577            if ref.generate_hol_defs(params,
1578                                     suppressed_field=self.tagname,
1579                                     prefix="%s_" % self.name,
1580                                     in_union=True):
1581                empty_blocks[ref] = True
1582
1583        constructor_exprs = []
1584        for name, value, ref in self.tags:
1585            if ref in empty_blocks:
1586                constructor_exprs.append(gen_name(name, True))
1587            else:
1588                constructor_exprs.append("%s %s_CL" %
1589                                         (gen_name(name, True), gen_name(name)))
1590
1591        print("datatype %s_CL =\n    %s\n" %
1592              (self.name, '\n  | '.join(constructor_exprs)),
1593              file=output)
1594
1595        # Generate get_tag definition
1596        subs = {"name":      self.name,
1597                "base":      self.base}
1598
1599        templates = ([union_get_tag_def_entry_template] * (len(self.widths) - 1)
1600                     + [union_get_tag_def_final_template])
1601
1602        fs = (union_get_tag_def_header_template % subs
1603              + "".join([template %
1604                         dict(subs,
1605                              tag_size=width,
1606                              classmask=self.word_classmask(width),
1607                              tag_index=self.tag_offset[width] // self.base,
1608                              tag_shift=self.tag_offset[width] % self.base)
1609                         for template, width in zip(templates, self.widths)])
1610              + union_get_tag_def_footer_template % subs)
1611
1612        print(fs, file=output)
1613        print(file=output)
1614
1615        # Generate get_tag_eq_x lemma
1616        templates = ([union_get_tag_eq_x_def_entry_template]
1617                     * (len(self.widths) - 1)
1618                     + [union_get_tag_eq_x_def_final_template])
1619
1620        fs = (union_get_tag_eq_x_def_header_template % subs
1621              + "".join([template %
1622                         dict(subs,
1623                              tag_size=width,
1624                              classmask=self.word_classmask(width),
1625                              tag_index=self.tag_offset[width] // self.base,
1626                              tag_shift=self.tag_offset[width] % self.base)
1627                         for template, width in zip(templates, self.widths)])
1628              + union_get_tag_eq_x_def_footer_template % subs)
1629
1630        print(fs, file=output)
1631        print(file=output)
1632
1633        # Generate mask helper lemmas
1634
1635        for name, value, ref in self.tags:
1636            offset, size, _ = ref.field_map[self.tagname]
1637            part_widths = [w for w in self.widths if w < size]
1638            if part_widths:
1639                subs = {"name":         self.name,
1640                        "block":        name,
1641                        "full_mask":    hex(2 ** size - 1),
1642                        "full_value":   hex(value)}
1643
1644                fs = (union_tag_mask_helpers_header_template % subs
1645                      + "".join([union_tag_mask_helpers_entry_template %
1646                                 dict(subs, part_mask=hex(2 ** pw - 1),
1647                                      part_value=hex(value & (2 ** pw - 1)))
1648                                 for pw in part_widths])
1649                      + union_tag_mask_helpers_footer_template)
1650
1651                print(fs, file=output)
1652                print(file=output)
1653
1654        # Generate lift definition
1655        collapse_proofs = ""
1656        tag_cases = []
1657        for name, value, ref in self.tags:
1658            field_inits = []
1659
1660            for field in ref.visible_order:
1661                offset, size, high = ref.field_map[field]
1662
1663                if field == self.tagname:
1664                    continue
1665
1666                index = offset // self.base
1667                sign_extend = ""
1668
1669                if high:
1670                    shift_op = "<<"
1671                    shift = self.base_bits - size - (offset % self.base)
1672                    if shift < 0:
1673                        shift = -shift
1674                        shift_op = ">>"
1675                    if self.base_sign_extend:
1676                        sign_extend = "sign_extend %d " % (self.base_bits - 1)
1677                else:
1678                    shift_op = ">>"
1679                    shift = offset % self.base
1680
1681                initialiser = \
1682                    "%s_CL.%s_CL = %s(((index (%s_C.words_C %s) %d) %s %d)" % \
1683                    (gen_name(name), field, sign_extend, self.name, self.name,
1684                     index, shift_op, shift)
1685
1686                if size < self.base:
1687                    mask = field_mask_proof(self.base, self.base_bits,
1688                                            self.base_sign_extend, high, size)
1689                    initialiser += " AND " + mask
1690
1691                field_inits.append("\n       " + initialiser + ")")
1692
1693            if len(field_inits) == 0:
1694                value = gen_name(name, True)
1695            else:
1696                value = "%s \<lparr> %s \<rparr>" % \
1697                    (gen_name(name, True), ','.join(field_inits))
1698
1699            tag_cases.append("if tag = scast %s then Some (%s)" %
1700                             (gen_name(name), value))
1701
1702            collapse_proofs += \
1703                make_proof("lift_collapse_proof",
1704                           {"name": self.name,
1705                            "block": name,
1706                            "value": value},
1707                           params.sorry)
1708            collapse_proofs += "\n\n"
1709
1710        print(union_lift_def_template %
1711              {"name": self.name,
1712               "tag_cases": '\n     else '.join(tag_cases)},
1713              file=output)
1714        print(file=output)
1715
1716        print(collapse_proofs, file=output)
1717
1718        block_lift_lemmas = "lemmas %s_lifts = \n" % self.name
1719        # Generate lifted access/update definitions, and abstract lifters
1720        for name, value, ref in self.tags:
1721            # Don't generate accessors if the block (minus tag) is empty
1722            if ref in empty_blocks:
1723                continue
1724
1725            substs = {"union": self.name,
1726                      "block": name,
1727                      "generator": gen_name(name, True)}
1728
1729            for t in [union_access_def_template, union_update_def_template]:
1730                print(t % substs, file=output)
1731                print(file=output)
1732
1733            print(block_lift_def_template % substs, file=output)
1734            print(file=output)
1735
1736            print(block_lift_lemma_template % substs, file=output)
1737            print(file=output)
1738
1739            block_lift_lemmas += "\t%(union)s_%(block)s_lift\n" % substs
1740
1741        print(block_lift_lemmas, file=output)
1742        print(file=output)
1743
1744    def generate(self, params):
1745        output = params.output
1746
1747        # Generate typedef
1748        print(typedef_template %
1749              {"type": TYPES[options.environment][self.base],
1750               "name": self.name,
1751               "multiple": self.multiple}, file=output)
1752        print(file=output)
1753
1754        # Generate tag enum
1755        print("enum %s_tag {" % self.name, file=output)
1756        if len(self.tags) > 0:
1757            for name, value, ref in self.tags[:-1]:
1758                print("    %s_%s = %d," % (self.name, name, value),
1759                      file=output)
1760            name, value, ref = self.tags[-1]
1761            print("    %s_%s = %d" % (self.name, name, value),
1762                  file=output)
1763        print("};", file=output)
1764        print("typedef enum %s_tag %s_tag_t;" %
1765              (self.name, self.name), file=output)
1766        print(file=output)
1767
1768        subs = {
1769            'inline': INLINE[options.environment],
1770            'union': self.name,
1771            'type':  TYPES[options.environment][self.union_base],
1772            'tagname': self.tagname,
1773            'suf': self.constant_suffix}
1774
1775        # Generate tag reader
1776        templates = ([tag_reader_entry_template] * (len(self.widths) - 1)
1777                     + [tag_reader_final_template])
1778
1779        fs = (tag_reader_header_template % subs
1780              + "".join([template %
1781                         dict(subs,
1782                              mask=2 ** width - 1,
1783                              classmask=self.word_classmask(width),
1784                              index=self.tag_offset[width] // self.base,
1785                              shift=self.tag_offset[width] % self.base)
1786                         for template, width in zip(templates, self.widths)])
1787              + tag_reader_footer_template % subs)
1788
1789        emit_named("%s_get_%s" % (self.name, self.tagname), params, fs)
1790
1791        # Generate tag eq reader
1792        templates = ([tag_eq_reader_entry_template] * (len(self.widths) - 1)
1793                     + [tag_eq_reader_final_template])
1794
1795        fs = (tag_eq_reader_header_template % subs
1796              + "".join([template %
1797                         dict(subs,
1798                              mask=2 ** width - 1,
1799                              classmask=self.word_classmask(width),
1800                              index=self.tag_offset[width] // self.base,
1801                              shift=self.tag_offset[width] % self.base)
1802                         for template, width in zip(templates, self.widths)])
1803              + tag_eq_reader_footer_template % subs)
1804
1805        emit_named("%s_%s_equals" % (self.name, self.tagname), params, fs)
1806
1807        # Generate pointer lifted tag reader
1808        templates = ([ptr_tag_reader_entry_template] * (len(self.widths) - 1)
1809                     + [ptr_tag_reader_final_template])
1810
1811        fs = (ptr_tag_reader_header_template % subs
1812              + "".join([template %
1813                         dict(subs,
1814                              mask=2 ** width - 1,
1815                              classmask=self.word_classmask(width),
1816                              index=self.tag_offset[width] // self.base,
1817                              shift=self.tag_offset[width] % self.base)
1818                         for template, width in zip(templates, self.widths)])
1819              + ptr_tag_reader_footer_template % subs)
1820
1821        emit_named("%s_ptr_get_%s" % (self.name, self.tagname), params, fs)
1822
1823        for name, value, ref in self.tags:
1824            # Generate generators
1825            param_fields = [field for field in ref.visible_order if field != self.tagname]
1826            param_list = ["%s %s" % (TYPES[options.environment][self.base], field)
1827                          for field in param_fields]
1828
1829            if len(param_list) == 0:
1830                gen_params = 'void'
1831            else:
1832                gen_params = ', '.join(param_list)
1833
1834            ptr_params = ', '.join(["%s_t *%s_ptr" % (self.name, self.name)] + param_list)
1835
1836            field_updates = {word: [] for word in range(self.multiple)}
1837            field_asserts = ["    /* fail if user has passed bits that we will override */"]
1838
1839            for field in ref.visible_order:
1840                offset, size, high = ref.field_map[field]
1841
1842                if field == self.tagname:
1843                    f_value = "(%s)%s_%s" % (TYPES[options.environment][self.base], self.name, name)
1844                else:
1845                    f_value = field
1846
1847                index = offset // self.base
1848                if high:
1849                    shift_op = ">>"
1850                    shift = self.base_bits - size - (offset % self.base)
1851                    if self.base_sign_extend:
1852                        high_bits = ((self.base_sign_extend << (
1853                            self.base - self.base_bits)) - 1) << self.base_bits
1854                    else:
1855                        high_bits = 0
1856                    if shift < 0:
1857                        shift = -shift
1858                        shift_op = "<<"
1859                else:
1860                    shift_op = "<<"
1861                    shift = offset % self.base
1862                    high_bits = 0
1863                if size < self.base:
1864                    if high:
1865                        mask = ((1 << size) - 1) << (self.base_bits - size)
1866                    else:
1867                        mask = (1 << size) - 1
1868                    suf = self.constant_suffix
1869
1870                    field_asserts.append(
1871                        "    %s((%s & ~0x%x%s) == ((%d && (%s & (1%s << %d))) ? 0x%x : 0));"
1872                        % (ASSERTS[options.environment], f_value, mask, suf, self.base_sign_extend,
1873                           f_value, suf, self.base_bits - 1, high_bits))
1874
1875                    field_updates[index].append(
1876                        "(%s & 0x%x%s) %s %d" % (f_value, mask, suf, shift_op, shift))
1877
1878                else:
1879                    field_updates[index].append("%s %s %d" % (f_value, shift_op, shift))
1880
1881            word_inits = [
1882                ("words[%d] = 0" % index) + ''.join(["\n        | %s" % up for up in ups]) + ';'
1883                for (index, ups) in field_updates.items()]
1884
1885            def mk_inits(prefix):
1886                return '\n'.join(["    %s%s" % (prefix, word_init) for word_init in word_inits])
1887
1888            print_params = {
1889                "inline":     INLINE[options.environment],
1890                "block":      name,
1891                "union":      self.name,
1892                "gen_params": gen_params,
1893                "ptr_params": ptr_params,
1894                "gen_inits":  mk_inits("%s." % self.name),
1895                "ptr_inits":  mk_inits("%s_ptr->" % self.name),
1896                "asserts": '  \n'.join(field_asserts)
1897            }
1898
1899            generator = union_generator_template % print_params
1900            ptr_generator = ptr_union_generator_template % print_params
1901
1902            emit_named("%s_%s_new" % (self.name, name), params, generator)
1903            emit_named("%s_%s_ptr_new" % (self.name, name), params, ptr_generator)
1904
1905            # Generate field readers/writers
1906            tagnameoffset, tagnamesize, _ = ref.field_map[self.tagname]
1907            tagmask = (2 ** tagnamesize) - 1
1908            for field, offset, size, high in ref.fields:
1909                # Don't duplicate tag accessors
1910                if field == self.tagname:
1911                    continue
1912
1913                index = offset // self.base
1914                if high:
1915                    write_shift = ">>"
1916                    read_shift = "<<"
1917                    shift = self.base_bits - size - (offset % self.base)
1918                    if shift < 0:
1919                        shift = -shift
1920                        write_shift = "<<"
1921                        read_shift = ">>"
1922                    if self.base_sign_extend:
1923                        high_bits = ((self.base_sign_extend << (
1924                            self.base - self.base_bits)) - 1) << self.base_bits
1925                    else:
1926                        high_bits = 0
1927                else:
1928                    write_shift = "<<"
1929                    read_shift = ">>"
1930                    shift = offset % self.base
1931                    high_bits = 0
1932                mask = ((1 << size) - 1) << (offset % self.base)
1933
1934                subs = {
1935                    "inline": INLINE[options.environment],
1936                    "block": ref.name,
1937                    "field": field,
1938                    "type": TYPES[options.environment][ref.base],
1939                    "assert": ASSERTS[options.environment],
1940                    "index": index,
1941                    "shift": shift,
1942                    "r_shift_op": read_shift,
1943                    "w_shift_op": write_shift,
1944                    "mask": mask,
1945                    "tagindex": tagnameoffset // self.base,
1946                    "tagshift": tagnameoffset % self.base,
1947                    "tagmask": tagmask,
1948                    "union": self.name,
1949                    "suf": self.constant_suffix,
1950                    "high_bits": high_bits,
1951                    "sign_extend": self.base_sign_extend and high,
1952                    "extend_bit": self.base_bits - 1,
1953                    "base": self.base}
1954
1955                # Reader
1956                emit_named("%s_%s_get_%s" % (self.name, ref.name, field),
1957                           params, union_reader_template % subs)
1958
1959                # Writer
1960                emit_named("%s_%s_set_%s" % (self.name, ref.name, field),
1961                           params, union_writer_template % subs)
1962
1963                # Pointer lifted reader
1964                emit_named("%s_%s_ptr_get_%s" % (self.name, ref.name, field),
1965                           params, ptr_union_reader_template % subs)
1966
1967                # Pointer lifted writer
1968                emit_named("%s_%s_ptr_set_%s" % (self.name, ref.name, field),
1969                           params, ptr_union_writer_template % subs)
1970
1971    def make_names(self):
1972        "Return the set of candidate function names for a union"
1973
1974        substs = {"union": self.name,
1975                  "tagname": self.tagname}
1976        names = [t % substs for t in [
1977            "%(union)s_get_%(tagname)s",
1978            "%(union)s_ptr_get_%(tagname)s",
1979            "%(union)s_%(tagname)s_equals"]]
1980
1981        for name, value, ref in self.tags:
1982            names += ref.make_names(self)
1983
1984        return names
1985
1986    def represent_value(self, value, width):
1987        max_width = max(self.classes.keys())
1988
1989        tail_str = ("{:0{}b}".format(value, width)
1990                    + "_" * (self.tag_offset[width] - self.class_offset))
1991        head_str = "_" * ((max_width + self.tag_offset[max_width]
1992                           - self.class_offset) - len(tail_str))
1993
1994        return head_str + tail_str
1995
1996    def represent_class(self, width):
1997        max_width = max(self.classes.keys())
1998
1999        cmask = self.classes[width]
2000        return ("{:0{}b}".format(cmask, max_width).replace("0", "-")
2001                + " ({:#x})".format(cmask))
2002
2003    def represent_field(self, width):
2004        max_width = max(self.classes.keys())
2005        offset = self.tag_offset[width] - self.class_offset
2006
2007        return ("{:0{}b}".format((2 ** width - 1) << offset, max_width)
2008                .replace("0", "-").replace("1", "#"))
2009
2010    def assert_value_in_class(self, name, value, width):
2011        max_width = max(self.classes.keys())
2012        ovalue = value << self.tag_offset[width]
2013        cvalue = value << (self.tag_offset[width] - self.class_offset)
2014
2015        offset_field = (2 ** width - 1) << self.tag_offset[width]
2016        if (ovalue | offset_field) != offset_field:
2017            raise ValueError(
2018                "The value for element %s of tagged union %s,\n"
2019                "    %s\nexceeds the field bounds\n"
2020                "    %s."
2021                % (name, self.name,
2022                   self.represent_value(value, width),
2023                   self.represent_field(width)))
2024
2025        for w, mask in [(lw, self.classes[lw])
2026                        for lw in self.widths if lw < width]:
2027            if (cvalue & mask) != mask:
2028                raise ValueError(
2029                    "The value for element %s of tagged union %s,\n"
2030                    "    %s\nis invalid: it has %d bits but fails"
2031                    " to match the earlier mask at %d bits,\n"
2032                    "    %s."
2033                    % (name, self.name,
2034                       self.represent_value(value, width),
2035                       width, w, self.represent_class(w)))
2036
2037        if (self.widths.index(width) + 1 < len(self.widths) and
2038                (cvalue & self.classes[width]) == self.classes[width]):
2039            raise ValueError(
2040                "The value for element %s of tagged union %s,\n"
2041                "    %s (%d/%s)\nis invalid: it must not match the "
2042                "mask for %d bits,\n    %s."
2043                % (name, self.name,
2044                   ("{:0%db}" % width).format(cvalue),
2045                   value, hex(value),
2046                   width,
2047                   self.represent_class(width)))
2048
2049    def word_classmask(self, width):
2050        "Return a class mask for testing a whole word, i.e., one."
2051        "that is positioned absolutely relative to the lsb of the"
2052        "relevant word."
2053
2054        return (self.classes[width] << (self.class_offset % self.base))
2055
2056    def make_classes(self, params):
2057        "Calculate an encoding for variable width tagnames"
2058
2059        # Check self.classes, which maps from the bit width of tagname in a
2060        # particular block to a classmask that identifies when a value belongs
2061        # to wider tagname.
2062        #
2063        # For example, given three possible field widths -- 4, 8, and 12 bits --
2064        # one possible encoding is:
2065        #
2066        #                       * * _ _     (** != 11)
2067        #             0 _ _ _   1 1 _ _
2068        #   _ _ _ _   1 _ _ _   1 1 _ _
2069        #
2070        # where the 3rd and 4th lsbs signify whether the field should be
2071        # interpreted using a 4-bit mask (if 00, 01, or 10) or as an 8 or 16 bit
2072        # mask (if 11). And, in the latter case, the 8th lsb signifies whether
2073        # to intrepret it as an 8 bit field (if 0) or a 16 bit field (if 1).
2074        #
2075        # In this example we have:
2076        #   4-bit class:  classmask = 0b00001100
2077        #   8-bit class:  classmask = 0b10001100
2078        #  16-bit class:  classmask = 0b10001100
2079        #
2080        # More generally, the fields need not all start at the same offset
2081        # (measured "left" from the lsb), for example:
2082        #
2083        #    ...# ###. .... ....       4-bit field at offset 9
2084        #    ..## #### ##.. ....       8-bit field at offset 6
2085        #    #### #### #### ....      12-bit field at offset 4
2086        #
2087        # In this case, the class_offset is the minimum offset (here 4).
2088        # Classmasks are declared relative to the field, but converted
2089        # internally to be relative to the class_offset; tag_offsets
2090        # are absolute (relative to the lsb); values are relative to
2091        # the tag_offset (i.e., within the field). for example:
2092        #
2093        #    ...1 100. ....    4-bit class: classmask=0xc   tag_offset=9
2094        #    ..01 1000 10..    8-bit class: classmask=0x62  tag_offset=6
2095        #    0001 1000 1000   16-bit class: classmask=0x188 tag_offset=4
2096
2097        used = set()
2098        self.tag_offset = {}
2099        for name, _, ref in self.tags:
2100            offset, size, _ = ref.field_map[self.tagname]
2101            used.add(size)
2102            self.tag_offset[size] = offset
2103
2104        self.class_offset = min(self.tag_offset.values())
2105
2106        # internally, classmasks are relative to the class_offset, so
2107        # that we can compare them to each other.
2108        for w in self.classes:
2109            self.classes[w] <<= self.tag_offset[w] - self.class_offset
2110
2111        used_widths = sorted(list(used))
2112        assert(len(used_widths) > 0)
2113
2114        if not self.classes:
2115            self.classes = {used_widths[0]: 0}
2116
2117        # sanity checks on classes
2118        classes = self.classes
2119        widths = sorted(self.classes.keys())
2120        context = "masks for %s.%s" % (self.name, self.tagname)
2121        class_offset = self.class_offset
2122
2123        for uw in used_widths:
2124            if uw not in classes:
2125                raise ValueError("%s: none defined for a field of %d bits."
2126                                 % (context, uw))
2127
2128        for mw in classes.keys():
2129            if mw not in used_widths:
2130                raise ValueError(
2131                    "%s: there is a mask with %d bits but no corresponding fields."
2132                    % (context, mw))
2133
2134        for w in widths:
2135            offset_field = (2 ** w - 1) << self.tag_offset[w]
2136            if (classes[w] << class_offset) | offset_field != offset_field:
2137                raise ValueError(
2138                    "{:s}: the mask for {:d} bits:\n  {:s}\n"
2139                    "exceeds the field bounds:\n  {:s}."
2140                    .format(context, w, self.represent_class(w),
2141                            self.represent_field(w)))
2142
2143        if len(widths) > 1 and classes[widths[0]] == 0:
2144            raise ValueError("%s: the first (width %d) is zero." % (
2145                context, widths[0]))
2146
2147        if any([classes[widths[i-1]] == classes[widths[i]]
2148                for i in range(1, len(widths) - 1)]):
2149            raise ValueError("%s: there is a non-final duplicate!" % context)
2150
2151        # smaller masks are included within larger masks
2152        pre_mask = None
2153        pre_width = None
2154        for w in widths:
2155            if pre_mask is not None and (classes[w] & pre_mask) != pre_mask:
2156                raise ValueError(
2157                    "{:s}: the mask\n  0b{:b} for width {:d} does not include "
2158                    "the mask\n  0b{:b} for width {:d}.".format(
2159                        context, classes[w], w, pre_mask, pre_width))
2160            pre_width = w
2161            pre_mask = classes[w]
2162
2163        if params.showclasses:
2164            print("-----%s.%s" % (self.name, self.tagname), file=sys.stderr)
2165            for w in widths:
2166                print("{:2d} = {:s}".format(w, self.represent_class(w)),
2167                      file=sys.stderr)
2168
2169        self.widths = widths
2170
2171
2172class Block:
2173    def __init__(self, name, fields, visible_order):
2174        offset = 0
2175        _fields = []
2176        self.size = sum(size for _name, size, _high in fields)
2177        offset = self.size
2178        self.constant_suffix = ''
2179
2180        if visible_order is None:
2181            self.visible_order = []
2182
2183        for _name, _size, _high in fields:
2184            offset -= _size
2185            if not _name is None:
2186                if visible_order is None:
2187                    self.visible_order.append(_name)
2188                _fields.append((_name, offset, _size, _high))
2189
2190        self.name = name
2191        self.tagged = False
2192        self.fields = _fields
2193        self.field_map = dict((name, (offset, size, high))
2194                              for name, offset, size, high in _fields)
2195
2196        if not visible_order is None:
2197            missed_fields = set(self.field_map.keys())
2198
2199            for _name in visible_order:
2200                if _name not in self.field_map:
2201                    raise ValueError("Nonexistent field '%s' in visible_order"
2202                                     % _name)
2203                missed_fields.remove(_name)
2204
2205            if len(missed_fields) > 0:
2206                raise ValueError("Fields %s missing from visible_order" %
2207                                 str([x for x in missed_fields]))
2208
2209            self.visible_order = visible_order
2210
2211    def set_base(self, base, base_bits, base_sign_extend, suffix):
2212        self.base = base
2213        self.constant_suffix = suffix
2214        self.base_bits = base_bits
2215        self.base_sign_extend = base_sign_extend
2216        if self.size % base != 0:
2217            raise ValueError("Size of block %s not a multiple of base"
2218                             % self.name)
2219        self.multiple = self.size // base
2220        for name, offset, size, high in self.fields:
2221            if offset // base != (offset+size-1) // base:
2222                raise ValueError("Field %s of block %s "
2223                                 "crosses a word boundary"
2224                                 % (name, self.name))
2225
2226    def generate_hol_defs(self, params, suppressed_field=None,
2227                          prefix="", in_union=False):
2228        output = params.output
2229
2230        # Don't generate raw records for blocks in tagged unions
2231        if self.tagged and not in_union:
2232            return
2233
2234        _name = prefix + self.name
2235
2236        # Generate record def
2237        out = "record %s_CL =\n" % _name
2238
2239        empty = True
2240
2241        for field in self.visible_order:
2242            if suppressed_field == field:
2243                continue
2244
2245            empty = False
2246
2247            out += '    %s_CL :: "word%d"\n' % (field, self.base)
2248
2249        word_updates = ""
2250
2251        if not empty:
2252            print(out, file=output)
2253
2254        # Generate lift definition
2255        if not in_union:
2256            field_inits = []
2257
2258            for name in self.visible_order:
2259                offset, size, high = self.field_map[name]
2260
2261                index = offset // self.base
2262                sign_extend = ""
2263
2264                if high:
2265                    shift_op = "<<"
2266                    shift = self.base_bits - size - (offset % self.base)
2267                    if shift < 0:
2268                        shift = -shift
2269                        shift_op = ">>"
2270                    if self.base_sign_extend:
2271                        sign_extend = "sign_extend %d " % (self.base_bits - 1)
2272                else:
2273                    shift_op = ">>"
2274                    shift = offset % self.base
2275
2276                initialiser = \
2277                    "%s_CL.%s_CL = %s(((index (%s_C.words_C %s) %d) %s %d)" % \
2278                    (self.name, name, sign_extend, self.name, self.name,
2279                     index, shift_op, shift)
2280
2281                if size < self.base:
2282                    mask = field_mask_proof(self.base, self.base_bits,
2283                                            self.base_sign_extend, high, size)
2284
2285                    initialiser += " AND " + mask
2286
2287                field_inits.append(initialiser + ")")
2288
2289            print(lift_def_template %
2290                  {"name": self.name,
2291                   "fields": ',\n       '.join(field_inits)},
2292                  file=output)
2293            print(file=output)
2294
2295        return empty
2296
2297    def generate_hol_proofs(self, params, type_map):
2298        output = params.output
2299
2300        if self.tagged:
2301            return
2302
2303        # Add fixed simp rule for struct
2304        print("lemmas %(name)s_C_words_C_fl_simp[simp] = "
2305              "%(name)s_C_words_C_fl[simplified]" %
2306              {"name": self.name}, file=output)
2307        print(file=output)
2308
2309        # Generate struct field pointer proofs
2310        substs = {"name": self.name,
2311                  "words": self.multiple,
2312                  "base": self.base}
2313
2314        print(make_proof('words_NULL_proof',
2315                         substs, params.sorry), file=output)
2316        print(file=output)
2317
2318        print(make_proof('words_aligned_proof',
2319                         substs, params.sorry), file=output)
2320        print(file=output)
2321
2322        print(make_proof('words_ptr_safe_proof',
2323                         substs, params.sorry), file=output)
2324        print(file=output)
2325
2326        # Generate struct lemmas
2327        print(struct_lemmas_template % {"name": self.name},
2328              file=output)
2329        print(file=output)
2330
2331        # Generate struct_new specs
2332        arg_list = ["\<acute>" + field for field in self.visible_order]
2333
2334        if not params.skip_modifies:
2335            emit_named("%s_new" % self.name, params,
2336                       make_proof('const_modifies_proof',
2337                                  {"fun_name": "%s_new" % self.name,
2338                                   "args": ', '.join(["\<acute>ret__struct_%s_C" %
2339                                                      self.name] +
2340                                                     arg_list)},
2341                                  params.sorry))
2342            # FIXME: ptr_new (doesn't seem to be used)
2343
2344        field_eq_list = []
2345        for field in self.visible_order:
2346            offset, size, high = self.field_map[field]
2347            mask = field_mask_proof(self.base, self.base_bits, self.base_sign_extend, high, size)
2348            sign_extend = sign_extend_proof(high, self.base_bits, self.base_sign_extend)
2349
2350            field_eq_list.append("%s_CL.%s_CL = %s(\<^bsup>s\<^esup>%s AND %s)" %
2351                                 (self.name, field, sign_extend, field, mask))
2352        field_eqs = ',\n          '.join(field_eq_list)
2353
2354        emit_named("%s_new" % self.name, params,
2355                   make_proof('new_spec',
2356                              {"name": self.name,
2357                               "args": ', '.join(arg_list),
2358                               "field_eqs": field_eqs},
2359                              params.sorry))
2360
2361        emit_named_ptr_proof("%s_ptr_new" % self.name, params, self.name,
2362                             type_map, params.toplevel_types,
2363                             'ptr_new_spec',
2364                             {"name": self.name,
2365                              "args": ', '.join(arg_list),
2366                              "field_eqs": field_eqs})
2367
2368        # Generate get/set specs
2369        for (field, offset, size, high) in self.fields:
2370            mask = field_mask_proof(self.base, self.base_bits, self.base_sign_extend, high, size)
2371            sign_extend = sign_extend_proof(high, self.base_bits, self.base_sign_extend)
2372
2373            substs = {"name": self.name,
2374                      "field": field,
2375                      "mask": mask,
2376                      "sign_extend": sign_extend,
2377                      "ret_name": return_name(self.base),
2378                      "base": self.base}
2379
2380            if not params.skip_modifies:
2381                # Get modifies spec
2382                emit_named("%s_get_%s" % (self.name, field), params,
2383                           make_proof('const_modifies_proof',
2384                                      {"fun_name": "%s_get_%s" % (self.name, field),
2385                                       "args": ', '.join([
2386                                           "\<acute>ret__unsigned_long",
2387                                           "\<acute>%s" % self.name])},
2388                                      params.sorry))
2389
2390                # Ptr get modifies spec
2391                emit_named("%s_ptr_get_%s" % (self.name, field), params,
2392                           make_proof('const_modifies_proof',
2393                                      {"fun_name": "%s_ptr_get_%s" % (self.name, field),
2394                                       "args": ', '.join([
2395                                           "\<acute>ret__unsigned_long",
2396                                           "\<acute>%s_ptr" % self.name])},
2397                                      params.sorry))
2398
2399            # Get spec
2400            emit_named("%s_get_%s" % (self.name, field), params,
2401                       make_proof('get_spec', substs, params.sorry))
2402
2403            if not params.skip_modifies:
2404                # Set modifies spec
2405                emit_named("%s_set_%s" % (self.name, field), params,
2406                           make_proof('const_modifies_proof',
2407                                      {"fun_name": "%s_set_%s" % (self.name, field),
2408                                       "args": ', '.join([
2409                                           "\<acute>ret__struct_%s_C" % self.name,
2410                                           "\<acute>%s" % self.name,
2411                                           "\<acute>v%(base)d"])},
2412                                      params.sorry))
2413
2414                emit_named("%s_ptr_set_%s" % (self.name, field), params,
2415                           make_proof('ptr_set_modifies_proof',
2416                                      {"fun_name": "%s_ptr_set_%s" % (self.name, field),
2417                                       "args": ', '.join([
2418                                           "\<acute>%s_ptr" % self.name,
2419                                           "\<acute>v%(base)d"])},
2420                                      params.sorry))
2421
2422            # Set spec
2423            emit_named("%s_set_%s" % (self.name, field), params,
2424                       make_proof('set_spec', substs, params.sorry))
2425
2426            emit_named_ptr_proof("%s_ptr_get_%s" % (self.name, field), params, self.name,
2427                                 type_map, params.toplevel_types,
2428                                 'ptr_get_spec', substs)
2429            emit_named_ptr_proof("%s_ptr_set_%s" % (self.name, field), params, self.name,
2430                                 type_map, params.toplevel_types,
2431                                 'ptr_set_spec', substs)
2432
2433    def generate(self, params):
2434        output = params.output
2435
2436        # Don't generate raw accessors for blocks in tagged unions
2437        if self.tagged:
2438            return
2439
2440        # Type definition
2441        print(typedef_template %
2442              {"type": TYPES[options.environment][self.base],
2443               "name": self.name,
2444               "multiple": self.multiple}, file=output)
2445        print(file=output)
2446
2447        # Generator
2448        param_fields = [field for field in self.visible_order]
2449        param_list = ["%s %s" % (TYPES[options.environment][self.base], field)
2450                      for field in param_fields]
2451
2452        if len(param_list) == 0:
2453            gen_params = 'void'
2454        else:
2455            gen_params = ', '.join(param_list)
2456
2457        ptr_params = ', '.join(["%s_t *%s_ptr" % (self.name, self.name)] + param_list)
2458
2459        field_updates = {word: [] for word in range(self.multiple)}
2460        field_asserts = ["    /* fail if user has passed bits that we will override */"]
2461
2462        for field, offset, size, high in self.fields:
2463            index = offset // self.base
2464            if high:
2465                shift_op = ">>"
2466                shift = self.base_bits - size - (offset % self.base)
2467                if self.base_sign_extend:
2468                    high_bits = ((self.base_sign_extend << (
2469                        self.base - self.base_bits)) - 1) << self.base_bits
2470                else:
2471                    high_bits = 0
2472                if shift < 0:
2473                    shift = -shift
2474                    shift_op = "<<"
2475            else:
2476                shift_op = "<<"
2477                shift = offset % self.base
2478                high_bits = 0
2479            if size < self.base:
2480                if high:
2481                    mask = ((1 << size) - 1) << (self.base_bits - size)
2482                else:
2483                    mask = (1 << size) - 1
2484                suf = self.constant_suffix
2485
2486                field_asserts.append(
2487                    "    %s((%s & ~0x%x%s) == ((%d && (%s & (1%s << %d))) ? 0x%x : 0));"
2488                    % (ASSERTS[options.environment], field, mask, suf, self.base_sign_extend,
2489                       field, suf, self.base_bits - 1, high_bits))
2490
2491                field_updates[index].append(
2492                    "(%s & 0x%x%s) %s %d" % (field, mask, suf, shift_op, shift))
2493
2494            else:
2495                field_updates[index].append("%s %s %d;" % (field, shift_op, shift))
2496
2497        word_inits = [
2498            ("words[%d] = 0" % index) + ''.join(["\n        | %s" % up for up in ups]) + ';'
2499            for (index, ups) in field_updates.items()]
2500
2501        def mk_inits(prefix):
2502            return '\n'.join(["    %s%s" % (prefix, word_init) for word_init in word_inits])
2503
2504        print_params = {
2505            "inline": INLINE[options.environment],
2506            "block": self.name,
2507            "gen_params": gen_params,
2508            "ptr_params": ptr_params,
2509            "gen_inits": mk_inits("%s." % self.name),
2510            "ptr_inits": mk_inits("%s_ptr->" % self.name),
2511            "asserts": '  \n'.join(field_asserts)
2512        }
2513
2514        generator = generator_template % print_params
2515        ptr_generator = ptr_generator_template % print_params
2516
2517        emit_named("%s_new" % self.name, params, generator)
2518        emit_named("%s_ptr_new" % self.name, params, ptr_generator)
2519
2520        # Accessors
2521        for field, offset, size, high in self.fields:
2522            index = offset // self.base
2523            if high:
2524                write_shift = ">>"
2525                read_shift = "<<"
2526                shift = self.base_bits - size - (offset % self.base)
2527                if shift < 0:
2528                    shift = -shift
2529                    write_shift = "<<"
2530                    read_shift = ">>"
2531                if self.base_sign_extend:
2532                    high_bits = ((self.base_sign_extend << (
2533                        self.base - self.base_bits)) - 1) << self.base_bits
2534                else:
2535                    high_bits = 0
2536            else:
2537                write_shift = "<<"
2538                read_shift = ">>"
2539                shift = offset % self.base
2540                high_bits = 0
2541            mask = ((1 << size) - 1) << (offset % self.base)
2542
2543            subs = {
2544                "inline": INLINE[options.environment],
2545                "block": self.name,
2546                "field": field,
2547                "type": TYPES[options.environment][self.base],
2548                "assert": ASSERTS[options.environment],
2549                "index": index,
2550                "shift": shift,
2551                "r_shift_op": read_shift,
2552                "w_shift_op": write_shift,
2553                "mask": mask,
2554                "suf": self.constant_suffix,
2555                "high_bits": high_bits,
2556                "sign_extend": self.base_sign_extend and high,
2557                "extend_bit": self.base_bits - 1,
2558                "base": self.base}
2559
2560            # Reader
2561            emit_named("%s_get_%s" % (self.name, field), params,
2562                       reader_template % subs)
2563
2564            # Writer
2565            emit_named("%s_set_%s" % (self.name, field), params,
2566                       writer_template % subs)
2567
2568            # Pointer lifted reader
2569            emit_named("%s_ptr_get_%s" % (self.name, field), params,
2570                       ptr_reader_template % subs)
2571
2572            # Pointer lifted writer
2573            emit_named("%s_ptr_set_%s" % (self.name, field), params,
2574                       ptr_writer_template % subs)
2575
2576    def make_names(self, union=None):
2577        "Return the set of candidate function names for a block"
2578
2579        if union is None:
2580            # Don't generate raw accessors for blocks in tagged unions
2581            if self.tagged:
2582                return []
2583
2584            substs = {"block": self.name}
2585
2586            # A standalone block
2587            field_templates = [
2588                "%(block)s_get_%(field)s",
2589                "%(block)s_set_%(field)s",
2590                "%(block)s_ptr_get_%(field)s",
2591                "%(block)s_ptr_set_%(field)s"]
2592
2593            names = [t % substs for t in [
2594                "%(block)s_new",
2595                "%(block)s_ptr_new"]]
2596        else:
2597            substs = {"block": self.name,
2598                      "union": union.name}
2599
2600            # A tagged union block
2601            field_templates = [
2602                "%(union)s_%(block)s_get_%(field)s",
2603                "%(union)s_%(block)s_set_%(field)s",
2604                "%(union)s_%(block)s_ptr_get_%(field)s",
2605                "%(union)s_%(block)s_ptr_set_%(field)s"]
2606
2607            names = [t % substs for t in [
2608                "%(union)s_%(block)s_new",
2609                "%(union)s_%(block)s_ptr_new"]]
2610
2611        for field, offset, size, high in self.fields:
2612            if not union is None and field == union.tagname:
2613                continue
2614
2615            substs["field"] = field
2616            names += [t % substs for t in field_templates]
2617
2618        return names
2619
2620
2621temp_output_files = []
2622
2623
2624class OutputFile(object):
2625    def __init__(self, filename, mode='w', atomic=True):
2626        """Open an output file for writing, recording its filename.
2627           If atomic is True, use a temporary file for writing.
2628           Call finish_output to finalise all temporary files."""
2629        self.filename = os.path.abspath(filename)
2630        if atomic:
2631            dirname, basename = os.path.split(self.filename)
2632            self.file = tempfile.NamedTemporaryFile(
2633                mode=mode, dir=dirname, prefix=basename + '.', delete=False)
2634            if DEBUG:
2635                print('Temp file: %r -> %r' % (self.file.name, self.filename), file=sys.stderr)
2636            global temp_output_files
2637            temp_output_files.append(self)
2638        else:
2639            self.file = open(filename, mode)
2640
2641    def write(self, *args, **kwargs):
2642        self.file.write(*args, **kwargs)
2643
2644
2645def finish_output():
2646    global temp_output_files
2647    for f in temp_output_files:
2648        os.rename(f.file.name, f.filename)
2649    temp_output_files = []
2650
2651
2652# Toplevel
2653if __name__ == '__main__':
2654    # Parse arguments to set mode and grab I/O filenames
2655    params = {}
2656    in_filename = None
2657    in_file = sys.stdin
2658    out_file = sys.stdout
2659    mode = 'c_defs'
2660
2661    parser = optparse.OptionParser()
2662    parser.add_option('--c_defs', action='store_true', default=False)
2663    parser.add_option('--environment', action='store', default='sel4',
2664                      choices=list(INCLUDES.keys()))
2665    parser.add_option('--hol_defs', action='store_true', default=False)
2666    parser.add_option('--hol_proofs', action='store_true', default=False)
2667    parser.add_option('--sorry_lemmas', action='store_true',
2668                      dest='sorry', default=False)
2669    parser.add_option('--prune', action='append',
2670                      dest="prune_files", default=[])
2671    parser.add_option('--toplevel', action='append',
2672                      dest="toplevel_types", default=[])
2673    parser.add_option('--umm_types', action='store',
2674                      dest="umm_types_file", default=None)
2675    parser.add_option('--multifile_base', action='store', default=None)
2676    parser.add_option('--cspec-dir', action='store', default=None,
2677                      help="Location of the 'cspec' directory containing 'KernelState_C'.")
2678    parser.add_option('--thy-output-path', action='store', default=None,
2679                      help="Path that the output theory files will be located in.")
2680    parser.add_option('--skip_modifies', action='store_true', default=False)
2681    parser.add_option('--showclasses', action='store_true', default=False)
2682    parser.add_option('--debug', action='store_true', default=False)
2683
2684    options, args = parser.parse_args()
2685    DEBUG = options.debug
2686
2687    if len(args) > 0:
2688        in_filename = args[0]
2689        in_file = open(in_filename)
2690
2691        if len(args) > 1:
2692            out_file = OutputFile(args[1])
2693
2694    #
2695    # If generating Isabelle scripts, ensure we have enough information for
2696    # relative paths required by Isabelle.
2697    #
2698    if options.hol_defs or options.hol_proofs:
2699        # Ensure directory that we need to include is known.
2700        if options.cspec_dir is None:
2701            parser.error("'cspec_dir' not defined.")
2702
2703        # Ensure that if an output file was not specified, an output path was.
2704        if len(args) <= 1:
2705            if options.thy_output_path is None:
2706                parser.error("Theory output path was not specified")
2707            if out_file == sys.stdout:
2708                parser.error(
2709                    'Output file name must be given when generating HOL definitions or proofs')
2710            out_file.filename = os.path.abspath(options.thy_output_path)
2711
2712    if options.hol_proofs and not options.umm_types_file:
2713        parser.error('--umm_types must be specified when generating HOL proofs')
2714
2715    del parser
2716
2717    options.output = out_file
2718
2719    # Parse the spec
2720    lexer = lex.lex()
2721    yacc.yacc(debug=0, write_tables=0)
2722    blocks = {}
2723    unions = {}
2724    _, block_map, union_map = yacc.parse(input=in_file.read(), lexer=lexer)
2725    base_list = [8, 16, 32, 64]
2726    # assumes that unsigned int = 32 bit on 32-bit and 64-bit platforms,
2727    # and that unsigned long long = 64 bit on 64-bit platforms.
2728    # Should still work fine if ull = 128 bit, but will not work
2729    # if unsigned int is less than 32 bit.
2730    suffix_map = {8: 'u', 16: 'u', 32: 'u', 64: 'ull'}
2731    for base_info, block_list in block_map.items():
2732        base, base_bits, base_sign_extend = base_info
2733        for name, b in block_list.items():
2734            if not base in base_list:
2735                raise ValueError("Invalid base size: %d" % base)
2736            suffix = suffix_map[base]
2737            b.set_base(base, base_bits, base_sign_extend, suffix)
2738            blocks[name] = b
2739
2740    symtab = {}
2741    symtab.update(blocks)
2742    for base, union_list in union_map.items():
2743        unions.update(union_list)
2744    symtab.update(unions)
2745    for base_info, union_list in union_map.items():
2746        base, base_bits, base_sign_extend = base_info
2747        for u in union_list.values():
2748            if not base in base_list:
2749                raise ValueError("Invalid base size: %d" % base)
2750            suffix = suffix_map[base]
2751            u.resolve(options, symtab)
2752            u.set_base(base, base_bits, base_sign_extend, suffix)
2753
2754    if not in_filename is None:
2755        base_filename = os.path.basename(in_filename).split('.')[0]
2756
2757        # Generate the module name from the input filename
2758        module_name = base_filename
2759
2760    # Prune list of names to generate
2761    name_list = []
2762    for e in det_values(blocks, unions):
2763        name_list += e.make_names()
2764
2765    name_list = set(name_list)
2766    if len(options.prune_files) > 0:
2767        search_re = re.compile('[a-zA-Z0-9_]+')
2768
2769        pruned_names = set()
2770        for filename in options.prune_files:
2771            f = open(filename)
2772            string = f.read()
2773
2774            matched_tokens = set(search_re.findall(string))
2775            pruned_names.update(matched_tokens & name_list)
2776    else:
2777        pruned_names = name_list
2778
2779    options.names = pruned_names
2780
2781    # Generate the output
2782    if options.hol_defs:
2783        # Fetch kernel
2784        if options.multifile_base is None:
2785            print("theory %s_defs" % module_name, file=out_file)
2786            print("imports \"%s/KernelState_C\"" % (
2787                os.path.relpath(options.cspec_dir,
2788                                os.path.dirname(out_file.filename))), file=out_file)
2789            print("begin", file=out_file)
2790            print(file=out_file)
2791
2792            print(defs_global_lemmas, file=out_file)
2793            print(file=out_file)
2794
2795            for e in det_values(blocks, unions):
2796                e.generate_hol_defs(options)
2797
2798            print("end", file=out_file)
2799        else:
2800            print("theory %s_defs" % module_name, file=out_file)
2801            print("imports", file=out_file)
2802            print("  \"%s/KernelState_C\"" % (
2803                os.path.relpath(options.cspec_dir,
2804                                os.path.dirname(out_file.filename))), file=out_file)
2805            for e in det_values(blocks, unions):
2806                print("  %s_%s_defs" % (module_name, e.name),
2807                      file=out_file)
2808            print("begin", file=out_file)
2809            print("end", file=out_file)
2810
2811            for e in det_values(blocks, unions):
2812                base_filename = \
2813                    os.path.basename(options.multifile_base).split('.')[0]
2814                submodule_name = base_filename + "_" + \
2815                    e.name + "_defs"
2816                out_file = OutputFile(options.multifile_base + "_" +
2817                                      e.name + "_defs" + ".thy")
2818
2819                print("theory %s imports \"%s/KernelState_C\" begin" % (
2820                    submodule_name, os.path.relpath(options.cspec_dir,
2821                                                    os.path.dirname(out_file.filename))),
2822                      file=out_file)
2823                print(file=out_file)
2824
2825                options.output = out_file
2826                e.generate_hol_defs(options)
2827
2828                print("end", file=out_file)
2829    elif options.hol_proofs:
2830        def is_bit_type(tp):
2831            return umm.is_base(tp) & (umm.base_name(tp) in
2832                                      map(lambda e: e.name + '_C', det_values(blocks, unions)))
2833
2834        tps = umm.build_types(options.umm_types_file)
2835        type_map = {}
2836
2837        # invert type map
2838        for toptp in options.toplevel_types:
2839            paths = umm.paths_to_type(tps, is_bit_type, toptp)
2840
2841            for path, tp in paths:
2842                tp = umm.base_name(tp)
2843
2844                if tp in type_map:
2845                    raise ValueError("Type %s has multiple parents" % tp)
2846
2847                type_map[tp] = (toptp, path)
2848
2849        if options.multifile_base is None:
2850            print("theory %s_proofs" % module_name, file=out_file)
2851            print("imports %s_defs" % module_name, file=out_file)
2852            print("begin", file=out_file)
2853            print(file=out_file)
2854            print(file=out_file)
2855
2856            for e in det_values(blocks, unions):
2857                e.generate_hol_proofs(options, type_map)
2858
2859            print("end", file=out_file)
2860        else:
2861            # top types are broken here.
2862            print("theory %s_proofs" % module_name, file=out_file)
2863            print("imports", file=out_file)
2864            for e in det_values(blocks, unions):
2865                print("  %s_%s_proofs" % (module_name, e.name),
2866                      file=out_file)
2867            print("begin", file=out_file)
2868            print("end", file=out_file)
2869
2870            for e in det_values(blocks, unions):
2871                base_filename = \
2872                    os.path.basename(options.multifile_base).split('.')[0]
2873                submodule_name = base_filename + "_" + \
2874                    e.name + "_proofs"
2875                out_file = OutputFile(options.multifile_base + "_" +
2876                                      e.name + "_proofs" + ".thy")
2877
2878                print(("theory %s imports "
2879                       + "%s_%s_defs begin") % (
2880                    submodule_name, base_filename, e.name),
2881                    file=out_file)
2882                print(file=out_file)
2883
2884                options.output = out_file
2885                e.generate_hol_proofs(options, type_map)
2886
2887                print("end", file=out_file)
2888    else:
2889        guard = re.sub(r'[^a-zA-Z0-9_]', '_', out_file.filename.upper())
2890        print("#ifndef %(guard)s\n#define %(guard)s\n" %
2891              {'guard': guard}, file=out_file)
2892        print('\n'.join(map(lambda x: '#include <%s>' % x,
2893                            INCLUDES[options.environment])), file=out_file)
2894        for e in det_values(blocks, unions):
2895            e.generate(options)
2896        print("#endif", file=out_file)
2897
2898    finish_output()
2899