1# 2# Copyright 2014, NICTA 3# 4# This software may be distributed and modified according to the terms of 5# the BSD 2-Clause license. Note that NO WARRANTY is provided. 6# See "LICENSE_BSD2.txt" for details. 7# 8# @TAG(NICTA_BSD) 9# 10 11from __future__ import print_function 12from __future__ import absolute_import 13import braces 14import re 15import sys 16import os 17import six 18from six.moves import map 19from six.moves import range 20from six.moves import zip 21from functools import reduce 22 23 24class Call(object): 25 26 def __init__(self): 27 self.restr = None 28 self.decls_only = False 29 self.instanceproofs = False 30 self.bodies_only = False 31 self.bad_type_assignment = False 32 self.body = False 33 self.current_context = [] 34 35 36class Def(object): 37 38 def __init__(self): 39 self.type = None 40 self.defined = None 41 self.body = None 42 self.line = None 43 self.sig = None 44 self.instance_proofs = [] 45 self.instance_extras = [] 46 self.comments = [] 47 self.primrec = None 48 self.deriving = [] 49 self.instance_defs = {} 50 51 52def parse(call): 53 """Parses a file.""" 54 set_global(call) 55 56 defs = get_defs(call.filename) 57 58 lines = get_lines(defs, call) 59 60 lines = perform_module_redirects(lines, call) 61 62 return ['%s\n' % line for line in lines] 63 64def settings_line(l): 65 """Adjusts some global settings.""" 66 bits = l.split (',') 67 for bit in bits: 68 bit = bit.strip () 69 (kind, setting) = bit.split ('=') 70 kind = kind.strip () 71 if kind == 'keep_constructor': 72 [cons] = setting.split () 73 keep_conss[cons] = 1 74 else: 75 assert not "setting kind understood", bit 76 77def set_global(_call): 78 global call 79 call = _call 80 global filename 81 filename = _call.filename 82 83 84file_defs = {} 85 86 87def splitList(list, pred): 88 """Splits a list according to pred.""" 89 result = [] 90 el = [] 91 for l in list: 92 if pred(l): 93 if el != []: 94 result.append(el) 95 el = [] 96 else: 97 el.append(l) 98 if el != []: 99 result.append(el) 100 return result 101 102 103def takeWhile(list, pred): 104 """Returns the initial portion of the list where each 105element matches pred""" 106 limit = 0 107 108 for l in list: 109 if pred(l): 110 limit = limit + 1 111 else: 112 break 113 return list[0:limit] 114 115 116def get_defs(filename): 117 #if filename in file_defs: 118 # return file_defs[filename] 119 120 cmdline = os.environ['L4CPP'] 121 f = os.popen('cpp -Wno-invalid-pp-token -traditional-cpp %s %s' % 122 (cmdline, filename)) 123 input = [line.rstrip() for line in f] 124 f.close() 125 defs = top_transform(input, filename.endswith(".lhs")) 126 127 file_defs[filename] = defs 128 return defs 129 130 131def top_transform(input, isLhs): 132 """Top level transform, deals with lhs artefacts, divides 133 the code up into a series of seperate definitions, and 134 passes these definitions through the definition transforms.""" 135 to_process = [] 136 comments = [] 137 for n, line in enumerate(input): 138 if '\t' in line: 139 sys.stderr.write('WARN: tab in line %d, %s.\n' % 140 (n, filename)) 141 if isLhs: 142 if line.startswith('> '): 143 if '--' in line: 144 line = line.split('--')[0].strip() 145 146 if line[2:].strip() == '': 147 comments.append((n, 'C', '')) 148 elif line.startswith('> {-#'): 149 comments.append((n, 'C', '(*' + line + '*)')) 150 else: 151 to_process.append((line[2:], n)) 152 else: 153 if line.strip(): 154 comments.append((n, 'C', '(*' + line + '*)')) 155 else: 156 comments.append((n, 'C', '')) 157 else: 158 if '--' in line: 159 line = line.split('--')[0].rstrip() 160 161 if line.strip() == '': 162 comments.append((n, 'C', '')) 163 elif line.strip().startswith('{-'): 164 # single-line {- -} comments only 165 comments.append((n, 'C', '(*' + line + '*)')) 166 elif line.startswith('#'): 167 # preprocessor directives 168 comments.append((n, 'C', '(*' + line + '*)')) 169 else: 170 to_process.append((line, n)) 171 172 def_tree = offside_tree(to_process) 173 defs = create_defs(def_tree) 174 defs = group_defs(defs) 175 176 # Forget about the comments for now 177 178 # defs_plus_comments = [d.line, d) for d in defs] + comments 179 # defs_plus_comments.sort() 180 # defs = [] 181 # prev_comments = [] 182 # for term in defs_plus_comments: 183 # if term[1] == 'C': 184 # prev_comments.append(term[2]) 185 # else: 186 # d = term[1] 187 # d.comments = prev_comments 188 # defs.append(d) 189 # prev_comments = [] 190 191 # apply def_transform and cut out any None return values 192 defs = [defs_transform(d) for d in defs] 193 defs = [d for d in defs if d is not None] 194 195 defs = ensure_type_ordering(defs) 196 197 return defs 198 199 200def get_lines(defs, call): 201 """Gets the output lines needed for this call from 202 all the potential output generated at parse time.""" 203 204 if call.restr: 205 defs = [d for d in defs if d.type == 'comments' 206 or call.restr(d)] 207 208 output = [] 209 for d in defs: 210 lines = def_lines(d, call) 211 if lines: 212 output.extend(lines) 213 output.append('') 214 215 return output 216 217 218def offside_tree(input): 219 """Breaks lines up into a tree based on the offside rule. 220 Each line gets as children the lines following it up until 221 the next line whose indent is less or equal.""" 222 if input == []: 223 return [] 224 head, head_n = input[0] 225 head_indent = len(head) - len(head.lstrip()) 226 children = [] 227 result = [] 228 for line, n in input[1:]: 229 indent = len(line) - len(line.lstrip()) 230 if indent <= head_indent: 231 result.append((head, head_n, offside_tree(children))) 232 head, head_n, head_indent = (line, n, indent) 233 children = [] 234 else: 235 children.append((line, n)) 236 result.append((head, head_n, offside_tree(children))) 237 238 return result 239 240 241def discard_line_numbers(tree): 242 """Takes a tree containing tuples (line, n, children) and 243 discards the n terms, returning a tree with tuples 244 (line, children)""" 245 result = [] 246 for (line, _, children) in tree: 247 result.append((line, discard_line_numbers(children))) 248 return result 249 250 251def flatten_tree(tree): 252 """Returns a tree to the set of numbered lines it was 253 drawn from.""" 254 result = [] 255 for (line, children) in tree: 256 result.append(line) 257 result.extend(flatten_tree(children)) 258 259 return result 260 261 262def create_defs(tree): 263 defs = [create_def(elt) for elt in tree] 264 defs = [d for d in defs if d is not None] 265 266 return defs 267 268 269def group_defs(defs): 270 """Takes a file broken into a series of definitions, and locates 271 multiple definitions of constants, caused by type signatures or 272 pattern matching, and accumulates to a single object per genuine 273 definition""" 274 defgroups = [] 275 defined = '' 276 for d in defs: 277 this_defines = d.defined 278 if d.type != 'definitions': 279 this_defines = '' 280 if this_defines == defined and this_defines: 281 defgroups[-1].body.extend(d.body) 282 else: 283 defgroups.append(d) 284 defined = this_defines 285 286 return defgroups 287 288 289def create_def(elt): 290 """Takes an element of an offside tree and creates 291 a definition object.""" 292 (line, n, children) = elt 293 children = discard_line_numbers(children) 294 return create_def_2(line, children, n) 295 296 297def create_def_2(line, children, n): 298 d = Def() 299 d.body = [(line, children)] 300 d.line = n 301 lead = line.split(None, 3) 302 if lead[0] in ['import', 'module', 'class']: 303 return 304 elif lead[0] == 'instance': 305 type = 'instance' 306 defined = lead[2] 307 elif lead[0] in ['type', 'newtype', 'data']: 308 type = 'newtype' 309 defined = lead[1] 310 else: 311 type = 'definitions' 312 defined = lead[0] 313 314 d.type = type 315 d.defined = defined 316 return d 317 318 319def get_primrecs(): 320 f = open('primrecs') 321 keys = [line.strip() for line in f] 322 return set(key for key in keys if key != '') 323 324 325primrecs = get_primrecs() 326 327 328def defs_transform(d): 329 """Transforms the set of definitions for a function. This 330 may include its type signature, and may include the special 331 case of multiple definitions.""" 332 # the first tokens of the first line in the first definition 333 if d.type == 'newtype': 334 return newtype_transform(d) 335 elif d.type == 'instance': 336 return instance_transform(d) 337 338 lead = d.body[0][0].split(None, 2) 339 if lead[1] == '::': 340 d.sig = type_sig_transform(d.body[0]) 341 d.body.pop(0) 342 343 if d.defined in primrecs: 344 return primrec_transform(d) 345 346 if len(d.body) > 1: 347 d.body = pattern_match_transform(d.body) 348 349 if len(d.body) == 0: 350 print() 351 print(d) 352 assert 0 353 354 d.body = body_transform(d.body, d.defined, d.sig) 355 return d 356 357 358def wrap_qualify(lines, deep=True): 359 if len(lines) == 0: 360 return lines 361 362 """Close and then re-open a locale so instantiations can go through""" 363 if deep: 364 asdfextra = "" 365 else: 366 asdfextra = "" 367 368 if call.current_context: 369 lines.insert(0, 'end\nqualify {} (in Arch) {}'.format(call.current_context[-1], 370 asdfextra)) 371 lines.append('end_qualify\ncontext Arch begin global_naming %s' % call.current_context[-1]) 372 return lines 373 374def def_lines(d, call): 375 """Produces the set of lines associated with a definition.""" 376 if call.all_bits: 377 L = [] 378 if d.comments: 379 L.extend(flatten_tree(d.comments)) 380 L.append('') 381 if d.type == 'definitions': 382 L.append('definition') 383 if d.sig: 384 L.extend(flatten_tree([d.sig])) 385 L.append('where') 386 L.extend(flatten_tree(d.body)) 387 elif d.type == 'newtype': 388 L.extend(flatten_tree(d.body)) 389 if d.instance_proofs: 390 L.extend(wrap_qualify(flatten_tree(d.instance_proofs))) 391 L.append('') 392 if d.instance_extras: 393 L.extend(flatten_tree(d.instance_extras)) 394 L.append('') 395 return L 396 397 if call.instanceproofs: 398 if not call.bodies_only: 399 instance_proofs = wrap_qualify(flatten_tree(d.instance_proofs)) 400 else: 401 instance_proofs = [] 402 403 if not call.decls_only: 404 instance_extras = flatten_tree(d.instance_extras) 405 else: 406 instance_extras = [] 407 408 newline_needed = len(instance_proofs) > 0 and len(instance_extras) > 0 409 return instance_proofs + ([''] 410 if newline_needed else []) + instance_extras 411 412 if call.body: 413 return get_lambda_body_lines(d) 414 415 comments = d.comments 416 try: 417 typesig = flatten_tree([d.sig]) 418 except: 419 typesig = [] 420 body = flatten_tree(d.body) 421 type = d.type 422 423 if type == 'definitions': 424 if call.decls_only: 425 if typesig: 426 return comments + ["consts'"] + typesig 427 else: 428 return [] 429 elif call.bodies_only: 430 if d.sig: 431 defname = '%s_def' % d.defined 432 if d.primrec: 433 print('warning body-only primrec:') 434 print(body[0]) 435 return comments + ['primrec'] + body 436 return comments + ['defs %s:' % defname] + body 437 else: 438 return comments + ['definition'] + body 439 else: 440 if d.primrec: 441 return comments + ['primrec'] + typesig \ 442 + ['where'] + body 443 if typesig: 444 return comments + ['definition'] + typesig + ['where'] + body 445 else: 446 return comments + ['definition'] + body 447 elif type == 'comments': 448 return comments 449 elif type == 'newtype': 450 if not call.bodies_only: 451 return body 452 453 454def type_sig_transform(tree_element): 455 """Performs transformations on a type signature line 456 preceding a function declaration or some such.""" 457 458 line = reduce_to_single_line(tree_element) 459 (pre, post) = line.split('::') 460 result = type_transform(post) 461 if '[pp' in result: 462 print(line) 463 print(pre) 464 print(post) 465 print(result) 466 assert 0 467 line = pre + ':: "' + result + '"' 468 469 return (line, []) 470 471 472ignore_classes = {'Error': 1} 473hand_classes = {'Bits': ['HS_bit'], 474 'Num': ['minus', 'one', 'zero', 'plus', 'numeral'], 475 'FiniteBits': ['finiteBit']} 476 477 478def type_transform(string): 479 """Performs transformations on a type signature, whether 480 part of a type signature line or occuring in a function.""" 481 482 # deal with type classes by recursion 483 bits = string.split('=>', 1) 484 if len(bits) == 2: 485 lhs = bits[0].strip() 486 if lhs.startswith('(') and lhs.endswith(')'): 487 instances = lhs[1:-1].split(',') 488 string = ' => '.join(instances + [bits[1]]) 489 else: 490 instances = [lhs] 491 var_annotes = {} 492 for instance in instances: 493 (name, var) = instance.split() 494 if name in ignore_classes: 495 continue 496 if name in hand_classes: 497 names = hand_classes[name] 498 else: 499 names = [type_conv(name)] 500 var = "'" + var 501 var_annotes.setdefault(var, []) 502 var_annotes[var].extend(names) 503 transformed = type_transform(bits[1]) 504 for (var, insts) in six.iteritems(var_annotes): 505 if len(insts) == 1: 506 newvar = '(%s :: %s)' % (var, insts[0]) 507 else: 508 newvar = '(%s :: {%s})' % (var, ', '.join(insts)) 509 transformed = newvar.join(transformed.split(var, 1)) 510 return transformed 511 512 # get rid of (), insert Unit, which converts to unit 513 string = 'Unit'.join(string.split('()')) 514 515 # divide up by -> or by , then divide on space. 516 # apply everything locally then work back up 517 bstring = braces.str(string, '(', ')') 518 bits = bstring.split('->') 519 r = ' \<Rightarrow> ' 520 if len(bits) == 1: 521 bits = bstring.split(',') 522 r = ' * ' 523 result = [type_bit_transform(bit) for bit in bits] 524 return r.join(result) 525 526 527def type_bit_transform(bit): 528 s = str(bit).strip() 529 if s.startswith('['): 530 # handling this properly is hard. 531 assert s.endswith(']') 532 bit2 = braces.str(s[1:-1], '(', ')') 533 return '%s list' % type_bit_transform(bit2) 534 bits = bit.split(None, braces=True) 535 if str(bits[0]) == 'PPtr': 536 assert len(bits) == 2 537 return 'machine_word' 538 if len(bits) > 1 and bits[1].startswith('['): 539 assert bits[-1].endswith(']') 540 arg = ' '.join([str(bit) for bit in bits[1:]])[1:-1] 541 arg = type_transform(arg) 542 return ' '.join([arg, 'list', str(type_conv(bits[0]))]) 543 bits = [type_conv(bit) for bit in bits] 544 bits = constructor_reversing(bits) 545 bits = [bit.map(type_transform) for bit in bits] 546 strs = [str(bit) for bit in bits] 547 return ' '.join(strs) 548 549 550def reduce_to_single_line(tree_element): 551 def inner(tree_element, acc): 552 (line, children) = tree_element 553 acc.append(line) 554 for child in children: 555 inner(child, acc) 556 return acc 557 return ' '.join(inner(tree_element, [])) 558 559 560type_conv_table = { 561 'Maybe': 'option', 562 'Bool': 'bool', 563 'Word': 'machine_word', 564 'Int': 'nat', 565 'String': 'unit list'} 566 567 568def type_conv(string): 569 """Converts a type used in Haskell to our equivalent""" 570 if string.startswith('('): 571 # ignore compound types, type_transform will descend into em 572 result = string 573 elif '.' in string: 574 # qualified references 575 bits = string.split('.') 576 typename = bits[-1] 577 module = reduce(lambda x, y: x + '.' + y, bits[:-1]) 578 typename = type_conv(typename) 579 result = module + '.' + typename 580 elif string[0].islower(): 581 # type variable 582 result = "'%s" % string 583 elif string[0] == '[' and string[-1] == ']': 584 # list 585 inner = type_conv(string[1:-1]) 586 result = '%s list' % inner 587 elif str(string) in type_conv_table: 588 result = type_conv_table[str(string)] 589 else: 590 # convert StudlyCaps to lower_with_underscores 591 was_lower = False 592 s = '' 593 for c in string: 594 if c.isupper() and was_lower: 595 s = s + '_' + c.lower() 596 else: 597 s = s + c.lower() 598 was_lower = c.islower() 599 result = s 600 type_conv_table[str(string)] = result 601 602 return braces.clone(result, string) 603 604 605def constructor_reversing(tokens): 606 if len(tokens) < 2: 607 return tokens 608 elif len(tokens) == 2: 609 return [tokens[1], tokens[0]] 610 elif tokens[0] == '[' and tokens[2] == ']': 611 return [tokens[1], braces.str('list', '(', ')')] 612 elif len(tokens) == 4 and tokens[1] == '[' and tokens[3] == ']': 613 listToken = braces.str('(List %s)' % tokens[2], '(', ')') 614 return [listToken, tokens[0]] 615 elif tokens[0] == 'array': 616 arrow_token = braces.str('\<Rightarrow>', '(', ')') 617 return [tokens[1], arrow_token, tokens[2]] 618 elif tokens[0] == 'either': 619 plus_token = braces.str('+', '(', ')') 620 return [tokens[1], plus_token, tokens[2]] 621 elif len(tokens) == 5 and tokens[2] == '[' and tokens[4] == ']': 622 listToken = braces.str('(List %s)' % tokens[3], '(', ')') 623 lbrack = braces.str('(', '+', '+') 624 rbrack = braces.str(')', '+', '+') 625 comma = braces.str(',', '+', '+') 626 return [lbrack, tokens[1], comma, listToken, rbrack, tokens[0]] 627 elif len(tokens) == 3: 628 # here comes a fudge 629 lbrack = braces.str('(', '+', '+') 630 rbrack = braces.str(')', '+', '+') 631 comma = braces.str(',', '+', '+') 632 return [lbrack, tokens[1], comma, tokens[2], rbrack, tokens[0]] 633 else: 634 print("Error parsing " + filename) 635 print("Can't deal with %s" % tokens) 636 sys.exit(1) 637 638 639def newtype_transform(d): 640 """Takes a Haskell style newtype/data type declaration, whose 641 options are divided with | and each of whose options has named 642 elements, and forms a datatype statement and definitions for 643 the named extractors and their update functions.""" 644 if len(d.body) != 1: 645 print('--- newtype long body ---') 646 print(d) 647 [(line, children)] = d.body 648 649 if children and children[-1][0].lstrip().startswith('deriving'): 650 l = reduce_to_single_line(children[-1]) 651 children = children[:-1] 652 r = re.compile(r"[,\s\(\)]+") 653 bits = r.split(l) 654 bits = [bit for bit in bits if bit and bit != 'deriving'] 655 d.deriving = bits 656 657 line = reduce_to_single_line((line, children)) 658 659 bits = line.split(None, 1) 660 op = bits[0] 661 line = bits[1] 662 bits = line.split('=', 1) 663 header = type_conv(bits[0].strip()) 664 d.typename = header 665 d.typedeps = set() 666 if len(bits) == 1: 667 # line of form 'data Blah' introduces unknown type? 668 d.body = [('typedecl %s' % header, [])] 669 all_type_arities[header] = [] # HACK 670 return d 671 line = bits[1] 672 673 if op == 'type': 674 # type synonym 675 return typename_transform(line, header, d) 676 elif line.find('{') == -1: 677 # not a record 678 return simple_newtype_transform(line, header, d) 679 else: 680 return named_newtype_transform(line, header, d) 681 682 683def typename_transform(line, header, d): 684 try: 685 [oldtype] = line.split() 686 except: 687 sys.stderr.write('Warning: type assignment with parameters not supported %s\n' % d.body) 688 call.bad_type_assignment = True 689 return 690 if oldtype.startswith('Data.Word.Word'): 691 # take off the prefix, leave Word32 or Word64 etc 692 oldtype = oldtype[10:] 693 oldtype = type_conv(oldtype) 694 bits = oldtype.split() 695 for bit in bits: 696 d.typedeps.add(bit) 697 lines = [ 698 'type_synonym %s = "%s"' % (header, oldtype), 699 # translations (* TYPE 1 *)', 700 # "%s" <=(type) "%s"' % (oldtype, header) 701 ] 702 d.body = [(line, []) for line in lines] 703 return d 704 705 706keep_conss = {} 707 708 709def simple_newtype_transform(line, header, d): 710 lines = [] 711 arities = [] 712 for i, bit in enumerate(line.split('|')): 713 braced = braces.str(bit, '(', ')') 714 bits = braced.split() 715 if len(bits) == 2: 716 last_lhs = bits[0] 717 718 if i == 0: 719 l = ' %s' % bits[0] 720 else: 721 l = ' | %s' % bits[0] 722 for bit in bits[1:]: 723 if bit.startswith('('): 724 bit = bit[1:-1] 725 typename = type_transform(str(bit)) 726 if len(bits) == 2: 727 last_rhs = typename 728 if ' ' in typename: 729 typename = '"%s"' % typename 730 l = l + ' ' + typename 731 d.typedeps.add(typename) 732 lines.append(l) 733 734 arities.append((str(bits[0]), len(bits[1:]))) 735 736 if list((dict(arities)).values()) == [1] and header not in keep_conss: 737 return type_wrapper_type(header, last_lhs, last_rhs, d) 738 739 d.body = [('datatype %s =' % header, [(line, []) for line in lines])] 740 741 set_instance_proofs(header, arities, d) 742 743 return d 744 745 746all_constructor_args = {} 747 748 749def named_newtype_transform(line, header, d): 750 bits = line.split('|') 751 752 constructors = [get_type_map(bit) for bit in bits] 753 all_constructor_args.update(dict(constructors)) 754 755 lines = [] 756 for i, (name, map) in enumerate(constructors): 757 if i == 0: 758 l = ' %s' % name 759 else: 760 l = ' | %s' % name 761 oname = name 762 for name, type in map: 763 if type is None: 764 print("ERR: None snuck into constructor list for %s" % name) 765 print(line, header, oname) 766 assert False 767 768 if len(type.split()) == 1 and '(' not in type: 769 l = l + ' ' + type 770 else: 771 l = l + ' "' + type + '"' 772 for bit in type.split(): 773 d.typedeps.add(bit) 774 lines.append(l) 775 776 names = {} 777 types = {} 778 for cons, map in constructors: 779 for i, (name, type) in enumerate(map): 780 names.setdefault(name, {}) 781 names[name][cons] = i 782 types[name] = type 783 784 for name, map in six.iteritems(names): 785 lines.append('') 786 lines.extend(named_extractor_definitions(name, map, types[name], 787 header, dict(constructors))) 788 789 for name, map in six.iteritems(names): 790 lines.append('') 791 lines.extend(named_update_definitions(name, map, types[name], header, 792 dict(constructors))) 793 794 for name, map in constructors: 795 if map == []: 796 continue 797 lines.append('') 798 lines.extend(named_constructor_translation(name, map, header)) 799 800 if len(constructors) > 1: 801 for name, map in constructors: 802 lines.append('') 803 check = named_constructor_check(name, map, header) 804 lines.extend(check) 805 806 if len(constructors) == 1: 807 for ex_name, _ in six.iteritems(names): 808 for up_name, _ in six.iteritems(names): 809 lines.append('') 810 lines.extend(named_extractor_update_lemma(ex_name, up_name)) 811 812 arities = [(name, len(map)) for (name, map) in constructors] 813 814 if list((dict(arities)).values()) == [1] and header not in keep_conss: 815 [(cons, map)] = constructors 816 [(name, type)] = map 817 return type_wrapper_type(header, cons, type, d, decons=(name, type)) 818 819 set_instance_proofs(header, arities, d) 820 821 d.body = [('datatype %s =' % header, [(line, []) for line in lines])] 822 return d 823 824 825def named_extractor_update_lemma(ex_name, up_name): 826 lines = [] 827 lines.append('lemma %s_%s_update [simp]:' % (ex_name, up_name)) 828 829 if up_name == ex_name: 830 lines.append(' "%s (%s_update f v) = f (%s v)"' % 831 (ex_name, up_name, ex_name)) 832 else: 833 lines.append(' "%s (%s_update f v) = %s v"' % 834 (ex_name, up_name, ex_name)) 835 836 lines.append(' by (cases v) simp') 837 838 return lines 839 840 841def named_extractor_definitions(name, map, type, header, constructors): 842 lines = [] 843 lines.append('primrec') 844 lines.append(' %s :: "%s \<Rightarrow> %s"' 845 % (name, header, type)) 846 lines.append('where') 847 is_first = True 848 for cons, i in six.iteritems(map): 849 if is_first: 850 l = ' "%s (%s' % (name, cons) 851 is_first = False 852 else: 853 l = '| "%s (%s' % (name, cons) 854 num = len(constructors[cons]) 855 for k in range(num): 856 l = l + ' v%d' % k 857 l = l + ') = v%d"' % i 858 lines.append(l) 859 860 return lines 861 862 863def named_update_definitions(name, map, type, header, constructors): 864 lines = [] 865 lines.append('primrec') 866 ra = '\<Rightarrow>' 867 if len(type.split()) > 1: 868 type = '(%s)' % type 869 lines.append(' %s_update :: "(%s %s %s) %s %s %s %s"' 870 % (name, type, ra, type, ra, header, ra, header)) 871 lines.append('where') 872 is_first = True 873 for cons, i in six.iteritems(map): 874 if is_first: 875 l = ' "%s_update f (%s' % (name, cons) 876 is_first = False 877 else: 878 l = '| "%s_update f (%s' % (name, cons) 879 num = len(constructors[cons]) 880 for k in range(num): 881 l = l + ' v%d' % k 882 l = l + ') = %s' % cons 883 for k in range(num): 884 if k == i: 885 l = l + ' (f v%d)' % k 886 else: 887 l = l + ' v%d' % k 888 l = l + '"' 889 lines.append(l) 890 891 return lines 892 893 894def named_constructor_translation(name, map, header): 895 lines = [] 896 lines.append('abbreviation (input)') 897 l = ' %s_trans :: "' % name 898 for n, type in map: 899 l = l + '(' + type + ') \<Rightarrow> ' 900 l = l + '%s" ("%s\'_ \<lparr> %s= _' % (header, name, map[0][0]) 901 for n, type in map[1:]: 902 l = l + ', %s= _' % n 903 l = l + ' \<rparr>")' 904 lines.append(l) 905 lines.append('where') 906 l = ' "%s_ \<lparr> %s= v0' % (name, map[0][0]) 907 for i, (n, type) in enumerate(map[1:]): 908 l = l + ', %s= v%d' % (n, i + 1) 909 l = l + ' \<rparr> == %s' % name 910 for i in range(len(map)): 911 l = l + ' v%d' % i 912 l = l + '"' 913 lines.append(l) 914 915 return lines 916 917 918def named_constructor_check(name, map, header): 919 lines = [] 920 lines.append('definition') 921 lines.append(' is%s :: "%s \<Rightarrow> bool"' % (name, header)) 922 lines.append('where') 923 lines.append(' "is%s v \<equiv> case v of' % name) 924 l = ' %s ' % name 925 for i, _ in enumerate(map): 926 l = l + 'v%d ' % i 927 l = l + '\<Rightarrow> True' 928 lines.append(l) 929 lines.append(' | _ \<Rightarrow> False"') 930 931 return lines 932 933 934def type_wrapper_type(header, cons, rhs, d, decons=None): 935 if '\\<Rightarrow>' in d.typedeps: 936 d.body = [('(* type declaration of %s omitted *)' % header, [])] 937 return d 938 lines = [ 939 'type_synonym %s = "%s"' % (header, rhs), 940 # translations (* TYPE 2 *)', 941 # "%s" <=(type) "%s"' % (header, rhs), 942 '', 943 'definition', 944 ' %s :: "%s \\<Rightarrow> %s"' % (cons, header, header), 945 'where %s_def[simp]:' % cons, 946 ' "%s \\<equiv> id"' % cons, 947 ] 948 if decons: 949 (decons, decons_type) = decons 950 lines.extend([ 951 '', 952 'definition', 953 ' %s :: "%s \\<Rightarrow> %s"' % (decons, header, header), 954 'where', 955 ' %s_def[simp]:' % decons, 956 ' "%s \\<equiv> id"' % decons, 957 '', 958 'definition' 959 ' %s_update :: "(%s \\<Rightarrow> %s) \\<Rightarrow> %s \\<Rightarrow> %s"' 960 % (decons, header, header, header, header), 961 'where', 962 ' %s_update_def[simp]:' % decons, 963 ' "%s_update f y \<equiv> f y"' % decons, 964 '' 965 ]) 966 lines.extend(named_constructor_translation(cons, [(decons, decons_type) 967 ], header)) 968 969 d.body = [(line, []) for line in lines] 970 return d 971 972 973def instance_transform(d): 974 [(line, children)] = d.body 975 bits = line.split(None, 3) 976 assert bits[0] == 'instance' 977 classname = bits[1] 978 typename = type_conv(bits[2]) 979 if classname == 'Show': 980 print("Warning: discarding class instance '%s :: Show'" % typename) 981 return None 982 if typename == '()': 983 print("Warning: discarding class instance 'unit :: %s'" % classname) 984 return None 985 if len(bits) == 3: 986 if children == []: 987 defs = [] 988 else: 989 [(l, c)] = children 990 assert l.strip() == 'where' 991 defs = c 992 else: 993 assert bits[3:] == ['where'] 994 defs = children 995 defs = [create_def_2(l, c, 0) for (l, c) in defs] 996 defs = [d2 for d2 in defs if d2 is not None] 997 defs = group_defs(defs) 998 defs = [defs_transform(d2) for d2 in defs] 999 defs_dict = {} 1000 for d2 in defs: 1001 if d2 is not None: 1002 defs_dict[d2.defined] = d2 1003 d.instance_defs = defs_dict 1004 d.deriving = [classname] 1005 if typename not in all_type_arities: 1006 sys.stderr.write('FAIL: attempting %s\n' % d.defined) 1007 sys.stderr.write('(typename %r)\n' % typename) 1008 sys.stderr.write('when reading %s\n' % filename) 1009 sys.stderr.write('but class not defined yet\n') 1010 sys.stderr.write('perhaps parse in different order?\n') 1011 sys.stderr.write('hint: #INCLUDE_HASKELL_PREPARSE\n') 1012 sys.exit(1) 1013 arities = all_type_arities[typename] 1014 set_instance_proofs(typename, arities, d) 1015 1016 return d 1017 1018 1019all_type_arities = {} 1020 1021 1022def set_instance_proofs(header, constructor_arities, d): 1023 all_type_arities[header] = constructor_arities 1024 pfs = [] 1025 exs = [] 1026 canonical = list(enumerate(constructor_arities)) 1027 1028 classes = d.deriving 1029 instance_proof_fns = set( 1030 sorted((instance_proof_table[classname] for classname in classes), 1031 key=lambda x: x.order)) 1032 for f in instance_proof_fns: 1033 (npfs, nexs) = f(header, canonical, d) 1034 pfs.extend(npfs) 1035 exs.extend(nexs) 1036 1037 if d.type == 'newtype' and len(canonical) == 1 and False: 1038 [(cons, n)] = constructor_arities 1039 if n == 1: 1040 pfs.extend(finite_instance_proofs(header, cons)) 1041 1042 if pfs: 1043 lead = '(* %s instance proofs *)' % header 1044 d.instance_proofs = [(lead, [(line, []) for line in pfs])] 1045 if exs: 1046 lead = '(* %s extra instance defs *)' % header 1047 d.instance_extras = [(lead, [(line, []) for line in exs])] 1048 1049 1050def finite_instance_proofs(header, cons): 1051 lines = [] 1052 lines.append('') 1053 lines.append('instance %s :: finite' % header) 1054 if call.current_context: 1055 lines.append('interpretation Arch .') 1056 lines.append(' apply (intro_classes)') 1057 lines.append(' apply (rule_tac f="%s" in finite_surj_type)' 1058 % cons) 1059 lines.append(' apply (safe, case_tac x, simp_all)') 1060 lines.append(' done') 1061 1062 return (lines, []) 1063 1064# wondering where the serialisable proofs went? see 1065# commit 21361f073bbafcfc985934e563868116810d9fa2 for last known occurence. 1066 1067# leave type tags 0..11 for explicit use outside of this script 1068next_type_tag = 12 1069 1070 1071def storable_instance_proofs(header, canonical, d): 1072 proofs = [] 1073 extradefs = [] 1074 1075 global next_type_tag 1076 next_type_tag = next_type_tag + 1 1077 proofs.extend([ 1078 '', 'defs (overloaded)', ' typetag_%s[simp]:' % header, 1079 ' "typetag (x :: %s) \<equiv> %d"' % (header, next_type_tag), '' 1080 'instance %s :: dynamic' % header, ' by (intro_classes, simp)' 1081 ]) 1082 1083 proofs.append('') 1084 proofs.append('instance %s :: storable ..' % header) 1085 1086 defs = d.instance_defs 1087 extradefs.append('') 1088 if 'objBits' in defs: 1089 extradefs.append('definition') 1090 body = flatten_tree(defs['objBits'].body) 1091 bits = body[0].split('objBits') 1092 assert bits[0].strip() == '"' 1093 if bits[1].strip().startswith('_'): 1094 bits[1] = 'x ' + bits[1].strip()[1:] 1095 bits = bits[1].split(None, 1) 1096 body[0] = ' objBits_%s: "objBits (%s :: %s) %s' \ 1097 % (header, bits[0], header, bits[1]) 1098 extradefs.extend(body) 1099 1100 extradefs.append('') 1101 if 'makeObject' in defs: 1102 extradefs.append('definition') 1103 body = flatten_tree(defs['makeObject'].body) 1104 bits = body[0].split('makeObject') 1105 assert bits[0].strip() == '"' 1106 body[0] = ' makeObject_%s: "(makeObject :: %s) %s' \ 1107 % (header, header, bits[1]) 1108 extradefs.extend(body) 1109 1110 extradefs.extend(['', 'definition', ]) 1111 if 'loadObject' in defs: 1112 extradefs.append(' loadObject_%s:' % header) 1113 extradefs.extend(flatten_tree(defs['loadObject'].body)) 1114 else: 1115 extradefs.extend([ 1116 ' loadObject_%s[simp]:' % header, 1117 ' "(loadObject p q n obj) :: %s \<equiv>' % header, 1118 ' loadObject_default p q n obj"', 1119 ]) 1120 1121 extradefs.extend(['', 'definition', ]) 1122 if 'updateObject' in defs: 1123 extradefs.append(' updateObject_%s:' % header) 1124 body = flatten_tree(defs['updateObject'].body) 1125 bits = body[0].split('updateObject') 1126 assert bits[0].strip() == '"' 1127 bits = bits[1].split(None, 1) 1128 body[0] = ' "updateObject (%s :: %s) %s' \ 1129 % (bits[0], header, bits[1]) 1130 extradefs.extend(body) 1131 else: 1132 extradefs.extend([ 1133 ' updateObject_%s[simp]:' % header, 1134 ' "updateObject (val :: %s) \<equiv>' % header, 1135 ' updateObject_default val"', 1136 ]) 1137 1138 return (proofs, extradefs) 1139 1140 1141storable_instance_proofs.order = 1 1142 1143 1144def pspace_storable_instance_proofs(header, canonical, d): 1145 proofs = [] 1146 extradefs = [] 1147 1148 proofs.append('') 1149 proofs.append('instance %s :: pre_storable' % header) 1150 proofs.append(' by (intro_classes,') 1151 proofs.append( 1152 ' auto simp: projectKO_opts_defs split: kernel_object.splits arch_kernel_object.splits)') 1153 1154 defs = d.instance_defs 1155 extradefs.append('') 1156 if 'objBits' in defs: 1157 extradefs.append('definition') 1158 body = flatten_tree(defs['objBits'].body) 1159 bits = body[0].split('objBits') 1160 assert bits[0].strip() == '"' 1161 if bits[1].strip().startswith('_'): 1162 bits[1] = 'x ' + bits[1].strip()[1:] 1163 bits = bits[1].split(None, 1) 1164 body[0] = ' objBits_%s: "objBits (%s :: %s) %s' \ 1165 % (header, bits[0], header, bits[1]) 1166 extradefs.extend(body) 1167 1168 extradefs.append('') 1169 if 'makeObject' in defs: 1170 extradefs.append('definition') 1171 body = flatten_tree(defs['makeObject'].body) 1172 bits = body[0].split('makeObject') 1173 assert bits[0].strip() == '"' 1174 body[0] = ' makeObject_%s: "(makeObject :: %s) %s' \ 1175 % (header, header, bits[1]) 1176 extradefs.extend(body) 1177 1178 extradefs.extend(['', 'definition', ]) 1179 if 'loadObject' in defs: 1180 extradefs.append(' loadObject_%s:' % header) 1181 extradefs.extend(flatten_tree(defs['loadObject'].body)) 1182 else: 1183 extradefs.extend([ 1184 ' loadObject_%s[simp]:' % header, 1185 ' "(loadObject p q n obj) :: %s kernel \<equiv>' % header, 1186 ' loadObject_default p q n obj"', 1187 ]) 1188 1189 extradefs.extend(['', 'definition', ]) 1190 if 'updateObject' in defs: 1191 extradefs.append(' updateObject_%s:' % header) 1192 body = flatten_tree(defs['updateObject'].body) 1193 bits = body[0].split('updateObject') 1194 assert bits[0].strip() == '"' 1195 bits = bits[1].split(None, 1) 1196 body[0] = ' "updateObject (%s :: %s) %s' \ 1197 % (bits[0], header, bits[1]) 1198 extradefs.extend(body) 1199 else: 1200 extradefs.extend([ 1201 ' updateObject_%s[simp]:' % header, 1202 ' "updateObject (val :: %s) \<equiv>' % header, 1203 ' updateObject_default val"', 1204 ]) 1205 1206 return (proofs, extradefs) 1207 1208 1209pspace_storable_instance_proofs.order = 1 1210 1211 1212def num_instance_proofs(header, canonical, d): 1213 assert len(canonical) == 1 1214 [(_, (cons, n))] = canonical 1215 assert n == 1 1216 lines = [] 1217 1218 def add_bij_instance(classname, fns): 1219 ins = bij_instance(classname, header, cons, fns) 1220 lines.extend(ins) 1221 1222 add_bij_instance('plus', [('plus', '%s + %s', True)]) 1223 add_bij_instance('minus', [('minus', '%s - %s', True)]) 1224 add_bij_instance('zero', [('zero', '0', True)]) 1225 add_bij_instance('one', [('one', '1', True)]) 1226 add_bij_instance('times', [('times', '%s * %s', True)]) 1227 1228 return (lines, []) 1229 1230 1231num_instance_proofs.order = 2 1232 1233def enum_instance_proofs (header, canonical, d): 1234 lines = ['(*<*)'] 1235 if len(canonical) == 1: 1236 [(_, (cons, n))] = canonical 1237 assert n == 1 1238 lines.append('instantiation %s :: enum begin' % header) 1239 if call.current_context: 1240 lines.append('interpretation Arch .') 1241 lines.append('definition') 1242 lines.append(' enum_%s: "enum_class.enum \<equiv> map %s enum"' \ 1243 % (header, cons)) 1244 1245 else: 1246 cons_no_args = [cons for i, (cons, n) in canonical if n == 0] 1247 cons_one_arg = [cons for i, (cons, n) in canonical if n == 1] 1248 cons_two_args = [cons for i, (cons, n) in canonical if n > 1] 1249 assert cons_two_args == [] 1250 lines.append ('instantiation %s :: enum begin' % header) 1251 if call.current_context: 1252 lines.append('interpretation Arch .') 1253 lines.append ('definition') 1254 lines.append (' enum_%s: "enum_class.enum \<equiv> ' % header) 1255 lines.append (' [ ') 1256 for cons in cons_no_args[:-1]: 1257 lines.append (' %s,' % cons) 1258 for cons in cons_no_args[-1:]: 1259 lines.append (' %s' % cons) 1260 lines.append (' ]') 1261 for cons in cons_one_arg: 1262 lines.append (' @ (map %s enum)' % cons) 1263 lines[-1] = lines[-1] + '"' 1264 lines.append ('') 1265 for cons in cons_one_arg: 1266 lines.append('lemma %s_map_distinct[simp]: "distinct (map %s enum)"' % (cons, cons)) 1267 lines.append(' apply (simp add: distinct_map)') 1268 lines.append(' by (meson injI %s.inject)' % header) 1269 lines.append('') 1270 lines.append('definition') 1271 lines.append(' "enum_class.enum_all (P :: %s \<Rightarrow> bool) \<longleftrightarrow> Ball UNIV P"' \ 1272 % header) 1273 lines.append('') 1274 lines.append('definition') 1275 lines.append(' "enum_class.enum_ex (P :: %s \<Rightarrow> bool) \<longleftrightarrow> Bex UNIV P"' \ 1276 % header) 1277 lines.append('') 1278 lines.append(' instance') 1279 lines.append(' apply intro_classes') 1280 lines.append(' apply (safe, simp)') 1281 lines.append(' apply (case_tac x)') 1282 if len(canonical) == 1: 1283 lines.append(' apply (simp_all add: enum_%s enum_all_%s_def enum_ex_%s_def' \ 1284 % (header, header, header)) 1285 lines.append(' distinct_map_enum)') 1286 lines.append(' done') 1287 else: 1288 lines.append(' apply (simp_all add: enum_%s enum_all_%s_def enum_ex_%s_def)' \ 1289 % (header, header, header)) 1290 lines.append(' by fast+') 1291 lines.append('end') 1292 lines.append('') 1293 lines.append('instantiation %s :: enum_alt' % header) 1294 lines.append('begin') 1295 if call.current_context: 1296 lines.append('interpretation Arch .') 1297 lines.append('definition') 1298 lines.append(' enum_alt_%s: "enum_alt \<equiv> ' % header) 1299 lines.append(' alt_from_ord (enum :: %s list)"' % header) 1300 lines.append('instance ..') 1301 lines.append('end') 1302 lines.append('') 1303 lines.append('instantiation %s :: enumeration_both' % header) 1304 lines.append('begin') 1305 if call.current_context: 1306 lines.append('interpretation Arch .') 1307 lines.append('instance by (intro_classes, simp add: enum_alt_%s)' \ 1308 % header) 1309 lines.append('end') 1310 lines.append('') 1311 lines.append('(*>*)') 1312 1313 return (lines, []) 1314 1315 1316enum_instance_proofs.order = 3 1317 1318 1319def bits_instance_proofs(header, canonical, d): 1320 assert len(canonical) == 1 1321 [(_, (cons, n))] = canonical 1322 assert n == 1 1323 1324 return (bij_instance('bit', header, cons, 1325 [('shiftL', 'shiftL %s n', True), 1326 ('shiftR', 'shiftR %s n', True), 1327 ('bitSize', 'bitSize %s', False)]), []) 1328 1329 1330bits_instance_proofs.order = 5 1331 1332 1333def no_proofs(header, canonical, d): 1334 return ([], []) 1335 1336 1337no_proofs.order = 6 1338 1339# FIXME "Bounded" emits enum proofs even if something already has enum proofs 1340# generated due to "Enum" 1341 1342instance_proof_table = { 1343 'Eq': no_proofs, 1344 'Bounded': no_proofs, #enum_instance_proofs, 1345 'Enum': enum_instance_proofs, 1346 'Ix': no_proofs, 1347 'Ord': no_proofs, 1348 'Show': no_proofs, 1349 'Bits': bits_instance_proofs, 1350 'Real': no_proofs, 1351 'Num': num_instance_proofs, 1352 'Integral': no_proofs, 1353 'Storable': storable_instance_proofs, 1354 'PSpaceStorable': pspace_storable_instance_proofs, 1355 'Typeable': no_proofs, 1356 'Error': no_proofs, 1357 } 1358 1359def bij_instance (classname, typename, constructor, fns): 1360 L = [ 1361 '', 1362 'instance %s :: %s ..' % (typename, classname), 1363 'defs (overloaded)' 1364 ] 1365 for (fname, fstr, cast_return) in fns: 1366 n = len (fstr.split('%s')) - 1 1367 names = ('x', 'y', 'z', 'w')[:n] 1368 names2 = tuple ([name + "'" for name in names]) 1369 fstr1 = fstr % names 1370 fstr2 = fstr % names2 1371 L.append (' %s_%s: "%s \<equiv>' % (fname, typename, fstr1)) 1372 for name in names: 1373 L.append(" case %s of %s %s' \<Rightarrow>" \ 1374 % (name, constructor, name)) 1375 if cast_return: 1376 L.append (' %s (%s)"' % (constructor, fstr2)) 1377 else: 1378 L.append (' %s"' % fstr2) 1379 1380 return L 1381 1382def get_type_map (string): 1383 """Takes Haskell named record syntax and produces 1384 a map (in the form of a list of tuples) defining 1385 the converted types of the names.""" 1386 bits = string.split(None, 1) 1387 header = bits[0].strip() 1388 if len(bits) == 1: 1389 return (header, []) 1390 actual_map = bits[1].strip() 1391 if not (actual_map.startswith('{') and actual_map.endswith('}')): 1392 print('Error in ' + filename) 1393 print('Expected "A { blah :: blah etc }"') 1394 print('However { and } not found correctly') 1395 print('When parsing %s' % string) 1396 sys.exit(1) 1397 actual_map = actual_map[1:-1] 1398 1399 bits = braces.str(actual_map, '(', ')').split(',') 1400 bits.reverse() 1401 type = None 1402 map = [] 1403 for bit in bits: 1404 bits = bit.split('::') 1405 if len(bits) == 2: 1406 type = type_transform(str(bits[1]).strip()) 1407 name = str(bits[0]).strip() 1408 else: 1409 name = str(bit).strip() 1410 map.append((name, type)) 1411 map.reverse() 1412 return (header, map) 1413 1414 1415numLiftIO = [0] 1416 1417 1418def body_transform(body, defined, sig, nopattern=False): 1419 # assume single object 1420 [(line, children)] = body 1421 1422 if '(' in line.split('=')[0] and not nopattern: 1423 [(line, children)] = \ 1424 pattern_match_transform([(line, children)]) 1425 1426 if 'liftIO' in reduce_to_single_line((line, children)): 1427 # liftIO is the translation boundary for current 1428 # SEL4, below which we get into details of interaction 1429 # with the foreign function interface - axiomatise 1430 assert '=' in line 1431 line = line.split('=')[0] 1432 bits = line.split() 1433 numLiftIO[0] = numLiftIO[0] + 1 1434 rhs = '(%d :: Int) %s' % (numLiftIO[0], ' '.join(bits[1:])) 1435 line = '%s\<equiv> underlying_arch_op %s' % (line, rhs) 1436 children = [] 1437 elif '=' in line: 1438 line = '\<equiv>'.join(line.split('=', 1)) 1439 elif leading_bar.match(children[0][0]): 1440 pass 1441 elif '=' in children[0][0]: 1442 (nextline, c2) = children[0] 1443 children[0] = ('\<equiv>'.join(nextline.split('=', 1)), c2) 1444 else: 1445 sys.stderr.write('WARN: def of %s missing =\n' % defined) 1446 1447 if children and (children[-1][0].endswith('where') or 1448 children[-1][0].lstrip().startswith('where')): 1449 bits = line.split('\<equiv>') 1450 where_clause = where_clause_transform(children[-1]) 1451 children = children[:-1] 1452 if len(bits) == 2 and bits[1].strip(): 1453 line = bits[0] + '\<equiv>' 1454 new_line = ' ' * len(line) + bits[1] 1455 children = [(new_line, children)] 1456 else: 1457 where_clause = [] 1458 1459 (line, children) = zipWith_transforms(line, children) 1460 1461 (line, children) = supplied_transforms(line, children) 1462 1463 (line, children) = case_clauses_transform((line, children)) 1464 1465 (line, children) = do_clauses_transform((line, children), sig) 1466 1467 if children and leading_bar.match(children[0][0]): 1468 line = line + ' \<equiv>' 1469 children = guarded_body_transform(children, ' = ') 1470 1471 children = where_clause + children 1472 1473 if not nopattern: 1474 line = lhs_transform(line) 1475 line = lhs_de_underscore(line) 1476 1477 (line, children) = type_assertion_transform(line, children) 1478 1479 (line, children) = run_regexes((line, children)) 1480 (line, children) = run_ext_regexes((line, children)) 1481 1482 (line, children) = bracket_dollar_lambdas((line, children)) 1483 1484 line = '"' + line 1485 (line, children) = add_trailing_string('"', (line, children)) 1486 return [(line, children)] 1487 1488 1489dollar_lambda_regex = re.compile(r"\$\s*\\<lambda>") 1490 1491 1492def bracket_dollar_lambdas(xxx_todo_changeme): 1493 (line, children) = xxx_todo_changeme 1494 if dollar_lambda_regex.search(line): 1495 [left, right] = dollar_lambda_regex.split(line) 1496 line = '%s(\<lambda>%s' % (left, right) 1497 both = (line, children) 1498 if has_trailing_string(';', both): 1499 both = remove_trailing_string(';', both) 1500 (line, children) = add_trailing_string(');', both) 1501 else: 1502 (line, children) = add_trailing_string(')', both) 1503 children = [bracket_dollar_lambdas(elt) for elt in children] 1504 return (line, children) 1505 1506 1507def zipWith_transforms(line, children): 1508 if 'zipWithM_' not in line: 1509 children = [zipWith_transforms(l, c) for (l, c) in children] 1510 return (line, children) 1511 1512 if children == []: 1513 return (line, []) 1514 1515 if len(children) == 2: 1516 [(l, c), (l2, c2)] = children 1517 if c == [] and c2 == [] and '..]' in l + l2: 1518 children = [(l + ' ' + l2.strip(), [])] 1519 1520 (l, c) = children[-1] 1521 if c != [] or '..]' not in l: 1522 return (line, children) 1523 1524 bits = line.split('zipWithM_', 1) 1525 line = bits[0] + 'mapM_' 1526 ws = lead_ws(l) 1527 line2 = ws + '(split ' + bits[1] 1528 1529 bits = braces.str(l, '[', ']').split(None, braces=True) 1530 1531 line3 = ws + ' '.join(bits[:-2]) + ')' 1532 used_children = children[:-1] + [(line3, [])] 1533 1534 sndlast = bits[-2] 1535 last = bits[-1] 1536 if last.endswith('..]'): 1537 internal = last[1:-3].strip() 1538 if ',' in internal: 1539 bits = internal.split(',') 1540 l = '%s(zipE4 (%s) (%s) (%s))' \ 1541 % (ws, sndlast, bits[0], bits[-1]) 1542 else: 1543 l = '%s(zipE3 (%s) (%s))' % (ws, sndlast, internal) 1544 else: 1545 internal = sndlast[1:-3].strip() 1546 if ',' in internal: 1547 bits = internal.split(',') 1548 l = '%s(zipE2 (%s) (%s) (%s))' \ 1549 % (ws, bits[0], bits[1], last) 1550 else: 1551 l = '%s(zipE1 (%s) (%s))' % (ws, internal, last) 1552 1553 return (line, [(line2, used_children), (l, [])]) 1554 1555 1556def supplied_transforms(line, children): 1557 t = convert_to_stripped_tuple(line, children) 1558 1559 if t in supplied_transform_table: 1560 ws1 = lead_ws(line) 1561 result = supplied_transform_table[t] 1562 ws2 = lead_ws(result[0]) 1563 result = adjust_ws(result, len(ws1) - len(ws2)) 1564 supplied_transforms_usage[t] = 1 1565 return result 1566 1567 children = [supplied_transforms(l, c) for (l, c) in children] 1568 1569 return (line, children) 1570 1571 1572def convert_to_stripped_tuple(line, children): 1573 children = [convert_to_stripped_tuple(l, c) for (l, c) in children] 1574 1575 return (line.strip(), tuple(children)) 1576 1577 1578def type_assertion_transform_inner(line): 1579 m = type_assertion.search(line) 1580 if not m: 1581 return line 1582 var = m.expand('\\1') 1583 type = m.expand('\\2').strip() 1584 type = type_transform(type) 1585 return line[:m.start()] + '(%s::%s)' % (var, type) \ 1586 + type_assertion_transform_inner(line[m.end():]) 1587 1588 1589def type_assertion_transform(line, children): 1590 children = [type_assertion_transform(l, c) for (l, c) in children] 1591 1592 return (type_assertion_transform_inner(line), children) 1593 1594 1595def where_clause_guarded_body(xxx_todo_changeme1): 1596 (line, children) = xxx_todo_changeme1 1597 if children and leading_bar.match(children[0][0]): 1598 return (line + ' =', guarded_body_transform(children, ' = ')) 1599 else: 1600 return (line, children) 1601 1602 1603def where_clause_transform(xxx_todo_changeme2): 1604 (line, children) = xxx_todo_changeme2 1605 ws = line.split('where', 1)[0] 1606 if line.strip() != 'where': 1607 assert line.strip().startswith('where') 1608 children = [(''.join(line.split('where', 1)), [])] + children 1609 pre = ws + 'let' 1610 post = ws + 'in' 1611 1612 children = [(l, c) for (l, c) in children if l.split()[1] != '::'] 1613 children = [case_clauses_transform((l, c)) for (l, c) in children] 1614 children = [do_clauses_transform( 1615 (l, c), 1616 None, 1617 type=0) for (l, c) in children] 1618 children = list(map(where_clause_guarded_body, children)) 1619 for i, (l, c) in enumerate(children): 1620 l2 = braces.str(l, '(', ')') 1621 if len(l2.split('=')[0].split()) > 1: 1622 # turn f a = b into f = (\a -> b) 1623 l = '->'.join(l.split('=', 1)) 1624 l = lead_ws(l) + ' = (\\ '.join(l.split(None, 1)) 1625 (l, c) = add_trailing_string(')', (l, c)) 1626 children[i] = (l, c) 1627 children = order_let_children(children) 1628 for i, child in enumerate(children[:-1]): 1629 children[i] = add_trailing_string(';', child) 1630 return [(pre, [])] + children + [(post, [])] 1631 1632 1633varname_re = re.compile(r"\w+") 1634 1635 1636def order_let_children(L): 1637 single_lines = [reduce_to_single_line(elt) for elt in L] 1638 keys = [str(braces.str(line, '(', ')').split('=')[0]).split()[0] 1639 for line in single_lines] 1640 keys = dict((key, n) for (n, key) in enumerate(keys)) 1641 bits = [varname_re.findall(line) for line in single_lines] 1642 deps = {} 1643 for i, bs in enumerate(bits): 1644 for bit in bs: 1645 if bit in keys: 1646 j = keys[bit] 1647 if j != i: 1648 deps.setdefault(i, {})[j] = 1 1649 done = {} 1650 L2 = [] 1651 todo = dict(enumerate(L)) 1652 n = len(todo) 1653 while n: 1654 todo_keys = list(todo.keys()) 1655 for key in todo_keys: 1656 depstodo = [dep 1657 for dep in list(deps.get(key, {}).keys()) if dep not in done] 1658 if depstodo == []: 1659 L2.append(todo.pop(key)) 1660 done[key] = 1 1661 if len(todo) == n: 1662 print("No progress resolving let deps") 1663 print() 1664 print(list(todo.values())) 1665 print() 1666 print("Dependencies:") 1667 print(deps) 1668 assert 0 1669 n = len(todo) 1670 return L2 1671 1672 1673def do_clauses_transform(xxx_todo_changeme3, rawsig, type=None): 1674 (line, children) = xxx_todo_changeme3 1675 if children and children[-1][0].lstrip().startswith('where'): 1676 where_clause = where_clause_transform(children[-1]) 1677 where_clause = [do_clauses_transform( 1678 (l, c), rawsig, 0) for (l, c) in where_clause] 1679 others = (line, children[:-1]) 1680 others = do_clauses_transform(others, rawsig, type) 1681 (line, children) = where_clause[0] 1682 return (line, children + where_clause[1:] + [others]) 1683 1684 if children: 1685 (l, c) = children[0] 1686 if c == [] and l.endswith('do'): 1687 line = line + ' ' + l.strip() 1688 children = children[1:] 1689 1690 if type is None: 1691 if not rawsig: 1692 type = 0 1693 sig = None 1694 else: 1695 sig = ' '.join(flatten_tree([rawsig])) 1696 type = monad_type_acquire(sig) 1697 (line, type) = monad_type_transform((line, type)) 1698 if children == []: 1699 return (line, []) 1700 1701 rhs = line.split('<-', 1)[-1] 1702 if rhs.strip() == 'syscall' or rhs.strip() == 'atomicSyscall': 1703 assert len(children) == 5 1704 children = [do_clauses_transform(elt, 1705 rawsig, 1706 type=subtype) 1707 for elt, subtype in zip(children, [1, 0, 1, 0, type])] 1708 elif line.strip().endswith('catchFailure'): 1709 assert len(children) == 2 1710 children = [do_clauses_transform(elt, 1711 rawsig, 1712 type=subtype) 1713 for elt, subtype in zip(children, [1, 0])] 1714 else: 1715 children = [do_clauses_transform(elt, 1716 rawsig, 1717 type=type) for elt in children] 1718 1719 if not line.endswith('do'): 1720 return (line, children) 1721 1722 children, other_children = split_on_unmatched_bracket(children) 1723 1724 # single statement do clause won't parse in Isabelle 1725 if len(children) == 1: 1726 ws = lead_ws(line) 1727 return (line[:-2] + '(', children + [(ws + ')', [])]) 1728 1729 line = line[:-2] + '(do' + 'E' * type 1730 1731 children = [(l, c) for (l, c) in children if l.strip() or c] 1732 1733 children2 = [] 1734 for (l, c) in children: 1735 if l.lstrip().startswith('let '): 1736 if '=' not in l: 1737 extra = reduce_to_single_line(c[0]) 1738 assert '=' in extra 1739 l = l + ' ' + extra 1740 c = c[1:] 1741 l = ''.join(l.split('let ', 1)) 1742 letsubs = '<- return' + 'Ok' * type + ' (' 1743 l = letsubs.join(l.split('=', 1)) 1744 (l, c) = add_trailing_string(')', (l, c)) 1745 children2.extend(do_clause_pattern(l, c, type)) 1746 else: 1747 children2.extend(do_clause_pattern(l, c, type)) 1748 1749 children = [add_trailing_string(';', child) 1750 for child in children2[:-1]] + [children2[-1]] 1751 1752 ws = lead_ws(line) 1753 children.append((ws + 'od' + 'E' * type + ')', [])) 1754 1755 return (line, children + other_children) 1756 1757 1758left_start_table = { 1759 'ASIDPool': '(inv ASIDPool)', 1760 'HardwareASID': 'id', 1761 'ArchObjectCap': 'capCap', 1762 'Just': 'the' 1763} 1764 1765 1766def do_clause_pattern(line, children, type, n=0): 1767 bits = line.split('<-') 1768 default = [('\<leftarrow>'.join(bits), children)] 1769 if len(bits) == 1: 1770 return default 1771 (left, right) = line.split('<-', 1) 1772 if ':' not in left and '[' not in left \ 1773 and len(left.split()) == 1: 1774 return default 1775 left = left.strip() 1776 1777 v = 'v%d' % get_next_unique_id() 1778 1779 ass = 'assert' + ('E' * type) 1780 ws = lead_ws(line) 1781 1782 if left.startswith('('): 1783 assert left.endswith(')') 1784 if (',' in left): 1785 return default 1786 else: 1787 left = left[1:-1] 1788 bs = braces.str(left, '[', ']') 1789 if len(bs.split(':')) > 1: 1790 bits = [str(bit).strip() for bit in bs.split(':', 1)] 1791 lines = [('%s%s <- %s' % (ws, v, right), children), 1792 ('%s%s <- headM %s' % (ws, bits[0], v), []), 1793 ('%s%s <- tailM %s' % (ws, bits[1], v), [])] 1794 result = [] 1795 for (l, c) in lines: 1796 result.extend(do_clause_pattern(l, c, type, n + 1)) 1797 return result 1798 if left == '[]': 1799 return [('%s%s <- %s' % (ws, v, right), children), 1800 ('%s%s (%s = []) []' % (ws, ass, v), [])] 1801 if left.startswith('['): 1802 assert left.endswith(']') 1803 bs = braces.str(left[1:-1], '[', ']') 1804 bits = bs.split(',', 1) 1805 if len(bits) == 1: 1806 new_left = '%s:%s' % (bits[0], v) 1807 new_line = '%s%s <- %s' % (ws, new_left, right) 1808 extra = [('%s%s (%s = []) []' % (ws, ass, v), [])] 1809 else: 1810 new_left = '%s:[%s]' % (bits[0], bits[1]) 1811 new_line = lead_ws(line) + new_left + ' <- ' + right 1812 extra = [] 1813 return do_clause_pattern (new_line, children, type, n + 1) \ 1814 + extra 1815 for lhs in left_start_table: 1816 if left.startswith(lhs): 1817 left = left[len(lhs):] 1818 tab = left_start_table[lhs] 1819 lM = 'liftM' + 'E' * type 1820 nl = ('%s <- %s %s $ %s' % (left, lM, tab, right)) 1821 return do_clause_pattern(nl, children, type, n + 1) 1822 1823 print(line) 1824 print(left) 1825 assert 0 1826 1827 1828def split_on_unmatched_bracket(elts, n=None): 1829 if n is None: 1830 elts, other_elts, n = split_on_unmatched_bracket(elts, 0) 1831 return (elts, other_elts) 1832 1833 for (i, (line, children)) in enumerate(elts): 1834 for (j, c) in enumerate(line): 1835 if c == '(': 1836 n = n + 1 1837 if c == ')': 1838 n = n - 1 1839 if n < 0: 1840 frag1 = line[:j] 1841 frag2 = ' ' * len(frag1) + line[j:] 1842 new_elts = elts[:i] + [(frag1, [])] 1843 oth_elts = [(frag2, children)] \ 1844 + elts[i + 1:] 1845 return (new_elts, oth_elts, n) 1846 c, other_c, n = split_on_unmatched_bracket(children, n) 1847 if other_c: 1848 new_elts = elts[:i] + [(line, c)] 1849 other_elts = other_c + elts[i + 1:] 1850 return (new_elts, other_elts, n) 1851 1852 return (elts, [], n) 1853 1854 1855def monad_type_acquire(sig, type=0): 1856 # note kernel appears after kernel_f/kernel_monad 1857 for (key, n) in [('kernel_f', 1), ('fault_monad', 1), ('syscall_monad', 2), 1858 ('kernel_monad', 0), ('kernel_init', 1), ('kernel_p', 1), 1859 ('kernel', 0)]: 1860 if key in sig: 1861 sigend = sig.split(key)[-1] 1862 return monad_type_acquire(sigend, n) 1863 1864 return type 1865 1866 1867def monad_type_transform(xxx_todo_changeme4): 1868 (line, type) = xxx_todo_changeme4 1869 split = None 1870 if 'withoutError' in line: 1871 split = 'withoutError' 1872 newtype = 1 1873 elif 'doKernelOp' in line: 1874 split = 'doKernelOp' 1875 newtype = 0 1876 elif 'runInit' in line: 1877 split = 'runInit' 1878 newtype = 1 1879 elif 'withoutFailure' in line: 1880 split = 'withoutFailure' 1881 newtype = 0 1882 elif 'withoutFault' in line: 1883 split = 'withoutFault' 1884 newtype = 0 1885 elif 'withoutPreemption' in line: 1886 split = 'withoutPreemption' 1887 newtype = 0 1888 elif 'allowingFaults' in line: 1889 split = 'allowingFaults' 1890 newtype = 1 1891 elif 'allowingErrors' in line: 1892 split = 'allowingErrors' 1893 newtype = 2 1894 elif '`catchFailure`' in line: 1895 [left, right] = line.split('`catchFailure`', 1) 1896 left, _ = monad_type_transform((left, 1)) 1897 right, type = monad_type_transform((right, 0)) 1898 return (left + '`catchFailure`' + right, type) 1899 elif 'catchingFailure' in line: 1900 split = 'catchingFailure' 1901 newtype = 1 1902 elif 'catchF' in line: 1903 split = 'catchF' 1904 newtype = 1 1905 elif 'emptyOnFailure' in line: 1906 split = 'emptyOnFailure' 1907 newtype = 1 1908 elif 'constOnFailure' in line: 1909 split = 'constOnFailure' 1910 newtype = 1 1911 elif 'nothingOnFailure' in line: 1912 split = 'nothingOnFailure' 1913 newtype = 1 1914 elif 'nullCapOnFailure' in line: 1915 split = 'nullCapOnFailure' 1916 newtype = 1 1917 elif '`catchFault`' in line: 1918 split = '`catchFault`' 1919 newtype = 1 1920 elif 'capFaultOnFailure' in line: 1921 split = 'capFaultOnFailure' 1922 newtype = 1 1923 elif 'ignoreFailure' in line: 1924 split = 'ignoreFailure' 1925 newtype = 1 1926 elif 'handleInvocation False' in line: # THIS IS A HACK 1927 split = 'handleInvocation False' 1928 newtype = 0 1929 if split: 1930 [left, right] = line.split(split, 1) 1931 left, _ = monad_type_transform((left, type)) 1932 right, newnewtype = monad_type_transform((right, newtype)) 1933 return (left + split + right, newnewtype) 1934 1935 if type: 1936 line = ('return' + 'Ok' * type).join(line.split('return')) 1937 line = ('when' + 'E' * type).join(line.split('when')) 1938 line = ('unless' + 'E' * type).join(line.split('unless')) 1939 line = ('mapM' + 'E' * type).join(line.split('mapM')) 1940 line = ('forM' + 'E' * type).join(line.split('forM')) 1941 line = ('liftM' + 'E' * type).join(line.split('liftM')) 1942 line = ('assert' + 'E' * type).join(line.split('assert')) 1943 line = ('stateAssert' + 'E' * type).join(line.split('stateAssert')) 1944 1945 return (line, type) 1946 1947 1948def case_clauses_transform(xxx_todo_changeme5): 1949 (line, children) = xxx_todo_changeme5 1950 children = [case_clauses_transform(child) for child in children] 1951 1952 if not line.endswith(' of'): 1953 return (line, children) 1954 1955 bits = line.split('case ', 1) 1956 beforecase = bits[0] 1957 x = bits[1][:-3] 1958 1959 if '::' in x: 1960 x2 = braces.str(x, '(', ')') 1961 bits = x2.split('::', 1) 1962 if len(bits) == 2: 1963 x = str(bits[0]) + ':: ' + type_transform(str(bits[1])) 1964 1965 if children and children[-1][0].strip().startswith('where'): 1966 sys.stderr.write('Warning: where clause in case: %r\n' 1967 % line) 1968 return (beforecase + '(* case removed *) undefined', []) 1969 # where_clause = where_clause_transform(children[-1]) 1970 # children = children[:-1] 1971 # (in_stmt, l) = where_clause[-1] 1972 # l.append(case_clauses_transform((line, children))) 1973 # return where_clause 1974 1975 cases = [] 1976 bodies = [] 1977 for (l, c) in children: 1978 bits = l.split('->', 1) 1979 while len(bits) == 1: 1980 if c == []: 1981 sys.stderr.write('wtf %r\n' % l) 1982 sys.exit(1) 1983 if c[0][0].strip().startswith('|'): 1984 bits = [bits[0], ''] 1985 c = guarded_body_transform(c, '->') 1986 elif c[0][1] == []: 1987 l = l + ' ' + c.pop(0)[0].strip() 1988 bits = l.split('->', 1) 1989 else: 1990 [(moreline, c)] = c 1991 l = l + ' ' + moreline.strip() 1992 bits = l.split('->', 1) 1993 case = bits[0].strip() 1994 tail = bits[1] 1995 if c and c[-1][0].lstrip().startswith('where'): 1996 where_clause = where_clause_transform(c[-1]) 1997 ws = lead_ws(where_clause[0][0]) 1998 c = where_clause + [(ws + tail.strip(), [])] + c[:-1] 1999 tail = '' 2000 cases.append(case) 2001 bodies.append((tail, c)) 2002 2003 cases = tuple(cases) # used as key of a dict later, needs to be hashable 2004 # (since lists are mutable, they aren't) 2005 if not cases: 2006 print(line) 2007 conv = get_case_conv(cases) 2008 if conv == '<X>': 2009 sys.stderr.write('Warning: blanked case in caseconvs\n') 2010 return (beforecase + '(* case removed *) undefined', []) 2011 if not conv: 2012 sys.stderr.write('Warning: case %r\n' % (cases, )) 2013 if cases not in cases_added: 2014 casestr = 'case \\x of ' + ' -> '.join(cases) + ' -> ' 2015 2016 f = open('caseconvs', 'a') 2017 f.write('%s ---X>\n\n' % casestr) 2018 f.close() 2019 cases_added[cases] = 1 2020 return (beforecase + '(* case removed *) undefined', []) 2021 conv = subs_nums_and_x(conv, x) 2022 2023 new_line = beforecase + '(' + conv[0][0] 2024 assert conv[0][1] is None 2025 2026 ws = lead_ws(children[0][0]) 2027 new_children = [] 2028 for (header, endnum) in conv[1:]: 2029 if endnum is None: 2030 new_children.append((ws + header, [])) 2031 else: 2032 if (len(bodies) <= endnum): 2033 sys.stderr.write('ERROR: index %d out of bounds in case %r\n' % 2034 (endnum, 2035 cases, )) 2036 sys.exit(1) 2037 2038 (body, c) = bodies[endnum] 2039 new_children.append((ws + header + ' ' + body, c)) 2040 2041 if has_trailing_string(',', new_children[-1]): 2042 new_children[-1] = \ 2043 remove_trailing_string(',', new_children[-1]) 2044 new_children.append((ws + '),', [])) 2045 else: 2046 new_children.append((ws + ')', [])) 2047 2048 return (new_line, new_children) 2049 2050 2051def guarded_body_transform(body, div): 2052 new_body = [] 2053 if body[-1][0].strip().startswith('where'): 2054 new_body.extend(where_clause_transform(body[-1])) 2055 body = body[:-1] 2056 for i, (line, children) in enumerate(body): 2057 try: 2058 while div not in line: 2059 (l, c) = children[0] 2060 children = c + children[1:] 2061 line = line + ' ' + l.strip() 2062 except: 2063 sys.stderr.write('missing %r in %r\n' % (div, line)) 2064 sys.stderr.write('\nhandling %r\n' % body) 2065 sys.exit(1) 2066 try: 2067 [ws, guts] = line.split('| ', 1) 2068 except: 2069 sys.stderr.write('missing "|" in %r\n' % line) 2070 sys.stderr.write('\nhandling %r\n' % body) 2071 sys.exit(1) 2072 if i == 0: 2073 new_body.append((ws + 'if', [])) 2074 else: 2075 new_body.append((ws + 'else if', [])) 2076 guts = ' then '.join(guts.split(div, 1)) 2077 new_body.append((ws + guts, children)) 2078 new_body.append((ws + 'else undefined', [])) 2079 2080 return new_body 2081 2082 2083def lhs_transform(line): 2084 if '(' not in line: 2085 return line 2086 2087 [left, right] = line.split('\<equiv>') 2088 2089 ws = left[:len(left) - len(left.lstrip())] 2090 2091 left = left.lstrip() 2092 2093 bits = braces.str(left, '(', ')').split(braces=True) 2094 2095 for (i, bit) in enumerate(bits): 2096 if bit.startswith('('): 2097 bits[i] = 'arg%d' % i 2098 case = 'case arg%d of %s \<Rightarrow> ' % (i, bit) 2099 right = case + right 2100 2101 return ws + ' '.join([str(bit) for bit in bits]) + '\<equiv>' + right 2102 2103 2104def lhs_de_underscore(line): 2105 if '_' not in line: 2106 return line 2107 2108 [left, right] = line.split('\<equiv>') 2109 2110 ws = left[:len(left) - len(left.lstrip())] 2111 2112 left = left.lstrip() 2113 bits = left.split() 2114 2115 for (i, bit) in enumerate(bits): 2116 if bit == '_': 2117 bits[i] = 'arg%d' % i 2118 2119 return ws + ' '.join([str(bit) for bit in bits]) + ' \<equiv>' + right 2120 2121 2122regexes = [ 2123 (re.compile(r" \. "), r" \<circ> "), 2124 (re.compile('-1'), '- 1'), 2125 (re.compile('-2'), '- 2'), 2126 (re.compile(r"\(!(\w+)\)"), r"(flip id \1)"), 2127 (re.compile(r"\(\+(\w+)\)"), r"(\<lambda> x. x + \1)"), 2128 (re.compile(r"\\([^<].*?)\s*->"), r"\<lambda> \1."), 2129 (re.compile('}'), r"\<rparr>"), 2130 (re.compile(r"(\s)!!(\s)"), r"\1LIST_INDEX\2"), 2131 (re.compile(r"(\w)!"), r"\1 "), 2132 (re.compile(r"\s?!"), ''), 2133 (re.compile(r"LIST_INDEX"), r"!"), 2134 (re.compile('`testBit`'), '!!'), 2135 (re.compile(r"//"), ' aLU '), 2136 (re.compile('otherwise'), 'True '), 2137 (re.compile(r"(^|\W)fail "), r"\1haskell_fail "), 2138 (re.compile('assert '), 'haskell_assert '), 2139 (re.compile('assertE '), 'haskell_assertE '), 2140 (re.compile('=='), '='), 2141 (re.compile(r"\(/="), '(\<lambda>x. x \<noteq>'), 2142 (re.compile('/='), '\<noteq>'), 2143 (re.compile('"([^"])*"'), '[]'), 2144 (re.compile('&&'), '\<and>'), 2145 (re.compile('\|\|'), '\<or>'), 2146 (re.compile(r"(\W)not(\s)"), r"\1Not\2"), 2147 (re.compile(r"(\W)and(\s)"), r"\1andList\2"), 2148 (re.compile(r"(\W)or(\s)"), r"\1orList\2"), 2149 # regex ordering important here 2150 (re.compile(r"\.&\."), '&&'), 2151 (re.compile(r"\(\.\|\.\)"), r"bitOR"), 2152 (re.compile(r"\(\+\)"), r"op +"), 2153 (re.compile(r"\.\|\."), '||'), 2154 (re.compile(r"Data\.Word\.Word"), r"word"), 2155 (re.compile(r"Data\.Map\."), r"data_map_"), 2156 (re.compile(r"Data\.Set\."), r"data_set_"), 2157 (re.compile(r"BinaryTree\."), 'bt_'), 2158 (re.compile("mapM_"), "mapM_x"), 2159 (re.compile("forM_"), "forM_x"), 2160 (re.compile("forME_"), "forME_x"), 2161 (re.compile("mapME_"), "mapME_x"), 2162 (re.compile("zipWithM_"), "zipWithM_x"), 2163 (re.compile(r"bit\s+([0-9]+)(\s)"), r"(1 << \1)\2"), 2164 (re.compile('`mod`'), 'mod'), 2165 (re.compile('`div`'), 'div'), 2166 (re.compile(r"`((\w|\.)*)`"), r"`~\1~`"), 2167 (re.compile('size'), 'magnitude'), 2168 (re.compile('foldr'), 'foldR'), 2169 (re.compile(r"\+\+"), '@'), 2170 (re.compile(r"(\s|\)|\w|\]):(\s|\(|\w|$|\[)"), r"\1#\2"), 2171 (re.compile(r"\[([^]]*)\.\.([^]]*)\]"), r"[\1 .e. \2]"), 2172 (re.compile(r"\[([^]]*)\.\.\s*$"), r"[\1 .e."), 2173 (re.compile(' Right'), ' Inr'), 2174 (re.compile(' Left'), ' Inl'), 2175 (re.compile('\\(Left'), '(Inl'), 2176 (re.compile('\\(Right'), '(Inr'), 2177 (re.compile(r"\$!"), r"$"), 2178 (re.compile('([^>])>='), r'\1\<ge>'), 2179 (re.compile('>>([^=])'), r'>>_\1'), 2180 (re.compile('<='), '\<le>'), 2181 (re.compile(r" \\\\ "), " `~listSubtract~` "), 2182 (re.compile(r"(\s\w+)\s*@\s*\w+\s*{\s*}\s*\<leftarrow>"), 2183 r"\1 \<leftarrow>"), 2184] 2185 2186ext_regexes = [ 2187 (re.compile(r"(\s[A-Z]\w*)\s*{"), r"\1_ \<lparr>", re.compile(r"(\w)\s*="), 2188 r"\1="), 2189 (re.compile(r"(\([A-Z]\w*)\s*{"), r"\1_ \<lparr>", re.compile(r"(\w)\s*="), 2190 r"\1="), 2191 (re.compile(r"{([^={<]*[^={<:])=([^={<]*)\\<rparr>"), 2192 r"\<lparr>\1:=\2\<rparr>", 2193 re.compile(r"THIS SHOULD NOT APPEAR IN THE SOURCE"), ""), 2194 (re.compile(r"{"), r"\<lparr>", re.compile(r"([^=:])=(\s|$|\w)"), 2195 r"\1:=\2"), 2196] 2197 2198leading_bar = re.compile(r"\s*\|") 2199 2200type_assertion = re.compile(r"\(([^(]*)::([^)]*)\)") 2201 2202 2203def run_regexes(xxx_todo_changeme6, _regexes=regexes): 2204 (line, children) = xxx_todo_changeme6 2205 for re, s in _regexes: 2206 line = re.sub(s, line) 2207 children = [run_regexes(elt, _regexes=_regexes) for elt in children] 2208 return ((line, children)) 2209 2210 2211def run_ext_regexes(xxx_todo_changeme7): 2212 (line, children) = xxx_todo_changeme7 2213 for re, s, add_re, add_s in ext_regexes: 2214 m = re.search(line) 2215 if m is None: 2216 continue 2217 before = line[:m.start()] 2218 substituted = m.expand(s) 2219 after = line[m.end():] 2220 add = [(add_re, add_s)] 2221 (after, children) = run_regexes((after, children), _regexes=add) 2222 line = before + substituted + after 2223 children = [run_ext_regexes(elt) for elt in children] 2224 return (line, children) 2225 2226 2227def get_case_lhs(lhs): 2228 assert lhs.startswith('case \\x of ') 2229 lhs = lhs.split('case \\x of ', 1)[1] 2230 cases = lhs.split('->') 2231 cases = [case.strip() for case in cases] 2232 cases = [case for case in cases if case != ''] 2233 cases = tuple(cases) 2234 2235 return cases 2236 2237 2238def get_case_rhs(rhs): 2239 tuples = [] 2240 while '->' in rhs: 2241 bits = rhs.split('->', 1) 2242 s = bits[0] 2243 bits = bits[1].split(None, 1) 2244 n = int(takeWhile(bits[0], lambda x: x.isdigit())) - 1 2245 if len(bits) > 1: 2246 rhs = bits[1] 2247 else: 2248 rhs = '' 2249 tuples.append((s, n)) 2250 if rhs != '': 2251 tuples.append((rhs, None)) 2252 2253 conv = [] 2254 for (string, num) in tuples: 2255 bits = string.split('\\n') 2256 bits = [bit.strip() for bit in bits] 2257 conv.extend([(bit, None) for bit in bits[:-1]]) 2258 conv.append((bits[-1], num)) 2259 2260 conv = [(s, n) for (s, n) in conv if s != '' or n is not None] 2261 2262 if conv[0][1] is not None: 2263 sys.stderr.write('%r\n' % conv[0][1]) 2264 sys.stderr.write( 2265 'For technical reasons the first line of this case conversion must be split with \\n: \n') 2266 sys.stderr.write(' %r\n' % rhs) 2267 sys.stderr.write( 2268 '(further notes: the rhs of each caseconv must have multiple lines\n' 2269 'and the first cannot contain any ->1, ->2 etc.)\n') 2270 sys.exit(1) 2271 2272 # this is a tad dodgy, but means that case_clauses_transform 2273 # can be safely run twice on the same input 2274 if conv[0][0].endswith('of'): 2275 conv[0] = (conv[0][0] + ' ', conv[0][1]) 2276 2277 return conv 2278 2279 2280def render_caseconv(cases, conv, f): 2281 bits = [bit for bit in conv.split('\\n') if bit != ''] 2282 assert bits 2283 casestr = 'case \\x of ' + ' -> '.join(cases) + ' -> ' 2284 f.write('%s --->' % casestr) 2285 for bit in bits: 2286 f.write(bit) 2287 f.write('\n') 2288 f.write('\n') 2289 2290 2291def get_case_conv_table(): 2292 f = open('caseconvs') 2293 f2 = open('caseconvs-useful', 'w') 2294 result = {} 2295 input = map(str.rstrip, f) 2296 input = ("\\n".join(lines) for lines in splitList(input, lambda s: s == '')) 2297 2298 for line in input: 2299 if line.strip() == '': 2300 continue 2301 try: 2302 if '---X>' in line: 2303 [from_case, _] = line.split('---X>') 2304 cases = get_case_lhs(from_case) 2305 result[cases] = "<X>" 2306 else: 2307 [from_case, to_case] = line.split('--->') 2308 cases = get_case_lhs(from_case) 2309 conv = get_case_rhs(to_case) 2310 result[cases] = conv 2311 if (not all_constructor_patterns(cases) and 2312 not is_extended_pattern(cases)): 2313 render_caseconv(cases, to_case, f2) 2314 except Exception as e: 2315 sys.stderr.write('Error parsing %r\n' % line) 2316 sys.stderr.write('%s\n ' % e) 2317 sys.exit(1) 2318 2319 f.close() 2320 f2.close() 2321 2322 return result 2323 2324 2325def all_constructor_patterns(cases): 2326 if cases[-1].strip() == '_': 2327 cases = cases[:-1] 2328 for pat in cases: 2329 if not is_constructor_pattern(pat): 2330 return False 2331 return True 2332 2333 2334def is_constructor_pattern(pat): 2335 """A constructor pattern takes the form Cons var1 var2 ..., 2336 characterised by all alphanumeric names, the constructor starting 2337 with an uppercase alphabetic char and the vars with lowercase.""" 2338 bits = pat.split() 2339 for bit in bits: 2340 if (not bit.isalnum()) and (not bit == '_'): 2341 return False 2342 if not bits[0][0].isupper(): 2343 return False 2344 for bit in bits[1:]: 2345 if (not bit[0].islower()) and (not bit == '_'): 2346 return False 2347 return True 2348 2349 2350ext_checker = re.compile(r"^(\(|\)|,|{|}|=|[a-zA-Z][0-9']?|\s|_|:|\[|\])*$") 2351 2352 2353def is_extended_pattern(cases): 2354 for case in cases: 2355 if not ext_checker.match(case): 2356 return False 2357 return True 2358 2359 2360case_conv_table = get_case_conv_table() 2361cases_added = {} 2362 2363 2364def get_case_conv(cases): 2365 if all_constructor_patterns(cases): 2366 return all_constructor_conv(cases) 2367 2368 if is_extended_pattern(cases): 2369 return extended_pattern_conv(cases) 2370 2371 return case_conv_table.get(cases) 2372 2373 2374constructor_conv_table = { 2375 'Just': 'Some', 2376 'Nothing': 'None', 2377 'Left': 'Inl', 2378 'Right': 'Inr', 2379 'PPtr': '(* PPtr *)', 2380 'Register': '(* Register *)', 2381 'Word': '(* Word *)', 2382} 2383 2384unique_ids_per_file = {} 2385 2386 2387def get_next_unique_id(): 2388 id = unique_ids_per_file.get(filename, 1) 2389 unique_ids_per_file[filename] = id + 1 2390 return id 2391 2392 2393def all_constructor_conv(cases): 2394 conv = [('case \\x of', None)] 2395 2396 for i, pat in enumerate(cases): 2397 bits = pat.split() 2398 if bits[0] in constructor_conv_table: 2399 bits[0] = constructor_conv_table[bits[0]] 2400 for j, bit in enumerate(bits): 2401 if j > 0 and bit == '_': 2402 bits[j] = 'v%d' % get_next_unique_id() 2403 pat = ' '.join(bits) 2404 if i == 0: 2405 conv.append((' %s \<Rightarrow> ' % pat, i)) 2406 else: 2407 conv.append(('| %s \<Rightarrow> ' % pat, i)) 2408 return conv 2409 2410 2411word_getter = re.compile (r"([a-zA-Z0-9]+)") 2412 2413record_getter = re.compile (r"([a-zA-Z0-9]+\s*{[a-zA-Z0-9'\s=\,_\(\):\]\[]*})") 2414 2415 2416def extended_pattern_conv(cases): 2417 conv = [('case \\x of', None)] 2418 2419 for i, pat in enumerate(cases): 2420 pat = '#'.join(pat.split(':')) 2421 while record_getter.search(pat): 2422 [left, record, right] = record_getter.split(pat) 2423 record = reduce_record_pattern(record) 2424 pat = left + record + right 2425 if '{' in pat: 2426 print(pat) 2427 assert '{' not in pat 2428 bits = word_getter.split(pat) 2429 bits = [constructor_conv_table.get(bit, bit) for bit in bits] 2430 pat = ''.join(bits) 2431 if i == 0: 2432 conv.append((' %s \<Rightarrow> ' % pat, i)) 2433 else: 2434 conv.append(('| %s \<Rightarrow> ' % pat, i)) 2435 return conv 2436 2437 2438def reduce_record_pattern(string): 2439 assert string[-1] == '}' 2440 string = string[:-1] 2441 [left, right] = string.split('{') 2442 cons = left.strip() 2443 right = braces.str(right, '(', ')') 2444 eqs = right.split(',') 2445 uses = {} 2446 for eq in eqs: 2447 eq = str(eq).strip() 2448 if eq: 2449 [left, right] = eq.split('=') 2450 (left, right) = (left.strip(), right.strip()) 2451 if len(right.split()) > 1: 2452 right = '(%s)' % right 2453 uses[left] = right 2454 if cons not in all_constructor_args: 2455 sys.stderr.write('FAIL: trying to build case for %s\n' % cons) 2456 sys.stderr.write('when reading %s\n' % filename) 2457 sys.stderr.write('but constructor not seen yet\n') 2458 sys.stderr.write('perhaps parse in different order?\n') 2459 sys.stderr.write('hint: #INCLUDE_HASKELL_PREPARSE\n') 2460 sys.exit(1) 2461 args = all_constructor_args[cons] 2462 args = [uses.get(name, '_') for (name, type) in args] 2463 return cons + ' ' + ' '.join(args) 2464 2465 2466def subs_nums_and_x(conv, x): 2467 ids = [] 2468 2469 result = [] 2470 for (line, num) in conv: 2471 line = x.join(line.split('\\x')) 2472 bits = line.split('\\v') 2473 line = bits[0] 2474 for bit in bits[1:]: 2475 bits = bit.split('\\', 1) 2476 n = int(bits[0]) 2477 while n >= len(ids): 2478 ids.append(get_next_unique_id()) 2479 line = line + 'v%d' % (ids[n]) 2480 if len(bits) > 1: 2481 line = line + bits[1] 2482 result.append((line, num)) 2483 2484 return result 2485 2486 2487def get_supplied_transform_table(): 2488 f = open('supplied') 2489 2490 lines = [line.rstrip() for line in f] 2491 f.close() 2492 2493 lines = [(line, n + 1) for (n, line) in enumerate(lines)] 2494 lines = [(line, n) for (line, n) in lines if line != ''] 2495 2496 for line in lines: 2497 if '\t' in line: 2498 sys.stderr.write('WARN: tab character in supplied') 2499 2500 tree = offside_tree(lines) 2501 2502 result = {} 2503 2504 for line, n, children in tree: 2505 if ('conv:' not in line) or len(children) != 2: 2506 sys.stderr.write('WARN: supplied line %d dropped\n' 2507 % n) 2508 if 'conv:' not in line: 2509 sys.stderr.write('\t\t(token "conv:" missing)\n') 2510 if len(children) != 2: 2511 sys.stderr.write('\t\t(%d children != 2)\n' % len(children)) 2512 continue 2513 2514 children = discard_line_numbers(children) 2515 2516 before, after = children 2517 2518 before = convert_to_stripped_tuple(before[0], before[1]) 2519 2520 result[before] = after 2521 2522 return result 2523 2524 2525def print_tree(tree, indent=0): 2526 for line, children in tree: 2527 print('\t' * indent) + line.strip() 2528 print_tree(children, indent + 1) 2529 2530 2531supplied_transform_table = get_supplied_transform_table() 2532supplied_transforms_usage = dict(( 2533 key, 0) for key in six.iterkeys(supplied_transform_table)) 2534 2535 2536def warn_supplied_usage(): 2537 for (key, usage) in six.iteritems(supplied_transforms_usage): 2538 if not usage: 2539 sys.stderr.write('WARN: supplied conv unused: %s\n' 2540 % key[0]) 2541 2542 2543quotes_getter = re.compile('"[^"]+"') 2544 2545 2546def detect_recursion(body): 2547 """Detects whether any of the bodies of the definitions of this 2548 function recursively refer to it.""" 2549 single_lines = [reduce_to_single_line(elt) for elt in body] 2550 single_lines = [''.join(quotes_getter.split(l)) for l in single_lines] 2551 bits = [line.split(None, 1) for line in single_lines] 2552 name = bits[0][0] 2553 assert [n for (n, _) in bits if n != name] == [] 2554 return [body for (n, body) in bits if name in body] != [] 2555 2556 2557def primrec_transform(d): 2558 sig = d.sig 2559 defn = d.defined 2560 body = [] 2561 is_not_first = False 2562 for (l, c) in d.body: 2563 [(l, c)] = body_transform([(l, c)], defn, sig, nopattern=True) 2564 if is_not_first: 2565 l = "| " + l 2566 else: 2567 l = " " + l 2568 is_not_first = True 2569 l = l.split('\<equiv>') 2570 assert len(l) == 2 2571 l = '= ('.join(l) 2572 (l, c) = remove_trailing_string('"', (l, c)) 2573 (l, c) = add_trailing_string(')"', (l, c)) 2574 body.append((l, c)) 2575 d.primrec = True 2576 d.body = body 2577 return d 2578 2579 2580variable_name_regex = re.compile(r"^[a-z]\w*$") 2581 2582 2583def is_variable_name(string): 2584 return variable_name_regex.match(string) 2585 2586 2587def pattern_match_transform(body): 2588 """Converts a body containing possibly multiple definitions 2589 and containing pattern matches into a normal Isabelle definition 2590 followed by a big Haskell case expression which is resolved 2591 elsewhere.""" 2592 splits = [] 2593 for (line, children) in body: 2594 string = braces.str(line, '(', ')') 2595 while len(string.split('=')) == 1: 2596 if len(children) == 1: 2597 [(moreline, children)] = children 2598 string = string + ' ' + moreline.strip() 2599 elif children and leading_bar.match(children[0][0]): 2600 string = string + ' =' 2601 children = \ 2602 guarded_body_transform(children, ' = ') 2603 elif children and children[0][1] == []: 2604 (moreline, _) = children.pop(0) 2605 string = string + ' ' + moreline.strip() 2606 else: 2607 print() 2608 print(line) 2609 print() 2610 for child in children: 2611 print(child) 2612 assert 0 2613 2614 [lead, tail] = string.split('=', 1) 2615 bits = lead.split() 2616 unbraced = bits 2617 function = str(bits[0]) 2618 splits.append((bits[1:], unbraced[1:], tail, children)) 2619 2620 common = splits[0][0][:] 2621 for i, term in enumerate(common): 2622 if term.startswith('('): 2623 common[i] = None 2624 if '@' in term: 2625 common[i] = None 2626 if term[0].isupper(): 2627 common[i] = None 2628 2629 for (bits, _, _, _) in splits[1:]: 2630 for i, term in enumerate(bits): 2631 if i >= len(common): 2632 print_tree(body) 2633 if term != common[i]: 2634 is_var = is_variable_name(str(term)) 2635 if common[i] == '_' and is_var: 2636 common[i] = term 2637 elif term != '_': 2638 common[i] = None 2639 2640 for i, term in enumerate(common): 2641 if term == '_': 2642 common[i] = 'x%d' % i 2643 2644 blanks = [i for (i, n) in enumerate(common) if n is None] 2645 2646 line = '%s ' % function 2647 for i, name in enumerate(common): 2648 if name is None: 2649 line = line + 'x%d ' % i 2650 else: 2651 line = line + '%s ' % name 2652 if blanks == []: 2653 print(splits) 2654 print(common) 2655 if len(blanks) == 1: 2656 line = line + '= case x%d of' % blanks[0] 2657 else: 2658 line = line + '= case (x%d' % blanks[0] 2659 for i in blanks[1:]: 2660 line = line + ', x%d' % i 2661 line = line + ') of' 2662 2663 children = [] 2664 for (bits, unbraced, tail, c) in splits: 2665 if len(blanks) == 1: 2666 l = ' %s' % unbraced[blanks[0]] 2667 else: 2668 l = ' (%s' % unbraced[blanks[0]] 2669 for i in blanks[1:]: 2670 l = l + ', %s' % unbraced[i] 2671 l = l + ')' 2672 l = l + ' -> %s' % tail 2673 children.append((l, c)) 2674 2675 return [(line, children)] 2676 2677 2678def get_lambda_body_lines(d): 2679 """Returns lines equivalent to the body of the function as 2680 a lambda expression.""" 2681 fn = d.defined 2682 2683 [(line, children)] = d.body 2684 2685 line = line[1:] 2686 # find \<equiv> in first or 2nd line 2687 if '\<equiv>' not in line and '\<equiv>' in children[0][0]: 2688 (l, c) = children[0] 2689 children = c + children[1:] 2690 line = line + l 2691 [lhs, rhs] = line.split('\<equiv>', 1) 2692 bits = lhs.split() 2693 args = bits[1:] 2694 assert fn in bits[0] 2695 2696 line = '(\<lambda>' + ' '.join(args) + '. ' + rhs 2697 # lines = ['(* body of %s *)' % fn, line] + flatten_tree (children) 2698 lines = [line] + flatten_tree(children) 2699 assert (lines[-1].endswith('"')) 2700 lines[-1] = lines[-1][:-1] + ')' 2701 2702 return lines 2703 2704 2705def add_trailing_string(s, xxx_todo_changeme8): 2706 (line, children) = xxx_todo_changeme8 2707 if children == []: 2708 return (line + s, children) 2709 else: 2710 modified = add_trailing_string(s, children[-1]) 2711 return (line, children[0:-1] + [modified]) 2712 2713 2714def remove_trailing_string(s, xxx_todo_changeme9, _handled=False): 2715 (line, children) = xxx_todo_changeme9 2716 if not _handled: 2717 try: 2718 return remove_trailing_string(s, (line, children), _handled=True) 2719 except: 2720 sys.stderr.write('handling %s\n' % ((line, children), )) 2721 raise 2722 if children == []: 2723 if not line.endswith(s): 2724 sys.stderr.write('ERR: expected %r\n' % line) 2725 sys.stderr.write('to end with %r\n' % s) 2726 assert line.endswith(s) 2727 n = len(s) 2728 return (line[:-n], []) 2729 else: 2730 modified = remove_trailing_string(s, children[-1], _handled=True) 2731 return (line, children[0:-1] + [modified]) 2732 2733 2734def get_trailing_string(n, xxx_todo_changeme10): 2735 (line, children) = xxx_todo_changeme10 2736 if children == []: 2737 return line[-n:] 2738 else: 2739 return get_trailing_string(n, children[-1]) 2740 2741 2742def has_trailing_string(s, xxx_todo_changeme11): 2743 (line, children) = xxx_todo_changeme11 2744 if children == []: 2745 return line.endswith(s) 2746 else: 2747 return has_trailing_string(s, children[-1]) 2748 2749 2750def ensure_type_ordering(defs): 2751 typedefs = [d for d in defs if d.type == 'newtype'] 2752 other = [d for d in defs if d.type != 'newtype'] 2753 2754 final_typedefs = [] 2755 while typedefs: 2756 try: 2757 i = 0 2758 deps = typedefs[i].typedeps 2759 while 1: 2760 for j, term in enumerate(typedefs): 2761 if term.typename in deps: 2762 break 2763 else: 2764 break 2765 i = j 2766 deps = typedefs[i].typedeps 2767 final_typedefs.append(typedefs.pop(i)) 2768 except Exception as e: 2769 print('Exception hit ordering types:') 2770 for td in typedefs: 2771 print(' - %s' % td.typename) 2772 raise e 2773 2774 return final_typedefs + other 2775 2776 2777def lead_ws(string): 2778 amount = len(string) - len(string.lstrip()) 2779 return string[:amount] 2780 2781 2782def adjust_ws(xxx_todo_changeme12, n): 2783 (line, children) = xxx_todo_changeme12 2784 if n > 0: 2785 line = ' ' * n + line 2786 else: 2787 x = -n 2788 line = line[x:] 2789 2790 return (line, [adjust_ws(child, n) for child in children]) 2791 2792 2793modulename = re.compile(r"(\w+\.)+") 2794 2795 2796def perform_module_redirects(lines, call): 2797 return [subst_module_redirects(line, call) for line in lines] 2798 2799 2800def subst_module_redirects(line, call): 2801 m = modulename.search(line) 2802 if not m: 2803 return line 2804 module = line[m.start():m.end() - 1] 2805 before = line[:m.start()] 2806 after = line[m.end():] 2807 after = subst_module_redirects(after, call) 2808 if module in call.moduletranslations: 2809 module = call.moduletranslations[module] 2810 if module: 2811 return before + module + '.' + after 2812 else: 2813 return before + after 2814