1# 2# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230) 3# 4# SPDX-License-Identifier: BSD-2-Clause 5# 6 7# Syntax and simple operations for types, expressions, graph nodes 8# and graph functions (functions in graph-language format). 9 10from target_objects import structs, trace 11import target_objects 12 13quick_reference = """ 14Quick reference on the graph language and its syntax. 15===================================================== 16 17Example 18======= 19 20Let's build up the syntax that roughly mirrors the C statement x = y + 1; 21 22This is a quick example. A more thorough reference is below. 23 24We have two example types: Bool and Word 32. 25 26Here are two atomic expressions of type Word 32: 27Var y Word 32 28Num 1 Word 32 29 30These encode the variable y and the number 1, both of type Word 32. 31 32Expressions in this language are built up by concatenating strings of tokens. 33Any whitespace delimited string can be a token, and thus a variable name. 34 35Compound expressions are built with operators, e.g. y + 1: 36Op Plus Word 32 2 Var y Word 32 Num 1 Word 32 37 38This is quite verbose. For simplicity, the Op syntax includes the type of 39the whole expression - Word 32 - and the number of arguments - 2 - even though 40much of this information is redundant. 41 42Finally, we can encode the statement x = y + 1; at address 14. 43 4414 Basic 15 1 x Word 32 Op Plus Word 32 2 Var y Word 32 Num 1 Word 32 45 46This specifies that node 14 is a basic node (which updates variables). The 47successor node is 15. It updates 1 variable (basic nodes may simultaneously 48update many variables). The variable to be updated is specified by x Word 32. 49The remainder of the line is the expression y + 1 as seen before. 50 51We can encode the statements if (x == z) { return; } as follows: 5215 Cond 16 17 Op Equals Bool Var x Word 32 Var y Word 32 5316 Basic Ret 0 54 55Conditional node 15 continues to node 16 if x == y, and to 17 otherwise. Node 5616 is a basic node which updates no variables, i.e. a skip statement. It 57continues to the special address Ret which designates return from the current 58function. 59 60Reference 61========= 62 63A graph program consists of a number of functions, each 64of whose control flow is structured as a graph of nodes 65with integer addresses. Nodes are of three types: 66 - 'Basic' updates variable values. 67 - 'Cond' chooses between two successor nodes. 68 - 'Call' makes calls to functions. 69 70A top level syntax file (interpreted by syntax.parse_all) 71contains two kinds of sections: 72 - Function blocks introduce functions. 73 - Struct blocks introduce aggregate types (e.g. struct in C). 74 75All lines in the input syntax are read as a whitespace-separated 76list of tokens. Empty lines and lines whose first non-whitespace 77character is '#' are ignored. All syntax is prefix encoded, so 78each clause has a first token which determines what kind of clause 79it is, followed by the concatenation of all of its subclauses. 80 81The two kinds of lines that appear in Struct blocks illustrate this format: 82Struct "struct name" <int size> <int alignment> 83StructField "name" <type> <int offset> 84 85On notation: above we denote the specific tokens 'Struct' and 'StructField', 86the arbitrary tokens "name" etc, the special integer subclauses <int size> etc, 87and a <type> subclause. Integer subclauses always consist of a single token. 88A leading character '-' or '~' indicates negative. A leading '0x' or '0b' 89indicates hexadecimal or binary encoding, otherwise decimal, e.g. 1, 0x1A, 90-0x2b, ~0b1100, etc. Naming tokens can be anything at all that does not contain 91whitespace, so x y'v a,b c#d # etc are all valid struct or field names. 92 93Struct blocks begin with the 'Struct' line, which specifies their name, total 94size and required alignment. The struct block then contains a number of 95StructField lines, where each field has a type and an offset from the start of 96the structure. The Struct block ends where any other Struct or Function block 97begins. 98 99Function blocks begin with a function declaration line 100 Function "name of function" <int>*("input argument name" <type>) 101 <int>*("output argument name" <type>) 102 103The notation <int>*(...) above denotes an arbitrary length list of subclauses. 104The first token of this clause will be a decimal integer specifying the length 105of the list. For instance, the classic 'XOR' function with inputs x, y and 106output z could be specified as follows: 107Function XOR 2 x Bool y Bool 1 z Bool 108 109Each function has a name, a list of input arguments, and a list of output 110arguments. The state in the graph language is entirely encoded in variables, 111so functions will typically have global objects such as the heap as both input 112and output arguments. 113 114The function block may end immediately, for functions with unspecified body, 115or may contain a number of graph node lines, followed by an entry point line: 116EntryPoint <int> 117The entry point line ends the function block. 118 119A graph node line has one of these formats: 120<int> Basic <next node> <int>*("var name" <type> <expression>) 121<int> Cond <next node> <next node> <expression> 122<int> Call <next node> "fun name" <int>*(<expression>) <int>*("var name" <type>) 123 124Each node line begins with an integer which is the address of the node. A 125<next node> is either the integer address of the following node, or one of the 126special address tokens Ret and Err. 127 128Basic nodes update the value of a (possibly empty) list of variables. The 129variables are all updated simultaneously to the computed value of the 130expressions given. 131 132Cond nodes evaluate an expresion which must be of boolean type. They specify 133two possible next nodes, left and right. The left node is visited when the 134expression evaluates to true, the right on false. 135 136Call nodes call another function of a given name. The list of arguments 137expressions is evaluated and must match the list of input arguments of the 138given function (same length and types). These arguments define the starting 139variable environment of that function. Its return parameters are saved to the 140list of variables given. 141 142The <type> clauses are in one of these formats: 143Word <int> -- BitVec <int> is a synonym 144Array <type> <int> 145Struct "struct name" 146Ptr <type> 147One of the builtin types as a single token: 148 'Bool Mem Dom HTD PMS UNIT Type Token RoundingMode' 149FloatingPoint <int> <int> 150 151These types encode a machine word with the given number of bits, an array of 152some type, a struct previously defined by a Struct clause, or a pointer to 153another type. The Array, Struct and Ptr types are provided only for use in 154pointer validity assertions, which are discussed below. The FloatingPoint type 155exactly mirrors the SMTLIB2 (_ FloatingPoint eb sb) type, and is available as 156an optional extension. Using the FloatingPoint and RoundingMode types in any 157problem changes the SMT theory needed and may limit which solvers and features 158can be used. 159 160Of the builtin types, booleans are standard, UNIT is the singleton type, and 161memory is a 32-bit to 8-bit mapping. We aim to include 64-bit support soon. The 162Dom type is a set of 32-bit values, used to encode the valid domain of memory 163operations. The heap type description HTD is used by pointer-validity 164operators. The phantom machine state type PMS is unspecified and used to 165represent unspecified aspects of the environment. The type Type is the type of 166Type expressions which are used in pointer-validity. 167 168The <expression> clauses have one of these formats: 169Var "name" <type> 170Op "name" <type> <int>*(<expression>) 171Num <int> <type> 172Type <type> 173Symbol "name" <type> 174 175Most of the expressions of the language are composed from variables, numerals 176and operator applications, which should be self explanatory. The Type 177expression wraps a type into an expression (of type Type) so it may be passed 178to a pointer-validity operator. 179 180The special Symbol clauses are used to denote in source languages the values 181symbols will have in the binaries. They are replaced by specific numerals in 182a first pass. 183 184Many of the builtin operators are equivalent to SMTLIB2 operators, and are also 185available with their SMTLIB2 names. The operators are: 186 - Binary arithmetic operators on words: 187 + Plus/bvadd, Minus/bvsub, Times/bvmul, Modulus/bvurem, DividedBy/bvudiv, 188 BWAnd/bvand, BWOr/bvor, BWXOR/bvxor, 189 ShiftLeft/bvshl, ShiftRight/bvlshr, SignedShiftRight/bvashr 190 - Binary operators on booleans: 191 + And/and, Or/or, Implies 192 - Unary operators on bools: 193 + Not/not 194 - Booleans (nullary operators, i.e. constants): 195 + True/true, False/false 196 - Equals relation in any type: 197 + Equals 198 - Comparison relations on words (result is bool): 199 + Less/bvult, LessEquals/bvule, SignedLess/bvslt, SignedLessEquals/bvsle 200 - Unary operators on words: 201 + BWNot/bvnot, CountLeadingZeroes, CountTrailingZeroes, WordReverse 202 - Cast operators on words - result type may be different to argument type: 203 + WordCast, WordCastSigned 204 - Memory operations: 205 + MemAcc (args [m :: Mem, ptr :: Word 32]) any word type 206 + MemUpdate (args, [m :: Mem, ptr :: Word 32, v :: any word type]) 207 - Pointer-validity operators: 208 + PValid, PWeakValid, PAlignValid, PGlobalValid, PArrayValid 209 - Miscellaneous: 210 + MemDom, HTDUpdate 211 + IfThenElse/ite (args [b :: bool, x :: any type T, y :: T]) 212 - FloatingPoint operations from the SMTLIB2 floating point specification: 213 + roundNearestTiesToEven/RNE, roundNearestTiesToAway/RNA, 214 roundTowardPositive/RTP, roundTowardNegative/RTN, 215 roundTowardZero/RTZ 216 + fp.abs, fp.neg, fp.add, fp.sub, fp.mul, fp.div, fp.fma, fp.sqrt, 217 fp.rem, fp.roundToIntegral, fp.min, fp.max, fp.leq, fp.lt, 218 fp.geq, fp.gt, fp.eq, fp.isNormal, fp.IsSubnormal, fp.isZero, 219 fp.isInfinite, fp.isNaN, fp.isNegative, fp.isPositive 220 + ToFloatingPoint, ToFloatingPointSigned, ToFloatingPointUnsigned, 221 FloatingPointCast 222 223Operators with SMTLIB2 equivalents have the same semantics. Mem and Dom 224operations are wrap the SMTLIB2 BitVec Array type with more convenient 225operations. 226 227Memory accesses and updates can operate on various word types. 228 229The pointer-validity operators PValid, PWeakValid, PGlobalValid all take 3 230arguments of type HTD, Type, and Word 32. The PValid operator asserts that 231the heap-type-description contains this type starting at this address. This 232is exclusive with any incompatible type. PGlobalValid additionally asserts that 233this is a global object, and therefore not contained within any larger object. 234PWeakValid asserts that the type could be valid (it is aligned and within the 235range of available addresses in the heap-type-description) and is needed only 236in rare circumstances. PAlignValid omits the HTD argument and asserts only that 237the pointer is appropriately aligned and that the object does not start at or 238wrap around the 0 address. PArrayValid takes an additional argument (HTD, 239Type, Word 32, Word 32) where the final argument specifies the number of 240entries in an array. 241 242The MemDom operator takes argument types [Word 32, Dom] and returns the boolean 243of whether this pointer is in this domain. 244 245The IfThenElse operator takes a bool and any two arguments of the same type. 246 247The FloatingPoint operators are mostly taken from the SMTLIB2 floating 248point standard. The conversions ToFloatingPoint ([Word] to FP), 249ToFloatingPointSigned and ToFloatingPointUnsigned ([RoundingMode, Word] to FP) 250and FloatingPointCast (FP to FP) represent the variants of to_fp in the SMTLIB2 251standard. 252""" 253 254class Type: 255 def __init__ (self, kind, name, el_typ=None): 256 self.kind = kind 257 if kind in ['Array', 'Word', 'TokenWords']: 258 self.num = int (name) 259 else: 260 self.name = name 261 if kind in ['Array', 'Ptr']: 262 self.el_typ_symb = el_typ 263 self.el_typ = concrete_type (el_typ) 264 if kind in ['WordArray', 'FloatingPoint']: 265 self.nums = [int (name), int (el_typ)] 266 267 def __repr__ (self): 268 if self.kind == 'Array': 269 return 'Type ("Array", %r, %r)' % (self.num, 270 self.el_typ_symb) 271 elif self.kind in ('Word', 'TokenWords'): 272 return 'Type (%r, %r)' % (self.kind, self.num) 273 elif self.kind == 'Ptr': 274 return 'Type ("Ptr", %r)' % self.el_typ_symb 275 elif self.kind in ('WordArray', 'FloatingPoint'): 276 return 'Type (%r, %r, %r)' % tuple ([self.kind] 277 + self.nums) 278 else: 279 return 'Type (%r, %r)' % (self.kind, self.name) 280 281 def __eq__ (self, other): 282 if not other: 283 return False 284 if self.kind != other.kind: 285 return False 286 if self.kind in ['Array', 'Word', 'TokenWords']: 287 if self.num != other.num: 288 return False 289 elif self.kind in ['WordArray', 'FloatingPoint']: 290 if self.nums != other.nums: 291 return False 292 else: 293 if self.name != other.name: 294 return False 295 if self.kind in ['Array', 'Ptr']: 296 if self.el_typ_symb != other.el_typ_symb: 297 return False 298 return True 299 300 def __ne__ (self, other): 301 return not other or not (self == other) 302 303 def __hash__ (self): 304 return hash(str(self)) 305 306 def __cmp__ (self, other): 307 self_ss = [] 308 self.serialise (self_ss) 309 other_ss = [] 310 other.serialise (other_ss) 311 return cmp (self_ss, other_ss) 312 313 def subtypes (self): 314 if self.kind == 'Struct': 315 return structs[self.name].subtypes() 316 elif self.kind == 'Array': 317 return [self] + self.el_typ.subtypes() 318 else: 319 return [self] 320 321 def size (self): 322 if self.kind == 'Struct': 323 return structs[self.name].size 324 elif self.kind == 'Array': 325 return self.el_typ.size() * self.num 326 elif self.kind == 'Word': 327 assert self.num % 8 == 0, self 328 return self.num / 8 329 elif self.kind == 'FloatingPoint': 330 sz = sum (self.nums) 331 assert sz % 8 == 0, self 332 return sz / 8 333 elif self.kind == 'Ptr': 334 return 4 335 else: 336 assert not 'type has size' 337 338 def align (self): 339 if self.kind == 'Struct': 340 return structs[self.name].align 341 elif self.kind == 'Array': 342 return self.el_typ.align () 343 elif self.kind in ('Word', 'FloatingPoint'): 344 return self.size () 345 elif self.kind == 'Ptr': 346 return 4 347 else: 348 assert not 'type has alignment' 349 350 def serialise (self, xs): 351 if self.kind in ('Word', 'TokenWords'): 352 xs.append (self.kind) 353 xs.append (str (self.num)) 354 elif self.kind in ('WordArray', 'FloatingPoint'): 355 xs.append (self.kind) 356 xs.extend ([str (n) for n in self.nums]) 357 elif self.kind == 'Builtin': 358 xs.append (self.name) 359 elif self.kind == 'Array': 360 xs.append ('Array') 361 self.el_typ_symb.serialise (xs) 362 xs.append (str (self.num)) 363 elif self.kind == 'Struct': 364 xs.append ('Struct') 365 xs.append (self.name) 366 elif self.kind == 'Ptr': 367 xs.append ('Ptr') 368 self.el_typ_symb.serialise (xs) 369 else: 370 assert not 'type serialisable', self.kind 371 372class Expr: 373 def __init__ (self, kind, typ, name = None, struct = None, 374 field = None, val = None, vals = None): 375 self.kind = kind 376 self.typ = typ 377 if name != None: 378 self.name = name 379 if struct != None: 380 self.struct = struct 381 if field != None: 382 self.field = field 383 if val != None: 384 self.val = val 385 if vals != None: 386 self.vals = vals 387 if kind == 'Op': 388 assert type (self.vals) == list 389 390 def binds (self): 391 binds = [] 392 if self.kind in set (['Symbol', 'Var', 'ConstGlobal', 'Token']): 393 binds.append(('name', self.name)) 394 elif self.kind in ['Array', 'StructCons']: 395 binds.append(('vals', self.vals)) 396 elif self.kind == 'Field': 397 binds.append(('struct', self.struct)) 398 binds.append(('field', self.field)) 399 elif self.kind == 'FieldUpd': 400 binds.append(('struct', self.struct)) 401 binds.append(('field', self.field)) 402 binds.append(('val', self.val)) 403 elif self.kind == 'Num': 404 binds.append(('val', self.val)) 405 elif self.kind == 'Op': 406 binds.append(('name', self.name)) 407 binds.append(('vals', self.vals)) 408 elif self.kind == 'Type': 409 binds.append(('val', self.val)) 410 elif self.kind == 'SMTExpr': 411 binds.append(('val', self.val)) 412 else: 413 assert not 'expression understood for repr', self.kind 414 return binds 415 416 def __repr__ (self): 417 bits = [repr(self.kind), repr(self.typ)] 418 bits.extend(['%s = %r' % b for b in self.binds()]) 419 return 'Expr (%s)' % ', '.join(bits) 420 421 def __eq__ (self, other): 422 return (other and self.kind == other.kind 423 and self.typ == other.typ 424 and self.binds() == other.binds()) 425 426 def __ne__ (self, other): 427 return not other or not (self == other) 428 429 def __hash__ (self): 430 return hash_tuplify (self.kind, self.typ, self.binds ()) 431 432 def __cmp__ (self, other): 433 return cmp ((self.kind, self.typ, self.binds ()), 434 (other.kind, other.typ, other.binds ())) 435 436 def is_var (self, (nm, typ)): 437 return self.kind == 'Var' and all([self.name == nm, 438 self.typ == typ]) 439 440 def is_op (self, nm): 441 if type (nm) == str: 442 return self.kind == 'Op' and self.name == nm 443 else: 444 return self.kind == 'Op' and self.name in nm 445 446 def visit (self, visit): 447 visit (self) 448 if self.kind == 'Var': 449 pass 450 elif self.kind == 'Op' or self.kind == 'Array': 451 for x in self.vals: 452 x.visit (visit) 453 elif self.kind == 'StructCons': 454 for x in self.vals.itervalues (): 455 x.visit (visit) 456 elif self.kind == 'Field': 457 self.struct.visit (visit) 458 459 struct_typ = self.struct.typ 460 assert struct_typ.kind == 'Struct' 461 struct = structs[struct_typ.name] 462 (name, typ) = self.field 463 assert struct.fields[name][0] == typ 464 elif self.kind == 'FieldUpd': 465 self.val.visit (visit) 466 self.struct.visit (visit) 467 468 struct_typ = self.struct.typ 469 assert struct_typ.kind == 'Struct' 470 struct = structs[struct_typ.name] 471 (name, typ) = self.field 472 assert struct.fields[name][0] == typ 473 elif self.kind == 'ConstGlobal': 474 assert (target_objects.const_globals[self.name].typ 475 == self.typ) 476 elif self.kind in set (['Num', 'Symbol', 'Type', 'Token']): 477 pass 478 else: 479 assert not 'expr understood', self 480 481 def gen_visit (self, visit_lval, visit_rval): 482 self.visit (visit_rval) 483 484 def subst (self, substor, ss = None): 485 ret = False 486 if self.kind == 'Op': 487 subst_vals = subst_list (substor, self.vals) 488 if subst_vals: 489 self = Expr ('Op', self.typ, name = self.name, 490 vals = subst_vals) 491 ret = True 492 if (ss == None or self.kind in ss or self.is_op (ss)): 493 r = substor (self) 494 if r != None: 495 return r 496 if ret: 497 return self 498 return 499 500 def add_const_ranges (self, ranges): 501 def visit (expr): 502 if expr.kind == 'ConstGlobal': 503 (start, size, _) = symbols[expr.name] 504 assert size == expr.typ.size () 505 ranges[expr.name] = (start, start + size - 1) 506 507 self.visit (visit) 508 509 def get_mem_access (self): 510 if self.is_op ('MemAcc'): 511 [m, p] = self.vals 512 return [('MemAcc', p, self, m)] 513 elif self.is_op ('MemUpdate'): 514 [m, p, v] = self.vals 515 return [('MemUpdate', p, v, m)] 516 else: 517 return [] 518 519 def get_mem_accesses (self): 520 accesses = [] 521 def visit (expr): 522 accesses.extend (expr.get_mem_access ()) 523 self.visit (visit) 524 return accesses 525 526 def serialise (self, xs): 527 xs.append (self.kind) 528 if self.kind == 'Op': 529 xs.append (self.name) 530 self.typ.serialise (xs) 531 xs.append (str (len (self.vals))) 532 for v in self.vals: 533 v.serialise (xs) 534 elif self.kind == 'Num': 535 xs.append (str (self.val)) 536 self.typ.serialise (xs) 537 elif self.kind == 'Var': 538 xs.append (self.name) 539 self.typ.serialise (xs) 540 elif self.kind == 'Type': 541 self.val.serialise (xs) 542 elif self.kind == 'Token': 543 xs.extend ([self.kind, self.name]) 544 self.typ.serialise (xs) 545 else: 546 assert not 'expr serialisable', self.kind 547 548class Struct: 549 def __init__ (self, name, size, align): 550 self.name = name 551 self.size = size 552 self.align = align 553 self.field_list = [] 554 self.fields = {} 555 self._subtypes = None 556 self.typ = Type ('Struct', name) 557 558 def add_field (self, name, typ, offset): 559 concrete = concrete_type (typ) 560 self.field_list.append ((name, concrete)) 561 self.fields[name] = (concrete, offset, typ) 562 assert self._subtypes == None 563 564 def subtypes (self): 565 if self._subtypes != None: 566 return self._subtypes 567 xs = [self.typ] 568 for (concrete, offs, typ2) in self.fields.itervalues(): 569 xs.extend(typ2.subtypes()) 570 self._subtypes = xs 571 return xs 572 573def tuplify (x): 574 if type(x) == tuple or type(x) == list: 575 return tuple ([tuplify (y) for y in x]) 576 if type(x) == dict: 577 return tuple ([tuplify (y) for y in x.iteritems ()]) 578 else: 579 return x 580 581def hash_tuplify (* xs): 582 return hash (tuplify (xs)) 583 584def subst_list (substor, xs, ss = None): 585 ys = [x.subst (substor, ss = ss) for x in xs] 586 if [y for y in ys if y != None]: 587 xs = list (xs) 588 for (i, y) in enumerate (ys): 589 if y != None: 590 xs[i] = y 591 return xs 592 else: 593 return 594 595class Node: 596 def __init__ (self, kind, conts, args): 597 self.kind = kind 598 599 if type (conts) == list: 600 if len (conts) == 1: 601 the_cont = conts[0] 602 else: 603 the_cont = conts 604 605 if kind == 'Basic': 606 self.cont = the_cont 607 self.upds = [(lv, v) for (lv, v) in args 608 if not v.is_var (lv)] 609 elif kind == 'Call': 610 self.cont = the_cont 611 (self.fname, self.args, self.rets) = args 612 elif kind == 'Cond': 613 (self.left, self.right) = conts 614 self.cond = args 615 else: 616 assert not 'node kind understood', self.kind 617 618 def __repr__ (self): 619 return 'Node (%r, %r, %r)' % (self.kind, 620 self.get_conts (), self.get_args ()) 621 622 def __hash__ (self): 623 if self.kind == 'Call': 624 return hash ((self.fname, tuple (self.args), 625 tuple (self.rets), self.cont)) 626 elif self.kind == 'Basic': 627 return hash (tuple (self.upds)) 628 elif self.kind == 'Cond': 629 return hash ((self.cond, self.left, self.right)) 630 else: 631 assert not 'node kind understood', self.kind 632 633 def __eq__ (self, other): 634 return all ([self.kind == other.kind, 635 self.get_conts () == other.get_conts (), 636 self.get_args () == other.get_args ()]) 637 638 def __ne__ (self, other): 639 return not other or not self == other 640 641 def get_args (self): 642 if self.kind == 'Basic': 643 return self.upds 644 elif self.kind == 'Call': 645 return (self.fname, self.args, self.rets) 646 else: 647 return self.cond 648 649 def get_conts (self): 650 if self.kind == 'Cond': 651 return [self.left, self.right] 652 else: 653 return [self.cont] 654 655 def get_lvals (self): 656 if self.kind == 'Basic': 657 return [lv for (lv, v) in self.upds] 658 elif self.kind == 'Call': 659 return self.rets 660 else: 661 return [] 662 663 def is_noop (self): 664 if self.kind == 'Basic': 665 return self.upds == [] 666 elif self.kind == 'Cond': 667 return self.left == self.right 668 else: 669 return False 670 671 def visit (self, visit_lval, visit_rval): 672 if self.kind == 'Basic': 673 for (lv, v) in self.upds: 674 visit_lval (lv) 675 v.visit (visit_rval) 676 elif self.kind == 'Cond': 677 self.cond.visit (visit_rval) 678 elif self.kind == 'Call': 679 for v in self.args: 680 v.visit (visit_rval) 681 for r in self.rets: 682 visit_lval (r) 683 684 def gen_visit (self, visit_lval, visit_rval): 685 self.visit (visit_lval, visit_rval) 686 687 def subst_exprs (self, substor, ss = None): 688 if self.kind == 'Basic': 689 rvs = subst_list (substor, 690 [v for (lv, v) in self.upds], ss = ss) 691 if rvs == None: 692 return self 693 return Node ('Basic', self.cont, 694 zip ([lv for (lv, v) in self.upds], rvs)) 695 elif self.kind == 'Cond': 696 r = self.cond.subst (substor, ss = ss) 697 if r == None: 698 return self 699 return Node ('Cond', [self.left, self.right], r) 700 elif self.kind == 'Call': 701 args = subst_list (substor, self.args, ss = ss) 702 if args == None: 703 return self 704 return Node ('Call', self.cont, (self.fname, 705 args, self.rets)) 706 707 def get_mem_accesses (self): 708 accesses = [] 709 def visit (expr): 710 accesses.extend (expr.get_mem_access ()) 711 self.visit (lambda x: (), visit) 712 return accesses 713 714 def err_cond (self): 715 if self.kind != 'Cond': 716 return None 717 if self.left != 'Err': 718 if self.right == 'Err': 719 return mk_not (self.cond) 720 else: 721 return None 722 else: 723 if self.right == 'Err': 724 return true_term 725 else: 726 return self.cond 727 728 def serialise (self, xs): 729 xs.append (self.kind) 730 xs.extend ([str (c) for c in self.get_conts ()]) 731 if self.kind == 'Basic': 732 xs.append (str (len (self.upds))) 733 for (lv, v) in self.upds: 734 xs.append (lv[0]) 735 lv[1].serialise (xs) 736 v.serialise (xs) 737 elif self.kind == 'Cond': 738 self.cond.serialise (xs) 739 elif self.kind == 'Call': 740 xs.append (self.fname) 741 xs.append (str (len (self.args))) 742 for arg in self.args: 743 arg.serialise (xs) 744 xs.append (str (len (self.rets))) 745 for (nm, typ) in self.rets: 746 xs.append (nm) 747 typ.serialise (xs) 748 749def rename_lval ((name, typ), renames): 750 return (renames.get (name, name), typ) 751 752def do_subst (expr, substor, ss = None): 753 r = expr.subst (substor, ss = ss) 754 if r == None: 755 return expr 756 else: 757 return r 758 759standard_expr_kinds = set (['Symbol', 'ConstGlobal', 'Var', 'Op', 'Num', 760 'Type']) 761 762def rename_expr_substor (renames): 763 def ren (expr): 764 if expr.kind == 'Var' and expr.name in renames: 765 return mk_var (renames[expr.name], expr.typ) 766 else: 767 return 768 return ren 769 770def rename_expr (expr, renames): 771 return do_subst (expr, rename_expr_substor (renames), 772 ss = set (['Var'])) 773 774def copy_rename (node, renames): 775 (vs, ns) = renames 776 nf = lambda n: ns.get (n, n) 777 node = node.subst_exprs (rename_expr_substor (vs)) 778 if node.kind == 'Call': 779 return Node ('Call', nf (node.cont), (node.fname, node.args, 780 [rename_lval (l, vs) for l in node.rets])) 781 elif node.kind == 'Basic': 782 return Node ('Basic', nf (node.cont), 783 [(rename_lval (lv, vs), v) for (lv, v) in node.upds]) 784 elif node.kind == 'Cond': 785 return Node ('Cond', [nf (node.left), nf (node.right)], 786 node.cond) 787 else: 788 assert not 'node kind understood', node.kind 789 790class Function: 791 def __init__ (self, name, inputs, outputs): 792 self.name = name 793 self.inputs = inputs 794 self.outputs = outputs 795 self.entry = None 796 self.nodes = {} 797 798 def __hash__ (self): 799 en = self.entry 800 if not en: 801 en = -1 802 return hash (tuple ([self.name, tuple (self.inputs), 803 tuple (self.outputs), en]) 804 + tuple (sorted (self.nodes.iteritems ()))) 805 806 def reachable_nodes (self, simplify = False): 807 if not self.entry: 808 return {} 809 rs = {} 810 vs = [self.entry] 811 while vs: 812 n = vs.pop() 813 if type (n) == str: 814 continue 815 rs[n] = True 816 node = self.nodes[n] 817 if simplify: 818 import logic 819 node = logic.simplify_node_elementary (node) 820 for c in node.get_conts (): 821 if not c in rs: 822 vs.append (c) 823 return rs 824 825 def serialise (self): 826 xs = ['Function', self.name, str (len (self.inputs))] 827 for (nm, typ) in self.inputs: 828 xs.append (nm) 829 typ.serialise (xs) 830 xs.append (str (len (self.outputs))) 831 for (nm, typ) in self.outputs: 832 xs.append (nm) 833 typ.serialise (xs) 834 ss = [' '.join (xs)] 835 if not self.entry: 836 return ss 837 for n in self.nodes: 838 xs = [str (n)] 839 self.nodes[n].serialise (xs) 840 ss.append (' '.join (xs)) 841 ss.append ('EntryPoint %d' % self.entry) 842 return ss 843 844 def as_problem (self, Problem, name = 'temp'): 845 p = Problem(None, 'Function (%s)' % self.name) 846 p.clone_function (self, name) 847 p.compute_preds () 848 return p 849 850 def function_call_addrs (self): 851 return [(n, self.nodes[n].fname) 852 for n in self.nodes if self.nodes[n].kind == 'Call'] 853 854 def function_calls (self): 855 return set ([fn for (n, fn) in self.function_call_addrs ()]) 856 857 def compile_hints (self, Problem): 858 xs = ['Hints %s' % self.name] 859 860 p = self.as_problem (Problem) 861 862 for n in p.nodes: 863 ys = ['VarDeps', str (n)] 864 for (nm, typ) in p.var_deps[n]: 865 ys.append (nm) 866 typ.serialise (ys) 867 xs.append (' '.join (ys)) 868 if self.nodes[n].kind != 'Basic': 869 continue 870 return xs 871 872 def save_graph (self, fname): 873 import problem 874 problem.save_graph (self.nodes, fname) 875 876def mk_builtinTs (): 877 return dict([(n, Type('Builtin', n)) for n 878 in 'Bool Mem Dom HTD PMS UNIT Type Token RelWrapper'.split()]) 879builtinTs = mk_builtinTs () 880boolT = builtinTs['Bool'] 881word32T = Type ('Word', '32') 882word64T = Type ('Word', '64') 883word16T = Type ('Word', 16) 884word8T = Type ('Word', '8') 885 886phantom_types = set ([builtinTs[t] for t 887 in 'Dom HTD PMS UNIT Type'.split ()]) 888 889def concrete_type (typ): 890 if typ.kind == 'Ptr': 891 return word32T 892 else: 893 return typ 894 895global_wrappers = {} 896def get_global_wrapper (typ): 897 if typ in global_wrappers: 898 return global_wrappers[typ] 899 struct_name = fresh_name ('Global (%s)' % typ, structs) 900 struct = Struct (struct_name, typ.size (), typ.align ()) 901 struct.add_field ('v', typ, 0) 902 structs[struct_name] = struct 903 904 global_wrappers[typ] = struct.typ 905 return struct.typ 906 907# ========================================================== 908# parsing code for types, expressions, structs and functions 909 910def parse_int (s): 911 if s.startswith ('-') or s.startswith ('~'): 912 return (- parse_int (s[1:])) 913 if s.startswith ('0x') or s.startswith ('0b'): 914 return int (s, 0) 915 else: 916 return int (s) 917 918def parse_typ(bits, n, symbolic_types = False): 919 if bits[n] == 'Word' or bits[n] == 'BitVec': 920 return (n + 2, Type('Word', parse_int (bits[n + 1]))) 921 elif bits[n] == 'WordArray' or bits[n] == 'FloatingPoint': 922 return (n + 3, Type(bits[n], 923 parse_int (bits[n + 1]), parse_int (bits[n + 2]))) 924 elif bits[n] in builtinTs: 925 return (n + 1, builtinTs[bits[n]]) 926 elif bits[n] == 'Array': 927 (n, typ) = parse_typ (bits, n + 1, True) 928 return (n + 1, Type ('Array', parse_int (bits[n]), typ)) 929 elif bits[n] == 'Struct': 930 return (n + 2, Type ('Struct', bits[n + 1])) 931 elif bits[n] == 'Ptr': 932 (n, typ) = parse_typ (bits, n + 1, True) 933 if symbolic_types: 934 return (n, Type ('Ptr', '', typ)) 935 else: 936 return (n, word32T) 937 else: 938 assert not 'type encoded', (n, bits[n:], bits) 939 940def node_name (name): 941 if name in {'Ret':True, 'Err':True}: 942 return name 943 else: 944 try: 945 return parse_int (name) 946 except ValueError, e: 947 assert not 'node name understood', name 948 949def parse_list (parser, bits, n, extra=None): 950 try: 951 num = parse_int (bits[n]) 952 except ValueError: 953 assert not 'number parseable', (n, bits[n:], bits) 954 n = n + 1 955 xs = [] 956 for i in range (num): 957 if extra: 958 (n, x) = parser(bits, n, extra) 959 else: 960 (n, x) = parser(bits, n) 961 xs.append(x) 962 return (n, xs) 963 964def parse_arg (bits, n): 965 nm = bits[n] 966 (n, typ) = parse_typ (bits, n + 1) 967 return (n, (nm, typ)) 968 969ops = {'Plus':2, 'Minus':2, 'Times':2, 'Modulus':2, 970 'DividedBy':2, 'BWAnd':2, 'BWOr':2, 'BWXOR':2, 'And':2, 971 'Or':2, 'Implies':2, 'Equals':2, 'Less':2, 972 'LessEquals':2, 'SignedLess':2, 'SignedLessEquals':2, 973 'ShiftLeft':2, 'ShiftRight':2, 'CountLeadingZeroes':1, 974 'CountTrailingZeroes':1, 'WordReverse':1, 'SignedShiftRight':2, 975 'Not':1, 'BWNot':1, 'WordCast':1, 'WordCastSigned':1, 976 'True':0, 'False':0, 'UnspecifiedPrecond':0, 977 'MemUpdate':3, 'MemAcc':2, 'IfThenElse':3, 'ArrayIndex':2, 978 'ArrayUpdate':3, 'MemDom':2, 979 'PValid':3, 'PWeakValid':3, 'PAlignValid':2, 'PGlobalValid':3, 980 'PArrayValid':4, 981 'HTDUpdate':5, 'WordArrayAccess':2, 'WordArrayUpdate':3, 982 'TokenWordsAccess':2, 'TokenWordsUpdate':3, 983 'ROData':1, 'StackWrapper':2, 984 'ToFloatingPoint':1, 'ToFloatingPointSigned':2, 985 'ToFloatingPointUnsigned':2, 'FloatingPointCast':1, 986} 987 988ops_to_smt = {'Plus':'bvadd', 'Minus':'bvsub', 'Times':'bvmul', 'Modulus':'bvurem', 989 'DividedBy':'bvudiv', 'BWAnd':'bvand', 'BWOr':'bvor', 'BWXOR':'bvxor', 990 'And':'and', 991 'Or':'or', 'Implies':'=>', 'Equals':'=', 'Less':'bvult', 992 'LessEquals':'bvule', 'SignedLess':'bvslt', 'SignedLessEquals':'bvsle', 993 'ShiftLeft':'bvshl', 'ShiftRight':'bvlshr', 'SignedShiftRight':'bvashr', 994 'Not':'not', 'BWNot':'bvnot', 995 'True':'true', 'False':'false', 996 'UnspecifiedPrecond': 'unspecified-precond', 997 'IfThenElse':'ite', 'MemDom':'mem-dom', 998 'ROData': 'rodata', 'ImpliesROData': 'implies-rodata', 999 'WordArrayAccess':'select', 'WordArrayUpdate':'store'} 1000 1001ex_smt_ops = """roundNearestTiesToEven RNE roundNearestTiesToAway RNA 1002 roundTowardPositive RTP roundTowardNegative RTN 1003 roundTowardZero RTZ 1004 fp.abs fp.ne fp.add fp.sub fp.mul fp.div fp.fma fp.sqrt fp.rem 1005 fp.roundToInteral fp.min fp.max fp.leq fp.lt fp.eq fp.t fp.eq 1006 fp.isNormal fp.IsSubnormal fp.isZero fp.isInfinite fp.isNaN 1007 fp.isNeative fp.isPositive""".split () 1008 1009ops_to_smt.update (dict ([(smt, smt) for smt in ex_smt_ops])) 1010 1011smt_to_ops = dict ([(smt, oper) for (oper, smt) in ops_to_smt.iteritems ()]) 1012 1013def parse_struct_elem (bits, n): 1014 name = bits[n] 1015 (n, typ) = parse_typ (bits, n + 1) 1016 (n, val) = parse_expr (bits, n) 1017 return (n, (name, val)) 1018 1019def parse_expr (bits, n): 1020 if bits[n] in set (['Symbol', 'Var', 'ConstGlobal', 'Token']): 1021 kind = bits[n] 1022 name = bits[n + 1] 1023 (n, typ) = parse_typ (bits, n + 2) 1024 return (n, Expr (kind, typ, name = name)) 1025 if bits[n] == 'Array': 1026 (n, typ) = parse_typ (bits, n + 1) 1027 assert typ.kind == 'Array' 1028 (n, xs) = parse_list (parse_expr, bits, n) 1029 assert len(xs) == typ.num 1030 return (n, Expr ('Array', typ, vals = xs)) 1031 elif bits[n] == 'Field': 1032 (n, typ) = parse_typ (bits, n + 1) 1033 name = bits[n] 1034 (n, typ2) = parse_typ (bits, n + 1) 1035 (n, struct) = parse_expr (bits, n) 1036 assert struct.typ == typ 1037 return (n, Expr ('Field', typ2, struct = struct, 1038 field = (name, typ2))) 1039 elif bits[n] == 'FieldUpd': 1040 (n, typ) = parse_typ (bits, n + 1) 1041 name = bits[n] 1042 (n, typ2) = parse_typ (bits, n + 1) 1043 (n, val) = parse_expr (bits, n) 1044 (n, struct) = parse_expr (bits, n) 1045 return (n, Expr ('FieldUpd', typ, struct = struct, 1046 field = (name, typ2), val = val)) 1047 elif bits[n] == 'StructCons': 1048 (n, typ) = parse_typ (bits, n + 1) 1049 (n, xs) = parse_list (parse_struct_elem, bits, n) 1050 return (n, Expr ('StructCons', typ, vals = dict(xs))) 1051 elif bits[n] == 'Num': 1052 v = parse_int (bits[n + 1]) 1053 (n, typ) = parse_typ (bits, n + 2) 1054 return (n, Expr ('Num', typ, val = v)) 1055 elif bits[n] == 'Op': 1056 op = bits[n + 1] 1057 op = smt_to_ops.get (op, op) 1058 assert op in ops, op 1059 (n, typ) = parse_typ (bits, n + 2) 1060 (n, xs) = parse_list (parse_expr, bits, n) 1061 assert len (xs) == ops[op] 1062 return (n, Expr ('Op', typ, name = op, vals = xs)) 1063 elif bits[n] == 'Type': 1064 (n, typ) = parse_typ (bits, n + 1, symbolic_types = True) 1065 return (n, Expr ('Type', builtinTs['Type'], val = typ)) 1066 else: 1067 assert not 'expr parseable', (n, bits[n : n + 3], bits) 1068 1069def parse_lval (bits, n): 1070 name = bits[n] 1071 (n, typ) = parse_typ (bits, n + 1) 1072 return (n, (name, typ)) 1073 1074def parse_lval_and_val (bits, n): 1075 (n, (name, typ)) = parse_lval (bits, n) 1076 (n, val) = parse_expr (bits, n) 1077 return (n, ((name, typ), val)) 1078 1079def parse_node (bits, n): 1080 if bits[n] == 'Basic': 1081 cont = node_name(bits[n + 1]) 1082 (n, upds) = parse_list (parse_lval_and_val, bits, n + 2) 1083 return Node ('Basic', cont, upds) 1084 elif bits[n] == 'Cond': 1085 left = node_name(bits[n + 1]) 1086 right = node_name(bits[n + 2]) 1087 (n, cond) = parse_expr (bits, n + 3) 1088 return Node ('Cond', (left, right), cond) 1089 else: 1090 assert bits[n] == 'Call' 1091 cont = node_name(bits[n + 1]) 1092 name = bits[n + 2] 1093 (n, args) = parse_list (parse_expr, bits, n + 3) 1094 (n, saves) = parse_list (parse_lval, bits, n) 1095 return Node ('Call', cont, (name, args, saves)) 1096 1097true_term = Expr ('Op', boolT, name = 'True', vals = []) 1098false_term = Expr ('Op', boolT, name = 'False', vals = []) 1099unspecified_precond_term = Expr ('Op', boolT, name = 'UnspecifiedPrecond', vals = []) 1100 1101def parse_all(lines): 1102 '''Toplevel parser for input information. Accepts an iterator over 1103lines. See syntax.quick_reference for an explanation.''' 1104 1105 if hasattr (lines, 'name'): 1106 trace ('Loading syntax from %s' % lines.name) 1107 else: 1108 trace ('Loading syntax (from anonymous source).') 1109 1110 structs = {} 1111 functions = {} 1112 const_globals = {} 1113 cfg_warnings = [] 1114 for line in lines: 1115 bits = line.split() 1116 # empty lines and #-comments ignored 1117 if not bits or bits[0][0] == '#': 1118 continue 1119 if bits[0] == 'Struct': 1120 # Struct <name> <size> <alignment> 1121 # followed by block of StructField lines 1122 assert bits[1] not in structs 1123 current_struct = Struct (bits[1], parse_int (bits[2]), 1124 parse_int (bits[3])) 1125 structs[bits[1]] = current_struct 1126 elif bits[0] == 'StructField': 1127 # StructField <name> <type (encoded)> <offset> 1128 (_, typ) = parse_typ(bits, 2, symbolic_types = True) 1129 current_struct.add_field (bits[1], typ, 1130 parse_int (bits[-1])) 1131 elif bits[0] == 'ConstGlobalDef': 1132 # ConstGlobalDef <name> <value> 1133 name = bits[1] 1134 (_, val) = parse_expr (bits, 2) 1135 const_globals[name] = val 1136 elif bits[0] == 'Function': 1137 # Function <name> <inputs> <outputs> 1138 # followed by optional block of node lines 1139 # concluded by EntryPoint line 1140 fname = bits[1] 1141 (n, inputs) = parse_list (parse_arg, bits, 2) 1142 (_, outputs) = parse_list (parse_arg, bits, n) 1143 current_function = Function (fname, inputs, outputs) 1144 assert fname not in functions, fname 1145 functions[fname] = current_function 1146 elif bits[0] == 'EntryPoint': 1147 # EntryPoint <entry point> 1148 entry = node_name(bits[1]) 1149 # instead of setting function.entry to this value, 1150 # create a dummy node. this ensures there is always 1151 # at least one node (EntryPoint Ret is valid) and 1152 # also that the entry point is not in a loop 1153 name = fresh_node (current_function.nodes) 1154 current_function.nodes[name] = Node ('Basic', 1155 entry, []) 1156 current_function.entry = name 1157 # ensure that the function graph is closed 1158 check_cfg (current_function, warnings = cfg_warnings) 1159 current_function = None 1160 else: 1161 # <node name> <node (encoded)> 1162 name = node_name(bits[0]) 1163 assert name not in current_function.nodes, (name, bits) 1164 current_function.nodes[name] = parse_node (bits, 1) 1165 1166 print_cfg_warnings (cfg_warnings) 1167 trace ('Loaded %d functions, %d structs, %d globals.' 1168 % (len (functions), len (structs), len (const_globals))) 1169 1170 return (structs, functions, const_globals) 1171 1172def parse_and_install_all (lines, tag, skip_functions=None): 1173 if skip_functions == None: 1174 skip_functions = [] 1175 (structs, functions, const_globals) = parse_all (lines) 1176 for f in skip_functions: 1177 if f in functions: 1178 del functions[f] 1179 target_objects.structs.update (structs) 1180 target_objects.functions.update (functions) 1181 target_objects.const_globals.update (const_globals) 1182 if tag != None: 1183 target_objects.functions_by_tag[tag] = set (functions) 1184 return (structs, functions, const_globals) 1185 1186# =============================== 1187# simple accessor code and checks 1188 1189def visit_rval (vs): 1190 def visit (expr): 1191 if expr.kind == 'Var': 1192 v = expr.name 1193 if v in vs: 1194 assert vs[v] == expr.typ, (expr, vs[v]) 1195 vs[v] = expr.typ 1196 if expr.is_op ('MemAcc'): 1197 [m, p] = expr.vals 1198 assert p.typ == word32T, expr 1199 if expr.is_op ('PGlobalValid'): 1200 [htd, typ_expr, p] = expr.vals 1201 typ = typ_expr.val 1202 get_global_wrapper (typ) 1203 1204 return visit 1205 1206def visit_lval (vs): 1207 def visit ((name, typ)): 1208 assert vs.get(name, typ) == typ, (name, vs[name], typ) 1209 vs[name] = typ 1210 1211 return visit 1212 1213def get_expr_vars (expr, vs): 1214 expr.visit (visit_rval (vs)) 1215 1216def get_expr_var_set (expr): 1217 vs = {} 1218 get_expr_vars (expr, vs) 1219 return set (vs.items ()) 1220 1221def get_lval_vars (lval, vs): 1222 assert len(lval) == 2 1223 assert vs.get(lval[0], lval[1]) == lval[1] 1224 vs[lval[0]] = lval[1] 1225 1226def get_node_vars (node, vs): 1227 node.visit (visit_lval (vs), visit_rval (vs)) 1228 1229def get_node_rvals (node, vs = None): 1230 if vs == None: 1231 vs = {} 1232 node.visit (lambda l: (), visit_rval (vs)) 1233 return vs 1234 1235def get_vars(function): 1236 vs = dict(function.inputs + function.outputs) 1237 for node in function.nodes.itervalues(): 1238 get_node_vars(node, vs) 1239 return vs 1240 1241def get_lval_typ(lval): 1242 assert len(lval) == 2 1243 return lval[1] 1244 1245def get_expr_typ(expr): 1246 return expr.typ 1247 1248def check_cfg (fun, warnings = None): 1249 dead_arcs = [(n, n2) for (n, node) in fun.nodes.iteritems () 1250 for n2 in node.get_conts () 1251 if n2 not in fun.nodes and n2 not in ['Ret', 'Err']] 1252 for (n, n2) in dead_arcs: 1253 assert type (n2) != str 1254 # OK if multiple dead arcs and we save over n2 twice 1255 fun.nodes[n2] = Node ('Basic', 'Err', []) 1256 if warnings == None: 1257 print_cfg_warnings ([(fun, n, n2) for (n, n2) in dead_arcs]) 1258 else: 1259 warnings.extend ([(fun, n, n2) for (n, n2) in dead_arcs]) 1260 1261def print_cfg_warnings (warnings): 1262 post_calls = set ([(fun.nodes[n].fname, fun.name) 1263 for (fun, n, n2) in warnings 1264 if fun.nodes[n].kind == 'Call']) 1265 import logic 1266 for (call, sites) in logic.dict_list (post_calls).iteritems (): 1267 trace ('Missing nodes after calls to %s' % call) 1268 trace (' in %s' % str (sites)) 1269 for (fun, n, n2) in warnings: 1270 if fun.nodes[n].kind != 'Call': 1271 trace ('Warning: dead arc in %s: %s -> %s' 1272 % (fun.name, n, n2)) 1273 trace (' (follows %s node!)' % fun.nodes[n].kind) 1274 1275def check_funs (functions, verbose = False): 1276 for (f, fun) in functions.iteritems(): 1277 try: 1278 if not fun: 1279 continue 1280 if verbose: 1281 trace ('Checking %s' % f) 1282 check_cfg (fun) 1283 get_vars(fun) 1284 for (n, node) in fun.nodes.iteritems(): 1285 if node.kind == 'Call': 1286 c = functions[node.fname] 1287 assert map(get_expr_typ, node.args) == \ 1288 map (get_lval_typ, c.inputs), ( 1289 node.fname, node.args, c.inputs) 1290 assert map (get_lval_typ, node.rets) == \ 1291 map (get_lval_typ, c.outputs), ( 1292 node.fname, node.rets, c.outputs) 1293 elif node.kind == 'Basic': 1294 for (lv, v) in node.upds: 1295 assert get_lval_typ(lv) == get_expr_typ(v) 1296 elif node.kind == 'Cond': 1297 assert get_expr_typ(node.cond) == boolT 1298 except Exception, e: 1299 print "check_funs: failed for " + f 1300 raise e 1301 1302def get_extensions (v): 1303 extensions = set () 1304 rm = builtinTs['RoundingMode'] 1305 def visitor (expr): 1306 if expr.typ == rm or expr.typ.kind == 'FloatingPoint': 1307 extensions.add ('FloatingPoint') 1308 v.gen_visit (lambda l: (), visitor) 1309 return extensions 1310 1311# ========================================= 1312# common constructors for basic expressions 1313 1314def mk_var (nm, typ): 1315 return Expr ('Var', typ, name = nm) 1316 1317def mk_token (nm): 1318 return Expr ('Token', builtinTs['Token'], name = nm) 1319 1320def mk_plus (x, y): 1321 assert x.typ == y.typ 1322 return Expr ('Op', x.typ, name = 'Plus', vals = [x, y]) 1323 1324def mk_uminus (x): 1325 zero = Expr ('Num', x.typ, val = 0) 1326 return mk_minus (zero, x) 1327 1328def mk_minus (x, y): 1329 assert x.typ == y.typ 1330 return Expr ('Op', x.typ, name = 'Minus', vals = [x, y]) 1331 1332def mk_times (x, y): 1333 assert x.typ == y.typ 1334 return Expr ('Op', x.typ, name = 'Times', vals = [x, y]) 1335 1336def mk_divide (x, y): 1337 assert x.typ == y.typ 1338 return Expr ('Op', x.typ, name = 'DividedBy', vals = [x, y]) 1339 1340def mk_modulus (x, y): 1341 assert x.typ == y.typ 1342 return Expr ('Op', x.typ, name = 'Modulus', vals = [x, y]) 1343 1344def mk_bwand (x, y): 1345 assert x.typ == y.typ 1346 assert x.typ.kind == 'Word' 1347 return Expr ('Op', x.typ, name = 'BWAnd', vals = [x, y]) 1348 1349def mk_eq (x, y): 1350 assert x.typ == y.typ 1351 return Expr ('Op', boolT, name = 'Equals', vals = [x, y]) 1352 1353def mk_less_eq (x, y, signed = False): 1354 assert x.typ == y.typ 1355 name = {False: 'LessEquals', True: 'SignedLessEquals'}[signed] 1356 return Expr ('Op', boolT, name = name, vals = [x, y]) 1357 1358def mk_less (x, y, signed = False): 1359 assert x.typ == y.typ 1360 name = {False: 'Less', True: 'SignedLess'}[signed] 1361 return Expr ('Op', boolT, name = name, vals = [x, y]) 1362 1363def mk_implies (x, y): 1364 assert x.typ == boolT 1365 assert y.typ == boolT 1366 return Expr ('Op', boolT, name = 'Implies', vals = [x, y]) 1367 1368def mk_n_implies (xs, y): 1369 imp = y 1370 for x in reversed (xs): 1371 imp = mk_implies (x, imp) 1372 return imp 1373 1374def mk_and (x, y): 1375 assert x.typ == boolT 1376 assert y.typ == boolT 1377 return Expr ('Op', boolT, name = 'And', vals = [x, y]) 1378 1379def mk_or (x, y): 1380 assert x.typ == boolT 1381 assert y.typ == boolT 1382 return Expr ('Op', boolT, name = 'Or', vals = [x, y]) 1383 1384def mk_not (x): 1385 assert x.typ == boolT 1386 return Expr ('Op', boolT, name = 'Not', vals = [x]) 1387 1388def mk_shift_gen (name, x, n): 1389 assert x.typ.kind == 'Word' 1390 if type (n) == int: 1391 n = Expr ('Num', x.typ, val = n) 1392 return Expr ('Op', x.typ, name = name, vals = [x, n]) 1393 1394mk_shiftr = lambda x, n: mk_shift_gen ('ShiftRight', x, n) 1395 1396def mk_clz (x): 1397 return Expr ('Op', x.typ, name = "CountLeadingZeroes", vals = [x]) 1398 1399def mk_word_reverse (x): 1400 return Expr ('Op', x.typ, name = "WordReverse", vals = [x]) 1401 1402def mk_ctz (x): 1403 return mk_clz (mk_word_reverse (x)) 1404 1405def foldr1 (f, xs): 1406 x = xs[-1] 1407 for i in reversed (range (len (xs) - 1)): 1408 x = f (xs[i], x) 1409 return x 1410 1411def mk_num (x, typ): 1412 import logic 1413 if logic.is_int (typ): 1414 typ = Type ('Word', typ) 1415 assert typ.kind == 'Word', typ 1416 assert logic.is_int (x), x 1417 return Expr ('Num', typ, val = x) 1418 1419def mk_word32 (x): 1420 return mk_num (x, word32T) 1421 1422def mk_word8 (x): 1423 return mk_num (x, word8T) 1424 1425def mk_word32_maybe(x): 1426 import logic 1427 if logic.is_int (x): 1428 return mk_word32 (x) 1429 else: 1430 assert x.typ == word32T 1431 return x 1432 1433def mk_cast (x, typ): 1434 if x.typ == typ: 1435 return x 1436 else: 1437 assert x.typ.kind == 'Word', x.typ 1438 assert typ.kind == 'Word', typ 1439 return Expr ('Op', typ, name = 'WordCast', vals = [x]) 1440 1441def mk_memacc(m, p, typ): 1442 assert m.typ == builtinTs['Mem'] 1443 assert p.typ == word32T 1444 return Expr ('Op', typ, name = 'MemAcc', vals = [m, p]) 1445 1446def mk_memupd(m, p, v): 1447 assert m.typ == builtinTs['Mem'] 1448 assert p.typ == word32T 1449 return Expr ('Op', m.typ, name = 'MemUpdate', vals = [m, p, v]) 1450 1451def mk_arr_index (arr, i): 1452 assert arr.typ.kind == 'Array' 1453 return Expr ('Op', arr.typ.el_typ, name = 'ArrayIndex', 1454 vals = [arr, i]) 1455 1456def mk_arroffs(p, typ, i): 1457 assert typ.kind == 'Array' 1458 import logic 1459 if logic.is_int (i): 1460 assert i < typ.num 1461 offs = i * typ.el_typ.size() 1462 assert offs == i or offs % 4 == 0 1463 return mk_plus (p, mk_word32 (offs)) 1464 else: 1465 sz = typ.el_typ.size() 1466 return mk_plus (p, mk_times (i, mk_word32 (sz))) 1467 1468def mk_if (P, x, y): 1469 assert P.typ == boolT 1470 assert x.typ == y.typ 1471 return Expr ('Op', x.typ, name = 'IfThenElse', vals = [P, x, y]) 1472 1473def mk_meta_typ (typ): 1474 return Expr ('Type', builtinTs['Type'], val = typ) 1475 1476def mk_pvalid (htd, typ, p): 1477 return Expr ('Op', boolT, name = 'PValid', 1478 vals = [htd, mk_meta_typ (typ), p]) 1479 1480def mk_rel_wrapper (nm, vals): 1481 return Expr ('Op', builtinTs['RelWrapper'], name = nm, vals = vals) 1482 1483def adjust_op_vals (expr, vals): 1484 assert expr.kind == 'Op' 1485 return Expr ('Op', expr.typ, expr.name, vals = vals) 1486 1487mks = (mk_var, mk_plus, mk_uminus, mk_minus, mk_times, mk_modulus, mk_bwand, 1488mk_eq, mk_less_eq, mk_less, mk_implies, mk_and, mk_or, mk_not, mk_word32, 1489mk_word8, mk_word32_maybe, mk_cast, mk_memacc, mk_memupd, mk_arr_index, 1490mk_arroffs, mk_if, mk_meta_typ, mk_pvalid) 1491 1492# ==================================================================== 1493# pretty printing code for the syntax - only used for printing reports 1494 1495def pretty_type (typ): 1496 if typ.kind == 'Word': 1497 return 'Word%d' % typ.num 1498 elif typ.kind == 'WordArray': 1499 [ix, num] = typ.nums 1500 return 'Word%d[%d]' % (num, ix) 1501 elif typ.kind == 'Ptr': 1502 return 'Ptr(%s)' % pretty_type (typ.el_typ_symb) 1503 elif typ.kind == 'Struct': 1504 return 'struct %s' % typ.name 1505 elif typ.kind == 'Builtin': 1506 return typ.name 1507 else: 1508 assert not 'type pretty-printable', typ 1509 1510pretty_opers = {'Plus': '+', 'Minus': '-', 'Times': '*'} 1511 1512known_typ_change = set (['ROData', 'MemAcc', 'IfThenElse', 'WordArrayUpdate', 1513 'MemDom']) 1514 1515def pretty_expr (expr, print_type = False): 1516 if print_type: 1517 return '((%s) (%s))' % (pretty_type (expr.typ), 1518 pretty_expr (expr)) 1519 elif expr.kind == 'Var': 1520 return repr (expr.name) 1521 elif expr.kind == 'Num': 1522 return '%d' % expr.val 1523 elif expr.kind == 'Op' and expr.name in pretty_opers: 1524 [x, y] = expr.vals 1525 return '(%s %s %s)' % (pretty_expr (x), pretty_opers[expr.name], 1526 pretty_expr (y)) 1527 elif expr.kind == 'Op': 1528 if expr.name in known_typ_change: 1529 vals = [pretty_expr (v) for v in expr.vals] 1530 else: 1531 vals = [pretty_expr (v, print_type = v.typ != expr.typ) 1532 for v in expr.vals] 1533 return '%s(%s)' % (expr.name, ', '.join (vals)) 1534 elif expr.kind == 'Token': 1535 return "''%s''" % expr.name 1536 else: 1537 assert not 'expr pretty-printable', expr 1538 1539 1540# ================================================= 1541# some helper code that's needed all over the place 1542 1543def fresh_name (n, D, v=True): 1544 if n not in D: 1545 D[n] = v 1546 return n 1547 1548 x = 1 1549 y = 1 1550 while ('%s.%d' % (n, x)) in D: 1551 y = x 1552 x = x * 2 1553 while y < x: 1554 z = (y + x) / 2 1555 if ('%s.%d' % (n, z)) in D: 1556 y = z + 1 1557 else: 1558 x = z 1559 n = '%s.%d' % (n, x) 1560 assert n not in D 1561 1562 D[n] = v 1563 return n 1564 1565def fresh_node (ns, hint = 1): 1566 n = hint 1567 # use names that are *not* multiples of 4 1568 n = (n | 15) + 2 1569 while n in ns: 1570 n += 16 1571 return n 1572 1573