1#
2# Copyright 2014, NICTA
3#
4# This software may be distributed and modified according to the terms of
5# the BSD 2-Clause license. Note that NO WARRANTY is provided.
6# See "LICENSE_BSD2.txt" for details.
7#
8# @TAG(NICTA_BSD)
9#
10
11from __future__ import print_function
12from __future__ import absolute_import
13import braces
14import re
15import sys
16import os
17import six
18from six.moves import map
19from six.moves import range
20from six.moves import zip
21from functools import reduce
22
23
24class Call(object):
25
26    def __init__(self):
27        self.restr = None
28        self.decls_only = False
29        self.instanceproofs = False
30        self.bodies_only = False
31        self.bad_type_assignment = False
32        self.body = False
33        self.current_context = []
34
35
36class Def(object):
37
38    def __init__(self):
39        self.type = None
40        self.defined = None
41        self.body = None
42        self.line = None
43        self.sig = None
44        self.instance_proofs = []
45        self.instance_extras = []
46        self.comments = []
47        self.primrec = None
48        self.deriving = []
49        self.instance_defs = {}
50
51
52def parse(call):
53    """Parses a file."""
54    set_global(call)
55
56    defs = get_defs(call.filename)
57
58    lines = get_lines(defs, call)
59
60    lines = perform_module_redirects(lines, call)
61
62    return ['%s\n' % line for line in lines]
63
64def settings_line(l):
65    """Adjusts some global settings."""
66    bits = l.split (',')
67    for bit in bits:
68        bit = bit.strip ()
69        (kind, setting) = bit.split ('=')
70        kind = kind.strip ()
71        if kind == 'keep_constructor':
72            [cons] = setting.split ()
73            keep_conss[cons] = 1
74        else:
75            assert not "setting kind understood", bit
76
77def set_global(_call):
78    global call
79    call = _call
80    global filename
81    filename = _call.filename
82
83
84file_defs = {}
85
86
87def splitList(list, pred):
88    """Splits a list according to pred."""
89    result = []
90    el = []
91    for l in list:
92        if pred(l):
93            if el != []:
94                result.append(el)
95                el = []
96        else:
97            el.append(l)
98    if el != []:
99        result.append(el)
100    return result
101
102
103def takeWhile(list, pred):
104    """Returns the initial portion of the list where each
105element matches pred"""
106    limit = 0
107
108    for l in list:
109        if pred(l):
110            limit = limit + 1
111        else:
112            break
113    return list[0:limit]
114
115
116def get_defs(filename):
117    #if filename in file_defs:
118    #    return file_defs[filename]
119
120    cmdline = os.environ['L4CPP']
121    f = os.popen('cpp -Wno-invalid-pp-token -traditional-cpp %s %s' %
122                 (cmdline, filename))
123    input = [line.rstrip() for line in f]
124    f.close()
125    defs = top_transform(input, filename.endswith(".lhs"))
126
127    file_defs[filename] = defs
128    return defs
129
130
131def top_transform(input, isLhs):
132    """Top level transform, deals with lhs artefacts, divides
133        the code up into a series of seperate definitions, and
134        passes these definitions through the definition transforms."""
135    to_process = []
136    comments = []
137    for n, line in enumerate(input):
138        if '\t' in line:
139            sys.stderr.write('WARN: tab in line %d, %s.\n' %
140                             (n, filename))
141        if isLhs:
142            if line.startswith('> '):
143                if '--' in line:
144                    line = line.split('--')[0].strip()
145
146                if line[2:].strip() == '':
147                    comments.append((n, 'C', ''))
148                elif line.startswith('> {-#'):
149                    comments.append((n, 'C', '(*' + line + '*)'))
150                else:
151                    to_process.append((line[2:], n))
152            else:
153                if line.strip():
154                    comments.append((n, 'C', '(*' + line + '*)'))
155                else:
156                    comments.append((n, 'C', ''))
157        else:
158            if '--' in line:
159                line = line.split('--')[0].rstrip()
160
161            if line.strip() == '':
162                comments.append((n, 'C', ''))
163            elif line.strip().startswith('{-'):
164                # single-line {- -} comments only
165                comments.append((n, 'C', '(*' + line + '*)'))
166            elif line.startswith('#'):
167                # preprocessor directives
168                comments.append((n, 'C', '(*' + line + '*)'))
169            else:
170                to_process.append((line, n))
171
172    def_tree = offside_tree(to_process)
173    defs = create_defs(def_tree)
174    defs = group_defs(defs)
175
176    #   Forget about the comments for now
177
178    # defs_plus_comments = [d.line, d) for d in defs] + comments
179    # defs_plus_comments.sort()
180    # defs = []
181    # prev_comments = []
182    # for term in defs_plus_comments:
183    #     if term[1] == 'C':
184    #         prev_comments.append(term[2])
185    #     else:
186    #         d = term[1]
187    #         d.comments = prev_comments
188    #         defs.append(d)
189    #         prev_comments = []
190
191    # apply def_transform and cut out any None return values
192    defs = [defs_transform(d) for d in defs]
193    defs = [d for d in defs if d is not None]
194
195    defs = ensure_type_ordering(defs)
196
197    return defs
198
199
200def get_lines(defs, call):
201    """Gets the output lines needed for this call from
202        all the potential output generated at parse time."""
203
204    if call.restr:
205        defs = [d for d in defs if d.type == 'comments'
206                or call.restr(d)]
207
208    output = []
209    for d in defs:
210        lines = def_lines(d, call)
211        if lines:
212            output.extend(lines)
213            output.append('')
214
215    return output
216
217
218def offside_tree(input):
219    """Breaks lines up into a tree based on the offside rule.
220        Each line gets as children the lines following it up until
221        the next line whose indent is less or equal."""
222    if input == []:
223        return []
224    head, head_n = input[0]
225    head_indent = len(head) - len(head.lstrip())
226    children = []
227    result = []
228    for line, n in input[1:]:
229        indent = len(line) - len(line.lstrip())
230        if indent <= head_indent:
231            result.append((head, head_n, offside_tree(children)))
232            head, head_n, head_indent = (line, n, indent)
233            children = []
234        else:
235            children.append((line, n))
236    result.append((head, head_n, offside_tree(children)))
237
238    return result
239
240
241def discard_line_numbers(tree):
242    """Takes a tree containing tuples (line, n, children) and
243        discards the n terms, returning a tree with tuples
244        (line, children)"""
245    result = []
246    for (line, _, children) in tree:
247        result.append((line, discard_line_numbers(children)))
248    return result
249
250
251def flatten_tree(tree):
252    """Returns a tree to the set of numbered lines it was
253        drawn from."""
254    result = []
255    for (line, children) in tree:
256        result.append(line)
257        result.extend(flatten_tree(children))
258
259    return result
260
261
262def create_defs(tree):
263    defs = [create_def(elt) for elt in tree]
264    defs = [d for d in defs if d is not None]
265
266    return defs
267
268
269def group_defs(defs):
270    """Takes a file broken into a series of definitions, and locates
271        multiple definitions of constants, caused by type signatures or
272        pattern matching, and accumulates to a single object per genuine
273        definition"""
274    defgroups = []
275    defined = ''
276    for d in defs:
277        this_defines = d.defined
278        if d.type != 'definitions':
279            this_defines = ''
280        if this_defines == defined and this_defines:
281            defgroups[-1].body.extend(d.body)
282        else:
283            defgroups.append(d)
284            defined = this_defines
285
286    return defgroups
287
288
289def create_def(elt):
290    """Takes an element of an offside tree and creates
291        a definition object."""
292    (line, n, children) = elt
293    children = discard_line_numbers(children)
294    return create_def_2(line, children, n)
295
296
297def create_def_2(line, children, n):
298    d = Def()
299    d.body = [(line, children)]
300    d.line = n
301    lead = line.split(None, 3)
302    if lead[0] in ['import', 'module', 'class']:
303        return
304    elif lead[0] == 'instance':
305        type = 'instance'
306        defined = lead[2]
307    elif lead[0] in ['type', 'newtype', 'data']:
308        type = 'newtype'
309        defined = lead[1]
310    else:
311        type = 'definitions'
312        defined = lead[0]
313
314    d.type = type
315    d.defined = defined
316    return d
317
318
319def get_primrecs():
320    f = open('primrecs')
321    keys = [line.strip() for line in f]
322    return set(key for key in keys if key != '')
323
324
325primrecs = get_primrecs()
326
327
328def defs_transform(d):
329    """Transforms the set of definitions for a function. This
330        may include its type signature, and may include the special
331        case of multiple definitions."""
332    # the first tokens of the first line in the first definition
333    if d.type == 'newtype':
334        return newtype_transform(d)
335    elif d.type == 'instance':
336        return instance_transform(d)
337
338    lead = d.body[0][0].split(None, 2)
339    if lead[1] == '::':
340        d.sig = type_sig_transform(d.body[0])
341        d.body.pop(0)
342
343    if d.defined in primrecs:
344        return primrec_transform(d)
345
346    if len(d.body) > 1:
347        d.body = pattern_match_transform(d.body)
348
349    if len(d.body) == 0:
350        print()
351        print(d)
352        assert 0
353
354    d.body = body_transform(d.body, d.defined, d.sig)
355    return d
356
357
358def wrap_qualify(lines, deep=True):
359    if len(lines) == 0:
360        return lines
361
362    """Close and then re-open a locale so instantiations can go through"""
363    if deep:
364        asdfextra = ""
365    else:
366        asdfextra = ""
367
368    if call.current_context:
369        lines.insert(0, 'end\nqualify {} (in Arch) {}'.format(call.current_context[-1],
370                asdfextra))
371        lines.append('end_qualify\ncontext Arch begin global_naming %s' % call.current_context[-1])
372    return lines
373
374def def_lines(d, call):
375    """Produces the set of lines associated with a definition."""
376    if call.all_bits:
377        L = []
378        if d.comments:
379            L.extend(flatten_tree(d.comments))
380            L.append('')
381        if d.type == 'definitions':
382            L.append('definition')
383            if d.sig:
384                L.extend(flatten_tree([d.sig]))
385                L.append('where')
386                L.extend(flatten_tree(d.body))
387        elif d.type == 'newtype':
388            L.extend(flatten_tree(d.body))
389        if d.instance_proofs:
390            L.extend(wrap_qualify(flatten_tree(d.instance_proofs)))
391            L.append('')
392        if d.instance_extras:
393            L.extend(flatten_tree(d.instance_extras))
394            L.append('')
395        return L
396
397    if call.instanceproofs:
398        if not call.bodies_only:
399            instance_proofs = wrap_qualify(flatten_tree(d.instance_proofs))
400        else:
401            instance_proofs = []
402
403        if not call.decls_only:
404            instance_extras = flatten_tree(d.instance_extras)
405        else:
406            instance_extras = []
407
408        newline_needed = len(instance_proofs) > 0 and len(instance_extras) > 0
409        return instance_proofs + (['']
410                                  if newline_needed else []) + instance_extras
411
412    if call.body:
413        return get_lambda_body_lines(d)
414
415    comments = d.comments
416    try:
417        typesig = flatten_tree([d.sig])
418    except:
419        typesig = []
420    body = flatten_tree(d.body)
421    type = d.type
422
423    if type == 'definitions':
424        if call.decls_only:
425            if typesig:
426                return comments + ["consts'"] + typesig
427            else:
428                return []
429        elif call.bodies_only:
430            if d.sig:
431                defname = '%s_def' % d.defined
432                if d.primrec:
433                    print('warning body-only primrec:')
434                    print(body[0])
435                    return comments + ['primrec'] + body
436                return comments + ['defs %s:' % defname] + body
437            else:
438                return comments + ['definition'] + body
439        else:
440            if d.primrec:
441                return comments + ['primrec'] + typesig \
442                    + ['where'] + body
443            if typesig:
444                return comments + ['definition'] + typesig + ['where'] + body
445            else:
446                return comments + ['definition'] + body
447    elif type == 'comments':
448        return comments
449    elif type == 'newtype':
450        if not call.bodies_only:
451            return body
452
453
454def type_sig_transform(tree_element):
455    """Performs transformations on a type signature line
456        preceding a function declaration or some such."""
457
458    line = reduce_to_single_line(tree_element)
459    (pre, post) = line.split('::')
460    result = type_transform(post)
461    if '[pp' in result:
462        print(line)
463        print(pre)
464        print(post)
465        print(result)
466        assert 0
467    line = pre + ':: "' + result + '"'
468
469    return (line, [])
470
471
472ignore_classes = {'Error': 1}
473hand_classes = {'Bits': ['HS_bit'],
474                'Num': ['minus', 'one', 'zero', 'plus', 'numeral'],
475                'FiniteBits': ['finiteBit']}
476
477
478def type_transform(string):
479    """Performs transformations on a type signature, whether
480        part of a type signature line or occuring in a function."""
481
482    # deal with type classes by recursion
483    bits = string.split('=>', 1)
484    if len(bits) == 2:
485        lhs = bits[0].strip()
486        if lhs.startswith('(') and lhs.endswith(')'):
487            instances = lhs[1:-1].split(',')
488            string = ' => '.join(instances + [bits[1]])
489        else:
490            instances = [lhs]
491        var_annotes = {}
492        for instance in instances:
493            (name, var) = instance.split()
494            if name in ignore_classes:
495                continue
496            if name in hand_classes:
497                names = hand_classes[name]
498            else:
499                names = [type_conv(name)]
500            var = "'" + var
501            var_annotes.setdefault(var, [])
502            var_annotes[var].extend(names)
503        transformed = type_transform(bits[1])
504        for (var, insts) in six.iteritems(var_annotes):
505            if len(insts) == 1:
506                newvar = '(%s :: %s)' % (var, insts[0])
507            else:
508                newvar = '(%s :: {%s})' % (var, ', '.join(insts))
509            transformed = newvar.join(transformed.split(var, 1))
510        return transformed
511
512    # get rid of (), insert Unit, which converts to unit
513    string = 'Unit'.join(string.split('()'))
514
515    # divide up by -> or by , then divide on space.
516    # apply everything locally then work back up
517    bstring = braces.str(string, '(', ')')
518    bits = bstring.split('->')
519    r = ' \<Rightarrow> '
520    if len(bits) == 1:
521        bits = bstring.split(',')
522        r = ' * '
523    result = [type_bit_transform(bit) for bit in bits]
524    return r.join(result)
525
526
527def type_bit_transform(bit):
528    s = str(bit).strip()
529    if s.startswith('['):
530        # handling this properly is hard.
531        assert s.endswith(']')
532        bit2 = braces.str(s[1:-1], '(', ')')
533        return '%s list' % type_bit_transform(bit2)
534    bits = bit.split(None, braces=True)
535    if str(bits[0]) == 'PPtr':
536        assert len(bits) == 2
537        return 'machine_word'
538    if len(bits) > 1 and bits[1].startswith('['):
539        assert bits[-1].endswith(']')
540        arg = ' '.join([str(bit) for bit in bits[1:]])[1:-1]
541        arg = type_transform(arg)
542        return ' '.join([arg, 'list', str(type_conv(bits[0]))])
543    bits = [type_conv(bit) for bit in bits]
544    bits = constructor_reversing(bits)
545    bits = [bit.map(type_transform) for bit in bits]
546    strs = [str(bit) for bit in bits]
547    return ' '.join(strs)
548
549
550def reduce_to_single_line(tree_element):
551    def inner(tree_element, acc):
552        (line, children) = tree_element
553        acc.append(line)
554        for child in children:
555            inner(child, acc)
556        return acc
557    return ' '.join(inner(tree_element, []))
558
559
560type_conv_table = {
561 'Maybe': 'option',
562 'Bool': 'bool',
563 'Word': 'machine_word',
564 'Int': 'nat',
565 'String': 'unit list'}
566
567
568def type_conv(string):
569    """Converts a type used in Haskell to our equivalent"""
570    if string.startswith('('):
571        # ignore compound types, type_transform will descend into em
572        result = string
573    elif '.' in string:
574        # qualified references
575        bits = string.split('.')
576        typename = bits[-1]
577        module = reduce(lambda x, y: x + '.' + y, bits[:-1])
578        typename = type_conv(typename)
579        result = module + '.' + typename
580    elif string[0].islower():
581        # type variable
582        result = "'%s" % string
583    elif string[0] == '[' and string[-1] == ']':
584        # list
585        inner = type_conv(string[1:-1])
586        result = '%s list' % inner
587    elif str(string) in type_conv_table:
588        result = type_conv_table[str(string)]
589    else:
590        # convert StudlyCaps to lower_with_underscores
591        was_lower = False
592        s = ''
593        for c in string:
594            if c.isupper() and was_lower:
595                s = s + '_' + c.lower()
596            else:
597                s = s + c.lower()
598            was_lower = c.islower()
599        result = s
600        type_conv_table[str(string)] = result
601
602    return braces.clone(result, string)
603
604
605def constructor_reversing(tokens):
606    if len(tokens) < 2:
607        return tokens
608    elif len(tokens) == 2:
609        return [tokens[1], tokens[0]]
610    elif tokens[0] == '[' and tokens[2] == ']':
611        return [tokens[1], braces.str('list', '(', ')')]
612    elif len(tokens) == 4 and tokens[1] == '[' and tokens[3] == ']':
613        listToken = braces.str('(List %s)' % tokens[2], '(', ')')
614        return [listToken, tokens[0]]
615    elif tokens[0] == 'array':
616        arrow_token = braces.str('\<Rightarrow>', '(', ')')
617        return [tokens[1], arrow_token, tokens[2]]
618    elif tokens[0] == 'either':
619        plus_token = braces.str('+', '(', ')')
620        return [tokens[1], plus_token, tokens[2]]
621    elif len(tokens) == 5 and tokens[2] == '[' and tokens[4] == ']':
622        listToken = braces.str('(List %s)' % tokens[3], '(', ')')
623        lbrack = braces.str('(', '+', '+')
624        rbrack = braces.str(')', '+', '+')
625        comma = braces.str(',', '+', '+')
626        return [lbrack, tokens[1], comma, listToken, rbrack, tokens[0]]
627    elif len(tokens) == 3:
628        # here comes a fudge
629        lbrack = braces.str('(', '+', '+')
630        rbrack = braces.str(')', '+', '+')
631        comma = braces.str(',', '+', '+')
632        return [lbrack, tokens[1], comma, tokens[2], rbrack, tokens[0]]
633    else:
634        print("Error parsing " + filename)
635        print("Can't deal with %s" % tokens)
636        sys.exit(1)
637
638
639def newtype_transform(d):
640    """Takes a Haskell style newtype/data type declaration, whose
641        options are divided with | and each of whose options has named
642        elements, and forms a datatype statement and definitions for
643        the named extractors and their update functions."""
644    if len(d.body) != 1:
645        print('--- newtype long body ---')
646        print(d)
647    [(line, children)] = d.body
648
649    if children and children[-1][0].lstrip().startswith('deriving'):
650        l = reduce_to_single_line(children[-1])
651        children = children[:-1]
652        r = re.compile(r"[,\s\(\)]+")
653        bits = r.split(l)
654        bits = [bit for bit in bits if bit and bit != 'deriving']
655        d.deriving = bits
656
657    line = reduce_to_single_line((line, children))
658
659    bits = line.split(None, 1)
660    op = bits[0]
661    line = bits[1]
662    bits = line.split('=', 1)
663    header = type_conv(bits[0].strip())
664    d.typename = header
665    d.typedeps = set()
666    if len(bits) == 1:
667        # line of form 'data Blah' introduces unknown type?
668        d.body = [('typedecl %s' % header, [])]
669        all_type_arities[header] = []  # HACK
670        return d
671    line = bits[1]
672
673    if op == 'type':
674        # type synonym
675        return typename_transform(line, header, d)
676    elif line.find('{') == -1:
677        # not a record
678        return simple_newtype_transform(line, header, d)
679    else:
680        return named_newtype_transform(line, header, d)
681
682
683def typename_transform(line, header, d):
684    try:
685        [oldtype] = line.split()
686    except:
687        sys.stderr.write('Warning: type assignment with parameters not supported %s\n' % d.body)
688        call.bad_type_assignment = True
689        return
690    if oldtype.startswith('Data.Word.Word'):
691        # take off the prefix, leave Word32 or Word64 etc
692        oldtype = oldtype[10:]
693    oldtype = type_conv(oldtype)
694    bits = oldtype.split()
695    for bit in bits:
696        d.typedeps.add(bit)
697    lines = [
698        'type_synonym %s = "%s"' % (header, oldtype),
699        # translations (* TYPE 1 *)',
700        # "%s" <=(type) "%s"' % (oldtype, header)
701    ]
702    d.body = [(line, []) for line in lines]
703    return d
704
705
706keep_conss = {}
707
708
709def simple_newtype_transform(line, header, d):
710    lines = []
711    arities = []
712    for i, bit in enumerate(line.split('|')):
713        braced = braces.str(bit, '(', ')')
714        bits = braced.split()
715        if len(bits) == 2:
716            last_lhs = bits[0]
717
718        if i == 0:
719            l = '    %s' % bits[0]
720        else:
721            l = '  | %s' % bits[0]
722        for bit in bits[1:]:
723            if bit.startswith('('):
724                bit = bit[1:-1]
725            typename = type_transform(str(bit))
726            if len(bits) == 2:
727                last_rhs = typename
728            if ' ' in typename:
729                typename = '"%s"' % typename
730            l = l + ' ' + typename
731            d.typedeps.add(typename)
732        lines.append(l)
733
734        arities.append((str(bits[0]), len(bits[1:])))
735
736    if list((dict(arities)).values()) == [1] and header not in keep_conss:
737        return type_wrapper_type(header, last_lhs, last_rhs, d)
738
739    d.body = [('datatype %s =' % header, [(line, []) for line in lines])]
740
741    set_instance_proofs(header, arities, d)
742
743    return d
744
745
746all_constructor_args = {}
747
748
749def named_newtype_transform(line, header, d):
750    bits = line.split('|')
751
752    constructors = [get_type_map(bit) for bit in bits]
753    all_constructor_args.update(dict(constructors))
754
755    lines = []
756    for i, (name, map) in enumerate(constructors):
757        if i == 0:
758            l = '    %s' % name
759        else:
760            l = '  | %s' % name
761        oname = name
762        for name, type in map:
763            if type is None:
764                print("ERR: None snuck into constructor list for %s" % name)
765                print(line, header, oname)
766                assert False
767
768            if len(type.split()) == 1 and '(' not in type:
769                l = l + ' ' + type
770            else:
771                l = l + ' "' + type + '"'
772            for bit in type.split():
773                d.typedeps.add(bit)
774        lines.append(l)
775
776    names = {}
777    types = {}
778    for cons, map in constructors:
779        for i, (name, type) in enumerate(map):
780            names.setdefault(name, {})
781            names[name][cons] = i
782            types[name] = type
783
784    for name, map in six.iteritems(names):
785        lines.append('')
786        lines.extend(named_extractor_definitions(name, map, types[name],
787                                                 header, dict(constructors)))
788
789    for name, map in six.iteritems(names):
790        lines.append('')
791        lines.extend(named_update_definitions(name, map, types[name], header,
792                                              dict(constructors)))
793
794    for name, map in constructors:
795        if map == []:
796            continue
797        lines.append('')
798        lines.extend(named_constructor_translation(name, map, header))
799
800    if len(constructors) > 1:
801        for name, map in constructors:
802            lines.append('')
803            check = named_constructor_check(name, map, header)
804            lines.extend(check)
805
806    if len(constructors) == 1:
807        for ex_name, _ in six.iteritems(names):
808            for up_name, _ in six.iteritems(names):
809                lines.append('')
810                lines.extend(named_extractor_update_lemma(ex_name, up_name))
811
812    arities = [(name, len(map)) for (name, map) in constructors]
813
814    if list((dict(arities)).values()) == [1] and header not in keep_conss:
815        [(cons, map)] = constructors
816        [(name, type)] = map
817        return type_wrapper_type(header, cons, type, d, decons=(name, type))
818
819    set_instance_proofs(header, arities, d)
820
821    d.body = [('datatype %s =' % header, [(line, []) for line in lines])]
822    return d
823
824
825def named_extractor_update_lemma(ex_name, up_name):
826    lines = []
827    lines.append('lemma %s_%s_update [simp]:' % (ex_name, up_name))
828
829    if up_name == ex_name:
830        lines.append('  "%s (%s_update f v) = f (%s v)"' %
831                     (ex_name, up_name, ex_name))
832    else:
833        lines.append('  "%s (%s_update f v) = %s v"' %
834                     (ex_name, up_name, ex_name))
835
836    lines.append('  by (cases v) simp')
837
838    return lines
839
840
841def named_extractor_definitions(name, map, type, header, constructors):
842    lines = []
843    lines.append('primrec')
844    lines.append('  %s :: "%s \<Rightarrow> %s"'
845                 % (name, header, type))
846    lines.append('where')
847    is_first = True
848    for cons, i in six.iteritems(map):
849        if is_first:
850            l = '  "%s (%s' % (name, cons)
851            is_first = False
852        else:
853            l = '| "%s (%s' % (name, cons)
854        num = len(constructors[cons])
855        for k in range(num):
856            l = l + ' v%d' % k
857        l = l + ') = v%d"' % i
858        lines.append(l)
859
860    return lines
861
862
863def named_update_definitions(name, map, type, header, constructors):
864    lines = []
865    lines.append('primrec')
866    ra = '\<Rightarrow>'
867    if len(type.split()) > 1:
868        type = '(%s)' % type
869    lines.append('  %s_update :: "(%s %s %s) %s %s %s %s"'
870                 % (name, type, ra, type, ra, header, ra, header))
871    lines.append('where')
872    is_first = True
873    for cons, i in six.iteritems(map):
874        if is_first:
875            l = '  "%s_update f (%s' % (name, cons)
876            is_first = False
877        else:
878            l = '| "%s_update f (%s' % (name, cons)
879        num = len(constructors[cons])
880        for k in range(num):
881            l = l + ' v%d' % k
882        l = l + ') = %s' % cons
883        for k in range(num):
884            if k == i:
885                l = l + ' (f v%d)' % k
886            else:
887                l = l + ' v%d' % k
888        l = l + '"'
889        lines.append(l)
890
891    return lines
892
893
894def named_constructor_translation(name, map, header):
895    lines = []
896    lines.append('abbreviation (input)')
897    l = '  %s_trans :: "' % name
898    for n, type in map:
899        l = l + '(' + type + ') \<Rightarrow> '
900    l = l + '%s" ("%s\'_ \<lparr> %s= _' % (header, name, map[0][0])
901    for n, type in map[1:]:
902        l = l + ', %s= _' % n
903    l = l + ' \<rparr>")'
904    lines.append(l)
905    lines.append('where')
906    l = '  "%s_ \<lparr> %s= v0' % (name, map[0][0])
907    for i, (n, type) in enumerate(map[1:]):
908        l = l + ', %s= v%d' % (n, i + 1)
909    l = l + ' \<rparr> == %s' % name
910    for i in range(len(map)):
911        l = l + ' v%d' % i
912    l = l + '"'
913    lines.append(l)
914
915    return lines
916
917
918def named_constructor_check(name, map, header):
919    lines = []
920    lines.append('definition')
921    lines.append('  is%s :: "%s \<Rightarrow> bool"' % (name, header))
922    lines.append('where')
923    lines.append(' "is%s v \<equiv> case v of' % name)
924    l = '    %s ' % name
925    for i, _ in enumerate(map):
926        l = l + 'v%d ' % i
927    l = l + '\<Rightarrow> True'
928    lines.append(l)
929    lines.append('  | _ \<Rightarrow> False"')
930
931    return lines
932
933
934def type_wrapper_type(header, cons, rhs, d, decons=None):
935    if '\\<Rightarrow>' in d.typedeps:
936        d.body = [('(* type declaration of %s omitted *)' % header, [])]
937        return d
938    lines = [
939        'type_synonym %s = "%s"' % (header, rhs),
940        # translations (* TYPE 2 *)',
941        # "%s" <=(type) "%s"' % (header, rhs),
942        '',
943        'definition',
944        '  %s :: "%s \\<Rightarrow> %s"' % (cons, header, header),
945        'where %s_def[simp]:' % cons,
946        ' "%s \\<equiv> id"' % cons,
947    ]
948    if decons:
949        (decons, decons_type) = decons
950        lines.extend([
951            '',
952            'definition',
953            '  %s :: "%s \\<Rightarrow> %s"' % (decons, header, header),
954            'where',
955            '  %s_def[simp]:' % decons,
956            ' "%s \\<equiv> id"' % decons,
957            '',
958            'definition'
959            '  %s_update :: "(%s \\<Rightarrow> %s) \\<Rightarrow> %s \\<Rightarrow> %s"'
960            % (decons, header, header, header, header),
961            'where',
962            '  %s_update_def[simp]:' % decons,
963            ' "%s_update f y \<equiv> f y"' % decons,
964            ''
965        ])
966        lines.extend(named_constructor_translation(cons, [(decons, decons_type)
967                                                          ], header))
968
969    d.body = [(line, []) for line in lines]
970    return d
971
972
973def instance_transform(d):
974    [(line, children)] = d.body
975    bits = line.split(None, 3)
976    assert bits[0] == 'instance'
977    classname = bits[1]
978    typename = type_conv(bits[2])
979    if classname == 'Show':
980        print("Warning: discarding class instance '%s :: Show'" % typename)
981        return None
982    if typename == '()':
983        print("Warning: discarding class instance 'unit :: %s'" % classname)
984        return None
985    if len(bits) == 3:
986        if children == []:
987            defs = []
988        else:
989            [(l, c)] = children
990            assert l.strip() == 'where'
991            defs = c
992    else:
993        assert bits[3:] == ['where']
994        defs = children
995    defs = [create_def_2(l, c, 0) for (l, c) in defs]
996    defs = [d2 for d2 in defs if d2 is not None]
997    defs = group_defs(defs)
998    defs = [defs_transform(d2) for d2 in defs]
999    defs_dict = {}
1000    for d2 in defs:
1001        if d2 is not None:
1002            defs_dict[d2.defined] = d2
1003    d.instance_defs = defs_dict
1004    d.deriving = [classname]
1005    if typename not in all_type_arities:
1006        sys.stderr.write('FAIL: attempting %s\n' % d.defined)
1007        sys.stderr.write('(typename %r)\n' % typename)
1008        sys.stderr.write('when reading %s\n' % filename)
1009        sys.stderr.write('but class not defined yet\n')
1010        sys.stderr.write('perhaps parse in different order?\n')
1011        sys.stderr.write('hint: #INCLUDE_HASKELL_PREPARSE\n')
1012        sys.exit(1)
1013    arities = all_type_arities[typename]
1014    set_instance_proofs(typename, arities, d)
1015
1016    return d
1017
1018
1019all_type_arities = {}
1020
1021
1022def set_instance_proofs(header, constructor_arities, d):
1023    all_type_arities[header] = constructor_arities
1024    pfs = []
1025    exs = []
1026    canonical = list(enumerate(constructor_arities))
1027
1028    classes = d.deriving
1029    instance_proof_fns = set(
1030            sorted((instance_proof_table[classname] for classname in classes),
1031                   key=lambda x: x.order))
1032    for f in instance_proof_fns:
1033        (npfs, nexs) = f(header, canonical, d)
1034        pfs.extend(npfs)
1035        exs.extend(nexs)
1036
1037    if d.type == 'newtype' and len(canonical) == 1 and False:
1038        [(cons, n)] = constructor_arities
1039        if n == 1:
1040            pfs.extend(finite_instance_proofs(header, cons))
1041
1042    if pfs:
1043        lead = '(* %s instance proofs *)' % header
1044        d.instance_proofs = [(lead, [(line, []) for line in pfs])]
1045    if exs:
1046        lead = '(* %s extra instance defs *)' % header
1047        d.instance_extras = [(lead, [(line, []) for line in exs])]
1048
1049
1050def finite_instance_proofs(header, cons):
1051    lines = []
1052    lines.append('')
1053    lines.append('instance %s :: finite' % header)
1054    if call.current_context:
1055        lines.append('interpretation Arch .')
1056    lines.append('  apply (intro_classes)')
1057    lines.append('  apply (rule_tac f="%s" in finite_surj_type)'
1058                 % cons)
1059    lines.append('  apply (safe, case_tac x, simp_all)')
1060    lines.append('  done')
1061
1062    return (lines, [])
1063
1064# wondering where the serialisable proofs went? see
1065# commit 21361f073bbafcfc985934e563868116810d9fa2 for last known occurence.
1066
1067# leave type tags 0..11 for explicit use outside of this script
1068next_type_tag = 12
1069
1070
1071def storable_instance_proofs(header, canonical, d):
1072    proofs = []
1073    extradefs = []
1074
1075    global next_type_tag
1076    next_type_tag = next_type_tag + 1
1077    proofs.extend([
1078        '', 'defs (overloaded)', '  typetag_%s[simp]:' % header,
1079        ' "typetag (x :: %s) \<equiv> %d"' % (header, next_type_tag), ''
1080        'instance %s :: dynamic' % header, '  by (intro_classes, simp)'
1081    ])
1082
1083    proofs.append('')
1084    proofs.append('instance %s :: storable ..' % header)
1085
1086    defs = d.instance_defs
1087    extradefs.append('')
1088    if 'objBits' in defs:
1089        extradefs.append('definition')
1090        body = flatten_tree(defs['objBits'].body)
1091        bits = body[0].split('objBits')
1092        assert bits[0].strip() == '"'
1093        if bits[1].strip().startswith('_'):
1094            bits[1] = 'x ' + bits[1].strip()[1:]
1095        bits = bits[1].split(None, 1)
1096        body[0] = '  objBits_%s: "objBits (%s :: %s) %s' \
1097            % (header, bits[0], header, bits[1])
1098        extradefs.extend(body)
1099
1100    extradefs.append('')
1101    if 'makeObject' in defs:
1102        extradefs.append('definition')
1103        body = flatten_tree(defs['makeObject'].body)
1104        bits = body[0].split('makeObject')
1105        assert bits[0].strip() == '"'
1106        body[0] = '  makeObject_%s: "(makeObject :: %s) %s' \
1107            % (header, header, bits[1])
1108        extradefs.extend(body)
1109
1110    extradefs.extend(['', 'definition', ])
1111    if 'loadObject' in defs:
1112        extradefs.append('  loadObject_%s:' % header)
1113        extradefs.extend(flatten_tree(defs['loadObject'].body))
1114    else:
1115        extradefs.extend([
1116            '  loadObject_%s[simp]:' % header,
1117            ' "(loadObject p q n obj) :: %s \<equiv>' % header,
1118            '    loadObject_default p q n obj"',
1119        ])
1120
1121    extradefs.extend(['', 'definition', ])
1122    if 'updateObject' in defs:
1123        extradefs.append('  updateObject_%s:' % header)
1124        body = flatten_tree(defs['updateObject'].body)
1125        bits = body[0].split('updateObject')
1126        assert bits[0].strip() == '"'
1127        bits = bits[1].split(None, 1)
1128        body[0] = ' "updateObject (%s :: %s) %s' \
1129            % (bits[0], header, bits[1])
1130        extradefs.extend(body)
1131    else:
1132        extradefs.extend([
1133            '  updateObject_%s[simp]:' % header,
1134            ' "updateObject (val :: %s) \<equiv>' % header,
1135            '    updateObject_default val"',
1136        ])
1137
1138    return (proofs, extradefs)
1139
1140
1141storable_instance_proofs.order = 1
1142
1143
1144def pspace_storable_instance_proofs(header, canonical, d):
1145    proofs = []
1146    extradefs = []
1147
1148    proofs.append('')
1149    proofs.append('instance %s :: pre_storable' % header)
1150    proofs.append('  by (intro_classes,')
1151    proofs.append(
1152        '      auto simp: projectKO_opts_defs split: kernel_object.splits arch_kernel_object.splits)')
1153
1154    defs = d.instance_defs
1155    extradefs.append('')
1156    if 'objBits' in defs:
1157        extradefs.append('definition')
1158        body = flatten_tree(defs['objBits'].body)
1159        bits = body[0].split('objBits')
1160        assert bits[0].strip() == '"'
1161        if bits[1].strip().startswith('_'):
1162            bits[1] = 'x ' + bits[1].strip()[1:]
1163        bits = bits[1].split(None, 1)
1164        body[0] = '  objBits_%s: "objBits (%s :: %s) %s' \
1165            % (header, bits[0], header, bits[1])
1166        extradefs.extend(body)
1167
1168    extradefs.append('')
1169    if 'makeObject' in defs:
1170        extradefs.append('definition')
1171        body = flatten_tree(defs['makeObject'].body)
1172        bits = body[0].split('makeObject')
1173        assert bits[0].strip() == '"'
1174        body[0] = '  makeObject_%s: "(makeObject :: %s) %s' \
1175            % (header, header, bits[1])
1176        extradefs.extend(body)
1177
1178    extradefs.extend(['', 'definition', ])
1179    if 'loadObject' in defs:
1180        extradefs.append('  loadObject_%s:' % header)
1181        extradefs.extend(flatten_tree(defs['loadObject'].body))
1182    else:
1183        extradefs.extend([
1184            '  loadObject_%s[simp]:' % header,
1185            ' "(loadObject p q n obj) :: %s kernel \<equiv>' % header,
1186            '    loadObject_default p q n obj"',
1187        ])
1188
1189    extradefs.extend(['', 'definition', ])
1190    if 'updateObject' in defs:
1191        extradefs.append('  updateObject_%s:' % header)
1192        body = flatten_tree(defs['updateObject'].body)
1193        bits = body[0].split('updateObject')
1194        assert bits[0].strip() == '"'
1195        bits = bits[1].split(None, 1)
1196        body[0] = ' "updateObject (%s :: %s) %s' \
1197            % (bits[0], header, bits[1])
1198        extradefs.extend(body)
1199    else:
1200        extradefs.extend([
1201            '  updateObject_%s[simp]:' % header,
1202            ' "updateObject (val :: %s) \<equiv>' % header,
1203            '    updateObject_default val"',
1204        ])
1205
1206    return (proofs, extradefs)
1207
1208
1209pspace_storable_instance_proofs.order = 1
1210
1211
1212def num_instance_proofs(header, canonical, d):
1213    assert len(canonical) == 1
1214    [(_, (cons, n))] = canonical
1215    assert n == 1
1216    lines = []
1217
1218    def add_bij_instance(classname, fns):
1219        ins = bij_instance(classname, header, cons, fns)
1220        lines.extend(ins)
1221
1222    add_bij_instance('plus', [('plus', '%s + %s', True)])
1223    add_bij_instance('minus', [('minus', '%s - %s', True)])
1224    add_bij_instance('zero', [('zero', '0', True)])
1225    add_bij_instance('one', [('one', '1', True)])
1226    add_bij_instance('times', [('times', '%s * %s', True)])
1227
1228    return (lines, [])
1229
1230
1231num_instance_proofs.order = 2
1232
1233def enum_instance_proofs (header, canonical, d):
1234    lines = ['(*<*)']
1235    if len(canonical) == 1:
1236        [(_, (cons, n))] = canonical
1237        assert n == 1
1238        lines.append('instantiation %s :: enum begin' % header)
1239        if call.current_context:
1240            lines.append('interpretation Arch .')
1241        lines.append('definition')
1242        lines.append('  enum_%s: "enum_class.enum \<equiv> map %s enum"' \
1243                     % (header, cons))
1244
1245    else:
1246        cons_no_args = [cons for i, (cons, n) in canonical if n == 0]
1247        cons_one_arg = [cons for i, (cons, n) in canonical if n == 1]
1248        cons_two_args = [cons for i, (cons, n) in canonical if n > 1]
1249        assert cons_two_args == []
1250        lines.append ('instantiation %s :: enum begin' % header)
1251        if call.current_context:
1252            lines.append('interpretation Arch .')
1253        lines.append ('definition')
1254        lines.append ('  enum_%s: "enum_class.enum \<equiv> ' % header)
1255        lines.append ('    [ ')
1256        for cons in cons_no_args[:-1]:
1257            lines.append ('      %s,' % cons)
1258        for cons in cons_no_args[-1:]:
1259            lines.append ('      %s' % cons)
1260        lines.append ('    ]')
1261        for cons in cons_one_arg:
1262            lines.append ('    @ (map %s enum)' % cons)
1263        lines[-1] = lines[-1] + '"'
1264    lines.append ('')
1265    for cons in cons_one_arg:
1266        lines.append('lemma %s_map_distinct[simp]: "distinct (map %s enum)"' % (cons, cons))
1267        lines.append('  apply (simp add: distinct_map)')
1268        lines.append('  by (meson injI %s.inject)' % header)
1269    lines.append('')
1270    lines.append('definition')
1271    lines.append('  "enum_class.enum_all (P :: %s \<Rightarrow> bool) \<longleftrightarrow> Ball UNIV P"' \
1272            % header)
1273    lines.append('')
1274    lines.append('definition')
1275    lines.append('  "enum_class.enum_ex (P :: %s \<Rightarrow> bool) \<longleftrightarrow> Bex UNIV P"' \
1276            % header)
1277    lines.append('')
1278    lines.append('  instance')
1279    lines.append('  apply intro_classes')
1280    lines.append('   apply (safe, simp)')
1281    lines.append('   apply (case_tac x)')
1282    if len(canonical) == 1:
1283        lines.append('  apply (simp_all add: enum_%s enum_all_%s_def enum_ex_%s_def' \
1284                % (header, header, header))
1285        lines.append('    distinct_map_enum)')
1286        lines.append('  done')
1287    else:
1288        lines.append('  apply (simp_all add: enum_%s enum_all_%s_def enum_ex_%s_def)' \
1289                % (header, header, header))
1290        lines.append('  by fast+')
1291    lines.append('end')
1292    lines.append('')
1293    lines.append('instantiation %s :: enum_alt' % header)
1294    lines.append('begin')
1295    if call.current_context:
1296        lines.append('interpretation Arch .')
1297    lines.append('definition')
1298    lines.append('  enum_alt_%s: "enum_alt \<equiv> ' % header)
1299    lines.append('    alt_from_ord (enum :: %s list)"' % header)
1300    lines.append('instance ..')
1301    lines.append('end')
1302    lines.append('')
1303    lines.append('instantiation %s :: enumeration_both' % header)
1304    lines.append('begin')
1305    if call.current_context:
1306        lines.append('interpretation Arch .')
1307    lines.append('instance by (intro_classes, simp add: enum_alt_%s)' \
1308            % header)
1309    lines.append('end')
1310    lines.append('')
1311    lines.append('(*>*)')
1312
1313    return (lines, [])
1314
1315
1316enum_instance_proofs.order = 3
1317
1318
1319def bits_instance_proofs(header, canonical, d):
1320    assert len(canonical) == 1
1321    [(_, (cons, n))] = canonical
1322    assert n == 1
1323
1324    return (bij_instance('bit', header, cons,
1325                         [('shiftL', 'shiftL %s n', True),
1326                          ('shiftR', 'shiftR %s n', True),
1327                          ('bitSize', 'bitSize %s', False)]), [])
1328
1329
1330bits_instance_proofs.order = 5
1331
1332
1333def no_proofs(header, canonical, d):
1334    return ([], [])
1335
1336
1337no_proofs.order = 6
1338
1339# FIXME "Bounded" emits enum proofs even if something already has enum proofs
1340# generated due to "Enum"
1341
1342instance_proof_table = {
1343    'Eq': no_proofs,
1344    'Bounded': no_proofs, #enum_instance_proofs,
1345    'Enum': enum_instance_proofs,
1346    'Ix': no_proofs,
1347    'Ord': no_proofs,
1348    'Show': no_proofs,
1349    'Bits': bits_instance_proofs,
1350    'Real': no_proofs,
1351    'Num': num_instance_proofs,
1352    'Integral': no_proofs,
1353    'Storable': storable_instance_proofs,
1354    'PSpaceStorable': pspace_storable_instance_proofs,
1355    'Typeable': no_proofs,
1356    'Error': no_proofs,
1357    }
1358
1359def bij_instance (classname, typename, constructor, fns):
1360    L = [
1361        '',
1362        'instance %s :: %s ..' % (typename, classname),
1363        'defs (overloaded)'
1364    ]
1365    for (fname, fstr, cast_return) in fns:
1366        n = len (fstr.split('%s')) - 1
1367        names = ('x', 'y', 'z', 'w')[:n]
1368        names2 = tuple ([name + "'" for name in names])
1369        fstr1 = fstr % names
1370        fstr2 = fstr % names2
1371        L.append ('  %s_%s: "%s \<equiv>' % (fname, typename, fstr1))
1372        for name in names:
1373            L.append("    case %s of %s %s' \<Rightarrow>" \
1374                % (name, constructor, name))
1375        if cast_return:
1376            L.append ('      %s (%s)"' % (constructor, fstr2))
1377        else:
1378            L.append ('      %s"' % fstr2)
1379
1380    return L
1381
1382def get_type_map (string):
1383    """Takes Haskell named record syntax and produces
1384        a map (in the form of a list of tuples) defining
1385        the converted types of the names."""
1386    bits = string.split(None, 1)
1387    header = bits[0].strip()
1388    if len(bits) == 1:
1389        return (header, [])
1390    actual_map = bits[1].strip()
1391    if not (actual_map.startswith('{') and actual_map.endswith('}')):
1392        print('Error in ' + filename)
1393        print('Expected "A { blah :: blah etc }"')
1394        print('However { and } not found correctly')
1395        print('When parsing %s' % string)
1396        sys.exit(1)
1397    actual_map = actual_map[1:-1]
1398
1399    bits = braces.str(actual_map, '(', ')').split(',')
1400    bits.reverse()
1401    type = None
1402    map = []
1403    for bit in bits:
1404        bits = bit.split('::')
1405        if len(bits) == 2:
1406            type = type_transform(str(bits[1]).strip())
1407            name = str(bits[0]).strip()
1408        else:
1409            name = str(bit).strip()
1410        map.append((name, type))
1411    map.reverse()
1412    return (header, map)
1413
1414
1415numLiftIO = [0]
1416
1417
1418def body_transform(body, defined, sig, nopattern=False):
1419    # assume single object
1420    [(line, children)] = body
1421
1422    if '(' in line.split('=')[0] and not nopattern:
1423        [(line, children)] = \
1424            pattern_match_transform([(line, children)])
1425
1426    if 'liftIO' in reduce_to_single_line((line, children)):
1427        # liftIO is the translation boundary for current
1428        # SEL4, below which we get into details of interaction
1429        # with the foreign function interface - axiomatise
1430        assert '=' in line
1431        line = line.split('=')[0]
1432        bits = line.split()
1433        numLiftIO[0] = numLiftIO[0] + 1
1434        rhs = '(%d :: Int) %s' % (numLiftIO[0], ' '.join(bits[1:]))
1435        line = '%s\<equiv> underlying_arch_op %s' % (line, rhs)
1436        children = []
1437    elif '=' in line:
1438        line = '\<equiv>'.join(line.split('=', 1))
1439    elif leading_bar.match(children[0][0]):
1440        pass
1441    elif '=' in children[0][0]:
1442        (nextline, c2) = children[0]
1443        children[0] = ('\<equiv>'.join(nextline.split('=', 1)), c2)
1444    else:
1445        sys.stderr.write('WARN: def of %s missing =\n' % defined)
1446
1447    if children and (children[-1][0].endswith('where') or
1448                     children[-1][0].lstrip().startswith('where')):
1449        bits = line.split('\<equiv>')
1450        where_clause = where_clause_transform(children[-1])
1451        children = children[:-1]
1452        if len(bits) == 2 and bits[1].strip():
1453            line = bits[0] + '\<equiv>'
1454            new_line = ' ' * len(line) + bits[1]
1455            children = [(new_line, children)]
1456    else:
1457        where_clause = []
1458
1459    (line, children) = zipWith_transforms(line, children)
1460
1461    (line, children) = supplied_transforms(line, children)
1462
1463    (line, children) = case_clauses_transform((line, children))
1464
1465    (line, children) = do_clauses_transform((line, children), sig)
1466
1467    if children and leading_bar.match(children[0][0]):
1468        line = line + ' \<equiv>'
1469        children = guarded_body_transform(children, ' = ')
1470
1471    children = where_clause + children
1472
1473    if not nopattern:
1474        line = lhs_transform(line)
1475    line = lhs_de_underscore(line)
1476
1477    (line, children) = type_assertion_transform(line, children)
1478
1479    (line, children) = run_regexes((line, children))
1480    (line, children) = run_ext_regexes((line, children))
1481
1482    (line, children) = bracket_dollar_lambdas((line, children))
1483
1484    line = '"' + line
1485    (line, children) = add_trailing_string('"', (line, children))
1486    return [(line, children)]
1487
1488
1489dollar_lambda_regex = re.compile(r"\$\s*\\<lambda>")
1490
1491
1492def bracket_dollar_lambdas(xxx_todo_changeme):
1493    (line, children) = xxx_todo_changeme
1494    if dollar_lambda_regex.search(line):
1495        [left, right] = dollar_lambda_regex.split(line)
1496        line = '%s(\<lambda>%s' % (left, right)
1497        both = (line, children)
1498        if has_trailing_string(';', both):
1499            both = remove_trailing_string(';', both)
1500            (line, children) = add_trailing_string(');', both)
1501        else:
1502            (line, children) = add_trailing_string(')', both)
1503    children = [bracket_dollar_lambdas(elt) for elt in children]
1504    return (line, children)
1505
1506
1507def zipWith_transforms(line, children):
1508    if 'zipWithM_' not in line:
1509        children = [zipWith_transforms(l, c) for (l, c) in children]
1510        return (line, children)
1511
1512    if children == []:
1513        return (line, [])
1514
1515    if len(children) == 2:
1516        [(l, c), (l2, c2)] = children
1517        if c == [] and c2 == [] and '..]' in l + l2:
1518            children = [(l + ' ' + l2.strip(), [])]
1519
1520    (l, c) = children[-1]
1521    if c != [] or '..]' not in l:
1522        return (line, children)
1523
1524    bits = line.split('zipWithM_', 1)
1525    line = bits[0] + 'mapM_'
1526    ws = lead_ws(l)
1527    line2 = ws + '(split ' + bits[1]
1528
1529    bits = braces.str(l, '[', ']').split(None, braces=True)
1530
1531    line3 = ws + ' '.join(bits[:-2]) + ')'
1532    used_children = children[:-1] + [(line3, [])]
1533
1534    sndlast = bits[-2]
1535    last = bits[-1]
1536    if last.endswith('..]'):
1537        internal = last[1:-3].strip()
1538        if ',' in internal:
1539            bits = internal.split(',')
1540            l = '%s(zipE4 (%s) (%s) (%s))' \
1541                % (ws, sndlast, bits[0], bits[-1])
1542        else:
1543            l = '%s(zipE3 (%s) (%s))' % (ws, sndlast, internal)
1544    else:
1545        internal = sndlast[1:-3].strip()
1546        if ',' in internal:
1547            bits = internal.split(',')
1548            l = '%s(zipE2 (%s) (%s) (%s))' \
1549                % (ws, bits[0], bits[1], last)
1550        else:
1551            l = '%s(zipE1 (%s) (%s))' % (ws, internal, last)
1552
1553    return (line, [(line2, used_children), (l, [])])
1554
1555
1556def supplied_transforms(line, children):
1557    t = convert_to_stripped_tuple(line, children)
1558
1559    if t in supplied_transform_table:
1560        ws1 = lead_ws(line)
1561        result = supplied_transform_table[t]
1562        ws2 = lead_ws(result[0])
1563        result = adjust_ws(result, len(ws1) - len(ws2))
1564        supplied_transforms_usage[t] = 1
1565        return result
1566
1567    children = [supplied_transforms(l, c) for (l, c) in children]
1568
1569    return (line, children)
1570
1571
1572def convert_to_stripped_tuple(line, children):
1573    children = [convert_to_stripped_tuple(l, c) for (l, c) in children]
1574
1575    return (line.strip(), tuple(children))
1576
1577
1578def type_assertion_transform_inner(line):
1579    m = type_assertion.search(line)
1580    if not m:
1581        return line
1582    var = m.expand('\\1')
1583    type = m.expand('\\2').strip()
1584    type = type_transform(type)
1585    return line[:m.start()] + '(%s::%s)' % (var, type) \
1586        + type_assertion_transform_inner(line[m.end():])
1587
1588
1589def type_assertion_transform(line, children):
1590    children = [type_assertion_transform(l, c) for (l, c) in children]
1591
1592    return (type_assertion_transform_inner(line), children)
1593
1594
1595def where_clause_guarded_body(xxx_todo_changeme1):
1596    (line, children) = xxx_todo_changeme1
1597    if children and leading_bar.match(children[0][0]):
1598        return (line + ' =', guarded_body_transform(children, ' = '))
1599    else:
1600        return (line, children)
1601
1602
1603def where_clause_transform(xxx_todo_changeme2):
1604    (line, children) = xxx_todo_changeme2
1605    ws = line.split('where', 1)[0]
1606    if line.strip() != 'where':
1607        assert line.strip().startswith('where')
1608        children = [(''.join(line.split('where', 1)), [])] + children
1609    pre = ws + 'let'
1610    post = ws + 'in'
1611
1612    children = [(l, c) for (l, c) in children if l.split()[1] != '::']
1613    children = [case_clauses_transform((l, c)) for (l, c) in children]
1614    children = [do_clauses_transform(
1615        (l, c),
1616        None,
1617        type=0) for (l, c) in children]
1618    children = list(map(where_clause_guarded_body, children))
1619    for i, (l, c) in enumerate(children):
1620        l2 = braces.str(l, '(', ')')
1621        if len(l2.split('=')[0].split()) > 1:
1622            # turn f a = b into f = (\a -> b)
1623            l = '->'.join(l.split('=', 1))
1624            l = lead_ws(l) + ' = (\\ '.join(l.split(None, 1))
1625            (l, c) = add_trailing_string(')', (l, c))
1626            children[i] = (l, c)
1627    children = order_let_children(children)
1628    for i, child in enumerate(children[:-1]):
1629        children[i] = add_trailing_string(';', child)
1630    return [(pre, [])] + children + [(post, [])]
1631
1632
1633varname_re = re.compile(r"\w+")
1634
1635
1636def order_let_children(L):
1637    single_lines = [reduce_to_single_line(elt) for elt in L]
1638    keys = [str(braces.str(line, '(', ')').split('=')[0]).split()[0]
1639            for line in single_lines]
1640    keys = dict((key, n) for (n, key) in enumerate(keys))
1641    bits = [varname_re.findall(line) for line in single_lines]
1642    deps = {}
1643    for i, bs in enumerate(bits):
1644        for bit in bs:
1645            if bit in keys:
1646                j = keys[bit]
1647                if j != i:
1648                    deps.setdefault(i, {})[j] = 1
1649    done = {}
1650    L2 = []
1651    todo = dict(enumerate(L))
1652    n = len(todo)
1653    while n:
1654        todo_keys = list(todo.keys())
1655        for key in todo_keys:
1656            depstodo = [dep
1657                        for dep in list(deps.get(key, {}).keys()) if dep not in done]
1658            if depstodo == []:
1659                L2.append(todo.pop(key))
1660                done[key] = 1
1661        if len(todo) == n:
1662            print("No progress resolving let deps")
1663            print()
1664            print(list(todo.values()))
1665            print()
1666            print("Dependencies:")
1667            print(deps)
1668            assert 0
1669        n = len(todo)
1670    return L2
1671
1672
1673def do_clauses_transform(xxx_todo_changeme3, rawsig, type=None):
1674    (line, children) = xxx_todo_changeme3
1675    if children and children[-1][0].lstrip().startswith('where'):
1676        where_clause = where_clause_transform(children[-1])
1677        where_clause = [do_clauses_transform(
1678            (l, c), rawsig, 0) for (l, c) in where_clause]
1679        others = (line, children[:-1])
1680        others = do_clauses_transform(others, rawsig, type)
1681        (line, children) = where_clause[0]
1682        return (line, children + where_clause[1:] + [others])
1683
1684    if children:
1685        (l, c) = children[0]
1686        if c == [] and l.endswith('do'):
1687            line = line + ' ' + l.strip()
1688            children = children[1:]
1689
1690    if type is None:
1691        if not rawsig:
1692            type = 0
1693            sig = None
1694        else:
1695            sig = ' '.join(flatten_tree([rawsig]))
1696            type = monad_type_acquire(sig)
1697    (line, type) = monad_type_transform((line, type))
1698    if children == []:
1699        return (line, [])
1700
1701    rhs = line.split('<-', 1)[-1]
1702    if rhs.strip() == 'syscall' or rhs.strip() == 'atomicSyscall':
1703        assert len(children) == 5
1704        children = [do_clauses_transform(elt,
1705                                         rawsig,
1706                                         type=subtype)
1707                    for elt, subtype in zip(children, [1, 0, 1, 0, type])]
1708    elif line.strip().endswith('catchFailure'):
1709        assert len(children) == 2
1710        children = [do_clauses_transform(elt,
1711                                         rawsig,
1712                                         type=subtype)
1713                    for elt, subtype in zip(children, [1, 0])]
1714    else:
1715        children = [do_clauses_transform(elt,
1716                                         rawsig,
1717                                         type=type) for elt in children]
1718
1719    if not line.endswith('do'):
1720        return (line, children)
1721
1722    children, other_children = split_on_unmatched_bracket(children)
1723
1724    # single statement do clause won't parse in Isabelle
1725    if len(children) == 1:
1726        ws = lead_ws(line)
1727        return (line[:-2] + '(', children + [(ws + ')', [])])
1728
1729    line = line[:-2] + '(do' + 'E' * type
1730
1731    children = [(l, c) for (l, c) in children if l.strip() or c]
1732
1733    children2 = []
1734    for (l, c) in children:
1735        if l.lstrip().startswith('let '):
1736            if '=' not in l:
1737                extra = reduce_to_single_line(c[0])
1738                assert '=' in extra
1739                l = l + ' ' + extra
1740                c = c[1:]
1741            l = ''.join(l.split('let ', 1))
1742            letsubs = '<- return' + 'Ok' * type + ' ('
1743            l = letsubs.join(l.split('=', 1))
1744            (l, c) = add_trailing_string(')', (l, c))
1745            children2.extend(do_clause_pattern(l, c, type))
1746        else:
1747            children2.extend(do_clause_pattern(l, c, type))
1748
1749    children = [add_trailing_string(';', child)
1750                for child in children2[:-1]] + [children2[-1]]
1751
1752    ws = lead_ws(line)
1753    children.append((ws + 'od' + 'E' * type + ')', []))
1754
1755    return (line, children + other_children)
1756
1757
1758left_start_table = {
1759    'ASIDPool': '(inv ASIDPool)',
1760    'HardwareASID': 'id',
1761    'ArchObjectCap': 'capCap',
1762    'Just': 'the'
1763}
1764
1765
1766def do_clause_pattern(line, children, type, n=0):
1767    bits = line.split('<-')
1768    default = [('\<leftarrow>'.join(bits), children)]
1769    if len(bits) == 1:
1770        return default
1771    (left, right) = line.split('<-', 1)
1772    if ':' not in left and '[' not in left \
1773            and len(left.split()) == 1:
1774        return default
1775    left = left.strip()
1776
1777    v = 'v%d' % get_next_unique_id()
1778
1779    ass = 'assert' + ('E' * type)
1780    ws = lead_ws(line)
1781
1782    if left.startswith('('):
1783        assert left.endswith(')')
1784        if (',' in left):
1785            return default
1786        else:
1787            left = left[1:-1]
1788    bs = braces.str(left, '[', ']')
1789    if len(bs.split(':')) > 1:
1790        bits = [str(bit).strip() for bit in bs.split(':', 1)]
1791        lines = [('%s%s <- %s' % (ws, v, right), children),
1792                 ('%s%s <- headM %s' % (ws, bits[0], v), []),
1793                 ('%s%s <- tailM %s' % (ws, bits[1], v), [])]
1794        result = []
1795        for (l, c) in lines:
1796            result.extend(do_clause_pattern(l, c, type, n + 1))
1797        return result
1798    if left == '[]':
1799        return [('%s%s <- %s' % (ws, v, right), children),
1800                ('%s%s (%s = []) []' % (ws, ass, v), [])]
1801    if left.startswith('['):
1802        assert left.endswith(']')
1803        bs = braces.str(left[1:-1], '[', ']')
1804        bits = bs.split(',', 1)
1805        if len(bits) == 1:
1806            new_left = '%s:%s' % (bits[0], v)
1807            new_line = '%s%s <- %s' % (ws, new_left, right)
1808            extra = [('%s%s (%s = []) []' % (ws, ass, v), [])]
1809        else:
1810            new_left = '%s:[%s]' % (bits[0], bits[1])
1811            new_line = lead_ws(line) + new_left + ' <- ' + right
1812            extra = []
1813        return do_clause_pattern (new_line, children, type, n + 1) \
1814            + extra
1815    for lhs in left_start_table:
1816        if left.startswith(lhs):
1817            left = left[len(lhs):]
1818            tab = left_start_table[lhs]
1819            lM = 'liftM' + 'E' * type
1820            nl = ('%s <- %s %s $ %s' % (left, lM, tab, right))
1821            return do_clause_pattern(nl, children, type, n + 1)
1822
1823    print(line)
1824    print(left)
1825    assert 0
1826
1827
1828def split_on_unmatched_bracket(elts, n=None):
1829    if n is None:
1830        elts, other_elts, n = split_on_unmatched_bracket(elts, 0)
1831        return (elts, other_elts)
1832
1833    for (i, (line, children)) in enumerate(elts):
1834        for (j, c) in enumerate(line):
1835            if c == '(':
1836                n = n + 1
1837            if c == ')':
1838                n = n - 1
1839                if n < 0:
1840                    frag1 = line[:j]
1841                    frag2 = ' ' * len(frag1) + line[j:]
1842                    new_elts = elts[:i] + [(frag1, [])]
1843                    oth_elts = [(frag2, children)] \
1844                        + elts[i + 1:]
1845                    return (new_elts, oth_elts, n)
1846        c, other_c, n = split_on_unmatched_bracket(children, n)
1847        if other_c:
1848            new_elts = elts[:i] + [(line, c)]
1849            other_elts = other_c + elts[i + 1:]
1850            return (new_elts, other_elts, n)
1851
1852    return (elts, [], n)
1853
1854
1855def monad_type_acquire(sig, type=0):
1856    # note kernel appears after kernel_f/kernel_monad
1857    for (key, n) in [('kernel_f', 1), ('fault_monad', 1), ('syscall_monad', 2),
1858                     ('kernel_monad', 0), ('kernel_init', 1), ('kernel_p', 1),
1859                     ('kernel', 0)]:
1860        if key in sig:
1861            sigend = sig.split(key)[-1]
1862            return monad_type_acquire(sigend, n)
1863
1864    return type
1865
1866
1867def monad_type_transform(xxx_todo_changeme4):
1868    (line, type) = xxx_todo_changeme4
1869    split = None
1870    if 'withoutError' in line:
1871        split = 'withoutError'
1872        newtype = 1
1873    elif 'doKernelOp' in line:
1874        split = 'doKernelOp'
1875        newtype = 0
1876    elif 'runInit' in line:
1877        split = 'runInit'
1878        newtype = 1
1879    elif 'withoutFailure' in line:
1880        split = 'withoutFailure'
1881        newtype = 0
1882    elif 'withoutFault' in line:
1883        split = 'withoutFault'
1884        newtype = 0
1885    elif 'withoutPreemption' in line:
1886        split = 'withoutPreemption'
1887        newtype = 0
1888    elif 'allowingFaults' in line:
1889        split = 'allowingFaults'
1890        newtype = 1
1891    elif 'allowingErrors' in line:
1892        split = 'allowingErrors'
1893        newtype = 2
1894    elif '`catchFailure`' in line:
1895        [left, right] = line.split('`catchFailure`', 1)
1896        left, _ = monad_type_transform((left, 1))
1897        right, type = monad_type_transform((right, 0))
1898        return (left + '`catchFailure`' + right, type)
1899    elif 'catchingFailure' in line:
1900        split = 'catchingFailure'
1901        newtype = 1
1902    elif 'catchF' in line:
1903        split = 'catchF'
1904        newtype = 1
1905    elif 'emptyOnFailure' in line:
1906        split = 'emptyOnFailure'
1907        newtype = 1
1908    elif 'constOnFailure' in line:
1909        split = 'constOnFailure'
1910        newtype = 1
1911    elif 'nothingOnFailure' in line:
1912        split = 'nothingOnFailure'
1913        newtype = 1
1914    elif 'nullCapOnFailure' in line:
1915        split = 'nullCapOnFailure'
1916        newtype = 1
1917    elif '`catchFault`' in line:
1918        split = '`catchFault`'
1919        newtype = 1
1920    elif 'capFaultOnFailure' in line:
1921        split = 'capFaultOnFailure'
1922        newtype = 1
1923    elif 'ignoreFailure' in line:
1924        split = 'ignoreFailure'
1925        newtype = 1
1926    elif 'handleInvocation False' in line:  # THIS IS A HACK
1927        split = 'handleInvocation False'
1928        newtype = 0
1929    if split:
1930        [left, right] = line.split(split, 1)
1931        left, _ = monad_type_transform((left, type))
1932        right, newnewtype = monad_type_transform((right, newtype))
1933        return (left + split + right, newnewtype)
1934
1935    if type:
1936        line = ('return' + 'Ok' * type).join(line.split('return'))
1937        line = ('when' + 'E' * type).join(line.split('when'))
1938        line = ('unless' + 'E' * type).join(line.split('unless'))
1939        line = ('mapM' + 'E' * type).join(line.split('mapM'))
1940        line = ('forM' + 'E' * type).join(line.split('forM'))
1941        line = ('liftM' + 'E' * type).join(line.split('liftM'))
1942        line = ('assert' + 'E' * type).join(line.split('assert'))
1943        line = ('stateAssert' + 'E' * type).join(line.split('stateAssert'))
1944
1945    return (line, type)
1946
1947
1948def case_clauses_transform(xxx_todo_changeme5):
1949    (line, children) = xxx_todo_changeme5
1950    children = [case_clauses_transform(child) for child in children]
1951
1952    if not line.endswith(' of'):
1953        return (line, children)
1954
1955    bits = line.split('case ', 1)
1956    beforecase = bits[0]
1957    x = bits[1][:-3]
1958
1959    if '::' in x:
1960        x2 = braces.str(x, '(', ')')
1961        bits = x2.split('::', 1)
1962        if len(bits) == 2:
1963            x = str(bits[0]) + ':: ' + type_transform(str(bits[1]))
1964
1965    if children and children[-1][0].strip().startswith('where'):
1966        sys.stderr.write('Warning: where clause in case: %r\n'
1967                         % line)
1968        return (beforecase + '(* case removed *) undefined', [])
1969        # where_clause = where_clause_transform(children[-1])
1970        # children = children[:-1]
1971        # (in_stmt, l) = where_clause[-1]
1972        # l.append(case_clauses_transform((line, children)))
1973        # return where_clause
1974
1975    cases = []
1976    bodies = []
1977    for (l, c) in children:
1978        bits = l.split('->', 1)
1979        while len(bits) == 1:
1980            if c == []:
1981                sys.stderr.write('wtf %r\n' % l)
1982                sys.exit(1)
1983            if c[0][0].strip().startswith('|'):
1984                bits = [bits[0], '']
1985                c = guarded_body_transform(c, '->')
1986            elif c[0][1] == []:
1987                l = l + ' ' + c.pop(0)[0].strip()
1988                bits = l.split('->', 1)
1989            else:
1990                [(moreline, c)] = c
1991                l = l + ' ' + moreline.strip()
1992                bits = l.split('->', 1)
1993        case = bits[0].strip()
1994        tail = bits[1]
1995        if c and c[-1][0].lstrip().startswith('where'):
1996            where_clause = where_clause_transform(c[-1])
1997            ws = lead_ws(where_clause[0][0])
1998            c = where_clause + [(ws + tail.strip(), [])] + c[:-1]
1999            tail = ''
2000        cases.append(case)
2001        bodies.append((tail, c))
2002
2003    cases = tuple(cases) # used as key of a dict later, needs to be hashable
2004                         # (since lists are mutable, they aren't)
2005    if not cases:
2006        print(line)
2007    conv = get_case_conv(cases)
2008    if conv == '<X>':
2009        sys.stderr.write('Warning: blanked case in caseconvs\n')
2010        return (beforecase + '(* case removed *) undefined', [])
2011    if not conv:
2012        sys.stderr.write('Warning: case %r\n' % (cases, ))
2013        if cases not in cases_added:
2014            casestr = 'case \\x of ' + ' -> '.join(cases) + ' -> '
2015
2016            f = open('caseconvs', 'a')
2017            f.write('%s ---X>\n\n' % casestr)
2018            f.close()
2019            cases_added[cases] = 1
2020        return (beforecase + '(* case removed *) undefined', [])
2021    conv = subs_nums_and_x(conv, x)
2022
2023    new_line = beforecase + '(' + conv[0][0]
2024    assert conv[0][1] is None
2025
2026    ws = lead_ws(children[0][0])
2027    new_children = []
2028    for (header, endnum) in conv[1:]:
2029        if endnum is None:
2030            new_children.append((ws + header, []))
2031        else:
2032            if (len(bodies) <= endnum):
2033                sys.stderr.write('ERROR: index %d out of bounds in case %r\n' %
2034                                 (endnum,
2035                                  cases, ))
2036                sys.exit(1)
2037
2038            (body, c) = bodies[endnum]
2039            new_children.append((ws + header + ' ' + body, c))
2040
2041    if has_trailing_string(',', new_children[-1]):
2042        new_children[-1] = \
2043            remove_trailing_string(',', new_children[-1])
2044        new_children.append((ws + '),', []))
2045    else:
2046        new_children.append((ws + ')', []))
2047
2048    return (new_line, new_children)
2049
2050
2051def guarded_body_transform(body, div):
2052    new_body = []
2053    if body[-1][0].strip().startswith('where'):
2054        new_body.extend(where_clause_transform(body[-1]))
2055        body = body[:-1]
2056    for i, (line, children) in enumerate(body):
2057        try:
2058            while div not in line:
2059                (l, c) = children[0]
2060                children = c + children[1:]
2061                line = line + ' ' + l.strip()
2062        except:
2063            sys.stderr.write('missing %r in %r\n' % (div, line))
2064            sys.stderr.write('\nhandling %r\n' % body)
2065            sys.exit(1)
2066        try:
2067            [ws, guts] = line.split('| ', 1)
2068        except:
2069            sys.stderr.write('missing "|" in %r\n' % line)
2070            sys.stderr.write('\nhandling %r\n' % body)
2071            sys.exit(1)
2072        if i == 0:
2073            new_body.append((ws + 'if', []))
2074        else:
2075            new_body.append((ws + 'else if', []))
2076        guts = ' then '.join(guts.split(div, 1))
2077        new_body.append((ws + guts, children))
2078    new_body.append((ws + 'else undefined', []))
2079
2080    return new_body
2081
2082
2083def lhs_transform(line):
2084    if '(' not in line:
2085        return line
2086
2087    [left, right] = line.split('\<equiv>')
2088
2089    ws = left[:len(left) - len(left.lstrip())]
2090
2091    left = left.lstrip()
2092
2093    bits = braces.str(left, '(', ')').split(braces=True)
2094
2095    for (i, bit) in enumerate(bits):
2096        if bit.startswith('('):
2097            bits[i] = 'arg%d' % i
2098            case = 'case arg%d of %s \<Rightarrow> ' % (i, bit)
2099            right = case + right
2100
2101    return ws + ' '.join([str(bit) for bit in bits]) + '\<equiv>' + right
2102
2103
2104def lhs_de_underscore(line):
2105    if '_' not in line:
2106        return line
2107
2108    [left, right] = line.split('\<equiv>')
2109
2110    ws = left[:len(left) - len(left.lstrip())]
2111
2112    left = left.lstrip()
2113    bits = left.split()
2114
2115    for (i, bit) in enumerate(bits):
2116        if bit == '_':
2117            bits[i] = 'arg%d' % i
2118
2119    return ws + ' '.join([str(bit) for bit in bits]) + ' \<equiv>' + right
2120
2121
2122regexes = [
2123    (re.compile(r" \. "), r" \<circ> "),
2124    (re.compile('-1'), '- 1'),
2125    (re.compile('-2'), '- 2'),
2126    (re.compile(r"\(!(\w+)\)"), r"(flip id \1)"),
2127    (re.compile(r"\(\+(\w+)\)"), r"(\<lambda> x. x + \1)"),
2128    (re.compile(r"\\([^<].*?)\s*->"), r"\<lambda> \1."),
2129    (re.compile('}'), r"\<rparr>"),
2130    (re.compile(r"(\s)!!(\s)"), r"\1LIST_INDEX\2"),
2131    (re.compile(r"(\w)!"), r"\1 "),
2132    (re.compile(r"\s?!"), ''),
2133    (re.compile(r"LIST_INDEX"), r"!"),
2134    (re.compile('`testBit`'), '!!'),
2135    (re.compile(r"//"), ' aLU '),
2136    (re.compile('otherwise'), 'True     '),
2137    (re.compile(r"(^|\W)fail "), r"\1haskell_fail "),
2138    (re.compile('assert '), 'haskell_assert '),
2139    (re.compile('assertE '), 'haskell_assertE '),
2140    (re.compile('=='), '='),
2141    (re.compile(r"\(/="), '(\<lambda>x. x \<noteq>'),
2142    (re.compile('/='), '\<noteq>'),
2143    (re.compile('"([^"])*"'), '[]'),
2144    (re.compile('&&'), '\<and>'),
2145    (re.compile('\|\|'), '\<or>'),
2146    (re.compile(r"(\W)not(\s)"), r"\1Not\2"),
2147    (re.compile(r"(\W)and(\s)"), r"\1andList\2"),
2148    (re.compile(r"(\W)or(\s)"), r"\1orList\2"),
2149    # regex ordering important here
2150    (re.compile(r"\.&\."), '&&'),
2151    (re.compile(r"\(\.\|\.\)"), r"bitOR"),
2152    (re.compile(r"\(\+\)"), r"op +"),
2153    (re.compile(r"\.\|\."), '||'),
2154    (re.compile(r"Data\.Word\.Word"), r"word"),
2155    (re.compile(r"Data\.Map\."), r"data_map_"),
2156    (re.compile(r"Data\.Set\."), r"data_set_"),
2157    (re.compile(r"BinaryTree\."), 'bt_'),
2158    (re.compile("mapM_"), "mapM_x"),
2159    (re.compile("forM_"), "forM_x"),
2160    (re.compile("forME_"), "forME_x"),
2161    (re.compile("mapME_"), "mapME_x"),
2162    (re.compile("zipWithM_"), "zipWithM_x"),
2163    (re.compile(r"bit\s+([0-9]+)(\s)"), r"(1 << \1)\2"),
2164    (re.compile('`mod`'), 'mod'),
2165    (re.compile('`div`'), 'div'),
2166    (re.compile(r"`((\w|\.)*)`"), r"`~\1~`"),
2167    (re.compile('size'), 'magnitude'),
2168    (re.compile('foldr'), 'foldR'),
2169    (re.compile(r"\+\+"), '@'),
2170    (re.compile(r"(\s|\)|\w|\]):(\s|\(|\w|$|\[)"), r"\1#\2"),
2171    (re.compile(r"\[([^]]*)\.\.([^]]*)\]"), r"[\1 .e. \2]"),
2172    (re.compile(r"\[([^]]*)\.\.\s*$"), r"[\1 .e."),
2173    (re.compile(' Right'), ' Inr'),
2174    (re.compile(' Left'), ' Inl'),
2175    (re.compile('\\(Left'), '(Inl'),
2176    (re.compile('\\(Right'), '(Inr'),
2177    (re.compile(r"\$!"), r"$"),
2178    (re.compile('([^>])>='), r'\1\<ge>'),
2179    (re.compile('>>([^=])'), r'>>_\1'),
2180    (re.compile('<='), '\<le>'),
2181    (re.compile(r" \\\\ "), " `~listSubtract~` "),
2182    (re.compile(r"(\s\w+)\s*@\s*\w+\s*{\s*}\s*\<leftarrow>"),
2183     r"\1 \<leftarrow>"),
2184]
2185
2186ext_regexes = [
2187    (re.compile(r"(\s[A-Z]\w*)\s*{"), r"\1_ \<lparr>", re.compile(r"(\w)\s*="),
2188     r"\1="),
2189    (re.compile(r"(\([A-Z]\w*)\s*{"), r"\1_ \<lparr>", re.compile(r"(\w)\s*="),
2190     r"\1="),
2191    (re.compile(r"{([^={<]*[^={<:])=([^={<]*)\\<rparr>"),
2192     r"\<lparr>\1:=\2\<rparr>",
2193     re.compile(r"THIS SHOULD NOT APPEAR IN THE SOURCE"), ""),
2194    (re.compile(r"{"), r"\<lparr>", re.compile(r"([^=:])=(\s|$|\w)"),
2195     r"\1:=\2"),
2196]
2197
2198leading_bar = re.compile(r"\s*\|")
2199
2200type_assertion = re.compile(r"\(([^(]*)::([^)]*)\)")
2201
2202
2203def run_regexes(xxx_todo_changeme6, _regexes=regexes):
2204    (line, children) = xxx_todo_changeme6
2205    for re, s in _regexes:
2206        line = re.sub(s, line)
2207    children = [run_regexes(elt, _regexes=_regexes) for elt in children]
2208    return ((line, children))
2209
2210
2211def run_ext_regexes(xxx_todo_changeme7):
2212    (line, children) = xxx_todo_changeme7
2213    for re, s, add_re, add_s in ext_regexes:
2214        m = re.search(line)
2215        if m is None:
2216            continue
2217        before = line[:m.start()]
2218        substituted = m.expand(s)
2219        after = line[m.end():]
2220        add = [(add_re, add_s)]
2221        (after, children) = run_regexes((after, children), _regexes=add)
2222        line = before + substituted + after
2223    children = [run_ext_regexes(elt) for elt in children]
2224    return (line, children)
2225
2226
2227def get_case_lhs(lhs):
2228    assert lhs.startswith('case \\x of ')
2229    lhs = lhs.split('case \\x of ', 1)[1]
2230    cases = lhs.split('->')
2231    cases = [case.strip() for case in cases]
2232    cases = [case for case in cases if case != '']
2233    cases = tuple(cases)
2234
2235    return cases
2236
2237
2238def get_case_rhs(rhs):
2239    tuples = []
2240    while '->' in rhs:
2241        bits = rhs.split('->', 1)
2242        s = bits[0]
2243        bits = bits[1].split(None, 1)
2244        n = int(takeWhile(bits[0], lambda x: x.isdigit())) - 1
2245        if len(bits) > 1:
2246            rhs = bits[1]
2247        else:
2248            rhs = ''
2249        tuples.append((s, n))
2250    if rhs != '':
2251        tuples.append((rhs, None))
2252
2253    conv = []
2254    for (string, num) in tuples:
2255        bits = string.split('\\n')
2256        bits = [bit.strip() for bit in bits]
2257        conv.extend([(bit, None) for bit in bits[:-1]])
2258        conv.append((bits[-1], num))
2259
2260    conv = [(s, n) for (s, n) in conv if s != '' or n is not None]
2261
2262    if conv[0][1] is not None:
2263        sys.stderr.write('%r\n' % conv[0][1])
2264        sys.stderr.write(
2265            'For technical reasons the first line of this case conversion must be split with \\n: \n')
2266        sys.stderr.write('  %r\n' % rhs)
2267        sys.stderr.write(
2268            '(further notes: the rhs of each caseconv must have multiple lines\n'
2269            'and the first cannot contain any ->1, ->2 etc.)\n')
2270        sys.exit(1)
2271
2272    # this is a tad dodgy, but means that case_clauses_transform
2273    # can be safely run twice on the same input
2274    if conv[0][0].endswith('of'):
2275        conv[0] = (conv[0][0] + ' ', conv[0][1])
2276
2277    return conv
2278
2279
2280def render_caseconv(cases, conv, f):
2281    bits = [bit for bit in conv.split('\\n') if bit != '']
2282    assert bits
2283    casestr = 'case \\x of ' + ' -> '.join(cases) + ' -> '
2284    f.write('%s --->' % casestr)
2285    for bit in bits:
2286        f.write(bit)
2287        f.write('\n')
2288    f.write('\n')
2289
2290
2291def get_case_conv_table():
2292    f = open('caseconvs')
2293    f2 = open('caseconvs-useful', 'w')
2294    result = {}
2295    input = map(str.rstrip, f)
2296    input = ("\\n".join(lines) for lines in splitList(input, lambda s: s == ''))
2297
2298    for line in input:
2299        if line.strip() == '':
2300            continue
2301        try:
2302            if '---X>' in line:
2303                [from_case, _] = line.split('---X>')
2304                cases = get_case_lhs(from_case)
2305                result[cases] = "<X>"
2306            else:
2307                [from_case, to_case] = line.split('--->')
2308                cases = get_case_lhs(from_case)
2309                conv = get_case_rhs(to_case)
2310                result[cases] = conv
2311                if (not all_constructor_patterns(cases) and
2312                        not is_extended_pattern(cases)):
2313                    render_caseconv(cases, to_case, f2)
2314        except Exception as e:
2315            sys.stderr.write('Error parsing %r\n' % line)
2316            sys.stderr.write('%s\n ' % e)
2317            sys.exit(1)
2318
2319    f.close()
2320    f2.close()
2321
2322    return result
2323
2324
2325def all_constructor_patterns(cases):
2326    if cases[-1].strip() == '_':
2327        cases = cases[:-1]
2328    for pat in cases:
2329        if not is_constructor_pattern(pat):
2330            return False
2331    return True
2332
2333
2334def is_constructor_pattern(pat):
2335    """A constructor pattern takes the form Cons var1 var2 ...,
2336        characterised by all alphanumeric names, the constructor starting
2337        with an uppercase alphabetic char and the vars with lowercase."""
2338    bits = pat.split()
2339    for bit in bits:
2340        if (not bit.isalnum()) and (not bit == '_'):
2341            return False
2342    if not bits[0][0].isupper():
2343        return False
2344    for bit in bits[1:]:
2345        if (not bit[0].islower()) and (not bit == '_'):
2346            return False
2347    return True
2348
2349
2350ext_checker = re.compile(r"^(\(|\)|,|{|}|=|[a-zA-Z][0-9']?|\s|_|:|\[|\])*$")
2351
2352
2353def is_extended_pattern(cases):
2354    for case in cases:
2355        if not ext_checker.match(case):
2356            return False
2357    return True
2358
2359
2360case_conv_table = get_case_conv_table()
2361cases_added = {}
2362
2363
2364def get_case_conv(cases):
2365    if all_constructor_patterns(cases):
2366        return all_constructor_conv(cases)
2367
2368    if is_extended_pattern(cases):
2369        return extended_pattern_conv(cases)
2370
2371    return case_conv_table.get(cases)
2372
2373
2374constructor_conv_table = {
2375    'Just': 'Some',
2376    'Nothing': 'None',
2377    'Left': 'Inl',
2378    'Right': 'Inr',
2379    'PPtr': '(* PPtr *)',
2380    'Register': '(* Register *)',
2381    'Word': '(* Word *)',
2382}
2383
2384unique_ids_per_file = {}
2385
2386
2387def get_next_unique_id():
2388    id = unique_ids_per_file.get(filename, 1)
2389    unique_ids_per_file[filename] = id + 1
2390    return id
2391
2392
2393def all_constructor_conv(cases):
2394    conv = [('case \\x of', None)]
2395
2396    for i, pat in enumerate(cases):
2397        bits = pat.split()
2398        if bits[0] in constructor_conv_table:
2399            bits[0] = constructor_conv_table[bits[0]]
2400        for j, bit in enumerate(bits):
2401            if j > 0 and bit == '_':
2402                bits[j] = 'v%d' % get_next_unique_id()
2403        pat = ' '.join(bits)
2404        if i == 0:
2405            conv.append(('  %s \<Rightarrow> ' % pat, i))
2406        else:
2407            conv.append(('| %s \<Rightarrow> ' % pat, i))
2408    return conv
2409
2410
2411word_getter = re.compile (r"([a-zA-Z0-9]+)")
2412
2413record_getter = re.compile (r"([a-zA-Z0-9]+\s*{[a-zA-Z0-9'\s=\,_\(\):\]\[]*})")
2414
2415
2416def extended_pattern_conv(cases):
2417    conv = [('case \\x of', None)]
2418
2419    for i, pat in enumerate(cases):
2420        pat = '#'.join(pat.split(':'))
2421        while record_getter.search(pat):
2422            [left, record, right] = record_getter.split(pat)
2423            record = reduce_record_pattern(record)
2424            pat = left + record + right
2425        if '{' in pat:
2426            print(pat)
2427        assert '{' not in pat
2428        bits = word_getter.split(pat)
2429        bits = [constructor_conv_table.get(bit, bit) for bit in bits]
2430        pat = ''.join(bits)
2431        if i == 0:
2432            conv.append(('  %s \<Rightarrow> ' % pat, i))
2433        else:
2434            conv.append(('| %s \<Rightarrow> ' % pat, i))
2435    return conv
2436
2437
2438def reduce_record_pattern(string):
2439    assert string[-1] == '}'
2440    string = string[:-1]
2441    [left, right] = string.split('{')
2442    cons = left.strip()
2443    right = braces.str(right, '(', ')')
2444    eqs = right.split(',')
2445    uses = {}
2446    for eq in eqs:
2447        eq = str(eq).strip()
2448        if eq:
2449            [left, right] = eq.split('=')
2450            (left, right) = (left.strip(), right.strip())
2451            if len(right.split()) > 1:
2452                right = '(%s)' % right
2453            uses[left] = right
2454    if cons not in all_constructor_args:
2455        sys.stderr.write('FAIL: trying to build case for %s\n' % cons)
2456        sys.stderr.write('when reading %s\n' % filename)
2457        sys.stderr.write('but constructor not seen yet\n')
2458        sys.stderr.write('perhaps parse in different order?\n')
2459        sys.stderr.write('hint: #INCLUDE_HASKELL_PREPARSE\n')
2460        sys.exit(1)
2461    args = all_constructor_args[cons]
2462    args = [uses.get(name, '_') for (name, type) in args]
2463    return cons + ' ' + ' '.join(args)
2464
2465
2466def subs_nums_and_x(conv, x):
2467    ids = []
2468
2469    result = []
2470    for (line, num) in conv:
2471        line = x.join(line.split('\\x'))
2472        bits = line.split('\\v')
2473        line = bits[0]
2474        for bit in bits[1:]:
2475            bits = bit.split('\\', 1)
2476            n = int(bits[0])
2477            while n >= len(ids):
2478                ids.append(get_next_unique_id())
2479            line = line + 'v%d' % (ids[n])
2480            if len(bits) > 1:
2481                line = line + bits[1]
2482        result.append((line, num))
2483
2484    return result
2485
2486
2487def get_supplied_transform_table():
2488    f = open('supplied')
2489
2490    lines = [line.rstrip() for line in f]
2491    f.close()
2492
2493    lines = [(line, n + 1) for (n, line) in enumerate(lines)]
2494    lines = [(line, n) for (line, n) in lines if line != '']
2495
2496    for line in lines:
2497        if '\t' in line:
2498            sys.stderr.write('WARN: tab character in supplied')
2499
2500    tree = offside_tree(lines)
2501
2502    result = {}
2503
2504    for line, n, children in tree:
2505        if ('conv:' not in line) or len(children) != 2:
2506            sys.stderr.write('WARN: supplied line %d dropped\n'
2507                             % n)
2508            if 'conv:' not in line:
2509                sys.stderr.write('\t\t(token "conv:" missing)\n')
2510            if len(children) != 2:
2511                sys.stderr.write('\t\t(%d children != 2)\n' % len(children))
2512            continue
2513
2514        children = discard_line_numbers(children)
2515
2516        before, after = children
2517
2518        before = convert_to_stripped_tuple(before[0], before[1])
2519
2520        result[before] = after
2521
2522    return result
2523
2524
2525def print_tree(tree, indent=0):
2526    for line, children in tree:
2527        print('\t' * indent) + line.strip()
2528        print_tree(children, indent + 1)
2529
2530
2531supplied_transform_table = get_supplied_transform_table()
2532supplied_transforms_usage = dict((
2533    key, 0) for key in six.iterkeys(supplied_transform_table))
2534
2535
2536def warn_supplied_usage():
2537    for (key, usage) in six.iteritems(supplied_transforms_usage):
2538        if not usage:
2539            sys.stderr.write('WARN: supplied conv unused: %s\n'
2540                             % key[0])
2541
2542
2543quotes_getter = re.compile('"[^"]+"')
2544
2545
2546def detect_recursion(body):
2547    """Detects whether any of the bodies of the definitions of this
2548        function recursively refer to it."""
2549    single_lines = [reduce_to_single_line(elt) for elt in body]
2550    single_lines = [''.join(quotes_getter.split(l)) for l in single_lines]
2551    bits = [line.split(None, 1) for line in single_lines]
2552    name = bits[0][0]
2553    assert [n for (n, _) in bits if n != name] == []
2554    return [body for (n, body) in bits if name in body] != []
2555
2556
2557def primrec_transform(d):
2558    sig = d.sig
2559    defn = d.defined
2560    body = []
2561    is_not_first = False
2562    for (l, c) in d.body:
2563        [(l, c)] = body_transform([(l, c)], defn, sig, nopattern=True)
2564        if is_not_first:
2565            l = "| " + l
2566        else:
2567            l = "  " + l
2568            is_not_first = True
2569        l = l.split('\<equiv>')
2570        assert len(l) == 2
2571        l = '= ('.join(l)
2572        (l, c) = remove_trailing_string('"', (l, c))
2573        (l, c) = add_trailing_string(')"', (l, c))
2574        body.append((l, c))
2575    d.primrec = True
2576    d.body = body
2577    return d
2578
2579
2580variable_name_regex = re.compile(r"^[a-z]\w*$")
2581
2582
2583def is_variable_name(string):
2584    return variable_name_regex.match(string)
2585
2586
2587def pattern_match_transform(body):
2588    """Converts a body containing possibly multiple definitions
2589        and containing pattern matches into a normal Isabelle definition
2590        followed by a big Haskell case expression which is resolved
2591        elsewhere."""
2592    splits = []
2593    for (line, children) in body:
2594        string = braces.str(line, '(', ')')
2595        while len(string.split('=')) == 1:
2596            if len(children) == 1:
2597                [(moreline, children)] = children
2598                string = string + ' ' + moreline.strip()
2599            elif children and leading_bar.match(children[0][0]):
2600                string = string + ' ='
2601                children = \
2602                    guarded_body_transform(children, ' = ')
2603            elif children and children[0][1] == []:
2604                (moreline, _) = children.pop(0)
2605                string = string + ' ' + moreline.strip()
2606            else:
2607                print()
2608                print(line)
2609                print()
2610                for child in children:
2611                    print(child)
2612                assert 0
2613
2614        [lead, tail] = string.split('=', 1)
2615        bits = lead.split()
2616        unbraced = bits
2617        function = str(bits[0])
2618        splits.append((bits[1:], unbraced[1:], tail, children))
2619
2620    common = splits[0][0][:]
2621    for i, term in enumerate(common):
2622        if term.startswith('('):
2623            common[i] = None
2624        if '@' in term:
2625            common[i] = None
2626        if term[0].isupper():
2627            common[i] = None
2628
2629    for (bits, _, _, _) in splits[1:]:
2630        for i, term in enumerate(bits):
2631            if i >= len(common):
2632                print_tree(body)
2633            if term != common[i]:
2634                is_var = is_variable_name(str(term))
2635                if common[i] == '_' and is_var:
2636                    common[i] = term
2637                elif term != '_':
2638                    common[i] = None
2639
2640    for i, term in enumerate(common):
2641        if term == '_':
2642            common[i] = 'x%d' % i
2643
2644    blanks = [i for (i, n) in enumerate(common) if n is None]
2645
2646    line = '%s ' % function
2647    for i, name in enumerate(common):
2648        if name is None:
2649            line = line + 'x%d ' % i
2650        else:
2651            line = line + '%s ' % name
2652    if blanks == []:
2653        print(splits)
2654        print(common)
2655    if len(blanks) == 1:
2656        line = line + '= case x%d of' % blanks[0]
2657    else:
2658        line = line + '= case (x%d' % blanks[0]
2659        for i in blanks[1:]:
2660            line = line + ', x%d' % i
2661        line = line + ') of'
2662
2663    children = []
2664    for (bits, unbraced, tail, c) in splits:
2665        if len(blanks) == 1:
2666            l = '  %s' % unbraced[blanks[0]]
2667        else:
2668            l = '  (%s' % unbraced[blanks[0]]
2669            for i in blanks[1:]:
2670                l = l + ', %s' % unbraced[i]
2671            l = l + ')'
2672        l = l + ' -> %s' % tail
2673        children.append((l, c))
2674
2675    return [(line, children)]
2676
2677
2678def get_lambda_body_lines(d):
2679    """Returns lines equivalent to the body of the function as
2680        a lambda expression."""
2681    fn = d.defined
2682
2683    [(line, children)] = d.body
2684
2685    line = line[1:]
2686    # find \<equiv> in first or 2nd line
2687    if '\<equiv>' not in line and '\<equiv>' in children[0][0]:
2688        (l, c) = children[0]
2689        children = c + children[1:]
2690        line = line + l
2691    [lhs, rhs] = line.split('\<equiv>', 1)
2692    bits = lhs.split()
2693    args = bits[1:]
2694    assert fn in bits[0]
2695
2696    line = '(\<lambda>' + ' '.join(args) + '. ' + rhs
2697    # lines = ['(* body of %s *)' % fn, line] + flatten_tree (children)
2698    lines = [line] + flatten_tree(children)
2699    assert (lines[-1].endswith('"'))
2700    lines[-1] = lines[-1][:-1] + ')'
2701
2702    return lines
2703
2704
2705def add_trailing_string(s, xxx_todo_changeme8):
2706    (line, children) = xxx_todo_changeme8
2707    if children == []:
2708        return (line + s, children)
2709    else:
2710        modified = add_trailing_string(s, children[-1])
2711        return (line, children[0:-1] + [modified])
2712
2713
2714def remove_trailing_string(s, xxx_todo_changeme9, _handled=False):
2715    (line, children) = xxx_todo_changeme9
2716    if not _handled:
2717        try:
2718            return remove_trailing_string(s, (line, children), _handled=True)
2719        except:
2720            sys.stderr.write('handling %s\n' % ((line, children), ))
2721            raise
2722    if children == []:
2723        if not line.endswith(s):
2724            sys.stderr.write('ERR: expected %r\n' % line)
2725            sys.stderr.write('to end with %r\n' % s)
2726            assert line.endswith(s)
2727        n = len(s)
2728        return (line[:-n], [])
2729    else:
2730        modified = remove_trailing_string(s, children[-1], _handled=True)
2731        return (line, children[0:-1] + [modified])
2732
2733
2734def get_trailing_string(n, xxx_todo_changeme10):
2735    (line, children) = xxx_todo_changeme10
2736    if children == []:
2737        return line[-n:]
2738    else:
2739        return get_trailing_string(n, children[-1])
2740
2741
2742def has_trailing_string(s, xxx_todo_changeme11):
2743    (line, children) = xxx_todo_changeme11
2744    if children == []:
2745        return line.endswith(s)
2746    else:
2747        return has_trailing_string(s, children[-1])
2748
2749
2750def ensure_type_ordering(defs):
2751    typedefs = [d for d in defs if d.type == 'newtype']
2752    other = [d for d in defs if d.type != 'newtype']
2753
2754    final_typedefs = []
2755    while typedefs:
2756        try:
2757            i = 0
2758            deps = typedefs[i].typedeps
2759            while 1:
2760                for j, term in enumerate(typedefs):
2761                    if term.typename in deps:
2762                        break
2763                else:
2764                    break
2765                i = j
2766                deps = typedefs[i].typedeps
2767            final_typedefs.append(typedefs.pop(i))
2768        except Exception as e:
2769            print('Exception hit ordering types:')
2770            for td in typedefs:
2771                print('  - %s' % td.typename)
2772            raise e
2773
2774    return final_typedefs + other
2775
2776
2777def lead_ws(string):
2778    amount = len(string) - len(string.lstrip())
2779    return string[:amount]
2780
2781
2782def adjust_ws(xxx_todo_changeme12, n):
2783    (line, children) = xxx_todo_changeme12
2784    if n > 0:
2785        line = ' ' * n + line
2786    else:
2787        x = -n
2788        line = line[x:]
2789
2790    return (line, [adjust_ws(child, n) for child in children])
2791
2792
2793modulename = re.compile(r"(\w+\.)+")
2794
2795
2796def perform_module_redirects(lines, call):
2797    return [subst_module_redirects(line, call) for line in lines]
2798
2799
2800def subst_module_redirects(line, call):
2801    m = modulename.search(line)
2802    if not m:
2803        return line
2804    module = line[m.start():m.end() - 1]
2805    before = line[:m.start()]
2806    after = line[m.end():]
2807    after = subst_module_redirects(after, call)
2808    if module in call.moduletranslations:
2809        module = call.moduletranslations[module]
2810    if module:
2811        return before + module + '.' + after
2812    else:
2813        return before + after
2814