1
2(* utility functions for working with bdd's and term-bdd's *)
3
4structure bddTools =
5struct
6
7local
8
9open Globals HolKernel Parse
10infixr 3 -->;
11
12open Psyntax bossLib pairTheory pred_setTheory pred_setLib stringLib
13     listTheory simpLib pairSyntax pairLib PrimitiveBddRules
14     DerivedBddRules Binarymap PairRules pairTools boolSyntax Drule
15     Tactical Conv Rewrite Tactic boolTheory listSyntax stringTheory
16     boolSimps pureSimps listSimps numLib HolSatLib metisLib
17
18open stringBinTree reachTheory commonTools
19
20val dpfx = "bto_"
21
22in
23
24fun t2tb vm t = DerivedBddRules.GenTermToTermBdd (!DerivedBddRules.termToTermBddFun) vm t
25
26fun mk_tb_res_subst red res vm = ListPair.map (fn (v,c) => (BddVar true vm v,BddCon (c=T) vm)) (red,res)
27
28
29fun BddListConj vm (h::t) = if (List.null t) then h else PrimitiveBddRules.BddOp (bdd.And, h, (BddListConj vm t))
30|   BddListConj vm [] = PrimitiveBddRules.BddCon true vm;
31
32fun BddListDisj vm (h::t) = if (List.null t) then h else PrimitiveBddRules.BddOp (bdd.Or, h, (BddListDisj vm t))
33|   BddListDisj vm [] = PrimitiveBddRules.BddCon false vm;
34
35
36(* return bdd b as a DNF term (this is similar to the output of bdd.printset and in fact mimics the code) *)
37(* used when the term part of bdd is higher order but we need the boolean equivalent                      *)
38(* and it would be inefficient to unwind the higher order bits                                            *)
39(* bddToTerm returns a nested i-t-e term that can get way too big                                         *)
40fun bdd2dnf vm b =
41    if (bdd.equal b bdd.TRUE) then T
42    else if (bdd.equal b bdd.FALSE) then F
43    else let val pairs = Binarymap.listItems vm
44             fun get_var n =
45                 case assoc2 n pairs of
46                     SOME(str,_) => mk_var(str,bool)
47                   | NONE        => (failwith("bdd2dnf: Node "^(Int.toString n)^" has no name"))
48             fun b2t_aux b assl =
49                 if (bdd.equal b bdd.TRUE)
50                 then [assl]
51                 else
52                     if (bdd.equal b bdd.FALSE)
53                     then []
54                     else let val v = get_var(bdd.var b)
55                          in (b2t_aux (bdd.high b) (v::assl))@(b2t_aux (bdd.low b) ((mk_neg v)::assl)) end
56         in
57             list_mk_disj (List.map list_mk_conj (b2t_aux b []))
58         end;
59
60fun bdd2cnf vm b =
61    if (bdd.equal b bdd.TRUE) then T
62    else if (bdd.equal b bdd.FALSE) then F
63    else let val pairs = Binarymap.listItems vm
64             fun get_var n =
65                 case assoc2 n pairs of
66                     SOME(str,_) => mk_var(str,bool)
67                   | NONE        => (failwith("bdd2cnf: Node "^(Int.toString n)^" has no name"))
68             fun b2t_aux b assl =
69                 if (bdd.equal b bdd.TRUE)
70                 then []
71                 else
72                     if (bdd.equal b bdd.FALSE)
73                     then [assl]
74                     else let val v = get_var(bdd.var b)
75                          in (b2t_aux (bdd.low b) (v::assl))@(b2t_aux (bdd.high b) ((mk_neg v)::assl)) end
76         in
77             list_mk_conj (List.map list_mk_disj (b2t_aux b []))
78         end;
79
80fun getIntForVar vm (s:string) =  Binarymap.find(vm,s);
81
82fun getVarForInt vm (i:int) =
83    let val l = List.filter (fn (ks,ki) => ki=i) (Binarymap.listItems vm)
84in if List.null l then NONE else SOME (fst(List.hd l)) end
85
86fun termToBdd vm t = PrimitiveBddRules.getBdd(DerivedBddRules.GenTermToTermBdd (!DerivedBddRules.termToTermBddFun) vm t)
87
88(* transform term part of term-bdd using the supplied conversion; suppress UNCHANGED exceptions *)
89fun BddConv conv tb = DerivedBddRules.BddApConv conv tb handle Conv.UNCHANGED => tb;
90
91(* spells out one state in the bdd b *)
92fun gba b vm =
93let val al = bdd.getAssignment (bdd.toAssignment_ b)
94    fun lkp i = fst(List.hd (List.filter (fn (k,j) => j=i) (Binarymap.listItems vm)))
95    in List.map (fn (i,bl) => (lkp i, bl)) al end
96
97(* folds in all the messing about with pair sets and what nots *)
98fun bdd_replace vm b subs =
99    let val vprs = List.map dest_subst subs
100        val nprs = List.map (getIntForVar vm o term_to_string2 ## getIntForVar vm o term_to_string2) vprs
101        val nsubs =  bdd.makepairSet nprs
102        val res = bdd.replace b nsubs
103    in res end
104
105(* given a string from the output of bdd.printset (less the angle brackets), constructs equivalent bdd *)
106fun mk_bdd s =
107let val vars = List.map (fn (vr,vl) => if vl=0 then bdd.nithvar vr else bdd.ithvar vr)
108                        (List.map (fn arg =>
109                                      let val var = List.hd arg
110                                          val vl = List.last arg
111                                      in ((Option.valOf o Int.fromString) var,
112                                          (Option.valOf o Int.fromString) vl)
113                                      end)
114                                  (List.map (String.tokens (fn c => Char.compare(c,#":")=EQUAL))
115                                            (String.tokens (fn c =>  Char.compare(c,#",")=EQUAL) s)))
116    in List.foldl (fn (abdd,bdd) => bdd.AND(abdd,bdd)) (bdd.TRUE) vars end
117
118(* constructs the bdd of one of the states of b, including only the vars in vm *)
119fun mk_pt b vm =
120    let
121        val _ = dbgTools.DEN dpfx "mpt" (*DBG*)
122        val res =
123            if bdd.equal bdd.FALSE b then bdd.FALSE
124            else let val b1 =  List.map (fn (vi,tv) => if tv then bdd.ithvar vi else bdd.nithvar vi)
125                                        (List.filter (fn (vi,tv) => Option.isSome(getVarForInt vm vi))
126                                                 (bdd.getAssignment (bdd.fullsatone b)))
127                 in List.foldl (fn (abdd,bdd) => bdd.AND(abdd,bdd)) (bdd.TRUE) b1 end
128        val _ = dbgTools.DBD (dpfx^"mp_res") res (*DBG*)
129        val _ = dbgTools.DEX dpfx "mpt" (*DBG*)
130    in res end
131
132(* computes the image under bR of b1 *) (*FIXME: what's with the foldl's in the second last line?*)
133fun mk_next state bR vm b1 =
134    let
135        val _ = dbgTools.DEN dpfx "mn" (*DBG*)
136        fun getIntForVar v = Binarymap.find(vm,v)
137        val sv = List.map term_to_string (strip_pair state)
138        val svi =  List.map getIntForVar sv
139        val spi = List.map getIntForVar (List.map (fn v => v^"'") sv)
140        val s = bdd.makeset svi
141        val sp2s =  bdd.makepairSet (ListPair.zip(List.foldl (fn (h,t) => h::t) [] (spi),
142                                                  List.foldl (fn (h,t) => h::t) [] (svi)))
143        val res = bdd.replace (bdd.appex bR b1 bdd.And s) sp2s
144        val _ = dbgTools.DEX dpfx "mn" (*DBG*)
145    in res end
146
147(* computes the preimage under bR of b1 *)(*FIXME: what's with the foldl's in the second last line?*)
148fun mk_prev state bR vm b1 =
149    let
150        val _ = dbgTools.DEN dpfx "mpv" (*DBG*)
151        fun getIntForVar v = Binarymap.find(vm,v)
152        val sv = List.map term_to_string (strip_pair state)
153        val svi =  List.map getIntForVar sv
154        val spi = List.map getIntForVar (List.map (fn v => v^"'") sv)
155        val sp = bdd.makeset spi
156        val s2sp =  bdd.makepairSet (ListPair.zip(List.foldl (fn (h,t) => h::t) [] (svi),
157                                                  List.foldl (fn (h,t) => h::t) [] (spi)))
158        val res = bdd.appex bR (bdd.replace b1 s2sp) bdd.And sp
159        val _ = dbgTools.DEX dpfx "mpv" (*DBG*)
160    in res end
161
162fun mk_g'' ((fvt',t')::fvl) (fvt,t) ofvl =
163       if (Binaryset.isEmpty(Binaryset.intersection(fvt,fvt')))
164       then mk_g'' fvl (fvt,t) ofvl
165       else let val ofvl' = List.filter (fn (_,t) => not(Term.compare(t,t')=EQUAL)) ofvl
166            in Binaryset.add(mk_g'' ofvl' (Binaryset.union(fvt,fvt'),t) ofvl',(fvt',t')) end
167| mk_g'' [] (fvt,t) ofvl = Binaryset.add (Binaryset.empty (Term.compare o (snd ## snd)),(fvt,t))
168
169fun mk_g' ((fvt,t)::fvl) =
170    let val fvs' = mk_g'' fvl (fvt,t) fvl
171        val fvs = Binaryset.addList(Binaryset.empty (Term.compare o (snd ## snd)),fvl)
172    in (fvs'::(mk_g' (Binaryset.listItems (Binaryset.difference(fvs,fvs'))))) end
173| mk_g' [] = []
174
175(* group terms in tc by free_vars *)
176fun mk_g tc =
177    let val fvl = ListPair.zip(List.map (fn t => Binaryset.addList(Binaryset.empty Term.compare, free_vars t)) tc,tc)
178        val vcfc = mk_g' fvl
179    in List.map (fn l => List.foldl (fn ((fvt,t),(fvta,ta)) => (Binaryset.union(fvt,fvta),t::ta)) (Binaryset.empty Term.compare,[]) l) (List.map Binaryset.listItems vcfc) end
180
181(* given a string*bool list and a term list, uses the first list as a set of substitutions for the terms, and simplify,
182   filtering out any that simplify to true  *)
183(* this is used with the output of gba (t being the conjuncts of R as grouped by mk_g) to get a term representation for the next state of the state given by sb *)
184fun mk_sb sb t =
185 let val hsb = List.map (fn (t1,t2) => (mk_var(t1,bool)) |-> (if t2 then T else F)) sb
186 in List.map (fn (t,t') => if (Term.compare(F,t)=EQUAL) then (t,SOME t') else (t,NONE))
187         (List.filter (fn (t,t') => not (Term.compare(T,t)=EQUAL))
188         (List.map (fn (t,t') => (rhs(concl(SIMP_CONV std_ss [] (Term.subst hsb t))) handle ex => (Term.subst hsb t),t'))
189          (ListPair.zip(t,t))))
190 end
191
192(* return a satisfying assignment for t, as a HOL subst *)
193fun findAss t =
194    let val th = SAT_PROVE (mk_neg t) handle HolSatLib.SAT_cex th => th
195        val t = strip_conj (fst(dest_imp (concl th)))
196        val t1 = List.filter (fn v =>  (if is_neg v then not (is_genvar(dest_neg v)) else not (is_genvar v))) t
197        fun ncompx v = not (String.compare(term_to_string v, "x")=EQUAL)
198        val t2 = List.filter (fn v => if is_neg v then ncompx (dest_neg v) else ncompx v) t1
199    in  List.map (fn v => if is_neg v then (dest_neg v) |-> F else v |-> T) t2 end
200
201(* given a list of vars and a HOL assignment to perhaps not all the vars in the list,
202   return an order preserving list of bool assgns *)
203(* this is for use with MAP_EVERY EXISTS_TAC *)
204fun exv l ass =
205let val t1 = List.map (fn v => subst ass v) l
206    in List.map (fn v => if is_var v then T else v) t1 end;
207
208(* given an existential goal,
209 replaces all quantified variables with satisfying values (assumes entire goal is propositional)*)
210fun SAT_EXISTS_TAC (asl,w) =
211    let val (vl,t) = strip_exists w
212        val ass = findAss t
213        val exl = exv vl ass
214in (MAP_EVERY EXISTS_TAC exl) (asl,w) end
215
216(* take a point bdd (i.e. just one state) and return it as concrete instance of state
217   annotated with var names *)
218fun pt_bdd2state state vm pb =
219    let val _ = dbgTools.DEN dpfx "pb2s"(*DBG*)
220        val _ = dbgTools.DBD (dpfx^"pb2s_pb") pb(*DBG*)
221        val _ = Vector.app (dbgTools.DNM (dpfx^"pb2s_pb_support")) (bdd.scanset (bdd.support pb)) (*DBG*)
222        val i2val = list2imap((bdd.getAssignment o bdd.toAssignment_) pb)
223        val res = list_mk_pair (List.map (fn v => if Binarymap.find(i2val,Binarymap.find(vm,v))
224                                                  then mk_bool_var v else mk_neg (mk_bool_var v)
225                                         handle ex => mk_bool_var v)
226                                         (List.map term_to_string2 (strip_pair state)))
227        val _ = dbgTools.DEX dpfx "pb2s"(*DBG*)
228    in res end
229
230
231(* make varmap. if ordering is not given, just shuffle the current and next state vars.
232   FIXME: do a better default ordering *)
233fun mk_varmap state bvm =
234    let val bvm = if (Option.isSome bvm) then Option.valOf bvm
235                  else let val st = strip_pair state
236                           val st' = List.map prime st
237                           val bvm = List.map (term_to_string2) (List.concat (List.map (fn (v,v') => [v',v])
238                                                                                       (ListPair.zip(st,st'))))
239                       in bvm end
240        val vm = List.foldr (fn(v,vm') => Varmap.insert v vm') (Varmap.empty)
241                            (ListPair.zip(bvm,(List.tabulate(List.length bvm,I))))
242        val _ = if (bdd.getVarnum()<(List.length bvm))
243                then bdd.setVarnum (List.length bvm) else () (* this tells BuDDy where and what the vars are *)
244    in vm end
245
246end
247end
248
249(*
250(*FIXME: move this comment into documentation *)
251(* debugging usage example *)
252(* this assumes I1, R1, T1, ks_def and wfKS_ks have been computed... see alu.sml or ahb_total.sml or scratch.sml on howto for that*)
253load "cearTools";
254load "debugTools";
255open cearTools;
256open debugTools;
257val sc = DerivedBddRules.statecount;
258val dtb = PrimitiveBddRules.dest_term_bdd;
259open PrimitiveBddRules;
260        val vm = List.foldr (fn(v,vm') => Varmap.insert v vm') (Varmap.empty)
261                            (ListPair.zip(bvm,(List.tabulate(List.length bvm,fn x => x))))
262        val _ = bdd.setVarnum (List.length bvm) (* this tells BuDDy where and what the vars are *)
263        val tbRS = muTools.RcomputeReachable (R1,I1) vm;
264        val brs = Primiti#veBddRules.getBdd tbRS;
265        val Ree = Array.fromList []
266        val RTm = muCheck.RmakeTmap T1 vm
267        val Tm = List.map (fn (nm,tb) => (nm,getTerm tb)) (Binarymap.listItems RTm) (* using nontotal R for ahbapb composition *)
268        val (dks_def,wfKS_dks) = muCheck.mk_wfKS Tm I1 NONE NONE
269        val chk = fn mf => muCheck.muCheck RTm Ree I1 mf (dks_def,wfKS_dks) vm NONE handle ex => Raise ex;
270(* note how dbg below returns debugging info such as a bad state, it's forward and rear states and more readable versions of those *)
271fun chk2 cf = let val tb2 = chk (ctl2mu cf)
272                  val bb2 = PrimitiveBddRules.getBdd tb2
273                  val bd = bdd.DIFF(brs,bdd.AND(brs,bb2))
274                  val dbg = if bdd.equal bi (bdd.AND(bi,bdd.AND(brs,bb2))) then NONE
275                            else let val Rtb = Binarymap.find (RTm,".")
276                                     val b2 = mk_pt bd vm
277                                     val bn = mk_next I1 (getBdd Rtb) vm b2
278                                     val bp = mk_prev I1 (getBdd Rtb) vm b2
279                                     val sb = mk_sb (gba b2 vm) (strip_conj (getTerm Rtb))
280                                 in SOME (Rtb,b2,bn,bp,sb) end
281              in (tb2,bdd.AND(brs,bb2),bd,dbg) end;
282
283*)
284