1(* non-interactive mode 2*) 3structure ho_discrimTools :> ho_discrimTools = 4struct 5open HolKernel Parse boolLib; 6 7(* interactive mode 8val () = loadPath := union ["..", "../finished"] (!loadPath); 9val () = app load 10 ["bossLib", 11 "realLib", 12 "rich_listTheory", 13 "arithmeticTheory", 14 "numTheory", 15 "pred_setTheory", 16 "pairTheory", 17 "combinTheory", 18 "listTheory", 19 "dividesTheory", 20 "primeTheory", 21 "gcdTheory", 22 "probLib", 23 "HurdUseful"]; 24val () = show_assums := true; 25*) 26 27open hurdUtils ho_basicTools; 28 29infixr 0 oo ## ++ << || THENC ORELSEC THENR ORELSER; 30infix 1 >>; 31 32val op++ = op THEN; 33val op<< = op THENL; 34val op|| = op ORELSE; 35val op>> = op THEN1; 36val !! = REPEAT; 37 38(* ------------------------------------------------------------------------- *) 39(* Type/term substitutions *) 40(* ------------------------------------------------------------------------- *) 41 42val empty_raw_subst : raw_substitution = (([], empty_tmset), ([], [])); 43 44fun raw_match' tm1 tm2 ((tmS, tmIds), (tyS, tyIds)) = 45 raw_match tyIds tmIds tm1 tm2 (tmS, tyS); 46 47fun type_raw_match ty1 ty2 (sub : raw_substitution) = 48 let 49 val tm1 = mk_const ("NIL", mk_type ("list", [ty1])) 50 val tm2 = mk_const ("NIL", mk_type ("list", [ty2])) 51 in 52 raw_match' tm1 tm2 sub 53 end; 54 55val finalize_subst : raw_substitution -> substitution = norm_subst; 56 57(* ------------------------------------------------------------------------- *) 58(* A term discriminator. *) 59(* ------------------------------------------------------------------------- *) 60 61datatype pattern 62 = COMB_BEGIN 63 | COMB_END 64 | ABS_BEGIN of hol_type 65 | ABS_END 66 | CONSTANT of term 67 | BVAR of int 68 | FVAR of term * int list; 69 70fun pat_eq p1 p2 = 71 case (p1,p2) of 72 (COMB_BEGIN, COMB_BEGIN) => true 73 | (COMB_END, COMB_END) => true 74 | (ABS_BEGIN ty1, ABS_BEGIN ty2) => Type.compare(ty1,ty2) = EQUAL 75 | (ABS_END, ABS_END) => true 76 | (CONSTANT t1, CONSTANT t2) => aconv t1 t2 77 | (BVAR i1, BVAR i2) => i1 = i2 78 | (FVAR p1, FVAR p2) => pair_eq aconv equal p1 p2 79 | _ => false 80 81datatype 'a discrim = DISCRIM of int * (pattern, 'a) tree list; 82 83val empty_discrim = DISCRIM (0, []); 84fun discrim_size (DISCRIM (i, _)) = i; 85 86local 87 val bv_prefix = "bv"; 88 89 fun advance COMB_BEGIN state = state 90 | advance COMB_END (bvs, s1 :: s2 :: srest) = 91 (bvs, mk_comb (s2, s1) :: srest) 92 | advance (ABS_BEGIN ty) (bvs, stack) = 93 let 94 val bv = mk_var (mk_string_fn bv_prefix [int_to_string (length bvs)], ty) 95 in 96 (bv :: bvs, stack) 97 end 98 | advance ABS_END (bv :: bvs, s1 :: stack) = (bvs, mk_abs (bv, s1) :: stack) 99 | advance (CONSTANT tm) (bvs, stack) = (bvs, tm :: stack) 100 | advance (BVAR n) (bvs, stack) = (bvs, mk_bv bvs n :: stack) 101 | advance (FVAR fvc) (bvs, stack) = (bvs, mk_ho_pat bvs fvc :: stack) 102 | advance _ _ = raise BUG "pats_to_term" "not a valid pattern list" 103 104 fun final (_, [tm]) = tm 105 | final _ = raise BUG "pats_to_term" "not a complete pattern list" 106 107 fun final_a a state = (final state, a) 108in 109 fun pats_to_term_bvs (bvs, pats) = final (trans advance (bvs, []) pats) 110 fun pats_to_term pats = pats_to_term_bvs ([], pats) 111 fun dest_discrim (DISCRIM (_, d)) = 112 flatten (map (tree_trans advance final_a ([], [])) d) 113end; 114 115fun tm_pat_break (bvs, RIGHT tm :: rest) = 116 if is_bv bvs tm then 117 (bvs, LEFT (BVAR (dest_bv bvs tm)) :: rest) 118 else if is_ho_pat bvs tm then 119 (bvs, LEFT (FVAR (dest_ho_pat bvs tm)) :: rest) 120 else 121 (case dest_term tm of CONST _ 122 => (bvs, LEFT (CONSTANT tm) :: rest) 123 | COMB (Rator, Rand) 124 => (bvs, 125 LEFT COMB_BEGIN :: RIGHT Rator :: RIGHT Rand :: 126 LEFT COMB_END :: rest) 127 | LAMB (Bvar, Body) 128 => (Bvar :: bvs, 129 LEFT (ABS_BEGIN (type_of Bvar)) :: RIGHT Body :: 130 LEFT ABS_END :: rest) 131 | VAR _ => raise BUG "tm_pat_break" "shouldn't be a var") 132 | tm_pat_break (_, []) = 133 raise BUG "tm_pat_break" "nothing to break" 134 | tm_pat_break (_, LEFT _::_) = 135 raise BUG "tm_pat_break" "can't break apart a pattern!"; 136 137fun tm_pat_correct ABS_END ((_ : term) :: bvs) = bvs 138 | tm_pat_correct ABS_END [] = 139 raise BUG "tm_pat_correct" "no bvs left at an ABS_END" 140 | tm_pat_correct _ bvs = bvs; 141 142local 143 fun tm_pats res (_, []) = res 144 | tm_pats res (bvs, LEFT p :: rest) = 145 tm_pats (p :: res) (tm_pat_correct p bvs, rest) 146 | tm_pats res unbroken = 147 tm_pats res (tm_pat_break unbroken) 148in 149 fun term_to_pats_bvs (bvs, tm) = rev (tm_pats [] (bvs, [RIGHT tm])) 150 fun term_to_pats tm = term_to_pats_bvs ([], tm) 151end; 152 153local 154 fun add a ts [] = LEAF a :: ts 155 | add a [] (pat :: next) = [BRANCH (pat, add a [] next)] 156 | add a ((b as BRANCH (pat', ts')) :: rest) (ps as pat :: next) = 157 if pat_eq pat pat' then BRANCH (pat', add a ts' next) :: rest 158 else b :: add a rest ps 159 | add _ (LEAF _::_) (_::_) = 160 raise BUG "discrim_add" "expected a branch, got a leaf" 161in 162 fun discrim_add (tm, a) (DISCRIM (i, d)) = 163 DISCRIM (i + 1, add a d (term_to_pats tm)); 164end; 165 166fun mk_discrim l = trans discrim_add empty_discrim l; 167 168fun pat_ho_match _ (FVAR _) _ = 169 raise BUG "pat_ho_match" "can't match variables" 170 | pat_ho_match sub_thks COMB_BEGIN (_, COMB_BEGIN) = sub_thks 171 | pat_ho_match (sub, thks) (ABS_BEGIN ty) (_, ABS_BEGIN ty') = 172 (type_raw_match ty ty' sub, thks) 173 | pat_ho_match (sub, th1 :: th2 :: thks) COMB_END (_, COMB_END) = 174 (sub, (fn () => MK_COMB (th2 (), th1 ())) :: thks) 175 | pat_ho_match (sub, th :: thks) ABS_END (bv :: _, ABS_END) = 176 (sub, (fn () => MK_ABS (GEN bv (th ()))) :: thks) 177 | pat_ho_match (sub, thks) (BVAR n) (bvs, BVAR n') = 178 if n = n' then (sub, (fn () => REFL (mk_bv bvs n)) :: thks) 179 else raise ERR "pat_ho_match" "different bound vars" 180 | pat_ho_match (sub, thks) (CONSTANT c) (_, CONSTANT c') = 181 (raw_match' c c' sub, (fn () => REFL c') :: thks) 182 | pat_ho_match _ _ _ = 183 raise ERR "pat_ho_match" "pats fundamentally different"; 184 185local 186 fun advance (FVAR (fv, fbs)) ((bvs, RIGHT tm::rest), (sub, thks)) = 187 let 188 val (sub', thk') = ho_raw_match (fv, fbs) (bvs, tm) sub 189 in 190 ((bvs, rest), (sub', thk' :: thks)) 191 end 192 | advance pat (state as (_, RIGHT _::_), ho_sub) = 193 advance pat (tm_pat_break state, ho_sub) 194 | advance pat ((bvs, LEFT pat' :: rest), ho_sub) = 195 let 196 val ho_sub' = pat_ho_match ho_sub pat (bvs, pat') 197 in 198 ((tm_pat_correct pat' bvs, rest), ho_sub') 199 end 200 | advance _ ((_, []), _) = 201 raise BUG "discrim_ho_match" "no pats left in list" 202 203 fun finally a ((_, []), (sub, [thk])) = 204 SOME ((finalize_subst sub, fn () => SYM (thk ())), a) 205 | finally _ _ = raise BUG "discrim_ho_match" "weird state at end" 206 207 fun tree_match tm = 208 tree_partial_trans 209 (fn x => total (advance x)) 210 finally 211 (([], [RIGHT tm]), (empty_raw_subst, [])) 212in 213 fun discrim_ho_match (DISCRIM (_, d)) tm = 214 (flatten o map (tree_match tm)) d 215 handle (h as HOL_ERR _) => raise err_BUG "discrim_ho_match" h; 216end; 217 218fun pat_fo_match _ (FVAR _) _ = 219 raise BUG "pat_fo_match" "can't match variables" 220 | pat_fo_match sub COMB_BEGIN COMB_BEGIN = sub 221 | pat_fo_match sub (ABS_BEGIN ty) (ABS_BEGIN ty') = type_raw_match ty ty' sub 222 | pat_fo_match sub COMB_END COMB_END = sub 223 | pat_fo_match sub ABS_END ABS_END = sub 224 | pat_fo_match sub (BVAR n) (BVAR n') = 225 if n = n' then sub else raise ERR "pat_fo_match" "different bound vars" 226 | pat_fo_match sub (CONSTANT c) (CONSTANT c') = raw_match' c c' sub 227 | pat_fo_match _ _ _ = 228 raise ERR "pat_fo_match" "pats fundamentally different"; 229 230local 231 fun advance (FVAR (fv, fbs)) ((bvs, RIGHT tm :: rest), sub) = 232 let 233 val sub' = fo_raw_match (fv, fbs) (bvs, tm) sub 234 in 235 ((bvs, rest), sub') 236 end 237 | advance pat (state as (_, RIGHT _::_), sub) = 238 advance pat (tm_pat_break state, sub) 239 | advance pat ((bvs, LEFT pat' :: rest), sub) = 240 let 241 val sub' = pat_fo_match sub pat pat' 242 in 243 ((tm_pat_correct pat' bvs, rest), sub') 244 end 245 | advance _ ((_, []), _) = 246 raise BUG "discrim_fo_match" "no pats left in list" 247 248 fun finally a ((_, []), sub) = 249 SOME (finalize_subst sub, a) 250 | finally _ _ = raise BUG "discrim_fo_match" "weird state at end" 251 252 fun tree_match tm = 253 tree_partial_trans (fn x => total (advance x)) 254 finally (([], [RIGHT tm]), empty_raw_subst) 255in 256 fun discrim_fo_match (DISCRIM (_, d)) tm = 257 (flatten o map (tree_match tm)) d 258 handle HOL_ERR _ => raise BUG "discrim_fo_match" "should never fail"; 259end; 260 261(* ------------------------------------------------------------------------- *) 262(* Variable Term Discriminators *) 263(* Terms come with a list of variables that may be instantiated, the rest *) 264(* should be treated as constants. *) 265(* ------------------------------------------------------------------------- *) 266 267type 'a vdiscrim = (vars * 'a) discrim; 268 269val empty_vdiscrim : 'a vdiscrim = empty_discrim; 270 271val vdiscrim_size : 'a vdiscrim -> int = discrim_size; 272 273local 274 fun dest (tm : term, (vars : vars, a)) = ((vars, tm), a) 275in 276 fun dest_vdiscrim d = (map dest o dest_discrim) d; 277end; 278 279local 280 fun prepare (vars : vars) (tm, a) = (tm, (vars, a)) 281in 282 fun vdiscrim_add ((vars, tm), a) = discrim_add (prepare vars (tm, a)); 283end; 284 285fun mk_vdiscrim l = trans vdiscrim_add empty_discrim l; 286 287local 288 fun vars_check (ho_sub as (sub, _) : ho_substitution, (vars, a)) = 289 if subst_vars_in_set sub vars then SOME (ho_sub, a) else NONE 290 fun check_ho_match ml = partial_map vars_check ml 291in 292 fun vdiscrim_ho_match d tm = check_ho_match (discrim_ho_match d tm); 293end; 294 295local 296 fun vars_check (sub : substitution, (vars, a)) = 297 if subst_vars_in_set sub vars then SOME (sub, a) else NONE 298 fun check_fo_match ml = partial_map vars_check ml 299in 300 fun vdiscrim_fo_match d tm = check_fo_match (discrim_fo_match d tm); 301end; 302 303(* ------------------------------------------------------------------------- *) 304(* Ordered (Variable) Term Discriminators *) 305(* Entries are returned in the order they arrived (latest first). *) 306(* ------------------------------------------------------------------------- *) 307 308type 'a ovdiscrim = (int * 'a) vdiscrim; 309 310val empty_ovdiscrim : 'a ovdiscrim = empty_vdiscrim; 311 312val ovdiscrim_size : 'a ovdiscrim -> int = vdiscrim_size; 313 314local 315 fun transfer (a, (n : int, b)) = (n, (a, b)); 316 fun order (m, _) (n, _) = m > n; 317 fun dest dl = (map snd o sort order o map transfer) dl; 318in 319 fun dest_ovdiscrim d = (dest o dest_vdiscrim) d; 320 321 fun ovdiscrim_ho_match ml d = dest (vdiscrim_ho_match ml d) 322 handle (h as HOL_ERR _) => raise err_BUG "ovdiscrim_ho_match" h; 323 324 fun ovdiscrim_fo_match ml d = dest (vdiscrim_fo_match ml d) 325 handle (h as HOL_ERR _) => raise err_BUG "ovdiscrim_fo_match" h; 326end; 327 328fun ovdiscrim_add (tm, a) d = vdiscrim_add (tm, (vdiscrim_size d, a)) d; 329 330fun mk_ovdiscrim l = trans ovdiscrim_add empty_discrim l; 331 332(* non-interactive mode 333*) 334end; 335