1# * Copyright 2015, NICTA
2# *
3# * This software may be distributed and modified according to the terms of
4# * the BSD 2-Clause license. Note that NO WARRANTY is provided.
5# * See "LICENSE_BSD2.txt" for details.
6# *
7# * @TAG(NICTA_BSD)
8
9# Syntax and simple operations for types, expressions, graph nodes
10# and graph functions (functions in graph-language format).
11
12from target_objects import structs, trace
13import target_objects
14
15quick_reference = """
16Quick reference on the graph language and its syntax.
17=====================================================
18
19Example
20=======
21
22Let's build up the syntax that roughly mirrors the C statement x = y + 1;
23
24This is a quick example. A more thorough reference is below.
25
26We have two example types: Bool and Word 32.
27
28Here are two atomic expressions of type Word 32:
29Var y Word 32
30Num 1 Word 32
31
32These encode the variable y and the number 1, both of type Word 32.
33
34Expressions in this language are built up by concatenating strings of tokens.
35Any whitespace delimited string can be a token, and thus a variable name.
36
37Compound expressions are built with operators, e.g. y + 1:
38Op Plus Word 32 2 Var y Word 32 Num 1 Word 32
39
40This is quite verbose. For simplicity, the Op syntax includes the type of
41the whole expression - Word 32 - and the number of arguments - 2 - even though
42much of this information is redundant.
43
44Finally, we can encode the statement x = y + 1; at address 14.
45
4614 Basic 15 1 x Word 32 Op Plus Word 32 2 Var y Word 32 Num 1 Word 32
47
48This specifies that node 14 is a basic node (which updates variables). The
49successor node is 15. It updates 1 variable (basic nodes may simultaneously
50update many variables). The variable to be updated is specified by x Word 32.
51The remainder of the line is the expression y + 1 as seen before.
52
53We can encode the statements if (x == z) { return; } as follows:
5415 Cond 16 17 Op Equals Bool Var x Word 32 Var y Word 32
5516 Basic Ret 0
56
57Conditional node 15 continues to node 16 if x == y, and to 17 otherwise. Node
5816 is a basic node which updates no variables, i.e. a skip statement. It
59continues to the special address Ret which designates return from the current
60function.
61
62Reference
63=========
64
65A graph program consists of a number of functions, each
66of whose control flow is structured as a graph of nodes
67with integer addresses. Nodes are of three types:
68  - 'Basic' updates variable values.
69  - 'Cond' chooses between two successor nodes.
70  - 'Call' makes calls to functions.
71
72A top level syntax file (interpreted by syntax.parse_all)
73contains two kinds of sections:
74  - Function blocks introduce functions.
75  - Struct blocks introduce aggregate types (e.g. struct in C).
76
77All lines in the input syntax are read as a whitespace-separated
78list of tokens. Empty lines and lines whose first non-whitespace
79character is '#' are ignored. All syntax is prefix encoded, so
80each clause has a first token which determines what kind of clause
81it is, followed by the concatenation of all of its subclauses.
82
83The two kinds of lines that appear in Struct blocks illustrate this format:
84Struct "struct name" <int size> <int alignment>
85StructField "name" <type> <int offset>
86
87On notation: above we denote the specific tokens 'Struct' and 'StructField',
88the arbitrary tokens "name" etc, the special integer subclauses <int size> etc,
89and a <type> subclause. Integer subclauses always consist of a single token.
90A leading character '-' or '~' indicates negative. A leading '0x' or '0b'
91indicates hexadecimal or binary encoding, otherwise decimal, e.g. 1, 0x1A,
92-0x2b, ~0b1100, etc. Naming tokens can be anything at all that does not contain
93whitespace, so x y'v a,b c#d # etc are all valid struct or field names.
94
95Struct blocks begin with the 'Struct' line, which specifies their name, total
96size and required alignment. The struct block then contains a number of
97StructField lines, where each field has a type and an offset from the start of
98the structure. The Struct block ends where any other Struct or Function block
99begins.
100
101Function blocks begin with a function declaration line
102  Function "name of function" <int>*("input argument name" <type>)
103    <int>*("output argument name" <type>)
104
105The notation <int>*(...) above denotes an arbitrary length list of subclauses.
106The first token of this clause will be a decimal integer specifying the length
107of the list. For instance, the classic 'XOR' function with inputs x, y and
108output z could be specified as follows:
109Function XOR 2 x Bool y Bool 1 z Bool
110
111Each function has a name, a list of input arguments, and a list of output
112arguments. The state in the graph language is entirely encoded in variables,
113so functions will typically have global objects such as the heap as both input
114and output arguments.
115
116The function block may end immediately, for functions with unspecified body,
117or may contain a number of graph node lines, followed by an entry point line:
118EntryPoint <int>
119The entry point line ends the function block.
120
121A graph node line has one of these formats:
122<int> Basic <next node> <int>*("var name" <type> <expression>)
123<int> Cond <next node> <next node> <expression>
124<int> Call <next node> "fun name" <int>*(<expression>) <int>*("var name" <type>)
125
126Each node line begins with an integer which is the address of the node. A
127<next node> is either the integer address of the following node, or one of the
128special address tokens Ret and Err.
129
130Basic nodes update the value of a (possibly empty) list of variables. The
131variables are all updated simultaneously to the computed value of the
132expressions given.
133
134Cond nodes evaluate an expresion which must be of boolean type. They specify
135two possible next nodes, left and right. The left node is visited when the
136expression evaluates to true, the right on false.
137
138Call nodes call another function of a given name. The list of arguments
139expressions is evaluated and must match the list of input arguments of the
140given function (same length and types). These arguments define the starting
141variable environment of that function. Its return parameters are saved to the
142list of variables given.
143
144The <type> clauses are in one of these formats:
145Word <int>  -- BitVec <int> is a synonym
146Array <type> <int>
147Struct "struct name"
148Ptr <type>
149One of the builtin types as a single token:
150  'Bool Mem Dom HTD PMS UNIT Type Token RoundingMode'
151FloatingPoint <int> <int>
152
153These types encode a machine word with the given number of bits, an array of
154some type, a struct previously defined by a Struct clause, or a pointer to
155another type. The Array, Struct and Ptr types are provided only for use in
156pointer validity assertions, which are discussed below. The FloatingPoint type
157exactly mirrors the SMTLIB2 (_ FloatingPoint eb sb) type, and is available as
158an optional extension. Using the FloatingPoint and RoundingMode types in any
159problem changes the SMT theory needed and may limit which solvers and features
160can be used.
161
162Of the builtin types, booleans are standard, UNIT is the singleton type, and
163memory is a 32-bit to 8-bit mapping. We aim to include 64-bit support soon. The
164Dom type is a set of 32-bit values, used to encode the valid domain of memory
165operations. The heap type description HTD is used by pointer-validity
166operators. The phantom machine state type PMS is unspecified and used to
167represent unspecified aspects of the environment. The type Type is the type of
168Type expressions which are used in pointer-validity.
169
170The <expression> clauses have one of these formats:
171Var "name" <type>
172Op "name" <type> <int>*(<expression>)
173Num <int> <type>
174Type <type>
175Symbol "name" <type>
176
177Most of the expressions of the language are composed from variables, numerals
178and operator applications, which should be self explanatory. The Type
179expression wraps a type into an expression (of type Type) so it may be passed
180to a pointer-validity operator.
181
182The special Symbol clauses are used to denote in source languages the values
183symbols will have in the binaries. They are replaced by specific numerals in
184a first pass.
185
186Many of the builtin operators are equivalent to SMTLIB2 operators, and are also
187available with their SMTLIB2 names. The operators are:
188  - Binary arithmetic operators on words:
189    + Plus/bvadd, Minus/bvsub, Times/bvmul, Modulus/bvurem, DividedBy/bvudiv,
190      BWAnd/bvand, BWOr/bvor, BWXOR/bvxor,
191      ShiftLeft/bvshl, ShiftRight/bvlshr, SignedShiftRight/bvashr
192  - Binary operators on booleans:
193    + And/and, Or/or, Implies
194  - Unary operators on bools:
195    + Not/not
196  - Booleans (nullary operators, i.e. constants):
197    + True/true, False/false
198  - Equals relation in any type:
199    + Equals
200  - Comparison relations on words (result is bool):
201    + Less/bvult, LessEquals/bvule, SignedLess/bvslt, SignedLessEquals/bvsle
202  - Unary operators on words:
203    + BWNot/bvnot, CountLeadingZeroes, CountTrailingZeroes, WordReverse
204  - Cast operators on words - result type may be different to argument type:
205    + WordCast, WordCastSigned
206  - Memory operations:
207    + MemAcc (args [m :: Mem, ptr :: Word 32]) any word type
208    + MemUpdate (args, [m :: Mem, ptr :: Word 32, v :: any word type])
209  - Pointer-validity operators:
210    + PValid, PWeakValid, PAlignValid, PGlobalValid, PArrayValid
211  - Miscellaneous:
212    + MemDom, HTDUpdate
213    + IfThenElse/ite (args [b :: bool, x :: any type T, y :: T])
214  - FloatingPoint operations from the SMTLIB2 floating point specification:
215    + roundNearestTiesToEven/RNE, roundNearestTiesToAway/RNA,
216      roundTowardPositive/RTP, roundTowardNegative/RTN,
217      roundTowardZero/RTZ
218    + fp.abs, fp.neg, fp.add, fp.sub, fp.mul, fp.div, fp.fma, fp.sqrt,
219      fp.rem, fp.roundToIntegral, fp.min, fp.max, fp.leq, fp.lt,
220      fp.geq, fp.gt, fp.eq, fp.isNormal, fp.IsSubnormal, fp.isZero,
221      fp.isInfinite, fp.isNaN, fp.isNegative, fp.isPositive
222    + ToFloatingPoint, ToFloatingPointSigned, ToFloatingPointUnsigned,
223      FloatingPointCast
224
225Operators with SMTLIB2 equivalents have the same semantics. Mem and Dom
226operations are wrap the SMTLIB2 BitVec Array type with more convenient
227operations.
228
229Memory accesses and updates can operate on various word types.
230
231The pointer-validity operators PValid, PWeakValid, PGlobalValid all take 3
232arguments of type HTD, Type, and Word 32. The PValid operator asserts that
233the heap-type-description contains this type starting at this address. This
234is exclusive with any incompatible type. PGlobalValid additionally asserts that
235this is a global object, and therefore not contained within any larger object.
236PWeakValid asserts that the type could be valid (it is aligned and within the
237range of available addresses in the heap-type-description) and is needed only
238in rare circumstances. PAlignValid omits the HTD argument and asserts only that
239the pointer is appropriately aligned and that the object does not start at or
240wrap around the 0 address. PArrayValid takes an additional argument (HTD,
241Type, Word 32, Word 32) where the final argument specifies the number of
242entries in an array.
243
244The MemDom operator takes argument types [Word 32, Dom] and returns the boolean
245of whether this pointer is in this domain.
246
247The IfThenElse operator takes a bool and any two arguments of the same type.
248
249The FloatingPoint operators are mostly taken from the SMTLIB2 floating
250point standard. The conversions ToFloatingPoint ([Word] to FP),
251ToFloatingPointSigned and ToFloatingPointUnsigned ([RoundingMode, Word] to FP)
252and FloatingPointCast (FP to FP) represent the variants of to_fp in the SMTLIB2
253standard.
254"""
255
256class Type:
257	def __init__ (self, kind, name, el_typ=None):
258		self.kind = kind
259		if kind in ['Array', 'Word', 'TokenWords']:
260			self.num = int (name)
261		else:
262			self.name = name
263		if kind in ['Array', 'Ptr']:
264			self.el_typ_symb = el_typ
265			self.el_typ = concrete_type (el_typ)
266		if kind in ['WordArray', 'FloatingPoint']:
267			self.nums = [int (name), int (el_typ)]
268
269	def __repr__ (self):
270		if self.kind == 'Array':
271			return 'Type ("Array", %r, %r)' % (self.num,
272				self.el_typ_symb)
273		elif self.kind in ('Word', 'TokenWords'):
274			return 'Type (%r, %r)' % (self.kind, self.num)
275		elif self.kind == 'Ptr':
276			return 'Type ("Ptr", %r)' % self.el_typ_symb
277		elif self.kind in ('WordArray', 'FloatingPoint'):
278			return 'Type (%r, %r, %r)' % tuple ([self.kind]
279				+ self.nums)
280		else:
281			return 'Type (%r, %r)' % (self.kind, self.name)
282
283	def __eq__ (self, other):
284		if not other:
285			return False
286		if self.kind != other.kind:
287			return False
288		if self.kind in ['Array', 'Word', 'TokenWords']:
289			if self.num != other.num:
290				return False
291		elif self.kind in ['WordArray', 'FloatingPoint']:
292			if self.nums != other.nums:
293				return False
294		else:
295			if self.name != other.name:
296				return False
297		if self.kind in ['Array', 'Ptr']:
298			if self.el_typ_symb != other.el_typ_symb:
299				return False
300		return True
301
302	def __ne__ (self, other):
303		return not other or not (self == other)
304
305	def __hash__ (self):
306		return hash(str(self))
307
308	def __cmp__ (self, other):
309		self_ss = []
310		self.serialise (self_ss)
311		other_ss = []
312		other.serialise (other_ss)
313		return cmp (self_ss, other_ss)
314
315	def subtypes (self):
316		if self.kind == 'Struct':
317			return structs[self.name].subtypes()
318		elif self.kind == 'Array':
319			return [self] + self.el_typ.subtypes()
320		else:
321			return [self]
322
323	def size (self):
324		if self.kind == 'Struct':
325			return structs[self.name].size
326		elif self.kind == 'Array':
327			return self.el_typ.size() * self.num
328		elif self.kind == 'Word':
329			assert self.num % 8 == 0, self
330			return self.num / 8
331		elif self.kind == 'FloatingPoint':
332			sz = sum (self.nums)
333			assert sz % 8 == 0, self
334			return sz / 8
335		elif self.kind == 'Ptr':
336			return 4
337		else:
338			assert not 'type has size'
339
340	def align (self):
341		if self.kind == 'Struct':
342			return structs[self.name].align
343		elif self.kind == 'Array':
344			return self.el_typ.align ()
345		elif self.kind in ('Word', 'FloatingPoint'):
346			return self.size ()
347		elif self.kind == 'Ptr':
348			return 4
349		else:
350			assert not 'type has alignment'
351
352	def serialise (self, xs):
353		if self.kind in ('Word', 'TokenWords'):
354			xs.append (self.kind)
355			xs.append (str (self.num))
356		elif self.kind in ('WordArray', 'FloatingPoint'):
357			xs.append (self.kind)
358			xs.extend ([str (n) for n in self.nums])
359		elif self.kind == 'Builtin':
360			xs.append (self.name)
361		elif self.kind == 'Array':
362			xs.append ('Array')
363			self.el_typ_symb.serialise (xs)
364			xs.append (str (self.num))
365		elif self.kind == 'Struct':
366			xs.append ('Struct')
367			xs.append (self.name)
368		elif self.kind == 'Ptr':
369			xs.append ('Ptr')
370			self.el_typ_symb.serialise (xs)
371		else:
372			assert not 'type serialisable', self.kind
373
374class Expr:
375	def __init__ (self, kind, typ, name = None, struct = None,
376			field = None, val = None, vals = None):
377		self.kind = kind
378		self.typ = typ
379		if name != None:
380			self.name = name
381		if struct != None:
382			self.struct = struct
383		if field != None:
384			self.field = field
385		if val != None:
386			self.val = val
387		if vals != None:
388			self.vals = vals
389		if kind == 'Op':
390			assert type (self.vals) == list
391
392	def binds (self):
393		binds = []
394		if self.kind in set (['Symbol', 'Var', 'ConstGlobal', 'Token']):
395			binds.append(('name', self.name))
396		elif self.kind in ['Array', 'StructCons']:
397			binds.append(('vals', self.vals))
398		elif self.kind == 'Field':
399			binds.append(('struct', self.struct))
400			binds.append(('field', self.field))
401		elif self.kind == 'FieldUpd':
402			binds.append(('struct', self.struct))
403			binds.append(('field', self.field))
404			binds.append(('val', self.val))
405		elif self.kind == 'Num':
406			binds.append(('val', self.val))
407		elif self.kind == 'Op':
408			binds.append(('name', self.name))
409			binds.append(('vals', self.vals))
410		elif self.kind == 'Type':
411			binds.append(('val', self.val))
412		elif self.kind == 'SMTExpr':
413			binds.append(('val', self.val))
414		else:
415			assert not 'expression understood for repr', self.kind
416		return binds
417
418	def __repr__ (self):
419		bits = [repr(self.kind), repr(self.typ)]
420		bits.extend(['%s = %r' % b for b in self.binds()])
421		return 'Expr (%s)' % ', '.join(bits)
422
423	def __eq__ (self, other):
424		return (other and self.kind == other.kind
425			and self.typ == other.typ
426			and self.binds() == other.binds())
427
428	def __ne__ (self, other):
429		return not other or not (self == other)
430
431	def __hash__ (self):
432		return hash_tuplify (self.kind, self.typ, self.binds ())
433
434	def __cmp__ (self, other):
435		return cmp ((self.kind, self.typ, self.binds ()),
436			(other.kind, other.typ, other.binds ()))
437
438	def is_var (self, (nm, typ)):
439		return self.kind == 'Var' and all([self.name == nm,
440			self.typ == typ])
441
442	def is_op (self, nm):
443		if type (nm) == str:
444			return self.kind == 'Op' and self.name == nm
445		else:
446			return self.kind == 'Op' and self.name in nm
447
448	def visit (self, visit):
449		visit (self)
450		if self.kind == 'Var':
451			pass
452		elif self.kind == 'Op' or self.kind == 'Array':
453			for x in self.vals:
454				x.visit (visit)
455		elif self.kind == 'StructCons':
456			for x in self.vals.itervalues ():
457				x.visit (visit)
458		elif self.kind == 'Field':
459			self.struct.visit (visit)
460
461			struct_typ = self.struct.typ
462			assert struct_typ.kind == 'Struct'
463			struct = structs[struct_typ.name]
464			(name, typ) = self.field
465			assert struct.fields[name][0] == typ
466		elif self.kind == 'FieldUpd':
467			self.val.visit (visit)
468			self.struct.visit (visit)
469
470			struct_typ = self.struct.typ
471			assert struct_typ.kind == 'Struct'
472			struct = structs[struct_typ.name]
473			(name, typ) = self.field
474			assert struct.fields[name][0] == typ
475		elif self.kind == 'ConstGlobal':
476			assert (target_objects.const_globals[self.name].typ
477				== self.typ)
478		elif self.kind in set (['Num', 'Symbol', 'Type', 'Token']):
479			pass
480		else:
481			assert not 'expr understood', self
482
483	def gen_visit (self, visit_lval, visit_rval):
484		self.visit (visit_rval)
485
486	def subst (self, substor, ss = None):
487		ret = False
488		if self.kind == 'Op':
489			subst_vals = subst_list (substor, self.vals)
490			if subst_vals:
491				self = Expr ('Op', self.typ, name = self.name,
492					vals = subst_vals)
493				ret = True
494		if (ss == None or self.kind in ss or self.is_op (ss)):
495			r = substor (self)
496			if r != None:
497				return r
498		if ret:
499			return self
500		return
501
502	def add_const_ranges (self, ranges):
503		def visit (expr):
504			if expr.kind == 'ConstGlobal':
505				(start, size, _) = symbols[expr.name]
506				assert size == expr.typ.size ()
507				ranges[expr.name] = (start, start + size - 1)
508
509		self.visit (visit)
510
511	def get_mem_access (self):
512		if self.is_op ('MemAcc'):
513			[m, p] = self.vals
514			return [('MemAcc', p, self, m)]
515		elif self.is_op ('MemUpdate'):
516			[m, p, v] = self.vals
517			return [('MemUpdate', p, v, m)]
518		else:
519			return []
520
521	def get_mem_accesses (self):
522		accesses = []
523		def visit (expr):
524			accesses.extend (expr.get_mem_access ())
525		self.visit (visit)
526		return accesses
527
528	def serialise (self, xs):
529		xs.append (self.kind)
530		if self.kind == 'Op':
531			xs.append (self.name)
532			self.typ.serialise (xs)
533			xs.append (str (len (self.vals)))
534			for v in self.vals:
535				v.serialise (xs)
536		elif self.kind == 'Num':
537			xs.append (str (self.val))
538			self.typ.serialise (xs)
539		elif self.kind == 'Var':
540			xs.append (self.name)
541			self.typ.serialise (xs)
542		elif self.kind == 'Type':
543			self.val.serialise (xs)
544		elif self.kind == 'Token':
545			xs.extend ([self.kind, self.name])
546			self.typ.serialise (xs)
547		else:
548			assert not 'expr serialisable', self.kind
549
550class Struct:
551	def __init__ (self, name, size, align):
552		self.name = name
553		self.size = size
554		self.align = align
555		self.field_list = []
556		self.fields = {}
557		self._subtypes = None
558		self.typ = Type ('Struct', name)
559
560	def add_field (self, name, typ, offset):
561		concrete = concrete_type (typ)
562		self.field_list.append ((name, concrete))
563		self.fields[name] = (concrete, offset, typ)
564		assert self._subtypes == None
565
566	def subtypes (self):
567		if self._subtypes != None:
568			return self._subtypes
569		xs = [self.typ]
570		for (concrete, offs, typ2) in self.fields.itervalues():
571			xs.extend(typ2.subtypes())
572		self._subtypes = xs
573		return xs
574
575def tuplify (x):
576	if type(x) == tuple or type(x) == list:
577		return tuple ([tuplify (y) for y in x])
578	if type(x) == dict:
579		return tuple ([tuplify (y) for y in x.iteritems ()])
580	else:
581		return x
582
583def hash_tuplify (* xs):
584	return hash (tuplify (xs))
585
586def subst_list (substor, xs, ss = None):
587	ys = [x.subst (substor, ss = ss) for x in xs]
588	if [y for y in ys if y != None]:
589		xs = list (xs)
590		for (i, y) in enumerate (ys):
591			if y != None:
592				xs[i] = y
593		return xs
594	else:
595		return
596
597class Node:
598	def __init__ (self, kind, conts, args):
599		self.kind = kind
600
601		if type (conts) == list:
602			if len (conts) == 1:
603				the_cont = conts[0]
604		else:
605			the_cont = conts
606
607		if kind == 'Basic':
608			self.cont = the_cont
609			self.upds = [(lv, v) for (lv, v) in args
610				if not v.is_var (lv)]
611		elif kind == 'Call':
612			self.cont = the_cont
613			(self.fname, self.args, self.rets) = args
614		elif kind == 'Cond':
615			(self.left, self.right) = conts
616			self.cond = args
617		else:
618			assert not 'node kind understood', self.kind
619
620	def __repr__ (self):
621		return 'Node (%r, %r, %r)' % (self.kind,
622			self.get_conts (), self.get_args ())
623
624	def __hash__ (self):
625		if self.kind == 'Call':
626			return hash ((self.fname, tuple (self.args),
627				tuple (self.rets), self.cont))
628		elif self.kind == 'Basic':
629			return hash (tuple (self.upds))
630		elif self.kind == 'Cond':
631			return hash ((self.cond, self.left, self.right))
632		else:
633			assert not 'node kind understood', self.kind
634
635	def __eq__ (self, other):
636		return all ([self.kind == other.kind,
637			self.get_conts () == other.get_conts (),
638			self.get_args () == other.get_args ()])
639
640	def __ne__ (self, other):
641		return not other or not self == other
642
643	def get_args (self):
644		if self.kind == 'Basic':
645			return self.upds
646		elif self.kind == 'Call':
647			return (self.fname, self.args, self.rets)
648		else:
649			return self.cond
650
651	def get_conts (self):
652		if self.kind == 'Cond':
653			return [self.left, self.right]
654		else:
655			return [self.cont]
656
657	def get_lvals (self):
658		if self.kind == 'Basic':
659			return [lv for (lv, v) in self.upds]
660		elif self.kind == 'Call':
661			return self.rets
662		else:
663			return []
664
665	def is_noop (self):
666		if self.kind == 'Basic':
667			return self.upds == []
668		elif self.kind == 'Cond':
669			return self.left == self.right
670		else:
671			return False
672
673	def visit (self, visit_lval, visit_rval):
674		if self.kind == 'Basic':
675			for (lv, v) in self.upds:
676				visit_lval (lv)
677				v.visit (visit_rval)
678		elif self.kind == 'Cond':
679			self.cond.visit (visit_rval)
680		elif self.kind == 'Call':
681			for v in self.args:
682				v.visit (visit_rval)
683			for r in self.rets:
684				visit_lval (r)
685
686	def gen_visit (self, visit_lval, visit_rval):
687		self.visit (visit_lval, visit_rval)
688
689	def subst_exprs (self, substor, ss = None):
690		if self.kind == 'Basic':
691			rvs = subst_list (substor,
692				[v for (lv, v) in self.upds], ss = ss)
693			if rvs == None:
694				return self
695			return Node ('Basic', self.cont,
696				zip ([lv for (lv, v) in self.upds], rvs))
697		elif self.kind == 'Cond':
698			r = self.cond.subst (substor, ss = ss)
699			if r == None:
700				return self
701			return Node ('Cond', [self.left, self.right], r)
702		elif self.kind == 'Call':
703			args = subst_list (substor, self.args, ss = ss)
704			if args == None:
705				return self
706			return Node ('Call', self.cont, (self.fname,
707				args, self.rets))
708
709	def get_mem_accesses (self):
710		accesses = []
711		def visit (expr):
712			accesses.extend (expr.get_mem_access ())
713		self.visit (lambda x: (), visit)
714		return accesses
715
716	def err_cond (self):
717		if self.kind != 'Cond':
718			return None
719		if self.left != 'Err':
720			if self.right == 'Err':
721				return mk_not (self.cond)
722			else:
723				return None
724		else:
725			if self.right == 'Err':
726				return true_term
727			else:
728				return self.cond
729
730	def serialise (self, xs):
731		xs.append (self.kind)
732		xs.extend ([str (c) for c in self.get_conts ()])
733		if self.kind == 'Basic':
734			xs.append (str (len (self.upds)))
735			for (lv, v) in self.upds:
736				xs.append (lv[0])
737				lv[1].serialise (xs)
738				v.serialise (xs)
739		elif self.kind == 'Cond':
740			self.cond.serialise (xs)
741		elif self.kind == 'Call':
742			xs.append (self.fname)
743			xs.append (str (len (self.args)))
744			for arg in self.args:
745				arg.serialise (xs)
746			xs.append (str (len (self.rets)))
747			for (nm, typ) in self.rets:
748				xs.append (nm)
749				typ.serialise (xs)
750
751def rename_lval ((name, typ), renames):
752	return (renames.get (name, name), typ)
753
754def do_subst (expr, substor, ss = None):
755	r = expr.subst (substor, ss = ss)
756	if r == None:
757		return expr
758	else:
759		return r
760
761standard_expr_kinds = set (['Symbol', 'ConstGlobal', 'Var', 'Op', 'Num',
762	'Type'])
763
764def rename_expr_substor (renames):
765	def ren (expr):
766		if expr.kind == 'Var' and expr.name in renames:
767			return mk_var (renames[expr.name], expr.typ)
768		else:
769			return
770	return ren
771
772def rename_expr (expr, renames):
773	return do_subst (expr, rename_expr_substor (renames),
774		ss = set (['Var']))
775
776def copy_rename (node, renames):
777	(vs, ns) = renames
778	nf = lambda n: ns.get (n, n)
779	node = node.subst_exprs (rename_expr_substor (vs))
780	if node.kind == 'Call':
781		return Node ('Call', nf (node.cont), (node.fname, node.args,
782                              [rename_lval (l, vs) for l in node.rets]))
783	elif node.kind == 'Basic':
784		return Node ('Basic', nf (node.cont),
785			[(rename_lval (lv, vs), v) for (lv, v) in node.upds])
786	elif node.kind == 'Cond':
787		return Node ('Cond', [nf (node.left), nf (node.right)],
788			node.cond)
789	else:
790		assert not 'node kind understood', node.kind
791
792class Function:
793	def __init__ (self, name, inputs, outputs):
794		self.name = name
795		self.inputs = inputs
796		self.outputs = outputs
797		self.entry = None
798		self.nodes = {}
799
800	def __hash__ (self):
801		en = self.entry
802		if not en:
803			en = -1
804		return hash (tuple ([self.name, tuple (self.inputs),
805				tuple (self.outputs), en])
806			+ tuple (sorted (self.nodes.iteritems ())))
807
808	def reachable_nodes (self, simplify = False):
809		if not self.entry:
810			return {}
811		rs = {}
812		vs = [self.entry]
813		while vs:
814			n = vs.pop()
815			if type (n) == str:
816				continue
817			rs[n] = True
818			node = self.nodes[n]
819			if simplify:
820				import logic
821				node = logic.simplify_node_elementary (node)
822			for c in node.get_conts ():
823				if not c in rs:
824					vs.append (c)
825		return rs
826
827	def serialise (self):
828		xs = ['Function', self.name, str (len (self.inputs))]
829		for (nm, typ) in self.inputs:
830			xs.append (nm)
831			typ.serialise (xs)
832		xs.append (str (len (self.outputs)))
833		for (nm, typ) in self.outputs:
834			xs.append (nm)
835			typ.serialise (xs)
836		ss = [' '.join (xs)]
837		if not self.entry:
838			return ss
839		for n in self.nodes:
840			xs = [str (n)]
841			self.nodes[n].serialise (xs)
842			ss.append (' '.join (xs))
843		ss.append ('EntryPoint %d' % self.entry)
844		return ss
845
846	def as_problem (self, Problem, name = 'temp'):
847		p = Problem(None, 'Function (%s)' % self.name)
848		p.clone_function (self, name)
849		p.compute_preds ()
850		return p
851
852	def function_call_addrs (self):
853		return [(n, self.nodes[n].fname)
854			for n in self.nodes if self.nodes[n].kind == 'Call']
855
856	def function_calls (self):
857		return set ([fn for (n, fn) in self.function_call_addrs ()])
858
859	def compile_hints (self, Problem):
860		xs = ['Hints %s' % self.name]
861
862		p = self.as_problem (Problem)
863
864		for n in p.nodes:
865			ys = ['VarDeps', str (n)]
866			for (nm, typ) in p.var_deps[n]:
867				ys.append (nm)
868				typ.serialise (ys)
869			xs.append (' '.join (ys))
870			if self.nodes[n].kind != 'Basic':
871				continue
872		return xs
873
874	def save_graph (self, fname):
875		import problem
876		problem.save_graph (self.nodes, fname)
877
878def mk_builtinTs ():
879	return dict([(n, Type('Builtin', n)) for n
880	in 'Bool Mem Dom HTD PMS UNIT Type Token RelWrapper'.split()])
881builtinTs = mk_builtinTs ()
882boolT = builtinTs['Bool']
883word32T = Type ('Word', '32')
884word64T = Type ('Word', '64')
885word16T = Type ('Word', 16)
886word8T = Type ('Word', '8')
887
888phantom_types = set ([builtinTs[t] for t
889	in 'Dom HTD PMS UNIT Type'.split ()])
890
891def concrete_type (typ):
892	if typ.kind == 'Ptr':
893		return word32T
894	else:
895		return typ
896
897global_wrappers = {}
898def get_global_wrapper (typ):
899	if typ in global_wrappers:
900		return global_wrappers[typ]
901	struct_name = fresh_name ('Global (%s)' % typ, structs)
902	struct = Struct (struct_name, typ.size (), typ.align ())
903	struct.add_field ('v', typ, 0)
904	structs[struct_name] = struct
905
906	global_wrappers[typ] = struct.typ
907	return struct.typ
908
909# ==========================================================
910# parsing code for types, expressions, structs and functions
911
912def parse_int (s):
913	if s.startswith ('-') or s.startswith ('~'):
914		return (- parse_int (s[1:]))
915	if s.startswith ('0x') or s.startswith ('0b'):
916		return int (s, 0)
917	else:
918		return int (s)
919
920def parse_typ(bits, n, symbolic_types = False):
921	if bits[n] == 'Word' or bits[n] == 'BitVec':
922		return (n + 2, Type('Word', parse_int (bits[n + 1])))
923	elif bits[n] == 'WordArray' or bits[n] == 'FloatingPoint':
924		return (n + 3, Type(bits[n],
925			parse_int (bits[n + 1]), parse_int (bits[n + 2])))
926	elif bits[n] in builtinTs:
927		return (n + 1, builtinTs[bits[n]])
928	elif bits[n] == 'Array':
929		(n, typ) = parse_typ (bits, n + 1, True)
930		return (n + 1, Type ('Array', parse_int (bits[n]), typ))
931	elif bits[n] == 'Struct':
932		return (n + 2, Type ('Struct', bits[n + 1]))
933	elif bits[n] == 'Ptr':
934		(n, typ) = parse_typ (bits, n + 1, True)
935		if symbolic_types:
936			return (n, Type ('Ptr', '', typ))
937		else:
938			return (n, word32T)
939	else:
940		assert not 'type encoded', (n, bits[n:], bits)
941
942def node_name (name):
943	if name in {'Ret':True, 'Err':True}:
944		return name
945	else:
946		try:
947			return parse_int (name)
948		except ValueError, e:
949			assert not 'node name understood', name
950
951def parse_list (parser, bits, n, extra=None):
952	try:
953		num = parse_int (bits[n])
954	except ValueError:
955		assert not 'number parseable', (n, bits[n:], bits)
956	n = n + 1
957	xs = []
958	for i in range (num):
959		if extra:
960			(n, x) = parser(bits, n, extra)
961		else:
962			(n, x) = parser(bits, n)
963		xs.append(x)
964	return (n, xs)
965
966def parse_arg (bits, n):
967	nm = bits[n]
968	(n, typ) = parse_typ (bits, n + 1)
969	return (n, (nm, typ))
970
971ops = {'Plus':2, 'Minus':2, 'Times':2, 'Modulus':2,
972	'DividedBy':2, 'BWAnd':2, 'BWOr':2, 'BWXOR':2, 'And':2,
973	'Or':2, 'Implies':2, 'Equals':2, 'Less':2,
974	'LessEquals':2, 'SignedLess':2, 'SignedLessEquals':2,
975	'ShiftLeft':2, 'ShiftRight':2, 'CountLeadingZeroes':1,
976	'CountTrailingZeroes':1, 'WordReverse':1, 'SignedShiftRight':2,
977	'Not':1, 'BWNot':1, 'WordCast':1, 'WordCastSigned':1,
978	'True':0, 'False':0, 'UnspecifiedPrecond':0,
979	'MemUpdate':3, 'MemAcc':2, 'IfThenElse':3, 'ArrayIndex':2,
980	'ArrayUpdate':3, 'MemDom':2,
981	'PValid':3, 'PWeakValid':3, 'PAlignValid':2, 'PGlobalValid':3,
982	'PArrayValid':4,
983	'HTDUpdate':5, 'WordArrayAccess':2, 'WordArrayUpdate':3,
984	'TokenWordsAccess':2, 'TokenWordsUpdate':3,
985	'ROData':1, 'StackWrapper':2,
986	'ToFloatingPoint':1, 'ToFloatingPointSigned':2,
987	'ToFloatingPointUnsigned':2, 'FloatingPointCast':1,
988}
989
990ops_to_smt = {'Plus':'bvadd', 'Minus':'bvsub', 'Times':'bvmul', 'Modulus':'bvurem',
991	'DividedBy':'bvudiv', 'BWAnd':'bvand', 'BWOr':'bvor', 'BWXOR':'bvxor',
992	'And':'and',
993	'Or':'or', 'Implies':'=>', 'Equals':'=', 'Less':'bvult',
994	'LessEquals':'bvule', 'SignedLess':'bvslt', 'SignedLessEquals':'bvsle',
995	'ShiftLeft':'bvshl', 'ShiftRight':'bvlshr', 'SignedShiftRight':'bvashr',
996	'Not':'not', 'BWNot':'bvnot',
997	'True':'true', 'False':'false',
998	'UnspecifiedPrecond': 'unspecified-precond',
999	'IfThenElse':'ite', 'MemDom':'mem-dom',
1000	'ROData': 'rodata', 'ImpliesROData': 'implies-rodata',
1001	'WordArrayAccess':'select', 'WordArrayUpdate':'store'}
1002
1003ex_smt_ops = """roundNearestTiesToEven RNE roundNearestTiesToAway RNA
1004    roundTowardPositive RTP roundTowardNegative RTN
1005    roundTowardZero RTZ
1006    fp.abs fp.ne fp.add fp.sub fp.mul fp.div fp.fma fp.sqrt fp.rem
1007    fp.roundToInteral fp.min fp.max fp.leq fp.lt fp.eq fp.t fp.eq
1008    fp.isNormal fp.IsSubnormal fp.isZero fp.isInfinite fp.isNaN
1009    fp.isNeative fp.isPositive""".split ()
1010
1011ops_to_smt.update (dict ([(smt, smt) for smt in ex_smt_ops]))
1012
1013smt_to_ops = dict ([(smt, oper) for (oper, smt) in ops_to_smt.iteritems ()])
1014
1015def parse_struct_elem (bits, n):
1016	name = bits[n]
1017	(n, typ) = parse_typ (bits, n + 1)
1018	(n, val) = parse_expr (bits, n)
1019	return (n, (name, val))
1020
1021def parse_expr (bits, n):
1022	if bits[n] in set (['Symbol', 'Var', 'ConstGlobal', 'Token']):
1023		kind = bits[n]
1024		name = bits[n + 1]
1025		(n, typ) = parse_typ (bits, n + 2)
1026		return (n, Expr (kind, typ, name = name))
1027	if bits[n] == 'Array':
1028		(n, typ) = parse_typ (bits, n + 1)
1029		assert typ.kind == 'Array'
1030		(n, xs) = parse_list (parse_expr, bits, n)
1031		assert len(xs) == typ.num
1032		return (n, Expr ('Array', typ, vals = xs))
1033	elif bits[n] == 'Field':
1034		(n, typ) = parse_typ (bits, n + 1)
1035		name = bits[n]
1036		(n, typ2) = parse_typ (bits, n + 1)
1037		(n, struct) = parse_expr (bits, n)
1038		assert struct.typ == typ
1039		return (n, Expr ('Field', typ2, struct = struct,
1040			field = (name, typ2)))
1041	elif bits[n] == 'FieldUpd':
1042		(n, typ) = parse_typ (bits, n + 1)
1043		name = bits[n]
1044		(n, typ2) = parse_typ (bits, n + 1)
1045		(n, val) = parse_expr (bits, n)
1046		(n, struct) = parse_expr (bits, n)
1047		return (n, Expr ('FieldUpd', typ, struct = struct,
1048			field = (name, typ2), val = val))
1049	elif bits[n] == 'StructCons':
1050		(n, typ) = parse_typ (bits, n + 1)
1051		(n, xs) = parse_list (parse_struct_elem, bits, n)
1052		return (n, Expr ('StructCons', typ, vals = dict(xs)))
1053	elif bits[n] == 'Num':
1054		v = parse_int (bits[n + 1])
1055		(n, typ) = parse_typ (bits, n + 2)
1056		return (n, Expr ('Num', typ, val = v))
1057	elif bits[n] == 'Op':
1058		op = bits[n + 1]
1059		op = smt_to_ops.get (op, op)
1060		assert op in ops, op
1061		(n, typ) = parse_typ (bits, n + 2)
1062		(n, xs) = parse_list (parse_expr, bits, n)
1063		assert len (xs) == ops[op]
1064		return (n, Expr ('Op', typ, name = op, vals = xs))
1065	elif bits[n] == 'Type':
1066		(n, typ) = parse_typ (bits, n + 1, symbolic_types = True)
1067		return (n, Expr ('Type', builtinTs['Type'], val = typ))
1068	else:
1069		assert not 'expr parseable', (n, bits[n : n + 3], bits)
1070
1071def parse_lval (bits, n):
1072	name = bits[n]
1073	(n, typ) = parse_typ (bits, n + 1)
1074	return (n, (name, typ))
1075
1076def parse_lval_and_val (bits, n):
1077	(n, (name, typ)) = parse_lval (bits, n)
1078	(n, val) = parse_expr (bits, n)
1079	return (n, ((name, typ), val))
1080
1081def parse_node (bits, n):
1082	if bits[n] == 'Basic':
1083		cont = node_name(bits[n + 1])
1084		(n, upds) = parse_list (parse_lval_and_val, bits, n + 2)
1085		return Node ('Basic', cont, upds)
1086	elif bits[n] == 'Cond':
1087		left = node_name(bits[n + 1])
1088		right = node_name(bits[n + 2])
1089		(n, cond) = parse_expr (bits, n + 3)
1090		return Node ('Cond', (left, right), cond)
1091	else:
1092		assert bits[n] == 'Call'
1093		cont = node_name(bits[n + 1])
1094		name = bits[n + 2]
1095		(n, args) = parse_list (parse_expr, bits, n + 3)
1096		(n, saves) = parse_list (parse_lval, bits, n)
1097		return Node ('Call', cont, (name, args, saves))
1098
1099true_term = Expr ('Op', boolT, name = 'True', vals = [])
1100false_term = Expr ('Op', boolT, name = 'False', vals = [])
1101unspecified_precond_term = Expr ('Op', boolT, name = 'UnspecifiedPrecond', vals = [])
1102
1103def parse_all(lines):
1104	'''Toplevel parser for input information. Accepts an iterator over
1105lines. See syntax.quick_reference for an explanation.'''
1106
1107	if hasattr (lines, 'name'):
1108		trace ('Loading syntax from %s' % lines.name)
1109	else:
1110		trace ('Loading syntax (from anonymous source).')
1111
1112	structs = {}
1113	functions = {}
1114	const_globals = {}
1115	cfg_warnings = []
1116	for line in lines:
1117		bits = line.split()
1118		# empty lines and #-comments ignored
1119		if not bits or bits[0][0] == '#':
1120			continue
1121		if bits[0] == 'Struct':
1122			# Struct <name> <size> <alignment>
1123			# followed by block of StructField lines
1124			assert bits[1] not in structs
1125			current_struct = Struct (bits[1], parse_int (bits[2]),
1126				parse_int (bits[3]))
1127			structs[bits[1]] = current_struct
1128		elif bits[0] == 'StructField':
1129			# StructField <name> <type (encoded)> <offset>
1130			(_, typ) = parse_typ(bits, 2, symbolic_types = True)
1131			current_struct.add_field (bits[1], typ,
1132				parse_int (bits[-1]))
1133		elif bits[0] == 'ConstGlobalDef':
1134			# ConstGlobalDef <name> <value>
1135			name = bits[1]
1136			(_, val) = parse_expr (bits, 2)
1137			const_globals[name] = val
1138		elif bits[0] == 'Function':
1139			# Function <name> <inputs> <outputs>
1140			# followed by optional block of node lines
1141			# concluded by EntryPoint line
1142			fname = bits[1]
1143			(n, inputs) = parse_list (parse_arg, bits, 2)
1144			(_, outputs) = parse_list (parse_arg, bits, n)
1145			current_function = Function (fname, inputs, outputs)
1146			assert fname not in functions, fname
1147			functions[fname] = current_function
1148		elif bits[0] == 'EntryPoint':
1149			# EntryPoint <entry point>
1150			entry = node_name(bits[1])
1151			# instead of setting function.entry to this value,
1152			# create a dummy node. this ensures there is always
1153			# at least one node (EntryPoint Ret is valid) and
1154			# also that the entry point is not in a loop
1155			name = fresh_node (current_function.nodes)
1156			current_function.nodes[name] = Node ('Basic',
1157				entry, [])
1158			current_function.entry = name
1159			# ensure that the function graph is closed
1160			check_cfg (current_function, warnings = cfg_warnings)
1161			current_function = None
1162		else:
1163			# <node name> <node (encoded)>
1164			name = node_name(bits[0])
1165			assert name not in current_function.nodes, (name, bits)
1166			current_function.nodes[name] = parse_node (bits, 1)
1167
1168	print_cfg_warnings (cfg_warnings)
1169	trace ('Loaded %d functions, %d structs, %d globals.'
1170		% (len (functions), len (structs), len (const_globals)))
1171
1172	return (structs, functions, const_globals)
1173
1174def parse_and_install_all (lines, tag, skip_functions=None):
1175	if skip_functions == None:
1176	    skip_functions = []
1177	(structs, functions, const_globals) = parse_all (lines)
1178	for f in skip_functions:
1179	    if f in functions:
1180		del functions[f]
1181	target_objects.structs.update (structs)
1182	target_objects.functions.update (functions)
1183	target_objects.const_globals.update (const_globals)
1184	if tag != None:
1185		target_objects.functions_by_tag[tag] = set (functions)
1186	return (structs, functions, const_globals)
1187
1188# ===============================
1189# simple accessor code and checks
1190
1191def visit_rval (vs):
1192	def visit (expr):
1193		if expr.kind == 'Var':
1194			v = expr.name
1195			if v in vs:
1196				assert vs[v] == expr.typ, (expr, vs[v])
1197			vs[v] = expr.typ
1198		if expr.is_op ('MemAcc'):
1199			[m, p] = expr.vals
1200			assert p.typ == word32T, expr
1201		if expr.is_op ('PGlobalValid'):
1202			[htd, typ_expr, p] = expr.vals
1203			typ = typ_expr.val
1204			get_global_wrapper (typ)
1205
1206	return visit
1207
1208def visit_lval (vs):
1209	def visit ((name, typ)):
1210		assert vs.get(name, typ) == typ, (name, vs[name], typ)
1211		vs[name] = typ
1212
1213	return visit
1214
1215def get_expr_vars (expr, vs):
1216	expr.visit (visit_rval (vs))
1217
1218def get_expr_var_set (expr):
1219	vs = {}
1220	get_expr_vars (expr, vs)
1221	return set (vs.items ())
1222
1223def get_lval_vars (lval, vs):
1224	assert len(lval) == 2
1225	assert vs.get(lval[0], lval[1]) == lval[1]
1226	vs[lval[0]] = lval[1]
1227
1228def get_node_vars (node, vs):
1229	node.visit (visit_lval (vs), visit_rval (vs))
1230
1231def get_node_rvals (node, vs = None):
1232	if vs == None:
1233		vs = {}
1234	node.visit (lambda l: (), visit_rval (vs))
1235	return vs
1236
1237def get_vars(function):
1238	vs = dict(function.inputs + function.outputs)
1239	for node in function.nodes.itervalues():
1240		get_node_vars(node, vs)
1241	return vs
1242
1243def get_lval_typ(lval):
1244	assert len(lval) == 2
1245	return lval[1]
1246
1247def get_expr_typ(expr):
1248	return expr.typ
1249
1250def check_cfg (fun, warnings = None):
1251	dead_arcs = [(n, n2) for (n, node) in fun.nodes.iteritems ()
1252		for n2 in node.get_conts ()
1253		if n2 not in fun.nodes and n2 not in ['Ret', 'Err']]
1254	for (n, n2) in dead_arcs:
1255		assert type (n2) != str
1256		# OK if multiple dead arcs and we save over n2 twice
1257		fun.nodes[n2] = Node ('Basic', 'Err', [])
1258	if warnings == None:
1259		print_cfg_warnings ([(fun, n, n2) for (n, n2) in dead_arcs])
1260	else:
1261		warnings.extend ([(fun, n, n2) for (n, n2) in dead_arcs])
1262
1263def print_cfg_warnings (warnings):
1264	post_calls = set ([(fun.nodes[n].fname, fun.name)
1265		for (fun, n, n2) in warnings
1266		if fun.nodes[n].kind == 'Call'])
1267	import logic
1268	for (call, sites) in logic.dict_list (post_calls).iteritems ():
1269		trace ('Missing nodes after calls to %s' % call)
1270		trace ('  in %s' % str (sites))
1271	for (fun, n, n2) in warnings:
1272		if fun.nodes[n].kind != 'Call':
1273			trace ('Warning: dead arc in %s: %s -> %s'
1274				% (fun.name, n, n2))
1275			trace ('  (follows %s node!)' % fun.nodes[n].kind)
1276
1277def check_funs (functions, verbose = False):
1278	for (f, fun) in functions.iteritems():
1279		if not fun:
1280			continue
1281		if verbose:
1282			trace ('Checking %s' % f)
1283		check_cfg (fun)
1284		get_vars(fun)
1285		for (n, node) in fun.nodes.iteritems():
1286			if node.kind == 'Call':
1287				c = functions[node.fname]
1288				assert map(get_expr_typ, node.args) == \
1289					map (get_lval_typ, c.inputs), (
1290						 node.fname, node.args, c.inputs)
1291				assert map (get_lval_typ, node.rets) == \
1292					map (get_lval_typ, c.outputs), (
1293						 node.fname, node.rets, c.outputs)
1294			elif node.kind == 'Basic':
1295				for (lv, v) in node.upds:
1296					assert get_lval_typ(lv) == get_expr_typ(v)
1297			elif node.kind == 'Cond':
1298				assert get_expr_typ(node.cond) == boolT
1299
1300def get_extensions (v):
1301	extensions = set ()
1302	rm = builtinTs['RoundingMode']
1303	def visitor (expr):
1304		if expr.typ == rm or expr.typ.kind == 'FloatingPoint':
1305			extensions.add ('FloatingPoint')
1306	v.gen_visit (lambda l: (), visitor)
1307	return extensions
1308
1309# =========================================
1310# common constructors for basic expressions
1311
1312def mk_var (nm, typ):
1313	return Expr ('Var', typ, name = nm)
1314
1315def mk_token (nm):
1316	return Expr ('Token', builtinTs['Token'], name = nm)
1317
1318def mk_plus (x, y):
1319	assert x.typ == y.typ
1320	return Expr ('Op', x.typ, name = 'Plus', vals = [x, y])
1321
1322def mk_uminus (x):
1323	zero = Expr ('Num', x.typ, val = 0)
1324	return mk_minus (zero, x)
1325
1326def mk_minus (x, y):
1327	assert x.typ == y.typ
1328	return Expr ('Op', x.typ, name = 'Minus', vals = [x, y])
1329
1330def mk_times (x, y):
1331	assert x.typ == y.typ
1332	return Expr ('Op', x.typ, name = 'Times', vals = [x, y])
1333
1334def mk_divide (x, y):
1335	assert x.typ == y.typ
1336	return Expr ('Op', x.typ, name = 'DividedBy', vals = [x, y])
1337
1338def mk_modulus (x, y):
1339	assert x.typ == y.typ
1340	return Expr ('Op', x.typ, name = 'Modulus', vals = [x, y])
1341
1342def mk_bwand (x, y):
1343	assert x.typ == y.typ
1344	assert x.typ.kind == 'Word'
1345	return Expr ('Op', x.typ, name = 'BWAnd', vals = [x, y])
1346
1347def mk_eq (x, y):
1348	assert x.typ == y.typ
1349	return Expr ('Op', boolT, name = 'Equals', vals = [x, y])
1350
1351def mk_less_eq (x, y, signed = False):
1352	assert x.typ == y.typ
1353	name = {False: 'LessEquals', True: 'SignedLessEquals'}[signed]
1354	return Expr ('Op', boolT, name = name, vals = [x, y])
1355
1356def mk_less (x, y, signed = False):
1357	assert x.typ == y.typ
1358	name = {False: 'Less', True: 'SignedLess'}[signed]
1359	return Expr ('Op', boolT, name = name, vals = [x, y])
1360
1361def mk_implies (x, y):
1362	assert x.typ == boolT
1363	assert y.typ == boolT
1364	return Expr ('Op', boolT, name = 'Implies', vals = [x, y])
1365
1366def mk_n_implies (xs, y):
1367	imp = y
1368	for x in reversed (xs):
1369		imp = mk_implies (x, imp)
1370	return imp
1371
1372def mk_and (x, y):
1373	assert x.typ == boolT
1374	assert y.typ == boolT
1375	return Expr ('Op', boolT, name = 'And', vals = [x, y])
1376
1377def mk_or (x, y):
1378	assert x.typ == boolT
1379	assert y.typ == boolT
1380	return Expr ('Op', boolT, name = 'Or', vals = [x, y])
1381
1382def mk_not (x):
1383	assert x.typ == boolT
1384	return Expr ('Op', boolT, name = 'Not', vals = [x])
1385
1386def mk_shift_gen (name, x, n):
1387	assert x.typ.kind == 'Word'
1388	if type (n) == int:
1389		n = Expr ('Num', x.typ, val = n)
1390	return Expr ('Op', x.typ, name = name, vals = [x, n])
1391
1392mk_shiftr = lambda x, n: mk_shift_gen ('ShiftRight', x, n)
1393
1394def mk_clz (x):
1395	return Expr ('Op', x.typ, name = "CountLeadingZeroes", vals = [x])
1396
1397def mk_word_reverse (x):
1398	return Expr ('Op', x.typ, name = "WordReverse", vals = [x])
1399
1400def mk_ctz (x):
1401	return mk_clz (mk_word_reverse (x))
1402
1403def foldr1 (f, xs):
1404	x = xs[-1]
1405	for i in reversed (range (len (xs) - 1)):
1406		x = f (xs[i], x)
1407	return x
1408
1409def mk_num (x, typ):
1410	import logic
1411	if logic.is_int (typ):
1412		typ = Type ('Word', typ)
1413	assert typ.kind == 'Word', typ
1414	assert logic.is_int (x), x
1415	return Expr ('Num', typ, val = x)
1416
1417def mk_word32 (x):
1418	return mk_num (x, word32T)
1419
1420def mk_word8 (x):
1421	return mk_num (x, word8T)
1422
1423def mk_word32_maybe(x):
1424	import logic
1425	if logic.is_int (x):
1426		return mk_word32 (x)
1427	else:
1428		assert x.typ == word32T
1429		return x
1430
1431def mk_cast (x, typ):
1432	if x.typ == typ:
1433		return x
1434	else:
1435		assert x.typ.kind == 'Word', x.typ
1436		assert typ.kind == 'Word', typ
1437		return Expr ('Op', typ, name = 'WordCast', vals = [x])
1438
1439def mk_memacc(m, p, typ):
1440	assert m.typ == builtinTs['Mem']
1441	assert p.typ == word32T
1442	return Expr ('Op', typ, name = 'MemAcc', vals = [m, p])
1443
1444def mk_memupd(m, p, v):
1445	assert m.typ == builtinTs['Mem']
1446	assert p.typ == word32T
1447	return Expr ('Op', m.typ, name = 'MemUpdate', vals = [m, p, v])
1448
1449def mk_arr_index (arr, i):
1450	assert arr.typ.kind == 'Array'
1451	return Expr ('Op', arr.typ.el_typ, name = 'ArrayIndex',
1452		vals = [arr, i])
1453
1454def mk_arroffs(p, typ, i):
1455	assert typ.kind == 'Array'
1456	import logic
1457	if logic.is_int (i):
1458		assert i < typ.num
1459		offs = i * typ.el_typ.size()
1460		assert offs == i or offs % 4 == 0
1461		return mk_plus (p, mk_word32 (offs))
1462	else:
1463		sz = typ.el_typ.size()
1464		return mk_plus (p, mk_times (i, mk_word32 (sz)))
1465
1466def mk_if (P, x, y):
1467	assert P.typ == boolT
1468	assert x.typ == y.typ
1469	return Expr ('Op', x.typ, name = 'IfThenElse', vals = [P, x, y])
1470
1471def mk_meta_typ (typ):
1472	return Expr ('Type', builtinTs['Type'], val = typ)
1473
1474def mk_pvalid (htd, typ, p):
1475	return Expr ('Op', boolT, name = 'PValid',
1476		vals = [htd, mk_meta_typ (typ), p])
1477
1478def mk_rel_wrapper (nm, vals):
1479	return Expr ('Op', builtinTs['RelWrapper'], name = nm, vals = vals)
1480
1481def adjust_op_vals (expr, vals):
1482	assert expr.kind == 'Op'
1483	return Expr ('Op', expr.typ, expr.name, vals = vals)
1484
1485mks = (mk_var, mk_plus, mk_uminus, mk_minus, mk_times, mk_modulus, mk_bwand,
1486mk_eq, mk_less_eq, mk_less, mk_implies, mk_and, mk_or, mk_not, mk_word32,
1487mk_word8, mk_word32_maybe, mk_cast, mk_memacc, mk_memupd, mk_arr_index,
1488mk_arroffs, mk_if, mk_meta_typ, mk_pvalid)
1489
1490# ====================================================================
1491# pretty printing code for the syntax - only used for printing reports
1492
1493def pretty_type (typ):
1494	if typ.kind == 'Word':
1495		return 'Word%d' % typ.num
1496	elif typ.kind == 'WordArray':
1497		[ix, num] = typ.nums
1498		return 'Word%d[%d]' % (num, ix)
1499	elif typ.kind == 'Ptr':
1500		return 'Ptr(%s)' % pretty_type (typ.el_typ_symb)
1501	elif typ.kind == 'Struct':
1502		return 'struct %s' % typ.name
1503	elif typ.kind == 'Builtin':
1504		return typ.name
1505	else:
1506		assert not 'type pretty-printable', typ
1507
1508pretty_opers = {'Plus': '+', 'Minus': '-', 'Times': '*'}
1509
1510known_typ_change = set (['ROData', 'MemAcc', 'IfThenElse', 'WordArrayUpdate',
1511	'MemDom'])
1512
1513def pretty_expr (expr, print_type = False):
1514	if print_type:
1515		return '((%s) (%s))' % (pretty_type (expr.typ),
1516			pretty_expr (expr))
1517	elif expr.kind == 'Var':
1518		return repr (expr.name)
1519	elif expr.kind == 'Num':
1520		return '%d' % expr.val
1521	elif expr.kind == 'Op' and expr.name in pretty_opers:
1522		[x, y] = expr.vals
1523		return '(%s %s %s)' % (pretty_expr (x), pretty_opers[expr.name],
1524			pretty_expr (y))
1525	elif expr.kind == 'Op':
1526		if expr.name in known_typ_change:
1527			vals = [pretty_expr (v) for v in expr.vals]
1528		else:
1529			vals = [pretty_expr (v, print_type = v.typ != expr.typ)
1530				for v in expr.vals]
1531		return '%s(%s)' % (expr.name, ', '.join (vals))
1532	elif expr.kind == 'Token':
1533		return "''%s''" % expr.name
1534	else:
1535		assert not 'expr pretty-printable', expr
1536
1537
1538# =================================================
1539# some helper code that's needed all over the place
1540
1541def fresh_name (n, D, v=True):
1542	if n not in D:
1543		D[n] = v
1544		return n
1545
1546	x = 1
1547	y = 1
1548	while ('%s.%d' % (n, x)) in D:
1549		y = x
1550		x = x * 2
1551	while y < x:
1552		z = (y + x) / 2
1553		if ('%s.%d' % (n, z)) in D:
1554			y = z + 1
1555		else:
1556			x = z
1557	n = '%s.%d' % (n, x)
1558	assert n not in D
1559
1560	D[n] = v
1561	return n
1562
1563def fresh_node (ns, hint = 1):
1564	n = hint
1565	# use names that are *not* multiples of 4
1566	n = (n | 15) + 2
1567	while n in ns:
1568		n += 16
1569	return n
1570
1571