1# * Copyright 2015, NICTA 2# * 3# * This software may be distributed and modified according to the terms of 4# * the BSD 2-Clause license. Note that NO WARRANTY is provided. 5# * See "LICENSE_BSD2.txt" for details. 6# * 7# * @TAG(NICTA_BSD) 8 9# Syntax and simple operations for types, expressions, graph nodes 10# and graph functions (functions in graph-language format). 11 12from target_objects import structs, trace 13import target_objects 14 15quick_reference = """ 16Quick reference on the graph language and its syntax. 17===================================================== 18 19Example 20======= 21 22Let's build up the syntax that roughly mirrors the C statement x = y + 1; 23 24This is a quick example. A more thorough reference is below. 25 26We have two example types: Bool and Word 32. 27 28Here are two atomic expressions of type Word 32: 29Var y Word 32 30Num 1 Word 32 31 32These encode the variable y and the number 1, both of type Word 32. 33 34Expressions in this language are built up by concatenating strings of tokens. 35Any whitespace delimited string can be a token, and thus a variable name. 36 37Compound expressions are built with operators, e.g. y + 1: 38Op Plus Word 32 2 Var y Word 32 Num 1 Word 32 39 40This is quite verbose. For simplicity, the Op syntax includes the type of 41the whole expression - Word 32 - and the number of arguments - 2 - even though 42much of this information is redundant. 43 44Finally, we can encode the statement x = y + 1; at address 14. 45 4614 Basic 15 1 x Word 32 Op Plus Word 32 2 Var y Word 32 Num 1 Word 32 47 48This specifies that node 14 is a basic node (which updates variables). The 49successor node is 15. It updates 1 variable (basic nodes may simultaneously 50update many variables). The variable to be updated is specified by x Word 32. 51The remainder of the line is the expression y + 1 as seen before. 52 53We can encode the statements if (x == z) { return; } as follows: 5415 Cond 16 17 Op Equals Bool Var x Word 32 Var y Word 32 5516 Basic Ret 0 56 57Conditional node 15 continues to node 16 if x == y, and to 17 otherwise. Node 5816 is a basic node which updates no variables, i.e. a skip statement. It 59continues to the special address Ret which designates return from the current 60function. 61 62Reference 63========= 64 65A graph program consists of a number of functions, each 66of whose control flow is structured as a graph of nodes 67with integer addresses. Nodes are of three types: 68 - 'Basic' updates variable values. 69 - 'Cond' chooses between two successor nodes. 70 - 'Call' makes calls to functions. 71 72A top level syntax file (interpreted by syntax.parse_all) 73contains two kinds of sections: 74 - Function blocks introduce functions. 75 - Struct blocks introduce aggregate types (e.g. struct in C). 76 77All lines in the input syntax are read as a whitespace-separated 78list of tokens. Empty lines and lines whose first non-whitespace 79character is '#' are ignored. All syntax is prefix encoded, so 80each clause has a first token which determines what kind of clause 81it is, followed by the concatenation of all of its subclauses. 82 83The two kinds of lines that appear in Struct blocks illustrate this format: 84Struct "struct name" <int size> <int alignment> 85StructField "name" <type> <int offset> 86 87On notation: above we denote the specific tokens 'Struct' and 'StructField', 88the arbitrary tokens "name" etc, the special integer subclauses <int size> etc, 89and a <type> subclause. Integer subclauses always consist of a single token. 90A leading character '-' or '~' indicates negative. A leading '0x' or '0b' 91indicates hexadecimal or binary encoding, otherwise decimal, e.g. 1, 0x1A, 92-0x2b, ~0b1100, etc. Naming tokens can be anything at all that does not contain 93whitespace, so x y'v a,b c#d # etc are all valid struct or field names. 94 95Struct blocks begin with the 'Struct' line, which specifies their name, total 96size and required alignment. The struct block then contains a number of 97StructField lines, where each field has a type and an offset from the start of 98the structure. The Struct block ends where any other Struct or Function block 99begins. 100 101Function blocks begin with a function declaration line 102 Function "name of function" <int>*("input argument name" <type>) 103 <int>*("output argument name" <type>) 104 105The notation <int>*(...) above denotes an arbitrary length list of subclauses. 106The first token of this clause will be a decimal integer specifying the length 107of the list. For instance, the classic 'XOR' function with inputs x, y and 108output z could be specified as follows: 109Function XOR 2 x Bool y Bool 1 z Bool 110 111Each function has a name, a list of input arguments, and a list of output 112arguments. The state in the graph language is entirely encoded in variables, 113so functions will typically have global objects such as the heap as both input 114and output arguments. 115 116The function block may end immediately, for functions with unspecified body, 117or may contain a number of graph node lines, followed by an entry point line: 118EntryPoint <int> 119The entry point line ends the function block. 120 121A graph node line has one of these formats: 122<int> Basic <next node> <int>*("var name" <type> <expression>) 123<int> Cond <next node> <next node> <expression> 124<int> Call <next node> "fun name" <int>*(<expression>) <int>*("var name" <type>) 125 126Each node line begins with an integer which is the address of the node. A 127<next node> is either the integer address of the following node, or one of the 128special address tokens Ret and Err. 129 130Basic nodes update the value of a (possibly empty) list of variables. The 131variables are all updated simultaneously to the computed value of the 132expressions given. 133 134Cond nodes evaluate an expresion which must be of boolean type. They specify 135two possible next nodes, left and right. The left node is visited when the 136expression evaluates to true, the right on false. 137 138Call nodes call another function of a given name. The list of arguments 139expressions is evaluated and must match the list of input arguments of the 140given function (same length and types). These arguments define the starting 141variable environment of that function. Its return parameters are saved to the 142list of variables given. 143 144The <type> clauses are in one of these formats: 145Word <int> -- BitVec <int> is a synonym 146Array <type> <int> 147Struct "struct name" 148Ptr <type> 149One of the builtin types as a single token: 150 'Bool Mem Dom HTD PMS UNIT Type Token RoundingMode' 151FloatingPoint <int> <int> 152 153These types encode a machine word with the given number of bits, an array of 154some type, a struct previously defined by a Struct clause, or a pointer to 155another type. The Array, Struct and Ptr types are provided only for use in 156pointer validity assertions, which are discussed below. The FloatingPoint type 157exactly mirrors the SMTLIB2 (_ FloatingPoint eb sb) type, and is available as 158an optional extension. Using the FloatingPoint and RoundingMode types in any 159problem changes the SMT theory needed and may limit which solvers and features 160can be used. 161 162Of the builtin types, booleans are standard, UNIT is the singleton type, and 163memory is a 32-bit to 8-bit mapping. We aim to include 64-bit support soon. The 164Dom type is a set of 32-bit values, used to encode the valid domain of memory 165operations. The heap type description HTD is used by pointer-validity 166operators. The phantom machine state type PMS is unspecified and used to 167represent unspecified aspects of the environment. The type Type is the type of 168Type expressions which are used in pointer-validity. 169 170The <expression> clauses have one of these formats: 171Var "name" <type> 172Op "name" <type> <int>*(<expression>) 173Num <int> <type> 174Type <type> 175Symbol "name" <type> 176 177Most of the expressions of the language are composed from variables, numerals 178and operator applications, which should be self explanatory. The Type 179expression wraps a type into an expression (of type Type) so it may be passed 180to a pointer-validity operator. 181 182The special Symbol clauses are used to denote in source languages the values 183symbols will have in the binaries. They are replaced by specific numerals in 184a first pass. 185 186Many of the builtin operators are equivalent to SMTLIB2 operators, and are also 187available with their SMTLIB2 names. The operators are: 188 - Binary arithmetic operators on words: 189 + Plus/bvadd, Minus/bvsub, Times/bvmul, Modulus/bvurem, DividedBy/bvudiv, 190 BWAnd/bvand, BWOr/bvor, BWXOR/bvxor, 191 ShiftLeft/bvshl, ShiftRight/bvlshr, SignedShiftRight/bvashr 192 - Binary operators on booleans: 193 + And/and, Or/or, Implies 194 - Unary operators on bools: 195 + Not/not 196 - Booleans (nullary operators, i.e. constants): 197 + True/true, False/false 198 - Equals relation in any type: 199 + Equals 200 - Comparison relations on words (result is bool): 201 + Less/bvult, LessEquals/bvule, SignedLess/bvslt, SignedLessEquals/bvsle 202 - Unary operators on words: 203 + BWNot/bvnot, CountLeadingZeroes, CountTrailingZeroes, WordReverse 204 - Cast operators on words - result type may be different to argument type: 205 + WordCast, WordCastSigned 206 - Memory operations: 207 + MemAcc (args [m :: Mem, ptr :: Word 32]) any word type 208 + MemUpdate (args, [m :: Mem, ptr :: Word 32, v :: any word type]) 209 - Pointer-validity operators: 210 + PValid, PWeakValid, PAlignValid, PGlobalValid, PArrayValid 211 - Miscellaneous: 212 + MemDom, HTDUpdate 213 + IfThenElse/ite (args [b :: bool, x :: any type T, y :: T]) 214 - FloatingPoint operations from the SMTLIB2 floating point specification: 215 + roundNearestTiesToEven/RNE, roundNearestTiesToAway/RNA, 216 roundTowardPositive/RTP, roundTowardNegative/RTN, 217 roundTowardZero/RTZ 218 + fp.abs, fp.neg, fp.add, fp.sub, fp.mul, fp.div, fp.fma, fp.sqrt, 219 fp.rem, fp.roundToIntegral, fp.min, fp.max, fp.leq, fp.lt, 220 fp.geq, fp.gt, fp.eq, fp.isNormal, fp.IsSubnormal, fp.isZero, 221 fp.isInfinite, fp.isNaN, fp.isNegative, fp.isPositive 222 + ToFloatingPoint, ToFloatingPointSigned, ToFloatingPointUnsigned, 223 FloatingPointCast 224 225Operators with SMTLIB2 equivalents have the same semantics. Mem and Dom 226operations are wrap the SMTLIB2 BitVec Array type with more convenient 227operations. 228 229Memory accesses and updates can operate on various word types. 230 231The pointer-validity operators PValid, PWeakValid, PGlobalValid all take 3 232arguments of type HTD, Type, and Word 32. The PValid operator asserts that 233the heap-type-description contains this type starting at this address. This 234is exclusive with any incompatible type. PGlobalValid additionally asserts that 235this is a global object, and therefore not contained within any larger object. 236PWeakValid asserts that the type could be valid (it is aligned and within the 237range of available addresses in the heap-type-description) and is needed only 238in rare circumstances. PAlignValid omits the HTD argument and asserts only that 239the pointer is appropriately aligned and that the object does not start at or 240wrap around the 0 address. PArrayValid takes an additional argument (HTD, 241Type, Word 32, Word 32) where the final argument specifies the number of 242entries in an array. 243 244The MemDom operator takes argument types [Word 32, Dom] and returns the boolean 245of whether this pointer is in this domain. 246 247The IfThenElse operator takes a bool and any two arguments of the same type. 248 249The FloatingPoint operators are mostly taken from the SMTLIB2 floating 250point standard. The conversions ToFloatingPoint ([Word] to FP), 251ToFloatingPointSigned and ToFloatingPointUnsigned ([RoundingMode, Word] to FP) 252and FloatingPointCast (FP to FP) represent the variants of to_fp in the SMTLIB2 253standard. 254""" 255 256class Type: 257 def __init__ (self, kind, name, el_typ=None): 258 self.kind = kind 259 if kind in ['Array', 'Word', 'TokenWords']: 260 self.num = int (name) 261 else: 262 self.name = name 263 if kind in ['Array', 'Ptr']: 264 self.el_typ_symb = el_typ 265 self.el_typ = concrete_type (el_typ) 266 if kind in ['WordArray', 'FloatingPoint']: 267 self.nums = [int (name), int (el_typ)] 268 269 def __repr__ (self): 270 if self.kind == 'Array': 271 return 'Type ("Array", %r, %r)' % (self.num, 272 self.el_typ_symb) 273 elif self.kind in ('Word', 'TokenWords'): 274 return 'Type (%r, %r)' % (self.kind, self.num) 275 elif self.kind == 'Ptr': 276 return 'Type ("Ptr", %r)' % self.el_typ_symb 277 elif self.kind in ('WordArray', 'FloatingPoint'): 278 return 'Type (%r, %r, %r)' % tuple ([self.kind] 279 + self.nums) 280 else: 281 return 'Type (%r, %r)' % (self.kind, self.name) 282 283 def __eq__ (self, other): 284 if not other: 285 return False 286 if self.kind != other.kind: 287 return False 288 if self.kind in ['Array', 'Word', 'TokenWords']: 289 if self.num != other.num: 290 return False 291 elif self.kind in ['WordArray', 'FloatingPoint']: 292 if self.nums != other.nums: 293 return False 294 else: 295 if self.name != other.name: 296 return False 297 if self.kind in ['Array', 'Ptr']: 298 if self.el_typ_symb != other.el_typ_symb: 299 return False 300 return True 301 302 def __ne__ (self, other): 303 return not other or not (self == other) 304 305 def __hash__ (self): 306 return hash(str(self)) 307 308 def __cmp__ (self, other): 309 self_ss = [] 310 self.serialise (self_ss) 311 other_ss = [] 312 other.serialise (other_ss) 313 return cmp (self_ss, other_ss) 314 315 def subtypes (self): 316 if self.kind == 'Struct': 317 return structs[self.name].subtypes() 318 elif self.kind == 'Array': 319 return [self] + self.el_typ.subtypes() 320 else: 321 return [self] 322 323 def size (self): 324 if self.kind == 'Struct': 325 return structs[self.name].size 326 elif self.kind == 'Array': 327 return self.el_typ.size() * self.num 328 elif self.kind == 'Word': 329 assert self.num % 8 == 0, self 330 return self.num / 8 331 elif self.kind == 'FloatingPoint': 332 sz = sum (self.nums) 333 assert sz % 8 == 0, self 334 return sz / 8 335 elif self.kind == 'Ptr': 336 return 4 337 else: 338 assert not 'type has size' 339 340 def align (self): 341 if self.kind == 'Struct': 342 return structs[self.name].align 343 elif self.kind == 'Array': 344 return self.el_typ.align () 345 elif self.kind in ('Word', 'FloatingPoint'): 346 return self.size () 347 elif self.kind == 'Ptr': 348 return 4 349 else: 350 assert not 'type has alignment' 351 352 def serialise (self, xs): 353 if self.kind in ('Word', 'TokenWords'): 354 xs.append (self.kind) 355 xs.append (str (self.num)) 356 elif self.kind in ('WordArray', 'FloatingPoint'): 357 xs.append (self.kind) 358 xs.extend ([str (n) for n in self.nums]) 359 elif self.kind == 'Builtin': 360 xs.append (self.name) 361 elif self.kind == 'Array': 362 xs.append ('Array') 363 self.el_typ_symb.serialise (xs) 364 xs.append (str (self.num)) 365 elif self.kind == 'Struct': 366 xs.append ('Struct') 367 xs.append (self.name) 368 elif self.kind == 'Ptr': 369 xs.append ('Ptr') 370 self.el_typ_symb.serialise (xs) 371 else: 372 assert not 'type serialisable', self.kind 373 374class Expr: 375 def __init__ (self, kind, typ, name = None, struct = None, 376 field = None, val = None, vals = None): 377 self.kind = kind 378 self.typ = typ 379 if name != None: 380 self.name = name 381 if struct != None: 382 self.struct = struct 383 if field != None: 384 self.field = field 385 if val != None: 386 self.val = val 387 if vals != None: 388 self.vals = vals 389 if kind == 'Op': 390 assert type (self.vals) == list 391 392 def binds (self): 393 binds = [] 394 if self.kind in set (['Symbol', 'Var', 'ConstGlobal', 'Token']): 395 binds.append(('name', self.name)) 396 elif self.kind in ['Array', 'StructCons']: 397 binds.append(('vals', self.vals)) 398 elif self.kind == 'Field': 399 binds.append(('struct', self.struct)) 400 binds.append(('field', self.field)) 401 elif self.kind == 'FieldUpd': 402 binds.append(('struct', self.struct)) 403 binds.append(('field', self.field)) 404 binds.append(('val', self.val)) 405 elif self.kind == 'Num': 406 binds.append(('val', self.val)) 407 elif self.kind == 'Op': 408 binds.append(('name', self.name)) 409 binds.append(('vals', self.vals)) 410 elif self.kind == 'Type': 411 binds.append(('val', self.val)) 412 elif self.kind == 'SMTExpr': 413 binds.append(('val', self.val)) 414 else: 415 assert not 'expression understood for repr', self.kind 416 return binds 417 418 def __repr__ (self): 419 bits = [repr(self.kind), repr(self.typ)] 420 bits.extend(['%s = %r' % b for b in self.binds()]) 421 return 'Expr (%s)' % ', '.join(bits) 422 423 def __eq__ (self, other): 424 return (other and self.kind == other.kind 425 and self.typ == other.typ 426 and self.binds() == other.binds()) 427 428 def __ne__ (self, other): 429 return not other or not (self == other) 430 431 def __hash__ (self): 432 return hash_tuplify (self.kind, self.typ, self.binds ()) 433 434 def __cmp__ (self, other): 435 return cmp ((self.kind, self.typ, self.binds ()), 436 (other.kind, other.typ, other.binds ())) 437 438 def is_var (self, (nm, typ)): 439 return self.kind == 'Var' and all([self.name == nm, 440 self.typ == typ]) 441 442 def is_op (self, nm): 443 if type (nm) == str: 444 return self.kind == 'Op' and self.name == nm 445 else: 446 return self.kind == 'Op' and self.name in nm 447 448 def visit (self, visit): 449 visit (self) 450 if self.kind == 'Var': 451 pass 452 elif self.kind == 'Op' or self.kind == 'Array': 453 for x in self.vals: 454 x.visit (visit) 455 elif self.kind == 'StructCons': 456 for x in self.vals.itervalues (): 457 x.visit (visit) 458 elif self.kind == 'Field': 459 self.struct.visit (visit) 460 461 struct_typ = self.struct.typ 462 assert struct_typ.kind == 'Struct' 463 struct = structs[struct_typ.name] 464 (name, typ) = self.field 465 assert struct.fields[name][0] == typ 466 elif self.kind == 'FieldUpd': 467 self.val.visit (visit) 468 self.struct.visit (visit) 469 470 struct_typ = self.struct.typ 471 assert struct_typ.kind == 'Struct' 472 struct = structs[struct_typ.name] 473 (name, typ) = self.field 474 assert struct.fields[name][0] == typ 475 elif self.kind == 'ConstGlobal': 476 assert (target_objects.const_globals[self.name].typ 477 == self.typ) 478 elif self.kind in set (['Num', 'Symbol', 'Type', 'Token']): 479 pass 480 else: 481 assert not 'expr understood', self 482 483 def gen_visit (self, visit_lval, visit_rval): 484 self.visit (visit_rval) 485 486 def subst (self, substor, ss = None): 487 ret = False 488 if self.kind == 'Op': 489 subst_vals = subst_list (substor, self.vals) 490 if subst_vals: 491 self = Expr ('Op', self.typ, name = self.name, 492 vals = subst_vals) 493 ret = True 494 if (ss == None or self.kind in ss or self.is_op (ss)): 495 r = substor (self) 496 if r != None: 497 return r 498 if ret: 499 return self 500 return 501 502 def add_const_ranges (self, ranges): 503 def visit (expr): 504 if expr.kind == 'ConstGlobal': 505 (start, size, _) = symbols[expr.name] 506 assert size == expr.typ.size () 507 ranges[expr.name] = (start, start + size - 1) 508 509 self.visit (visit) 510 511 def get_mem_access (self): 512 if self.is_op ('MemAcc'): 513 [m, p] = self.vals 514 return [('MemAcc', p, self, m)] 515 elif self.is_op ('MemUpdate'): 516 [m, p, v] = self.vals 517 return [('MemUpdate', p, v, m)] 518 else: 519 return [] 520 521 def get_mem_accesses (self): 522 accesses = [] 523 def visit (expr): 524 accesses.extend (expr.get_mem_access ()) 525 self.visit (visit) 526 return accesses 527 528 def serialise (self, xs): 529 xs.append (self.kind) 530 if self.kind == 'Op': 531 xs.append (self.name) 532 self.typ.serialise (xs) 533 xs.append (str (len (self.vals))) 534 for v in self.vals: 535 v.serialise (xs) 536 elif self.kind == 'Num': 537 xs.append (str (self.val)) 538 self.typ.serialise (xs) 539 elif self.kind == 'Var': 540 xs.append (self.name) 541 self.typ.serialise (xs) 542 elif self.kind == 'Type': 543 self.val.serialise (xs) 544 elif self.kind == 'Token': 545 xs.extend ([self.kind, self.name]) 546 self.typ.serialise (xs) 547 else: 548 assert not 'expr serialisable', self.kind 549 550class Struct: 551 def __init__ (self, name, size, align): 552 self.name = name 553 self.size = size 554 self.align = align 555 self.field_list = [] 556 self.fields = {} 557 self._subtypes = None 558 self.typ = Type ('Struct', name) 559 560 def add_field (self, name, typ, offset): 561 concrete = concrete_type (typ) 562 self.field_list.append ((name, concrete)) 563 self.fields[name] = (concrete, offset, typ) 564 assert self._subtypes == None 565 566 def subtypes (self): 567 if self._subtypes != None: 568 return self._subtypes 569 xs = [self.typ] 570 for (concrete, offs, typ2) in self.fields.itervalues(): 571 xs.extend(typ2.subtypes()) 572 self._subtypes = xs 573 return xs 574 575def tuplify (x): 576 if type(x) == tuple or type(x) == list: 577 return tuple ([tuplify (y) for y in x]) 578 if type(x) == dict: 579 return tuple ([tuplify (y) for y in x.iteritems ()]) 580 else: 581 return x 582 583def hash_tuplify (* xs): 584 return hash (tuplify (xs)) 585 586def subst_list (substor, xs, ss = None): 587 ys = [x.subst (substor, ss = ss) for x in xs] 588 if [y for y in ys if y != None]: 589 xs = list (xs) 590 for (i, y) in enumerate (ys): 591 if y != None: 592 xs[i] = y 593 return xs 594 else: 595 return 596 597class Node: 598 def __init__ (self, kind, conts, args): 599 self.kind = kind 600 601 if type (conts) == list: 602 if len (conts) == 1: 603 the_cont = conts[0] 604 else: 605 the_cont = conts 606 607 if kind == 'Basic': 608 self.cont = the_cont 609 self.upds = [(lv, v) for (lv, v) in args 610 if not v.is_var (lv)] 611 elif kind == 'Call': 612 self.cont = the_cont 613 (self.fname, self.args, self.rets) = args 614 elif kind == 'Cond': 615 (self.left, self.right) = conts 616 self.cond = args 617 else: 618 assert not 'node kind understood', self.kind 619 620 def __repr__ (self): 621 return 'Node (%r, %r, %r)' % (self.kind, 622 self.get_conts (), self.get_args ()) 623 624 def __hash__ (self): 625 if self.kind == 'Call': 626 return hash ((self.fname, tuple (self.args), 627 tuple (self.rets), self.cont)) 628 elif self.kind == 'Basic': 629 return hash (tuple (self.upds)) 630 elif self.kind == 'Cond': 631 return hash ((self.cond, self.left, self.right)) 632 else: 633 assert not 'node kind understood', self.kind 634 635 def __eq__ (self, other): 636 return all ([self.kind == other.kind, 637 self.get_conts () == other.get_conts (), 638 self.get_args () == other.get_args ()]) 639 640 def __ne__ (self, other): 641 return not other or not self == other 642 643 def get_args (self): 644 if self.kind == 'Basic': 645 return self.upds 646 elif self.kind == 'Call': 647 return (self.fname, self.args, self.rets) 648 else: 649 return self.cond 650 651 def get_conts (self): 652 if self.kind == 'Cond': 653 return [self.left, self.right] 654 else: 655 return [self.cont] 656 657 def get_lvals (self): 658 if self.kind == 'Basic': 659 return [lv for (lv, v) in self.upds] 660 elif self.kind == 'Call': 661 return self.rets 662 else: 663 return [] 664 665 def is_noop (self): 666 if self.kind == 'Basic': 667 return self.upds == [] 668 elif self.kind == 'Cond': 669 return self.left == self.right 670 else: 671 return False 672 673 def visit (self, visit_lval, visit_rval): 674 if self.kind == 'Basic': 675 for (lv, v) in self.upds: 676 visit_lval (lv) 677 v.visit (visit_rval) 678 elif self.kind == 'Cond': 679 self.cond.visit (visit_rval) 680 elif self.kind == 'Call': 681 for v in self.args: 682 v.visit (visit_rval) 683 for r in self.rets: 684 visit_lval (r) 685 686 def gen_visit (self, visit_lval, visit_rval): 687 self.visit (visit_lval, visit_rval) 688 689 def subst_exprs (self, substor, ss = None): 690 if self.kind == 'Basic': 691 rvs = subst_list (substor, 692 [v for (lv, v) in self.upds], ss = ss) 693 if rvs == None: 694 return self 695 return Node ('Basic', self.cont, 696 zip ([lv for (lv, v) in self.upds], rvs)) 697 elif self.kind == 'Cond': 698 r = self.cond.subst (substor, ss = ss) 699 if r == None: 700 return self 701 return Node ('Cond', [self.left, self.right], r) 702 elif self.kind == 'Call': 703 args = subst_list (substor, self.args, ss = ss) 704 if args == None: 705 return self 706 return Node ('Call', self.cont, (self.fname, 707 args, self.rets)) 708 709 def get_mem_accesses (self): 710 accesses = [] 711 def visit (expr): 712 accesses.extend (expr.get_mem_access ()) 713 self.visit (lambda x: (), visit) 714 return accesses 715 716 def err_cond (self): 717 if self.kind != 'Cond': 718 return None 719 if self.left != 'Err': 720 if self.right == 'Err': 721 return mk_not (self.cond) 722 else: 723 return None 724 else: 725 if self.right == 'Err': 726 return true_term 727 else: 728 return self.cond 729 730 def serialise (self, xs): 731 xs.append (self.kind) 732 xs.extend ([str (c) for c in self.get_conts ()]) 733 if self.kind == 'Basic': 734 xs.append (str (len (self.upds))) 735 for (lv, v) in self.upds: 736 xs.append (lv[0]) 737 lv[1].serialise (xs) 738 v.serialise (xs) 739 elif self.kind == 'Cond': 740 self.cond.serialise (xs) 741 elif self.kind == 'Call': 742 xs.append (self.fname) 743 xs.append (str (len (self.args))) 744 for arg in self.args: 745 arg.serialise (xs) 746 xs.append (str (len (self.rets))) 747 for (nm, typ) in self.rets: 748 xs.append (nm) 749 typ.serialise (xs) 750 751def rename_lval ((name, typ), renames): 752 return (renames.get (name, name), typ) 753 754def do_subst (expr, substor, ss = None): 755 r = expr.subst (substor, ss = ss) 756 if r == None: 757 return expr 758 else: 759 return r 760 761standard_expr_kinds = set (['Symbol', 'ConstGlobal', 'Var', 'Op', 'Num', 762 'Type']) 763 764def rename_expr_substor (renames): 765 def ren (expr): 766 if expr.kind == 'Var' and expr.name in renames: 767 return mk_var (renames[expr.name], expr.typ) 768 else: 769 return 770 return ren 771 772def rename_expr (expr, renames): 773 return do_subst (expr, rename_expr_substor (renames), 774 ss = set (['Var'])) 775 776def copy_rename (node, renames): 777 (vs, ns) = renames 778 nf = lambda n: ns.get (n, n) 779 node = node.subst_exprs (rename_expr_substor (vs)) 780 if node.kind == 'Call': 781 return Node ('Call', nf (node.cont), (node.fname, node.args, 782 [rename_lval (l, vs) for l in node.rets])) 783 elif node.kind == 'Basic': 784 return Node ('Basic', nf (node.cont), 785 [(rename_lval (lv, vs), v) for (lv, v) in node.upds]) 786 elif node.kind == 'Cond': 787 return Node ('Cond', [nf (node.left), nf (node.right)], 788 node.cond) 789 else: 790 assert not 'node kind understood', node.kind 791 792class Function: 793 def __init__ (self, name, inputs, outputs): 794 self.name = name 795 self.inputs = inputs 796 self.outputs = outputs 797 self.entry = None 798 self.nodes = {} 799 800 def __hash__ (self): 801 en = self.entry 802 if not en: 803 en = -1 804 return hash (tuple ([self.name, tuple (self.inputs), 805 tuple (self.outputs), en]) 806 + tuple (sorted (self.nodes.iteritems ()))) 807 808 def reachable_nodes (self, simplify = False): 809 if not self.entry: 810 return {} 811 rs = {} 812 vs = [self.entry] 813 while vs: 814 n = vs.pop() 815 if type (n) == str: 816 continue 817 rs[n] = True 818 node = self.nodes[n] 819 if simplify: 820 import logic 821 node = logic.simplify_node_elementary (node) 822 for c in node.get_conts (): 823 if not c in rs: 824 vs.append (c) 825 return rs 826 827 def serialise (self): 828 xs = ['Function', self.name, str (len (self.inputs))] 829 for (nm, typ) in self.inputs: 830 xs.append (nm) 831 typ.serialise (xs) 832 xs.append (str (len (self.outputs))) 833 for (nm, typ) in self.outputs: 834 xs.append (nm) 835 typ.serialise (xs) 836 ss = [' '.join (xs)] 837 if not self.entry: 838 return ss 839 for n in self.nodes: 840 xs = [str (n)] 841 self.nodes[n].serialise (xs) 842 ss.append (' '.join (xs)) 843 ss.append ('EntryPoint %d' % self.entry) 844 return ss 845 846 def as_problem (self, Problem, name = 'temp'): 847 p = Problem(None, 'Function (%s)' % self.name) 848 p.clone_function (self, name) 849 p.compute_preds () 850 return p 851 852 def function_call_addrs (self): 853 return [(n, self.nodes[n].fname) 854 for n in self.nodes if self.nodes[n].kind == 'Call'] 855 856 def function_calls (self): 857 return set ([fn for (n, fn) in self.function_call_addrs ()]) 858 859 def compile_hints (self, Problem): 860 xs = ['Hints %s' % self.name] 861 862 p = self.as_problem (Problem) 863 864 for n in p.nodes: 865 ys = ['VarDeps', str (n)] 866 for (nm, typ) in p.var_deps[n]: 867 ys.append (nm) 868 typ.serialise (ys) 869 xs.append (' '.join (ys)) 870 if self.nodes[n].kind != 'Basic': 871 continue 872 return xs 873 874 def save_graph (self, fname): 875 import problem 876 problem.save_graph (self.nodes, fname) 877 878def mk_builtinTs (): 879 return dict([(n, Type('Builtin', n)) for n 880 in 'Bool Mem Dom HTD PMS UNIT Type Token RelWrapper'.split()]) 881builtinTs = mk_builtinTs () 882boolT = builtinTs['Bool'] 883word32T = Type ('Word', '32') 884word64T = Type ('Word', '64') 885word16T = Type ('Word', 16) 886word8T = Type ('Word', '8') 887 888phantom_types = set ([builtinTs[t] for t 889 in 'Dom HTD PMS UNIT Type'.split ()]) 890 891def concrete_type (typ): 892 if typ.kind == 'Ptr': 893 return word32T 894 else: 895 return typ 896 897global_wrappers = {} 898def get_global_wrapper (typ): 899 if typ in global_wrappers: 900 return global_wrappers[typ] 901 struct_name = fresh_name ('Global (%s)' % typ, structs) 902 struct = Struct (struct_name, typ.size (), typ.align ()) 903 struct.add_field ('v', typ, 0) 904 structs[struct_name] = struct 905 906 global_wrappers[typ] = struct.typ 907 return struct.typ 908 909# ========================================================== 910# parsing code for types, expressions, structs and functions 911 912def parse_int (s): 913 if s.startswith ('-') or s.startswith ('~'): 914 return (- parse_int (s[1:])) 915 if s.startswith ('0x') or s.startswith ('0b'): 916 return int (s, 0) 917 else: 918 return int (s) 919 920def parse_typ(bits, n, symbolic_types = False): 921 if bits[n] == 'Word' or bits[n] == 'BitVec': 922 return (n + 2, Type('Word', parse_int (bits[n + 1]))) 923 elif bits[n] == 'WordArray' or bits[n] == 'FloatingPoint': 924 return (n + 3, Type(bits[n], 925 parse_int (bits[n + 1]), parse_int (bits[n + 2]))) 926 elif bits[n] in builtinTs: 927 return (n + 1, builtinTs[bits[n]]) 928 elif bits[n] == 'Array': 929 (n, typ) = parse_typ (bits, n + 1, True) 930 return (n + 1, Type ('Array', parse_int (bits[n]), typ)) 931 elif bits[n] == 'Struct': 932 return (n + 2, Type ('Struct', bits[n + 1])) 933 elif bits[n] == 'Ptr': 934 (n, typ) = parse_typ (bits, n + 1, True) 935 if symbolic_types: 936 return (n, Type ('Ptr', '', typ)) 937 else: 938 return (n, word32T) 939 else: 940 assert not 'type encoded', (n, bits[n:], bits) 941 942def node_name (name): 943 if name in {'Ret':True, 'Err':True}: 944 return name 945 else: 946 try: 947 return parse_int (name) 948 except ValueError, e: 949 assert not 'node name understood', name 950 951def parse_list (parser, bits, n, extra=None): 952 try: 953 num = parse_int (bits[n]) 954 except ValueError: 955 assert not 'number parseable', (n, bits[n:], bits) 956 n = n + 1 957 xs = [] 958 for i in range (num): 959 if extra: 960 (n, x) = parser(bits, n, extra) 961 else: 962 (n, x) = parser(bits, n) 963 xs.append(x) 964 return (n, xs) 965 966def parse_arg (bits, n): 967 nm = bits[n] 968 (n, typ) = parse_typ (bits, n + 1) 969 return (n, (nm, typ)) 970 971ops = {'Plus':2, 'Minus':2, 'Times':2, 'Modulus':2, 972 'DividedBy':2, 'BWAnd':2, 'BWOr':2, 'BWXOR':2, 'And':2, 973 'Or':2, 'Implies':2, 'Equals':2, 'Less':2, 974 'LessEquals':2, 'SignedLess':2, 'SignedLessEquals':2, 975 'ShiftLeft':2, 'ShiftRight':2, 'CountLeadingZeroes':1, 976 'CountTrailingZeroes':1, 'WordReverse':1, 'SignedShiftRight':2, 977 'Not':1, 'BWNot':1, 'WordCast':1, 'WordCastSigned':1, 978 'True':0, 'False':0, 'UnspecifiedPrecond':0, 979 'MemUpdate':3, 'MemAcc':2, 'IfThenElse':3, 'ArrayIndex':2, 980 'ArrayUpdate':3, 'MemDom':2, 981 'PValid':3, 'PWeakValid':3, 'PAlignValid':2, 'PGlobalValid':3, 982 'PArrayValid':4, 983 'HTDUpdate':5, 'WordArrayAccess':2, 'WordArrayUpdate':3, 984 'TokenWordsAccess':2, 'TokenWordsUpdate':3, 985 'ROData':1, 'StackWrapper':2, 986 'ToFloatingPoint':1, 'ToFloatingPointSigned':2, 987 'ToFloatingPointUnsigned':2, 'FloatingPointCast':1, 988} 989 990ops_to_smt = {'Plus':'bvadd', 'Minus':'bvsub', 'Times':'bvmul', 'Modulus':'bvurem', 991 'DividedBy':'bvudiv', 'BWAnd':'bvand', 'BWOr':'bvor', 'BWXOR':'bvxor', 992 'And':'and', 993 'Or':'or', 'Implies':'=>', 'Equals':'=', 'Less':'bvult', 994 'LessEquals':'bvule', 'SignedLess':'bvslt', 'SignedLessEquals':'bvsle', 995 'ShiftLeft':'bvshl', 'ShiftRight':'bvlshr', 'SignedShiftRight':'bvashr', 996 'Not':'not', 'BWNot':'bvnot', 997 'True':'true', 'False':'false', 998 'UnspecifiedPrecond': 'unspecified-precond', 999 'IfThenElse':'ite', 'MemDom':'mem-dom', 1000 'ROData': 'rodata', 'ImpliesROData': 'implies-rodata', 1001 'WordArrayAccess':'select', 'WordArrayUpdate':'store'} 1002 1003ex_smt_ops = """roundNearestTiesToEven RNE roundNearestTiesToAway RNA 1004 roundTowardPositive RTP roundTowardNegative RTN 1005 roundTowardZero RTZ 1006 fp.abs fp.ne fp.add fp.sub fp.mul fp.div fp.fma fp.sqrt fp.rem 1007 fp.roundToInteral fp.min fp.max fp.leq fp.lt fp.eq fp.t fp.eq 1008 fp.isNormal fp.IsSubnormal fp.isZero fp.isInfinite fp.isNaN 1009 fp.isNeative fp.isPositive""".split () 1010 1011ops_to_smt.update (dict ([(smt, smt) for smt in ex_smt_ops])) 1012 1013smt_to_ops = dict ([(smt, oper) for (oper, smt) in ops_to_smt.iteritems ()]) 1014 1015def parse_struct_elem (bits, n): 1016 name = bits[n] 1017 (n, typ) = parse_typ (bits, n + 1) 1018 (n, val) = parse_expr (bits, n) 1019 return (n, (name, val)) 1020 1021def parse_expr (bits, n): 1022 if bits[n] in set (['Symbol', 'Var', 'ConstGlobal', 'Token']): 1023 kind = bits[n] 1024 name = bits[n + 1] 1025 (n, typ) = parse_typ (bits, n + 2) 1026 return (n, Expr (kind, typ, name = name)) 1027 if bits[n] == 'Array': 1028 (n, typ) = parse_typ (bits, n + 1) 1029 assert typ.kind == 'Array' 1030 (n, xs) = parse_list (parse_expr, bits, n) 1031 assert len(xs) == typ.num 1032 return (n, Expr ('Array', typ, vals = xs)) 1033 elif bits[n] == 'Field': 1034 (n, typ) = parse_typ (bits, n + 1) 1035 name = bits[n] 1036 (n, typ2) = parse_typ (bits, n + 1) 1037 (n, struct) = parse_expr (bits, n) 1038 assert struct.typ == typ 1039 return (n, Expr ('Field', typ2, struct = struct, 1040 field = (name, typ2))) 1041 elif bits[n] == 'FieldUpd': 1042 (n, typ) = parse_typ (bits, n + 1) 1043 name = bits[n] 1044 (n, typ2) = parse_typ (bits, n + 1) 1045 (n, val) = parse_expr (bits, n) 1046 (n, struct) = parse_expr (bits, n) 1047 return (n, Expr ('FieldUpd', typ, struct = struct, 1048 field = (name, typ2), val = val)) 1049 elif bits[n] == 'StructCons': 1050 (n, typ) = parse_typ (bits, n + 1) 1051 (n, xs) = parse_list (parse_struct_elem, bits, n) 1052 return (n, Expr ('StructCons', typ, vals = dict(xs))) 1053 elif bits[n] == 'Num': 1054 v = parse_int (bits[n + 1]) 1055 (n, typ) = parse_typ (bits, n + 2) 1056 return (n, Expr ('Num', typ, val = v)) 1057 elif bits[n] == 'Op': 1058 op = bits[n + 1] 1059 op = smt_to_ops.get (op, op) 1060 assert op in ops, op 1061 (n, typ) = parse_typ (bits, n + 2) 1062 (n, xs) = parse_list (parse_expr, bits, n) 1063 assert len (xs) == ops[op] 1064 return (n, Expr ('Op', typ, name = op, vals = xs)) 1065 elif bits[n] == 'Type': 1066 (n, typ) = parse_typ (bits, n + 1, symbolic_types = True) 1067 return (n, Expr ('Type', builtinTs['Type'], val = typ)) 1068 else: 1069 assert not 'expr parseable', (n, bits[n : n + 3], bits) 1070 1071def parse_lval (bits, n): 1072 name = bits[n] 1073 (n, typ) = parse_typ (bits, n + 1) 1074 return (n, (name, typ)) 1075 1076def parse_lval_and_val (bits, n): 1077 (n, (name, typ)) = parse_lval (bits, n) 1078 (n, val) = parse_expr (bits, n) 1079 return (n, ((name, typ), val)) 1080 1081def parse_node (bits, n): 1082 if bits[n] == 'Basic': 1083 cont = node_name(bits[n + 1]) 1084 (n, upds) = parse_list (parse_lval_and_val, bits, n + 2) 1085 return Node ('Basic', cont, upds) 1086 elif bits[n] == 'Cond': 1087 left = node_name(bits[n + 1]) 1088 right = node_name(bits[n + 2]) 1089 (n, cond) = parse_expr (bits, n + 3) 1090 return Node ('Cond', (left, right), cond) 1091 else: 1092 assert bits[n] == 'Call' 1093 cont = node_name(bits[n + 1]) 1094 name = bits[n + 2] 1095 (n, args) = parse_list (parse_expr, bits, n + 3) 1096 (n, saves) = parse_list (parse_lval, bits, n) 1097 return Node ('Call', cont, (name, args, saves)) 1098 1099true_term = Expr ('Op', boolT, name = 'True', vals = []) 1100false_term = Expr ('Op', boolT, name = 'False', vals = []) 1101unspecified_precond_term = Expr ('Op', boolT, name = 'UnspecifiedPrecond', vals = []) 1102 1103def parse_all(lines): 1104 '''Toplevel parser for input information. Accepts an iterator over 1105lines. See syntax.quick_reference for an explanation.''' 1106 1107 if hasattr (lines, 'name'): 1108 trace ('Loading syntax from %s' % lines.name) 1109 else: 1110 trace ('Loading syntax (from anonymous source).') 1111 1112 structs = {} 1113 functions = {} 1114 const_globals = {} 1115 cfg_warnings = [] 1116 for line in lines: 1117 bits = line.split() 1118 # empty lines and #-comments ignored 1119 if not bits or bits[0][0] == '#': 1120 continue 1121 if bits[0] == 'Struct': 1122 # Struct <name> <size> <alignment> 1123 # followed by block of StructField lines 1124 assert bits[1] not in structs 1125 current_struct = Struct (bits[1], parse_int (bits[2]), 1126 parse_int (bits[3])) 1127 structs[bits[1]] = current_struct 1128 elif bits[0] == 'StructField': 1129 # StructField <name> <type (encoded)> <offset> 1130 (_, typ) = parse_typ(bits, 2, symbolic_types = True) 1131 current_struct.add_field (bits[1], typ, 1132 parse_int (bits[-1])) 1133 elif bits[0] == 'ConstGlobalDef': 1134 # ConstGlobalDef <name> <value> 1135 name = bits[1] 1136 (_, val) = parse_expr (bits, 2) 1137 const_globals[name] = val 1138 elif bits[0] == 'Function': 1139 # Function <name> <inputs> <outputs> 1140 # followed by optional block of node lines 1141 # concluded by EntryPoint line 1142 fname = bits[1] 1143 (n, inputs) = parse_list (parse_arg, bits, 2) 1144 (_, outputs) = parse_list (parse_arg, bits, n) 1145 current_function = Function (fname, inputs, outputs) 1146 assert fname not in functions, fname 1147 functions[fname] = current_function 1148 elif bits[0] == 'EntryPoint': 1149 # EntryPoint <entry point> 1150 entry = node_name(bits[1]) 1151 # instead of setting function.entry to this value, 1152 # create a dummy node. this ensures there is always 1153 # at least one node (EntryPoint Ret is valid) and 1154 # also that the entry point is not in a loop 1155 name = fresh_node (current_function.nodes) 1156 current_function.nodes[name] = Node ('Basic', 1157 entry, []) 1158 current_function.entry = name 1159 # ensure that the function graph is closed 1160 check_cfg (current_function, warnings = cfg_warnings) 1161 current_function = None 1162 else: 1163 # <node name> <node (encoded)> 1164 name = node_name(bits[0]) 1165 assert name not in current_function.nodes, (name, bits) 1166 current_function.nodes[name] = parse_node (bits, 1) 1167 1168 print_cfg_warnings (cfg_warnings) 1169 trace ('Loaded %d functions, %d structs, %d globals.' 1170 % (len (functions), len (structs), len (const_globals))) 1171 1172 return (structs, functions, const_globals) 1173 1174def parse_and_install_all (lines, tag, skip_functions=None): 1175 if skip_functions == None: 1176 skip_functions = [] 1177 (structs, functions, const_globals) = parse_all (lines) 1178 for f in skip_functions: 1179 if f in functions: 1180 del functions[f] 1181 target_objects.structs.update (structs) 1182 target_objects.functions.update (functions) 1183 target_objects.const_globals.update (const_globals) 1184 if tag != None: 1185 target_objects.functions_by_tag[tag] = set (functions) 1186 return (structs, functions, const_globals) 1187 1188# =============================== 1189# simple accessor code and checks 1190 1191def visit_rval (vs): 1192 def visit (expr): 1193 if expr.kind == 'Var': 1194 v = expr.name 1195 if v in vs: 1196 assert vs[v] == expr.typ, (expr, vs[v]) 1197 vs[v] = expr.typ 1198 if expr.is_op ('MemAcc'): 1199 [m, p] = expr.vals 1200 assert p.typ == word32T, expr 1201 if expr.is_op ('PGlobalValid'): 1202 [htd, typ_expr, p] = expr.vals 1203 typ = typ_expr.val 1204 get_global_wrapper (typ) 1205 1206 return visit 1207 1208def visit_lval (vs): 1209 def visit ((name, typ)): 1210 assert vs.get(name, typ) == typ, (name, vs[name], typ) 1211 vs[name] = typ 1212 1213 return visit 1214 1215def get_expr_vars (expr, vs): 1216 expr.visit (visit_rval (vs)) 1217 1218def get_expr_var_set (expr): 1219 vs = {} 1220 get_expr_vars (expr, vs) 1221 return set (vs.items ()) 1222 1223def get_lval_vars (lval, vs): 1224 assert len(lval) == 2 1225 assert vs.get(lval[0], lval[1]) == lval[1] 1226 vs[lval[0]] = lval[1] 1227 1228def get_node_vars (node, vs): 1229 node.visit (visit_lval (vs), visit_rval (vs)) 1230 1231def get_node_rvals (node, vs = None): 1232 if vs == None: 1233 vs = {} 1234 node.visit (lambda l: (), visit_rval (vs)) 1235 return vs 1236 1237def get_vars(function): 1238 vs = dict(function.inputs + function.outputs) 1239 for node in function.nodes.itervalues(): 1240 get_node_vars(node, vs) 1241 return vs 1242 1243def get_lval_typ(lval): 1244 assert len(lval) == 2 1245 return lval[1] 1246 1247def get_expr_typ(expr): 1248 return expr.typ 1249 1250def check_cfg (fun, warnings = None): 1251 dead_arcs = [(n, n2) for (n, node) in fun.nodes.iteritems () 1252 for n2 in node.get_conts () 1253 if n2 not in fun.nodes and n2 not in ['Ret', 'Err']] 1254 for (n, n2) in dead_arcs: 1255 assert type (n2) != str 1256 # OK if multiple dead arcs and we save over n2 twice 1257 fun.nodes[n2] = Node ('Basic', 'Err', []) 1258 if warnings == None: 1259 print_cfg_warnings ([(fun, n, n2) for (n, n2) in dead_arcs]) 1260 else: 1261 warnings.extend ([(fun, n, n2) for (n, n2) in dead_arcs]) 1262 1263def print_cfg_warnings (warnings): 1264 post_calls = set ([(fun.nodes[n].fname, fun.name) 1265 for (fun, n, n2) in warnings 1266 if fun.nodes[n].kind == 'Call']) 1267 import logic 1268 for (call, sites) in logic.dict_list (post_calls).iteritems (): 1269 trace ('Missing nodes after calls to %s' % call) 1270 trace (' in %s' % str (sites)) 1271 for (fun, n, n2) in warnings: 1272 if fun.nodes[n].kind != 'Call': 1273 trace ('Warning: dead arc in %s: %s -> %s' 1274 % (fun.name, n, n2)) 1275 trace (' (follows %s node!)' % fun.nodes[n].kind) 1276 1277def check_funs (functions, verbose = False): 1278 for (f, fun) in functions.iteritems(): 1279 if not fun: 1280 continue 1281 if verbose: 1282 trace ('Checking %s' % f) 1283 check_cfg (fun) 1284 get_vars(fun) 1285 for (n, node) in fun.nodes.iteritems(): 1286 if node.kind == 'Call': 1287 c = functions[node.fname] 1288 assert map(get_expr_typ, node.args) == \ 1289 map (get_lval_typ, c.inputs), ( 1290 node.fname, node.args, c.inputs) 1291 assert map (get_lval_typ, node.rets) == \ 1292 map (get_lval_typ, c.outputs), ( 1293 node.fname, node.rets, c.outputs) 1294 elif node.kind == 'Basic': 1295 for (lv, v) in node.upds: 1296 assert get_lval_typ(lv) == get_expr_typ(v) 1297 elif node.kind == 'Cond': 1298 assert get_expr_typ(node.cond) == boolT 1299 1300def get_extensions (v): 1301 extensions = set () 1302 rm = builtinTs['RoundingMode'] 1303 def visitor (expr): 1304 if expr.typ == rm or expr.typ.kind == 'FloatingPoint': 1305 extensions.add ('FloatingPoint') 1306 v.gen_visit (lambda l: (), visitor) 1307 return extensions 1308 1309# ========================================= 1310# common constructors for basic expressions 1311 1312def mk_var (nm, typ): 1313 return Expr ('Var', typ, name = nm) 1314 1315def mk_token (nm): 1316 return Expr ('Token', builtinTs['Token'], name = nm) 1317 1318def mk_plus (x, y): 1319 assert x.typ == y.typ 1320 return Expr ('Op', x.typ, name = 'Plus', vals = [x, y]) 1321 1322def mk_uminus (x): 1323 zero = Expr ('Num', x.typ, val = 0) 1324 return mk_minus (zero, x) 1325 1326def mk_minus (x, y): 1327 assert x.typ == y.typ 1328 return Expr ('Op', x.typ, name = 'Minus', vals = [x, y]) 1329 1330def mk_times (x, y): 1331 assert x.typ == y.typ 1332 return Expr ('Op', x.typ, name = 'Times', vals = [x, y]) 1333 1334def mk_divide (x, y): 1335 assert x.typ == y.typ 1336 return Expr ('Op', x.typ, name = 'DividedBy', vals = [x, y]) 1337 1338def mk_modulus (x, y): 1339 assert x.typ == y.typ 1340 return Expr ('Op', x.typ, name = 'Modulus', vals = [x, y]) 1341 1342def mk_bwand (x, y): 1343 assert x.typ == y.typ 1344 assert x.typ.kind == 'Word' 1345 return Expr ('Op', x.typ, name = 'BWAnd', vals = [x, y]) 1346 1347def mk_eq (x, y): 1348 assert x.typ == y.typ 1349 return Expr ('Op', boolT, name = 'Equals', vals = [x, y]) 1350 1351def mk_less_eq (x, y, signed = False): 1352 assert x.typ == y.typ 1353 name = {False: 'LessEquals', True: 'SignedLessEquals'}[signed] 1354 return Expr ('Op', boolT, name = name, vals = [x, y]) 1355 1356def mk_less (x, y, signed = False): 1357 assert x.typ == y.typ 1358 name = {False: 'Less', True: 'SignedLess'}[signed] 1359 return Expr ('Op', boolT, name = name, vals = [x, y]) 1360 1361def mk_implies (x, y): 1362 assert x.typ == boolT 1363 assert y.typ == boolT 1364 return Expr ('Op', boolT, name = 'Implies', vals = [x, y]) 1365 1366def mk_n_implies (xs, y): 1367 imp = y 1368 for x in reversed (xs): 1369 imp = mk_implies (x, imp) 1370 return imp 1371 1372def mk_and (x, y): 1373 assert x.typ == boolT 1374 assert y.typ == boolT 1375 return Expr ('Op', boolT, name = 'And', vals = [x, y]) 1376 1377def mk_or (x, y): 1378 assert x.typ == boolT 1379 assert y.typ == boolT 1380 return Expr ('Op', boolT, name = 'Or', vals = [x, y]) 1381 1382def mk_not (x): 1383 assert x.typ == boolT 1384 return Expr ('Op', boolT, name = 'Not', vals = [x]) 1385 1386def mk_shift_gen (name, x, n): 1387 assert x.typ.kind == 'Word' 1388 if type (n) == int: 1389 n = Expr ('Num', x.typ, val = n) 1390 return Expr ('Op', x.typ, name = name, vals = [x, n]) 1391 1392mk_shiftr = lambda x, n: mk_shift_gen ('ShiftRight', x, n) 1393 1394def mk_clz (x): 1395 return Expr ('Op', x.typ, name = "CountLeadingZeroes", vals = [x]) 1396 1397def mk_word_reverse (x): 1398 return Expr ('Op', x.typ, name = "WordReverse", vals = [x]) 1399 1400def mk_ctz (x): 1401 return mk_clz (mk_word_reverse (x)) 1402 1403def foldr1 (f, xs): 1404 x = xs[-1] 1405 for i in reversed (range (len (xs) - 1)): 1406 x = f (xs[i], x) 1407 return x 1408 1409def mk_num (x, typ): 1410 import logic 1411 if logic.is_int (typ): 1412 typ = Type ('Word', typ) 1413 assert typ.kind == 'Word', typ 1414 assert logic.is_int (x), x 1415 return Expr ('Num', typ, val = x) 1416 1417def mk_word32 (x): 1418 return mk_num (x, word32T) 1419 1420def mk_word8 (x): 1421 return mk_num (x, word8T) 1422 1423def mk_word32_maybe(x): 1424 import logic 1425 if logic.is_int (x): 1426 return mk_word32 (x) 1427 else: 1428 assert x.typ == word32T 1429 return x 1430 1431def mk_cast (x, typ): 1432 if x.typ == typ: 1433 return x 1434 else: 1435 assert x.typ.kind == 'Word', x.typ 1436 assert typ.kind == 'Word', typ 1437 return Expr ('Op', typ, name = 'WordCast', vals = [x]) 1438 1439def mk_memacc(m, p, typ): 1440 assert m.typ == builtinTs['Mem'] 1441 assert p.typ == word32T 1442 return Expr ('Op', typ, name = 'MemAcc', vals = [m, p]) 1443 1444def mk_memupd(m, p, v): 1445 assert m.typ == builtinTs['Mem'] 1446 assert p.typ == word32T 1447 return Expr ('Op', m.typ, name = 'MemUpdate', vals = [m, p, v]) 1448 1449def mk_arr_index (arr, i): 1450 assert arr.typ.kind == 'Array' 1451 return Expr ('Op', arr.typ.el_typ, name = 'ArrayIndex', 1452 vals = [arr, i]) 1453 1454def mk_arroffs(p, typ, i): 1455 assert typ.kind == 'Array' 1456 import logic 1457 if logic.is_int (i): 1458 assert i < typ.num 1459 offs = i * typ.el_typ.size() 1460 assert offs == i or offs % 4 == 0 1461 return mk_plus (p, mk_word32 (offs)) 1462 else: 1463 sz = typ.el_typ.size() 1464 return mk_plus (p, mk_times (i, mk_word32 (sz))) 1465 1466def mk_if (P, x, y): 1467 assert P.typ == boolT 1468 assert x.typ == y.typ 1469 return Expr ('Op', x.typ, name = 'IfThenElse', vals = [P, x, y]) 1470 1471def mk_meta_typ (typ): 1472 return Expr ('Type', builtinTs['Type'], val = typ) 1473 1474def mk_pvalid (htd, typ, p): 1475 return Expr ('Op', boolT, name = 'PValid', 1476 vals = [htd, mk_meta_typ (typ), p]) 1477 1478def mk_rel_wrapper (nm, vals): 1479 return Expr ('Op', builtinTs['RelWrapper'], name = nm, vals = vals) 1480 1481def adjust_op_vals (expr, vals): 1482 assert expr.kind == 'Op' 1483 return Expr ('Op', expr.typ, expr.name, vals = vals) 1484 1485mks = (mk_var, mk_plus, mk_uminus, mk_minus, mk_times, mk_modulus, mk_bwand, 1486mk_eq, mk_less_eq, mk_less, mk_implies, mk_and, mk_or, mk_not, mk_word32, 1487mk_word8, mk_word32_maybe, mk_cast, mk_memacc, mk_memupd, mk_arr_index, 1488mk_arroffs, mk_if, mk_meta_typ, mk_pvalid) 1489 1490# ==================================================================== 1491# pretty printing code for the syntax - only used for printing reports 1492 1493def pretty_type (typ): 1494 if typ.kind == 'Word': 1495 return 'Word%d' % typ.num 1496 elif typ.kind == 'WordArray': 1497 [ix, num] = typ.nums 1498 return 'Word%d[%d]' % (num, ix) 1499 elif typ.kind == 'Ptr': 1500 return 'Ptr(%s)' % pretty_type (typ.el_typ_symb) 1501 elif typ.kind == 'Struct': 1502 return 'struct %s' % typ.name 1503 elif typ.kind == 'Builtin': 1504 return typ.name 1505 else: 1506 assert not 'type pretty-printable', typ 1507 1508pretty_opers = {'Plus': '+', 'Minus': '-', 'Times': '*'} 1509 1510known_typ_change = set (['ROData', 'MemAcc', 'IfThenElse', 'WordArrayUpdate', 1511 'MemDom']) 1512 1513def pretty_expr (expr, print_type = False): 1514 if print_type: 1515 return '((%s) (%s))' % (pretty_type (expr.typ), 1516 pretty_expr (expr)) 1517 elif expr.kind == 'Var': 1518 return repr (expr.name) 1519 elif expr.kind == 'Num': 1520 return '%d' % expr.val 1521 elif expr.kind == 'Op' and expr.name in pretty_opers: 1522 [x, y] = expr.vals 1523 return '(%s %s %s)' % (pretty_expr (x), pretty_opers[expr.name], 1524 pretty_expr (y)) 1525 elif expr.kind == 'Op': 1526 if expr.name in known_typ_change: 1527 vals = [pretty_expr (v) for v in expr.vals] 1528 else: 1529 vals = [pretty_expr (v, print_type = v.typ != expr.typ) 1530 for v in expr.vals] 1531 return '%s(%s)' % (expr.name, ', '.join (vals)) 1532 elif expr.kind == 'Token': 1533 return "''%s''" % expr.name 1534 else: 1535 assert not 'expr pretty-printable', expr 1536 1537 1538# ================================================= 1539# some helper code that's needed all over the place 1540 1541def fresh_name (n, D, v=True): 1542 if n not in D: 1543 D[n] = v 1544 return n 1545 1546 x = 1 1547 y = 1 1548 while ('%s.%d' % (n, x)) in D: 1549 y = x 1550 x = x * 2 1551 while y < x: 1552 z = (y + x) / 2 1553 if ('%s.%d' % (n, z)) in D: 1554 y = z + 1 1555 else: 1556 x = z 1557 n = '%s.%d' % (n, x) 1558 assert n not in D 1559 1560 D[n] = v 1561 return n 1562 1563def fresh_node (ns, hint = 1): 1564 n = hint 1565 # use names that are *not* multiples of 4 1566 n = (n | 15) + 2 1567 while n in ns: 1568 n += 16 1569 return n 1570 1571