1#!/usr/bin/env python
2
3#
4# Copyright 2017, Data61
5# Commonwealth Scientific and Industrial Research Organisation (CSIRO)
6# ABN 41 687 119 230.
7#
8# This software may be distributed and modified according to the terms of
9# the BSD 2-Clause license. Note that NO WARRANTY is provided.
10# See "LICENSE_BSD2.txt" for details.
11#
12# @TAG(DATA61_BSD)
13#
14
15##
16## A tool for generating bifield structures with get/set/new methods
17## including Isabelle/HOL specifications and correctness proofs.
18##
19
20from __future__ import print_function, division
21import sys
22import os.path
23import optparse
24import re
25import itertools
26import tempfile
27
28from six.moves import range
29from functools import reduce
30
31import lex
32from ply import yacc
33
34import umm
35
36# Whether debugging is enabled (turn on with command line option --debug).
37DEBUG = False
38
39# name of locale the bitfield proofs should be in
40loc_name = 'kernel_all_substitute'
41
42# Isabelle word size suffixes for return value names
43ret_name_suffix_map = {8 : '', 16 : '', 32 : '', 64 : '_longlong'}
44
45def return_name(base):
46    # name of return value for standard word sizes
47    return 'ret__unsigned' + ret_name_suffix_map[base]
48
49
50# Headers to include depending on which environment we are generating code for.
51INCLUDES = {
52    'sel4':['assert.h', 'config.h', 'stdint.h', 'util.h'],
53    'libsel4':['autoconf.h', 'sel4/simple_types.h', 'sel4/debug_assert.h'],
54}
55
56ASSERTS = {
57    'sel4': 'assert',
58    'libsel4': 'seL4_DebugAssert'
59}
60
61INLINE = {
62    'sel4': 'static inline',
63    'libsel4': 'LIBSEL4_INLINE_FUNC'
64}
65
66TYPES = {
67    "sel4": {
68        8:  "uint8_t",
69        16: "uint16_t",
70        32: "uint32_t",
71        64: "uint64_t"
72    },
73
74    "libsel4": {
75        8:  "seL4_Uint8",
76        16: "seL4_Uint16",
77        32: "seL4_Uint32",
78        64: "seL4_Uint64"
79    }
80}
81
82### Parser
83
84reserved = ('BLOCK', 'BASE', 'FIELD', 'FIELD_HIGH', 'MASK', 'PADDING', \
85            'TAGGED_UNION', 'TAG')
86
87tokens = reserved + ('IDENTIFIER', 'INTLIT', 'LBRACE', 'RBRACE', \
88                     'LPAREN', 'RPAREN', 'COMMA')
89
90t_LBRACE = r'{'
91t_RBRACE = r'}'
92t_LPAREN = r'\('
93t_RPAREN = r'\)'
94t_COMMA  = r','
95
96reserved_map = dict((r.lower(), r) for r in reserved)
97
98def t_IDENTIFIER(t):
99    r'[A-Za-z_]\w+|[A-Za-z]'
100    t.type = reserved_map.get(t.value, 'IDENTIFIER')
101    return t
102
103def t_INTLIT(t):
104    r'([1-9][0-9]*|0[oO]?[0-7]+|0[xX][0-9a-fA-F]+|0[bB][01]+|0)[lL]?'
105    t.value = int(t.value, 0)
106    return t
107
108def t_NEWLINE(t):
109    r'\n+'
110    t.lexer.lineno += len(t.value)
111
112def t_comment(t):
113    r'--.*|\#.*'
114
115t_ignore = ' \t'
116
117def t_error(t):
118    print("%s: Unexpected character '%s'" % (sys.argv[0], t.value[0]),
119            file=sys.stderr)
120    if DEBUG:
121        print('Token: %s' % str(t), file=sys.stderr)
122    sys.exit(1)
123
124def p_start(t):
125    """start : entity_list"""
126    t[0] = t[1]
127
128def p_entity_list_empty(t):
129    """entity_list : """
130    t[0] = (None,{},{})
131
132def p_entity_list_base(t):
133    """entity_list : entity_list base"""
134    current_base, block_map, union_map = t[1]
135    block_map.setdefault(t[2], {})
136    union_map.setdefault(t[2], {})
137    t[0] = (t[2], block_map, union_map)
138
139def p_entity_list_block(t):
140    """entity_list : entity_list block"""
141    current_base, block_map, union_map = t[1]
142    block_map[current_base][t[2].name] = t[2]
143    t[0] = (current_base, block_map, union_map)
144
145def p_entity_list_union(t):
146    """entity_list : entity_list tagged_union"""
147    current_base, block_map, union_map = t[1]
148    union_map[current_base][t[2].name] = t[2]
149    t[0] = (current_base, block_map, union_map)
150
151def p_base_simple(t):
152    """base : BASE INTLIT"""
153    t[0] = (t[2], t[2], 0)
154
155def p_base_mask(t):
156    """base : BASE INTLIT LPAREN INTLIT COMMA INTLIT RPAREN"""
157    t[0] = (t[2], t[4], t[6])
158
159def p_block(t):
160    """block : BLOCK IDENTIFIER opt_visible_order_spec""" \
161           """ LBRACE fields RBRACE"""
162    t[0] = Block(name=t[2], fields=t[5], visible_order=t[3])
163
164def p_opt_visible_order_spec_empty(t):
165    """opt_visible_order_spec : """
166    t[0] = None
167
168def p_opt_visible_order_spec(t):
169    """opt_visible_order_spec : LPAREN visible_order_spec RPAREN"""
170    t[0] = t[2]
171
172def p_visible_order_spec_empty(t):
173    """visible_order_spec : """
174    t[0] = []
175
176def p_visible_order_spec_single(t):
177    """visible_order_spec : IDENTIFIER"""
178    t[0] = [t[1]]
179
180def p_visible_order_spec(t):
181    """visible_order_spec : visible_order_spec COMMA IDENTIFIER"""
182    t[0] = t[1] + [t[3]]
183
184def p_fields_empty(t):
185    """fields : """
186    t[0] = []
187
188def p_fields_field(t):
189    """fields : fields FIELD IDENTIFIER INTLIT"""
190    t[0] = t[1] + [(t[3], t[4], False)]
191
192def p_fields_field_high(t):
193    """fields : fields FIELD_HIGH IDENTIFIER INTLIT"""
194    t[0] = t[1] + [(t[3], t[4], True)]
195
196def p_fields_padding(t):
197    """fields : fields PADDING INTLIT"""
198    t[0] = t[1] + [(None, t[3], False)]
199
200def p_tagged_union(t):
201    """tagged_union : TAGGED_UNION IDENTIFIER IDENTIFIER""" \
202                  """ LBRACE masks tags RBRACE"""
203    t[0] = TaggedUnion(name=t[2], tagname=t[3], classes=t[5], tags=t[6])
204
205def p_tags_empty(t):
206    """tags :"""
207    t[0] = []
208
209def p_tags(t):
210    """tags : tags TAG IDENTIFIER INTLIT"""
211    t[0] = t[1] + [(t[3],t[4])]
212
213def p_masks_empty(t):
214    """masks :"""
215    t[0] = []
216
217def p_masks(t):
218    """masks : masks MASK INTLIT INTLIT"""
219    t[0] = t[1] + [(t[3],t[4])]
220
221def p_error(t):
222    print("Syntax error at token '%s'" % t.value, file=sys.stderr)
223    sys.exit(1)
224
225### Templates
226
227## C templates
228
229typedef_template = \
230"""struct %(name)s {
231    %(type)s words[%(multiple)d];
232};
233typedef struct %(name)s %(name)s_t;"""
234
235generator_template = \
236"""%(inline)s %(block)s_t CONST
237%(block)s_new(%(gen_params)s) {
238    %(block)s_t %(block)s;
239
240%(asserts)s
241
242%(gen_inits)s
243
244    return %(block)s;
245}"""
246
247ptr_generator_template = \
248"""%(inline)s void
249%(block)s_ptr_new(%(ptr_params)s) {
250%(asserts)s
251
252%(ptr_inits)s
253}"""
254
255reader_template = \
256"""%(inline)s %(type)s CONST
257%(block)s_get_%(field)s(%(block)s_t %(block)s) {
258    %(type)s ret;
259    ret = (%(block)s.words[%(index)d] & 0x%(mask)x%(suf)s) %(r_shift_op)s %(shift)d;
260    /* Possibly sign extend */
261    if (%(sign_extend)d && (ret & (1%(suf)s << (%(extend_bit)d)))) {
262        ret |= 0x%(high_bits)x;
263    }
264    return ret;
265}"""
266
267ptr_reader_template = \
268"""%(inline)s %(type)s PURE
269%(block)s_ptr_get_%(field)s(%(block)s_t *%(block)s_ptr) {
270    %(type)s ret;
271    ret = (%(block)s_ptr->words[%(index)d] & 0x%(mask)x%(suf)s) """ \
272    """%(r_shift_op)s %(shift)d;
273    /* Possibly sign extend */
274    if (%(sign_extend)d && (ret & (1%(suf)s << (%(extend_bit)d)))) {
275        ret |= 0x%(high_bits)x;
276    }
277    return ret;
278}"""
279
280writer_template = \
281"""%(inline)s %(block)s_t CONST
282%(block)s_set_%(field)s(%(block)s_t %(block)s, %(type)s v%(base)d) {
283    /* fail if user has passed bits that we will override */
284    %(assert)s((((~0x%(mask)x %(r_shift_op)s %(shift)d ) | 0x%(high_bits)x) & v%(base)d) == ((%(sign_extend)d && (v%(base)d & (1%(suf)s << (%(extend_bit)d)))) ? 0x%(high_bits)x : 0));
285    %(block)s.words[%(index)d] &= ~0x%(mask)x%(suf)s;
286    %(block)s.words[%(index)d] |= (v%(base)d %(w_shift_op)s %(shift)d) & 0x%(mask)x%(suf)s;
287    return %(block)s;
288}"""
289
290ptr_writer_template = \
291"""%(inline)s void
292%(block)s_ptr_set_%(field)s(%(block)s_t *%(block)s_ptr, %(type)s v%(base)d) {
293    /* fail if user has passed bits that we will override */
294    %(assert)s((((~0x%(mask)x %(r_shift_op)s %(shift)d) | 0x%(high_bits)x) & v%(base)d) == ((%(sign_extend)d && (v%(base)d & (1%(suf)s << (%(extend_bit)d)))) ? 0x%(high_bits)x : 0));
295    %(block)s_ptr->words[%(index)d] &= ~0x%(mask)x%(suf)s;
296    %(block)s_ptr->words[%(index)d] |= (v%(base)d %(w_shift_op)s """ \
297    """%(shift)d) & 0x%(mask)x;
298}"""
299
300union_generator_template = \
301"""%(inline)s %(union)s_t CONST
302%(union)s_%(block)s_new(%(gen_params)s) {
303    %(union)s_t %(union)s;
304
305%(asserts)s
306
307%(gen_inits)s
308
309    return %(union)s;
310}"""
311
312ptr_union_generator_template = \
313"""%(inline)s void
314%(union)s_%(block)s_ptr_new(%(ptr_params)s) {
315%(asserts)s
316
317%(ptr_inits)s
318}"""
319
320union_reader_template = \
321"""%(inline)s %(type)s CONST
322%(union)s_%(block)s_get_%(field)s(%(union)s_t %(union)s) {
323    %(type)s ret;
324    %(assert)s(((%(union)s.words[%(tagindex)d] >> %(tagshift)d) & 0x%(tagmask)x) ==
325           %(union)s_%(block)s);
326
327    ret = (%(union)s.words[%(index)d] & 0x%(mask)x%(suf)s) %(r_shift_op)s %(shift)d;
328    /* Possibly sign extend */
329    if (%(sign_extend)d && (ret & (1%(suf)s << (%(extend_bit)d)))) {
330        ret |= 0x%(high_bits)x;
331    }
332    return ret;
333}"""
334
335ptr_union_reader_template = \
336"""%(inline)s %(type)s PURE
337%(union)s_%(block)s_ptr_get_%(field)s(%(union)s_t *%(union)s_ptr) {
338    %(type)s ret;
339    %(assert)s(((%(union)s_ptr->words[%(tagindex)d] >> """ \
340    """%(tagshift)d) & 0x%(tagmask)x) ==
341           %(union)s_%(block)s);
342
343    ret = (%(union)s_ptr->words[%(index)d] & 0x%(mask)x%(suf)s) """ \
344    """%(r_shift_op)s %(shift)d;
345    /* Possibly sign extend */
346    if (%(sign_extend)d && (ret & (1%(suf)s << (%(extend_bit)d)))) {
347        ret |= 0x%(high_bits)x;
348    }
349    return ret;
350}"""
351
352union_writer_template = \
353"""%(inline)s %(union)s_t CONST
354%(union)s_%(block)s_set_%(field)s(%(union)s_t %(union)s, %(type)s v%(base)d) {
355    %(assert)s(((%(union)s.words[%(tagindex)d] >> %(tagshift)d) & 0x%(tagmask)x) ==
356           %(union)s_%(block)s);
357    /* fail if user has passed bits that we will override */
358    %(assert)s((((~0x%(mask)x%(suf)s %(r_shift_op)s %(shift)d ) | 0x%(high_bits)x) & v%(base)d) == ((%(sign_extend)d && (v%(base)d & (1%(suf)s << (%(extend_bit)d)))) ? 0x%(high_bits)x : 0));
359
360    %(union)s.words[%(index)d] &= ~0x%(mask)x%(suf)s;
361    %(union)s.words[%(index)d] |= (v%(base)d %(w_shift_op)s %(shift)d) & 0x%(mask)x%(suf)s;
362    return %(union)s;
363}"""
364
365ptr_union_writer_template = \
366"""%(inline)s void
367%(union)s_%(block)s_ptr_set_%(field)s(%(union)s_t *%(union)s_ptr,
368                                      %(type)s v%(base)d) {
369    %(assert)s(((%(union)s_ptr->words[%(tagindex)d] >> """ \
370    """%(tagshift)d) & 0x%(tagmask)x) ==
371           %(union)s_%(block)s);
372
373    /* fail if user has passed bits that we will override */
374    %(assert)s((((~0x%(mask)x%(suf)s %(r_shift_op)s %(shift)d) | 0x%(high_bits)x) & v%(base)d) == ((%(sign_extend)d && (v%(base)d & (1%(suf)s << (%(extend_bit)d)))) ? 0x%(high_bits)x : 0));
375
376    %(union)s_ptr->words[%(index)d] &= ~0x%(mask)x%(suf)s;
377    %(union)s_ptr->words[%(index)d] |= """ \
378    """(v%(base)d %(w_shift_op)s %(shift)d) & 0x%(mask)x%(suf)s;
379}"""
380
381tag_reader_header_template = \
382"""%(inline)s %(type)s CONST
383%(union)s_get_%(tagname)s(%(union)s_t %(union)s) {
384"""
385
386tag_reader_entry_template = \
387"""    if ((%(union)s.words[%(index)d] & 0x%(classmask)x) != 0x%(classmask)x)
388        return (%(union)s.words[%(index)d] >> %(shift)d) & 0x%(mask)x%(suf)s;
389"""
390
391tag_reader_final_template = \
392"""    return (%(union)s.words[%(index)d] >> %(shift)d) & 0x%(mask)x%(suf)s;"""
393
394tag_reader_footer_template = \
395"""
396}"""
397
398tag_eq_reader_header_template = \
399"""%(inline)s int CONST
400%(union)s_%(tagname)s_equals(%(union)s_t %(union)s, %(type)s %(union)s_type_tag) {
401"""
402
403tag_eq_reader_entry_template = \
404"""    if ((%(union)s_type_tag & 0x%(classmask)x) != 0x%(classmask)x)
405        return ((%(union)s.words[%(index)d] >> %(shift)d) & 0x%(mask)x%(suf)s) == %(union)s_type_tag;
406"""
407
408tag_eq_reader_final_template = \
409"""    return ((%(union)s.words[%(index)d] >> %(shift)d) & 0x%(mask)x%(suf)s) == %(union)s_type_tag;"""
410
411tag_eq_reader_footer_template = \
412"""
413}"""
414
415ptr_tag_reader_header_template = \
416"""%(inline)s %(type)s PURE
417%(union)s_ptr_get_%(tagname)s(%(union)s_t *%(union)s_ptr) {
418"""
419
420ptr_tag_reader_entry_template = \
421"""    if ((%(union)s_ptr->words[%(index)d] & 0x%(classmask)x) != 0x%(classmask)x)
422        return (%(union)s_ptr->words[%(index)d] >> %(shift)d) & 0x%(mask)x%(suf)s;
423"""
424
425ptr_tag_reader_final_template = \
426"""    return (%(union)s_ptr->words[%(index)d] >> %(shift)d) & 0x%(mask)x%(suf)s;"""
427
428ptr_tag_reader_footer_template = \
429"""
430}"""
431
432tag_writer_template = \
433"""%(inline)s %(union)s_t CONST
434%(union)s_set_%(tagname)s(%(union)s_t %(union)s, %(type)s v%(base)d) {
435    /* fail if user has passed bits that we will override */
436    %(assert)s((((~0x%(mask)x%(suf)s %(r_shift_op)s %(shift)d) | 0x%(high_bits)x) & v%(base)d) == ((%(sign_extend)d && (v%(base)d & (1%(suf)s << (%(extend_bit)d)))) ? 0x%(high_bits)x : 0));
437
438    %(union)s.words[%(index)d] &= ~0x%(mask)x%(suf)s;
439    %(union)s.words[%(index)d] |= (v%(base)d << %(shift)d) & 0x%(mask)x%(suf)s;
440    return %(union)s;
441}"""
442
443ptr_tag_writer_template = \
444"""%(inline)s void
445%(union)s_ptr_set_%(tagname)s(%(union)s_t *%(union)s_ptr, %(type)s v%(base)d) {
446    /* fail if user has passed bits that we will override */
447    %(assert)s((((~0x%(mask)x%(suf)s %(r_shift_op)s %(shift)d) | 0x%(high_bits)x) & v%(base)d) == ((%(sign_extend)d && (v%(base)d & (1%(suf)s << (%(extend_bit)d)))) ? 0x%(high_bits)x : 0));
448
449    %(union)s_ptr->words[%(index)d] &= ~0x%(mask)x%(suf)s;
450    %(union)s_ptr->words[%(index)d] |= (v%(base)d << %(shift)d) & 0x%(mask)x%(suf)s;
451}"""
452
453# HOL definition templates
454
455lift_def_template = \
456'''definition
457  %(name)s_lift :: "%(name)s_C \<Rightarrow> %(name)s_CL"
458where
459  "%(name)s_lift %(name)s \<equiv> \<lparr>
460       %(fields)s \<rparr>"'''
461
462block_lift_def_template = \
463'''definition %(union)s_%(block)s_lift :: ''' \
464'''"%(union)s_C \<Rightarrow> %(union)s_%(block)s_CL"
465where
466  "%(union)s_%(block)s_lift %(union)s \<equiv>
467    case (%(union)s_lift %(union)s) of ''' \
468    '''Some (%(generator)s rec) \<Rightarrow> rec"'''
469
470block_lift_lemma_template = \
471'''lemma %(union)s_%(block)s_lift:
472  "(%(union)s_get_tag c = scast %(union)s_%(block)s) = ''' \
473 '''(%(union)s_lift c = Some (%(generator)s (%(union)s_%(block)s_lift c)))"
474  unfolding %(union)s_lift_def %(union)s_%(block)s_lift_def
475  by (clarsimp simp: %(union)s_tag_defs Let_def)'''
476
477union_get_tag_def_header_template = \
478'''definition
479  %(name)s_get_tag :: "%(name)s_C \<Rightarrow> word%(base)d"
480where
481  "%(name)s_get_tag %(name)s \<equiv>
482     '''
483
484union_get_tag_def_entry_template = \
485'''if ((index (%(name)s_C.words_C %(name)s) %(tag_index)d)''' \
486''' AND 0x%(classmask)x \<noteq> 0x%(classmask)x)
487      then ((index (%(name)s_C.words_C %(name)s) %(tag_index)d)'''\
488''' >> %(tag_shift)d) AND mask %(tag_size)d
489      else '''
490
491union_get_tag_def_final_template = \
492'''((index (%(name)s_C.words_C %(name)s) %(tag_index)d)'''\
493''' >> %(tag_shift)d) AND mask %(tag_size)d'''
494
495union_get_tag_def_footer_template = '''"'''
496
497union_get_tag_eq_x_def_header_template = \
498'''lemma %(name)s_get_tag_eq_x:
499  "(%(name)s_get_tag c = x) = (('''
500
501union_get_tag_eq_x_def_entry_template = \
502'''if ((x << %(tag_shift)d) AND 0x%(classmask)x \<noteq> 0x%(classmask)x)
503      then ((index (%(name)s_C.words_C c) %(tag_index)d)''' \
504''' >> %(tag_shift)d) AND mask %(tag_size)d
505      else '''
506
507union_get_tag_eq_x_def_final_template = \
508'''((index (%(name)s_C.words_C c) %(tag_index)d)''' \
509''' >> %(tag_shift)d) AND mask %(tag_size)d'''
510
511union_get_tag_eq_x_def_footer_template = ''') = x)"
512  by (auto simp add: %(name)s_get_tag_def mask_def word_bw_assocs)'''
513
514union_tag_mask_helpers_header_template = \
515'''lemma %(name)s_%(block)s_tag_mask_helpers:'''
516
517union_tag_mask_helpers_entry_template = '''
518  "w && %(full_mask)s = %(full_value)s \<Longrightarrow> w'''\
519''' && %(part_mask)s = %(part_value)s"
520'''
521
522union_tag_mask_helpers_footer_template = \
523'''  by (auto elim: word_sub_mask simp: mask_def)'''
524
525union_lift_def_template = \
526'''definition
527  %(name)s_lift :: "%(name)s_C \<Rightarrow> %(name)s_CL option"
528where
529  "%(name)s_lift %(name)s \<equiv>
530    (let tag = %(name)s_get_tag %(name)s in
531     %(tag_cases)s
532     else None)"'''
533
534union_access_def_template = \
535'''definition
536  %(union)s_%(block)s_access :: "(%(union)s_%(block)s_CL \<Rightarrow> 'a)
537                                 \<Rightarrow> %(union)s_CL \<Rightarrow> 'a"
538where
539  "%(union)s_%(block)s_access f %(union)s \<equiv>
540     (case %(union)s of %(generator)s rec \<Rightarrow> f rec)"'''
541
542union_update_def_template = \
543'''definition
544  %(union)s_%(block)s_update :: "(%(union)s_%(block)s_CL \<Rightarrow>''' \
545                                ''' %(union)s_%(block)s_CL) \<Rightarrow>
546                                 %(union)s_CL \<Rightarrow> %(union)s_CL"
547where
548  "%(union)s_%(block)s_update f %(union)s \<equiv>
549     (case %(union)s of %(generator)s rec \<Rightarrow>
550        %(generator)s (f rec))"'''
551
552# HOL proof templates
553
554#FIXME: avoid [simp]
555struct_lemmas_template = \
556'''
557lemmas %(name)s_ptr_guards[simp] =
558  %(name)s_ptr_words_NULL
559  %(name)s_ptr_words_aligned
560  %(name)s_ptr_words_ptr_safe'''
561
562# FIXME: move to global theory
563defs_global_lemmas = '''
564lemma word_sub_mask:
565  "\<lbrakk> w && m1 = v1; m1 && m2 = m2; v1 && m2 = v2 \<rbrakk>
566     \<Longrightarrow> w && m2 = v2"
567  by (clarsimp simp: word_bw_assocs)
568'''
569
570# Proof templates are stored as a list of
571# (header, body, stuff to go afterwards).
572# This makes it easy to replace the proof body with a sorry.
573
574# ptrname should be a function of s
575def ptr_basic_template(name, ptrname, retval, args, post):
576    return ('''lemma (in ''' + loc_name + ''') %(name)s_ptr_''' + name + '''_spec:
577           defines "ptrval s \<equiv> cslift s ''' + ptrname + '''"
578           shows "\<forall>s. \<Gamma> \<turnstile> \<lbrace>s. s \<Turnstile>\<^sub>c ''' + ptrname + '''\<rbrace>
579            ''' + retval + '''PROC %(name)s_ptr_''' + name + '''(\<acute>%(name)s_ptr''' + args + ''')
580            ''' + post + ''' " ''')
581
582def ptr_union_basic_template(name, ptrname, retval, args, pre, post):
583    return ('''lemma (in ''' + loc_name + ''') %(name)s_%(block)s_ptr_''' + name + '''_spec:
584    defines "ptrval s \<equiv> cslift s ''' + ptrname + '''"
585    shows "\<forall>s. \<Gamma> \<turnstile> \<lbrace>s. s \<Turnstile>\<^sub>c ''' + ptrname + " " + pre + '''\<rbrace>
586            ''' + retval + '''PROC %(name)s_%(block)s_ptr_''' + name + '''(\<acute>%(name)s_ptr''' + args + ''')
587            ''' + post + ''' " ''')
588
589direct_ptr_name = '\<^bsup>s\<^esup>%(name)s_ptr'
590path_ptr_name = '(cparent \<^bsup>s\<^esup>%(name)s_ptr [%(path)s] :: %(toptp)s ptr)'
591
592def ptr_get_template(ptrname):
593    return ptr_basic_template('get_%(field)s', ptrname, '\<acute>%(ret_name)s :== ', '',
594                              '''\<lbrace>\<acute>%(ret_name)s = ''' \
595                              '''%(name)s_CL.%(field)s_CL ''' \
596                              '''(%(name)s_lift (%(access_path)s))\<rbrace>''')
597
598def ptr_set_template(name, ptrname):
599    return ptr_basic_template(name, ptrname, '', ', \<acute>v%(base)d',
600                              '''{t. \<exists>%(name)s.
601                              %(name)s_lift %(name)s =
602                              %(name)s_lift (%(access_path)s) \<lparr> %(name)s_CL.%(field)s_CL ''' \
603                              ''':= %(sign_extend)s(\<^bsup>s\<^esup>v%(base)d AND %(mask)s) \<rparr> \<and>
604                              t_hrs_' (globals t) = hrs_mem_update (heap_update
605                                      (''' + ptrname + ''')
606                                      %(update_path)s)
607                                  (t_hrs_' (globals s))
608                              }''')
609
610def ptr_new_template(ptrname):
611    return ptr_basic_template('new', ptrname, '', ', %(args)s',
612                              '''{t. \<exists>%(name)s. %(name)s_lift %(name)s = \<lparr>
613                              %(field_eqs)s \<rparr> \<and>
614                              t_hrs_' (globals t) = hrs_mem_update (heap_update
615                                      (''' + ptrname + ''')
616                                      %(update_path)s)
617                                  (t_hrs_' (globals s))
618                              }''')
619
620def ptr_get_tag_template(ptrname):
621    return ptr_basic_template('get_%(tagname)s', ptrname, '\<acute>%(ret_name)s :== ', '',
622                              '''\<lbrace>\<acute>%(ret_name)s = %(name)s_get_tag (%(access_path)s)\<rbrace>''')
623
624
625def ptr_empty_union_new_template(ptrname):
626    return ptr_union_basic_template('new', ptrname, '', '', '',
627                                    '''{t. \<exists>%(name)s. ''' \
628                                    '''%(name)s_get_tag %(name)s = scast %(name)s_%(block)s \<and>
629                                    t_hrs_' (globals t) = hrs_mem_update (heap_update
630                                            (''' + ptrname + ''')
631                                            %(update_path)s)
632                                        (t_hrs_' (globals s))
633                                    }''')
634
635def ptr_union_new_template(ptrname):
636    return ptr_union_basic_template('new', ptrname, '', ', %(args)s', '',
637                                    '''{t. \<exists>%(name)s. ''' \
638                                    '''%(name)s_%(block)s_lift %(name)s = \<lparr>
639                                    %(field_eqs)s \<rparr> \<and>
640                                    %(name)s_get_tag %(name)s = scast %(name)s_%(block)s \<and>
641                                    t_hrs_' (globals t) = hrs_mem_update (heap_update
642                                            (''' + ptrname + ''')
643                                            %(update_path)s)
644                                        (t_hrs_' (globals s))
645                                    }''')
646
647def ptr_union_get_template(ptrname):
648    return ptr_union_basic_template('get_%(field)s', ptrname,
649                                    '\<acute>%(ret_name)s :== ', '',
650                                    '\<and> %(name)s_get_tag %(access_path)s = scast %(name)s_%(block)s',
651                                    '''\<lbrace>\<acute>%(ret_name)s = ''' \
652                                    '''%(name)s_%(block)s_CL.%(field)s_CL ''' \
653                                    '''(%(name)s_%(block)s_lift %(access_path)s)\<rbrace>''')
654
655def ptr_union_set_template(ptrname):
656    return ptr_union_basic_template('set_%(field)s', ptrname, '', ', \<acute>v%(base)d',
657                                    '\<and> %(name)s_get_tag %(access_path)s = scast %(name)s_%(block)s',
658                                    '''{t. \<exists>%(name)s. ''' \
659                                    '''%(name)s_%(block)s_lift %(name)s =
660                                    %(name)s_%(block)s_lift %(access_path)s ''' \
661                                    '''\<lparr> %(name)s_%(block)s_CL.%(field)s_CL ''' \
662                                    ''':= %(sign_extend)s(\<^bsup>s\<^esup>v%(base)d AND %(mask)s) \<rparr> \<and>
663                                    %(name)s_get_tag %(name)s = scast %(name)s_%(block)s \<and>
664                                    t_hrs_' (globals t) = hrs_mem_update (heap_update
665                                            (''' + ptrname + ''')
666                                            %(update_path)s)
667                                        (t_hrs_' (globals s))
668                                    }''')
669
670proof_templates = {
671
672'lift_collapse_proof' : [
673'''lemma %(name)s_lift_%(block)s:
674  "%(name)s_get_tag %(name)s = scast %(name)s_%(block)s \<Longrightarrow>
675  %(name)s_lift %(name)s =
676  Some (%(value)s)"''',
677''' apply(simp add:%(name)s_lift_def %(name)s_tag_defs)
678done'''],
679
680'words_NULL_proof' : [
681'''lemma %(name)s_ptr_words_NULL:
682  "c_guard (p::%(name)s_C ptr) \<Longrightarrow>
683   0 < &(p\<rightarrow>[''words_C''])"''',
684''' apply(fastforce intro:c_guard_NULL_fl simp:typ_uinfo_t_def)
685done'''],
686
687'words_aligned_proof' : [
688'''lemma %(name)s_ptr_words_aligned:
689  "c_guard (p::%(name)s_C ptr) \<Longrightarrow>
690   ptr_aligned ((Ptr &(p\<rightarrow>[''words_C'']))::''' \
691               '''((word%(base)d[%(words)d]) ptr))"''',
692''' apply(fastforce intro:c_guard_ptr_aligned_fl simp:typ_uinfo_t_def)
693done'''],
694
695'words_ptr_safe_proof' : [
696'''lemma %(name)s_ptr_words_ptr_safe:
697  "ptr_safe (p::%(name)s_C ptr) d \<Longrightarrow>
698   ptr_safe (Ptr &(p\<rightarrow>[''words_C''])::''' \
699         '''((word%(base)d[%(words)d]) ptr)) d"''',
700''' apply(fastforce intro:ptr_safe_mono simp:typ_uinfo_t_def)
701done'''],
702
703'get_tag_fun_spec_proof' : [
704'''lemma (in ''' + loc_name + ''') fun_spec:
705  "\<Gamma> \<turnstile> {\<sigma>}
706       \<acute>ret__%(rtype)s :== PROC %(name)s_get_%(tag_name)s(''' \
707                                             ''' \<acute>%(name))
708       \<lbrace>\<acute>ret__%(rtype)s = %(name)s_get_tag''' \
709                             '''\<^bsup>\<sigma>\<^esup>\<rbrace>"''',
710''' apply(rule allI, rule conseqPre, vcg)
711 apply(clarsimp)
712 apply(simp add:$(name)s_get_tag_def word_sle_def
713                mask_def ucast_def)
714done'''],
715
716'const_modifies_proof' : [
717'''lemma (in ''' + loc_name + ''') %(fun_name)s_modifies:
718  "\<forall> s. \<Gamma> \<turnstile>\<^bsub>/UNIV\<^esub> {s}
719       PROC %(fun_name)s(%(args)s)
720       {t. t may_not_modify_globals s}"''',
721''' by (vcg spec=modifies strip_guards=true)'''],
722
723'ptr_set_modifies_proof' : [
724'''lemma (in ''' + loc_name + ''') %(fun_name)s_modifies:
725  "\<forall>s. \<Gamma> \<turnstile>\<^bsub>/UNIV\<^esub> {s}
726       PROC %(fun_name)s(%(args)s)
727       {t. t may_only_modify_globals s in [t_hrs]}"''',
728''' by (vcg spec=modifies strip_guards=true)'''],
729
730
731'new_spec' : [
732'''lemma (in ''' + loc_name + ''') %(name)s_new_spec:
733  "\<forall> s. \<Gamma> \<turnstile> {s}
734       \<acute>ret__struct_%(name)s_C :== PROC %(name)s_new(%(args)s)
735       \<lbrace> %(name)s_lift \<acute>ret__struct_%(name)s_C = \<lparr>
736          %(field_eqs)s \<rparr> \<rbrace>"''',
737'''  apply (rule allI, rule conseqPre, vcg)
738  apply (clarsimp simp: guard_simps)
739  apply (simp add: %(name)s_lift_def)
740  apply ((intro conjI sign_extend_eq)?;
741         (simp add: mask_def shift_over_ao_dists multi_shift_simps word_size
742                    word_ao_dist word_bw_assocs word_and_max_simps)?)
743  done'''],
744
745'ptr_new_spec_direct' : [
746    ptr_new_template(direct_ptr_name),
747'''sorry (* ptr_new_spec_direct *)'''],
748
749'ptr_new_spec_path' : [
750    ptr_new_template(path_ptr_name),
751'''sorry (* ptr_new_spec_path *)'''],
752
753
754'get_spec' : [
755'''lemma (in ''' + loc_name + ''') %(name)s_get_%(field)s_spec:
756  "\<forall>s. \<Gamma> \<turnstile> {s}
757       \<acute>%(ret_name)s :== ''' \
758       '''PROC %(name)s_get_%(field)s(\<acute>%(name)s)
759       \<lbrace>\<acute>%(ret_name)s = ''' \
760       '''%(name)s_CL.%(field)s_CL ''' \
761       '''(%(name)s_lift \<^bsup>s\<^esup>%(name)s)\<rbrace>"''',
762'''  apply (rule allI, rule conseqPre, vcg)
763  apply clarsimp
764  apply (simp add: %(name)s_lift_def mask_shift_simps guard_simps)
765  apply (simp add: sign_extend_def' mask_def nth_is_and_neq_0 word_bw_assocs
766                   shift_over_ao_dists word_oa_dist word_and_max_simps)?
767  done'''],
768
769'set_spec' : [
770'''lemma (in ''' + loc_name + ''') %(name)s_set_%(field)s_spec:
771  "\<forall>s. \<Gamma> \<turnstile> {s}
772       \<acute>ret__struct_%(name)s_C :== ''' \
773       '''PROC %(name)s_set_%(field)s(\<acute>%(name)s, \<acute>v%(base)d)
774       \<lbrace>%(name)s_lift \<acute>ret__struct_%(name)s_C = ''' \
775       '''%(name)s_lift \<^bsup>s\<^esup>%(name)s \<lparr> ''' \
776       '''%(name)s_CL.%(field)s_CL ''' \
777       ''':= %(sign_extend)s (\<^bsup>s\<^esup>v%(base)d AND %(mask)s) \<rparr>\<rbrace>"''',
778'''  apply(rule allI, rule conseqPre, vcg)
779  apply(clarsimp simp: guard_simps ucast_id
780                       %(name)s_lift_def
781                       mask_def shift_over_ao_dists
782                       multi_shift_simps word_size
783                       word_ao_dist word_bw_assocs
784                       NOT_eq)
785  apply (simp add: sign_extend_def' mask_def nth_is_and_neq_0 word_bw_assocs
786                   shift_over_ao_dists word_and_max_simps)?
787  done'''],
788
789# where the top level type is the bitfield type --- these are split because they have different proofs
790'ptr_get_spec_direct' : [
791    ptr_get_template(direct_ptr_name),
792'''   unfolding ptrval_def
793  apply (rule allI, rule conseqPre, vcg)
794  apply (clarsimp simp: h_t_valid_clift_Some_iff)
795  apply (simp add: %(name)s_lift_def guard_simps mask_def typ_heap_simps ucast_def)
796  apply (simp add: sign_extend_def' mask_def nth_is_and_neq_0 word_bw_assocs
797                   shift_over_ao_dists word_oa_dist word_and_max_simps)?
798  done'''],
799
800'ptr_get_spec_path' : [
801    ptr_get_template(path_ptr_name),
802'''  unfolding ptrval_def
803  apply (rule allI, rule conseqPre, vcg)
804  apply (clarsimp simp: guard_simps)
805  apply (frule iffD1[OF h_t_valid_clift_Some_iff], rule exE, assumption, simp)
806  apply (frule clift_subtype, simp, simp, simp)
807  apply (simp add: h_val_field_clift' typ_heap_simps)
808  apply (simp add: thread_state_lift_def)
809  apply (simp add: sign_extend_def' mask_def nth_is_and_neq_0 word_bw_assocs
810                   shift_over_ao_dists word_oa_dist word_and_max_simps)?
811  apply (simp add: mask_shift_simps)?
812  done'''],
813
814'ptr_set_spec_direct' : [
815    ptr_set_template('set_%(field)s', direct_ptr_name),
816'''  unfolding ptrval_def
817  apply (rule allI, rule conseqPre, vcg)
818  apply (clarsimp simp: guard_simps)
819  apply (clarsimp simp add: packed_heap_update_collapse_hrs typ_heap_simps)?
820  apply (rule exI, rule conjI[rotated], rule refl)
821  apply (clarsimp simp: h_t_valid_clift_Some_iff %(name)s_lift_def typ_heap_simps)
822  apply ((intro conjI sign_extend_eq)?;
823         (simp add: mask_def shift_over_ao_dists multi_shift_simps word_size
824                    word_ao_dist word_bw_assocs word_and_max_simps))?
825  done'''],
826
827'ptr_set_spec_path' : [
828    ptr_set_template('set_%(field)s', path_ptr_name),
829'''  (* Invoke vcg *)
830  unfolding ptrval_def
831  apply (rule allI, rule conseqPre, vcg)
832  apply (clarsimp)
833
834  (* Infer h_t_valid for all three levels of indirection *)
835  apply (frule h_t_valid_c_guard_cparent, simp, simp add: typ_uinfo_t_def)
836  apply (frule h_t_valid_c_guard_field[where f="[''words_C'']"],
837                                       simp, simp add: typ_uinfo_t_def)
838
839  (* Discharge guards, including c_guard for pointers *)
840  apply (simp add: h_t_valid_c_guard guard_simps)
841
842  (* Lift field updates to bitfield struct updates *)
843  apply (simp add: heap_update_field_hrs h_t_valid_c_guard typ_heap_simps)
844
845  (* Collapse multiple updates *)
846  apply(simp add: packed_heap_update_collapse_hrs)
847
848  (* Instantiate the toplevel object *)
849  apply(frule iffD1[OF h_t_valid_clift_Some_iff], rule exE, assumption, simp)
850
851  (* Instantiate the next-level object in terms of the last *)
852  apply(frule clift_subtype, simp+)
853
854  (* Resolve pointer accesses *)
855  apply(simp add: h_val_field_clift')
856
857  (* Rewrite bitfield struct updates as enclosing struct updates *)
858  apply(frule h_t_valid_c_guard)
859  apply(simp add: parent_update_child)
860
861  (* Equate the updated values *)
862  apply(rule exI, rule conjI[rotated], simp add: h_val_clift')
863
864  (* Rewrite struct updates *)
865  apply(simp add: o_def %(name)s_lift_def)
866
867  (* Solve bitwise arithmetic *)
868  apply ((intro conjI sign_extend_eq)?;
869         (simp add: mask_def shift_over_ao_dists multi_shift_simps word_size
870                    word_ao_dist word_bw_assocs word_and_max_simps))?
871  done'''],
872
873
874'get_tag_spec' : [
875'''lemma (in ''' + loc_name + ''') %(name)s_get_%(tagname)s_spec:
876  "\<forall>s. \<Gamma> \<turnstile> {s}
877       \<acute>%(ret_name)s :== ''' \
878    '''PROC %(name)s_get_%(tagname)s(\<acute>%(name)s)
879       \<lbrace>\<acute>%(ret_name)s = ''' \
880    '''%(name)s_get_tag \<^bsup>s\<^esup>%(name)s\<rbrace>"''',
881''' apply(rule allI, rule conseqPre, vcg)
882 apply(clarsimp)
883 apply(simp add:%(name)s_get_tag_def
884                mask_shift_simps
885                guard_simps)
886done'''],
887
888'get_tag_equals_spec' : [
889'''lemma (in ''' + loc_name + ''') %(name)s_%(tagname)s_equals_spec:
890  "\<forall>s. \<Gamma> \<turnstile> {s}
891       \<acute>ret__int :==
892       PROC %(name)s_%(tagname)s_equals(\<acute>%(name)s, \<acute>%(name)s_type_tag)
893       \<lbrace>\<acute>ret__int = of_bl [%(name)s_get_tag \<^bsup>s\<^esup>%(name)s = \<^bsup>s\<^esup>%(name)s_type_tag]\<rbrace>"''',
894''' apply(rule allI, rule conseqPre, vcg)
895 apply(clarsimp)
896 apply(simp add:%(name)s_get_tag_eq_x
897                mask_shift_simps
898                guard_simps)
899done'''],
900
901'ptr_get_tag_spec_direct' : [
902    ptr_get_tag_template(direct_ptr_name),
903''' unfolding ptrval_def
904 apply(rule allI, rule conseqPre, vcg)
905 apply(clarsimp simp:guard_simps)
906 apply(frule h_t_valid_field[where f="[''words_C'']"], simp+)
907 apply(frule iffD1[OF h_t_valid_clift_Some_iff], rule exE, assumption, simp)
908 apply(simp add:h_val_clift' clift_field)
909 apply(simp add:%(name)s_get_tag_def)
910 apply(simp add:mask_shift_simps)
911done'''],
912
913'ptr_get_tag_spec_path' : [
914    ptr_get_tag_template(path_ptr_name),
915''' unfolding ptrval_def
916 apply(rule allI, rule conseqPre, vcg)
917 apply(clarsimp)
918 apply(frule h_t_valid_c_guard_cparent, simp, simp add: typ_uinfo_t_def)
919 apply(clarsimp simp: typ_heap_simps h_t_valid_clift_Some_iff)
920 apply(frule clift_subtype, simp+)
921 apply(simp add: %(name)s_get_tag_def mask_shift_simps guard_simps)
922done'''],
923
924
925'empty_union_new_spec' : [
926'''lemma (in ''' + loc_name + ''') ''' \
927      '''%(name)s_%(block)s_new_spec:
928  "\<forall>s. \<Gamma> \<turnstile> {s}
929       \<acute>ret__struct_%(name)s_C :== ''' \
930    '''PROC %(name)s_%(block)s_new()
931       \<lbrace>%(name)s_get_tag \<acute>ret__struct_%(name)s_C = ''' \
932     '''scast %(name)s_%(block)s\<rbrace>"''',
933''' apply(rule allI, rule conseqPre, vcg)
934 apply(clarsimp simp:guard_simps
935                     %(name)s_lift_def
936                     Let_def
937                     %(name)s_get_tag_def
938                     mask_shift_simps
939                     %(name)s_tag_defs
940                     word_of_int_hom_syms)
941done'''],
942
943'union_new_spec' : [
944'''lemma (in ''' + loc_name + ''') ''' \
945      '''%(name)s_%(block)s_new_spec:
946  "\<forall>s. \<Gamma> \<turnstile> {s}
947       \<acute>ret__struct_%(name)s_C :== ''' \
948    '''PROC %(name)s_%(block)s_new(%(args)s)
949       \<lbrace>%(name)s_%(block)s_lift ''' \
950    '''\<acute>ret__struct_%(name)s_C = \<lparr>
951          %(field_eqs)s \<rparr> \<and>
952        %(name)s_get_tag \<acute>ret__struct_%(name)s_C = ''' \
953     '''scast %(name)s_%(block)s\<rbrace>"''',
954'''  apply (rule allI, rule conseqPre, vcg)
955  apply (clarsimp simp: guard_simps o_def mask_def shift_over_ao_dists)
956  apply (rule context_conjI[THEN iffD1[OF conj_commute]],
957         fastforce simp: %(name)s_get_tag_eq_x %(name)s_%(block)s_def
958                         mask_def shift_over_ao_dists word_bw_assocs word_ao_dist)
959  apply (simp add: %(name)s_%(block)s_lift_def)
960  apply (erule %(name)s_lift_%(block)s[THEN subst[OF sym]]; simp?)
961  apply ((intro conjI sign_extend_eq)?;
962         (simp add: mask_def shift_over_ao_dists multi_shift_simps word_size
963                    word_ao_dist word_bw_assocs word_and_max_simps))?
964  done'''],
965
966'ptr_empty_union_new_spec_direct' : [
967    ptr_empty_union_new_template(direct_ptr_name),
968'''sorry (* ptr_empty_union_new_spec_direct *)'''],
969
970'ptr_empty_union_new_spec_path' : [
971    ptr_empty_union_new_template(path_ptr_name),
972''' unfolding ptrval_def
973 apply(rule allI, rule conseqPre, vcg)
974 apply(clarsimp)
975 apply(frule h_t_valid_c_guard_cparent, simp, simp add: typ_uinfo_t_def)
976 apply(clarsimp simp: h_t_valid_clift_Some_iff)
977 apply(frule clift_subtype, simp+)
978 apply(clarsimp simp: typ_heap_simps c_guard_clift
979                      packed_heap_update_collapse_hrs)
980
981 apply(simp add: guard_simps mask_shift_simps
982                 %(name)s_tag_defs[THEN tag_eq_to_tag_masked_eq])
983
984 apply(simp add: parent_update_child[OF c_guard_clift]
985                 typ_heap_simps c_guard_clift)
986
987 apply(simp add: o_def, rule exI, rule conjI[OF _ refl])
988
989 apply(simp add: %(name)s_get_tag_def %(name)s_tag_defs
990                 guard_simps mask_shift_simps)
991done
992'''],
993
994'ptr_union_new_spec_direct' : [
995    ptr_union_new_template(direct_ptr_name),
996'''sorry (* ptr_union_new_spec_direct *)'''],
997
998'ptr_union_new_spec_path' : [
999    ptr_union_new_template(path_ptr_name),
1000'''  unfolding ptrval_def
1001  apply (rule allI, rule conseqPre, vcg)
1002  apply (clarsimp)
1003  apply (frule h_t_valid_c_guard_cparent, simp, simp add: typ_uinfo_t_def)
1004  apply (drule h_t_valid_clift_Some_iff[THEN iffD1], erule exE)
1005  apply (frule clift_subtype, simp, simp)
1006  apply (clarsimp simp: typ_heap_simps c_guard_clift
1007                        packed_heap_update_collapse_hrs)
1008  apply (simp add: guard_simps mask_shift_simps
1009                   %(name)s_tag_defs[THEN tag_eq_to_tag_masked_eq])?
1010  apply (simp add: parent_update_child[OF c_guard_clift]
1011                   typ_heap_simps c_guard_clift)
1012  apply (simp add: o_def %(name)s_%(block)s_lift_def)
1013  apply (simp only: %(name)s_lift_%(block)s cong: rev_conj_cong)
1014  apply (rule exI, rule conjI[rotated], rule conjI[OF _ refl])
1015   apply (simp_all add: %(name)s_get_tag_eq_x %(name)s_tag_defs mask_shift_simps)
1016  apply (intro conjI sign_extend_eq; simp add: mask_def word_ao_dist word_bw_assocs)?
1017  done'''],
1018
1019'union_get_spec' : [
1020'''lemma (in ''' + loc_name + ''') ''' \
1021'''%(name)s_%(block)s_get_%(field)s_spec:
1022  "\<forall>s. \<Gamma> \<turnstile> ''' \
1023'''\<lbrace>s. %(name)s_get_tag \<acute>%(name)s = ''' \
1024            '''scast %(name)s_%(block)s\<rbrace>
1025       \<acute>%(ret_name)s :== ''' \
1026       '''PROC %(name)s_%(block)s_get_%(field)s(\<acute>%(name)s)
1027       \<lbrace>\<acute>%(ret_name)s = ''' \
1028       '''%(name)s_%(block)s_CL.%(field)s_CL ''' \
1029       '''(%(name)s_%(block)s_lift \<^bsup>s\<^esup>%(name)s)''' \
1030       '''\<rbrace>"''',
1031''' apply(rule allI, rule conseqPre, vcg)
1032 apply(clarsimp simp:guard_simps)
1033 apply(simp add:%(name)s_%(block)s_lift_def)
1034 apply(subst %(name)s_lift_%(block)s)
1035  apply(simp add:o_def
1036                 %(name)s_get_tag_def
1037                 %(name)s_%(block)s_def
1038                 mask_def
1039                 word_size
1040                 shift_over_ao_dists)
1041 apply(subst %(name)s_lift_%(block)s, simp)?
1042 apply(simp add:o_def
1043                %(name)s_get_tag_def
1044                %(name)s_%(block)s_def
1045                mask_def
1046                word_size
1047                shift_over_ao_dists
1048                multi_shift_simps
1049                word_bw_assocs
1050                word_oa_dist
1051                word_and_max_simps
1052                ucast_def
1053                sign_extend_def'
1054                nth_is_and_neq_0)
1055done'''],
1056
1057'union_set_spec' : [
1058'''lemma (in ''' + loc_name + ''') ''' \
1059       '''%(name)s_%(block)s_set_%(field)s_spec:
1060  "\<forall>s. \<Gamma> \<turnstile> ''' \
1061'''\<lbrace>s. %(name)s_get_tag \<acute>%(name)s = ''' \
1062            '''scast %(name)s_%(block)s\<rbrace>
1063       \<acute>ret__struct_%(name)s_C :== ''' \
1064    '''PROC %(name)s_%(block)s_set_%(field)s(\<acute>%(name)s, \<acute>v%(base)d)
1065       \<lbrace>%(name)s_%(block)s_lift \<acute>ret__struct_%(name)s_C = ''' \
1066    '''%(name)s_%(block)s_lift \<^bsup>s\<^esup>%(name)s \<lparr> ''' \
1067    '''%(name)s_%(block)s_CL.%(field)s_CL ''' \
1068    ''':= %(sign_extend)s (\<^bsup>s\<^esup>v%(base)d AND %(mask)s)\<rparr> \<and>
1069        %(name)s_get_tag \<acute>ret__struct_%(name)s_C = ''' \
1070     '''scast %(name)s_%(block)s\<rbrace>"''',
1071'''  apply (rule allI, rule conseqPre, vcg)
1072  apply clarsimp
1073  apply (rule context_conjI[THEN iffD1[OF conj_commute]],
1074         fastforce simp: %(name)s_get_tag_eq_x %(name)s_lift_def %(name)s_tag_defs
1075                         mask_def shift_over_ao_dists multi_shift_simps word_size
1076                         word_ao_dist word_bw_assocs)
1077  apply (simp add: %(name)s_%(block)s_lift_def %(name)s_lift_def %(name)s_tag_defs)
1078  apply ((intro conjI sign_extend_eq)?;
1079         (simp add: mask_def shift_over_ao_dists multi_shift_simps word_size
1080                    word_ao_dist word_bw_assocs word_and_max_simps))?
1081  done'''],
1082
1083'ptr_union_get_spec_direct' : [
1084    ptr_union_get_template(direct_ptr_name),
1085''' unfolding ptrval_def
1086 apply(rule allI, rule conseqPre, vcg)
1087 apply(clarsimp simp: typ_heap_simps h_t_valid_clift_Some_iff guard_simps
1088                      mask_shift_simps sign_extend_def' nth_is_and_neq_0
1089                      %(name)s_lift_%(block)s %(name)s_%(block)s_lift_def)
1090done
1091'''],
1092
1093'ptr_union_get_spec_path' : [
1094    ptr_union_get_template(path_ptr_name),
1095'''unfolding ptrval_def
1096  apply(rule allI, rule conseqPre, vcg)
1097  apply(clarsimp)
1098  apply(frule h_t_valid_c_guard_cparent, simp, simp add: typ_uinfo_t_def)
1099  apply(drule h_t_valid_clift_Some_iff[THEN iffD1], erule exE)
1100  apply(frule clift_subtype, simp, simp)
1101  apply(clarsimp simp: typ_heap_simps c_guard_clift)
1102  apply(simp add: guard_simps mask_shift_simps)
1103  apply(simp add:%(name)s_%(block)s_lift_def)
1104  apply(subst %(name)s_lift_%(block)s)
1105  apply(simp add: mask_def)+
1106  done
1107  (* ptr_union_get_spec_path *)'''],
1108
1109'ptr_union_set_spec_direct' : [
1110        ptr_union_set_template(direct_ptr_name),
1111'''sorry (* ptr_union_set_spec_direct *)'''],
1112
1113
1114'ptr_union_set_spec_path' : [
1115        ptr_union_set_template(path_ptr_name),
1116'''  unfolding ptrval_def
1117  apply (rule allI, rule conseqPre, vcg)
1118  apply (clarsimp)
1119  apply (frule h_t_valid_c_guard_cparent, simp, simp add: typ_uinfo_t_def)
1120  apply (drule h_t_valid_clift_Some_iff[THEN iffD1], erule exE)
1121  apply (frule clift_subtype, simp, simp)
1122  apply (clarsimp simp: typ_heap_simps c_guard_clift
1123                        packed_heap_update_collapse_hrs)
1124  apply (simp add: guard_simps mask_shift_simps
1125                   %(name)s_tag_defs[THEN tag_eq_to_tag_masked_eq])?
1126  apply (simp add: parent_update_child[OF c_guard_clift]
1127                   typ_heap_simps c_guard_clift)
1128  apply (simp add: o_def %(name)s_%(block)s_lift_def)
1129  apply (simp only: %(name)s_lift_%(block)s cong: rev_conj_cong)
1130  apply (rule exI, rule conjI[rotated], rule conjI[OF _ refl])
1131   apply (simp_all add: %(name)s_get_tag_eq_x %(name)s_tag_defs mask_shift_simps)
1132  apply (intro conjI sign_extend_eq; simp add: mask_def word_ao_dist word_bw_assocs)?
1133  done'''],
1134
1135}
1136
1137def make_proof(name, substs, sorry=False):
1138    result = proof_templates[name][0] % substs + '\n'
1139
1140    if sorry:
1141        result += '\nsorry'
1142    else:
1143        result += proof_templates[name][1] % substs
1144
1145    if len(proof_templates[name]) > 2:
1146        result += '\n' + '\n'.join(proof_templates[name][2:]) % substs
1147
1148    return result
1149
1150## AST objects
1151
1152def emit_named(name, params, string):
1153    # Emit a named definition/proof, only when the given name is in
1154    # params.names
1155
1156     if(name in params.names):
1157        print(string, file=params.output)
1158        print(file=params.output)
1159
1160# This calculates substs for each proof, which is probably inefficient.  Meh
1161def emit_named_ptr_proof(fn_name, params, name, type_map, toptps, prf_prefix, substs):
1162    name_C = name + '_C'
1163
1164    if name_C in type_map:
1165        toptp, path = type_map[name_C]
1166
1167        substs['access_path'] = '(' + reduce(lambda x, y: y + ' (' + x + ')', ['the (ptrval s)'] + path) + ')'
1168
1169        if len(path) == 0:
1170            substs['update_path'] = name
1171            emit_named(fn_name, params, make_proof(prf_prefix + '_direct', substs, params.sorry))
1172        else:
1173            substs['toptp'] = toptp
1174            # the split here gives us the field name (less any qualifiers) as the typ_heap
1175            # stuff doesn't use the qualifier
1176            substs['path'] = ', '.join(map(lambda x: "''%s''" % x.split('.')[-1], path))
1177
1178            # The self.name here is the variable name (so not _C)
1179            path.reverse()
1180            substs['update_path'] = '(' + reduce(lambda x, y: y + '_update (' + x + ')',
1181                                           ['\\<lambda>_. ' + name] + path) + '(the (ptrval s))' + ')'
1182            emit_named(fn_name, params, make_proof(prf_prefix + '_path', substs, params.sorry))
1183
1184def field_mask_proof(base, base_bits, sign_extend, high, size):
1185    if high:
1186        if base_bits == base or sign_extend:
1187            # equivalent to below, but nicer in proofs
1188            return "NOT (mask %d)" % (base_bits - size)
1189        else:
1190            return "(mask %d << %d)" % (size, base_bits - size)
1191    else:
1192        return "mask %d" % size
1193
1194def sign_extend_proof(high, base_bits, base_sign_extend):
1195    if high and base_sign_extend:
1196        return "sign_extend %d " % (base_bits - 1)
1197    else:
1198        return ""
1199
1200class TaggedUnion:
1201    def __init__(self, name, tagname, classes, tags):
1202        self.name = name
1203        self.tagname = tagname
1204        self.constant_suffix = ''
1205
1206        # Check for duplicate tags
1207        used_names = set()
1208        used_values = set()
1209        for name, value in tags:
1210            if name in used_names:
1211                raise ValueError("Duplicate tag name %s" % name)
1212            if value in used_values:
1213                raise ValueError("Duplicate tag value %d" % value)
1214
1215            used_names.add(name)
1216            used_values.add(value)
1217        self.classes = dict(classes)
1218        self.tags = tags
1219
1220    def resolve(self, params, symtab):
1221        # Grab block references for tags
1222        self.tags = [(name, value, symtab[name]) for name, value in self.tags]
1223        self.make_classes(params)
1224
1225        # Ensure that block sizes and tag size & position match for
1226        # all tags in the union
1227        union_base = None
1228        union_size = None
1229        for name, value, ref in self.tags:
1230            _tag_offset, _tag_size, _tag_high = ref.field_map[self.tagname]
1231
1232            if union_base is None:
1233                union_base = ref.base
1234            elif union_base != ref.base:
1235                raise ValueError("Base mismatch for element %s"
1236                                 " of tagged union %s" % (name, self.name))
1237
1238            if union_size is None:
1239                union_size = ref.size
1240            elif union_size != ref.size:
1241                raise ValueError("Size mismatch for element %s"
1242                                 " of tagged union %s" % (name, self.name))
1243
1244            if _tag_offset != self.tag_offset[_tag_size]:
1245                raise ValueError("Tag offset mismatch for element %s"
1246                                 " of tagged union %s" % (name, self.name))
1247
1248            self.assert_value_in_class(name, value, _tag_size)
1249
1250            if _tag_high:
1251                raise ValueError("Tag field is high-aligned for element %s"
1252                                 " of tagged union %s" % (name, self.name))
1253
1254            # Flag block as belonging to a tagged union
1255            ref.tagged = True
1256
1257        self.union_base = union_base
1258        self.union_size = union_size
1259
1260    def set_base(self, base, base_bits, base_sign_extend, suffix):
1261        self.base = base
1262        self.multiple = self.union_size // base
1263        self.constant_suffix = suffix
1264        self.base_bits = base_bits
1265        self.base_sign_extend = base_sign_extend
1266
1267        tag_index = None
1268        for w in self.tag_offset:
1269            tag_offset = self.tag_offset[w]
1270
1271            if tag_index is None:
1272                tag_index = tag_offset // base
1273
1274            if (tag_offset // base) != tag_index:
1275                raise ValueError(
1276                    "The tag field of tagged union %s"
1277                    " is in a different word (%s) to the others (%s)."
1278                    % (self.name, hex(tag_offset // base), hex(tag_index)))
1279
1280    def generate_hol_proofs(self, params, type_map):
1281        output = params.output
1282
1283        # Add fixed simp rule for struct
1284        print("lemmas %(name)s_C_words_C_fl_simp[simp] = "\
1285                        "%(name)s_C_words_C_fl[simplified]" % \
1286                        {"name": self.name}, file=output)
1287        print(file=output)
1288
1289        # Generate struct field pointer proofs
1290        substs = {"name": self.name,
1291                  "words": self.multiple,
1292                  "base": self.base}
1293
1294        print(make_proof('words_NULL_proof',
1295                                   substs, params.sorry), file=output)
1296        print(file=output)
1297
1298        print(make_proof('words_aligned_proof',
1299                                   substs, params.sorry), file=output)
1300        print(file=output)
1301
1302        print(make_proof('words_ptr_safe_proof',
1303                                   substs, params.sorry), file=output)
1304        print(file=output)
1305
1306        # Generate struct lemmas
1307        print(struct_lemmas_template % {"name": self.name},
1308        file=output)
1309        print(file=output)
1310
1311        # Generate get_tag specs
1312        substs = {"name": self.name,
1313                  "tagname": self.tagname,
1314                  "ret_name": return_name(self.base)}
1315
1316        if not params.skip_modifies:
1317            emit_named("%(name)s_get_%(tagname)s" % substs, params,
1318                make_proof('const_modifies_proof',
1319                    {"fun_name": "%(name)s_get_%(tagname)s" % substs, \
1320                     "args": ', '.join(["\<acute>ret__unsigned_long", \
1321                                        "\<acute>%(name)s" % substs])},
1322                    params.sorry))
1323            emit_named("%(name)s_ptr_get_%(tagname)s" % substs, params,
1324                make_proof('const_modifies_proof',
1325                    {"fun_name": "%(name)s_ptr_get_%(tagname)s" % substs, \
1326                     "args": ', '.join(["\<acute>ret__unsigned_long", \
1327                                        "\<acute>%(name)s_ptr" % substs])},
1328                    params.sorry))
1329
1330        emit_named("%s_get_%s" % (self.name, self.tagname), params,
1331                   make_proof('get_tag_spec', substs, params.sorry))
1332
1333        emit_named("%s_%s_equals" % (self.name, self.tagname), params,
1334                   make_proof('get_tag_equals_spec', substs, params.sorry))
1335
1336        # Only generate ptr lemmas for those types reachable from top level types
1337        emit_named_ptr_proof("%s_ptr_get_%s" % (self.name, self.tagname), params, self.name,
1338                             type_map, params.toplevel_types,
1339                             'ptr_get_tag_spec', substs)
1340
1341        for name, value, ref in self.tags:
1342            # Generate struct_new specs
1343            arg_list = ["\<acute>" + field
1344                        for field in ref.visible_order
1345                        if field != self.tagname]
1346
1347            # Generate modifies proof
1348            if not params.skip_modifies:
1349                emit_named("%s_%s_new" % (self.name, ref.name), params,
1350                           make_proof('const_modifies_proof',
1351                               {"fun_name": "%s_%s_new" % \
1352                                            (self.name, ref.name), \
1353                                "args": ', '.join([
1354                                "\<acute>ret__struct_%(name)s_C" % substs] + \
1355                                arg_list)},
1356                               params.sorry))
1357
1358                emit_named("%s_%s_ptr_new" % (self.name, ref.name), params,
1359                           make_proof('ptr_set_modifies_proof',
1360                               {"fun_name": "%s_%s_ptr_new" % \
1361                                            (self.name, ref.name), \
1362                                "args": ', '.join([
1363                                "\<acute>ret__struct_%(name)s_C" % substs] + \
1364                                arg_list)},
1365                               params.sorry))
1366
1367            if len(arg_list) == 0:
1368                # For an empty block:
1369                emit_named("%s_%s_new" % (self.name, ref.name), params,
1370                           make_proof('empty_union_new_spec',
1371                               {"name": self.name, \
1372                                "block": ref.name},
1373                               params.sorry))
1374
1375                emit_named_ptr_proof("%s_%s_ptr_new" % (self.name, ref.name), params, self.name,
1376                                     type_map, params.toplevel_types,
1377                                     'ptr_empty_union_new_spec',
1378                                     {"name": self.name, \
1379                                      "block": ref.name})
1380            else:
1381                field_eq_list = []
1382                for field in ref.visible_order:
1383                    offset, size, high = ref.field_map[field]
1384
1385                    if field == self.tagname:
1386                        continue
1387
1388                    mask = field_mask_proof(self.base, self.base_bits, self.base_sign_extend, high, size)
1389                    sign_extend = sign_extend_proof(high, self.base_bits, self.base_sign_extend)
1390                    field_eq_list.append(
1391                        "%s_%s_CL.%s_CL = %s(\<^bsup>s\<^esup>%s AND %s)" % \
1392                        (self.name, ref.name, field, sign_extend, field, mask))
1393                field_eqs = ',\n          '.join(field_eq_list)
1394
1395                emit_named("%s_%s_new" % (self.name, ref.name), params,
1396                           make_proof('union_new_spec',
1397                               {"name": self.name, \
1398                                "block": ref.name, \
1399                                "args": ', '.join(arg_list), \
1400                                "field_eqs": field_eqs},
1401                               params.sorry))
1402
1403                emit_named_ptr_proof("%s_%s_ptr_new" % (self.name, ref.name), params, self.name,
1404                                     type_map, params.toplevel_types,
1405                                     'ptr_union_new_spec',
1406                                     {"name": self.name, \
1407                                      "block": ref.name, \
1408                                      "args": ', '.join(arg_list), \
1409                                      "field_eqs": field_eqs})
1410
1411            _, size, _ = ref.field_map[self.tagname]
1412            if any([w for w in self.widths if w < size]):
1413                tag_mask_helpers = ("%s_%s_tag_mask_helpers"
1414                                        % (self.name, ref.name))
1415            else:
1416                tag_mask_helpers = ""
1417
1418            # Generate get/set specs
1419            for (field, offset, size, high) in ref.fields:
1420                if field == self.tagname:
1421                    continue
1422
1423                mask = field_mask_proof(self.base, self.base_bits, self.base_sign_extend, high, size)
1424                sign_extend = sign_extend_proof(high, self.base_bits, self.base_sign_extend)
1425
1426                substs = {"name":  self.name,
1427                          "block": ref.name,
1428                          "field": field,
1429                          "mask":  mask,
1430                          "sign_extend": sign_extend,
1431                          "tag_mask_helpers" : tag_mask_helpers,
1432                          "ret_name": return_name(self.base),
1433					      "base" : self.base}
1434
1435                # Get modifies spec
1436                if not params.skip_modifies:
1437                    emit_named("%s_%s_get_%s" % (self.name, ref.name, field),
1438                               params,
1439                               make_proof('const_modifies_proof',
1440                                   {"fun_name": "%s_%s_get_%s" % \
1441                                        (self.name, ref.name, field), \
1442                                    "args": ', '.join([
1443                                    "\<acute>ret__unsigned_long",
1444                                    "\<acute>%s" % self.name] )},
1445                                   params.sorry))
1446
1447                    emit_named("%s_%s_ptr_get_%s" % (self.name, ref.name, field),
1448                               params,
1449                               make_proof('const_modifies_proof',
1450                                   {"fun_name": "%s_%s_ptr_get_%s" % \
1451                                        (self.name, ref.name, field), \
1452                                    "args": ', '.join([
1453                                    "\<acute>ret__unsigned_long",
1454                                    "\<acute>%s_ptr" % self.name] )},
1455                                   params.sorry))
1456
1457                # Get spec
1458                emit_named("%s_%s_get_%s" % (self.name, ref.name, field),
1459                           params,
1460                           make_proof('union_get_spec', substs, params.sorry))
1461
1462                # Set modifies spec
1463                if not params.skip_modifies:
1464                    emit_named("%s_%s_set_%s" % (self.name, ref.name, field),
1465                               params,
1466                               make_proof('const_modifies_proof',
1467                                   {"fun_name": "%s_%s_set_%s" % \
1468                                        (self.name, ref.name, field), \
1469                                    "args": ', '.join([
1470                                    "\<acute>ret__struct_%s_C" % self.name,
1471                                    "\<acute>%s" % self.name,
1472                                    "\<acute>v%(base)d"] )},
1473                                   params.sorry))
1474
1475                    emit_named("%s_%s_ptr_set_%s" % (self.name, ref.name, field),
1476                               params,
1477                               make_proof('ptr_set_modifies_proof',
1478                                   {"fun_name": "%s_%s_ptr_set_%s" % \
1479                                        (self.name, ref.name, field), \
1480                                    "args": ', '.join([
1481                                    "\<acute>%s_ptr" % self.name,
1482                                    "\<acute>v%(base)d"] )},
1483                                   params.sorry))
1484
1485                # Set spec
1486                emit_named("%s_%s_set_%s" % (self.name, ref.name, field),
1487                           params,
1488                           make_proof('union_set_spec', substs, params.sorry))
1489
1490                # Pointer get spec
1491                emit_named_ptr_proof("%s_%s_ptr_get_%s" % (self.name, ref.name, field),
1492                                     params, self.name, type_map, params.toplevel_types,
1493                                     'ptr_union_get_spec', substs)
1494
1495                # Pointer set spec
1496                emit_named_ptr_proof("%s_%s_ptr_set_%s" % (self.name, ref.name, field),
1497                                     params, self.name, type_map, params.toplevel_types,
1498                                     'ptr_union_set_spec', substs)
1499
1500    def generate_hol_defs(self, params):
1501        output = params.output
1502
1503        empty_blocks = {}
1504
1505        def gen_name(ref_name, capitalise = False):
1506            # Create datatype generator/type name for a block
1507            if capitalise:
1508                return "%s_%s" % \
1509                       (self.name[0].upper() + self.name[1:], ref_name)
1510            else:
1511                return "%s_%s" % (self.name, ref_name)
1512
1513        # Generate block records with tag field removed
1514        for name, value, ref in self.tags:
1515            if ref.generate_hol_defs(params, \
1516                                     suppressed_field = self.tagname, \
1517                                     prefix="%s_" % self.name, \
1518                                     in_union = True):
1519                empty_blocks[ref] = True
1520
1521        constructor_exprs = []
1522        for name, value, ref in self.tags:
1523            if ref in empty_blocks:
1524                constructor_exprs.append(gen_name(name, True))
1525            else:
1526                constructor_exprs.append("%s %s_CL" % \
1527                    (gen_name(name, True), gen_name(name)))
1528
1529        print("datatype %s_CL =\n    %s\n" % \
1530                        (self.name, '\n  | '.join(constructor_exprs)),
1531                        file=output)
1532
1533        # Generate get_tag definition
1534        subs = {"name":      self.name,
1535                "base":      self.base }
1536
1537        templates = ([union_get_tag_def_entry_template] * (len(self.widths) - 1)
1538                   + [union_get_tag_def_final_template])
1539
1540        fs = (union_get_tag_def_header_template % subs
1541            + "".join([template %
1542                         dict(subs,
1543                              tag_size=width,
1544                              classmask=self.word_classmask(width),
1545                              tag_index=self.tag_offset[width] // self.base,
1546                              tag_shift=self.tag_offset[width] % self.base)
1547                       for template, width in zip(templates, self.widths)])
1548            + union_get_tag_def_footer_template % subs)
1549
1550        print(fs, file=output)
1551        print(file=output)
1552
1553        # Generate get_tag_eq_x lemma
1554        templates = ([union_get_tag_eq_x_def_entry_template]
1555                        * (len(self.widths) - 1)
1556                   + [union_get_tag_eq_x_def_final_template])
1557
1558        fs = (union_get_tag_eq_x_def_header_template % subs
1559            + "".join([template %
1560                         dict(subs,
1561                              tag_size=width,
1562                              classmask=self.word_classmask(width),
1563                              tag_index=self.tag_offset[width] // self.base,
1564                              tag_shift=self.tag_offset[width] % self.base)
1565                       for template, width in zip(templates, self.widths)])
1566            + union_get_tag_eq_x_def_footer_template % subs)
1567
1568        print(fs, file=output)
1569        print(file=output)
1570
1571        # Generate mask helper lemmas
1572
1573        for name, value, ref in self.tags:
1574            offset, size, _ = ref.field_map[self.tagname]
1575            part_widths = [w for w in self.widths if w < size]
1576            if part_widths:
1577                subs = {"name":         self.name,
1578                        "block":        name,
1579                        "full_mask":    hex(2 ** size - 1),
1580                        "full_value":   hex(value) }
1581
1582                fs = (union_tag_mask_helpers_header_template % subs
1583                    + "".join([union_tag_mask_helpers_entry_template %
1584                               dict(subs, part_mask=hex(2 ** pw - 1),
1585                                          part_value=hex(value & (2 ** pw - 1)))
1586                               for pw in part_widths])
1587                    + union_tag_mask_helpers_footer_template)
1588
1589                print(fs, file=output)
1590                print(file=output)
1591
1592        # Generate lift definition
1593        collapse_proofs = ""
1594        tag_cases = []
1595        for name, value, ref in self.tags:
1596            field_inits = []
1597
1598            for field in ref.visible_order:
1599                offset, size, high = ref.field_map[field]
1600
1601                if field == self.tagname: continue
1602
1603                index = offset // self.base
1604                sign_extend = ""
1605
1606                if high:
1607                    shift_op = "<<"
1608                    shift = self.base_bits - size - (offset % self.base)
1609                    if shift < 0:
1610                        shift = -shift
1611                        shift_op = ">>"
1612                    if self.base_sign_extend:
1613                        sign_extend = "sign_extend %d " % (self.base_bits - 1)
1614                else:
1615                    shift_op = ">>"
1616                    shift = offset % self.base
1617
1618                initialiser = \
1619                    "%s_CL.%s_CL = %s(((index (%s_C.words_C %s) %d) %s %d)" % \
1620                    (gen_name(name), field, sign_extend, self.name, self.name, \
1621                     index, shift_op, shift)
1622
1623                if size < self.base:
1624                    mask = field_mask_proof(self.base, self.base_bits, self.base_sign_extend, high, size)
1625                    initialiser += " AND " + mask
1626
1627                field_inits.append("\n       " + initialiser + ")")
1628
1629            if len(field_inits) == 0:
1630                value = gen_name(name, True)
1631            else:
1632                value = "%s \<lparr> %s \<rparr>" % \
1633                    (gen_name(name, True), ','.join(field_inits))
1634
1635            tag_cases.append("if tag = scast %s then Some (%s)" % \
1636                             (gen_name(name), value))
1637
1638            collapse_proofs += \
1639                make_proof("lift_collapse_proof",
1640                           {"name": self.name, \
1641                            "block": name, \
1642                            "value": value},
1643                           params.sorry)
1644            collapse_proofs += "\n\n"
1645
1646        print(union_lift_def_template % \
1647                        {"name": self.name, \
1648                         "tag_cases": '\n     else '.join(tag_cases)},
1649                        file=output)
1650        print(file=output)
1651
1652        print(collapse_proofs, file=output)
1653
1654        block_lift_lemmas = "lemmas %s_lifts = \n" % self.name
1655        # Generate lifted access/update definitions, and abstract lifters
1656        for name, value, ref in self.tags:
1657            # Don't generate accessors if the block (minus tag) is empty
1658            if ref in empty_blocks: continue
1659
1660            substs = {"union": self.name, \
1661                      "block": name, \
1662                      "generator": gen_name(name, True)}
1663
1664            for t in [union_access_def_template, union_update_def_template]:
1665                print(t % substs, file=output)
1666                print(file=output)
1667
1668            print(block_lift_def_template % substs, file=output)
1669            print(file=output)
1670
1671            print(block_lift_lemma_template % substs, file=output)
1672            print(file=output)
1673
1674            block_lift_lemmas += "\t%(union)s_%(block)s_lift\n" % substs
1675
1676        print(block_lift_lemmas, file=output)
1677        print(file=output)
1678
1679    def generate(self, params):
1680        output = params.output
1681
1682        # Generate typedef
1683        print(typedef_template % \
1684                        {"type": TYPES[options.environment][self.base], \
1685                         "name": self.name, \
1686                         "multiple": self.multiple}, file=output)
1687        print(file=output)
1688
1689        # Generate tag enum
1690        print("enum %s_tag {" % self.name, file=output)
1691        if len(self.tags) > 0:
1692            for name, value, ref in self.tags[:-1]:
1693                print("    %s_%s = %d," % (self.name, name, value),
1694                file=output)
1695            name, value, ref = self.tags[-1];
1696            print("    %s_%s = %d" % (self.name, name, value),
1697            file=output)
1698        print("};", file=output)
1699        print("typedef enum %s_tag %s_tag_t;" % \
1700                        (self.name, self.name), file=output)
1701        print(file=output)
1702
1703        subs = {\
1704            'inline': INLINE[options.environment], \
1705            'union': self.name, \
1706            'type':  TYPES[options.environment][self.union_base], \
1707            'tagname': self.tagname, \
1708            'suf' : self.constant_suffix}
1709
1710        # Generate tag reader
1711        templates = ([tag_reader_entry_template] * (len(self.widths) - 1)
1712                   + [tag_reader_final_template])
1713
1714        fs = (tag_reader_header_template % subs
1715            + "".join([template %
1716                         dict(subs,
1717                              mask=2 ** width - 1,
1718                              classmask=self.word_classmask(width),
1719                              index=self.tag_offset[width] // self.base,
1720                              shift=self.tag_offset[width] % self.base)
1721                       for template, width in zip(templates, self.widths)])
1722            + tag_reader_footer_template % subs)
1723
1724        emit_named("%s_get_%s" % (self.name, self.tagname), params, fs)
1725
1726        # Generate tag eq reader
1727        templates = ([tag_eq_reader_entry_template] * (len(self.widths) - 1)
1728                   + [tag_eq_reader_final_template])
1729
1730        fs = (tag_eq_reader_header_template % subs
1731            + "".join([template %
1732                         dict(subs,
1733                              mask=2 ** width - 1,
1734                              classmask=self.word_classmask(width),
1735                              index=self.tag_offset[width] // self.base,
1736                              shift=self.tag_offset[width] % self.base)
1737                       for template, width in zip(templates, self.widths)])
1738            + tag_eq_reader_footer_template % subs)
1739
1740        emit_named("%s_%s_equals" % (self.name, self.tagname), params, fs)
1741
1742        # Generate pointer lifted tag reader
1743        templates = ([ptr_tag_reader_entry_template] * (len(self.widths) - 1)
1744                   + [ptr_tag_reader_final_template])
1745
1746        fs = (ptr_tag_reader_header_template % subs
1747            + "".join([template %
1748                         dict(subs,
1749                              mask=2 ** width - 1,
1750                              classmask=self.word_classmask(width),
1751                              index=self.tag_offset[width] // self.base,
1752                              shift=self.tag_offset[width] % self.base)
1753                       for template, width in zip(templates, self.widths)])
1754            + ptr_tag_reader_footer_template % subs)
1755
1756        emit_named("%s_ptr_get_%s" % (self.name, self.tagname), params, fs)
1757
1758        for name, value, ref in self.tags:
1759            # Generate generators
1760            param_fields = [field for field in ref.visible_order if field != self.tagname]
1761            param_list = ["%s %s" % (TYPES[options.environment][self.base], field)
1762                          for field in param_fields]
1763
1764            if len(param_list) == 0:
1765                gen_params = 'void'
1766            else:
1767                gen_params = ', '.join(param_list)
1768
1769            ptr_params = ', '.join(["%s_t *%s_ptr" % (self.name, self.name)] + param_list)
1770
1771            field_updates = {word: [] for word in range(self.multiple)}
1772            field_asserts = ["    /* fail if user has passed bits that we will override */"]
1773
1774            for field in ref.visible_order:
1775                offset, size, high = ref.field_map[field]
1776
1777                if field == self.tagname:
1778                    f_value = "(%s)%s_%s" % (TYPES[options.environment][self.base], self.name, name)
1779                else:
1780                    f_value = field
1781
1782                index = offset // self.base
1783                if high:
1784                    shift_op = ">>"
1785                    shift = self.base_bits - size - (offset % self.base)
1786                    if self.base_sign_extend:
1787                        high_bits = ((self.base_sign_extend << (self.base - self.base_bits)) - 1) << self.base_bits
1788                    else:
1789                        high_bits = 0
1790                    if shift < 0:
1791                        shift = -shift
1792                        shift_op = "<<"
1793                else:
1794                    shift_op = "<<"
1795                    shift = offset % self.base
1796                    high_bits = 0
1797                if size < self.base:
1798                    if high:
1799                        mask = ((1 << size) - 1) << (self.base_bits - size)
1800                    else:
1801                        mask = (1 << size) - 1
1802                    suf = self.constant_suffix
1803
1804                    field_asserts.append(
1805                        "    %s((%s & ~0x%x%s) == ((%d && (%s & (1%s << %d))) ? 0x%x : 0));"
1806                        % (ASSERTS[options.environment], f_value, mask, suf, self.base_sign_extend,
1807                           f_value, suf, self.base_bits - 1, high_bits))
1808
1809                    field_updates[index].append(
1810                        "(%s & 0x%x%s) %s %d" % (f_value, mask, suf, shift_op, shift))
1811
1812                else:
1813                    field_updates[index].append("%s %s %d" % (f_value, shift_op, shift))
1814
1815            word_inits = [
1816                ("words[%d] = 0" % index) + ''.join(["\n        | %s" % up for up in ups]) + ';'
1817                for (index, ups) in field_updates.items()]
1818
1819            def mk_inits(prefix):
1820                return '\n'.join(["    %s%s" % (prefix, word_init) for word_init in word_inits])
1821
1822            print_params = {
1823                "inline":     INLINE[options.environment],
1824                "block":      name,
1825                "union":      self.name,
1826                "gen_params": gen_params,
1827                "ptr_params": ptr_params,
1828                "gen_inits":  mk_inits("%s." % self.name),
1829                "ptr_inits":  mk_inits("%s_ptr->" % self.name),
1830                "asserts": '  \n'.join(field_asserts)
1831            }
1832
1833            generator = union_generator_template % print_params
1834            ptr_generator = ptr_union_generator_template % print_params
1835
1836            emit_named("%s_%s_new" % (self.name, name), params, generator)
1837            emit_named("%s_%s_ptr_new" % (self.name, name), params, ptr_generator)
1838
1839            # Generate field readers/writers
1840            tagnameoffset, tagnamesize, _ = ref.field_map[self.tagname]
1841            tagmask = (2 ** tagnamesize) - 1
1842            for field, offset, size, high in ref.fields:
1843                # Don't duplicate tag accessors
1844                if field == self.tagname: continue
1845
1846                index = offset // self.base
1847                if high:
1848                    write_shift = ">>"
1849                    read_shift = "<<"
1850                    shift = self.base_bits - size - (offset % self.base)
1851                    if shift < 0:
1852                        shift = -shift
1853                        write_shift = "<<"
1854                        read_shift = ">>"
1855                    if self.base_sign_extend:
1856                        high_bits = ((self.base_sign_extend << (self.base - self.base_bits)) - 1) << self.base_bits
1857                    else:
1858                        high_bits = 0
1859                else:
1860                    write_shift = "<<"
1861                    read_shift = ">>"
1862                    shift = offset % self.base
1863                    high_bits = 0
1864                mask = ((1 << size) - 1) << (offset % self.base)
1865
1866                subs = {\
1867                    "inline": INLINE[options.environment], \
1868                    "block": ref.name, \
1869                    "field": field, \
1870                    "type": TYPES[options.environment][ref.base], \
1871                    "assert": ASSERTS[options.environment], \
1872                    "index": index, \
1873                    "shift": shift, \
1874                    "r_shift_op": read_shift, \
1875                    "w_shift_op": write_shift, \
1876                    "mask": mask, \
1877                    "tagindex": tagnameoffset // self.base, \
1878                    "tagshift": tagnameoffset % self.base, \
1879                    "tagmask": tagmask, \
1880                    "union": self.name, \
1881                    "suf": self.constant_suffix,
1882                    "high_bits": high_bits,
1883                    "sign_extend": self.base_sign_extend and high,
1884                    "extend_bit": self.base_bits - 1,
1885                    "base": self.base}
1886
1887                # Reader
1888                emit_named("%s_%s_get_%s" % (self.name, ref.name, field),
1889                           params, union_reader_template % subs)
1890
1891                # Writer
1892                emit_named("%s_%s_set_%s" % (self.name, ref.name, field),
1893                           params, union_writer_template % subs)
1894
1895                # Pointer lifted reader
1896                emit_named("%s_%s_ptr_get_%s" % (self.name, ref.name, field),
1897                           params, ptr_union_reader_template % subs)
1898
1899                # Pointer lifted writer
1900                emit_named("%s_%s_ptr_set_%s" % (self.name, ref.name, field),
1901                           params, ptr_union_writer_template % subs)
1902
1903    def make_names(self):
1904        "Return the set of candidate function names for a union"
1905
1906        substs = {"union" : self.name, \
1907                  "tagname": self.tagname}
1908        names = [t % substs for t in [
1909        "%(union)s_get_%(tagname)s",
1910        "%(union)s_ptr_get_%(tagname)s",
1911	"%(union)s_%(tagname)s_equals"]]
1912
1913        for name, value, ref in self.tags:
1914            names += ref.make_names(self)
1915
1916        return names
1917
1918    def represent_value(self, value, width):
1919        max_width = max(self.classes.keys())
1920
1921        tail_str = ("{:0{}b}".format(value, width)
1922                        + "_" * (self.tag_offset[width] - self.class_offset))
1923        head_str = "_" * ((max_width + self.tag_offset[max_width]
1924                            - self.class_offset) - len(tail_str))
1925
1926        return head_str + tail_str
1927
1928    def represent_class(self, width):
1929        max_width = max(self.classes.keys())
1930
1931        cmask = self.classes[width]
1932        return ("{:0{}b}".format(cmask, max_width).replace("0", "-")
1933                + " ({:#x})".format(cmask))
1934
1935    def represent_field(self, width):
1936        max_width = max(self.classes.keys())
1937        offset = self.tag_offset[width] - self.class_offset
1938
1939        return ("{:0{}b}".format((2 ** width - 1) << offset, max_width)
1940                .replace("0", "-").replace("1", "#"))
1941
1942    def assert_value_in_class(self, name, value, width):
1943        max_width = max(self.classes.keys())
1944        ovalue = value << self.tag_offset[width]
1945        cvalue = value << (self.tag_offset[width] - self.class_offset)
1946
1947        offset_field = (2 ** width - 1) << self.tag_offset[width]
1948        if (ovalue | offset_field) != offset_field:
1949                raise ValueError(
1950                    "The value for element %s of tagged union %s,\n"
1951                    "    %s\nexceeds the field bounds\n"
1952                    "    %s."
1953                    % (name, self.name,
1954                       self.represent_value(value, width),
1955                       self.represent_field(width)))
1956
1957        for w, mask in [(lw, self.classes[lw])
1958                        for lw in self.widths if lw < width]:
1959            if (cvalue & mask) != mask:
1960                raise ValueError(
1961                    "The value for element %s of tagged union %s,\n"
1962                    "    %s\nis invalid: it has %d bits but fails"
1963                    " to match the earlier mask at %d bits,\n"
1964                    "    %s."
1965                    % (name, self.name,
1966                       self.represent_value(value, width),
1967                       width, w, self.represent_class(w)))
1968
1969        if (self.widths.index(width) + 1 < len(self.widths) and
1970            (cvalue & self.classes[width]) == self.classes[width]):
1971            raise ValueError(
1972                "The value for element %s of tagged union %s,\n"
1973                "    %s (%d/%s)\nis invalid: it must not match the "
1974                "mask for %d bits,\n    %s."
1975                % (name, self.name,
1976                   ("{:0%db}" % width).format(cvalue),
1977                   value, hex(value),
1978                   width,
1979                   self.represent_class(width)))
1980
1981    def word_classmask(self, width):
1982        "Return a class mask for testing a whole word, i.e., one."
1983        "that is positioned absolutely relative to the lsb of the"
1984        "relevant word."
1985
1986        return (self.classes[width] << (self.class_offset % self.base))
1987
1988    def make_classes(self, params):
1989        "Calculate an encoding for variable width tagnames"
1990
1991        # Check self.classes, which maps from the bit width of tagname in a
1992        # particular block to a classmask that identifies when a value belongs
1993        # to wider tagname.
1994        #
1995        # For example, given three possible field widths -- 4, 8, and 12 bits --
1996        # one possible encoding is:
1997        #
1998        #                       * * _ _     (** != 11)
1999        #             0 _ _ _   1 1 _ _
2000        #   _ _ _ _   1 _ _ _   1 1 _ _
2001        #
2002        # where the 3rd and 4th lsbs signify whether the field should be
2003        # interpreted using a 4-bit mask (if 00, 01, or 10) or as an 8 or 16 bit
2004        # mask (if 11). And, in the latter case, the 8th lsb signifies whether
2005        # to intrepret it as an 8 bit field (if 0) or a 16 bit field (if 1).
2006        #
2007        # In this example we have:
2008        #   4-bit class:  classmask = 0b00001100
2009        #   8-bit class:  classmask = 0b10001100
2010        #  16-bit class:  classmask = 0b10001100
2011        #
2012        # More generally, the fields need not all start at the same offset
2013        # (measured "left" from the lsb), for example:
2014        #
2015        #    ...# ###. .... ....       4-bit field at offset 9
2016        #    ..## #### ##.. ....       8-bit field at offset 6
2017        #    #### #### #### ....      12-bit field at offset 4
2018        #
2019        # In this case, the class_offset is the minimum offset (here 4).
2020        # Classmasks are declared relative to the field, but converted
2021        # internally to be relative to the class_offset; tag_offsets
2022        # are absolute (relative to the lsb); values are relative to
2023        # the tag_offset (i.e., within the field). for example:
2024        #
2025        #    ...1 100. ....    4-bit class: classmask=0xc   tag_offset=9
2026        #    ..01 1000 10..    8-bit class: classmask=0x62  tag_offset=6
2027        #    0001 1000 1000   16-bit class: classmask=0x188 tag_offset=4
2028
2029        used = set()
2030        self.tag_offset = {}
2031        for name, _, ref in self.tags:
2032            offset, size, _ = ref.field_map[self.tagname]
2033            used.add(size)
2034            self.tag_offset[size] = offset
2035
2036        self.class_offset = min(self.tag_offset.values())
2037
2038        # internally, classmasks are relative to the class_offset, so
2039        # that we can compare them to each other.
2040        for w in self.classes:
2041            self.classes[w] <<= self.tag_offset[w] - self.class_offset
2042
2043        used_widths = sorted(list(used))
2044        assert(len(used_widths) > 0)
2045
2046        if not self.classes:
2047            self.classes = { used_widths[0] : 0 }
2048
2049        # sanity checks on classes
2050        classes = self.classes
2051        widths = sorted(self.classes.keys())
2052        context = "masks for %s.%s" % (self.name, self.tagname)
2053        class_offset = self.class_offset
2054
2055        for uw in used_widths:
2056            if uw not in classes:
2057                raise ValueError("%s: none defined for a field of %d bits."
2058                                    % (context, uw))
2059
2060        for mw in classes.keys():
2061            if mw not in used_widths:
2062                raise ValueError(
2063                    "%s: there is a mask with %d bits but no corresponding fields."
2064                        % (context, mw))
2065
2066        for w in widths:
2067            offset_field = (2 ** w - 1) << self.tag_offset[w]
2068            if (classes[w] << class_offset) | offset_field != offset_field:
2069                raise ValueError(
2070                        "{:s}: the mask for {:d} bits:\n  {:s}\n"
2071                        "exceeds the field bounds:\n  {:s}."
2072                        .format(context, w, self.represent_class(w),
2073                                self.represent_field(w)))
2074
2075        if len(widths) > 1 and classes[widths[0]] == 0:
2076            raise ValueError("%s: the first (width %d) is zero." % (
2077                                context, widths[0]))
2078
2079        if any([classes[widths[i-1]] == classes[widths[i]]
2080                for i in range(1, len(widths) - 1)]):
2081            raise ValueError("%s: there is a non-final duplicate!" % context)
2082
2083        # smaller masks are included within larger masks
2084        pre_mask = None
2085        pre_width = None
2086        for w in widths:
2087            if pre_mask is not None and (classes[w] & pre_mask) != pre_mask:
2088                raise ValueError(
2089                    "{:s}: the mask\n  0b{:b} for width {:d} does not include "
2090                    "the mask\n  0b{:b} for width {:d}.".format(
2091                        context, classes[w], w, pre_mask, pre_width))
2092            pre_width = w
2093            pre_mask = classes[w]
2094
2095        if params.showclasses:
2096            print("-----%s.%s" % (self.name, self.tagname), file=sys.stderr)
2097            for w in widths:
2098                print("{:2d} = {:s}".format( w, self.represent_class(w)),
2099                        file=sys.stderr)
2100
2101        self.widths = widths
2102
2103class Block:
2104    def __init__(self, name, fields, visible_order):
2105        offset = 0
2106        _fields = []
2107        self.size = sum(size for _name, size, _high in fields)
2108        offset = self.size
2109        self.constant_suffix = ''
2110
2111        if visible_order is None:
2112            self.visible_order = []
2113
2114        for _name, _size, _high in fields:
2115            offset -= _size
2116            if not _name is None:
2117                if visible_order is None:
2118                    self.visible_order.append(_name)
2119                _fields.append((_name, offset, _size, _high))
2120
2121        self.name = name
2122        self.tagged = False
2123        self.fields = _fields
2124        self.field_map = dict((name, (offset, size, high)) \
2125                              for name, offset, size, high in _fields)
2126
2127        if not visible_order is None:
2128            missed_fields = set(self.field_map.keys())
2129
2130            for _name in visible_order:
2131                if _name not in self.field_map:
2132                    raise ValueError("Nonexistent field '%s' in visible_order"
2133                                     % _name)
2134                missed_fields.remove(_name)
2135
2136            if len(missed_fields) > 0:
2137                raise ValueError("Fields %s missing from visible_order" % \
2138                                 str([x for x in missed_fields]))
2139
2140            self.visible_order = visible_order
2141
2142    def set_base(self, base, base_bits, base_sign_extend, suffix):
2143        self.base = base
2144        self.constant_suffix = suffix
2145        self.base_bits = base_bits
2146        self.base_sign_extend = base_sign_extend
2147        if self.size % base != 0:
2148            raise ValueError("Size of block %s not a multiple of base" \
2149                             % self.name)
2150        self.multiple = self.size // base
2151        for name, offset, size, high in self.fields:
2152            if offset // base != (offset+size-1) // base:
2153                raise ValueError("Field %s of block %s " \
2154                                 "crosses a word boundary" \
2155                                 % (name, self.name))
2156
2157    def generate_hol_defs(self, params, suppressed_field=None, \
2158                                prefix="", in_union = False):
2159        output = params.output
2160
2161        # Don't generate raw records for blocks in tagged unions
2162        if self.tagged and not in_union: return
2163
2164        _name = prefix + self.name
2165
2166        # Generate record def
2167        out = "record %s_CL =\n" % _name
2168
2169        empty = True
2170
2171        for field in self.visible_order:
2172            if suppressed_field == field:
2173                continue
2174
2175            empty = False
2176
2177            out += '    %s_CL :: "word%d"\n' % (field, self.base)
2178
2179        word_updates = ""
2180
2181        if not empty:
2182            print(out, file=output)
2183
2184        # Generate lift definition
2185        if not in_union:
2186            field_inits = []
2187
2188            for name in self.visible_order:
2189                offset, size, high = self.field_map[name]
2190
2191                index = offset // self.base
2192                sign_extend = ""
2193
2194                if high:
2195                    shift_op = "<<"
2196                    shift = self.base_bits - size - (offset % self.base)
2197                    if shift < 0:
2198                        shift = -shift
2199                        shift_op = ">>"
2200                    if self.base_sign_extend:
2201                        sign_extend = "sign_extend %d " % (self.base_bits - 1)
2202                else:
2203                    shift_op = ">>"
2204                    shift = offset % self.base
2205
2206                initialiser = \
2207                    "%s_CL.%s_CL = %s(((index (%s_C.words_C %s) %d) %s %d)" % \
2208                    (self.name, name, sign_extend, self.name, self.name, \
2209                     index, shift_op, shift)
2210
2211                if size < self.base:
2212                    mask = field_mask_proof(self.base, self.base_bits, self.base_sign_extend, high, size)
2213
2214                    initialiser += " AND " + mask
2215
2216                field_inits.append(initialiser + ")")
2217
2218            print(lift_def_template % \
2219                            {"name": self.name, \
2220                             "fields": ',\n       '.join(field_inits)},
2221                            file=output)
2222            print(file=output)
2223
2224        return empty
2225
2226    def generate_hol_proofs(self, params, type_map):
2227        output = params.output
2228
2229        if self.tagged: return
2230
2231        # Add fixed simp rule for struct
2232        print("lemmas %(name)s_C_words_C_fl_simp[simp] = "\
2233                        "%(name)s_C_words_C_fl[simplified]" % \
2234                        {"name": self.name}, file=output)
2235        print(file=output)
2236
2237        # Generate struct field pointer proofs
2238        substs = {"name": self.name,
2239                  "words": self.multiple,
2240                  "base": self.base}
2241
2242        print(make_proof('words_NULL_proof',
2243                                   substs, params.sorry), file=output)
2244        print(file=output)
2245
2246        print(make_proof('words_aligned_proof',
2247                                   substs, params.sorry), file=output)
2248        print(file=output)
2249
2250        print(make_proof('words_ptr_safe_proof',
2251                                   substs, params.sorry), file=output)
2252        print(file=output)
2253
2254        # Generate struct lemmas
2255        print(struct_lemmas_template % {"name": self.name},
2256        file=output)
2257        print(file=output)
2258
2259        # Generate struct_new specs
2260        arg_list = ["\<acute>" + field for
2261                    (field, offset, size, high) in self.fields]
2262
2263        if not params.skip_modifies:
2264            emit_named("%s_new" % self.name, params,
2265                       make_proof('const_modifies_proof',
2266                           {"fun_name": "%s_new" % self.name, \
2267                            "args": ', '.join(["\<acute>ret__struct_%s_C" % \
2268                                               self.name] + \
2269                                              arg_list)},
2270                           params.sorry))
2271            # FIXME: ptr_new (doesn't seem to be used)
2272
2273        field_eq_list = []
2274        for (field, offset, size, high) in self.fields:
2275            mask = field_mask_proof(self.base, self.base_bits, self.base_sign_extend, high, size)
2276            sign_extend = sign_extend_proof(high, self.base_bits, self.base_sign_extend)
2277
2278            field_eq_list.append("%s_CL.%s_CL = %s(\<^bsup>s\<^esup>%s AND %s)" % \
2279                                 (self.name, field, sign_extend, field, mask))
2280        field_eqs = ',\n          '.join(field_eq_list)
2281
2282        emit_named("%s_new" % self.name, params,
2283                   make_proof('new_spec',
2284                       {"name": self.name, \
2285                        "args": ', '.join(arg_list), \
2286                        "field_eqs": field_eqs},
2287                       params.sorry))
2288
2289        emit_named_ptr_proof("%s_ptr_new" % self.name, params, self.name,
2290                             type_map, params.toplevel_types,
2291                             'ptr_new_spec',
2292                             {"name": self.name, \
2293                              "args": ', '.join(arg_list), \
2294                              "field_eqs": field_eqs})
2295
2296        # Generate get/set specs
2297        for (field, offset, size, high) in self.fields:
2298            mask = field_mask_proof(self.base, self.base_bits, self.base_sign_extend, high, size)
2299            sign_extend = sign_extend_proof(high, self.base_bits, self.base_sign_extend)
2300
2301            substs = {"name": self.name, \
2302                      "field": field, \
2303                      "mask": mask,
2304                      "sign_extend": sign_extend,
2305                      "ret_name": return_name(self.base),
2306                      "base": self.base}
2307
2308            if not params.skip_modifies:
2309                # Get modifies spec
2310                emit_named("%s_get_%s" % (self.name, field), params,
2311                           make_proof('const_modifies_proof',
2312                               {"fun_name": "%s_get_%s" % (self.name, field), \
2313                                "args": ', '.join([
2314                                "\<acute>ret__unsigned_long",
2315                                "\<acute>%s" % self.name] )},
2316                               params.sorry))
2317
2318                # Ptr get modifies spec
2319                emit_named("%s_ptr_get_%s" % (self.name, field), params,
2320                           make_proof('const_modifies_proof',
2321                               {"fun_name": "%s_ptr_get_%s" % (self.name, field), \
2322                                "args": ', '.join([
2323                                "\<acute>ret__unsigned_long",
2324                                "\<acute>%s_ptr" % self.name] )},
2325                               params.sorry))
2326
2327
2328            # Get spec
2329            emit_named("%s_get_%s" % (self.name, field), params,
2330                        make_proof('get_spec', substs, params.sorry))
2331
2332            if not params.skip_modifies:
2333                # Set modifies spec
2334                emit_named("%s_set_%s" % (self.name, field), params,
2335                           make_proof('const_modifies_proof',
2336                               {"fun_name": "%s_set_%s" % (self.name, field), \
2337                                "args": ', '.join([
2338                                "\<acute>ret__struct_%s_C" % self.name,
2339                                "\<acute>%s" % self.name,
2340                                "\<acute>v%(base)d"] )},
2341                               params.sorry))
2342
2343                emit_named("%s_ptr_set_%s" % (self.name, field), params,
2344                           make_proof('ptr_set_modifies_proof',
2345                               {"fun_name": "%s_ptr_set_%s" % (self.name, field), \
2346                                "args": ', '.join([
2347                                "\<acute>%s_ptr" % self.name,
2348                                "\<acute>v%(base)d"] )},
2349                               params.sorry))
2350
2351
2352            # Set spec
2353            emit_named("%s_set_%s" % (self.name, field), params,
2354                        make_proof('set_spec', substs, params.sorry))
2355
2356            emit_named_ptr_proof("%s_ptr_get_%s" % (self.name, field), params, self.name,
2357                                 type_map, params.toplevel_types,
2358                                 'ptr_get_spec', substs)
2359            emit_named_ptr_proof("%s_ptr_set_%s" % (self.name, field), params, self.name,
2360                                 type_map, params.toplevel_types,
2361                                 'ptr_set_spec', substs)
2362
2363    def generate(self, params):
2364        output = params.output
2365
2366        # Don't generate raw accessors for blocks in tagged unions
2367        if self.tagged: return
2368
2369        # Type definition
2370        print(typedef_template % \
2371                        {"type": TYPES[options.environment][self.base], \
2372                         "name": self.name, \
2373                         "multiple": self.multiple}, file=output)
2374        print(file=output)
2375
2376        # Generator
2377        param_fields = [field for field in self.visible_order]
2378        param_list = ["%s %s" % (TYPES[options.environment][self.base], field)
2379                      for field in param_fields]
2380
2381        if len(param_list) == 0:
2382            gen_params = 'void'
2383        else:
2384            gen_params = ', '.join(param_list)
2385
2386        ptr_params = ', '.join(["%s_t *%s_ptr" % (self.name, self.name)] + param_list)
2387
2388        field_updates = {word: [] for word in range(self.multiple)}
2389        field_asserts = ["    /* fail if user has passed bits that we will override */"]
2390
2391        for field, offset, size, high in self.fields:
2392            index = offset // self.base
2393            if high:
2394                shift_op = ">>"
2395                shift = self.base_bits - size - (offset % self.base)
2396                if self.base_sign_extend:
2397                    high_bits = ((self.base_sign_extend << (self.base - self.base_bits)) - 1) << self.base_bits
2398                else:
2399                    high_bits = 0
2400                if shift < 0:
2401                    shift = -shift
2402                    shift_op = "<<"
2403            else:
2404                shift_op = "<<"
2405                shift = offset % self.base
2406                high_bits = 0
2407            if size < self.base:
2408                if high:
2409                    mask = ((1 << size) - 1) << (self.base_bits - size)
2410                else:
2411                    mask = (1 << size) - 1
2412                suf = self.constant_suffix
2413
2414                field_asserts.append(
2415                    "    %s((%s & ~0x%x%s) == ((%d && (%s & (1%s << %d))) ? 0x%x : 0));"
2416                    % (ASSERTS[options.environment], field, mask, suf, self.base_sign_extend,
2417                       field, suf, self.base_bits - 1, high_bits))
2418
2419                field_updates[index].append(
2420                    "(%s & 0x%x%s) %s %d" % (field, mask, suf, shift_op, shift))
2421
2422            else:
2423                field_updates[index].append("%s %s %d;" % (field, shift_op, shift))
2424
2425        word_inits = [
2426            ("words[%d] = 0" % index) + ''.join(["\n        | %s" % up for up in ups]) + ';'
2427            for (index, ups) in field_updates.items()]
2428
2429        def mk_inits(prefix):
2430            return '\n'.join(["    %s%s" % (prefix, word_init) for word_init in word_inits])
2431
2432        print_params = {
2433            "inline": INLINE[options.environment],
2434            "block": self.name,
2435            "gen_params": gen_params,
2436            "ptr_params": ptr_params,
2437            "gen_inits": mk_inits("%s." % self.name),
2438            "ptr_inits": mk_inits("%s_ptr->" % self.name),
2439            "asserts": '  \n'.join(field_asserts)
2440        }
2441
2442        generator = generator_template % print_params
2443        ptr_generator = ptr_generator_template % print_params
2444
2445        emit_named("%s_new" % self.name, params, generator)
2446        emit_named("%s_ptr_new" % self.name, params, ptr_generator)
2447
2448        # Accessors
2449        for field, offset, size, high in self.fields:
2450            index = offset // self.base
2451            if high:
2452                write_shift = ">>"
2453                read_shift = "<<"
2454                shift = self.base_bits - size - (offset % self.base)
2455                if shift < 0:
2456                    shift = -shift
2457                    write_shift = "<<"
2458                    read_shift = ">>"
2459                if self.base_sign_extend:
2460                    high_bits = ((self.base_sign_extend << (self.base - self.base_bits)) - 1) << self.base_bits
2461                else:
2462                    high_bits = 0
2463            else:
2464                write_shift = "<<"
2465                read_shift = ">>"
2466                shift = offset % self.base
2467                high_bits = 0
2468            mask = ((1 << size) - 1) << (offset % self.base)
2469
2470            subs = {\
2471                "inline": INLINE[options.environment], \
2472                "block": self.name, \
2473                "field": field, \
2474                "type": TYPES[options.environment][self.base], \
2475                "assert": ASSERTS[options.environment], \
2476                "index": index, \
2477                "shift": shift, \
2478                "r_shift_op": read_shift, \
2479                "w_shift_op": write_shift, \
2480                "mask": mask, \
2481                "suf": self.constant_suffix, \
2482                "high_bits": high_bits, \
2483                "sign_extend": self.base_sign_extend and high,
2484                "extend_bit": self.base_bits - 1,
2485                "base": self.base}
2486
2487            # Reader
2488            emit_named("%s_get_%s" % (self.name, field), params,
2489                       reader_template % subs)
2490
2491            # Writer
2492            emit_named("%s_set_%s" % (self.name, field), params,
2493                       writer_template % subs)
2494
2495            # Pointer lifted reader
2496            emit_named("%s_ptr_get_%s" % (self.name, field), params,
2497                       ptr_reader_template % subs)
2498
2499            # Pointer lifted writer
2500            emit_named("%s_ptr_set_%s" % (self.name, field), params,
2501                       ptr_writer_template % subs)
2502
2503    def make_names(self, union=None):
2504        "Return the set of candidate function names for a block"
2505
2506        if union is None:
2507            # Don't generate raw accessors for blocks in tagged unions
2508            if self.tagged: return []
2509
2510            substs = {"block" : self.name}
2511
2512            # A standalone block
2513            field_templates = [
2514            "%(block)s_get_%(field)s",
2515            "%(block)s_set_%(field)s",
2516            "%(block)s_ptr_get_%(field)s",
2517            "%(block)s_ptr_set_%(field)s"]
2518
2519            names = [t % substs for t in [
2520            "%(block)s_new",
2521            "%(block)s_ptr_new"]]
2522        else:
2523            substs = {"block" : self.name, \
2524                      "union" : union.name}
2525
2526            # A tagged union block
2527            field_templates = [
2528            "%(union)s_%(block)s_get_%(field)s",
2529            "%(union)s_%(block)s_set_%(field)s",
2530            "%(union)s_%(block)s_ptr_get_%(field)s",
2531            "%(union)s_%(block)s_ptr_set_%(field)s"]
2532
2533            names = [t % substs for t in [
2534            "%(union)s_%(block)s_new",
2535            "%(union)s_%(block)s_ptr_new"]]
2536
2537        for field, offset, size, high in self.fields:
2538            if not union is None and field == union.tagname:
2539                continue
2540
2541            substs["field"] = field
2542            names += [t % substs for t in field_templates]
2543
2544        return names
2545
2546temp_output_files = []
2547class OutputFile(object):
2548    def __init__(self, filename, mode='w', atomic=True):
2549        """Open an output file for writing, recording its filename.
2550           If atomic is True, use a temporary file for writing.
2551           Call finish_output to finalise all temporary files."""
2552        self.filename = os.path.abspath(filename)
2553        if atomic:
2554            dirname, basename = os.path.split(self.filename)
2555            self.file = tempfile.NamedTemporaryFile(
2556                mode=mode, dir=dirname, prefix=basename + '.', delete=False)
2557            if DEBUG:
2558                print('Temp file: %r -> %r' % (self.file.name, self.filename), file=sys.stderr)
2559            global temp_output_files
2560            temp_output_files.append(self)
2561        else:
2562            self.file = open(filename, mode)
2563    def write(self, *args, **kwargs):
2564        self.file.write(*args, **kwargs)
2565
2566def finish_output():
2567    global temp_output_files
2568    for f in temp_output_files:
2569        os.rename(f.file.name, f.filename)
2570    temp_output_files = []
2571
2572## Toplevel
2573if __name__ == '__main__':
2574    # Parse arguments to set mode and grab I/O filenames
2575    params = {}
2576    in_filename = None
2577    in_file  = sys.stdin
2578    out_file = sys.stdout
2579    mode = 'c_defs'
2580
2581    parser = optparse.OptionParser()
2582    parser.add_option('--c_defs', action='store_true', default=False)
2583    parser.add_option('--environment', action='store', default='sel4',
2584                      choices=list(INCLUDES.keys()))
2585    parser.add_option('--hol_defs', action='store_true', default=False)
2586    parser.add_option('--hol_proofs', action='store_true', default=False)
2587    parser.add_option('--sorry_lemmas', action='store_true',
2588                      dest='sorry', default=False)
2589    parser.add_option('--prune', action='append',
2590                      dest="prune_files", default = [])
2591    parser.add_option('--toplevel', action='append',
2592                      dest="toplevel_types", default = [])
2593    parser.add_option('--umm_types', action='store',
2594                      dest="umm_types_file", default = None)
2595    parser.add_option('--multifile_base', action='store', default=None)
2596    parser.add_option('--cspec-dir', action='store', default=None,
2597            help="Location of the 'cspec' directory containing 'KernelState_C'.")
2598    parser.add_option('--thy-output-path', action='store', default=None,
2599            help="Path that the output theory files will be located in.")
2600    parser.add_option('--skip_modifies', action='store_true', default=False)
2601    parser.add_option('--showclasses', action='store_true', default=False)
2602    parser.add_option('--debug', action='store_true', default=False)
2603
2604    options, args = parser.parse_args()
2605    DEBUG = options.debug
2606
2607    if len(args) > 0:
2608        in_filename = args[0]
2609        in_file = open(in_filename)
2610
2611        if len(args) > 1:
2612            out_file = OutputFile(args[1])
2613
2614    #
2615    # If generating Isabelle scripts, ensure we have enough information for
2616    # relative paths required by Isabelle.
2617    #
2618    if options.hol_defs or options.hol_proofs:
2619        # Ensure directory that we need to include is known.
2620        if options.cspec_dir is None:
2621            parser.error("'cspec_dir' not defined.")
2622
2623        # Ensure that if an output file was not specified, an output path was.
2624        if len(args) <= 1:
2625            if options.thy_output_path is None:
2626                parser.error("Theory output path was not specified")
2627            if out_file == sys.stdout:
2628                parser.error('Output file name must be given when generating HOL definitions or proofs')
2629            out_file.filename = os.path.abspath(options.thy_output_path)
2630
2631    if options.hol_proofs and not options.umm_types_file:
2632        parser.error('--umm_types must be specified when generating HOL proofs')
2633
2634    del parser
2635
2636    options.output = out_file
2637
2638    # Parse the spec
2639    lexer = lex.lex()
2640    yacc.yacc(debug=0, write_tables=0)
2641    blocks = {}
2642    unions = {}
2643    _, block_map, union_map = yacc.parse(input=in_file.read(), lexer=lexer)
2644    base_list = [8, 16, 32, 64]
2645    # assumes that unsigned int = 32 bit on 32-bit and 64-bit platforms,
2646    # and that unsigned long long = 64 bit on 64-bit platforms.
2647    # Should still work fine if ull = 128 bit, but will not work
2648    # if unsigned int is less than 32 bit.
2649    suffix_map = {8 : 'u', 16 : 'u', 32 : 'u', 64 : 'ull'}
2650    for base_info, block_list in block_map.items():
2651        base, base_bits, base_sign_extend = base_info
2652        for name, b in block_list.items():
2653            if not base in base_list:
2654                raise ValueError("Invalid base size: %d" % base)
2655            suffix = suffix_map[base]
2656            b.set_base(base, base_bits, base_sign_extend, suffix)
2657            blocks[name] = b
2658
2659    symtab = {}
2660    symtab.update(blocks)
2661    for base, union_list in union_map.items():
2662        unions.update(union_list)
2663    symtab.update(unions)
2664    for base_info, union_list in union_map.items():
2665        base, base_bits, base_sign_extend = base_info
2666        for u in union_list.values():
2667            if not base in base_list:
2668                raise ValueError("Invalid base size: %d" % base)
2669            suffix = suffix_map[base]
2670            u.resolve(options, symtab)
2671            u.set_base(base, base_bits, base_sign_extend, suffix)
2672
2673    if not in_filename is None:
2674        base_filename = os.path.basename(in_filename).split('.')[0]
2675
2676        # Generate the module name from the input filename
2677        module_name = base_filename
2678
2679    # Prune list of names to generate
2680    name_list = []
2681    for e in itertools.chain(blocks.values(), unions.values()):
2682        name_list += e.make_names()
2683
2684    name_list = set(name_list)
2685    if len(options.prune_files) > 0:
2686        search_re = re.compile('[a-zA-Z0-9_]+');
2687
2688        pruned_names = set()
2689        for filename in options.prune_files:
2690            f = open(filename)
2691            string = f.read()
2692
2693            matched_tokens = set(search_re.findall(string))
2694            pruned_names.update(matched_tokens & name_list)
2695    else:
2696        pruned_names = name_list
2697
2698    options.names = pruned_names
2699
2700    # Generate the output
2701    if options.hol_defs:
2702        # Fetch kernel
2703        if options.multifile_base is None:
2704            print("theory %s_defs" % module_name, file=out_file)
2705            print("imports \"%s/KernelState_C\"" % (
2706                    os.path.relpath(options.cspec_dir,
2707                        os.path.dirname(out_file.filename))), file=out_file)
2708            print("begin", file=out_file)
2709            print(file=out_file)
2710
2711            print(defs_global_lemmas, file=out_file)
2712            print(file=out_file)
2713
2714            for e in blocks.values() + unions.values():
2715                e.generate_hol_defs(options)
2716
2717            print("end", file=out_file)
2718        else:
2719            print("theory %s_defs" % module_name, file=out_file)
2720            print("imports", file=out_file)
2721            print("  \"%s/KernelState_C\"" % (
2722                    os.path.relpath(options.cspec_dir,
2723                        os.path.dirname(out_file.filename))), file=out_file)
2724            for e in blocks.values() + unions.values():
2725                print("  %s_%s_defs" % (module_name, e.name),
2726                file=out_file)
2727            print("begin", file=out_file)
2728            print("end", file=out_file)
2729
2730            for e in blocks.values() + unions.values():
2731                base_filename = \
2732                    os.path.basename(options.multifile_base).split('.')[0]
2733                submodule_name = base_filename + "_" + \
2734                                 e.name + "_defs"
2735                out_file = OutputFile(options.multifile_base + "_" +
2736                                e.name + "_defs" + ".thy")
2737
2738                print("theory %s imports \"%s/KernelState_C\" begin" % (
2739                        submodule_name, os.path.relpath(options.cspec_dir,
2740                            os.path.dirname(out_file.filename))),
2741                        file=out_file)
2742                print(file=out_file)
2743
2744                options.output = out_file
2745                e.generate_hol_defs(options)
2746
2747                print("end", file=out_file)
2748    elif options.hol_proofs:
2749        def is_bit_type(tp):
2750            return (umm.is_base(tp) & (umm.base_name(tp) in map(lambda e: e.name + '_C', blocks.values() + unions.values())))
2751
2752        tps = umm.build_types(options.umm_types_file)
2753        type_map = {}
2754
2755        # invert type map
2756        for toptp in options.toplevel_types:
2757            paths = umm.paths_to_type(tps, is_bit_type, toptp)
2758
2759            for path, tp in paths:
2760                tp = umm.base_name(tp)
2761
2762                if tp in type_map:
2763                    raise ValueError("Type %s has multiple parents" % tp)
2764
2765                type_map[tp] = (toptp, path)
2766
2767        if options.multifile_base is None:
2768            print("theory %s_proofs" % module_name, file=out_file)
2769            print("imports %s_defs" % module_name, file=out_file)
2770            print("begin", file=out_file)
2771            print(file=out_file)
2772            print(file=out_file)
2773
2774            for e in blocks.values() + unions.values():
2775                e.generate_hol_proofs(options, type_map)
2776
2777            print("end", file=out_file)
2778        else:
2779            # top types are broken here.
2780            print("theory %s_proofs" % module_name, file=out_file)
2781            print("imports", file=out_file)
2782            for e in blocks.values() + unions.values():
2783                print("  %s_%s_proofs" % (module_name, e.name),
2784                file=out_file)
2785            print("begin", file=out_file)
2786            print("end", file=out_file)
2787
2788            for e in blocks.values() + unions.values():
2789                base_filename = \
2790                    os.path.basename(options.multifile_base).split('.')[0]
2791                submodule_name = base_filename + "_" + \
2792                                 e.name + "_proofs"
2793                out_file = OutputFile(options.multifile_base + "_" +
2794                                e.name + "_proofs" + ".thy")
2795
2796                print(("theory %s imports "
2797                        + "%s_%s_defs begin") % (
2798                            submodule_name, base_filename, e.name),
2799                                file=out_file)
2800                print(file=out_file)
2801
2802                options.output = out_file
2803                e.generate_hol_proofs(options, type_map)
2804
2805                print("end", file=out_file)
2806    else:
2807        guard = re.sub(r'[^a-zA-Z0-9_]', '_', out_file.filename.upper())
2808        print("#ifndef %(guard)s\n#define %(guard)s\n" % \
2809            {'guard':guard}, file=out_file)
2810        print('\n'.join(map(lambda x: '#include <%s>' % x,
2811            INCLUDES[options.environment])), file=out_file)
2812        for e in itertools.chain(blocks.values(), unions.values()):
2813            e.generate(options)
2814        print("#endif", file=out_file)
2815
2816    finish_output()
2817