1#
2# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
3#
4# SPDX-License-Identifier: BSD-2-Clause
5#
6
7# Syntax and simple operations for types, expressions, graph nodes
8# and graph functions (functions in graph-language format).
9
10from target_objects import structs, trace
11import target_objects
12
13quick_reference = """
14Quick reference on the graph language and its syntax.
15=====================================================
16
17Example
18=======
19
20Let's build up the syntax that roughly mirrors the C statement x = y + 1;
21
22This is a quick example. A more thorough reference is below.
23
24We have two example types: Bool and Word 32.
25
26Here are two atomic expressions of type Word 32:
27Var y Word 32
28Num 1 Word 32
29
30These encode the variable y and the number 1, both of type Word 32.
31
32Expressions in this language are built up by concatenating strings of tokens.
33Any whitespace delimited string can be a token, and thus a variable name.
34
35Compound expressions are built with operators, e.g. y + 1:
36Op Plus Word 32 2 Var y Word 32 Num 1 Word 32
37
38This is quite verbose. For simplicity, the Op syntax includes the type of
39the whole expression - Word 32 - and the number of arguments - 2 - even though
40much of this information is redundant.
41
42Finally, we can encode the statement x = y + 1; at address 14.
43
4414 Basic 15 1 x Word 32 Op Plus Word 32 2 Var y Word 32 Num 1 Word 32
45
46This specifies that node 14 is a basic node (which updates variables). The
47successor node is 15. It updates 1 variable (basic nodes may simultaneously
48update many variables). The variable to be updated is specified by x Word 32.
49The remainder of the line is the expression y + 1 as seen before.
50
51We can encode the statements if (x == z) { return; } as follows:
5215 Cond 16 17 Op Equals Bool Var x Word 32 Var y Word 32
5316 Basic Ret 0
54
55Conditional node 15 continues to node 16 if x == y, and to 17 otherwise. Node
5616 is a basic node which updates no variables, i.e. a skip statement. It
57continues to the special address Ret which designates return from the current
58function.
59
60Reference
61=========
62
63A graph program consists of a number of functions, each
64of whose control flow is structured as a graph of nodes
65with integer addresses. Nodes are of three types:
66  - 'Basic' updates variable values.
67  - 'Cond' chooses between two successor nodes.
68  - 'Call' makes calls to functions.
69
70A top level syntax file (interpreted by syntax.parse_all)
71contains two kinds of sections:
72  - Function blocks introduce functions.
73  - Struct blocks introduce aggregate types (e.g. struct in C).
74
75All lines in the input syntax are read as a whitespace-separated
76list of tokens. Empty lines and lines whose first non-whitespace
77character is '#' are ignored. All syntax is prefix encoded, so
78each clause has a first token which determines what kind of clause
79it is, followed by the concatenation of all of its subclauses.
80
81The two kinds of lines that appear in Struct blocks illustrate this format:
82Struct "struct name" <int size> <int alignment>
83StructField "name" <type> <int offset>
84
85On notation: above we denote the specific tokens 'Struct' and 'StructField',
86the arbitrary tokens "name" etc, the special integer subclauses <int size> etc,
87and a <type> subclause. Integer subclauses always consist of a single token.
88A leading character '-' or '~' indicates negative. A leading '0x' or '0b'
89indicates hexadecimal or binary encoding, otherwise decimal, e.g. 1, 0x1A,
90-0x2b, ~0b1100, etc. Naming tokens can be anything at all that does not contain
91whitespace, so x y'v a,b c#d # etc are all valid struct or field names.
92
93Struct blocks begin with the 'Struct' line, which specifies their name, total
94size and required alignment. The struct block then contains a number of
95StructField lines, where each field has a type and an offset from the start of
96the structure. The Struct block ends where any other Struct or Function block
97begins.
98
99Function blocks begin with a function declaration line
100  Function "name of function" <int>*("input argument name" <type>)
101    <int>*("output argument name" <type>)
102
103The notation <int>*(...) above denotes an arbitrary length list of subclauses.
104The first token of this clause will be a decimal integer specifying the length
105of the list. For instance, the classic 'XOR' function with inputs x, y and
106output z could be specified as follows:
107Function XOR 2 x Bool y Bool 1 z Bool
108
109Each function has a name, a list of input arguments, and a list of output
110arguments. The state in the graph language is entirely encoded in variables,
111so functions will typically have global objects such as the heap as both input
112and output arguments.
113
114The function block may end immediately, for functions with unspecified body,
115or may contain a number of graph node lines, followed by an entry point line:
116EntryPoint <int>
117The entry point line ends the function block.
118
119A graph node line has one of these formats:
120<int> Basic <next node> <int>*("var name" <type> <expression>)
121<int> Cond <next node> <next node> <expression>
122<int> Call <next node> "fun name" <int>*(<expression>) <int>*("var name" <type>)
123
124Each node line begins with an integer which is the address of the node. A
125<next node> is either the integer address of the following node, or one of the
126special address tokens Ret and Err.
127
128Basic nodes update the value of a (possibly empty) list of variables. The
129variables are all updated simultaneously to the computed value of the
130expressions given.
131
132Cond nodes evaluate an expresion which must be of boolean type. They specify
133two possible next nodes, left and right. The left node is visited when the
134expression evaluates to true, the right on false.
135
136Call nodes call another function of a given name. The list of arguments
137expressions is evaluated and must match the list of input arguments of the
138given function (same length and types). These arguments define the starting
139variable environment of that function. Its return parameters are saved to the
140list of variables given.
141
142The <type> clauses are in one of these formats:
143Word <int>  -- BitVec <int> is a synonym
144Array <type> <int>
145Struct "struct name"
146Ptr <type>
147One of the builtin types as a single token:
148  'Bool Mem Dom HTD PMS UNIT Type Token RoundingMode'
149FloatingPoint <int> <int>
150
151These types encode a machine word with the given number of bits, an array of
152some type, a struct previously defined by a Struct clause, or a pointer to
153another type. The Array, Struct and Ptr types are provided only for use in
154pointer validity assertions, which are discussed below. The FloatingPoint type
155exactly mirrors the SMTLIB2 (_ FloatingPoint eb sb) type, and is available as
156an optional extension. Using the FloatingPoint and RoundingMode types in any
157problem changes the SMT theory needed and may limit which solvers and features
158can be used.
159
160Of the builtin types, booleans are standard, UNIT is the singleton type, and
161memory is a 32-bit to 8-bit mapping. We aim to include 64-bit support soon. The
162Dom type is a set of 32-bit values, used to encode the valid domain of memory
163operations. The heap type description HTD is used by pointer-validity
164operators. The phantom machine state type PMS is unspecified and used to
165represent unspecified aspects of the environment. The type Type is the type of
166Type expressions which are used in pointer-validity.
167
168The <expression> clauses have one of these formats:
169Var "name" <type>
170Op "name" <type> <int>*(<expression>)
171Num <int> <type>
172Type <type>
173Symbol "name" <type>
174
175Most of the expressions of the language are composed from variables, numerals
176and operator applications, which should be self explanatory. The Type
177expression wraps a type into an expression (of type Type) so it may be passed
178to a pointer-validity operator.
179
180The special Symbol clauses are used to denote in source languages the values
181symbols will have in the binaries. They are replaced by specific numerals in
182a first pass.
183
184Many of the builtin operators are equivalent to SMTLIB2 operators, and are also
185available with their SMTLIB2 names. The operators are:
186  - Binary arithmetic operators on words:
187    + Plus/bvadd, Minus/bvsub, Times/bvmul, Modulus/bvurem, DividedBy/bvudiv,
188      BWAnd/bvand, BWOr/bvor, BWXOR/bvxor,
189      ShiftLeft/bvshl, ShiftRight/bvlshr, SignedShiftRight/bvashr
190  - Binary operators on booleans:
191    + And/and, Or/or, Implies
192  - Unary operators on bools:
193    + Not/not
194  - Booleans (nullary operators, i.e. constants):
195    + True/true, False/false
196  - Equals relation in any type:
197    + Equals
198  - Comparison relations on words (result is bool):
199    + Less/bvult, LessEquals/bvule, SignedLess/bvslt, SignedLessEquals/bvsle
200  - Unary operators on words:
201    + BWNot/bvnot, CountLeadingZeroes, CountTrailingZeroes, WordReverse
202  - Cast operators on words - result type may be different to argument type:
203    + WordCast, WordCastSigned
204  - Memory operations:
205    + MemAcc (args [m :: Mem, ptr :: Word 32]) any word type
206    + MemUpdate (args, [m :: Mem, ptr :: Word 32, v :: any word type])
207  - Pointer-validity operators:
208    + PValid, PWeakValid, PAlignValid, PGlobalValid, PArrayValid
209  - Miscellaneous:
210    + MemDom, HTDUpdate
211    + IfThenElse/ite (args [b :: bool, x :: any type T, y :: T])
212  - FloatingPoint operations from the SMTLIB2 floating point specification:
213    + roundNearestTiesToEven/RNE, roundNearestTiesToAway/RNA,
214      roundTowardPositive/RTP, roundTowardNegative/RTN,
215      roundTowardZero/RTZ
216    + fp.abs, fp.neg, fp.add, fp.sub, fp.mul, fp.div, fp.fma, fp.sqrt,
217      fp.rem, fp.roundToIntegral, fp.min, fp.max, fp.leq, fp.lt,
218      fp.geq, fp.gt, fp.eq, fp.isNormal, fp.IsSubnormal, fp.isZero,
219      fp.isInfinite, fp.isNaN, fp.isNegative, fp.isPositive
220    + ToFloatingPoint, ToFloatingPointSigned, ToFloatingPointUnsigned,
221      FloatingPointCast
222
223Operators with SMTLIB2 equivalents have the same semantics. Mem and Dom
224operations are wrap the SMTLIB2 BitVec Array type with more convenient
225operations.
226
227Memory accesses and updates can operate on various word types.
228
229The pointer-validity operators PValid, PWeakValid, PGlobalValid all take 3
230arguments of type HTD, Type, and Word 32. The PValid operator asserts that
231the heap-type-description contains this type starting at this address. This
232is exclusive with any incompatible type. PGlobalValid additionally asserts that
233this is a global object, and therefore not contained within any larger object.
234PWeakValid asserts that the type could be valid (it is aligned and within the
235range of available addresses in the heap-type-description) and is needed only
236in rare circumstances. PAlignValid omits the HTD argument and asserts only that
237the pointer is appropriately aligned and that the object does not start at or
238wrap around the 0 address. PArrayValid takes an additional argument (HTD,
239Type, Word 32, Word 32) where the final argument specifies the number of
240entries in an array.
241
242The MemDom operator takes argument types [Word 32, Dom] and returns the boolean
243of whether this pointer is in this domain.
244
245The IfThenElse operator takes a bool and any two arguments of the same type.
246
247The FloatingPoint operators are mostly taken from the SMTLIB2 floating
248point standard. The conversions ToFloatingPoint ([Word] to FP),
249ToFloatingPointSigned and ToFloatingPointUnsigned ([RoundingMode, Word] to FP)
250and FloatingPointCast (FP to FP) represent the variants of to_fp in the SMTLIB2
251standard.
252"""
253
254class Type:
255	def __init__ (self, kind, name, el_typ=None):
256		self.kind = kind
257		if kind in ['Array', 'Word', 'TokenWords']:
258			self.num = int (name)
259		else:
260			self.name = name
261		if kind in ['Array', 'Ptr']:
262			self.el_typ_symb = el_typ
263			self.el_typ = concrete_type (el_typ)
264		if kind in ['WordArray', 'FloatingPoint']:
265			self.nums = [int (name), int (el_typ)]
266
267	def __repr__ (self):
268		if self.kind == 'Array':
269			return 'Type ("Array", %r, %r)' % (self.num,
270				self.el_typ_symb)
271		elif self.kind in ('Word', 'TokenWords'):
272			return 'Type (%r, %r)' % (self.kind, self.num)
273		elif self.kind == 'Ptr':
274			return 'Type ("Ptr", %r)' % self.el_typ_symb
275		elif self.kind in ('WordArray', 'FloatingPoint'):
276			return 'Type (%r, %r, %r)' % tuple ([self.kind]
277				+ self.nums)
278		else:
279			return 'Type (%r, %r)' % (self.kind, self.name)
280
281	def __eq__ (self, other):
282		if not other:
283			return False
284		if self.kind != other.kind:
285			return False
286		if self.kind in ['Array', 'Word', 'TokenWords']:
287			if self.num != other.num:
288				return False
289		elif self.kind in ['WordArray', 'FloatingPoint']:
290			if self.nums != other.nums:
291				return False
292		else:
293			if self.name != other.name:
294				return False
295		if self.kind in ['Array', 'Ptr']:
296			if self.el_typ_symb != other.el_typ_symb:
297				return False
298		return True
299
300	def __ne__ (self, other):
301		return not other or not (self == other)
302
303	def __hash__ (self):
304		return hash(str(self))
305
306	def __cmp__ (self, other):
307		self_ss = []
308		self.serialise (self_ss)
309		other_ss = []
310		other.serialise (other_ss)
311		return cmp (self_ss, other_ss)
312
313	def subtypes (self):
314		if self.kind == 'Struct':
315			return structs[self.name].subtypes()
316		elif self.kind == 'Array':
317			return [self] + self.el_typ.subtypes()
318		else:
319			return [self]
320
321	def size (self):
322		if self.kind == 'Struct':
323			return structs[self.name].size
324		elif self.kind == 'Array':
325			return self.el_typ.size() * self.num
326		elif self.kind == 'Word':
327			assert self.num % 8 == 0, self
328			return self.num / 8
329		elif self.kind == 'FloatingPoint':
330			sz = sum (self.nums)
331			assert sz % 8 == 0, self
332			return sz / 8
333		elif self.kind == 'Ptr':
334			return 4
335		else:
336			assert not 'type has size'
337
338	def align (self):
339		if self.kind == 'Struct':
340			return structs[self.name].align
341		elif self.kind == 'Array':
342			return self.el_typ.align ()
343		elif self.kind in ('Word', 'FloatingPoint'):
344			return self.size ()
345		elif self.kind == 'Ptr':
346			return 4
347		else:
348			assert not 'type has alignment'
349
350	def serialise (self, xs):
351		if self.kind in ('Word', 'TokenWords'):
352			xs.append (self.kind)
353			xs.append (str (self.num))
354		elif self.kind in ('WordArray', 'FloatingPoint'):
355			xs.append (self.kind)
356			xs.extend ([str (n) for n in self.nums])
357		elif self.kind == 'Builtin':
358			xs.append (self.name)
359		elif self.kind == 'Array':
360			xs.append ('Array')
361			self.el_typ_symb.serialise (xs)
362			xs.append (str (self.num))
363		elif self.kind == 'Struct':
364			xs.append ('Struct')
365			xs.append (self.name)
366		elif self.kind == 'Ptr':
367			xs.append ('Ptr')
368			self.el_typ_symb.serialise (xs)
369		else:
370			assert not 'type serialisable', self.kind
371
372class Expr:
373	def __init__ (self, kind, typ, name = None, struct = None,
374			field = None, val = None, vals = None):
375		self.kind = kind
376		self.typ = typ
377		if name != None:
378			self.name = name
379		if struct != None:
380			self.struct = struct
381		if field != None:
382			self.field = field
383		if val != None:
384			self.val = val
385		if vals != None:
386			self.vals = vals
387		if kind == 'Op':
388			assert type (self.vals) == list
389
390	def binds (self):
391		binds = []
392		if self.kind in set (['Symbol', 'Var', 'ConstGlobal', 'Token']):
393			binds.append(('name', self.name))
394		elif self.kind in ['Array', 'StructCons']:
395			binds.append(('vals', self.vals))
396		elif self.kind == 'Field':
397			binds.append(('struct', self.struct))
398			binds.append(('field', self.field))
399		elif self.kind == 'FieldUpd':
400			binds.append(('struct', self.struct))
401			binds.append(('field', self.field))
402			binds.append(('val', self.val))
403		elif self.kind == 'Num':
404			binds.append(('val', self.val))
405		elif self.kind == 'Op':
406			binds.append(('name', self.name))
407			binds.append(('vals', self.vals))
408		elif self.kind == 'Type':
409			binds.append(('val', self.val))
410		elif self.kind == 'SMTExpr':
411			binds.append(('val', self.val))
412		else:
413			assert not 'expression understood for repr', self.kind
414		return binds
415
416	def __repr__ (self):
417		bits = [repr(self.kind), repr(self.typ)]
418		bits.extend(['%s = %r' % b for b in self.binds()])
419		return 'Expr (%s)' % ', '.join(bits)
420
421	def __eq__ (self, other):
422		return (other and self.kind == other.kind
423			and self.typ == other.typ
424			and self.binds() == other.binds())
425
426	def __ne__ (self, other):
427		return not other or not (self == other)
428
429	def __hash__ (self):
430		return hash_tuplify (self.kind, self.typ, self.binds ())
431
432	def __cmp__ (self, other):
433		return cmp ((self.kind, self.typ, self.binds ()),
434			(other.kind, other.typ, other.binds ()))
435
436	def is_var (self, (nm, typ)):
437		return self.kind == 'Var' and all([self.name == nm,
438			self.typ == typ])
439
440	def is_op (self, nm):
441		if type (nm) == str:
442			return self.kind == 'Op' and self.name == nm
443		else:
444			return self.kind == 'Op' and self.name in nm
445
446	def visit (self, visit):
447		visit (self)
448		if self.kind == 'Var':
449			pass
450		elif self.kind == 'Op' or self.kind == 'Array':
451			for x in self.vals:
452				x.visit (visit)
453		elif self.kind == 'StructCons':
454			for x in self.vals.itervalues ():
455				x.visit (visit)
456		elif self.kind == 'Field':
457			self.struct.visit (visit)
458
459			struct_typ = self.struct.typ
460			assert struct_typ.kind == 'Struct'
461			struct = structs[struct_typ.name]
462			(name, typ) = self.field
463			assert struct.fields[name][0] == typ
464		elif self.kind == 'FieldUpd':
465			self.val.visit (visit)
466			self.struct.visit (visit)
467
468			struct_typ = self.struct.typ
469			assert struct_typ.kind == 'Struct'
470			struct = structs[struct_typ.name]
471			(name, typ) = self.field
472			assert struct.fields[name][0] == typ
473		elif self.kind == 'ConstGlobal':
474			assert (target_objects.const_globals[self.name].typ
475				== self.typ)
476		elif self.kind in set (['Num', 'Symbol', 'Type', 'Token']):
477			pass
478		else:
479			assert not 'expr understood', self
480
481	def gen_visit (self, visit_lval, visit_rval):
482		self.visit (visit_rval)
483
484	def subst (self, substor, ss = None):
485		ret = False
486		if self.kind == 'Op':
487			subst_vals = subst_list (substor, self.vals)
488			if subst_vals:
489				self = Expr ('Op', self.typ, name = self.name,
490					vals = subst_vals)
491				ret = True
492		if (ss == None or self.kind in ss or self.is_op (ss)):
493			r = substor (self)
494			if r != None:
495				return r
496		if ret:
497			return self
498		return
499
500	def add_const_ranges (self, ranges):
501		def visit (expr):
502			if expr.kind == 'ConstGlobal':
503				(start, size, _) = symbols[expr.name]
504				assert size == expr.typ.size ()
505				ranges[expr.name] = (start, start + size - 1)
506
507		self.visit (visit)
508
509	def get_mem_access (self):
510		if self.is_op ('MemAcc'):
511			[m, p] = self.vals
512			return [('MemAcc', p, self, m)]
513		elif self.is_op ('MemUpdate'):
514			[m, p, v] = self.vals
515			return [('MemUpdate', p, v, m)]
516		else:
517			return []
518
519	def get_mem_accesses (self):
520		accesses = []
521		def visit (expr):
522			accesses.extend (expr.get_mem_access ())
523		self.visit (visit)
524		return accesses
525
526	def serialise (self, xs):
527		xs.append (self.kind)
528		if self.kind == 'Op':
529			xs.append (self.name)
530			self.typ.serialise (xs)
531			xs.append (str (len (self.vals)))
532			for v in self.vals:
533				v.serialise (xs)
534		elif self.kind == 'Num':
535			xs.append (str (self.val))
536			self.typ.serialise (xs)
537		elif self.kind == 'Var':
538			xs.append (self.name)
539			self.typ.serialise (xs)
540		elif self.kind == 'Type':
541			self.val.serialise (xs)
542		elif self.kind == 'Token':
543			xs.extend ([self.kind, self.name])
544			self.typ.serialise (xs)
545		else:
546			assert not 'expr serialisable', self.kind
547
548class Struct:
549	def __init__ (self, name, size, align):
550		self.name = name
551		self.size = size
552		self.align = align
553		self.field_list = []
554		self.fields = {}
555		self._subtypes = None
556		self.typ = Type ('Struct', name)
557
558	def add_field (self, name, typ, offset):
559		concrete = concrete_type (typ)
560		self.field_list.append ((name, concrete))
561		self.fields[name] = (concrete, offset, typ)
562		assert self._subtypes == None
563
564	def subtypes (self):
565		if self._subtypes != None:
566			return self._subtypes
567		xs = [self.typ]
568		for (concrete, offs, typ2) in self.fields.itervalues():
569			xs.extend(typ2.subtypes())
570		self._subtypes = xs
571		return xs
572
573def tuplify (x):
574	if type(x) == tuple or type(x) == list:
575		return tuple ([tuplify (y) for y in x])
576	if type(x) == dict:
577		return tuple ([tuplify (y) for y in x.iteritems ()])
578	else:
579		return x
580
581def hash_tuplify (* xs):
582	return hash (tuplify (xs))
583
584def subst_list (substor, xs, ss = None):
585	ys = [x.subst (substor, ss = ss) for x in xs]
586	if [y for y in ys if y != None]:
587		xs = list (xs)
588		for (i, y) in enumerate (ys):
589			if y != None:
590				xs[i] = y
591		return xs
592	else:
593		return
594
595class Node:
596	def __init__ (self, kind, conts, args):
597		self.kind = kind
598
599		if type (conts) == list:
600			if len (conts) == 1:
601				the_cont = conts[0]
602		else:
603			the_cont = conts
604
605		if kind == 'Basic':
606			self.cont = the_cont
607			self.upds = [(lv, v) for (lv, v) in args
608				if not v.is_var (lv)]
609		elif kind == 'Call':
610			self.cont = the_cont
611			(self.fname, self.args, self.rets) = args
612		elif kind == 'Cond':
613			(self.left, self.right) = conts
614			self.cond = args
615		else:
616			assert not 'node kind understood', self.kind
617
618	def __repr__ (self):
619		return 'Node (%r, %r, %r)' % (self.kind,
620			self.get_conts (), self.get_args ())
621
622	def __hash__ (self):
623		if self.kind == 'Call':
624			return hash ((self.fname, tuple (self.args),
625				tuple (self.rets), self.cont))
626		elif self.kind == 'Basic':
627			return hash (tuple (self.upds))
628		elif self.kind == 'Cond':
629			return hash ((self.cond, self.left, self.right))
630		else:
631			assert not 'node kind understood', self.kind
632
633	def __eq__ (self, other):
634		return all ([self.kind == other.kind,
635			self.get_conts () == other.get_conts (),
636			self.get_args () == other.get_args ()])
637
638	def __ne__ (self, other):
639		return not other or not self == other
640
641	def get_args (self):
642		if self.kind == 'Basic':
643			return self.upds
644		elif self.kind == 'Call':
645			return (self.fname, self.args, self.rets)
646		else:
647			return self.cond
648
649	def get_conts (self):
650		if self.kind == 'Cond':
651			return [self.left, self.right]
652		else:
653			return [self.cont]
654
655	def get_lvals (self):
656		if self.kind == 'Basic':
657			return [lv for (lv, v) in self.upds]
658		elif self.kind == 'Call':
659			return self.rets
660		else:
661			return []
662
663	def is_noop (self):
664		if self.kind == 'Basic':
665			return self.upds == []
666		elif self.kind == 'Cond':
667			return self.left == self.right
668		else:
669			return False
670
671	def visit (self, visit_lval, visit_rval):
672		if self.kind == 'Basic':
673			for (lv, v) in self.upds:
674				visit_lval (lv)
675				v.visit (visit_rval)
676		elif self.kind == 'Cond':
677			self.cond.visit (visit_rval)
678		elif self.kind == 'Call':
679			for v in self.args:
680				v.visit (visit_rval)
681			for r in self.rets:
682				visit_lval (r)
683
684	def gen_visit (self, visit_lval, visit_rval):
685		self.visit (visit_lval, visit_rval)
686
687	def subst_exprs (self, substor, ss = None):
688		if self.kind == 'Basic':
689			rvs = subst_list (substor,
690				[v for (lv, v) in self.upds], ss = ss)
691			if rvs == None:
692				return self
693			return Node ('Basic', self.cont,
694				zip ([lv for (lv, v) in self.upds], rvs))
695		elif self.kind == 'Cond':
696			r = self.cond.subst (substor, ss = ss)
697			if r == None:
698				return self
699			return Node ('Cond', [self.left, self.right], r)
700		elif self.kind == 'Call':
701			args = subst_list (substor, self.args, ss = ss)
702			if args == None:
703				return self
704			return Node ('Call', self.cont, (self.fname,
705				args, self.rets))
706
707	def get_mem_accesses (self):
708		accesses = []
709		def visit (expr):
710			accesses.extend (expr.get_mem_access ())
711		self.visit (lambda x: (), visit)
712		return accesses
713
714	def err_cond (self):
715		if self.kind != 'Cond':
716			return None
717		if self.left != 'Err':
718			if self.right == 'Err':
719				return mk_not (self.cond)
720			else:
721				return None
722		else:
723			if self.right == 'Err':
724				return true_term
725			else:
726				return self.cond
727
728	def serialise (self, xs):
729		xs.append (self.kind)
730		xs.extend ([str (c) for c in self.get_conts ()])
731		if self.kind == 'Basic':
732			xs.append (str (len (self.upds)))
733			for (lv, v) in self.upds:
734				xs.append (lv[0])
735				lv[1].serialise (xs)
736				v.serialise (xs)
737		elif self.kind == 'Cond':
738			self.cond.serialise (xs)
739		elif self.kind == 'Call':
740			xs.append (self.fname)
741			xs.append (str (len (self.args)))
742			for arg in self.args:
743				arg.serialise (xs)
744			xs.append (str (len (self.rets)))
745			for (nm, typ) in self.rets:
746				xs.append (nm)
747				typ.serialise (xs)
748
749def rename_lval ((name, typ), renames):
750	return (renames.get (name, name), typ)
751
752def do_subst (expr, substor, ss = None):
753	r = expr.subst (substor, ss = ss)
754	if r == None:
755		return expr
756	else:
757		return r
758
759standard_expr_kinds = set (['Symbol', 'ConstGlobal', 'Var', 'Op', 'Num',
760	'Type'])
761
762def rename_expr_substor (renames):
763	def ren (expr):
764		if expr.kind == 'Var' and expr.name in renames:
765			return mk_var (renames[expr.name], expr.typ)
766		else:
767			return
768	return ren
769
770def rename_expr (expr, renames):
771	return do_subst (expr, rename_expr_substor (renames),
772		ss = set (['Var']))
773
774def copy_rename (node, renames):
775	(vs, ns) = renames
776	nf = lambda n: ns.get (n, n)
777	node = node.subst_exprs (rename_expr_substor (vs))
778	if node.kind == 'Call':
779		return Node ('Call', nf (node.cont), (node.fname, node.args,
780                              [rename_lval (l, vs) for l in node.rets]))
781	elif node.kind == 'Basic':
782		return Node ('Basic', nf (node.cont),
783			[(rename_lval (lv, vs), v) for (lv, v) in node.upds])
784	elif node.kind == 'Cond':
785		return Node ('Cond', [nf (node.left), nf (node.right)],
786			node.cond)
787	else:
788		assert not 'node kind understood', node.kind
789
790class Function:
791	def __init__ (self, name, inputs, outputs):
792		self.name = name
793		self.inputs = inputs
794		self.outputs = outputs
795		self.entry = None
796		self.nodes = {}
797
798	def __hash__ (self):
799		en = self.entry
800		if not en:
801			en = -1
802		return hash (tuple ([self.name, tuple (self.inputs),
803				tuple (self.outputs), en])
804			+ tuple (sorted (self.nodes.iteritems ())))
805
806	def reachable_nodes (self, simplify = False):
807		if not self.entry:
808			return {}
809		rs = {}
810		vs = [self.entry]
811		while vs:
812			n = vs.pop()
813			if type (n) == str:
814				continue
815			rs[n] = True
816			node = self.nodes[n]
817			if simplify:
818				import logic
819				node = logic.simplify_node_elementary (node)
820			for c in node.get_conts ():
821				if not c in rs:
822					vs.append (c)
823		return rs
824
825	def serialise (self):
826		xs = ['Function', self.name, str (len (self.inputs))]
827		for (nm, typ) in self.inputs:
828			xs.append (nm)
829			typ.serialise (xs)
830		xs.append (str (len (self.outputs)))
831		for (nm, typ) in self.outputs:
832			xs.append (nm)
833			typ.serialise (xs)
834		ss = [' '.join (xs)]
835		if not self.entry:
836			return ss
837		for n in self.nodes:
838			xs = [str (n)]
839			self.nodes[n].serialise (xs)
840			ss.append (' '.join (xs))
841		ss.append ('EntryPoint %d' % self.entry)
842		return ss
843
844	def as_problem (self, Problem, name = 'temp'):
845		p = Problem(None, 'Function (%s)' % self.name)
846		p.clone_function (self, name)
847		p.compute_preds ()
848		return p
849
850	def function_call_addrs (self):
851		return [(n, self.nodes[n].fname)
852			for n in self.nodes if self.nodes[n].kind == 'Call']
853
854	def function_calls (self):
855		return set ([fn for (n, fn) in self.function_call_addrs ()])
856
857	def compile_hints (self, Problem):
858		xs = ['Hints %s' % self.name]
859
860		p = self.as_problem (Problem)
861
862		for n in p.nodes:
863			ys = ['VarDeps', str (n)]
864			for (nm, typ) in p.var_deps[n]:
865				ys.append (nm)
866				typ.serialise (ys)
867			xs.append (' '.join (ys))
868			if self.nodes[n].kind != 'Basic':
869				continue
870		return xs
871
872	def save_graph (self, fname):
873		import problem
874		problem.save_graph (self.nodes, fname)
875
876def mk_builtinTs ():
877	return dict([(n, Type('Builtin', n)) for n
878	in 'Bool Mem Dom HTD PMS UNIT Type Token RelWrapper'.split()])
879builtinTs = mk_builtinTs ()
880boolT = builtinTs['Bool']
881word32T = Type ('Word', '32')
882word64T = Type ('Word', '64')
883word16T = Type ('Word', 16)
884word8T = Type ('Word', '8')
885
886phantom_types = set ([builtinTs[t] for t
887	in 'Dom HTD PMS UNIT Type'.split ()])
888
889def concrete_type (typ):
890	if typ.kind == 'Ptr':
891		return word32T
892	else:
893		return typ
894
895global_wrappers = {}
896def get_global_wrapper (typ):
897	if typ in global_wrappers:
898		return global_wrappers[typ]
899	struct_name = fresh_name ('Global (%s)' % typ, structs)
900	struct = Struct (struct_name, typ.size (), typ.align ())
901	struct.add_field ('v', typ, 0)
902	structs[struct_name] = struct
903
904	global_wrappers[typ] = struct.typ
905	return struct.typ
906
907# ==========================================================
908# parsing code for types, expressions, structs and functions
909
910def parse_int (s):
911	if s.startswith ('-') or s.startswith ('~'):
912		return (- parse_int (s[1:]))
913	if s.startswith ('0x') or s.startswith ('0b'):
914		return int (s, 0)
915	else:
916		return int (s)
917
918def parse_typ(bits, n, symbolic_types = False):
919	if bits[n] == 'Word' or bits[n] == 'BitVec':
920		return (n + 2, Type('Word', parse_int (bits[n + 1])))
921	elif bits[n] == 'WordArray' or bits[n] == 'FloatingPoint':
922		return (n + 3, Type(bits[n],
923			parse_int (bits[n + 1]), parse_int (bits[n + 2])))
924	elif bits[n] in builtinTs:
925		return (n + 1, builtinTs[bits[n]])
926	elif bits[n] == 'Array':
927		(n, typ) = parse_typ (bits, n + 1, True)
928		return (n + 1, Type ('Array', parse_int (bits[n]), typ))
929	elif bits[n] == 'Struct':
930		return (n + 2, Type ('Struct', bits[n + 1]))
931	elif bits[n] == 'Ptr':
932		(n, typ) = parse_typ (bits, n + 1, True)
933		if symbolic_types:
934			return (n, Type ('Ptr', '', typ))
935		else:
936			return (n, word32T)
937	else:
938		assert not 'type encoded', (n, bits[n:], bits)
939
940def node_name (name):
941	if name in {'Ret':True, 'Err':True}:
942		return name
943	else:
944		try:
945			return parse_int (name)
946		except ValueError, e:
947			assert not 'node name understood', name
948
949def parse_list (parser, bits, n, extra=None):
950	try:
951		num = parse_int (bits[n])
952	except ValueError:
953		assert not 'number parseable', (n, bits[n:], bits)
954	n = n + 1
955	xs = []
956	for i in range (num):
957		if extra:
958			(n, x) = parser(bits, n, extra)
959		else:
960			(n, x) = parser(bits, n)
961		xs.append(x)
962	return (n, xs)
963
964def parse_arg (bits, n):
965	nm = bits[n]
966	(n, typ) = parse_typ (bits, n + 1)
967	return (n, (nm, typ))
968
969ops = {'Plus':2, 'Minus':2, 'Times':2, 'Modulus':2,
970	'DividedBy':2, 'BWAnd':2, 'BWOr':2, 'BWXOR':2, 'And':2,
971	'Or':2, 'Implies':2, 'Equals':2, 'Less':2,
972	'LessEquals':2, 'SignedLess':2, 'SignedLessEquals':2,
973	'ShiftLeft':2, 'ShiftRight':2, 'CountLeadingZeroes':1,
974	'CountTrailingZeroes':1, 'WordReverse':1, 'SignedShiftRight':2,
975	'Not':1, 'BWNot':1, 'WordCast':1, 'WordCastSigned':1,
976	'True':0, 'False':0, 'UnspecifiedPrecond':0,
977	'MemUpdate':3, 'MemAcc':2, 'IfThenElse':3, 'ArrayIndex':2,
978	'ArrayUpdate':3, 'MemDom':2,
979	'PValid':3, 'PWeakValid':3, 'PAlignValid':2, 'PGlobalValid':3,
980	'PArrayValid':4,
981	'HTDUpdate':5, 'WordArrayAccess':2, 'WordArrayUpdate':3,
982	'TokenWordsAccess':2, 'TokenWordsUpdate':3,
983	'ROData':1, 'StackWrapper':2,
984	'ToFloatingPoint':1, 'ToFloatingPointSigned':2,
985	'ToFloatingPointUnsigned':2, 'FloatingPointCast':1,
986}
987
988ops_to_smt = {'Plus':'bvadd', 'Minus':'bvsub', 'Times':'bvmul', 'Modulus':'bvurem',
989	'DividedBy':'bvudiv', 'BWAnd':'bvand', 'BWOr':'bvor', 'BWXOR':'bvxor',
990	'And':'and',
991	'Or':'or', 'Implies':'=>', 'Equals':'=', 'Less':'bvult',
992	'LessEquals':'bvule', 'SignedLess':'bvslt', 'SignedLessEquals':'bvsle',
993	'ShiftLeft':'bvshl', 'ShiftRight':'bvlshr', 'SignedShiftRight':'bvashr',
994	'Not':'not', 'BWNot':'bvnot',
995	'True':'true', 'False':'false',
996	'UnspecifiedPrecond': 'unspecified-precond',
997	'IfThenElse':'ite', 'MemDom':'mem-dom',
998	'ROData': 'rodata', 'ImpliesROData': 'implies-rodata',
999	'WordArrayAccess':'select', 'WordArrayUpdate':'store'}
1000
1001ex_smt_ops = """roundNearestTiesToEven RNE roundNearestTiesToAway RNA
1002    roundTowardPositive RTP roundTowardNegative RTN
1003    roundTowardZero RTZ
1004    fp.abs fp.ne fp.add fp.sub fp.mul fp.div fp.fma fp.sqrt fp.rem
1005    fp.roundToInteral fp.min fp.max fp.leq fp.lt fp.eq fp.t fp.eq
1006    fp.isNormal fp.IsSubnormal fp.isZero fp.isInfinite fp.isNaN
1007    fp.isNeative fp.isPositive""".split ()
1008
1009ops_to_smt.update (dict ([(smt, smt) for smt in ex_smt_ops]))
1010
1011smt_to_ops = dict ([(smt, oper) for (oper, smt) in ops_to_smt.iteritems ()])
1012
1013def parse_struct_elem (bits, n):
1014	name = bits[n]
1015	(n, typ) = parse_typ (bits, n + 1)
1016	(n, val) = parse_expr (bits, n)
1017	return (n, (name, val))
1018
1019def parse_expr (bits, n):
1020	if bits[n] in set (['Symbol', 'Var', 'ConstGlobal', 'Token']):
1021		kind = bits[n]
1022		name = bits[n + 1]
1023		(n, typ) = parse_typ (bits, n + 2)
1024		return (n, Expr (kind, typ, name = name))
1025	if bits[n] == 'Array':
1026		(n, typ) = parse_typ (bits, n + 1)
1027		assert typ.kind == 'Array'
1028		(n, xs) = parse_list (parse_expr, bits, n)
1029		assert len(xs) == typ.num
1030		return (n, Expr ('Array', typ, vals = xs))
1031	elif bits[n] == 'Field':
1032		(n, typ) = parse_typ (bits, n + 1)
1033		name = bits[n]
1034		(n, typ2) = parse_typ (bits, n + 1)
1035		(n, struct) = parse_expr (bits, n)
1036		assert struct.typ == typ
1037		return (n, Expr ('Field', typ2, struct = struct,
1038			field = (name, typ2)))
1039	elif bits[n] == 'FieldUpd':
1040		(n, typ) = parse_typ (bits, n + 1)
1041		name = bits[n]
1042		(n, typ2) = parse_typ (bits, n + 1)
1043		(n, val) = parse_expr (bits, n)
1044		(n, struct) = parse_expr (bits, n)
1045		return (n, Expr ('FieldUpd', typ, struct = struct,
1046			field = (name, typ2), val = val))
1047	elif bits[n] == 'StructCons':
1048		(n, typ) = parse_typ (bits, n + 1)
1049		(n, xs) = parse_list (parse_struct_elem, bits, n)
1050		return (n, Expr ('StructCons', typ, vals = dict(xs)))
1051	elif bits[n] == 'Num':
1052		v = parse_int (bits[n + 1])
1053		(n, typ) = parse_typ (bits, n + 2)
1054		return (n, Expr ('Num', typ, val = v))
1055	elif bits[n] == 'Op':
1056		op = bits[n + 1]
1057		op = smt_to_ops.get (op, op)
1058		assert op in ops, op
1059		(n, typ) = parse_typ (bits, n + 2)
1060		(n, xs) = parse_list (parse_expr, bits, n)
1061		assert len (xs) == ops[op]
1062		return (n, Expr ('Op', typ, name = op, vals = xs))
1063	elif bits[n] == 'Type':
1064		(n, typ) = parse_typ (bits, n + 1, symbolic_types = True)
1065		return (n, Expr ('Type', builtinTs['Type'], val = typ))
1066	else:
1067		assert not 'expr parseable', (n, bits[n : n + 3], bits)
1068
1069def parse_lval (bits, n):
1070	name = bits[n]
1071	(n, typ) = parse_typ (bits, n + 1)
1072	return (n, (name, typ))
1073
1074def parse_lval_and_val (bits, n):
1075	(n, (name, typ)) = parse_lval (bits, n)
1076	(n, val) = parse_expr (bits, n)
1077	return (n, ((name, typ), val))
1078
1079def parse_node (bits, n):
1080	if bits[n] == 'Basic':
1081		cont = node_name(bits[n + 1])
1082		(n, upds) = parse_list (parse_lval_and_val, bits, n + 2)
1083		return Node ('Basic', cont, upds)
1084	elif bits[n] == 'Cond':
1085		left = node_name(bits[n + 1])
1086		right = node_name(bits[n + 2])
1087		(n, cond) = parse_expr (bits, n + 3)
1088		return Node ('Cond', (left, right), cond)
1089	else:
1090		assert bits[n] == 'Call'
1091		cont = node_name(bits[n + 1])
1092		name = bits[n + 2]
1093		(n, args) = parse_list (parse_expr, bits, n + 3)
1094		(n, saves) = parse_list (parse_lval, bits, n)
1095		return Node ('Call', cont, (name, args, saves))
1096
1097true_term = Expr ('Op', boolT, name = 'True', vals = [])
1098false_term = Expr ('Op', boolT, name = 'False', vals = [])
1099unspecified_precond_term = Expr ('Op', boolT, name = 'UnspecifiedPrecond', vals = [])
1100
1101def parse_all(lines):
1102	'''Toplevel parser for input information. Accepts an iterator over
1103lines. See syntax.quick_reference for an explanation.'''
1104
1105	if hasattr (lines, 'name'):
1106		trace ('Loading syntax from %s' % lines.name)
1107	else:
1108		trace ('Loading syntax (from anonymous source).')
1109
1110	structs = {}
1111	functions = {}
1112	const_globals = {}
1113	cfg_warnings = []
1114	for line in lines:
1115		bits = line.split()
1116		# empty lines and #-comments ignored
1117		if not bits or bits[0][0] == '#':
1118			continue
1119		if bits[0] == 'Struct':
1120			# Struct <name> <size> <alignment>
1121			# followed by block of StructField lines
1122			assert bits[1] not in structs
1123			current_struct = Struct (bits[1], parse_int (bits[2]),
1124				parse_int (bits[3]))
1125			structs[bits[1]] = current_struct
1126		elif bits[0] == 'StructField':
1127			# StructField <name> <type (encoded)> <offset>
1128			(_, typ) = parse_typ(bits, 2, symbolic_types = True)
1129			current_struct.add_field (bits[1], typ,
1130				parse_int (bits[-1]))
1131		elif bits[0] == 'ConstGlobalDef':
1132			# ConstGlobalDef <name> <value>
1133			name = bits[1]
1134			(_, val) = parse_expr (bits, 2)
1135			const_globals[name] = val
1136		elif bits[0] == 'Function':
1137			# Function <name> <inputs> <outputs>
1138			# followed by optional block of node lines
1139			# concluded by EntryPoint line
1140			fname = bits[1]
1141			(n, inputs) = parse_list (parse_arg, bits, 2)
1142			(_, outputs) = parse_list (parse_arg, bits, n)
1143			current_function = Function (fname, inputs, outputs)
1144			assert fname not in functions, fname
1145			functions[fname] = current_function
1146		elif bits[0] == 'EntryPoint':
1147			# EntryPoint <entry point>
1148			entry = node_name(bits[1])
1149			# instead of setting function.entry to this value,
1150			# create a dummy node. this ensures there is always
1151			# at least one node (EntryPoint Ret is valid) and
1152			# also that the entry point is not in a loop
1153			name = fresh_node (current_function.nodes)
1154			current_function.nodes[name] = Node ('Basic',
1155				entry, [])
1156			current_function.entry = name
1157			# ensure that the function graph is closed
1158			check_cfg (current_function, warnings = cfg_warnings)
1159			current_function = None
1160		else:
1161			# <node name> <node (encoded)>
1162			name = node_name(bits[0])
1163			assert name not in current_function.nodes, (name, bits)
1164			current_function.nodes[name] = parse_node (bits, 1)
1165
1166	print_cfg_warnings (cfg_warnings)
1167	trace ('Loaded %d functions, %d structs, %d globals.'
1168		% (len (functions), len (structs), len (const_globals)))
1169
1170	return (structs, functions, const_globals)
1171
1172def parse_and_install_all (lines, tag, skip_functions=None):
1173	if skip_functions == None:
1174	    skip_functions = []
1175	(structs, functions, const_globals) = parse_all (lines)
1176	for f in skip_functions:
1177	    if f in functions:
1178		del functions[f]
1179	target_objects.structs.update (structs)
1180	target_objects.functions.update (functions)
1181	target_objects.const_globals.update (const_globals)
1182	if tag != None:
1183		target_objects.functions_by_tag[tag] = set (functions)
1184	return (structs, functions, const_globals)
1185
1186# ===============================
1187# simple accessor code and checks
1188
1189def visit_rval (vs):
1190	def visit (expr):
1191		if expr.kind == 'Var':
1192			v = expr.name
1193			if v in vs:
1194				assert vs[v] == expr.typ, (expr, vs[v])
1195			vs[v] = expr.typ
1196		if expr.is_op ('MemAcc'):
1197			[m, p] = expr.vals
1198			assert p.typ == word32T, expr
1199		if expr.is_op ('PGlobalValid'):
1200			[htd, typ_expr, p] = expr.vals
1201			typ = typ_expr.val
1202			get_global_wrapper (typ)
1203
1204	return visit
1205
1206def visit_lval (vs):
1207	def visit ((name, typ)):
1208		assert vs.get(name, typ) == typ, (name, vs[name], typ)
1209		vs[name] = typ
1210
1211	return visit
1212
1213def get_expr_vars (expr, vs):
1214	expr.visit (visit_rval (vs))
1215
1216def get_expr_var_set (expr):
1217	vs = {}
1218	get_expr_vars (expr, vs)
1219	return set (vs.items ())
1220
1221def get_lval_vars (lval, vs):
1222	assert len(lval) == 2
1223	assert vs.get(lval[0], lval[1]) == lval[1]
1224	vs[lval[0]] = lval[1]
1225
1226def get_node_vars (node, vs):
1227	node.visit (visit_lval (vs), visit_rval (vs))
1228
1229def get_node_rvals (node, vs = None):
1230	if vs == None:
1231		vs = {}
1232	node.visit (lambda l: (), visit_rval (vs))
1233	return vs
1234
1235def get_vars(function):
1236	vs = dict(function.inputs + function.outputs)
1237	for node in function.nodes.itervalues():
1238		get_node_vars(node, vs)
1239	return vs
1240
1241def get_lval_typ(lval):
1242	assert len(lval) == 2
1243	return lval[1]
1244
1245def get_expr_typ(expr):
1246	return expr.typ
1247
1248def check_cfg (fun, warnings = None):
1249	dead_arcs = [(n, n2) for (n, node) in fun.nodes.iteritems ()
1250		for n2 in node.get_conts ()
1251		if n2 not in fun.nodes and n2 not in ['Ret', 'Err']]
1252	for (n, n2) in dead_arcs:
1253		assert type (n2) != str
1254		# OK if multiple dead arcs and we save over n2 twice
1255		fun.nodes[n2] = Node ('Basic', 'Err', [])
1256	if warnings == None:
1257		print_cfg_warnings ([(fun, n, n2) for (n, n2) in dead_arcs])
1258	else:
1259		warnings.extend ([(fun, n, n2) for (n, n2) in dead_arcs])
1260
1261def print_cfg_warnings (warnings):
1262	post_calls = set ([(fun.nodes[n].fname, fun.name)
1263		for (fun, n, n2) in warnings
1264		if fun.nodes[n].kind == 'Call'])
1265	import logic
1266	for (call, sites) in logic.dict_list (post_calls).iteritems ():
1267		trace ('Missing nodes after calls to %s' % call)
1268		trace ('  in %s' % str (sites))
1269	for (fun, n, n2) in warnings:
1270		if fun.nodes[n].kind != 'Call':
1271			trace ('Warning: dead arc in %s: %s -> %s'
1272				% (fun.name, n, n2))
1273			trace ('  (follows %s node!)' % fun.nodes[n].kind)
1274
1275def check_funs (functions, verbose = False):
1276	for (f, fun) in functions.iteritems():
1277		try:
1278			if not fun:
1279				continue
1280			if verbose:
1281				trace ('Checking %s' % f)
1282			check_cfg (fun)
1283			get_vars(fun)
1284			for (n, node) in fun.nodes.iteritems():
1285				if node.kind == 'Call':
1286					c = functions[node.fname]
1287					assert map(get_expr_typ, node.args) == \
1288						map (get_lval_typ, c.inputs), (
1289							 node.fname, node.args, c.inputs)
1290					assert map (get_lval_typ, node.rets) == \
1291						map (get_lval_typ, c.outputs), (
1292							 node.fname, node.rets, c.outputs)
1293				elif node.kind == 'Basic':
1294					for (lv, v) in node.upds:
1295						assert get_lval_typ(lv) == get_expr_typ(v)
1296				elif node.kind == 'Cond':
1297					assert get_expr_typ(node.cond) == boolT
1298		except Exception, e:
1299			print "check_funs: failed for " + f
1300			raise e
1301
1302def get_extensions (v):
1303	extensions = set ()
1304	rm = builtinTs['RoundingMode']
1305	def visitor (expr):
1306		if expr.typ == rm or expr.typ.kind == 'FloatingPoint':
1307			extensions.add ('FloatingPoint')
1308	v.gen_visit (lambda l: (), visitor)
1309	return extensions
1310
1311# =========================================
1312# common constructors for basic expressions
1313
1314def mk_var (nm, typ):
1315	return Expr ('Var', typ, name = nm)
1316
1317def mk_token (nm):
1318	return Expr ('Token', builtinTs['Token'], name = nm)
1319
1320def mk_plus (x, y):
1321	assert x.typ == y.typ
1322	return Expr ('Op', x.typ, name = 'Plus', vals = [x, y])
1323
1324def mk_uminus (x):
1325	zero = Expr ('Num', x.typ, val = 0)
1326	return mk_minus (zero, x)
1327
1328def mk_minus (x, y):
1329	assert x.typ == y.typ
1330	return Expr ('Op', x.typ, name = 'Minus', vals = [x, y])
1331
1332def mk_times (x, y):
1333	assert x.typ == y.typ
1334	return Expr ('Op', x.typ, name = 'Times', vals = [x, y])
1335
1336def mk_divide (x, y):
1337	assert x.typ == y.typ
1338	return Expr ('Op', x.typ, name = 'DividedBy', vals = [x, y])
1339
1340def mk_modulus (x, y):
1341	assert x.typ == y.typ
1342	return Expr ('Op', x.typ, name = 'Modulus', vals = [x, y])
1343
1344def mk_bwand (x, y):
1345	assert x.typ == y.typ
1346	assert x.typ.kind == 'Word'
1347	return Expr ('Op', x.typ, name = 'BWAnd', vals = [x, y])
1348
1349def mk_eq (x, y):
1350	assert x.typ == y.typ
1351	return Expr ('Op', boolT, name = 'Equals', vals = [x, y])
1352
1353def mk_less_eq (x, y, signed = False):
1354	assert x.typ == y.typ
1355	name = {False: 'LessEquals', True: 'SignedLessEquals'}[signed]
1356	return Expr ('Op', boolT, name = name, vals = [x, y])
1357
1358def mk_less (x, y, signed = False):
1359	assert x.typ == y.typ
1360	name = {False: 'Less', True: 'SignedLess'}[signed]
1361	return Expr ('Op', boolT, name = name, vals = [x, y])
1362
1363def mk_implies (x, y):
1364	assert x.typ == boolT
1365	assert y.typ == boolT
1366	return Expr ('Op', boolT, name = 'Implies', vals = [x, y])
1367
1368def mk_n_implies (xs, y):
1369	imp = y
1370	for x in reversed (xs):
1371		imp = mk_implies (x, imp)
1372	return imp
1373
1374def mk_and (x, y):
1375	assert x.typ == boolT
1376	assert y.typ == boolT
1377	return Expr ('Op', boolT, name = 'And', vals = [x, y])
1378
1379def mk_or (x, y):
1380	assert x.typ == boolT
1381	assert y.typ == boolT
1382	return Expr ('Op', boolT, name = 'Or', vals = [x, y])
1383
1384def mk_not (x):
1385	assert x.typ == boolT
1386	return Expr ('Op', boolT, name = 'Not', vals = [x])
1387
1388def mk_shift_gen (name, x, n):
1389	assert x.typ.kind == 'Word'
1390	if type (n) == int:
1391		n = Expr ('Num', x.typ, val = n)
1392	return Expr ('Op', x.typ, name = name, vals = [x, n])
1393
1394mk_shiftr = lambda x, n: mk_shift_gen ('ShiftRight', x, n)
1395
1396def mk_clz (x):
1397	return Expr ('Op', x.typ, name = "CountLeadingZeroes", vals = [x])
1398
1399def mk_word_reverse (x):
1400	return Expr ('Op', x.typ, name = "WordReverse", vals = [x])
1401
1402def mk_ctz (x):
1403	return mk_clz (mk_word_reverse (x))
1404
1405def foldr1 (f, xs):
1406	x = xs[-1]
1407	for i in reversed (range (len (xs) - 1)):
1408		x = f (xs[i], x)
1409	return x
1410
1411def mk_num (x, typ):
1412	import logic
1413	if logic.is_int (typ):
1414		typ = Type ('Word', typ)
1415	assert typ.kind == 'Word', typ
1416	assert logic.is_int (x), x
1417	return Expr ('Num', typ, val = x)
1418
1419def mk_word32 (x):
1420	return mk_num (x, word32T)
1421
1422def mk_word8 (x):
1423	return mk_num (x, word8T)
1424
1425def mk_word32_maybe(x):
1426	import logic
1427	if logic.is_int (x):
1428		return mk_word32 (x)
1429	else:
1430		assert x.typ == word32T
1431		return x
1432
1433def mk_cast (x, typ):
1434	if x.typ == typ:
1435		return x
1436	else:
1437		assert x.typ.kind == 'Word', x.typ
1438		assert typ.kind == 'Word', typ
1439		return Expr ('Op', typ, name = 'WordCast', vals = [x])
1440
1441def mk_memacc(m, p, typ):
1442	assert m.typ == builtinTs['Mem']
1443	assert p.typ == word32T
1444	return Expr ('Op', typ, name = 'MemAcc', vals = [m, p])
1445
1446def mk_memupd(m, p, v):
1447	assert m.typ == builtinTs['Mem']
1448	assert p.typ == word32T
1449	return Expr ('Op', m.typ, name = 'MemUpdate', vals = [m, p, v])
1450
1451def mk_arr_index (arr, i):
1452	assert arr.typ.kind == 'Array'
1453	return Expr ('Op', arr.typ.el_typ, name = 'ArrayIndex',
1454		vals = [arr, i])
1455
1456def mk_arroffs(p, typ, i):
1457	assert typ.kind == 'Array'
1458	import logic
1459	if logic.is_int (i):
1460		assert i < typ.num
1461		offs = i * typ.el_typ.size()
1462		assert offs == i or offs % 4 == 0
1463		return mk_plus (p, mk_word32 (offs))
1464	else:
1465		sz = typ.el_typ.size()
1466		return mk_plus (p, mk_times (i, mk_word32 (sz)))
1467
1468def mk_if (P, x, y):
1469	assert P.typ == boolT
1470	assert x.typ == y.typ
1471	return Expr ('Op', x.typ, name = 'IfThenElse', vals = [P, x, y])
1472
1473def mk_meta_typ (typ):
1474	return Expr ('Type', builtinTs['Type'], val = typ)
1475
1476def mk_pvalid (htd, typ, p):
1477	return Expr ('Op', boolT, name = 'PValid',
1478		vals = [htd, mk_meta_typ (typ), p])
1479
1480def mk_rel_wrapper (nm, vals):
1481	return Expr ('Op', builtinTs['RelWrapper'], name = nm, vals = vals)
1482
1483def adjust_op_vals (expr, vals):
1484	assert expr.kind == 'Op'
1485	return Expr ('Op', expr.typ, expr.name, vals = vals)
1486
1487mks = (mk_var, mk_plus, mk_uminus, mk_minus, mk_times, mk_modulus, mk_bwand,
1488mk_eq, mk_less_eq, mk_less, mk_implies, mk_and, mk_or, mk_not, mk_word32,
1489mk_word8, mk_word32_maybe, mk_cast, mk_memacc, mk_memupd, mk_arr_index,
1490mk_arroffs, mk_if, mk_meta_typ, mk_pvalid)
1491
1492# ====================================================================
1493# pretty printing code for the syntax - only used for printing reports
1494
1495def pretty_type (typ):
1496	if typ.kind == 'Word':
1497		return 'Word%d' % typ.num
1498	elif typ.kind == 'WordArray':
1499		[ix, num] = typ.nums
1500		return 'Word%d[%d]' % (num, ix)
1501	elif typ.kind == 'Ptr':
1502		return 'Ptr(%s)' % pretty_type (typ.el_typ_symb)
1503	elif typ.kind == 'Struct':
1504		return 'struct %s' % typ.name
1505	elif typ.kind == 'Builtin':
1506		return typ.name
1507	else:
1508		assert not 'type pretty-printable', typ
1509
1510pretty_opers = {'Plus': '+', 'Minus': '-', 'Times': '*'}
1511
1512known_typ_change = set (['ROData', 'MemAcc', 'IfThenElse', 'WordArrayUpdate',
1513	'MemDom'])
1514
1515def pretty_expr (expr, print_type = False):
1516	if print_type:
1517		return '((%s) (%s))' % (pretty_type (expr.typ),
1518			pretty_expr (expr))
1519	elif expr.kind == 'Var':
1520		return repr (expr.name)
1521	elif expr.kind == 'Num':
1522		return '%d' % expr.val
1523	elif expr.kind == 'Op' and expr.name in pretty_opers:
1524		[x, y] = expr.vals
1525		return '(%s %s %s)' % (pretty_expr (x), pretty_opers[expr.name],
1526			pretty_expr (y))
1527	elif expr.kind == 'Op':
1528		if expr.name in known_typ_change:
1529			vals = [pretty_expr (v) for v in expr.vals]
1530		else:
1531			vals = [pretty_expr (v, print_type = v.typ != expr.typ)
1532				for v in expr.vals]
1533		return '%s(%s)' % (expr.name, ', '.join (vals))
1534	elif expr.kind == 'Token':
1535		return "''%s''" % expr.name
1536	else:
1537		assert not 'expr pretty-printable', expr
1538
1539
1540# =================================================
1541# some helper code that's needed all over the place
1542
1543def fresh_name (n, D, v=True):
1544	if n not in D:
1545		D[n] = v
1546		return n
1547
1548	x = 1
1549	y = 1
1550	while ('%s.%d' % (n, x)) in D:
1551		y = x
1552		x = x * 2
1553	while y < x:
1554		z = (y + x) / 2
1555		if ('%s.%d' % (n, z)) in D:
1556			y = z + 1
1557		else:
1558			x = z
1559	n = '%s.%d' % (n, x)
1560	assert n not in D
1561
1562	D[n] = v
1563	return n
1564
1565def fresh_node (ns, hint = 1):
1566	n = hint
1567	# use names that are *not* multiples of 4
1568	n = (n | 15) + 2
1569	while n in ns:
1570		n += 16
1571	return n
1572
1573