1# 2# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230) 3# 4# SPDX-License-Identifier: BSD-2-Clause 5# 6 7from __future__ import print_function 8from __future__ import absolute_import 9import braces 10import re 11import sys 12import os 13import six 14from six.moves import map 15from six.moves import range 16from six.moves import zip 17from functools import reduce 18 19 20class Call(object): 21 22 def __init__(self): 23 self.restr = None 24 self.decls_only = False 25 self.instanceproofs = False 26 self.bodies_only = False 27 self.bad_type_assignment = False 28 self.body = False 29 self.current_context = [] 30 31 32class Def(object): 33 34 def __init__(self): 35 self.type = None 36 self.defined = None 37 self.body = None 38 self.line = None 39 self.sig = None 40 self.instance_proofs = [] 41 self.instance_extras = [] 42 self.comments = [] 43 self.primrec = None 44 self.deriving = [] 45 self.instance_defs = {} 46 47 48def parse(call): 49 """Parses a file.""" 50 set_global(call) 51 52 defs = get_defs(call.filename) 53 54 lines = get_lines(defs, call) 55 56 lines = perform_module_redirects(lines, call) 57 58 return ['%s\n' % line for line in lines] 59 60 61def settings_line(l): 62 """Adjusts some global settings.""" 63 bits = l.split(',') 64 for bit in bits: 65 bit = bit.strip() 66 (kind, setting) = bit.split('=') 67 kind = kind.strip() 68 if kind == 'keep_constructor': 69 [cons] = setting.split() 70 keep_conss[cons] = 1 71 else: 72 assert not "setting kind understood", bit 73 74 75def set_global(_call): 76 global call 77 call = _call 78 global filename 79 filename = _call.filename 80 81 82file_defs = {} 83 84 85def splitList(list, pred): 86 """Splits a list according to pred.""" 87 result = [] 88 el = [] 89 for l in list: 90 if pred(l): 91 if el != []: 92 result.append(el) 93 el = [] 94 else: 95 el.append(l) 96 if el != []: 97 result.append(el) 98 return result 99 100 101def takeWhile(list, pred): 102 """Returns the initial portion of the list where each 103element matches pred""" 104 limit = 0 105 106 for l in list: 107 if pred(l): 108 limit = limit + 1 109 else: 110 break 111 return list[0:limit] 112 113 114def get_defs(filename): 115 # if filename in file_defs: 116 # return file_defs[filename] 117 118 cmdline = os.environ['L4CPP'] 119 f = os.popen('cpp -Wno-invalid-pp-token -traditional-cpp %s %s' % 120 (cmdline, filename)) 121 input = [line.rstrip() for line in f] 122 f.close() 123 defs = top_transform(input, filename.endswith(".lhs")) 124 125 file_defs[filename] = defs 126 return defs 127 128 129def top_transform(input, isLhs): 130 """Top level transform, deals with lhs artefacts, divides 131 the code up into a series of seperate definitions, and 132 passes these definitions through the definition transforms.""" 133 to_process = [] 134 comments = [] 135 for n, line in enumerate(input): 136 if '\t' in line: 137 sys.stderr.write('WARN: tab in line %d, %s.\n' % 138 (n, filename)) 139 if isLhs: 140 if line.startswith('> '): 141 if '--' in line: 142 line = line.split('--')[0].strip() 143 144 if line[2:].strip() == '': 145 comments.append((n, 'C', '')) 146 elif line.startswith('> {-#'): 147 comments.append((n, 'C', '(*' + line + '*)')) 148 else: 149 to_process.append((line[2:], n)) 150 else: 151 if line.strip(): 152 comments.append((n, 'C', '(*' + line + '*)')) 153 else: 154 comments.append((n, 'C', '')) 155 else: 156 if '--' in line: 157 line = line.split('--')[0].rstrip() 158 159 if line.strip() == '': 160 comments.append((n, 'C', '')) 161 elif line.strip().startswith('{-'): 162 # single-line {- -} comments only 163 comments.append((n, 'C', '(*' + line + '*)')) 164 elif line.startswith('#'): 165 # preprocessor directives 166 comments.append((n, 'C', '(*' + line + '*)')) 167 else: 168 to_process.append((line, n)) 169 170 def_tree = offside_tree(to_process) 171 defs = create_defs(def_tree) 172 defs = group_defs(defs) 173 174 # Forget about the comments for now 175 176 # defs_plus_comments = [d.line, d) for d in defs] + comments 177 # defs_plus_comments.sort() 178 # defs = [] 179 # prev_comments = [] 180 # for term in defs_plus_comments: 181 # if term[1] == 'C': 182 # prev_comments.append(term[2]) 183 # else: 184 # d = term[1] 185 # d.comments = prev_comments 186 # defs.append(d) 187 # prev_comments = [] 188 189 # apply def_transform and cut out any None return values 190 defs = [defs_transform(d) for d in defs] 191 defs = [d for d in defs if d is not None] 192 193 defs = ensure_type_ordering(defs) 194 195 return defs 196 197 198def get_lines(defs, call): 199 """Gets the output lines needed for this call from 200 all the potential output generated at parse time.""" 201 202 if call.restr: 203 defs = [d for d in defs if d.type == 'comments' 204 or call.restr(d)] 205 206 output = [] 207 for d in defs: 208 lines = def_lines(d, call) 209 if lines: 210 output.extend(lines) 211 output.append('') 212 213 return output 214 215 216def offside_tree(input): 217 """Breaks lines up into a tree based on the offside rule. 218 Each line gets as children the lines following it up until 219 the next line whose indent is less or equal.""" 220 if input == []: 221 return [] 222 head, head_n = input[0] 223 head_indent = len(head) - len(head.lstrip()) 224 children = [] 225 result = [] 226 for line, n in input[1:]: 227 indent = len(line) - len(line.lstrip()) 228 if indent <= head_indent: 229 result.append((head, head_n, offside_tree(children))) 230 head, head_n, head_indent = (line, n, indent) 231 children = [] 232 else: 233 children.append((line, n)) 234 result.append((head, head_n, offside_tree(children))) 235 236 return result 237 238 239def discard_line_numbers(tree): 240 """Takes a tree containing tuples (line, n, children) and 241 discards the n terms, returning a tree with tuples 242 (line, children)""" 243 result = [] 244 for (line, _, children) in tree: 245 result.append((line, discard_line_numbers(children))) 246 return result 247 248 249def flatten_tree(tree): 250 """Returns a tree to the set of numbered lines it was 251 drawn from.""" 252 result = [] 253 for (line, children) in tree: 254 result.append(line) 255 result.extend(flatten_tree(children)) 256 257 return result 258 259 260def create_defs(tree): 261 defs = [create_def(elt) for elt in tree] 262 defs = [d for d in defs if d is not None] 263 264 return defs 265 266 267def group_defs(defs): 268 """Takes a file broken into a series of definitions, and locates 269 multiple definitions of constants, caused by type signatures or 270 pattern matching, and accumulates to a single object per genuine 271 definition""" 272 defgroups = [] 273 defined = '' 274 for d in defs: 275 this_defines = d.defined 276 if d.type != 'definitions': 277 this_defines = '' 278 if this_defines == defined and this_defines: 279 defgroups[-1].body.extend(d.body) 280 else: 281 defgroups.append(d) 282 defined = this_defines 283 284 return defgroups 285 286 287def create_def(elt): 288 """Takes an element of an offside tree and creates 289 a definition object.""" 290 (line, n, children) = elt 291 children = discard_line_numbers(children) 292 return create_def_2(line, children, n) 293 294 295def create_def_2(line, children, n): 296 d = Def() 297 d.body = [(line, children)] 298 d.line = n 299 lead = line.split(None, 3) 300 if lead[0] in ['import', 'module', 'class']: 301 return 302 elif lead[0] == 'instance': 303 type = 'instance' 304 defined = lead[2] 305 elif lead[0] in ['type', 'newtype', 'data']: 306 type = 'newtype' 307 defined = lead[1] 308 else: 309 type = 'definitions' 310 defined = lead[0] 311 312 d.type = type 313 d.defined = defined 314 return d 315 316 317def get_primrecs(): 318 f = open('primrecs') 319 keys = [line.strip() for line in f] 320 return set(key for key in keys if key != '') 321 322 323primrecs = get_primrecs() 324 325 326def defs_transform(d): 327 """Transforms the set of definitions for a function. This 328 may include its type signature, and may include the special 329 case of multiple definitions.""" 330 # the first tokens of the first line in the first definition 331 if d.type == 'newtype': 332 return newtype_transform(d) 333 elif d.type == 'instance': 334 return instance_transform(d) 335 336 lead = d.body[0][0].split(None, 2) 337 if lead[1] == '::': 338 d.sig = type_sig_transform(d.body[0]) 339 d.body.pop(0) 340 341 if d.defined in primrecs: 342 return primrec_transform(d) 343 344 if len(d.body) > 1: 345 d.body = pattern_match_transform(d.body) 346 347 if len(d.body) == 0: 348 print() 349 print(d) 350 assert 0 351 352 d.body = body_transform(d.body, d.defined, d.sig) 353 return d 354 355 356def wrap_qualify(lines, deep=True): 357 if len(lines) == 0: 358 return lines 359 360 """Close and then re-open a locale so instantiations can go through""" 361 if deep: 362 asdfextra = "" 363 else: 364 asdfextra = "" 365 366 if call.current_context: 367 lines.insert(0, 'end\nqualify {} (in Arch) {}'.format(call.current_context[-1], 368 asdfextra)) 369 lines.append('end_qualify\ncontext Arch begin global_naming %s' % call.current_context[-1]) 370 return lines 371 372 373def def_lines(d, call): 374 """Produces the set of lines associated with a definition.""" 375 if call.all_bits: 376 L = [] 377 if d.comments: 378 L.extend(flatten_tree(d.comments)) 379 L.append('') 380 if d.type == 'definitions': 381 L.append('definition') 382 if d.sig: 383 L.extend(flatten_tree([d.sig])) 384 L.append('where') 385 L.extend(flatten_tree(d.body)) 386 elif d.type == 'newtype': 387 L.extend(flatten_tree(d.body)) 388 if d.instance_proofs: 389 L.extend(wrap_qualify(flatten_tree(d.instance_proofs))) 390 L.append('') 391 if d.instance_extras: 392 L.extend(flatten_tree(d.instance_extras)) 393 L.append('') 394 return L 395 396 if call.instanceproofs: 397 if not call.bodies_only: 398 instance_proofs = wrap_qualify(flatten_tree(d.instance_proofs)) 399 else: 400 instance_proofs = [] 401 402 if not call.decls_only: 403 instance_extras = flatten_tree(d.instance_extras) 404 else: 405 instance_extras = [] 406 407 newline_needed = len(instance_proofs) > 0 and len(instance_extras) > 0 408 return instance_proofs + ([''] 409 if newline_needed else []) + instance_extras 410 411 if call.body: 412 return get_lambda_body_lines(d) 413 414 comments = d.comments 415 try: 416 typesig = flatten_tree([d.sig]) 417 except: 418 typesig = [] 419 body = flatten_tree(d.body) 420 type = d.type 421 422 if type == 'definitions': 423 if call.decls_only: 424 if typesig: 425 return comments + ["consts'"] + typesig 426 else: 427 return [] 428 elif call.bodies_only: 429 if d.sig: 430 defname = '%s_def' % d.defined 431 if d.primrec: 432 print('warning body-only primrec:') 433 print(body[0]) 434 return comments + ['primrec'] + body 435 return comments + ['defs %s:' % defname] + body 436 else: 437 return comments + ['definition'] + body 438 else: 439 if d.primrec: 440 return comments + ['primrec'] + typesig \ 441 + ['where'] + body 442 if typesig: 443 return comments + ['definition'] + typesig + ['where'] + body 444 else: 445 return comments + ['definition'] + body 446 elif type == 'comments': 447 return comments 448 elif type == 'newtype': 449 if not call.bodies_only: 450 return body 451 452 453def type_sig_transform(tree_element): 454 """Performs transformations on a type signature line 455 preceding a function declaration or some such.""" 456 457 line = reduce_to_single_line(tree_element) 458 (pre, post) = line.split('::') 459 result = type_transform(post) 460 if '[pp' in result: 461 print(line) 462 print(pre) 463 print(post) 464 print(result) 465 assert 0 466 line = pre + ':: "' + result + '"' 467 468 return (line, []) 469 470 471ignore_classes = {'Error': 1} 472hand_classes = {'Bits': ['HS_bit'], 473 'Num': ['minus', 'one', 'zero', 'plus', 'numeral'], 474 'FiniteBits': ['finiteBit']} 475 476 477def type_transform(string): 478 """Performs transformations on a type signature, whether 479 part of a type signature line or occuring in a function.""" 480 481 # deal with type classes by recursion 482 bits = string.split('=>', 1) 483 if len(bits) == 2: 484 lhs = bits[0].strip() 485 if lhs.startswith('(') and lhs.endswith(')'): 486 instances = lhs[1:-1].split(',') 487 string = ' => '.join(instances + [bits[1]]) 488 else: 489 instances = [lhs] 490 var_annotes = {} 491 for instance in instances: 492 (name, var) = instance.split() 493 if name in ignore_classes: 494 continue 495 if name in hand_classes: 496 names = hand_classes[name] 497 else: 498 names = [type_conv(name)] 499 var = "'" + var 500 var_annotes.setdefault(var, []) 501 var_annotes[var].extend(names) 502 transformed = type_transform(bits[1]) 503 for (var, insts) in six.iteritems(var_annotes): 504 if len(insts) == 1: 505 newvar = '(%s :: %s)' % (var, insts[0]) 506 else: 507 newvar = '(%s :: {%s})' % (var, ', '.join(insts)) 508 transformed = newvar.join(transformed.split(var, 1)) 509 return transformed 510 511 # get rid of (), insert Unit, which converts to unit 512 string = 'Unit'.join(string.split('()')) 513 514 # divide up by -> or by , then divide on space. 515 # apply everything locally then work back up 516 bstring = braces.str(string, '(', ')') 517 bits = bstring.split('->') 518 r = ' \<Rightarrow> ' 519 if len(bits) == 1: 520 bits = bstring.split(',') 521 r = ' * ' 522 result = [type_bit_transform(bit) for bit in bits] 523 return r.join(result) 524 525 526def type_bit_transform(bit): 527 s = str(bit).strip() 528 if s.startswith('['): 529 # handling this properly is hard. 530 assert s.endswith(']') 531 bit2 = braces.str(s[1:-1], '(', ')') 532 return '%s list' % type_bit_transform(bit2) 533 bits = bit.split(None, braces=True) 534 if str(bits[0]) == 'PPtr': 535 assert len(bits) == 2 536 return 'machine_word' 537 if len(bits) > 1 and bits[1].startswith('['): 538 assert bits[-1].endswith(']') 539 arg = ' '.join([str(bit) for bit in bits[1:]])[1:-1] 540 arg = type_transform(arg) 541 return ' '.join([arg, 'list', str(type_conv(bits[0]))]) 542 bits = [type_conv(bit) for bit in bits] 543 bits = constructor_reversing(bits) 544 bits = [bit.map(type_transform) for bit in bits] 545 strs = [str(bit) for bit in bits] 546 return ' '.join(strs) 547 548 549def reduce_to_single_line(tree_element): 550 def inner(tree_element, acc): 551 (line, children) = tree_element 552 acc.append(line) 553 for child in children: 554 inner(child, acc) 555 return acc 556 return ' '.join(inner(tree_element, [])) 557 558 559type_conv_table = { 560 'Maybe': 'option', 561 'Bool': 'bool', 562 'Word': 'machine_word', 563 'Int': 'nat', 564 'String': 'unit list'} 565 566 567def type_conv(string): 568 """Converts a type used in Haskell to our equivalent""" 569 if string.startswith('('): 570 # ignore compound types, type_transform will descend into em 571 result = string 572 elif '.' in string: 573 # qualified references 574 bits = string.split('.') 575 typename = bits[-1] 576 module = reduce(lambda x, y: x + '.' + y, bits[:-1]) 577 typename = type_conv(typename) 578 result = module + '.' + typename 579 elif string[0].islower(): 580 # type variable 581 result = "'%s" % string 582 elif string[0] == '[' and string[-1] == ']': 583 # list 584 inner = type_conv(string[1:-1]) 585 result = '%s list' % inner 586 elif str(string) in type_conv_table: 587 result = type_conv_table[str(string)] 588 else: 589 # convert StudlyCaps to lower_with_underscores 590 was_lower = False 591 s = '' 592 for c in string: 593 if c.isupper() and was_lower: 594 s = s + '_' + c.lower() 595 else: 596 s = s + c.lower() 597 was_lower = c.islower() 598 result = s 599 type_conv_table[str(string)] = result 600 601 return braces.clone(result, string) 602 603 604def constructor_reversing(tokens): 605 if len(tokens) < 2: 606 return tokens 607 elif len(tokens) == 2: 608 return [tokens[1], tokens[0]] 609 elif tokens[0] == '[' and tokens[2] == ']': 610 return [tokens[1], braces.str('list', '(', ')')] 611 elif len(tokens) == 4 and tokens[1] == '[' and tokens[3] == ']': 612 listToken = braces.str('(List %s)' % tokens[2], '(', ')') 613 return [listToken, tokens[0]] 614 elif tokens[0] == 'array': 615 arrow_token = braces.str('\<Rightarrow>', '(', ')') 616 return [tokens[1], arrow_token, tokens[2]] 617 elif tokens[0] == 'either': 618 plus_token = braces.str('+', '(', ')') 619 return [tokens[1], plus_token, tokens[2]] 620 elif len(tokens) == 5 and tokens[2] == '[' and tokens[4] == ']': 621 listToken = braces.str('(List %s)' % tokens[3], '(', ')') 622 lbrack = braces.str('(', '+', '+') 623 rbrack = braces.str(')', '+', '+') 624 comma = braces.str(',', '+', '+') 625 return [lbrack, tokens[1], comma, listToken, rbrack, tokens[0]] 626 elif len(tokens) == 3: 627 # here comes a fudge 628 lbrack = braces.str('(', '+', '+') 629 rbrack = braces.str(')', '+', '+') 630 comma = braces.str(',', '+', '+') 631 return [lbrack, tokens[1], comma, tokens[2], rbrack, tokens[0]] 632 else: 633 print("Error parsing " + filename) 634 print("Can't deal with %s" % tokens) 635 sys.exit(1) 636 637 638def newtype_transform(d): 639 """Takes a Haskell style newtype/data type declaration, whose 640 options are divided with | and each of whose options has named 641 elements, and forms a datatype statement and definitions for 642 the named extractors and their update functions.""" 643 if len(d.body) != 1: 644 print('--- newtype long body ---') 645 print(d) 646 [(line, children)] = d.body 647 648 if children and children[-1][0].lstrip().startswith('deriving'): 649 l = reduce_to_single_line(children[-1]) 650 children = children[:-1] 651 r = re.compile(r"[,\s\(\)]+") 652 bits = r.split(l) 653 bits = [bit for bit in bits if bit and bit != 'deriving'] 654 d.deriving = bits 655 656 line = reduce_to_single_line((line, children)) 657 658 bits = line.split(None, 1) 659 op = bits[0] 660 line = bits[1] 661 bits = line.split('=', 1) 662 header = type_conv(bits[0].strip()) 663 d.typename = header 664 d.typedeps = set() 665 if len(bits) == 1: 666 # line of form 'data Blah' introduces unknown type? 667 d.body = [('typedecl %s' % header, [])] 668 all_type_arities[header] = [] # HACK 669 return d 670 line = bits[1] 671 672 if op == 'type': 673 # type synonym 674 return typename_transform(line, header, d) 675 elif line.find('{') == -1: 676 # not a record 677 return simple_newtype_transform(line, header, d) 678 else: 679 return named_newtype_transform(line, header, d) 680 681 682def typename_transform(line, header, d): 683 try: 684 [oldtype] = line.split() 685 except: 686 sys.stderr.write('Warning: type assignment with parameters not supported %s\n' % d.body) 687 call.bad_type_assignment = True 688 return 689 if oldtype.startswith('Data.Word.Word'): 690 # take off the prefix, leave Word32 or Word64 etc 691 oldtype = oldtype[10:] 692 oldtype = type_conv(oldtype) 693 bits = oldtype.split() 694 for bit in bits: 695 d.typedeps.add(bit) 696 lines = [ 697 'type_synonym %s = "%s"' % (header, oldtype), 698 # translations (* TYPE 1 *)', 699 # "%s" <=(type) "%s"' % (oldtype, header) 700 ] 701 d.body = [(line, []) for line in lines] 702 return d 703 704 705keep_conss = {} 706 707 708def simple_newtype_transform(line, header, d): 709 lines = [] 710 arities = [] 711 for i, bit in enumerate(line.split('|')): 712 braced = braces.str(bit, '(', ')') 713 bits = braced.split() 714 if len(bits) == 2: 715 last_lhs = bits[0] 716 717 if i == 0: 718 l = ' %s' % bits[0] 719 else: 720 l = ' | %s' % bits[0] 721 for bit in bits[1:]: 722 if bit.startswith('('): 723 bit = bit[1:-1] 724 typename = type_transform(str(bit)) 725 if len(bits) == 2: 726 last_rhs = typename 727 if ' ' in typename: 728 typename = '"%s"' % typename 729 l = l + ' ' + typename 730 d.typedeps.add(typename) 731 lines.append(l) 732 733 arities.append((str(bits[0]), len(bits[1:]))) 734 735 if list((dict(arities)).values()) == [1] and header not in keep_conss: 736 return type_wrapper_type(header, last_lhs, last_rhs, d) 737 738 d.body = [('datatype %s =' % header, [(line, []) for line in lines])] 739 740 set_instance_proofs(header, arities, d) 741 742 return d 743 744 745all_constructor_args = {} 746 747 748def named_newtype_transform(line, header, d): 749 bits = line.split('|') 750 751 constructors = [get_type_map(bit) for bit in bits] 752 all_constructor_args.update(dict(constructors)) 753 754 lines = [] 755 for i, (name, map) in enumerate(constructors): 756 if i == 0: 757 l = ' %s' % name 758 else: 759 l = ' | %s' % name 760 oname = name 761 for name, type in map: 762 if type is None: 763 print("ERR: None snuck into constructor list for %s" % name) 764 print(line, header, oname) 765 assert False 766 767 if name is None: 768 opt_name = "" 769 opt_close = "" 770 else: 771 opt_name = " (" + name + " :" 772 opt_close = ")" 773 774 if len(type.split()) == 1 and '(' not in type: 775 the_type = type 776 else: 777 the_type = '"' + type + '"' 778 779 l = l + opt_name + ' ' + the_type + opt_close 780 781 for bit in type.split(): 782 d.typedeps.add(bit) 783 lines.append(l) 784 785 names = {} 786 types = {} 787 for cons, map in constructors: 788 for i, (name, type) in enumerate(map): 789 names.setdefault(name, {}) 790 names[name][cons] = i 791 types[name] = type 792 793 for name, map in six.iteritems(names): 794 lines.append('') 795 lines.extend(named_update_definitions(name, map, types[name], header, 796 dict(constructors))) 797 798 for name, map in constructors: 799 if map == []: 800 continue 801 lines.append('') 802 lines.extend(named_constructor_translation(name, map, header)) 803 804 if len(constructors) > 1: 805 for name, map in constructors: 806 lines.append('') 807 check = named_constructor_check(name, map, header) 808 lines.extend(check) 809 810 if len(constructors) == 1: 811 for ex_name, _ in six.iteritems(names): 812 for up_name, _ in six.iteritems(names): 813 lines.append('') 814 lines.extend(named_extractor_update_lemma(ex_name, up_name)) 815 816 arities = [(name, len(map)) for (name, map) in constructors] 817 818 if list((dict(arities)).values()) == [1] and header not in keep_conss: 819 [(cons, map)] = constructors 820 [(name, type)] = map 821 return type_wrapper_type(header, cons, type, d, decons=(name, type)) 822 823 set_instance_proofs(header, arities, d) 824 825 d.body = [('datatype %s =' % header, [(line, []) for line in lines])] 826 return d 827 828 829def named_extractor_update_lemma(ex_name, up_name): 830 lines = [] 831 lines.append('lemma %s_%s_update [simp]:' % (ex_name, up_name)) 832 833 if up_name == ex_name: 834 lines.append(' "%s (%s_update f v) = f (%s v)"' % 835 (ex_name, up_name, ex_name)) 836 else: 837 lines.append(' "%s (%s_update f v) = %s v"' % 838 (ex_name, up_name, ex_name)) 839 840 lines.append(' by (cases v) simp') 841 842 return lines 843 844 845def named_extractor_definitions(name, map, type, header, constructors): 846 lines = [] 847 lines.append('primrec') 848 lines.append(' %s :: "%s \<Rightarrow> %s"' 849 % (name, header, type)) 850 lines.append('where') 851 is_first = True 852 for cons, i in six.iteritems(map): 853 if is_first: 854 l = ' "%s (%s' % (name, cons) 855 is_first = False 856 else: 857 l = '| "%s (%s' % (name, cons) 858 num = len(constructors[cons]) 859 for k in range(num): 860 l = l + ' v%d' % k 861 l = l + ') = v%d"' % i 862 lines.append(l) 863 864 return lines 865 866 867def named_update_definitions(name, map, type, header, constructors): 868 lines = [] 869 lines.append('primrec') 870 ra = '\<Rightarrow>' 871 if len(type.split()) > 1: 872 type = '(%s)' % type 873 lines.append(' %s_update :: "(%s %s %s) %s %s %s %s"' 874 % (name, type, ra, type, ra, header, ra, header)) 875 lines.append('where') 876 is_first = True 877 for cons, i in six.iteritems(map): 878 if is_first: 879 l = ' "%s_update f (%s' % (name, cons) 880 is_first = False 881 else: 882 l = '| "%s_update f (%s' % (name, cons) 883 num = len(constructors[cons]) 884 for k in range(num): 885 l = l + ' v%d' % k 886 l = l + ') = %s' % cons 887 for k in range(num): 888 if k == i: 889 l = l + ' (f v%d)' % k 890 else: 891 l = l + ' v%d' % k 892 l = l + '"' 893 lines.append(l) 894 895 return lines 896 897 898def named_constructor_translation(name, map, header): 899 lines = [] 900 lines.append('abbreviation (input)') 901 l = ' %s_trans :: "' % name 902 for n, type in map: 903 l = l + '(' + type + ') \<Rightarrow> ' 904 l = l + '%s" ("%s\'_ \<lparr> %s= _' % (header, name, map[0][0]) 905 for n, type in map[1:]: 906 l = l + ', %s= _' % n 907 l = l + ' \<rparr>")' 908 lines.append(l) 909 lines.append('where') 910 l = ' "%s_ \<lparr> %s= v0' % (name, map[0][0]) 911 for i, (n, type) in enumerate(map[1:]): 912 l = l + ', %s= v%d' % (n, i + 1) 913 l = l + ' \<rparr> == %s' % name 914 for i in range(len(map)): 915 l = l + ' v%d' % i 916 l = l + '"' 917 lines.append(l) 918 919 return lines 920 921 922def named_constructor_check(name, map, header): 923 lines = [] 924 lines.append('definition') 925 lines.append(' is%s :: "%s \<Rightarrow> bool"' % (name, header)) 926 lines.append('where') 927 lines.append(' "is%s v \<equiv> case v of' % name) 928 l = ' %s ' % name 929 for i, _ in enumerate(map): 930 l = l + 'v%d ' % i 931 l = l + '\<Rightarrow> True' 932 lines.append(l) 933 lines.append(' | _ \<Rightarrow> False"') 934 935 return lines 936 937 938def type_wrapper_type(header, cons, rhs, d, decons=None): 939 if '\\<Rightarrow>' in d.typedeps: 940 d.body = [('(* type declaration of %s omitted *)' % header, [])] 941 return d 942 lines = [ 943 'type_synonym %s = "%s"' % (header, rhs), 944 # translations (* TYPE 2 *)', 945 # "%s" <=(type) "%s"' % (header, rhs), 946 '', 947 'definition', 948 ' %s :: "%s \\<Rightarrow> %s"' % (cons, header, header), 949 'where %s_def[simp]:' % cons, 950 ' "%s \\<equiv> id"' % cons, 951 ] 952 if decons: 953 (decons, decons_type) = decons 954 lines.extend([ 955 '', 956 'definition', 957 ' %s :: "%s \\<Rightarrow> %s"' % (decons, header, header), 958 'where', 959 ' %s_def[simp]:' % decons, 960 ' "%s \\<equiv> id"' % decons, 961 '', 962 'definition' 963 ' %s_update :: "(%s \\<Rightarrow> %s) \\<Rightarrow> %s \\<Rightarrow> %s"' 964 % (decons, header, header, header, header), 965 'where', 966 ' %s_update_def[simp]:' % decons, 967 ' "%s_update f y \<equiv> f y"' % decons, 968 '' 969 ]) 970 lines.extend(named_constructor_translation(cons, [(decons, decons_type) 971 ], header)) 972 973 d.body = [(line, []) for line in lines] 974 return d 975 976 977def instance_transform(d): 978 [(line, children)] = d.body 979 bits = line.split(None, 3) 980 assert bits[0] == 'instance' 981 classname = bits[1] 982 typename = type_conv(bits[2]) 983 if classname == 'Show': 984 print("Warning: discarding class instance '%s :: Show'" % typename) 985 return None 986 if typename == '()': 987 print("Warning: discarding class instance 'unit :: %s'" % classname) 988 return None 989 if len(bits) == 3: 990 if children == []: 991 defs = [] 992 else: 993 [(l, c)] = children 994 assert l.strip() == 'where' 995 defs = c 996 else: 997 assert bits[3:] == ['where'] 998 defs = children 999 defs = [create_def_2(l, c, 0) for (l, c) in defs] 1000 defs = [d2 for d2 in defs if d2 is not None] 1001 defs = group_defs(defs) 1002 defs = [defs_transform(d2) for d2 in defs] 1003 defs_dict = {} 1004 for d2 in defs: 1005 if d2 is not None: 1006 defs_dict[d2.defined] = d2 1007 d.instance_defs = defs_dict 1008 d.deriving = [classname] 1009 if typename not in all_type_arities: 1010 sys.stderr.write('FAIL: attempting %s\n' % d.defined) 1011 sys.stderr.write('(typename %r)\n' % typename) 1012 sys.stderr.write('when reading %s\n' % filename) 1013 sys.stderr.write('but class not defined yet\n') 1014 sys.stderr.write('perhaps parse in different order?\n') 1015 sys.stderr.write('hint: #INCLUDE_HASKELL_PREPARSE\n') 1016 sys.exit(1) 1017 arities = all_type_arities[typename] 1018 set_instance_proofs(typename, arities, d) 1019 1020 return d 1021 1022 1023all_type_arities = {} 1024 1025 1026def set_instance_proofs(header, constructor_arities, d): 1027 all_type_arities[header] = constructor_arities 1028 pfs = [] 1029 exs = [] 1030 canonical = list(enumerate(constructor_arities)) 1031 1032 classes = d.deriving 1033 instance_proof_fns = set( 1034 sorted((instance_proof_table[classname] for classname in classes), 1035 key=lambda x: x.order)) 1036 for f in instance_proof_fns: 1037 (npfs, nexs) = f(header, canonical, d) 1038 pfs.extend(npfs) 1039 exs.extend(nexs) 1040 1041 if d.type == 'newtype' and len(canonical) == 1 and False: 1042 [(cons, n)] = constructor_arities 1043 if n == 1: 1044 pfs.extend(finite_instance_proofs(header, cons)) 1045 1046 if pfs: 1047 lead = '(* %s instance proofs *)' % header 1048 d.instance_proofs = [(lead, [(line, []) for line in pfs])] 1049 if exs: 1050 lead = '(* %s extra instance defs *)' % header 1051 d.instance_extras = [(lead, [(line, []) for line in exs])] 1052 1053 1054def finite_instance_proofs(header, cons): 1055 lines = [] 1056 lines.append('') 1057 lines.append('instance %s :: finite' % header) 1058 if call.current_context: 1059 lines.append('interpretation Arch .') 1060 lines.append(' apply (intro_classes)') 1061 lines.append(' apply (rule_tac f="%s" in finite_surj_type)' 1062 % cons) 1063 lines.append(' apply (safe, case_tac x, simp_all)') 1064 lines.append(' done') 1065 1066 return (lines, []) 1067 1068# wondering where the serialisable proofs went? see 1069# commit 21361f073bbafcfc985934e563868116810d9fa2 for last known occurence. 1070 1071 1072# leave type tags 0..11 for explicit use outside of this script 1073next_type_tag = 12 1074 1075 1076def storable_instance_proofs(header, canonical, d): 1077 proofs = [] 1078 extradefs = [] 1079 1080 global next_type_tag 1081 next_type_tag = next_type_tag + 1 1082 proofs.extend([ 1083 '', 'defs (overloaded)', ' typetag_%s[simp]:' % header, 1084 ' "typetag (x :: %s) \<equiv> %d"' % (header, next_type_tag), '' 1085 'instance %s :: dynamic' % header, ' by (intro_classes, simp)' 1086 ]) 1087 1088 proofs.append('') 1089 proofs.append('instance %s :: storable ..' % header) 1090 1091 defs = d.instance_defs 1092 extradefs.append('') 1093 if 'objBits' in defs: 1094 extradefs.append('definition') 1095 body = flatten_tree(defs['objBits'].body) 1096 bits = body[0].split('objBits') 1097 assert bits[0].strip() == '"' 1098 if bits[1].strip().startswith('_'): 1099 bits[1] = 'x ' + bits[1].strip()[1:] 1100 bits = bits[1].split(None, 1) 1101 body[0] = ' objBits_%s: "objBits (%s :: %s) %s' \ 1102 % (header, bits[0], header, bits[1]) 1103 extradefs.extend(body) 1104 1105 extradefs.append('') 1106 if 'makeObject' in defs: 1107 extradefs.append('definition') 1108 body = flatten_tree(defs['makeObject'].body) 1109 bits = body[0].split('makeObject') 1110 assert bits[0].strip() == '"' 1111 body[0] = ' makeObject_%s: "(makeObject :: %s) %s' \ 1112 % (header, header, bits[1]) 1113 extradefs.extend(body) 1114 1115 extradefs.extend(['', 'definition', ]) 1116 if 'loadObject' in defs: 1117 extradefs.append(' loadObject_%s:' % header) 1118 extradefs.extend(flatten_tree(defs['loadObject'].body)) 1119 else: 1120 extradefs.extend([ 1121 ' loadObject_%s[simp]:' % header, 1122 ' "(loadObject p q n obj) :: %s \<equiv>' % header, 1123 ' loadObject_default p q n obj"', 1124 ]) 1125 1126 extradefs.extend(['', 'definition', ]) 1127 if 'updateObject' in defs: 1128 extradefs.append(' updateObject_%s:' % header) 1129 body = flatten_tree(defs['updateObject'].body) 1130 bits = body[0].split('updateObject') 1131 assert bits[0].strip() == '"' 1132 bits = bits[1].split(None, 1) 1133 body[0] = ' "updateObject (%s :: %s) %s' \ 1134 % (bits[0], header, bits[1]) 1135 extradefs.extend(body) 1136 else: 1137 extradefs.extend([ 1138 ' updateObject_%s[simp]:' % header, 1139 ' "updateObject (val :: %s) \<equiv>' % header, 1140 ' updateObject_default val"', 1141 ]) 1142 1143 return (proofs, extradefs) 1144 1145 1146storable_instance_proofs.order = 1 1147 1148 1149def pspace_storable_instance_proofs(header, canonical, d): 1150 proofs = [] 1151 extradefs = [] 1152 1153 proofs.append('') 1154 proofs.append('instance %s :: pre_storable' % header) 1155 proofs.append(' by (intro_classes,') 1156 proofs.append( 1157 ' auto simp: projectKO_opts_defs split: kernel_object.splits arch_kernel_object.splits)') 1158 1159 defs = d.instance_defs 1160 extradefs.append('') 1161 if 'objBits' in defs: 1162 extradefs.append('definition') 1163 body = flatten_tree(defs['objBits'].body) 1164 bits = body[0].split('objBits') 1165 assert bits[0].strip() == '"' 1166 if bits[1].strip().startswith('_'): 1167 bits[1] = 'x ' + bits[1].strip()[1:] 1168 bits = bits[1].split(None, 1) 1169 body[0] = ' objBits_%s: "objBits (%s :: %s) %s' \ 1170 % (header, bits[0], header, bits[1]) 1171 extradefs.extend(body) 1172 1173 extradefs.append('') 1174 if 'makeObject' in defs: 1175 extradefs.append('definition') 1176 body = flatten_tree(defs['makeObject'].body) 1177 bits = body[0].split('makeObject') 1178 assert bits[0].strip() == '"' 1179 body[0] = ' makeObject_%s: "(makeObject :: %s) %s' \ 1180 % (header, header, bits[1]) 1181 extradefs.extend(body) 1182 1183 extradefs.extend(['', 'definition', ]) 1184 if 'loadObject' in defs: 1185 extradefs.append(' loadObject_%s:' % header) 1186 extradefs.extend(flatten_tree(defs['loadObject'].body)) 1187 else: 1188 extradefs.extend([ 1189 ' loadObject_%s[simp]:' % header, 1190 ' "(loadObject p q n obj) :: %s kernel \<equiv>' % header, 1191 ' loadObject_default p q n obj"', 1192 ]) 1193 1194 extradefs.extend(['', 'definition', ]) 1195 if 'updateObject' in defs: 1196 extradefs.append(' updateObject_%s:' % header) 1197 body = flatten_tree(defs['updateObject'].body) 1198 bits = body[0].split('updateObject') 1199 assert bits[0].strip() == '"' 1200 bits = bits[1].split(None, 1) 1201 body[0] = ' "updateObject (%s :: %s) %s' \ 1202 % (bits[0], header, bits[1]) 1203 extradefs.extend(body) 1204 else: 1205 extradefs.extend([ 1206 ' updateObject_%s[simp]:' % header, 1207 ' "updateObject (val :: %s) \<equiv>' % header, 1208 ' updateObject_default val"', 1209 ]) 1210 1211 return (proofs, extradefs) 1212 1213 1214pspace_storable_instance_proofs.order = 1 1215 1216 1217def num_instance_proofs(header, canonical, d): 1218 assert len(canonical) == 1 1219 [(_, (cons, n))] = canonical 1220 assert n == 1 1221 lines = [] 1222 1223 def add_bij_instance(classname, fns): 1224 ins = bij_instance(classname, header, cons, fns) 1225 lines.extend(ins) 1226 1227 add_bij_instance('plus', [('plus', '%s + %s', True)]) 1228 add_bij_instance('minus', [('minus', '%s - %s', True)]) 1229 add_bij_instance('zero', [('zero', '0', True)]) 1230 add_bij_instance('one', [('one', '1', True)]) 1231 add_bij_instance('times', [('times', '%s * %s', True)]) 1232 1233 return (lines, []) 1234 1235 1236num_instance_proofs.order = 2 1237 1238 1239def enum_instance_proofs(header, canonical, d): 1240 def singular_canonical(): 1241 if len(canonical) == 1: 1242 [(_, (_, n))] = canonical 1243 return n == 1 1244 else: 1245 return False 1246 1247 lines = ['(*<*)'] 1248 if singular_canonical(): 1249 [(_, (cons, n))] = canonical 1250 # special case for datatypes with single constructor with one argument 1251 lines.append('instantiation %s :: enum begin' % header) 1252 if call.current_context: 1253 lines.append('interpretation Arch .') 1254 lines.append('definition') 1255 lines.append(' enum_%s: "enum_class.enum \<equiv> map %s enum"' 1256 % (header, cons)) 1257 1258 else: 1259 cons_no_args = [cons for i, (cons, n) in canonical if n == 0] 1260 cons_one_arg = [cons for i, (cons, n) in canonical if n == 1] 1261 cons_two_args = [cons for i, (cons, n) in canonical if n > 1] 1262 assert cons_two_args == [] 1263 lines.append('instantiation %s :: enum begin' % header) 1264 if call.current_context: 1265 lines.append('interpretation Arch .') 1266 lines.append('definition') 1267 lines.append(' enum_%s: "enum_class.enum \<equiv> ' % header) 1268 if len(cons_no_args) == 0: 1269 lines.append(' []') 1270 else: 1271 lines.append(' [ ') 1272 for cons in cons_no_args[:-1]: 1273 lines.append(' %s,' % cons) 1274 for cons in cons_no_args[-1:]: 1275 lines.append(' %s' % cons) 1276 lines.append(' ]') 1277 for cons in cons_one_arg: 1278 lines.append(' @ (map %s enum)' % cons) 1279 lines[-1] = lines[-1] + '"' 1280 lines.append('') 1281 for cons in cons_one_arg: 1282 lines.append('lemma %s_map_distinct[simp]: "distinct (map %s enum)"' % (cons, cons)) 1283 lines.append(' apply (simp add: distinct_map)') 1284 lines.append(' by (meson injI %s.inject)' % header) 1285 lines.append('') 1286 lines.append('definition') 1287 lines.append(' "enum_class.enum_all (P :: %s \<Rightarrow> bool) \<longleftrightarrow> Ball UNIV P"' 1288 % header) 1289 lines.append('') 1290 lines.append('definition') 1291 lines.append(' "enum_class.enum_ex (P :: %s \<Rightarrow> bool) \<longleftrightarrow> Bex UNIV P"' 1292 % header) 1293 lines.append('') 1294 lines.append(' instance') 1295 lines.append(' apply intro_classes') 1296 lines.append(' apply (safe, simp)') 1297 lines.append(' apply (case_tac x)') 1298 if len(canonical) == 1: 1299 lines.append(' apply (auto simp: enum_%s enum_all_%s_def enum_ex_%s_def' 1300 % (header, header, header)) 1301 lines.append(' distinct_map_enum)') 1302 lines.append(' done') 1303 else: 1304 lines.append(' apply (simp_all add: enum_%s enum_all_%s_def enum_ex_%s_def)' 1305 % (header, header, header)) 1306 lines.append(' by fast+') 1307 lines.append('end') 1308 lines.append('') 1309 lines.append('instantiation %s :: enum_alt' % header) 1310 lines.append('begin') 1311 if call.current_context: 1312 lines.append('interpretation Arch .') 1313 lines.append('definition') 1314 lines.append(' enum_alt_%s: "enum_alt \<equiv> ' % header) 1315 lines.append(' alt_from_ord (enum :: %s list)"' % header) 1316 lines.append('instance ..') 1317 lines.append('end') 1318 lines.append('') 1319 lines.append('instantiation %s :: enumeration_both' % header) 1320 lines.append('begin') 1321 if call.current_context: 1322 lines.append('interpretation Arch .') 1323 lines.append('instance by (intro_classes, simp add: enum_alt_%s)' 1324 % header) 1325 lines.append('end') 1326 lines.append('') 1327 lines.append('(*>*)') 1328 1329 return (lines, []) 1330 1331 1332enum_instance_proofs.order = 3 1333 1334 1335def bits_instance_proofs(header, canonical, d): 1336 assert len(canonical) == 1 1337 [(_, (cons, n))] = canonical 1338 assert n == 1 1339 1340 return (bij_instance('bit', header, cons, 1341 [('shiftL', 'shiftL %s n', True), 1342 ('shiftR', 'shiftR %s n', True), 1343 ('bitSize', 'bitSize %s', False)]), []) 1344 1345 1346bits_instance_proofs.order = 5 1347 1348 1349def no_proofs(header, canonical, d): 1350 return ([], []) 1351 1352 1353no_proofs.order = 6 1354 1355# FIXME "Bounded" emits enum proofs even if something already has enum proofs 1356# generated due to "Enum" 1357 1358instance_proof_table = { 1359 'Eq': no_proofs, 1360 'Bounded': no_proofs, # enum_instance_proofs, 1361 'Enum': enum_instance_proofs, 1362 'Ix': no_proofs, 1363 'Ord': no_proofs, 1364 'Show': no_proofs, 1365 'Bits': bits_instance_proofs, 1366 'Real': no_proofs, 1367 'Num': num_instance_proofs, 1368 'Integral': no_proofs, 1369 'Storable': storable_instance_proofs, 1370 'PSpaceStorable': pspace_storable_instance_proofs, 1371 'Typeable': no_proofs, 1372 'Error': no_proofs, 1373} 1374 1375 1376def bij_instance(classname, typename, constructor, fns): 1377 L = [ 1378 '', 1379 'instance %s :: %s ..' % (typename, classname), 1380 'defs (overloaded)' 1381 ] 1382 for (fname, fstr, cast_return) in fns: 1383 n = len(fstr.split('%s')) - 1 1384 names = ('x', 'y', 'z', 'w')[:n] 1385 names2 = tuple([name + "'" for name in names]) 1386 fstr1 = fstr % names 1387 fstr2 = fstr % names2 1388 L.append(' %s_%s: "%s \<equiv>' % (fname, typename, fstr1)) 1389 for name in names: 1390 L.append(" case %s of %s %s' \<Rightarrow>" 1391 % (name, constructor, name)) 1392 if cast_return: 1393 L.append(' %s (%s)"' % (constructor, fstr2)) 1394 else: 1395 L.append(' %s"' % fstr2) 1396 1397 return L 1398 1399 1400def get_type_map(string): 1401 """Takes Haskell named record syntax and produces 1402 a map (in the form of a list of tuples) defining 1403 the converted types of the names.""" 1404 bits = string.split(None, 1) 1405 header = bits[0].strip() 1406 if len(bits) == 1: 1407 return (header, []) 1408 actual_map = bits[1].strip() 1409 if not (actual_map.startswith('{') and actual_map.endswith('}')): 1410 print('Error in ' + filename) 1411 print('Expected "A { blah :: blah etc }"') 1412 print('However { and } not found correctly') 1413 print('When parsing %s' % string) 1414 sys.exit(1) 1415 actual_map = actual_map[1:-1] 1416 1417 bits = braces.str(actual_map, '(', ')').split(',') 1418 bits.reverse() 1419 type = None 1420 map = [] 1421 for bit in bits: 1422 bits = bit.split('::') 1423 if len(bits) == 2: 1424 type = type_transform(str(bits[1]).strip()) 1425 name = str(bits[0]).strip() 1426 else: 1427 name = str(bit).strip() 1428 map.append((name, type)) 1429 map.reverse() 1430 return (header, map) 1431 1432 1433numLiftIO = [0] 1434 1435 1436def body_transform(body, defined, sig, nopattern=False): 1437 # assume single object 1438 [(line, children)] = body 1439 1440 if '(' in line.split('=')[0] and not nopattern: 1441 [(line, children)] = \ 1442 pattern_match_transform([(line, children)]) 1443 1444 if 'liftIO' in reduce_to_single_line((line, children)): 1445 # liftIO is the translation boundary for current 1446 # SEL4, below which we get into details of interaction 1447 # with the foreign function interface - axiomatise 1448 assert '=' in line 1449 line = line.split('=')[0] 1450 bits = line.split() 1451 numLiftIO[0] = numLiftIO[0] + 1 1452 rhs = '(%d :: Int) %s' % (numLiftIO[0], ' '.join(bits[1:])) 1453 line = '%s\<equiv> underlying_arch_op %s' % (line, rhs) 1454 children = [] 1455 elif '=' in line: 1456 line = '\<equiv>'.join(line.split('=', 1)) 1457 elif leading_bar.match(children[0][0]): 1458 pass 1459 elif '=' in children[0][0]: 1460 (nextline, c2) = children[0] 1461 children[0] = ('\<equiv>'.join(nextline.split('=', 1)), c2) 1462 else: 1463 sys.stderr.write('WARN: def of %s missing =\n' % defined) 1464 1465 if children and (children[-1][0].endswith('where') or 1466 children[-1][0].lstrip().startswith('where')): 1467 bits = line.split('\<equiv>') 1468 where_clause = where_clause_transform(children[-1]) 1469 children = children[:-1] 1470 if len(bits) == 2 and bits[1].strip(): 1471 line = bits[0] + '\<equiv>' 1472 new_line = ' ' * len(line) + bits[1] 1473 children = [(new_line, children)] 1474 else: 1475 where_clause = [] 1476 1477 (line, children) = zipWith_transforms(line, children) 1478 1479 (line, children) = supplied_transforms(line, children) 1480 1481 (line, children) = case_clauses_transform((line, children)) 1482 1483 (line, children) = do_clauses_transform((line, children), sig) 1484 1485 if children and leading_bar.match(children[0][0]): 1486 line = line + ' \<equiv>' 1487 children = guarded_body_transform(children, ' = ') 1488 1489 children = where_clause + children 1490 1491 if not nopattern: 1492 line = lhs_transform(line) 1493 line = lhs_de_underscore(line) 1494 1495 (line, children) = type_assertion_transform(line, children) 1496 1497 (line, children) = run_regexes((line, children)) 1498 (line, children) = run_ext_regexes((line, children)) 1499 1500 (line, children) = bracket_dollar_lambdas((line, children)) 1501 1502 line = '"' + line 1503 (line, children) = add_trailing_string('"', (line, children)) 1504 return [(line, children)] 1505 1506 1507dollar_lambda_regex = re.compile(r"\$\s*\\<lambda>") 1508 1509 1510def bracket_dollar_lambdas(xxx_todo_changeme): 1511 (line, children) = xxx_todo_changeme 1512 if dollar_lambda_regex.search(line): 1513 [left, right] = dollar_lambda_regex.split(line) 1514 line = '%s(\<lambda>%s' % (left, right) 1515 both = (line, children) 1516 if has_trailing_string(';', both): 1517 both = remove_trailing_string(';', both) 1518 (line, children) = add_trailing_string(');', both) 1519 else: 1520 (line, children) = add_trailing_string(')', both) 1521 children = [bracket_dollar_lambdas(elt) for elt in children] 1522 return (line, children) 1523 1524 1525def zipWith_transforms(line, children): 1526 if 'zipWithM_' not in line: 1527 children = [zipWith_transforms(l, c) for (l, c) in children] 1528 return (line, children) 1529 1530 if children == []: 1531 return (line, []) 1532 1533 if len(children) == 2: 1534 [(l, c), (l2, c2)] = children 1535 if c == [] and c2 == [] and '..]' in l + l2: 1536 children = [(l + ' ' + l2.strip(), [])] 1537 1538 (l, c) = children[-1] 1539 if c != [] or '..]' not in l: 1540 return (line, children) 1541 1542 bits = line.split('zipWithM_', 1) 1543 line = bits[0] + 'mapM_' 1544 ws = lead_ws(l) 1545 line2 = ws + '(split ' + bits[1] 1546 1547 bits = braces.str(l, '[', ']').split(None, braces=True) 1548 1549 line3 = ws + ' '.join(bits[:-2]) + ')' 1550 used_children = children[:-1] + [(line3, [])] 1551 1552 sndlast = bits[-2] 1553 last = bits[-1] 1554 if last.endswith('..]'): 1555 internal = last[1:-3].strip() 1556 if ',' in internal: 1557 bits = internal.split(',') 1558 l = '%s(zipE4 (%s) (%s) (%s))' \ 1559 % (ws, sndlast, bits[0], bits[-1]) 1560 else: 1561 l = '%s(zipE3 (%s) (%s))' % (ws, sndlast, internal) 1562 else: 1563 internal = sndlast[1:-3].strip() 1564 if ',' in internal: 1565 bits = internal.split(',') 1566 l = '%s(zipE2 (%s) (%s) (%s))' \ 1567 % (ws, bits[0], bits[1], last) 1568 else: 1569 l = '%s(zipE1 (%s) (%s))' % (ws, internal, last) 1570 1571 return (line, [(line2, used_children), (l, [])]) 1572 1573 1574def supplied_transforms(line, children): 1575 t = convert_to_stripped_tuple(line, children) 1576 1577 if t in supplied_transform_table: 1578 ws1 = lead_ws(line) 1579 result = supplied_transform_table[t] 1580 ws2 = lead_ws(result[0]) 1581 result = adjust_ws(result, len(ws1) - len(ws2)) 1582 supplied_transforms_usage[t] = 1 1583 return result 1584 1585 children = [supplied_transforms(l, c) for (l, c) in children] 1586 1587 return (line, children) 1588 1589 1590def convert_to_stripped_tuple(line, children): 1591 children = [convert_to_stripped_tuple(l, c) for (l, c) in children] 1592 1593 return (line.strip(), tuple(children)) 1594 1595 1596def type_assertion_transform_inner(line): 1597 m = type_assertion.search(line) 1598 if not m: 1599 return line 1600 var = m.expand('\\1') 1601 type = m.expand('\\2').strip() 1602 type = type_transform(type) 1603 return line[:m.start()] + '(%s::%s)' % (var, type) \ 1604 + type_assertion_transform_inner(line[m.end():]) 1605 1606 1607def type_assertion_transform(line, children): 1608 children = [type_assertion_transform(l, c) for (l, c) in children] 1609 1610 return (type_assertion_transform_inner(line), children) 1611 1612 1613def where_clause_guarded_body(xxx_todo_changeme1): 1614 (line, children) = xxx_todo_changeme1 1615 if children and leading_bar.match(children[0][0]): 1616 return (line + ' =', guarded_body_transform(children, ' = ')) 1617 else: 1618 return (line, children) 1619 1620 1621def where_clause_transform(xxx_todo_changeme2): 1622 (line, children) = xxx_todo_changeme2 1623 ws = line.split('where', 1)[0] 1624 if line.strip() != 'where': 1625 assert line.strip().startswith('where') 1626 children = [(''.join(line.split('where', 1)), [])] + children 1627 pre = ws + 'let' 1628 post = ws + 'in' 1629 1630 children = [(l, c) for (l, c) in children if l.split()[1] != '::'] 1631 children = [case_clauses_transform((l, c)) for (l, c) in children] 1632 children = [do_clauses_transform( 1633 (l, c), 1634 None, 1635 type=0) for (l, c) in children] 1636 children = list(map(where_clause_guarded_body, children)) 1637 for i, (l, c) in enumerate(children): 1638 l2 = braces.str(l, '(', ')') 1639 if len(l2.split('=')[0].split()) > 1: 1640 # turn f a = b into f = (\a -> b) 1641 l = '->'.join(l.split('=', 1)) 1642 l = lead_ws(l) + ' = (\\ '.join(l.split(None, 1)) 1643 (l, c) = add_trailing_string(')', (l, c)) 1644 children[i] = (l, c) 1645 children = order_let_children(children) 1646 for i, child in enumerate(children[:-1]): 1647 children[i] = add_trailing_string(';', child) 1648 return [(pre, [])] + children + [(post, [])] 1649 1650 1651varname_re = re.compile(r"\w+") 1652 1653 1654def order_let_children(L): 1655 single_lines = [reduce_to_single_line(elt) for elt in L] 1656 keys = [str(braces.str(line, '(', ')').split('=')[0]).split()[0] 1657 for line in single_lines] 1658 keys = dict((key, n) for (n, key) in enumerate(keys)) 1659 bits = [varname_re.findall(line) for line in single_lines] 1660 deps = {} 1661 for i, bs in enumerate(bits): 1662 for bit in bs: 1663 if bit in keys: 1664 j = keys[bit] 1665 if j != i: 1666 deps.setdefault(i, {})[j] = 1 1667 done = {} 1668 L2 = [] 1669 todo = dict(enumerate(L)) 1670 n = len(todo) 1671 while n: 1672 todo_keys = list(todo.keys()) 1673 for key in todo_keys: 1674 depstodo = [dep 1675 for dep in list(deps.get(key, {}).keys()) if dep not in done] 1676 if depstodo == []: 1677 L2.append(todo.pop(key)) 1678 done[key] = 1 1679 if len(todo) == n: 1680 print("No progress resolving let deps") 1681 print() 1682 print(list(todo.values())) 1683 print() 1684 print("Dependencies:") 1685 print(deps) 1686 assert 0 1687 n = len(todo) 1688 return L2 1689 1690 1691def do_clauses_transform(xxx_todo_changeme3, rawsig, type=None): 1692 (line, children) = xxx_todo_changeme3 1693 if children and children[-1][0].lstrip().startswith('where'): 1694 where_clause = where_clause_transform(children[-1]) 1695 where_clause = [do_clauses_transform( 1696 (l, c), rawsig, 0) for (l, c) in where_clause] 1697 others = (line, children[:-1]) 1698 others = do_clauses_transform(others, rawsig, type) 1699 (line, children) = where_clause[0] 1700 return (line, children + where_clause[1:] + [others]) 1701 1702 if children: 1703 (l, c) = children[0] 1704 if c == [] and l.endswith('do'): 1705 line = line + ' ' + l.strip() 1706 children = children[1:] 1707 1708 if type is None: 1709 if not rawsig: 1710 type = 0 1711 sig = None 1712 else: 1713 sig = ' '.join(flatten_tree([rawsig])) 1714 type = monad_type_acquire(sig) 1715 (line, type) = monad_type_transform((line, type)) 1716 if children == []: 1717 return (line, []) 1718 1719 rhs = line.split('<-', 1)[-1] 1720 if rhs.strip() == 'syscall' or rhs.strip() == 'atomicSyscall': 1721 assert len(children) == 5 1722 children = [do_clauses_transform(elt, 1723 rawsig, 1724 type=subtype) 1725 for elt, subtype in zip(children, [1, 0, 1, 0, type])] 1726 elif line.strip().endswith('catchFailure'): 1727 assert len(children) == 2 1728 children = [do_clauses_transform(elt, 1729 rawsig, 1730 type=subtype) 1731 for elt, subtype in zip(children, [1, 0])] 1732 else: 1733 children = [do_clauses_transform(elt, 1734 rawsig, 1735 type=type) for elt in children] 1736 1737 if not line.endswith('do'): 1738 return (line, children) 1739 1740 children, other_children = split_on_unmatched_bracket(children) 1741 1742 # single statement do clause won't parse in Isabelle 1743 if len(children) == 1: 1744 ws = lead_ws(line) 1745 return (line[:-2] + '(', children + [(ws + ')', [])]) 1746 1747 line = line[:-2] + '(do' + 'E' * type 1748 1749 children = [(l, c) for (l, c) in children if l.strip() or c] 1750 1751 children2 = [] 1752 for (l, c) in children: 1753 if l.lstrip().startswith('let '): 1754 if '=' not in l: 1755 extra = reduce_to_single_line(c[0]) 1756 assert '=' in extra 1757 l = l + ' ' + extra 1758 c = c[1:] 1759 l = ''.join(l.split('let ', 1)) 1760 letsubs = '<- return' + 'Ok' * type + ' (' 1761 l = letsubs.join(l.split('=', 1)) 1762 (l, c) = add_trailing_string(')', (l, c)) 1763 children2.extend(do_clause_pattern(l, c, type)) 1764 else: 1765 children2.extend(do_clause_pattern(l, c, type)) 1766 1767 children = [add_trailing_string(';', child) 1768 for child in children2[:-1]] + [children2[-1]] 1769 1770 ws = lead_ws(line) 1771 children.append((ws + 'od' + 'E' * type + ')', [])) 1772 1773 return (line, children + other_children) 1774 1775 1776left_start_table = { 1777 'ASIDPool': '(inv ASIDPool)', 1778 'HardwareASID': 'id', 1779 'ArchObjectCap': 'capCap', 1780 'Just': 'the' 1781} 1782 1783 1784def do_clause_pattern(line, children, type, n=0): 1785 bits = line.split('<-') 1786 default = [('\<leftarrow>'.join(bits), children)] 1787 if len(bits) == 1: 1788 return default 1789 (left, right) = line.split('<-', 1) 1790 if ':' not in left and '[' not in left \ 1791 and len(left.split()) == 1: 1792 return default 1793 left = left.strip() 1794 1795 v = 'v%d' % get_next_unique_id() 1796 1797 ass = 'assert' + ('E' * type) 1798 ws = lead_ws(line) 1799 1800 if left.startswith('('): 1801 assert left.endswith(')') 1802 if (',' in left): 1803 return default 1804 else: 1805 left = left[1:-1] 1806 bs = braces.str(left, '[', ']') 1807 if len(bs.split(':')) > 1: 1808 bits = [str(bit).strip() for bit in bs.split(':', 1)] 1809 lines = [('%s%s <- %s' % (ws, v, right), children), 1810 ('%s%s <- headM %s' % (ws, bits[0], v), []), 1811 ('%s%s <- tailM %s' % (ws, bits[1], v), [])] 1812 result = [] 1813 for (l, c) in lines: 1814 result.extend(do_clause_pattern(l, c, type, n + 1)) 1815 return result 1816 if left == '[]': 1817 return [('%s%s <- %s' % (ws, v, right), children), 1818 ('%s%s (%s = []) []' % (ws, ass, v), [])] 1819 if left.startswith('['): 1820 assert left.endswith(']') 1821 bs = braces.str(left[1:-1], '[', ']') 1822 bits = bs.split(',', 1) 1823 if len(bits) == 1: 1824 new_left = '%s:%s' % (bits[0], v) 1825 new_line = '%s%s <- %s' % (ws, new_left, right) 1826 extra = [('%s%s (%s = []) []' % (ws, ass, v), [])] 1827 else: 1828 new_left = '%s:[%s]' % (bits[0], bits[1]) 1829 new_line = lead_ws(line) + new_left + ' <- ' + right 1830 extra = [] 1831 return do_clause_pattern(new_line, children, type, n + 1) \ 1832 + extra 1833 for lhs in left_start_table: 1834 if left.startswith(lhs): 1835 left = left[len(lhs):] 1836 tab = left_start_table[lhs] 1837 lM = 'liftM' + 'E' * type 1838 nl = ('%s <- %s %s $ %s' % (left, lM, tab, right)) 1839 return do_clause_pattern(nl, children, type, n + 1) 1840 1841 print(line) 1842 print(left) 1843 assert 0 1844 1845 1846def split_on_unmatched_bracket(elts, n=None): 1847 if n is None: 1848 elts, other_elts, n = split_on_unmatched_bracket(elts, 0) 1849 return (elts, other_elts) 1850 1851 for (i, (line, children)) in enumerate(elts): 1852 for (j, c) in enumerate(line): 1853 if c == '(': 1854 n = n + 1 1855 if c == ')': 1856 n = n - 1 1857 if n < 0: 1858 frag1 = line[:j] 1859 frag2 = ' ' * len(frag1) + line[j:] 1860 new_elts = elts[:i] + [(frag1, [])] 1861 oth_elts = [(frag2, children)] \ 1862 + elts[i + 1:] 1863 return (new_elts, oth_elts, n) 1864 c, other_c, n = split_on_unmatched_bracket(children, n) 1865 if other_c: 1866 new_elts = elts[:i] + [(line, c)] 1867 other_elts = other_c + elts[i + 1:] 1868 return (new_elts, other_elts, n) 1869 1870 return (elts, [], n) 1871 1872 1873def monad_type_acquire(sig, type=0): 1874 # note kernel appears after kernel_f/kernel_monad 1875 for (key, n) in [('kernel_f', 1), ('fault_monad', 1), ('syscall_monad', 2), 1876 ('kernel_monad', 0), ('kernel_init', 1), ('kernel_p', 1), 1877 ('kernel', 0)]: 1878 if key in sig: 1879 sigend = sig.split(key)[-1] 1880 return monad_type_acquire(sigend, n) 1881 1882 return type 1883 1884 1885def monad_type_transform(xxx_todo_changeme4): 1886 (line, type) = xxx_todo_changeme4 1887 split = None 1888 if 'withoutError' in line: 1889 split = 'withoutError' 1890 newtype = 1 1891 elif 'doKernelOp' in line: 1892 split = 'doKernelOp' 1893 newtype = 0 1894 elif 'runInit' in line: 1895 split = 'runInit' 1896 newtype = 1 1897 elif 'withoutFailure' in line: 1898 split = 'withoutFailure' 1899 newtype = 0 1900 elif 'withoutFault' in line: 1901 split = 'withoutFault' 1902 newtype = 0 1903 elif 'withoutPreemption' in line: 1904 split = 'withoutPreemption' 1905 newtype = 0 1906 elif 'allowingFaults' in line: 1907 split = 'allowingFaults' 1908 newtype = 1 1909 elif 'allowingErrors' in line: 1910 split = 'allowingErrors' 1911 newtype = 2 1912 elif '`catchFailure`' in line: 1913 [left, right] = line.split('`catchFailure`', 1) 1914 left, _ = monad_type_transform((left, 1)) 1915 right, type = monad_type_transform((right, 0)) 1916 return (left + '`catchFailure`' + right, type) 1917 elif 'catchingFailure' in line: 1918 split = 'catchingFailure' 1919 newtype = 1 1920 elif 'catchF' in line: 1921 split = 'catchF' 1922 newtype = 1 1923 elif 'emptyOnFailure' in line: 1924 split = 'emptyOnFailure' 1925 newtype = 1 1926 elif 'constOnFailure' in line: 1927 split = 'constOnFailure' 1928 newtype = 1 1929 elif 'nothingOnFailure' in line: 1930 split = 'nothingOnFailure' 1931 newtype = 1 1932 elif 'nullCapOnFailure' in line: 1933 split = 'nullCapOnFailure' 1934 newtype = 1 1935 elif '`catchFault`' in line: 1936 split = '`catchFault`' 1937 newtype = 1 1938 elif 'capFaultOnFailure' in line: 1939 split = 'capFaultOnFailure' 1940 newtype = 1 1941 elif 'ignoreFailure' in line: 1942 split = 'ignoreFailure' 1943 newtype = 1 1944 elif 'handleInvocation False' in line: # THIS IS A HACK 1945 split = 'handleInvocation False' 1946 newtype = 0 1947 if split: 1948 [left, right] = line.split(split, 1) 1949 left, _ = monad_type_transform((left, type)) 1950 right, newnewtype = monad_type_transform((right, newtype)) 1951 return (left + split + right, newnewtype) 1952 1953 if type: 1954 line = ('return' + 'Ok' * type).join(line.split('return')) 1955 line = ('when' + 'E' * type).join(line.split('when')) 1956 line = ('unless' + 'E' * type).join(line.split('unless')) 1957 line = ('mapM' + 'E' * type).join(line.split('mapM')) 1958 line = ('forM' + 'E' * type).join(line.split('forM')) 1959 line = ('liftM' + 'E' * type).join(line.split('liftM')) 1960 line = ('assert' + 'E' * type).join(line.split('assert')) 1961 line = ('stateAssert' + 'E' * type).join(line.split('stateAssert')) 1962 1963 return (line, type) 1964 1965 1966def case_clauses_transform(xxx_todo_changeme5): 1967 (line, children) = xxx_todo_changeme5 1968 children = [case_clauses_transform(child) for child in children] 1969 1970 if not line.endswith(' of'): 1971 return (line, children) 1972 1973 bits = line.split('case ', 1) 1974 beforecase = bits[0] 1975 x = bits[1][:-3] 1976 1977 if '::' in x: 1978 x2 = braces.str(x, '(', ')') 1979 bits = x2.split('::', 1) 1980 if len(bits) == 2: 1981 x = str(bits[0]) + ':: ' + type_transform(str(bits[1])) 1982 1983 if children and children[-1][0].strip().startswith('where'): 1984 sys.stderr.write('Warning: where clause in case: %r\n' 1985 % line) 1986 return (beforecase + '\<comment> \<open>case removed\<close> undefined', []) 1987 # where_clause = where_clause_transform(children[-1]) 1988 # children = children[:-1] 1989 # (in_stmt, l) = where_clause[-1] 1990 # l.append(case_clauses_transform((line, children))) 1991 # return where_clause 1992 1993 cases = [] 1994 bodies = [] 1995 for (l, c) in children: 1996 bits = l.split('->', 1) 1997 while len(bits) == 1: 1998 if c == []: 1999 sys.stderr.write('wtf %r\n' % l) 2000 sys.exit(1) 2001 if c[0][0].strip().startswith('|'): 2002 bits = [bits[0], ''] 2003 c = guarded_body_transform(c, '->') 2004 elif c[0][1] == []: 2005 l = l + ' ' + c.pop(0)[0].strip() 2006 bits = l.split('->', 1) 2007 else: 2008 [(moreline, c)] = c 2009 l = l + ' ' + moreline.strip() 2010 bits = l.split('->', 1) 2011 case = bits[0].strip() 2012 tail = bits[1] 2013 if c and c[-1][0].lstrip().startswith('where'): 2014 where_clause = where_clause_transform(c[-1]) 2015 ws = lead_ws(where_clause[0][0]) 2016 c = where_clause + [(ws + tail.strip(), [])] + c[:-1] 2017 tail = '' 2018 cases.append(case) 2019 bodies.append((tail, c)) 2020 2021 cases = tuple(cases) # used as key of a dict later, needs to be hashable 2022 # (since lists are mutable, they aren't) 2023 if not cases: 2024 print(line) 2025 conv = get_case_conv(cases) 2026 if conv == '<X>': 2027 sys.stderr.write('Warning: blanked case in caseconvs\n') 2028 return (beforecase + '\<comment> \<open>case removed\<close> undefined', []) 2029 if not conv: 2030 sys.stderr.write('Warning: case %r\n' % (cases, )) 2031 if cases not in cases_added: 2032 casestr = 'case \\x of ' + ' -> '.join(cases) + ' -> ' 2033 2034 f = open('caseconvs', 'a') 2035 f.write('%s ---X>\n\n' % casestr) 2036 f.close() 2037 cases_added[cases] = 1 2038 return (beforecase + '\<comment> \<open>case removed\<close> undefined', []) 2039 conv = subs_nums_and_x(conv, x) 2040 2041 new_line = beforecase + '(' + conv[0][0] 2042 assert conv[0][1] is None 2043 2044 ws = lead_ws(children[0][0]) 2045 new_children = [] 2046 for (header, endnum) in conv[1:]: 2047 if endnum is None: 2048 new_children.append((ws + header, [])) 2049 else: 2050 if (len(bodies) <= endnum): 2051 sys.stderr.write('ERROR: index %d out of bounds in case %r\n' % 2052 (endnum, 2053 cases, )) 2054 sys.exit(1) 2055 2056 (body, c) = bodies[endnum] 2057 new_children.append((ws + header + ' ' + body, c)) 2058 2059 if has_trailing_string(',', new_children[-1]): 2060 new_children[-1] = \ 2061 remove_trailing_string(',', new_children[-1]) 2062 new_children.append((ws + '),', [])) 2063 else: 2064 new_children.append((ws + ')', [])) 2065 2066 return (new_line, new_children) 2067 2068 2069def guarded_body_transform(body, div): 2070 new_body = [] 2071 if body[-1][0].strip().startswith('where'): 2072 new_body.extend(where_clause_transform(body[-1])) 2073 body = body[:-1] 2074 for i, (line, children) in enumerate(body): 2075 try: 2076 while div not in line: 2077 (l, c) = children[0] 2078 children = c + children[1:] 2079 line = line + ' ' + l.strip() 2080 except: 2081 sys.stderr.write('missing %r in %r\n' % (div, line)) 2082 sys.stderr.write('\nhandling %r\n' % body) 2083 sys.exit(1) 2084 try: 2085 [ws, guts] = line.split('| ', 1) 2086 except: 2087 sys.stderr.write('missing "|" in %r\n' % line) 2088 sys.stderr.write('\nhandling %r\n' % body) 2089 sys.exit(1) 2090 if i == 0: 2091 new_body.append((ws + 'if', [])) 2092 else: 2093 new_body.append((ws + 'else if', [])) 2094 guts = ' then '.join(guts.split(div, 1)) 2095 new_body.append((ws + guts, children)) 2096 new_body.append((ws + 'else undefined', [])) 2097 2098 return new_body 2099 2100 2101def lhs_transform(line): 2102 if '(' not in line: 2103 return line 2104 2105 [left, right] = line.split('\<equiv>') 2106 2107 ws = left[:len(left) - len(left.lstrip())] 2108 2109 left = left.lstrip() 2110 2111 bits = braces.str(left, '(', ')').split(braces=True) 2112 2113 for (i, bit) in enumerate(bits): 2114 if bit.startswith('('): 2115 bits[i] = 'arg%d' % i 2116 case = 'case arg%d of %s \<Rightarrow> ' % (i, bit) 2117 right = case + right 2118 2119 return ws + ' '.join([str(bit) for bit in bits]) + '\<equiv>' + right 2120 2121 2122def lhs_de_underscore(line): 2123 if '_' not in line: 2124 return line 2125 2126 [left, right] = line.split('\<equiv>') 2127 2128 ws = left[:len(left) - len(left.lstrip())] 2129 2130 left = left.lstrip() 2131 bits = left.split() 2132 2133 for (i, bit) in enumerate(bits): 2134 if bit == '_': 2135 bits[i] = 'arg%d' % i 2136 2137 return ws + ' '.join([str(bit) for bit in bits]) + ' \<equiv>' + right 2138 2139 2140regexes = [ 2141 (re.compile(r" \. "), r" \<circ> "), 2142 (re.compile('-1'), '- 1'), 2143 (re.compile('-2'), '- 2'), 2144 (re.compile(r"\(!(\w+)\)"), r"(flip id \1)"), 2145 (re.compile(r"\(\+(\w+)\)"), r"(\<lambda> x. x + \1)"), 2146 (re.compile(r"\\([^<].*?)\s*->"), r"\<lambda> \1."), 2147 (re.compile('}'), r"\<rparr>"), 2148 (re.compile(r"(\s)!!(\s)"), r"\1LIST_INDEX\2"), 2149 (re.compile(r"(\w)!"), r"\1 "), 2150 (re.compile(r"\s?!"), ''), 2151 (re.compile(r"LIST_INDEX"), r"!"), 2152 (re.compile('`testBit`'), '!!'), 2153 (re.compile(r"//"), ' aLU '), 2154 (re.compile('otherwise'), 'True '), 2155 (re.compile(r"(^|\W)fail "), r"\1haskell_fail "), 2156 (re.compile('assert '), 'haskell_assert '), 2157 (re.compile('assertE '), 'haskell_assertE '), 2158 (re.compile('=='), '='), 2159 (re.compile(r"\(/="), '(\<lambda>x. x \<noteq>'), 2160 (re.compile('/='), '\<noteq>'), 2161 (re.compile('"([^"])*"'), '[]'), 2162 (re.compile('&&'), '\<and>'), 2163 (re.compile('\|\|'), '\<or>'), 2164 (re.compile(r"(\W)not(\s)"), r"\1Not\2"), 2165 (re.compile(r"(\W)and(\s)"), r"\1andList\2"), 2166 (re.compile(r"(\W)or(\s)"), r"\1orList\2"), 2167 # regex ordering important here 2168 (re.compile(r"\.&\."), '&&'), 2169 (re.compile(r"\(\.\|\.\)"), r"bitOR"), 2170 (re.compile(r"\(\+\)"), r"op +"), 2171 (re.compile(r"\.\|\."), '||'), 2172 (re.compile(r"Data\.Word\.Word"), r"word"), 2173 (re.compile(r"Data\.Map\."), r"data_map_"), 2174 (re.compile(r"Data\.Set\."), r"data_set_"), 2175 (re.compile(r"BinaryTree\."), 'bt_'), 2176 (re.compile("mapM_"), "mapM_x"), 2177 (re.compile("forM_"), "forM_x"), 2178 (re.compile("forME_"), "forME_x"), 2179 (re.compile("mapME_"), "mapME_x"), 2180 (re.compile("zipWithM_"), "zipWithM_x"), 2181 (re.compile(r"bit\s+([0-9]+)(\s)"), r"(1 << \1)\2"), 2182 (re.compile('`mod`'), 'mod'), 2183 (re.compile('`div`'), 'div'), 2184 (re.compile(r"`((\w|\.)*)`"), r"`~\1~`"), 2185 (re.compile('size'), 'magnitude'), 2186 (re.compile('foldr'), 'foldR'), 2187 (re.compile(r"\+\+"), '@'), 2188 (re.compile(r"(\s|\)|\w|\]):(\s|\(|\w|$|\[)"), r"\1#\2"), 2189 (re.compile(r"\[([^]]*)\.\.([^]]*)\]"), r"[\1 .e. \2]"), 2190 (re.compile(r"\[([^]]*)\.\.\s*$"), r"[\1 .e."), 2191 (re.compile(' Right'), ' Inr'), 2192 (re.compile(' Left'), ' Inl'), 2193 (re.compile('\\(Left'), '(Inl'), 2194 (re.compile('\\(Right'), '(Inr'), 2195 (re.compile(r"\$!"), r"$"), 2196 (re.compile('([^>])>='), r'\1\<ge>'), 2197 (re.compile('>>([^=])'), r'>>_\1'), 2198 (re.compile('<='), '\<le>'), 2199 (re.compile(r" \\\\ "), " `~listSubtract~` "), 2200 (re.compile(r"(\s\w+)\s*@\s*\w+\s*{\s*}\s*\<leftarrow>"), 2201 r"\1 \<leftarrow>"), 2202] 2203 2204ext_regexes = [ 2205 (re.compile(r"(\s[A-Z]\w*)\s*{"), r"\1_ \<lparr>", re.compile(r"(\w)\s*="), 2206 r"\1="), 2207 (re.compile(r"(\([A-Z]\w*)\s*{"), r"\1_ \<lparr>", re.compile(r"(\w)\s*="), 2208 r"\1="), 2209 (re.compile(r"{([^={<]*[^={<:])=([^={<]*)\\<rparr>"), 2210 r"\<lparr>\1:=\2\<rparr>", 2211 re.compile(r"THIS SHOULD NOT APPEAR IN THE SOURCE"), ""), 2212 (re.compile(r"{"), r"\<lparr>", re.compile(r"([^=:])=(\s|$|\w)"), 2213 r"\1:=\2"), 2214] 2215 2216leading_bar = re.compile(r"\s*\|") 2217 2218type_assertion = re.compile(r"\(([^(]*)::([^)]*)\)") 2219 2220 2221def run_regexes(xxx_todo_changeme6, _regexes=regexes): 2222 (line, children) = xxx_todo_changeme6 2223 for re, s in _regexes: 2224 line = re.sub(s, line) 2225 children = [run_regexes(elt, _regexes=_regexes) for elt in children] 2226 return ((line, children)) 2227 2228 2229def run_ext_regexes(xxx_todo_changeme7): 2230 (line, children) = xxx_todo_changeme7 2231 for re, s, add_re, add_s in ext_regexes: 2232 m = re.search(line) 2233 if m is None: 2234 continue 2235 before = line[:m.start()] 2236 substituted = m.expand(s) 2237 after = line[m.end():] 2238 add = [(add_re, add_s)] 2239 (after, children) = run_regexes((after, children), _regexes=add) 2240 line = before + substituted + after 2241 children = [run_ext_regexes(elt) for elt in children] 2242 return (line, children) 2243 2244 2245def get_case_lhs(lhs): 2246 assert lhs.startswith('case \\x of ') 2247 lhs = lhs.split('case \\x of ', 1)[1] 2248 cases = lhs.split('->') 2249 cases = [case.strip() for case in cases] 2250 cases = [case for case in cases if case != ''] 2251 cases = tuple(cases) 2252 2253 return cases 2254 2255 2256def get_case_rhs(rhs): 2257 tuples = [] 2258 while '->' in rhs: 2259 bits = rhs.split('->', 1) 2260 s = bits[0] 2261 bits = bits[1].split(None, 1) 2262 n = int(takeWhile(bits[0], lambda x: x.isdigit())) - 1 2263 if len(bits) > 1: 2264 rhs = bits[1] 2265 else: 2266 rhs = '' 2267 tuples.append((s, n)) 2268 if rhs != '': 2269 tuples.append((rhs, None)) 2270 2271 conv = [] 2272 for (string, num) in tuples: 2273 bits = string.split('\\n') 2274 bits = [bit.strip() for bit in bits] 2275 conv.extend([(bit, None) for bit in bits[:-1]]) 2276 conv.append((bits[-1], num)) 2277 2278 conv = [(s, n) for (s, n) in conv if s != '' or n is not None] 2279 2280 if conv[0][1] is not None: 2281 sys.stderr.write('%r\n' % conv[0][1]) 2282 sys.stderr.write( 2283 'For technical reasons the first line of this case conversion must be split with \\n: \n') 2284 sys.stderr.write(' %r\n' % rhs) 2285 sys.stderr.write( 2286 '(further notes: the rhs of each caseconv must have multiple lines\n' 2287 'and the first cannot contain any ->1, ->2 etc.)\n') 2288 sys.exit(1) 2289 2290 # this is a tad dodgy, but means that case_clauses_transform 2291 # can be safely run twice on the same input 2292 if conv[0][0].endswith('of'): 2293 conv[0] = (conv[0][0] + ' ', conv[0][1]) 2294 2295 return conv 2296 2297 2298def render_caseconv(cases, conv, f): 2299 bits = [bit for bit in conv.split('\\n') if bit != ''] 2300 assert bits 2301 casestr = 'case \\x of ' + ' -> '.join(cases) + ' -> ' 2302 f.write('%s --->' % casestr) 2303 for bit in bits: 2304 f.write(bit) 2305 f.write('\n') 2306 f.write('\n') 2307 2308 2309def get_case_conv_table(): 2310 f = open('caseconvs') 2311 f2 = open('caseconvs-useful', 'w') 2312 result = {} 2313 input = map(str.rstrip, f) 2314 input = ("\\n".join(lines) for lines in splitList(input, lambda s: s == '')) 2315 2316 for line in input: 2317 if line.strip() == '': 2318 continue 2319 try: 2320 if '---X>' in line: 2321 [from_case, _] = line.split('---X>') 2322 cases = get_case_lhs(from_case) 2323 result[cases] = "<X>" 2324 else: 2325 [from_case, to_case] = line.split('--->') 2326 cases = get_case_lhs(from_case) 2327 conv = get_case_rhs(to_case) 2328 result[cases] = conv 2329 if (not all_constructor_patterns(cases) and 2330 not is_extended_pattern(cases)): 2331 render_caseconv(cases, to_case, f2) 2332 except Exception as e: 2333 sys.stderr.write('Error parsing %r\n' % line) 2334 sys.stderr.write('%s\n ' % e) 2335 sys.exit(1) 2336 2337 f.close() 2338 f2.close() 2339 2340 return result 2341 2342 2343def all_constructor_patterns(cases): 2344 if cases[-1].strip() == '_': 2345 cases = cases[:-1] 2346 for pat in cases: 2347 if not is_constructor_pattern(pat): 2348 return False 2349 return True 2350 2351 2352def is_constructor_pattern(pat): 2353 """A constructor pattern takes the form Cons var1 var2 ..., 2354 characterised by all alphanumeric names, the constructor starting 2355 with an uppercase alphabetic char and the vars with lowercase.""" 2356 bits = pat.split() 2357 for bit in bits: 2358 if (not bit.isalnum()) and (not bit == '_'): 2359 return False 2360 if not bits[0][0].isupper(): 2361 return False 2362 for bit in bits[1:]: 2363 if (not bit[0].islower()) and (not bit == '_'): 2364 return False 2365 return True 2366 2367 2368ext_checker = re.compile(r"^(\(|\)|,|{|}|=|[a-zA-Z][0-9']?|\s|_|:|\[|\])*$") 2369 2370 2371def is_extended_pattern(cases): 2372 for case in cases: 2373 if not ext_checker.match(case): 2374 return False 2375 return True 2376 2377 2378case_conv_table = get_case_conv_table() 2379cases_added = {} 2380 2381 2382def get_case_conv(cases): 2383 if all_constructor_patterns(cases): 2384 return all_constructor_conv(cases) 2385 2386 if is_extended_pattern(cases): 2387 return extended_pattern_conv(cases) 2388 2389 return case_conv_table.get(cases) 2390 2391 2392constructor_conv_table = { 2393 'Just': 'Some', 2394 'Nothing': 'None', 2395 'Left': 'Inl', 2396 'Right': 'Inr', 2397 'PPtr': '\<comment> \<open>PPtr\<close>', 2398 'Register': '\<comment> \<open>Register\<close>', 2399 'Word': '\<comment> \<open>Word\<close>', 2400} 2401 2402unique_ids_per_file = {} 2403 2404 2405def get_next_unique_id(): 2406 id = unique_ids_per_file.get(filename, 1) 2407 unique_ids_per_file[filename] = id + 1 2408 return id 2409 2410 2411def all_constructor_conv(cases): 2412 conv = [('case \\x of', None)] 2413 2414 for i, pat in enumerate(cases): 2415 bits = pat.split() 2416 if bits[0] in constructor_conv_table: 2417 bits[0] = constructor_conv_table[bits[0]] 2418 for j, bit in enumerate(bits): 2419 if j > 0 and bit == '_': 2420 bits[j] = 'v%d' % get_next_unique_id() 2421 pat = ' '.join(bits) 2422 if i == 0: 2423 conv.append((' %s \<Rightarrow> ' % pat, i)) 2424 else: 2425 conv.append(('| %s \<Rightarrow> ' % pat, i)) 2426 return conv 2427 2428 2429word_getter = re.compile(r"([a-zA-Z0-9]+)") 2430 2431record_getter = re.compile(r"([a-zA-Z0-9]+\s*{[a-zA-Z0-9'\s=\,_\(\):\]\[]*})") 2432 2433 2434def extended_pattern_conv(cases): 2435 conv = [('case \\x of', None)] 2436 2437 for i, pat in enumerate(cases): 2438 pat = '#'.join(pat.split(':')) 2439 while record_getter.search(pat): 2440 [left, record, right] = record_getter.split(pat) 2441 record = reduce_record_pattern(record) 2442 pat = left + record + right 2443 if '{' in pat: 2444 print(pat) 2445 assert '{' not in pat 2446 bits = word_getter.split(pat) 2447 bits = [constructor_conv_table.get(bit, bit) for bit in bits] 2448 pat = ''.join(bits) 2449 if i == 0: 2450 conv.append((' %s \<Rightarrow> ' % pat, i)) 2451 else: 2452 conv.append(('| %s \<Rightarrow> ' % pat, i)) 2453 return conv 2454 2455 2456def reduce_record_pattern(string): 2457 assert string[-1] == '}' 2458 string = string[:-1] 2459 [left, right] = string.split('{') 2460 cons = left.strip() 2461 right = braces.str(right, '(', ')') 2462 eqs = right.split(',') 2463 uses = {} 2464 for eq in eqs: 2465 eq = str(eq).strip() 2466 if eq: 2467 [left, right] = eq.split('=') 2468 (left, right) = (left.strip(), right.strip()) 2469 if len(right.split()) > 1: 2470 right = '(%s)' % right 2471 uses[left] = right 2472 if cons not in all_constructor_args: 2473 sys.stderr.write('FAIL: trying to build case for %s\n' % cons) 2474 sys.stderr.write('when reading %s\n' % filename) 2475 sys.stderr.write('but constructor not seen yet\n') 2476 sys.stderr.write('perhaps parse in different order?\n') 2477 sys.stderr.write('hint: #INCLUDE_HASKELL_PREPARSE\n') 2478 sys.exit(1) 2479 args = all_constructor_args[cons] 2480 args = [uses.get(name, '_') for (name, type) in args] 2481 return cons + ' ' + ' '.join(args) 2482 2483 2484def subs_nums_and_x(conv, x): 2485 ids = [] 2486 2487 result = [] 2488 for (line, num) in conv: 2489 line = x.join(line.split('\\x')) 2490 bits = line.split('\\v') 2491 line = bits[0] 2492 for bit in bits[1:]: 2493 bits = bit.split('\\', 1) 2494 n = int(bits[0]) 2495 while n >= len(ids): 2496 ids.append(get_next_unique_id()) 2497 line = line + 'v%d' % (ids[n]) 2498 if len(bits) > 1: 2499 line = line + bits[1] 2500 result.append((line, num)) 2501 2502 return result 2503 2504 2505def get_supplied_transform_table(): 2506 f = open('supplied') 2507 2508 lines = [line.rstrip() for line in f] 2509 f.close() 2510 2511 lines = [(line, n + 1) for (n, line) in enumerate(lines)] 2512 lines = [(line, n) for (line, n) in lines if line != ''] 2513 2514 for line in lines: 2515 if '\t' in line: 2516 sys.stderr.write('WARN: tab character in supplied') 2517 2518 tree = offside_tree(lines) 2519 2520 result = {} 2521 2522 for line, n, children in tree: 2523 if ('conv:' not in line) or len(children) != 2: 2524 sys.stderr.write('WARN: supplied line %d dropped\n' 2525 % n) 2526 if 'conv:' not in line: 2527 sys.stderr.write('\t\t(token "conv:" missing)\n') 2528 if len(children) != 2: 2529 sys.stderr.write('\t\t(%d children != 2)\n' % len(children)) 2530 continue 2531 2532 children = discard_line_numbers(children) 2533 2534 before, after = children 2535 2536 before = convert_to_stripped_tuple(before[0], before[1]) 2537 2538 result[before] = after 2539 2540 return result 2541 2542 2543def print_tree(tree, indent=0): 2544 for line, children in tree: 2545 print('\t' * indent) + line.strip() 2546 print_tree(children, indent + 1) 2547 2548 2549supplied_transform_table = get_supplied_transform_table() 2550supplied_transforms_usage = dict(( 2551 key, 0) for key in six.iterkeys(supplied_transform_table)) 2552 2553 2554def warn_supplied_usage(): 2555 for (key, usage) in six.iteritems(supplied_transforms_usage): 2556 if not usage: 2557 sys.stderr.write('WARN: supplied conv unused: %s\n' 2558 % key[0]) 2559 2560 2561quotes_getter = re.compile('"[^"]+"') 2562 2563 2564def detect_recursion(body): 2565 """Detects whether any of the bodies of the definitions of this 2566 function recursively refer to it.""" 2567 single_lines = [reduce_to_single_line(elt) for elt in body] 2568 single_lines = [''.join(quotes_getter.split(l)) for l in single_lines] 2569 bits = [line.split(None, 1) for line in single_lines] 2570 name = bits[0][0] 2571 assert [n for (n, _) in bits if n != name] == [] 2572 return [body for (n, body) in bits if name in body] != [] 2573 2574 2575def primrec_transform(d): 2576 sig = d.sig 2577 defn = d.defined 2578 body = [] 2579 is_not_first = False 2580 for (l, c) in d.body: 2581 [(l, c)] = body_transform([(l, c)], defn, sig, nopattern=True) 2582 if is_not_first: 2583 l = "| " + l 2584 else: 2585 l = " " + l 2586 is_not_first = True 2587 l = l.split('\<equiv>') 2588 assert len(l) == 2 2589 l = '= ('.join(l) 2590 (l, c) = remove_trailing_string('"', (l, c)) 2591 (l, c) = add_trailing_string(')"', (l, c)) 2592 body.append((l, c)) 2593 d.primrec = True 2594 d.body = body 2595 return d 2596 2597 2598variable_name_regex = re.compile(r"^[a-z]\w*$") 2599 2600 2601def is_variable_name(string): 2602 return variable_name_regex.match(string) 2603 2604 2605def pattern_match_transform(body): 2606 """Converts a body containing possibly multiple definitions 2607 and containing pattern matches into a normal Isabelle definition 2608 followed by a big Haskell case expression which is resolved 2609 elsewhere.""" 2610 splits = [] 2611 for (line, children) in body: 2612 string = braces.str(line, '(', ')') 2613 while len(string.split('=')) == 1: 2614 if len(children) == 1: 2615 [(moreline, children)] = children 2616 string = string + ' ' + moreline.strip() 2617 elif children and leading_bar.match(children[0][0]): 2618 string = string + ' =' 2619 children = \ 2620 guarded_body_transform(children, ' = ') 2621 elif children and children[0][1] == []: 2622 (moreline, _) = children.pop(0) 2623 string = string + ' ' + moreline.strip() 2624 else: 2625 print() 2626 print(line) 2627 print() 2628 for child in children: 2629 print(child) 2630 assert 0 2631 2632 [lead, tail] = string.split('=', 1) 2633 bits = lead.split() 2634 unbraced = bits 2635 function = str(bits[0]) 2636 splits.append((bits[1:], unbraced[1:], tail, children)) 2637 2638 common = splits[0][0][:] 2639 for i, term in enumerate(common): 2640 if term.startswith('('): 2641 common[i] = None 2642 if '@' in term: 2643 common[i] = None 2644 if term[0].isupper(): 2645 common[i] = None 2646 2647 for (bits, _, _, _) in splits[1:]: 2648 for i, term in enumerate(bits): 2649 if i >= len(common): 2650 print_tree(body) 2651 if term != common[i]: 2652 is_var = is_variable_name(str(term)) 2653 if common[i] == '_' and is_var: 2654 common[i] = term 2655 elif term != '_': 2656 common[i] = None 2657 2658 for i, term in enumerate(common): 2659 if term == '_': 2660 common[i] = 'x%d' % i 2661 2662 blanks = [i for (i, n) in enumerate(common) if n is None] 2663 2664 line = '%s ' % function 2665 for i, name in enumerate(common): 2666 if name is None: 2667 line = line + 'x%d ' % i 2668 else: 2669 line = line + '%s ' % name 2670 if blanks == []: 2671 print(splits) 2672 print(common) 2673 if len(blanks) == 1: 2674 line = line + '= case x%d of' % blanks[0] 2675 else: 2676 line = line + '= case (x%d' % blanks[0] 2677 for i in blanks[1:]: 2678 line = line + ', x%d' % i 2679 line = line + ') of' 2680 2681 children = [] 2682 for (bits, unbraced, tail, c) in splits: 2683 if len(blanks) == 1: 2684 l = ' %s' % unbraced[blanks[0]] 2685 else: 2686 l = ' (%s' % unbraced[blanks[0]] 2687 for i in blanks[1:]: 2688 l = l + ', %s' % unbraced[i] 2689 l = l + ')' 2690 l = l + ' -> %s' % tail 2691 children.append((l, c)) 2692 2693 return [(line, children)] 2694 2695 2696def get_lambda_body_lines(d): 2697 """Returns lines equivalent to the body of the function as 2698 a lambda expression.""" 2699 fn = d.defined 2700 2701 [(line, children)] = d.body 2702 2703 line = line[1:] 2704 # find \<equiv> in first or 2nd line 2705 if '\<equiv>' not in line and '\<equiv>' in children[0][0]: 2706 (l, c) = children[0] 2707 children = c + children[1:] 2708 line = line + l 2709 [lhs, rhs] = line.split('\<equiv>', 1) 2710 bits = lhs.split() 2711 args = bits[1:] 2712 assert fn in bits[0] 2713 2714 line = '(\<lambda>' + ' '.join(args) + '. ' + rhs 2715 # lines = ['(* body of %s *)' % fn, line] + flatten_tree (children) 2716 lines = [line] + flatten_tree(children) 2717 assert (lines[-1].endswith('"')) 2718 lines[-1] = lines[-1][:-1] + ')' 2719 2720 return lines 2721 2722 2723def add_trailing_string(s, xxx_todo_changeme8): 2724 (line, children) = xxx_todo_changeme8 2725 if children == []: 2726 return (line + s, children) 2727 else: 2728 modified = add_trailing_string(s, children[-1]) 2729 return (line, children[0:-1] + [modified]) 2730 2731 2732def remove_trailing_string(s, xxx_todo_changeme9, _handled=False): 2733 (line, children) = xxx_todo_changeme9 2734 if not _handled: 2735 try: 2736 return remove_trailing_string(s, (line, children), _handled=True) 2737 except: 2738 sys.stderr.write('handling %s\n' % ((line, children), )) 2739 raise 2740 if children == []: 2741 if not line.endswith(s): 2742 sys.stderr.write('ERR: expected %r\n' % line) 2743 sys.stderr.write('to end with %r\n' % s) 2744 assert line.endswith(s) 2745 n = len(s) 2746 return (line[:-n], []) 2747 else: 2748 modified = remove_trailing_string(s, children[-1], _handled=True) 2749 return (line, children[0:-1] + [modified]) 2750 2751 2752def get_trailing_string(n, xxx_todo_changeme10): 2753 (line, children) = xxx_todo_changeme10 2754 if children == []: 2755 return line[-n:] 2756 else: 2757 return get_trailing_string(n, children[-1]) 2758 2759 2760def has_trailing_string(s, xxx_todo_changeme11): 2761 (line, children) = xxx_todo_changeme11 2762 if children == []: 2763 return line.endswith(s) 2764 else: 2765 return has_trailing_string(s, children[-1]) 2766 2767 2768def ensure_type_ordering(defs): 2769 typedefs = [d for d in defs if d.type == 'newtype'] 2770 other = [d for d in defs if d.type != 'newtype'] 2771 2772 final_typedefs = [] 2773 while typedefs: 2774 try: 2775 i = 0 2776 deps = typedefs[i].typedeps 2777 while 1: 2778 for j, term in enumerate(typedefs): 2779 if term.typename in deps: 2780 break 2781 else: 2782 break 2783 i = j 2784 deps = typedefs[i].typedeps 2785 final_typedefs.append(typedefs.pop(i)) 2786 except Exception as e: 2787 print('Exception hit ordering types:') 2788 for td in typedefs: 2789 print(' - %s' % td.typename) 2790 raise e 2791 2792 return final_typedefs + other 2793 2794 2795def lead_ws(string): 2796 amount = len(string) - len(string.lstrip()) 2797 return string[:amount] 2798 2799 2800def adjust_ws(xxx_todo_changeme12, n): 2801 (line, children) = xxx_todo_changeme12 2802 if n > 0: 2803 line = ' ' * n + line 2804 else: 2805 x = -n 2806 line = line[x:] 2807 2808 return (line, [adjust_ws(child, n) for child in children]) 2809 2810 2811modulename = re.compile(r"(\w+\.)+") 2812 2813 2814def perform_module_redirects(lines, call): 2815 return [subst_module_redirects(line, call) for line in lines] 2816 2817 2818def subst_module_redirects(line, call): 2819 m = modulename.search(line) 2820 if not m: 2821 return line 2822 module = line[m.start():m.end() - 1] 2823 before = line[:m.start()] 2824 after = line[m.end():] 2825 after = subst_module_redirects(after, call) 2826 if module in call.moduletranslations: 2827 module = call.moduletranslations[module] 2828 if module: 2829 return before + module + '.' + after 2830 else: 2831 return before + after 2832