1#
2# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3#
4# SPDX-License-Identifier: BSD-2-Clause
5#
6
7from __future__ import print_function
8from __future__ import absolute_import
9import braces
10import re
11import sys
12import os
13import six
14from six.moves import map
15from six.moves import range
16from six.moves import zip
17from functools import reduce
18
19
20class Call(object):
21
22    def __init__(self):
23        self.restr = None
24        self.decls_only = False
25        self.instanceproofs = False
26        self.bodies_only = False
27        self.bad_type_assignment = False
28        self.body = False
29        self.current_context = []
30
31
32class Def(object):
33
34    def __init__(self):
35        self.type = None
36        self.defined = None
37        self.body = None
38        self.line = None
39        self.sig = None
40        self.instance_proofs = []
41        self.instance_extras = []
42        self.comments = []
43        self.primrec = None
44        self.deriving = []
45        self.instance_defs = {}
46
47
48def parse(call):
49    """Parses a file."""
50    set_global(call)
51
52    defs = get_defs(call.filename)
53
54    lines = get_lines(defs, call)
55
56    lines = perform_module_redirects(lines, call)
57
58    return ['%s\n' % line for line in lines]
59
60
61def settings_line(l):
62    """Adjusts some global settings."""
63    bits = l.split(',')
64    for bit in bits:
65        bit = bit.strip()
66        (kind, setting) = bit.split('=')
67        kind = kind.strip()
68        if kind == 'keep_constructor':
69            [cons] = setting.split()
70            keep_conss[cons] = 1
71        else:
72            assert not "setting kind understood", bit
73
74
75def set_global(_call):
76    global call
77    call = _call
78    global filename
79    filename = _call.filename
80
81
82file_defs = {}
83
84
85def splitList(list, pred):
86    """Splits a list according to pred."""
87    result = []
88    el = []
89    for l in list:
90        if pred(l):
91            if el != []:
92                result.append(el)
93                el = []
94        else:
95            el.append(l)
96    if el != []:
97        result.append(el)
98    return result
99
100
101def takeWhile(list, pred):
102    """Returns the initial portion of the list where each
103element matches pred"""
104    limit = 0
105
106    for l in list:
107        if pred(l):
108            limit = limit + 1
109        else:
110            break
111    return list[0:limit]
112
113
114def get_defs(filename):
115    # if filename in file_defs:
116    #    return file_defs[filename]
117
118    cmdline = os.environ['L4CPP']
119    f = os.popen('cpp -Wno-invalid-pp-token -traditional-cpp %s %s' %
120                 (cmdline, filename))
121    input = [line.rstrip() for line in f]
122    f.close()
123    defs = top_transform(input, filename.endswith(".lhs"))
124
125    file_defs[filename] = defs
126    return defs
127
128
129def top_transform(input, isLhs):
130    """Top level transform, deals with lhs artefacts, divides
131        the code up into a series of seperate definitions, and
132        passes these definitions through the definition transforms."""
133    to_process = []
134    comments = []
135    for n, line in enumerate(input):
136        if '\t' in line:
137            sys.stderr.write('WARN: tab in line %d, %s.\n' %
138                             (n, filename))
139        if isLhs:
140            if line.startswith('> '):
141                if '--' in line:
142                    line = line.split('--')[0].strip()
143
144                if line[2:].strip() == '':
145                    comments.append((n, 'C', ''))
146                elif line.startswith('> {-#'):
147                    comments.append((n, 'C', '(*' + line + '*)'))
148                else:
149                    to_process.append((line[2:], n))
150            else:
151                if line.strip():
152                    comments.append((n, 'C', '(*' + line + '*)'))
153                else:
154                    comments.append((n, 'C', ''))
155        else:
156            if '--' in line:
157                line = line.split('--')[0].rstrip()
158
159            if line.strip() == '':
160                comments.append((n, 'C', ''))
161            elif line.strip().startswith('{-'):
162                # single-line {- -} comments only
163                comments.append((n, 'C', '(*' + line + '*)'))
164            elif line.startswith('#'):
165                # preprocessor directives
166                comments.append((n, 'C', '(*' + line + '*)'))
167            else:
168                to_process.append((line, n))
169
170    def_tree = offside_tree(to_process)
171    defs = create_defs(def_tree)
172    defs = group_defs(defs)
173
174    #   Forget about the comments for now
175
176    # defs_plus_comments = [d.line, d) for d in defs] + comments
177    # defs_plus_comments.sort()
178    # defs = []
179    # prev_comments = []
180    # for term in defs_plus_comments:
181    #     if term[1] == 'C':
182    #         prev_comments.append(term[2])
183    #     else:
184    #         d = term[1]
185    #         d.comments = prev_comments
186    #         defs.append(d)
187    #         prev_comments = []
188
189    # apply def_transform and cut out any None return values
190    defs = [defs_transform(d) for d in defs]
191    defs = [d for d in defs if d is not None]
192
193    defs = ensure_type_ordering(defs)
194
195    return defs
196
197
198def get_lines(defs, call):
199    """Gets the output lines needed for this call from
200        all the potential output generated at parse time."""
201
202    if call.restr:
203        defs = [d for d in defs if d.type == 'comments'
204                or call.restr(d)]
205
206    output = []
207    for d in defs:
208        lines = def_lines(d, call)
209        if lines:
210            output.extend(lines)
211            output.append('')
212
213    return output
214
215
216def offside_tree(input):
217    """Breaks lines up into a tree based on the offside rule.
218        Each line gets as children the lines following it up until
219        the next line whose indent is less or equal."""
220    if input == []:
221        return []
222    head, head_n = input[0]
223    head_indent = len(head) - len(head.lstrip())
224    children = []
225    result = []
226    for line, n in input[1:]:
227        indent = len(line) - len(line.lstrip())
228        if indent <= head_indent:
229            result.append((head, head_n, offside_tree(children)))
230            head, head_n, head_indent = (line, n, indent)
231            children = []
232        else:
233            children.append((line, n))
234    result.append((head, head_n, offside_tree(children)))
235
236    return result
237
238
239def discard_line_numbers(tree):
240    """Takes a tree containing tuples (line, n, children) and
241        discards the n terms, returning a tree with tuples
242        (line, children)"""
243    result = []
244    for (line, _, children) in tree:
245        result.append((line, discard_line_numbers(children)))
246    return result
247
248
249def flatten_tree(tree):
250    """Returns a tree to the set of numbered lines it was
251        drawn from."""
252    result = []
253    for (line, children) in tree:
254        result.append(line)
255        result.extend(flatten_tree(children))
256
257    return result
258
259
260def create_defs(tree):
261    defs = [create_def(elt) for elt in tree]
262    defs = [d for d in defs if d is not None]
263
264    return defs
265
266
267def group_defs(defs):
268    """Takes a file broken into a series of definitions, and locates
269        multiple definitions of constants, caused by type signatures or
270        pattern matching, and accumulates to a single object per genuine
271        definition"""
272    defgroups = []
273    defined = ''
274    for d in defs:
275        this_defines = d.defined
276        if d.type != 'definitions':
277            this_defines = ''
278        if this_defines == defined and this_defines:
279            defgroups[-1].body.extend(d.body)
280        else:
281            defgroups.append(d)
282            defined = this_defines
283
284    return defgroups
285
286
287def create_def(elt):
288    """Takes an element of an offside tree and creates
289        a definition object."""
290    (line, n, children) = elt
291    children = discard_line_numbers(children)
292    return create_def_2(line, children, n)
293
294
295def create_def_2(line, children, n):
296    d = Def()
297    d.body = [(line, children)]
298    d.line = n
299    lead = line.split(None, 3)
300    if lead[0] in ['import', 'module', 'class']:
301        return
302    elif lead[0] == 'instance':
303        type = 'instance'
304        defined = lead[2]
305    elif lead[0] in ['type', 'newtype', 'data']:
306        type = 'newtype'
307        defined = lead[1]
308    else:
309        type = 'definitions'
310        defined = lead[0]
311
312    d.type = type
313    d.defined = defined
314    return d
315
316
317def get_primrecs():
318    f = open('primrecs')
319    keys = [line.strip() for line in f]
320    return set(key for key in keys if key != '')
321
322
323primrecs = get_primrecs()
324
325
326def defs_transform(d):
327    """Transforms the set of definitions for a function. This
328        may include its type signature, and may include the special
329        case of multiple definitions."""
330    # the first tokens of the first line in the first definition
331    if d.type == 'newtype':
332        return newtype_transform(d)
333    elif d.type == 'instance':
334        return instance_transform(d)
335
336    lead = d.body[0][0].split(None, 2)
337    if lead[1] == '::':
338        d.sig = type_sig_transform(d.body[0])
339        d.body.pop(0)
340
341    if d.defined in primrecs:
342        return primrec_transform(d)
343
344    if len(d.body) > 1:
345        d.body = pattern_match_transform(d.body)
346
347    if len(d.body) == 0:
348        print()
349        print(d)
350        assert 0
351
352    d.body = body_transform(d.body, d.defined, d.sig)
353    return d
354
355
356def wrap_qualify(lines, deep=True):
357    if len(lines) == 0:
358        return lines
359
360    """Close and then re-open a locale so instantiations can go through"""
361    if deep:
362        asdfextra = ""
363    else:
364        asdfextra = ""
365
366    if call.current_context:
367        lines.insert(0, 'end\nqualify {} (in Arch) {}'.format(call.current_context[-1],
368                                                              asdfextra))
369        lines.append('end_qualify\ncontext Arch begin global_naming %s' % call.current_context[-1])
370    return lines
371
372
373def def_lines(d, call):
374    """Produces the set of lines associated with a definition."""
375    if call.all_bits:
376        L = []
377        if d.comments:
378            L.extend(flatten_tree(d.comments))
379            L.append('')
380        if d.type == 'definitions':
381            L.append('definition')
382            if d.sig:
383                L.extend(flatten_tree([d.sig]))
384                L.append('where')
385                L.extend(flatten_tree(d.body))
386        elif d.type == 'newtype':
387            L.extend(flatten_tree(d.body))
388        if d.instance_proofs:
389            L.extend(wrap_qualify(flatten_tree(d.instance_proofs)))
390            L.append('')
391        if d.instance_extras:
392            L.extend(flatten_tree(d.instance_extras))
393            L.append('')
394        return L
395
396    if call.instanceproofs:
397        if not call.bodies_only:
398            instance_proofs = wrap_qualify(flatten_tree(d.instance_proofs))
399        else:
400            instance_proofs = []
401
402        if not call.decls_only:
403            instance_extras = flatten_tree(d.instance_extras)
404        else:
405            instance_extras = []
406
407        newline_needed = len(instance_proofs) > 0 and len(instance_extras) > 0
408        return instance_proofs + (['']
409                                  if newline_needed else []) + instance_extras
410
411    if call.body:
412        return get_lambda_body_lines(d)
413
414    comments = d.comments
415    try:
416        typesig = flatten_tree([d.sig])
417    except:
418        typesig = []
419    body = flatten_tree(d.body)
420    type = d.type
421
422    if type == 'definitions':
423        if call.decls_only:
424            if typesig:
425                return comments + ["consts'"] + typesig
426            else:
427                return []
428        elif call.bodies_only:
429            if d.sig:
430                defname = '%s_def' % d.defined
431                if d.primrec:
432                    print('warning body-only primrec:')
433                    print(body[0])
434                    return comments + ['primrec'] + body
435                return comments + ['defs %s:' % defname] + body
436            else:
437                return comments + ['definition'] + body
438        else:
439            if d.primrec:
440                return comments + ['primrec'] + typesig \
441                    + ['where'] + body
442            if typesig:
443                return comments + ['definition'] + typesig + ['where'] + body
444            else:
445                return comments + ['definition'] + body
446    elif type == 'comments':
447        return comments
448    elif type == 'newtype':
449        if not call.bodies_only:
450            return body
451
452
453def type_sig_transform(tree_element):
454    """Performs transformations on a type signature line
455        preceding a function declaration or some such."""
456
457    line = reduce_to_single_line(tree_element)
458    (pre, post) = line.split('::')
459    result = type_transform(post)
460    if '[pp' in result:
461        print(line)
462        print(pre)
463        print(post)
464        print(result)
465        assert 0
466    line = pre + ':: "' + result + '"'
467
468    return (line, [])
469
470
471ignore_classes = {'Error': 1}
472hand_classes = {'Bits': ['HS_bit'],
473                'Num': ['minus', 'one', 'zero', 'plus', 'numeral'],
474                'FiniteBits': ['finiteBit']}
475
476
477def type_transform(string):
478    """Performs transformations on a type signature, whether
479        part of a type signature line or occuring in a function."""
480
481    # deal with type classes by recursion
482    bits = string.split('=>', 1)
483    if len(bits) == 2:
484        lhs = bits[0].strip()
485        if lhs.startswith('(') and lhs.endswith(')'):
486            instances = lhs[1:-1].split(',')
487            string = ' => '.join(instances + [bits[1]])
488        else:
489            instances = [lhs]
490        var_annotes = {}
491        for instance in instances:
492            (name, var) = instance.split()
493            if name in ignore_classes:
494                continue
495            if name in hand_classes:
496                names = hand_classes[name]
497            else:
498                names = [type_conv(name)]
499            var = "'" + var
500            var_annotes.setdefault(var, [])
501            var_annotes[var].extend(names)
502        transformed = type_transform(bits[1])
503        for (var, insts) in six.iteritems(var_annotes):
504            if len(insts) == 1:
505                newvar = '(%s :: %s)' % (var, insts[0])
506            else:
507                newvar = '(%s :: {%s})' % (var, ', '.join(insts))
508            transformed = newvar.join(transformed.split(var, 1))
509        return transformed
510
511    # get rid of (), insert Unit, which converts to unit
512    string = 'Unit'.join(string.split('()'))
513
514    # divide up by -> or by , then divide on space.
515    # apply everything locally then work back up
516    bstring = braces.str(string, '(', ')')
517    bits = bstring.split('->')
518    r = ' \<Rightarrow> '
519    if len(bits) == 1:
520        bits = bstring.split(',')
521        r = ' * '
522    result = [type_bit_transform(bit) for bit in bits]
523    return r.join(result)
524
525
526def type_bit_transform(bit):
527    s = str(bit).strip()
528    if s.startswith('['):
529        # handling this properly is hard.
530        assert s.endswith(']')
531        bit2 = braces.str(s[1:-1], '(', ')')
532        return '%s list' % type_bit_transform(bit2)
533    bits = bit.split(None, braces=True)
534    if str(bits[0]) == 'PPtr':
535        assert len(bits) == 2
536        return 'machine_word'
537    if len(bits) > 1 and bits[1].startswith('['):
538        assert bits[-1].endswith(']')
539        arg = ' '.join([str(bit) for bit in bits[1:]])[1:-1]
540        arg = type_transform(arg)
541        return ' '.join([arg, 'list', str(type_conv(bits[0]))])
542    bits = [type_conv(bit) for bit in bits]
543    bits = constructor_reversing(bits)
544    bits = [bit.map(type_transform) for bit in bits]
545    strs = [str(bit) for bit in bits]
546    return ' '.join(strs)
547
548
549def reduce_to_single_line(tree_element):
550    def inner(tree_element, acc):
551        (line, children) = tree_element
552        acc.append(line)
553        for child in children:
554            inner(child, acc)
555        return acc
556    return ' '.join(inner(tree_element, []))
557
558
559type_conv_table = {
560    'Maybe': 'option',
561    'Bool': 'bool',
562    'Word': 'machine_word',
563    'Int': 'nat',
564    'String': 'unit list'}
565
566
567def type_conv(string):
568    """Converts a type used in Haskell to our equivalent"""
569    if string.startswith('('):
570        # ignore compound types, type_transform will descend into em
571        result = string
572    elif '.' in string:
573        # qualified references
574        bits = string.split('.')
575        typename = bits[-1]
576        module = reduce(lambda x, y: x + '.' + y, bits[:-1])
577        typename = type_conv(typename)
578        result = module + '.' + typename
579    elif string[0].islower():
580        # type variable
581        result = "'%s" % string
582    elif string[0] == '[' and string[-1] == ']':
583        # list
584        inner = type_conv(string[1:-1])
585        result = '%s list' % inner
586    elif str(string) in type_conv_table:
587        result = type_conv_table[str(string)]
588    else:
589        # convert StudlyCaps to lower_with_underscores
590        was_lower = False
591        s = ''
592        for c in string:
593            if c.isupper() and was_lower:
594                s = s + '_' + c.lower()
595            else:
596                s = s + c.lower()
597            was_lower = c.islower()
598        result = s
599        type_conv_table[str(string)] = result
600
601    return braces.clone(result, string)
602
603
604def constructor_reversing(tokens):
605    if len(tokens) < 2:
606        return tokens
607    elif len(tokens) == 2:
608        return [tokens[1], tokens[0]]
609    elif tokens[0] == '[' and tokens[2] == ']':
610        return [tokens[1], braces.str('list', '(', ')')]
611    elif len(tokens) == 4 and tokens[1] == '[' and tokens[3] == ']':
612        listToken = braces.str('(List %s)' % tokens[2], '(', ')')
613        return [listToken, tokens[0]]
614    elif tokens[0] == 'array':
615        arrow_token = braces.str('\<Rightarrow>', '(', ')')
616        return [tokens[1], arrow_token, tokens[2]]
617    elif tokens[0] == 'either':
618        plus_token = braces.str('+', '(', ')')
619        return [tokens[1], plus_token, tokens[2]]
620    elif len(tokens) == 5 and tokens[2] == '[' and tokens[4] == ']':
621        listToken = braces.str('(List %s)' % tokens[3], '(', ')')
622        lbrack = braces.str('(', '+', '+')
623        rbrack = braces.str(')', '+', '+')
624        comma = braces.str(',', '+', '+')
625        return [lbrack, tokens[1], comma, listToken, rbrack, tokens[0]]
626    elif len(tokens) == 3:
627        # here comes a fudge
628        lbrack = braces.str('(', '+', '+')
629        rbrack = braces.str(')', '+', '+')
630        comma = braces.str(',', '+', '+')
631        return [lbrack, tokens[1], comma, tokens[2], rbrack, tokens[0]]
632    else:
633        print("Error parsing " + filename)
634        print("Can't deal with %s" % tokens)
635        sys.exit(1)
636
637
638def newtype_transform(d):
639    """Takes a Haskell style newtype/data type declaration, whose
640        options are divided with | and each of whose options has named
641        elements, and forms a datatype statement and definitions for
642        the named extractors and their update functions."""
643    if len(d.body) != 1:
644        print('--- newtype long body ---')
645        print(d)
646    [(line, children)] = d.body
647
648    if children and children[-1][0].lstrip().startswith('deriving'):
649        l = reduce_to_single_line(children[-1])
650        children = children[:-1]
651        r = re.compile(r"[,\s\(\)]+")
652        bits = r.split(l)
653        bits = [bit for bit in bits if bit and bit != 'deriving']
654        d.deriving = bits
655
656    line = reduce_to_single_line((line, children))
657
658    bits = line.split(None, 1)
659    op = bits[0]
660    line = bits[1]
661    bits = line.split('=', 1)
662    header = type_conv(bits[0].strip())
663    d.typename = header
664    d.typedeps = set()
665    if len(bits) == 1:
666        # line of form 'data Blah' introduces unknown type?
667        d.body = [('typedecl %s' % header, [])]
668        all_type_arities[header] = []  # HACK
669        return d
670    line = bits[1]
671
672    if op == 'type':
673        # type synonym
674        return typename_transform(line, header, d)
675    elif line.find('{') == -1:
676        # not a record
677        return simple_newtype_transform(line, header, d)
678    else:
679        return named_newtype_transform(line, header, d)
680
681
682def typename_transform(line, header, d):
683    try:
684        [oldtype] = line.split()
685    except:
686        sys.stderr.write('Warning: type assignment with parameters not supported %s\n' % d.body)
687        call.bad_type_assignment = True
688        return
689    if oldtype.startswith('Data.Word.Word'):
690        # take off the prefix, leave Word32 or Word64 etc
691        oldtype = oldtype[10:]
692    oldtype = type_conv(oldtype)
693    bits = oldtype.split()
694    for bit in bits:
695        d.typedeps.add(bit)
696    lines = [
697        'type_synonym %s = "%s"' % (header, oldtype),
698        # translations (* TYPE 1 *)',
699        # "%s" <=(type) "%s"' % (oldtype, header)
700    ]
701    d.body = [(line, []) for line in lines]
702    return d
703
704
705keep_conss = {}
706
707
708def simple_newtype_transform(line, header, d):
709    lines = []
710    arities = []
711    for i, bit in enumerate(line.split('|')):
712        braced = braces.str(bit, '(', ')')
713        bits = braced.split()
714        if len(bits) == 2:
715            last_lhs = bits[0]
716
717        if i == 0:
718            l = '    %s' % bits[0]
719        else:
720            l = '  | %s' % bits[0]
721        for bit in bits[1:]:
722            if bit.startswith('('):
723                bit = bit[1:-1]
724            typename = type_transform(str(bit))
725            if len(bits) == 2:
726                last_rhs = typename
727            if ' ' in typename:
728                typename = '"%s"' % typename
729            l = l + ' ' + typename
730            d.typedeps.add(typename)
731        lines.append(l)
732
733        arities.append((str(bits[0]), len(bits[1:])))
734
735    if list((dict(arities)).values()) == [1] and header not in keep_conss:
736        return type_wrapper_type(header, last_lhs, last_rhs, d)
737
738    d.body = [('datatype %s =' % header, [(line, []) for line in lines])]
739
740    set_instance_proofs(header, arities, d)
741
742    return d
743
744
745all_constructor_args = {}
746
747
748def named_newtype_transform(line, header, d):
749    bits = line.split('|')
750
751    constructors = [get_type_map(bit) for bit in bits]
752    all_constructor_args.update(dict(constructors))
753
754    lines = []
755    for i, (name, map) in enumerate(constructors):
756        if i == 0:
757            l = '    %s' % name
758        else:
759            l = '  | %s' % name
760        oname = name
761        for name, type in map:
762            if type is None:
763                print("ERR: None snuck into constructor list for %s" % name)
764                print(line, header, oname)
765                assert False
766
767            if name is None:
768                opt_name = ""
769                opt_close = ""
770            else:
771                opt_name = " (" + name + " :"
772                opt_close = ")"
773
774            if len(type.split()) == 1 and '(' not in type:
775                the_type = type
776            else:
777                the_type = '"' + type + '"'
778
779            l = l + opt_name + ' ' + the_type + opt_close
780
781            for bit in type.split():
782                d.typedeps.add(bit)
783        lines.append(l)
784
785    names = {}
786    types = {}
787    for cons, map in constructors:
788        for i, (name, type) in enumerate(map):
789            names.setdefault(name, {})
790            names[name][cons] = i
791            types[name] = type
792
793    for name, map in six.iteritems(names):
794        lines.append('')
795        lines.extend(named_update_definitions(name, map, types[name], header,
796                                              dict(constructors)))
797
798    for name, map in constructors:
799        if map == []:
800            continue
801        lines.append('')
802        lines.extend(named_constructor_translation(name, map, header))
803
804    if len(constructors) > 1:
805        for name, map in constructors:
806            lines.append('')
807            check = named_constructor_check(name, map, header)
808            lines.extend(check)
809
810    if len(constructors) == 1:
811        for ex_name, _ in six.iteritems(names):
812            for up_name, _ in six.iteritems(names):
813                lines.append('')
814                lines.extend(named_extractor_update_lemma(ex_name, up_name))
815
816    arities = [(name, len(map)) for (name, map) in constructors]
817
818    if list((dict(arities)).values()) == [1] and header not in keep_conss:
819        [(cons, map)] = constructors
820        [(name, type)] = map
821        return type_wrapper_type(header, cons, type, d, decons=(name, type))
822
823    set_instance_proofs(header, arities, d)
824
825    d.body = [('datatype %s =' % header, [(line, []) for line in lines])]
826    return d
827
828
829def named_extractor_update_lemma(ex_name, up_name):
830    lines = []
831    lines.append('lemma %s_%s_update [simp]:' % (ex_name, up_name))
832
833    if up_name == ex_name:
834        lines.append('  "%s (%s_update f v) = f (%s v)"' %
835                     (ex_name, up_name, ex_name))
836    else:
837        lines.append('  "%s (%s_update f v) = %s v"' %
838                     (ex_name, up_name, ex_name))
839
840    lines.append('  by (cases v) simp')
841
842    return lines
843
844
845def named_extractor_definitions(name, map, type, header, constructors):
846    lines = []
847    lines.append('primrec')
848    lines.append('  %s :: "%s \<Rightarrow> %s"'
849                 % (name, header, type))
850    lines.append('where')
851    is_first = True
852    for cons, i in six.iteritems(map):
853        if is_first:
854            l = '  "%s (%s' % (name, cons)
855            is_first = False
856        else:
857            l = '| "%s (%s' % (name, cons)
858        num = len(constructors[cons])
859        for k in range(num):
860            l = l + ' v%d' % k
861        l = l + ') = v%d"' % i
862        lines.append(l)
863
864    return lines
865
866
867def named_update_definitions(name, map, type, header, constructors):
868    lines = []
869    lines.append('primrec')
870    ra = '\<Rightarrow>'
871    if len(type.split()) > 1:
872        type = '(%s)' % type
873    lines.append('  %s_update :: "(%s %s %s) %s %s %s %s"'
874                 % (name, type, ra, type, ra, header, ra, header))
875    lines.append('where')
876    is_first = True
877    for cons, i in six.iteritems(map):
878        if is_first:
879            l = '  "%s_update f (%s' % (name, cons)
880            is_first = False
881        else:
882            l = '| "%s_update f (%s' % (name, cons)
883        num = len(constructors[cons])
884        for k in range(num):
885            l = l + ' v%d' % k
886        l = l + ') = %s' % cons
887        for k in range(num):
888            if k == i:
889                l = l + ' (f v%d)' % k
890            else:
891                l = l + ' v%d' % k
892        l = l + '"'
893        lines.append(l)
894
895    return lines
896
897
898def named_constructor_translation(name, map, header):
899    lines = []
900    lines.append('abbreviation (input)')
901    l = '  %s_trans :: "' % name
902    for n, type in map:
903        l = l + '(' + type + ') \<Rightarrow> '
904    l = l + '%s" ("%s\'_ \<lparr> %s= _' % (header, name, map[0][0])
905    for n, type in map[1:]:
906        l = l + ', %s= _' % n
907    l = l + ' \<rparr>")'
908    lines.append(l)
909    lines.append('where')
910    l = '  "%s_ \<lparr> %s= v0' % (name, map[0][0])
911    for i, (n, type) in enumerate(map[1:]):
912        l = l + ', %s= v%d' % (n, i + 1)
913    l = l + ' \<rparr> == %s' % name
914    for i in range(len(map)):
915        l = l + ' v%d' % i
916    l = l + '"'
917    lines.append(l)
918
919    return lines
920
921
922def named_constructor_check(name, map, header):
923    lines = []
924    lines.append('definition')
925    lines.append('  is%s :: "%s \<Rightarrow> bool"' % (name, header))
926    lines.append('where')
927    lines.append(' "is%s v \<equiv> case v of' % name)
928    l = '    %s ' % name
929    for i, _ in enumerate(map):
930        l = l + 'v%d ' % i
931    l = l + '\<Rightarrow> True'
932    lines.append(l)
933    lines.append('  | _ \<Rightarrow> False"')
934
935    return lines
936
937
938def type_wrapper_type(header, cons, rhs, d, decons=None):
939    if '\\<Rightarrow>' in d.typedeps:
940        d.body = [('(* type declaration of %s omitted *)' % header, [])]
941        return d
942    lines = [
943        'type_synonym %s = "%s"' % (header, rhs),
944        # translations (* TYPE 2 *)',
945        # "%s" <=(type) "%s"' % (header, rhs),
946        '',
947        'definition',
948        '  %s :: "%s \\<Rightarrow> %s"' % (cons, header, header),
949        'where %s_def[simp]:' % cons,
950        ' "%s \\<equiv> id"' % cons,
951    ]
952    if decons:
953        (decons, decons_type) = decons
954        lines.extend([
955            '',
956            'definition',
957            '  %s :: "%s \\<Rightarrow> %s"' % (decons, header, header),
958            'where',
959            '  %s_def[simp]:' % decons,
960            ' "%s \\<equiv> id"' % decons,
961            '',
962            'definition'
963            '  %s_update :: "(%s \\<Rightarrow> %s) \\<Rightarrow> %s \\<Rightarrow> %s"'
964            % (decons, header, header, header, header),
965            'where',
966            '  %s_update_def[simp]:' % decons,
967            ' "%s_update f y \<equiv> f y"' % decons,
968            ''
969        ])
970        lines.extend(named_constructor_translation(cons, [(decons, decons_type)
971                                                          ], header))
972
973    d.body = [(line, []) for line in lines]
974    return d
975
976
977def instance_transform(d):
978    [(line, children)] = d.body
979    bits = line.split(None, 3)
980    assert bits[0] == 'instance'
981    classname = bits[1]
982    typename = type_conv(bits[2])
983    if classname == 'Show':
984        print("Warning: discarding class instance '%s :: Show'" % typename)
985        return None
986    if typename == '()':
987        print("Warning: discarding class instance 'unit :: %s'" % classname)
988        return None
989    if len(bits) == 3:
990        if children == []:
991            defs = []
992        else:
993            [(l, c)] = children
994            assert l.strip() == 'where'
995            defs = c
996    else:
997        assert bits[3:] == ['where']
998        defs = children
999    defs = [create_def_2(l, c, 0) for (l, c) in defs]
1000    defs = [d2 for d2 in defs if d2 is not None]
1001    defs = group_defs(defs)
1002    defs = [defs_transform(d2) for d2 in defs]
1003    defs_dict = {}
1004    for d2 in defs:
1005        if d2 is not None:
1006            defs_dict[d2.defined] = d2
1007    d.instance_defs = defs_dict
1008    d.deriving = [classname]
1009    if typename not in all_type_arities:
1010        sys.stderr.write('FAIL: attempting %s\n' % d.defined)
1011        sys.stderr.write('(typename %r)\n' % typename)
1012        sys.stderr.write('when reading %s\n' % filename)
1013        sys.stderr.write('but class not defined yet\n')
1014        sys.stderr.write('perhaps parse in different order?\n')
1015        sys.stderr.write('hint: #INCLUDE_HASKELL_PREPARSE\n')
1016        sys.exit(1)
1017    arities = all_type_arities[typename]
1018    set_instance_proofs(typename, arities, d)
1019
1020    return d
1021
1022
1023all_type_arities = {}
1024
1025
1026def set_instance_proofs(header, constructor_arities, d):
1027    all_type_arities[header] = constructor_arities
1028    pfs = []
1029    exs = []
1030    canonical = list(enumerate(constructor_arities))
1031
1032    classes = d.deriving
1033    instance_proof_fns = set(
1034        sorted((instance_proof_table[classname] for classname in classes),
1035               key=lambda x: x.order))
1036    for f in instance_proof_fns:
1037        (npfs, nexs) = f(header, canonical, d)
1038        pfs.extend(npfs)
1039        exs.extend(nexs)
1040
1041    if d.type == 'newtype' and len(canonical) == 1 and False:
1042        [(cons, n)] = constructor_arities
1043        if n == 1:
1044            pfs.extend(finite_instance_proofs(header, cons))
1045
1046    if pfs:
1047        lead = '(* %s instance proofs *)' % header
1048        d.instance_proofs = [(lead, [(line, []) for line in pfs])]
1049    if exs:
1050        lead = '(* %s extra instance defs *)' % header
1051        d.instance_extras = [(lead, [(line, []) for line in exs])]
1052
1053
1054def finite_instance_proofs(header, cons):
1055    lines = []
1056    lines.append('')
1057    lines.append('instance %s :: finite' % header)
1058    if call.current_context:
1059        lines.append('interpretation Arch .')
1060    lines.append('  apply (intro_classes)')
1061    lines.append('  apply (rule_tac f="%s" in finite_surj_type)'
1062                 % cons)
1063    lines.append('  apply (safe, case_tac x, simp_all)')
1064    lines.append('  done')
1065
1066    return (lines, [])
1067
1068# wondering where the serialisable proofs went? see
1069# commit 21361f073bbafcfc985934e563868116810d9fa2 for last known occurence.
1070
1071
1072# leave type tags 0..11 for explicit use outside of this script
1073next_type_tag = 12
1074
1075
1076def storable_instance_proofs(header, canonical, d):
1077    proofs = []
1078    extradefs = []
1079
1080    global next_type_tag
1081    next_type_tag = next_type_tag + 1
1082    proofs.extend([
1083        '', 'defs (overloaded)', '  typetag_%s[simp]:' % header,
1084        ' "typetag (x :: %s) \<equiv> %d"' % (header, next_type_tag), ''
1085        'instance %s :: dynamic' % header, '  by (intro_classes, simp)'
1086    ])
1087
1088    proofs.append('')
1089    proofs.append('instance %s :: storable ..' % header)
1090
1091    defs = d.instance_defs
1092    extradefs.append('')
1093    if 'objBits' in defs:
1094        extradefs.append('definition')
1095        body = flatten_tree(defs['objBits'].body)
1096        bits = body[0].split('objBits')
1097        assert bits[0].strip() == '"'
1098        if bits[1].strip().startswith('_'):
1099            bits[1] = 'x ' + bits[1].strip()[1:]
1100        bits = bits[1].split(None, 1)
1101        body[0] = '  objBits_%s: "objBits (%s :: %s) %s' \
1102            % (header, bits[0], header, bits[1])
1103        extradefs.extend(body)
1104
1105    extradefs.append('')
1106    if 'makeObject' in defs:
1107        extradefs.append('definition')
1108        body = flatten_tree(defs['makeObject'].body)
1109        bits = body[0].split('makeObject')
1110        assert bits[0].strip() == '"'
1111        body[0] = '  makeObject_%s: "(makeObject :: %s) %s' \
1112            % (header, header, bits[1])
1113        extradefs.extend(body)
1114
1115    extradefs.extend(['', 'definition', ])
1116    if 'loadObject' in defs:
1117        extradefs.append('  loadObject_%s:' % header)
1118        extradefs.extend(flatten_tree(defs['loadObject'].body))
1119    else:
1120        extradefs.extend([
1121            '  loadObject_%s[simp]:' % header,
1122            ' "(loadObject p q n obj) :: %s \<equiv>' % header,
1123            '    loadObject_default p q n obj"',
1124        ])
1125
1126    extradefs.extend(['', 'definition', ])
1127    if 'updateObject' in defs:
1128        extradefs.append('  updateObject_%s:' % header)
1129        body = flatten_tree(defs['updateObject'].body)
1130        bits = body[0].split('updateObject')
1131        assert bits[0].strip() == '"'
1132        bits = bits[1].split(None, 1)
1133        body[0] = ' "updateObject (%s :: %s) %s' \
1134            % (bits[0], header, bits[1])
1135        extradefs.extend(body)
1136    else:
1137        extradefs.extend([
1138            '  updateObject_%s[simp]:' % header,
1139            ' "updateObject (val :: %s) \<equiv>' % header,
1140            '    updateObject_default val"',
1141        ])
1142
1143    return (proofs, extradefs)
1144
1145
1146storable_instance_proofs.order = 1
1147
1148
1149def pspace_storable_instance_proofs(header, canonical, d):
1150    proofs = []
1151    extradefs = []
1152
1153    proofs.append('')
1154    proofs.append('instance %s :: pre_storable' % header)
1155    proofs.append('  by (intro_classes,')
1156    proofs.append(
1157        '      auto simp: projectKO_opts_defs split: kernel_object.splits arch_kernel_object.splits)')
1158
1159    defs = d.instance_defs
1160    extradefs.append('')
1161    if 'objBits' in defs:
1162        extradefs.append('definition')
1163        body = flatten_tree(defs['objBits'].body)
1164        bits = body[0].split('objBits')
1165        assert bits[0].strip() == '"'
1166        if bits[1].strip().startswith('_'):
1167            bits[1] = 'x ' + bits[1].strip()[1:]
1168        bits = bits[1].split(None, 1)
1169        body[0] = '  objBits_%s: "objBits (%s :: %s) %s' \
1170            % (header, bits[0], header, bits[1])
1171        extradefs.extend(body)
1172
1173    extradefs.append('')
1174    if 'makeObject' in defs:
1175        extradefs.append('definition')
1176        body = flatten_tree(defs['makeObject'].body)
1177        bits = body[0].split('makeObject')
1178        assert bits[0].strip() == '"'
1179        body[0] = '  makeObject_%s: "(makeObject :: %s) %s' \
1180            % (header, header, bits[1])
1181        extradefs.extend(body)
1182
1183    extradefs.extend(['', 'definition', ])
1184    if 'loadObject' in defs:
1185        extradefs.append('  loadObject_%s:' % header)
1186        extradefs.extend(flatten_tree(defs['loadObject'].body))
1187    else:
1188        extradefs.extend([
1189            '  loadObject_%s[simp]:' % header,
1190            ' "(loadObject p q n obj) :: %s kernel \<equiv>' % header,
1191            '    loadObject_default p q n obj"',
1192        ])
1193
1194    extradefs.extend(['', 'definition', ])
1195    if 'updateObject' in defs:
1196        extradefs.append('  updateObject_%s:' % header)
1197        body = flatten_tree(defs['updateObject'].body)
1198        bits = body[0].split('updateObject')
1199        assert bits[0].strip() == '"'
1200        bits = bits[1].split(None, 1)
1201        body[0] = ' "updateObject (%s :: %s) %s' \
1202            % (bits[0], header, bits[1])
1203        extradefs.extend(body)
1204    else:
1205        extradefs.extend([
1206            '  updateObject_%s[simp]:' % header,
1207            ' "updateObject (val :: %s) \<equiv>' % header,
1208            '    updateObject_default val"',
1209        ])
1210
1211    return (proofs, extradefs)
1212
1213
1214pspace_storable_instance_proofs.order = 1
1215
1216
1217def num_instance_proofs(header, canonical, d):
1218    assert len(canonical) == 1
1219    [(_, (cons, n))] = canonical
1220    assert n == 1
1221    lines = []
1222
1223    def add_bij_instance(classname, fns):
1224        ins = bij_instance(classname, header, cons, fns)
1225        lines.extend(ins)
1226
1227    add_bij_instance('plus', [('plus', '%s + %s', True)])
1228    add_bij_instance('minus', [('minus', '%s - %s', True)])
1229    add_bij_instance('zero', [('zero', '0', True)])
1230    add_bij_instance('one', [('one', '1', True)])
1231    add_bij_instance('times', [('times', '%s * %s', True)])
1232
1233    return (lines, [])
1234
1235
1236num_instance_proofs.order = 2
1237
1238
1239def enum_instance_proofs(header, canonical, d):
1240    def singular_canonical():
1241        if len(canonical) == 1:
1242            [(_, (_, n))] = canonical
1243            return n == 1
1244        else:
1245            return False
1246
1247    lines = ['(*<*)']
1248    if singular_canonical():
1249        [(_, (cons, n))] = canonical
1250        # special case for datatypes with single constructor with one argument
1251        lines.append('instantiation %s :: enum begin' % header)
1252        if call.current_context:
1253            lines.append('interpretation Arch .')
1254        lines.append('definition')
1255        lines.append('  enum_%s: "enum_class.enum \<equiv> map %s enum"'
1256                     % (header, cons))
1257
1258    else:
1259        cons_no_args = [cons for i, (cons, n) in canonical if n == 0]
1260        cons_one_arg = [cons for i, (cons, n) in canonical if n == 1]
1261        cons_two_args = [cons for i, (cons, n) in canonical if n > 1]
1262        assert cons_two_args == []
1263        lines.append('instantiation %s :: enum begin' % header)
1264        if call.current_context:
1265            lines.append('interpretation Arch .')
1266        lines.append('definition')
1267        lines.append('  enum_%s: "enum_class.enum \<equiv> ' % header)
1268        if len(cons_no_args) == 0:
1269            lines.append('    []')
1270        else:
1271            lines.append('    [ ')
1272            for cons in cons_no_args[:-1]:
1273                lines.append('      %s,' % cons)
1274            for cons in cons_no_args[-1:]:
1275                lines.append('      %s' % cons)
1276            lines.append('    ]')
1277        for cons in cons_one_arg:
1278            lines.append('    @ (map %s enum)' % cons)
1279        lines[-1] = lines[-1] + '"'
1280    lines.append('')
1281    for cons in cons_one_arg:
1282        lines.append('lemma %s_map_distinct[simp]: "distinct (map %s enum)"' % (cons, cons))
1283        lines.append('  apply (simp add: distinct_map)')
1284        lines.append('  by (meson injI %s.inject)' % header)
1285    lines.append('')
1286    lines.append('definition')
1287    lines.append('  "enum_class.enum_all (P :: %s \<Rightarrow> bool) \<longleftrightarrow> Ball UNIV P"'
1288                 % header)
1289    lines.append('')
1290    lines.append('definition')
1291    lines.append('  "enum_class.enum_ex (P :: %s \<Rightarrow> bool) \<longleftrightarrow> Bex UNIV P"'
1292                 % header)
1293    lines.append('')
1294    lines.append('  instance')
1295    lines.append('  apply intro_classes')
1296    lines.append('   apply (safe, simp)')
1297    lines.append('   apply (case_tac x)')
1298    if len(canonical) == 1:
1299        lines.append('  apply (auto simp: enum_%s enum_all_%s_def enum_ex_%s_def'
1300                     % (header, header, header))
1301        lines.append('    distinct_map_enum)')
1302        lines.append('  done')
1303    else:
1304        lines.append('  apply (simp_all add: enum_%s enum_all_%s_def enum_ex_%s_def)'
1305                     % (header, header, header))
1306        lines.append('  by fast+')
1307    lines.append('end')
1308    lines.append('')
1309    lines.append('instantiation %s :: enum_alt' % header)
1310    lines.append('begin')
1311    if call.current_context:
1312        lines.append('interpretation Arch .')
1313    lines.append('definition')
1314    lines.append('  enum_alt_%s: "enum_alt \<equiv> ' % header)
1315    lines.append('    alt_from_ord (enum :: %s list)"' % header)
1316    lines.append('instance ..')
1317    lines.append('end')
1318    lines.append('')
1319    lines.append('instantiation %s :: enumeration_both' % header)
1320    lines.append('begin')
1321    if call.current_context:
1322        lines.append('interpretation Arch .')
1323    lines.append('instance by (intro_classes, simp add: enum_alt_%s)'
1324                 % header)
1325    lines.append('end')
1326    lines.append('')
1327    lines.append('(*>*)')
1328
1329    return (lines, [])
1330
1331
1332enum_instance_proofs.order = 3
1333
1334
1335def bits_instance_proofs(header, canonical, d):
1336    assert len(canonical) == 1
1337    [(_, (cons, n))] = canonical
1338    assert n == 1
1339
1340    return (bij_instance('bit', header, cons,
1341                         [('shiftL', 'shiftL %s n', True),
1342                          ('shiftR', 'shiftR %s n', True),
1343                          ('bitSize', 'bitSize %s', False)]), [])
1344
1345
1346bits_instance_proofs.order = 5
1347
1348
1349def no_proofs(header, canonical, d):
1350    return ([], [])
1351
1352
1353no_proofs.order = 6
1354
1355# FIXME "Bounded" emits enum proofs even if something already has enum proofs
1356# generated due to "Enum"
1357
1358instance_proof_table = {
1359    'Eq': no_proofs,
1360    'Bounded': no_proofs,  # enum_instance_proofs,
1361    'Enum': enum_instance_proofs,
1362    'Ix': no_proofs,
1363    'Ord': no_proofs,
1364    'Show': no_proofs,
1365    'Bits': bits_instance_proofs,
1366    'Real': no_proofs,
1367    'Num': num_instance_proofs,
1368    'Integral': no_proofs,
1369    'Storable': storable_instance_proofs,
1370    'PSpaceStorable': pspace_storable_instance_proofs,
1371    'Typeable': no_proofs,
1372    'Error': no_proofs,
1373}
1374
1375
1376def bij_instance(classname, typename, constructor, fns):
1377    L = [
1378        '',
1379        'instance %s :: %s ..' % (typename, classname),
1380        'defs (overloaded)'
1381    ]
1382    for (fname, fstr, cast_return) in fns:
1383        n = len(fstr.split('%s')) - 1
1384        names = ('x', 'y', 'z', 'w')[:n]
1385        names2 = tuple([name + "'" for name in names])
1386        fstr1 = fstr % names
1387        fstr2 = fstr % names2
1388        L.append('  %s_%s: "%s \<equiv>' % (fname, typename, fstr1))
1389        for name in names:
1390            L.append("    case %s of %s %s' \<Rightarrow>"
1391                     % (name, constructor, name))
1392        if cast_return:
1393            L.append('      %s (%s)"' % (constructor, fstr2))
1394        else:
1395            L.append('      %s"' % fstr2)
1396
1397    return L
1398
1399
1400def get_type_map(string):
1401    """Takes Haskell named record syntax and produces
1402        a map (in the form of a list of tuples) defining
1403        the converted types of the names."""
1404    bits = string.split(None, 1)
1405    header = bits[0].strip()
1406    if len(bits) == 1:
1407        return (header, [])
1408    actual_map = bits[1].strip()
1409    if not (actual_map.startswith('{') and actual_map.endswith('}')):
1410        print('Error in ' + filename)
1411        print('Expected "A { blah :: blah etc }"')
1412        print('However { and } not found correctly')
1413        print('When parsing %s' % string)
1414        sys.exit(1)
1415    actual_map = actual_map[1:-1]
1416
1417    bits = braces.str(actual_map, '(', ')').split(',')
1418    bits.reverse()
1419    type = None
1420    map = []
1421    for bit in bits:
1422        bits = bit.split('::')
1423        if len(bits) == 2:
1424            type = type_transform(str(bits[1]).strip())
1425            name = str(bits[0]).strip()
1426        else:
1427            name = str(bit).strip()
1428        map.append((name, type))
1429    map.reverse()
1430    return (header, map)
1431
1432
1433numLiftIO = [0]
1434
1435
1436def body_transform(body, defined, sig, nopattern=False):
1437    # assume single object
1438    [(line, children)] = body
1439
1440    if '(' in line.split('=')[0] and not nopattern:
1441        [(line, children)] = \
1442            pattern_match_transform([(line, children)])
1443
1444    if 'liftIO' in reduce_to_single_line((line, children)):
1445        # liftIO is the translation boundary for current
1446        # SEL4, below which we get into details of interaction
1447        # with the foreign function interface - axiomatise
1448        assert '=' in line
1449        line = line.split('=')[0]
1450        bits = line.split()
1451        numLiftIO[0] = numLiftIO[0] + 1
1452        rhs = '(%d :: Int) %s' % (numLiftIO[0], ' '.join(bits[1:]))
1453        line = '%s\<equiv> underlying_arch_op %s' % (line, rhs)
1454        children = []
1455    elif '=' in line:
1456        line = '\<equiv>'.join(line.split('=', 1))
1457    elif leading_bar.match(children[0][0]):
1458        pass
1459    elif '=' in children[0][0]:
1460        (nextline, c2) = children[0]
1461        children[0] = ('\<equiv>'.join(nextline.split('=', 1)), c2)
1462    else:
1463        sys.stderr.write('WARN: def of %s missing =\n' % defined)
1464
1465    if children and (children[-1][0].endswith('where') or
1466                     children[-1][0].lstrip().startswith('where')):
1467        bits = line.split('\<equiv>')
1468        where_clause = where_clause_transform(children[-1])
1469        children = children[:-1]
1470        if len(bits) == 2 and bits[1].strip():
1471            line = bits[0] + '\<equiv>'
1472            new_line = ' ' * len(line) + bits[1]
1473            children = [(new_line, children)]
1474    else:
1475        where_clause = []
1476
1477    (line, children) = zipWith_transforms(line, children)
1478
1479    (line, children) = supplied_transforms(line, children)
1480
1481    (line, children) = case_clauses_transform((line, children))
1482
1483    (line, children) = do_clauses_transform((line, children), sig)
1484
1485    if children and leading_bar.match(children[0][0]):
1486        line = line + ' \<equiv>'
1487        children = guarded_body_transform(children, ' = ')
1488
1489    children = where_clause + children
1490
1491    if not nopattern:
1492        line = lhs_transform(line)
1493    line = lhs_de_underscore(line)
1494
1495    (line, children) = type_assertion_transform(line, children)
1496
1497    (line, children) = run_regexes((line, children))
1498    (line, children) = run_ext_regexes((line, children))
1499
1500    (line, children) = bracket_dollar_lambdas((line, children))
1501
1502    line = '"' + line
1503    (line, children) = add_trailing_string('"', (line, children))
1504    return [(line, children)]
1505
1506
1507dollar_lambda_regex = re.compile(r"\$\s*\\<lambda>")
1508
1509
1510def bracket_dollar_lambdas(xxx_todo_changeme):
1511    (line, children) = xxx_todo_changeme
1512    if dollar_lambda_regex.search(line):
1513        [left, right] = dollar_lambda_regex.split(line)
1514        line = '%s(\<lambda>%s' % (left, right)
1515        both = (line, children)
1516        if has_trailing_string(';', both):
1517            both = remove_trailing_string(';', both)
1518            (line, children) = add_trailing_string(');', both)
1519        else:
1520            (line, children) = add_trailing_string(')', both)
1521    children = [bracket_dollar_lambdas(elt) for elt in children]
1522    return (line, children)
1523
1524
1525def zipWith_transforms(line, children):
1526    if 'zipWithM_' not in line:
1527        children = [zipWith_transforms(l, c) for (l, c) in children]
1528        return (line, children)
1529
1530    if children == []:
1531        return (line, [])
1532
1533    if len(children) == 2:
1534        [(l, c), (l2, c2)] = children
1535        if c == [] and c2 == [] and '..]' in l + l2:
1536            children = [(l + ' ' + l2.strip(), [])]
1537
1538    (l, c) = children[-1]
1539    if c != [] or '..]' not in l:
1540        return (line, children)
1541
1542    bits = line.split('zipWithM_', 1)
1543    line = bits[0] + 'mapM_'
1544    ws = lead_ws(l)
1545    line2 = ws + '(split ' + bits[1]
1546
1547    bits = braces.str(l, '[', ']').split(None, braces=True)
1548
1549    line3 = ws + ' '.join(bits[:-2]) + ')'
1550    used_children = children[:-1] + [(line3, [])]
1551
1552    sndlast = bits[-2]
1553    last = bits[-1]
1554    if last.endswith('..]'):
1555        internal = last[1:-3].strip()
1556        if ',' in internal:
1557            bits = internal.split(',')
1558            l = '%s(zipE4 (%s) (%s) (%s))' \
1559                % (ws, sndlast, bits[0], bits[-1])
1560        else:
1561            l = '%s(zipE3 (%s) (%s))' % (ws, sndlast, internal)
1562    else:
1563        internal = sndlast[1:-3].strip()
1564        if ',' in internal:
1565            bits = internal.split(',')
1566            l = '%s(zipE2 (%s) (%s) (%s))' \
1567                % (ws, bits[0], bits[1], last)
1568        else:
1569            l = '%s(zipE1 (%s) (%s))' % (ws, internal, last)
1570
1571    return (line, [(line2, used_children), (l, [])])
1572
1573
1574def supplied_transforms(line, children):
1575    t = convert_to_stripped_tuple(line, children)
1576
1577    if t in supplied_transform_table:
1578        ws1 = lead_ws(line)
1579        result = supplied_transform_table[t]
1580        ws2 = lead_ws(result[0])
1581        result = adjust_ws(result, len(ws1) - len(ws2))
1582        supplied_transforms_usage[t] = 1
1583        return result
1584
1585    children = [supplied_transforms(l, c) for (l, c) in children]
1586
1587    return (line, children)
1588
1589
1590def convert_to_stripped_tuple(line, children):
1591    children = [convert_to_stripped_tuple(l, c) for (l, c) in children]
1592
1593    return (line.strip(), tuple(children))
1594
1595
1596def type_assertion_transform_inner(line):
1597    m = type_assertion.search(line)
1598    if not m:
1599        return line
1600    var = m.expand('\\1')
1601    type = m.expand('\\2').strip()
1602    type = type_transform(type)
1603    return line[:m.start()] + '(%s::%s)' % (var, type) \
1604        + type_assertion_transform_inner(line[m.end():])
1605
1606
1607def type_assertion_transform(line, children):
1608    children = [type_assertion_transform(l, c) for (l, c) in children]
1609
1610    return (type_assertion_transform_inner(line), children)
1611
1612
1613def where_clause_guarded_body(xxx_todo_changeme1):
1614    (line, children) = xxx_todo_changeme1
1615    if children and leading_bar.match(children[0][0]):
1616        return (line + ' =', guarded_body_transform(children, ' = '))
1617    else:
1618        return (line, children)
1619
1620
1621def where_clause_transform(xxx_todo_changeme2):
1622    (line, children) = xxx_todo_changeme2
1623    ws = line.split('where', 1)[0]
1624    if line.strip() != 'where':
1625        assert line.strip().startswith('where')
1626        children = [(''.join(line.split('where', 1)), [])] + children
1627    pre = ws + 'let'
1628    post = ws + 'in'
1629
1630    children = [(l, c) for (l, c) in children if l.split()[1] != '::']
1631    children = [case_clauses_transform((l, c)) for (l, c) in children]
1632    children = [do_clauses_transform(
1633        (l, c),
1634        None,
1635        type=0) for (l, c) in children]
1636    children = list(map(where_clause_guarded_body, children))
1637    for i, (l, c) in enumerate(children):
1638        l2 = braces.str(l, '(', ')')
1639        if len(l2.split('=')[0].split()) > 1:
1640            # turn f a = b into f = (\a -> b)
1641            l = '->'.join(l.split('=', 1))
1642            l = lead_ws(l) + ' = (\\ '.join(l.split(None, 1))
1643            (l, c) = add_trailing_string(')', (l, c))
1644            children[i] = (l, c)
1645    children = order_let_children(children)
1646    for i, child in enumerate(children[:-1]):
1647        children[i] = add_trailing_string(';', child)
1648    return [(pre, [])] + children + [(post, [])]
1649
1650
1651varname_re = re.compile(r"\w+")
1652
1653
1654def order_let_children(L):
1655    single_lines = [reduce_to_single_line(elt) for elt in L]
1656    keys = [str(braces.str(line, '(', ')').split('=')[0]).split()[0]
1657            for line in single_lines]
1658    keys = dict((key, n) for (n, key) in enumerate(keys))
1659    bits = [varname_re.findall(line) for line in single_lines]
1660    deps = {}
1661    for i, bs in enumerate(bits):
1662        for bit in bs:
1663            if bit in keys:
1664                j = keys[bit]
1665                if j != i:
1666                    deps.setdefault(i, {})[j] = 1
1667    done = {}
1668    L2 = []
1669    todo = dict(enumerate(L))
1670    n = len(todo)
1671    while n:
1672        todo_keys = list(todo.keys())
1673        for key in todo_keys:
1674            depstodo = [dep
1675                        for dep in list(deps.get(key, {}).keys()) if dep not in done]
1676            if depstodo == []:
1677                L2.append(todo.pop(key))
1678                done[key] = 1
1679        if len(todo) == n:
1680            print("No progress resolving let deps")
1681            print()
1682            print(list(todo.values()))
1683            print()
1684            print("Dependencies:")
1685            print(deps)
1686            assert 0
1687        n = len(todo)
1688    return L2
1689
1690
1691def do_clauses_transform(xxx_todo_changeme3, rawsig, type=None):
1692    (line, children) = xxx_todo_changeme3
1693    if children and children[-1][0].lstrip().startswith('where'):
1694        where_clause = where_clause_transform(children[-1])
1695        where_clause = [do_clauses_transform(
1696            (l, c), rawsig, 0) for (l, c) in where_clause]
1697        others = (line, children[:-1])
1698        others = do_clauses_transform(others, rawsig, type)
1699        (line, children) = where_clause[0]
1700        return (line, children + where_clause[1:] + [others])
1701
1702    if children:
1703        (l, c) = children[0]
1704        if c == [] and l.endswith('do'):
1705            line = line + ' ' + l.strip()
1706            children = children[1:]
1707
1708    if type is None:
1709        if not rawsig:
1710            type = 0
1711            sig = None
1712        else:
1713            sig = ' '.join(flatten_tree([rawsig]))
1714            type = monad_type_acquire(sig)
1715    (line, type) = monad_type_transform((line, type))
1716    if children == []:
1717        return (line, [])
1718
1719    rhs = line.split('<-', 1)[-1]
1720    if rhs.strip() == 'syscall' or rhs.strip() == 'atomicSyscall':
1721        assert len(children) == 5
1722        children = [do_clauses_transform(elt,
1723                                         rawsig,
1724                                         type=subtype)
1725                    for elt, subtype in zip(children, [1, 0, 1, 0, type])]
1726    elif line.strip().endswith('catchFailure'):
1727        assert len(children) == 2
1728        children = [do_clauses_transform(elt,
1729                                         rawsig,
1730                                         type=subtype)
1731                    for elt, subtype in zip(children, [1, 0])]
1732    else:
1733        children = [do_clauses_transform(elt,
1734                                         rawsig,
1735                                         type=type) for elt in children]
1736
1737    if not line.endswith('do'):
1738        return (line, children)
1739
1740    children, other_children = split_on_unmatched_bracket(children)
1741
1742    # single statement do clause won't parse in Isabelle
1743    if len(children) == 1:
1744        ws = lead_ws(line)
1745        return (line[:-2] + '(', children + [(ws + ')', [])])
1746
1747    line = line[:-2] + '(do' + 'E' * type
1748
1749    children = [(l, c) for (l, c) in children if l.strip() or c]
1750
1751    children2 = []
1752    for (l, c) in children:
1753        if l.lstrip().startswith('let '):
1754            if '=' not in l:
1755                extra = reduce_to_single_line(c[0])
1756                assert '=' in extra
1757                l = l + ' ' + extra
1758                c = c[1:]
1759            l = ''.join(l.split('let ', 1))
1760            letsubs = '<- return' + 'Ok' * type + ' ('
1761            l = letsubs.join(l.split('=', 1))
1762            (l, c) = add_trailing_string(')', (l, c))
1763            children2.extend(do_clause_pattern(l, c, type))
1764        else:
1765            children2.extend(do_clause_pattern(l, c, type))
1766
1767    children = [add_trailing_string(';', child)
1768                for child in children2[:-1]] + [children2[-1]]
1769
1770    ws = lead_ws(line)
1771    children.append((ws + 'od' + 'E' * type + ')', []))
1772
1773    return (line, children + other_children)
1774
1775
1776left_start_table = {
1777    'ASIDPool': '(inv ASIDPool)',
1778    'HardwareASID': 'id',
1779    'ArchObjectCap': 'capCap',
1780    'Just': 'the'
1781}
1782
1783
1784def do_clause_pattern(line, children, type, n=0):
1785    bits = line.split('<-')
1786    default = [('\<leftarrow>'.join(bits), children)]
1787    if len(bits) == 1:
1788        return default
1789    (left, right) = line.split('<-', 1)
1790    if ':' not in left and '[' not in left \
1791            and len(left.split()) == 1:
1792        return default
1793    left = left.strip()
1794
1795    v = 'v%d' % get_next_unique_id()
1796
1797    ass = 'assert' + ('E' * type)
1798    ws = lead_ws(line)
1799
1800    if left.startswith('('):
1801        assert left.endswith(')')
1802        if (',' in left):
1803            return default
1804        else:
1805            left = left[1:-1]
1806    bs = braces.str(left, '[', ']')
1807    if len(bs.split(':')) > 1:
1808        bits = [str(bit).strip() for bit in bs.split(':', 1)]
1809        lines = [('%s%s <- %s' % (ws, v, right), children),
1810                 ('%s%s <- headM %s' % (ws, bits[0], v), []),
1811                 ('%s%s <- tailM %s' % (ws, bits[1], v), [])]
1812        result = []
1813        for (l, c) in lines:
1814            result.extend(do_clause_pattern(l, c, type, n + 1))
1815        return result
1816    if left == '[]':
1817        return [('%s%s <- %s' % (ws, v, right), children),
1818                ('%s%s (%s = []) []' % (ws, ass, v), [])]
1819    if left.startswith('['):
1820        assert left.endswith(']')
1821        bs = braces.str(left[1:-1], '[', ']')
1822        bits = bs.split(',', 1)
1823        if len(bits) == 1:
1824            new_left = '%s:%s' % (bits[0], v)
1825            new_line = '%s%s <- %s' % (ws, new_left, right)
1826            extra = [('%s%s (%s = []) []' % (ws, ass, v), [])]
1827        else:
1828            new_left = '%s:[%s]' % (bits[0], bits[1])
1829            new_line = lead_ws(line) + new_left + ' <- ' + right
1830            extra = []
1831        return do_clause_pattern(new_line, children, type, n + 1) \
1832            + extra
1833    for lhs in left_start_table:
1834        if left.startswith(lhs):
1835            left = left[len(lhs):]
1836            tab = left_start_table[lhs]
1837            lM = 'liftM' + 'E' * type
1838            nl = ('%s <- %s %s $ %s' % (left, lM, tab, right))
1839            return do_clause_pattern(nl, children, type, n + 1)
1840
1841    print(line)
1842    print(left)
1843    assert 0
1844
1845
1846def split_on_unmatched_bracket(elts, n=None):
1847    if n is None:
1848        elts, other_elts, n = split_on_unmatched_bracket(elts, 0)
1849        return (elts, other_elts)
1850
1851    for (i, (line, children)) in enumerate(elts):
1852        for (j, c) in enumerate(line):
1853            if c == '(':
1854                n = n + 1
1855            if c == ')':
1856                n = n - 1
1857                if n < 0:
1858                    frag1 = line[:j]
1859                    frag2 = ' ' * len(frag1) + line[j:]
1860                    new_elts = elts[:i] + [(frag1, [])]
1861                    oth_elts = [(frag2, children)] \
1862                        + elts[i + 1:]
1863                    return (new_elts, oth_elts, n)
1864        c, other_c, n = split_on_unmatched_bracket(children, n)
1865        if other_c:
1866            new_elts = elts[:i] + [(line, c)]
1867            other_elts = other_c + elts[i + 1:]
1868            return (new_elts, other_elts, n)
1869
1870    return (elts, [], n)
1871
1872
1873def monad_type_acquire(sig, type=0):
1874    # note kernel appears after kernel_f/kernel_monad
1875    for (key, n) in [('kernel_f', 1), ('fault_monad', 1), ('syscall_monad', 2),
1876                     ('kernel_monad', 0), ('kernel_init', 1), ('kernel_p', 1),
1877                     ('kernel', 0)]:
1878        if key in sig:
1879            sigend = sig.split(key)[-1]
1880            return monad_type_acquire(sigend, n)
1881
1882    return type
1883
1884
1885def monad_type_transform(xxx_todo_changeme4):
1886    (line, type) = xxx_todo_changeme4
1887    split = None
1888    if 'withoutError' in line:
1889        split = 'withoutError'
1890        newtype = 1
1891    elif 'doKernelOp' in line:
1892        split = 'doKernelOp'
1893        newtype = 0
1894    elif 'runInit' in line:
1895        split = 'runInit'
1896        newtype = 1
1897    elif 'withoutFailure' in line:
1898        split = 'withoutFailure'
1899        newtype = 0
1900    elif 'withoutFault' in line:
1901        split = 'withoutFault'
1902        newtype = 0
1903    elif 'withoutPreemption' in line:
1904        split = 'withoutPreemption'
1905        newtype = 0
1906    elif 'allowingFaults' in line:
1907        split = 'allowingFaults'
1908        newtype = 1
1909    elif 'allowingErrors' in line:
1910        split = 'allowingErrors'
1911        newtype = 2
1912    elif '`catchFailure`' in line:
1913        [left, right] = line.split('`catchFailure`', 1)
1914        left, _ = monad_type_transform((left, 1))
1915        right, type = monad_type_transform((right, 0))
1916        return (left + '`catchFailure`' + right, type)
1917    elif 'catchingFailure' in line:
1918        split = 'catchingFailure'
1919        newtype = 1
1920    elif 'catchF' in line:
1921        split = 'catchF'
1922        newtype = 1
1923    elif 'emptyOnFailure' in line:
1924        split = 'emptyOnFailure'
1925        newtype = 1
1926    elif 'constOnFailure' in line:
1927        split = 'constOnFailure'
1928        newtype = 1
1929    elif 'nothingOnFailure' in line:
1930        split = 'nothingOnFailure'
1931        newtype = 1
1932    elif 'nullCapOnFailure' in line:
1933        split = 'nullCapOnFailure'
1934        newtype = 1
1935    elif '`catchFault`' in line:
1936        split = '`catchFault`'
1937        newtype = 1
1938    elif 'capFaultOnFailure' in line:
1939        split = 'capFaultOnFailure'
1940        newtype = 1
1941    elif 'ignoreFailure' in line:
1942        split = 'ignoreFailure'
1943        newtype = 1
1944    elif 'handleInvocation False' in line:  # THIS IS A HACK
1945        split = 'handleInvocation False'
1946        newtype = 0
1947    if split:
1948        [left, right] = line.split(split, 1)
1949        left, _ = monad_type_transform((left, type))
1950        right, newnewtype = monad_type_transform((right, newtype))
1951        return (left + split + right, newnewtype)
1952
1953    if type:
1954        line = ('return' + 'Ok' * type).join(line.split('return'))
1955        line = ('when' + 'E' * type).join(line.split('when'))
1956        line = ('unless' + 'E' * type).join(line.split('unless'))
1957        line = ('mapM' + 'E' * type).join(line.split('mapM'))
1958        line = ('forM' + 'E' * type).join(line.split('forM'))
1959        line = ('liftM' + 'E' * type).join(line.split('liftM'))
1960        line = ('assert' + 'E' * type).join(line.split('assert'))
1961        line = ('stateAssert' + 'E' * type).join(line.split('stateAssert'))
1962
1963    return (line, type)
1964
1965
1966def case_clauses_transform(xxx_todo_changeme5):
1967    (line, children) = xxx_todo_changeme5
1968    children = [case_clauses_transform(child) for child in children]
1969
1970    if not line.endswith(' of'):
1971        return (line, children)
1972
1973    bits = line.split('case ', 1)
1974    beforecase = bits[0]
1975    x = bits[1][:-3]
1976
1977    if '::' in x:
1978        x2 = braces.str(x, '(', ')')
1979        bits = x2.split('::', 1)
1980        if len(bits) == 2:
1981            x = str(bits[0]) + ':: ' + type_transform(str(bits[1]))
1982
1983    if children and children[-1][0].strip().startswith('where'):
1984        sys.stderr.write('Warning: where clause in case: %r\n'
1985                         % line)
1986        return (beforecase + '\<comment> \<open>case removed\<close> undefined', [])
1987        # where_clause = where_clause_transform(children[-1])
1988        # children = children[:-1]
1989        # (in_stmt, l) = where_clause[-1]
1990        # l.append(case_clauses_transform((line, children)))
1991        # return where_clause
1992
1993    cases = []
1994    bodies = []
1995    for (l, c) in children:
1996        bits = l.split('->', 1)
1997        while len(bits) == 1:
1998            if c == []:
1999                sys.stderr.write('wtf %r\n' % l)
2000                sys.exit(1)
2001            if c[0][0].strip().startswith('|'):
2002                bits = [bits[0], '']
2003                c = guarded_body_transform(c, '->')
2004            elif c[0][1] == []:
2005                l = l + ' ' + c.pop(0)[0].strip()
2006                bits = l.split('->', 1)
2007            else:
2008                [(moreline, c)] = c
2009                l = l + ' ' + moreline.strip()
2010                bits = l.split('->', 1)
2011        case = bits[0].strip()
2012        tail = bits[1]
2013        if c and c[-1][0].lstrip().startswith('where'):
2014            where_clause = where_clause_transform(c[-1])
2015            ws = lead_ws(where_clause[0][0])
2016            c = where_clause + [(ws + tail.strip(), [])] + c[:-1]
2017            tail = ''
2018        cases.append(case)
2019        bodies.append((tail, c))
2020
2021    cases = tuple(cases)  # used as key of a dict later, needs to be hashable
2022    # (since lists are mutable, they aren't)
2023    if not cases:
2024        print(line)
2025    conv = get_case_conv(cases)
2026    if conv == '<X>':
2027        sys.stderr.write('Warning: blanked case in caseconvs\n')
2028        return (beforecase + '\<comment> \<open>case removed\<close> undefined', [])
2029    if not conv:
2030        sys.stderr.write('Warning: case %r\n' % (cases, ))
2031        if cases not in cases_added:
2032            casestr = 'case \\x of ' + ' -> '.join(cases) + ' -> '
2033
2034            f = open('caseconvs', 'a')
2035            f.write('%s ---X>\n\n' % casestr)
2036            f.close()
2037            cases_added[cases] = 1
2038        return (beforecase + '\<comment> \<open>case removed\<close> undefined', [])
2039    conv = subs_nums_and_x(conv, x)
2040
2041    new_line = beforecase + '(' + conv[0][0]
2042    assert conv[0][1] is None
2043
2044    ws = lead_ws(children[0][0])
2045    new_children = []
2046    for (header, endnum) in conv[1:]:
2047        if endnum is None:
2048            new_children.append((ws + header, []))
2049        else:
2050            if (len(bodies) <= endnum):
2051                sys.stderr.write('ERROR: index %d out of bounds in case %r\n' %
2052                                 (endnum,
2053                                  cases, ))
2054                sys.exit(1)
2055
2056            (body, c) = bodies[endnum]
2057            new_children.append((ws + header + ' ' + body, c))
2058
2059    if has_trailing_string(',', new_children[-1]):
2060        new_children[-1] = \
2061            remove_trailing_string(',', new_children[-1])
2062        new_children.append((ws + '),', []))
2063    else:
2064        new_children.append((ws + ')', []))
2065
2066    return (new_line, new_children)
2067
2068
2069def guarded_body_transform(body, div):
2070    new_body = []
2071    if body[-1][0].strip().startswith('where'):
2072        new_body.extend(where_clause_transform(body[-1]))
2073        body = body[:-1]
2074    for i, (line, children) in enumerate(body):
2075        try:
2076            while div not in line:
2077                (l, c) = children[0]
2078                children = c + children[1:]
2079                line = line + ' ' + l.strip()
2080        except:
2081            sys.stderr.write('missing %r in %r\n' % (div, line))
2082            sys.stderr.write('\nhandling %r\n' % body)
2083            sys.exit(1)
2084        try:
2085            [ws, guts] = line.split('| ', 1)
2086        except:
2087            sys.stderr.write('missing "|" in %r\n' % line)
2088            sys.stderr.write('\nhandling %r\n' % body)
2089            sys.exit(1)
2090        if i == 0:
2091            new_body.append((ws + 'if', []))
2092        else:
2093            new_body.append((ws + 'else if', []))
2094        guts = ' then '.join(guts.split(div, 1))
2095        new_body.append((ws + guts, children))
2096    new_body.append((ws + 'else undefined', []))
2097
2098    return new_body
2099
2100
2101def lhs_transform(line):
2102    if '(' not in line:
2103        return line
2104
2105    [left, right] = line.split('\<equiv>')
2106
2107    ws = left[:len(left) - len(left.lstrip())]
2108
2109    left = left.lstrip()
2110
2111    bits = braces.str(left, '(', ')').split(braces=True)
2112
2113    for (i, bit) in enumerate(bits):
2114        if bit.startswith('('):
2115            bits[i] = 'arg%d' % i
2116            case = 'case arg%d of %s \<Rightarrow> ' % (i, bit)
2117            right = case + right
2118
2119    return ws + ' '.join([str(bit) for bit in bits]) + '\<equiv>' + right
2120
2121
2122def lhs_de_underscore(line):
2123    if '_' not in line:
2124        return line
2125
2126    [left, right] = line.split('\<equiv>')
2127
2128    ws = left[:len(left) - len(left.lstrip())]
2129
2130    left = left.lstrip()
2131    bits = left.split()
2132
2133    for (i, bit) in enumerate(bits):
2134        if bit == '_':
2135            bits[i] = 'arg%d' % i
2136
2137    return ws + ' '.join([str(bit) for bit in bits]) + ' \<equiv>' + right
2138
2139
2140regexes = [
2141    (re.compile(r" \. "), r" \<circ> "),
2142    (re.compile('-1'), '- 1'),
2143    (re.compile('-2'), '- 2'),
2144    (re.compile(r"\(!(\w+)\)"), r"(flip id \1)"),
2145    (re.compile(r"\(\+(\w+)\)"), r"(\<lambda> x. x + \1)"),
2146    (re.compile(r"\\([^<].*?)\s*->"), r"\<lambda> \1."),
2147    (re.compile('}'), r"\<rparr>"),
2148    (re.compile(r"(\s)!!(\s)"), r"\1LIST_INDEX\2"),
2149    (re.compile(r"(\w)!"), r"\1 "),
2150    (re.compile(r"\s?!"), ''),
2151    (re.compile(r"LIST_INDEX"), r"!"),
2152    (re.compile('`testBit`'), '!!'),
2153    (re.compile(r"//"), ' aLU '),
2154    (re.compile('otherwise'), 'True     '),
2155    (re.compile(r"(^|\W)fail "), r"\1haskell_fail "),
2156    (re.compile('assert '), 'haskell_assert '),
2157    (re.compile('assertE '), 'haskell_assertE '),
2158    (re.compile('=='), '='),
2159    (re.compile(r"\(/="), '(\<lambda>x. x \<noteq>'),
2160    (re.compile('/='), '\<noteq>'),
2161    (re.compile('"([^"])*"'), '[]'),
2162    (re.compile('&&'), '\<and>'),
2163    (re.compile('\|\|'), '\<or>'),
2164    (re.compile(r"(\W)not(\s)"), r"\1Not\2"),
2165    (re.compile(r"(\W)and(\s)"), r"\1andList\2"),
2166    (re.compile(r"(\W)or(\s)"), r"\1orList\2"),
2167    # regex ordering important here
2168    (re.compile(r"\.&\."), '&&'),
2169    (re.compile(r"\(\.\|\.\)"), r"bitOR"),
2170    (re.compile(r"\(\+\)"), r"op +"),
2171    (re.compile(r"\.\|\."), '||'),
2172    (re.compile(r"Data\.Word\.Word"), r"word"),
2173    (re.compile(r"Data\.Map\."), r"data_map_"),
2174    (re.compile(r"Data\.Set\."), r"data_set_"),
2175    (re.compile(r"BinaryTree\."), 'bt_'),
2176    (re.compile("mapM_"), "mapM_x"),
2177    (re.compile("forM_"), "forM_x"),
2178    (re.compile("forME_"), "forME_x"),
2179    (re.compile("mapME_"), "mapME_x"),
2180    (re.compile("zipWithM_"), "zipWithM_x"),
2181    (re.compile(r"bit\s+([0-9]+)(\s)"), r"(1 << \1)\2"),
2182    (re.compile('`mod`'), 'mod'),
2183    (re.compile('`div`'), 'div'),
2184    (re.compile(r"`((\w|\.)*)`"), r"`~\1~`"),
2185    (re.compile('size'), 'magnitude'),
2186    (re.compile('foldr'), 'foldR'),
2187    (re.compile(r"\+\+"), '@'),
2188    (re.compile(r"(\s|\)|\w|\]):(\s|\(|\w|$|\[)"), r"\1#\2"),
2189    (re.compile(r"\[([^]]*)\.\.([^]]*)\]"), r"[\1 .e. \2]"),
2190    (re.compile(r"\[([^]]*)\.\.\s*$"), r"[\1 .e."),
2191    (re.compile(' Right'), ' Inr'),
2192    (re.compile(' Left'), ' Inl'),
2193    (re.compile('\\(Left'), '(Inl'),
2194    (re.compile('\\(Right'), '(Inr'),
2195    (re.compile(r"\$!"), r"$"),
2196    (re.compile('([^>])>='), r'\1\<ge>'),
2197    (re.compile('>>([^=])'), r'>>_\1'),
2198    (re.compile('<='), '\<le>'),
2199    (re.compile(r" \\\\ "), " `~listSubtract~` "),
2200    (re.compile(r"(\s\w+)\s*@\s*\w+\s*{\s*}\s*\<leftarrow>"),
2201     r"\1 \<leftarrow>"),
2202]
2203
2204ext_regexes = [
2205    (re.compile(r"(\s[A-Z]\w*)\s*{"), r"\1_ \<lparr>", re.compile(r"(\w)\s*="),
2206     r"\1="),
2207    (re.compile(r"(\([A-Z]\w*)\s*{"), r"\1_ \<lparr>", re.compile(r"(\w)\s*="),
2208     r"\1="),
2209    (re.compile(r"{([^={<]*[^={<:])=([^={<]*)\\<rparr>"),
2210     r"\<lparr>\1:=\2\<rparr>",
2211     re.compile(r"THIS SHOULD NOT APPEAR IN THE SOURCE"), ""),
2212    (re.compile(r"{"), r"\<lparr>", re.compile(r"([^=:])=(\s|$|\w)"),
2213     r"\1:=\2"),
2214]
2215
2216leading_bar = re.compile(r"\s*\|")
2217
2218type_assertion = re.compile(r"\(([^(]*)::([^)]*)\)")
2219
2220
2221def run_regexes(xxx_todo_changeme6, _regexes=regexes):
2222    (line, children) = xxx_todo_changeme6
2223    for re, s in _regexes:
2224        line = re.sub(s, line)
2225    children = [run_regexes(elt, _regexes=_regexes) for elt in children]
2226    return ((line, children))
2227
2228
2229def run_ext_regexes(xxx_todo_changeme7):
2230    (line, children) = xxx_todo_changeme7
2231    for re, s, add_re, add_s in ext_regexes:
2232        m = re.search(line)
2233        if m is None:
2234            continue
2235        before = line[:m.start()]
2236        substituted = m.expand(s)
2237        after = line[m.end():]
2238        add = [(add_re, add_s)]
2239        (after, children) = run_regexes((after, children), _regexes=add)
2240        line = before + substituted + after
2241    children = [run_ext_regexes(elt) for elt in children]
2242    return (line, children)
2243
2244
2245def get_case_lhs(lhs):
2246    assert lhs.startswith('case \\x of ')
2247    lhs = lhs.split('case \\x of ', 1)[1]
2248    cases = lhs.split('->')
2249    cases = [case.strip() for case in cases]
2250    cases = [case for case in cases if case != '']
2251    cases = tuple(cases)
2252
2253    return cases
2254
2255
2256def get_case_rhs(rhs):
2257    tuples = []
2258    while '->' in rhs:
2259        bits = rhs.split('->', 1)
2260        s = bits[0]
2261        bits = bits[1].split(None, 1)
2262        n = int(takeWhile(bits[0], lambda x: x.isdigit())) - 1
2263        if len(bits) > 1:
2264            rhs = bits[1]
2265        else:
2266            rhs = ''
2267        tuples.append((s, n))
2268    if rhs != '':
2269        tuples.append((rhs, None))
2270
2271    conv = []
2272    for (string, num) in tuples:
2273        bits = string.split('\\n')
2274        bits = [bit.strip() for bit in bits]
2275        conv.extend([(bit, None) for bit in bits[:-1]])
2276        conv.append((bits[-1], num))
2277
2278    conv = [(s, n) for (s, n) in conv if s != '' or n is not None]
2279
2280    if conv[0][1] is not None:
2281        sys.stderr.write('%r\n' % conv[0][1])
2282        sys.stderr.write(
2283            'For technical reasons the first line of this case conversion must be split with \\n: \n')
2284        sys.stderr.write('  %r\n' % rhs)
2285        sys.stderr.write(
2286            '(further notes: the rhs of each caseconv must have multiple lines\n'
2287            'and the first cannot contain any ->1, ->2 etc.)\n')
2288        sys.exit(1)
2289
2290    # this is a tad dodgy, but means that case_clauses_transform
2291    # can be safely run twice on the same input
2292    if conv[0][0].endswith('of'):
2293        conv[0] = (conv[0][0] + ' ', conv[0][1])
2294
2295    return conv
2296
2297
2298def render_caseconv(cases, conv, f):
2299    bits = [bit for bit in conv.split('\\n') if bit != '']
2300    assert bits
2301    casestr = 'case \\x of ' + ' -> '.join(cases) + ' -> '
2302    f.write('%s --->' % casestr)
2303    for bit in bits:
2304        f.write(bit)
2305        f.write('\n')
2306    f.write('\n')
2307
2308
2309def get_case_conv_table():
2310    f = open('caseconvs')
2311    f2 = open('caseconvs-useful', 'w')
2312    result = {}
2313    input = map(str.rstrip, f)
2314    input = ("\\n".join(lines) for lines in splitList(input, lambda s: s == ''))
2315
2316    for line in input:
2317        if line.strip() == '':
2318            continue
2319        try:
2320            if '---X>' in line:
2321                [from_case, _] = line.split('---X>')
2322                cases = get_case_lhs(from_case)
2323                result[cases] = "<X>"
2324            else:
2325                [from_case, to_case] = line.split('--->')
2326                cases = get_case_lhs(from_case)
2327                conv = get_case_rhs(to_case)
2328                result[cases] = conv
2329                if (not all_constructor_patterns(cases) and
2330                        not is_extended_pattern(cases)):
2331                    render_caseconv(cases, to_case, f2)
2332        except Exception as e:
2333            sys.stderr.write('Error parsing %r\n' % line)
2334            sys.stderr.write('%s\n ' % e)
2335            sys.exit(1)
2336
2337    f.close()
2338    f2.close()
2339
2340    return result
2341
2342
2343def all_constructor_patterns(cases):
2344    if cases[-1].strip() == '_':
2345        cases = cases[:-1]
2346    for pat in cases:
2347        if not is_constructor_pattern(pat):
2348            return False
2349    return True
2350
2351
2352def is_constructor_pattern(pat):
2353    """A constructor pattern takes the form Cons var1 var2 ...,
2354        characterised by all alphanumeric names, the constructor starting
2355        with an uppercase alphabetic char and the vars with lowercase."""
2356    bits = pat.split()
2357    for bit in bits:
2358        if (not bit.isalnum()) and (not bit == '_'):
2359            return False
2360    if not bits[0][0].isupper():
2361        return False
2362    for bit in bits[1:]:
2363        if (not bit[0].islower()) and (not bit == '_'):
2364            return False
2365    return True
2366
2367
2368ext_checker = re.compile(r"^(\(|\)|,|{|}|=|[a-zA-Z][0-9']?|\s|_|:|\[|\])*$")
2369
2370
2371def is_extended_pattern(cases):
2372    for case in cases:
2373        if not ext_checker.match(case):
2374            return False
2375    return True
2376
2377
2378case_conv_table = get_case_conv_table()
2379cases_added = {}
2380
2381
2382def get_case_conv(cases):
2383    if all_constructor_patterns(cases):
2384        return all_constructor_conv(cases)
2385
2386    if is_extended_pattern(cases):
2387        return extended_pattern_conv(cases)
2388
2389    return case_conv_table.get(cases)
2390
2391
2392constructor_conv_table = {
2393    'Just': 'Some',
2394    'Nothing': 'None',
2395    'Left': 'Inl',
2396    'Right': 'Inr',
2397    'PPtr': '\<comment> \<open>PPtr\<close>',
2398    'Register': '\<comment> \<open>Register\<close>',
2399    'Word': '\<comment> \<open>Word\<close>',
2400}
2401
2402unique_ids_per_file = {}
2403
2404
2405def get_next_unique_id():
2406    id = unique_ids_per_file.get(filename, 1)
2407    unique_ids_per_file[filename] = id + 1
2408    return id
2409
2410
2411def all_constructor_conv(cases):
2412    conv = [('case \\x of', None)]
2413
2414    for i, pat in enumerate(cases):
2415        bits = pat.split()
2416        if bits[0] in constructor_conv_table:
2417            bits[0] = constructor_conv_table[bits[0]]
2418        for j, bit in enumerate(bits):
2419            if j > 0 and bit == '_':
2420                bits[j] = 'v%d' % get_next_unique_id()
2421        pat = ' '.join(bits)
2422        if i == 0:
2423            conv.append(('  %s \<Rightarrow> ' % pat, i))
2424        else:
2425            conv.append(('| %s \<Rightarrow> ' % pat, i))
2426    return conv
2427
2428
2429word_getter = re.compile(r"([a-zA-Z0-9]+)")
2430
2431record_getter = re.compile(r"([a-zA-Z0-9]+\s*{[a-zA-Z0-9'\s=\,_\(\):\]\[]*})")
2432
2433
2434def extended_pattern_conv(cases):
2435    conv = [('case \\x of', None)]
2436
2437    for i, pat in enumerate(cases):
2438        pat = '#'.join(pat.split(':'))
2439        while record_getter.search(pat):
2440            [left, record, right] = record_getter.split(pat)
2441            record = reduce_record_pattern(record)
2442            pat = left + record + right
2443        if '{' in pat:
2444            print(pat)
2445        assert '{' not in pat
2446        bits = word_getter.split(pat)
2447        bits = [constructor_conv_table.get(bit, bit) for bit in bits]
2448        pat = ''.join(bits)
2449        if i == 0:
2450            conv.append(('  %s \<Rightarrow> ' % pat, i))
2451        else:
2452            conv.append(('| %s \<Rightarrow> ' % pat, i))
2453    return conv
2454
2455
2456def reduce_record_pattern(string):
2457    assert string[-1] == '}'
2458    string = string[:-1]
2459    [left, right] = string.split('{')
2460    cons = left.strip()
2461    right = braces.str(right, '(', ')')
2462    eqs = right.split(',')
2463    uses = {}
2464    for eq in eqs:
2465        eq = str(eq).strip()
2466        if eq:
2467            [left, right] = eq.split('=')
2468            (left, right) = (left.strip(), right.strip())
2469            if len(right.split()) > 1:
2470                right = '(%s)' % right
2471            uses[left] = right
2472    if cons not in all_constructor_args:
2473        sys.stderr.write('FAIL: trying to build case for %s\n' % cons)
2474        sys.stderr.write('when reading %s\n' % filename)
2475        sys.stderr.write('but constructor not seen yet\n')
2476        sys.stderr.write('perhaps parse in different order?\n')
2477        sys.stderr.write('hint: #INCLUDE_HASKELL_PREPARSE\n')
2478        sys.exit(1)
2479    args = all_constructor_args[cons]
2480    args = [uses.get(name, '_') for (name, type) in args]
2481    return cons + ' ' + ' '.join(args)
2482
2483
2484def subs_nums_and_x(conv, x):
2485    ids = []
2486
2487    result = []
2488    for (line, num) in conv:
2489        line = x.join(line.split('\\x'))
2490        bits = line.split('\\v')
2491        line = bits[0]
2492        for bit in bits[1:]:
2493            bits = bit.split('\\', 1)
2494            n = int(bits[0])
2495            while n >= len(ids):
2496                ids.append(get_next_unique_id())
2497            line = line + 'v%d' % (ids[n])
2498            if len(bits) > 1:
2499                line = line + bits[1]
2500        result.append((line, num))
2501
2502    return result
2503
2504
2505def get_supplied_transform_table():
2506    f = open('supplied')
2507
2508    lines = [line.rstrip() for line in f]
2509    f.close()
2510
2511    lines = [(line, n + 1) for (n, line) in enumerate(lines)]
2512    lines = [(line, n) for (line, n) in lines if line != '']
2513
2514    for line in lines:
2515        if '\t' in line:
2516            sys.stderr.write('WARN: tab character in supplied')
2517
2518    tree = offside_tree(lines)
2519
2520    result = {}
2521
2522    for line, n, children in tree:
2523        if ('conv:' not in line) or len(children) != 2:
2524            sys.stderr.write('WARN: supplied line %d dropped\n'
2525                             % n)
2526            if 'conv:' not in line:
2527                sys.stderr.write('\t\t(token "conv:" missing)\n')
2528            if len(children) != 2:
2529                sys.stderr.write('\t\t(%d children != 2)\n' % len(children))
2530            continue
2531
2532        children = discard_line_numbers(children)
2533
2534        before, after = children
2535
2536        before = convert_to_stripped_tuple(before[0], before[1])
2537
2538        result[before] = after
2539
2540    return result
2541
2542
2543def print_tree(tree, indent=0):
2544    for line, children in tree:
2545        print('\t' * indent) + line.strip()
2546        print_tree(children, indent + 1)
2547
2548
2549supplied_transform_table = get_supplied_transform_table()
2550supplied_transforms_usage = dict((
2551    key, 0) for key in six.iterkeys(supplied_transform_table))
2552
2553
2554def warn_supplied_usage():
2555    for (key, usage) in six.iteritems(supplied_transforms_usage):
2556        if not usage:
2557            sys.stderr.write('WARN: supplied conv unused: %s\n'
2558                             % key[0])
2559
2560
2561quotes_getter = re.compile('"[^"]+"')
2562
2563
2564def detect_recursion(body):
2565    """Detects whether any of the bodies of the definitions of this
2566        function recursively refer to it."""
2567    single_lines = [reduce_to_single_line(elt) for elt in body]
2568    single_lines = [''.join(quotes_getter.split(l)) for l in single_lines]
2569    bits = [line.split(None, 1) for line in single_lines]
2570    name = bits[0][0]
2571    assert [n for (n, _) in bits if n != name] == []
2572    return [body for (n, body) in bits if name in body] != []
2573
2574
2575def primrec_transform(d):
2576    sig = d.sig
2577    defn = d.defined
2578    body = []
2579    is_not_first = False
2580    for (l, c) in d.body:
2581        [(l, c)] = body_transform([(l, c)], defn, sig, nopattern=True)
2582        if is_not_first:
2583            l = "| " + l
2584        else:
2585            l = "  " + l
2586            is_not_first = True
2587        l = l.split('\<equiv>')
2588        assert len(l) == 2
2589        l = '= ('.join(l)
2590        (l, c) = remove_trailing_string('"', (l, c))
2591        (l, c) = add_trailing_string(')"', (l, c))
2592        body.append((l, c))
2593    d.primrec = True
2594    d.body = body
2595    return d
2596
2597
2598variable_name_regex = re.compile(r"^[a-z]\w*$")
2599
2600
2601def is_variable_name(string):
2602    return variable_name_regex.match(string)
2603
2604
2605def pattern_match_transform(body):
2606    """Converts a body containing possibly multiple definitions
2607        and containing pattern matches into a normal Isabelle definition
2608        followed by a big Haskell case expression which is resolved
2609        elsewhere."""
2610    splits = []
2611    for (line, children) in body:
2612        string = braces.str(line, '(', ')')
2613        while len(string.split('=')) == 1:
2614            if len(children) == 1:
2615                [(moreline, children)] = children
2616                string = string + ' ' + moreline.strip()
2617            elif children and leading_bar.match(children[0][0]):
2618                string = string + ' ='
2619                children = \
2620                    guarded_body_transform(children, ' = ')
2621            elif children and children[0][1] == []:
2622                (moreline, _) = children.pop(0)
2623                string = string + ' ' + moreline.strip()
2624            else:
2625                print()
2626                print(line)
2627                print()
2628                for child in children:
2629                    print(child)
2630                assert 0
2631
2632        [lead, tail] = string.split('=', 1)
2633        bits = lead.split()
2634        unbraced = bits
2635        function = str(bits[0])
2636        splits.append((bits[1:], unbraced[1:], tail, children))
2637
2638    common = splits[0][0][:]
2639    for i, term in enumerate(common):
2640        if term.startswith('('):
2641            common[i] = None
2642        if '@' in term:
2643            common[i] = None
2644        if term[0].isupper():
2645            common[i] = None
2646
2647    for (bits, _, _, _) in splits[1:]:
2648        for i, term in enumerate(bits):
2649            if i >= len(common):
2650                print_tree(body)
2651            if term != common[i]:
2652                is_var = is_variable_name(str(term))
2653                if common[i] == '_' and is_var:
2654                    common[i] = term
2655                elif term != '_':
2656                    common[i] = None
2657
2658    for i, term in enumerate(common):
2659        if term == '_':
2660            common[i] = 'x%d' % i
2661
2662    blanks = [i for (i, n) in enumerate(common) if n is None]
2663
2664    line = '%s ' % function
2665    for i, name in enumerate(common):
2666        if name is None:
2667            line = line + 'x%d ' % i
2668        else:
2669            line = line + '%s ' % name
2670    if blanks == []:
2671        print(splits)
2672        print(common)
2673    if len(blanks) == 1:
2674        line = line + '= case x%d of' % blanks[0]
2675    else:
2676        line = line + '= case (x%d' % blanks[0]
2677        for i in blanks[1:]:
2678            line = line + ', x%d' % i
2679        line = line + ') of'
2680
2681    children = []
2682    for (bits, unbraced, tail, c) in splits:
2683        if len(blanks) == 1:
2684            l = '  %s' % unbraced[blanks[0]]
2685        else:
2686            l = '  (%s' % unbraced[blanks[0]]
2687            for i in blanks[1:]:
2688                l = l + ', %s' % unbraced[i]
2689            l = l + ')'
2690        l = l + ' -> %s' % tail
2691        children.append((l, c))
2692
2693    return [(line, children)]
2694
2695
2696def get_lambda_body_lines(d):
2697    """Returns lines equivalent to the body of the function as
2698        a lambda expression."""
2699    fn = d.defined
2700
2701    [(line, children)] = d.body
2702
2703    line = line[1:]
2704    # find \<equiv> in first or 2nd line
2705    if '\<equiv>' not in line and '\<equiv>' in children[0][0]:
2706        (l, c) = children[0]
2707        children = c + children[1:]
2708        line = line + l
2709    [lhs, rhs] = line.split('\<equiv>', 1)
2710    bits = lhs.split()
2711    args = bits[1:]
2712    assert fn in bits[0]
2713
2714    line = '(\<lambda>' + ' '.join(args) + '. ' + rhs
2715    # lines = ['(* body of %s *)' % fn, line] + flatten_tree (children)
2716    lines = [line] + flatten_tree(children)
2717    assert (lines[-1].endswith('"'))
2718    lines[-1] = lines[-1][:-1] + ')'
2719
2720    return lines
2721
2722
2723def add_trailing_string(s, xxx_todo_changeme8):
2724    (line, children) = xxx_todo_changeme8
2725    if children == []:
2726        return (line + s, children)
2727    else:
2728        modified = add_trailing_string(s, children[-1])
2729        return (line, children[0:-1] + [modified])
2730
2731
2732def remove_trailing_string(s, xxx_todo_changeme9, _handled=False):
2733    (line, children) = xxx_todo_changeme9
2734    if not _handled:
2735        try:
2736            return remove_trailing_string(s, (line, children), _handled=True)
2737        except:
2738            sys.stderr.write('handling %s\n' % ((line, children), ))
2739            raise
2740    if children == []:
2741        if not line.endswith(s):
2742            sys.stderr.write('ERR: expected %r\n' % line)
2743            sys.stderr.write('to end with %r\n' % s)
2744            assert line.endswith(s)
2745        n = len(s)
2746        return (line[:-n], [])
2747    else:
2748        modified = remove_trailing_string(s, children[-1], _handled=True)
2749        return (line, children[0:-1] + [modified])
2750
2751
2752def get_trailing_string(n, xxx_todo_changeme10):
2753    (line, children) = xxx_todo_changeme10
2754    if children == []:
2755        return line[-n:]
2756    else:
2757        return get_trailing_string(n, children[-1])
2758
2759
2760def has_trailing_string(s, xxx_todo_changeme11):
2761    (line, children) = xxx_todo_changeme11
2762    if children == []:
2763        return line.endswith(s)
2764    else:
2765        return has_trailing_string(s, children[-1])
2766
2767
2768def ensure_type_ordering(defs):
2769    typedefs = [d for d in defs if d.type == 'newtype']
2770    other = [d for d in defs if d.type != 'newtype']
2771
2772    final_typedefs = []
2773    while typedefs:
2774        try:
2775            i = 0
2776            deps = typedefs[i].typedeps
2777            while 1:
2778                for j, term in enumerate(typedefs):
2779                    if term.typename in deps:
2780                        break
2781                else:
2782                    break
2783                i = j
2784                deps = typedefs[i].typedeps
2785            final_typedefs.append(typedefs.pop(i))
2786        except Exception as e:
2787            print('Exception hit ordering types:')
2788            for td in typedefs:
2789                print('  - %s' % td.typename)
2790            raise e
2791
2792    return final_typedefs + other
2793
2794
2795def lead_ws(string):
2796    amount = len(string) - len(string.lstrip())
2797    return string[:amount]
2798
2799
2800def adjust_ws(xxx_todo_changeme12, n):
2801    (line, children) = xxx_todo_changeme12
2802    if n > 0:
2803        line = ' ' * n + line
2804    else:
2805        x = -n
2806        line = line[x:]
2807
2808    return (line, [adjust_ws(child, n) for child in children])
2809
2810
2811modulename = re.compile(r"(\w+\.)+")
2812
2813
2814def perform_module_redirects(lines, call):
2815    return [subst_module_redirects(line, call) for line in lines]
2816
2817
2818def subst_module_redirects(line, call):
2819    m = modulename.search(line)
2820    if not m:
2821        return line
2822    module = line[m.start():m.end() - 1]
2823    before = line[:m.start()]
2824    after = line[m.end():]
2825    after = subst_module_redirects(after, call)
2826    if module in call.moduletranslations:
2827        module = call.moduletranslations[module]
2828    if module:
2829        return before + module + '.' + after
2830    else:
2831        return before + after
2832