1(* non-interactive mode 2*) 3structure ho_discrimTools :> ho_discrimTools = 4struct 5open HolKernel Parse boolLib; 6 7(* interactive mode 8val () = loadPath := union ["..", "../finished"] (!loadPath); 9val () = app load 10 ["bossLib", 11 "realLib", 12 "rich_listTheory", 13 "arithmeticTheory", 14 "numTheory", 15 "pred_setTheory", 16 "pairTheory", 17 "combinTheory", 18 "listTheory", 19 "dividesTheory", 20 "primeTheory", 21 "gcdTheory", 22 "probLib", 23 "HurdUseful"]; 24val () = show_assums := true; 25*) 26 27open HurdUseful ho_basicTools; 28 29infixr 0 oo ## ++ << || THENC ORELSEC THENR ORELSER; 30infix 1 >>; 31 32val op++ = op THEN; 33val op<< = op THENL; 34val op|| = op ORELSE; 35val op>> = op THEN1; 36val !! = REPEAT; 37 38(* ------------------------------------------------------------------------- *) 39(* Type/term substitutions *) 40(* ------------------------------------------------------------------------- *) 41 42val empty_raw_subst : raw_substitution = (([], empty_tmset), ([], [])); 43 44fun raw_match' tm1 tm2 ((tmS, tmIds), (tyS, tyIds)) = 45 raw_match tyIds tmIds tm1 tm2 (tmS, tyS); 46 47fun type_raw_match ty1 ty2 (sub : raw_substitution) = 48 let 49 val tm1 = mk_const ("NIL", mk_type ("list", [ty1])) 50 val tm2 = mk_const ("NIL", mk_type ("list", [ty2])) 51 in 52 raw_match' tm1 tm2 sub 53 end; 54 55val finalize_subst : raw_substitution -> substitution = norm_subst; 56 57(* ------------------------------------------------------------------------- *) 58(* A term discriminator. *) 59(* ------------------------------------------------------------------------- *) 60 61datatype pattern 62 = COMB_BEGIN 63 | COMB_END 64 | ABS_BEGIN of hol_type 65 | ABS_END 66 | CONSTANT of term 67 | BVAR of int 68 | FVAR of term * int list; 69 70datatype 'a discrim = DISCRIM of int * (pattern, 'a) tree list; 71 72val empty_discrim = DISCRIM (0, []); 73fun discrim_size (DISCRIM (i, _)) = i; 74 75local 76 val bv_prefix = "bv"; 77 78 fun advance COMB_BEGIN state = state 79 | advance COMB_END (bvs, s1 :: s2 :: srest) = 80 (bvs, mk_comb (s2, s1) :: srest) 81 | advance (ABS_BEGIN ty) (bvs, stack) = 82 let 83 val bv = mk_var (mk_string_fn bv_prefix [int_to_string (length bvs)], ty) 84 in 85 (bv :: bvs, stack) 86 end 87 | advance ABS_END (bv :: bvs, s1 :: stack) = (bvs, mk_abs (bv, s1) :: stack) 88 | advance (CONSTANT tm) (bvs, stack) = (bvs, tm :: stack) 89 | advance (BVAR n) (bvs, stack) = (bvs, mk_bv bvs n :: stack) 90 | advance (FVAR fvc) (bvs, stack) = (bvs, mk_ho_pat bvs fvc :: stack) 91 | advance _ _ = raise BUG "pats_to_term" "not a valid pattern list" 92 93 fun final (_, [tm]) = tm 94 | final _ = raise BUG "pats_to_term" "not a complete pattern list" 95 96 fun final_a a state = (final state, a) 97in 98 fun pats_to_term_bvs (bvs, pats) = final (trans advance (bvs, []) pats) 99 fun pats_to_term pats = pats_to_term_bvs ([], pats) 100 fun dest_discrim (DISCRIM (_, d)) = 101 flatten (map (tree_trans advance final_a ([], [])) d) 102end; 103 104fun tm_pat_break (bvs, RIGHT tm :: rest) = 105 if is_bv bvs tm then 106 (bvs, LEFT (BVAR (dest_bv bvs tm)) :: rest) 107 else if is_ho_pat bvs tm then 108 (bvs, LEFT (FVAR (dest_ho_pat bvs tm)) :: rest) 109 else 110 (case dest_term tm of CONST _ 111 => (bvs, LEFT (CONSTANT tm) :: rest) 112 | COMB (Rator, Rand) 113 => (bvs, 114 LEFT COMB_BEGIN :: RIGHT Rator :: RIGHT Rand :: 115 LEFT COMB_END :: rest) 116 | LAMB (Bvar, Body) 117 => (Bvar :: bvs, 118 LEFT (ABS_BEGIN (type_of Bvar)) :: RIGHT Body :: 119 LEFT ABS_END :: rest) 120 | VAR _ => raise BUG "tm_pat_break" "shouldn't be a var") 121 | tm_pat_break (_, []) = 122 raise BUG "tm_pat_break" "nothing to break" 123 | tm_pat_break (_, LEFT _::_) = 124 raise BUG "tm_pat_break" "can't break apart a pattern!"; 125 126fun tm_pat_correct ABS_END ((_ : term) :: bvs) = bvs 127 | tm_pat_correct ABS_END [] = 128 raise BUG "tm_pat_correct" "no bvs left at an ABS_END" 129 | tm_pat_correct _ bvs = bvs; 130 131local 132 fun tm_pats res (_, []) = res 133 | tm_pats res (bvs, LEFT p :: rest) = 134 tm_pats (p :: res) (tm_pat_correct p bvs, rest) 135 | tm_pats res unbroken = 136 tm_pats res (tm_pat_break unbroken) 137in 138 fun term_to_pats_bvs (bvs, tm) = rev (tm_pats [] (bvs, [RIGHT tm])) 139 fun term_to_pats tm = term_to_pats_bvs ([], tm) 140end; 141 142local 143 fun add a ts [] = LEAF a :: ts 144 | add a [] (pat :: next) = [BRANCH (pat, add a [] next)] 145 | add a ((b as BRANCH (pat', ts')) :: rest) (ps as pat :: next) = 146 if pat = pat' then BRANCH (pat', add a ts' next) :: rest 147 else b :: add a rest ps 148 | add _ (LEAF _::_) (_::_) = 149 raise BUG "discrim_add" "expected a branch, got a leaf" 150in 151 fun discrim_add (tm, a) (DISCRIM (i, d)) = 152 DISCRIM (i + 1, add a d (term_to_pats tm)); 153end; 154 155fun mk_discrim l = trans discrim_add empty_discrim l; 156 157fun pat_ho_match _ (FVAR _) _ = 158 raise BUG "pat_ho_match" "can't match variables" 159 | pat_ho_match sub_thks COMB_BEGIN (_, COMB_BEGIN) = sub_thks 160 | pat_ho_match (sub, thks) (ABS_BEGIN ty) (_, ABS_BEGIN ty') = 161 (type_raw_match ty ty' sub, thks) 162 | pat_ho_match (sub, th1 :: th2 :: thks) COMB_END (_, COMB_END) = 163 (sub, (fn () => MK_COMB (th2 (), th1 ())) :: thks) 164 | pat_ho_match (sub, th :: thks) ABS_END (bv :: _, ABS_END) = 165 (sub, (fn () => MK_ABS (GEN bv (th ()))) :: thks) 166 | pat_ho_match (sub, thks) (BVAR n) (bvs, BVAR n') = 167 if n = n' then (sub, (fn () => REFL (mk_bv bvs n)) :: thks) 168 else raise ERR "pat_ho_match" "different bound vars" 169 | pat_ho_match (sub, thks) (CONSTANT c) (_, CONSTANT c') = 170 (raw_match' c c' sub, (fn () => REFL c') :: thks) 171 | pat_ho_match _ _ _ = 172 raise ERR "pat_ho_match" "pats fundamentally different"; 173 174local 175 fun advance (FVAR (fv, fbs)) ((bvs, RIGHT tm::rest), (sub, thks)) = 176 let 177 val (sub', thk') = ho_raw_match (fv, fbs) (bvs, tm) sub 178 in 179 ((bvs, rest), (sub', thk' :: thks)) 180 end 181 | advance pat (state as (_, RIGHT _::_), ho_sub) = 182 advance pat (tm_pat_break state, ho_sub) 183 | advance pat ((bvs, LEFT pat' :: rest), ho_sub) = 184 let 185 val ho_sub' = pat_ho_match ho_sub pat (bvs, pat') 186 in 187 ((tm_pat_correct pat' bvs, rest), ho_sub') 188 end 189 | advance _ ((_, []), _) = 190 raise BUG "discrim_ho_match" "no pats left in list" 191 192 fun finally a ((_, []), (sub, [thk])) = 193 SOME ((finalize_subst sub, fn () => SYM (thk ())), a) 194 | finally _ _ = raise BUG "discrim_ho_match" "weird state at end" 195 196 fun tree_match tm = 197 tree_partial_trans 198 (fn x => total (advance x)) 199 finally 200 (([], [RIGHT tm]), (empty_raw_subst, [])) 201in 202 fun discrim_ho_match (DISCRIM (_, d)) tm = 203 (flatten o map (tree_match tm)) d 204 handle (h as HOL_ERR _) => raise err_BUG "discrim_ho_match" h; 205end; 206 207fun pat_fo_match _ (FVAR _) _ = 208 raise BUG "pat_fo_match" "can't match variables" 209 | pat_fo_match sub COMB_BEGIN COMB_BEGIN = sub 210 | pat_fo_match sub (ABS_BEGIN ty) (ABS_BEGIN ty') = type_raw_match ty ty' sub 211 | pat_fo_match sub COMB_END COMB_END = sub 212 | pat_fo_match sub ABS_END ABS_END = sub 213 | pat_fo_match sub (BVAR n) (BVAR n') = 214 if n = n' then sub else raise ERR "pat_fo_match" "different bound vars" 215 | pat_fo_match sub (CONSTANT c) (CONSTANT c') = raw_match' c c' sub 216 | pat_fo_match _ _ _ = 217 raise ERR "pat_fo_match" "pats fundamentally different"; 218 219local 220 fun advance (FVAR (fv, fbs)) ((bvs, RIGHT tm :: rest), sub) = 221 let 222 val sub' = fo_raw_match (fv, fbs) (bvs, tm) sub 223 in 224 ((bvs, rest), sub') 225 end 226 | advance pat (state as (_, RIGHT _::_), sub) = 227 advance pat (tm_pat_break state, sub) 228 | advance pat ((bvs, LEFT pat' :: rest), sub) = 229 let 230 val sub' = pat_fo_match sub pat pat' 231 in 232 ((tm_pat_correct pat' bvs, rest), sub') 233 end 234 | advance _ ((_, []), _) = 235 raise BUG "discrim_fo_match" "no pats left in list" 236 237 fun finally a ((_, []), sub) = 238 SOME (finalize_subst sub, a) 239 | finally _ _ = raise BUG "discrim_fo_match" "weird state at end" 240 241 fun tree_match tm = 242 tree_partial_trans (fn x => total (advance x)) 243 finally (([], [RIGHT tm]), empty_raw_subst) 244in 245 fun discrim_fo_match (DISCRIM (_, d)) tm = 246 (flatten o map (tree_match tm)) d 247 handle HOL_ERR _ => raise BUG "discrim_fo_match" "should never fail"; 248end; 249 250(* ------------------------------------------------------------------------- *) 251(* Variable Term Discriminators *) 252(* Terms come with a list of variables that may be instantiated, the rest *) 253(* should be treated as constants. *) 254(* ------------------------------------------------------------------------- *) 255 256type 'a vdiscrim = (vars * 'a) discrim; 257 258val empty_vdiscrim : 'a vdiscrim = empty_discrim; 259 260val vdiscrim_size : 'a vdiscrim -> int = discrim_size; 261 262local 263 fun dest (tm : term, (vars : vars, a)) = ((vars, tm), a) 264in 265 fun dest_vdiscrim d = (map dest o dest_discrim) d; 266end; 267 268local 269 fun prepare (vars : vars) (tm, a) = (tm, (vars, a)) 270in 271 fun vdiscrim_add ((vars, tm), a) = discrim_add (prepare vars (tm, a)); 272end; 273 274fun mk_vdiscrim l = trans vdiscrim_add empty_discrim l; 275 276local 277 fun vars_check (ho_sub as (sub, _) : ho_substitution, (vars, a)) = 278 if subst_vars_in_set sub vars then SOME (ho_sub, a) else NONE 279 fun check_ho_match ml = partial_map vars_check ml 280in 281 fun vdiscrim_ho_match d tm = check_ho_match (discrim_ho_match d tm); 282end; 283 284local 285 fun vars_check (sub : substitution, (vars, a)) = 286 if subst_vars_in_set sub vars then SOME (sub, a) else NONE 287 fun check_fo_match ml = partial_map vars_check ml 288in 289 fun vdiscrim_fo_match d tm = check_fo_match (discrim_fo_match d tm); 290end; 291 292(* ------------------------------------------------------------------------- *) 293(* Ordered (Variable) Term Discriminators *) 294(* Entries are returned in the order they arrived (latest first). *) 295(* ------------------------------------------------------------------------- *) 296 297type 'a ovdiscrim = (int * 'a) vdiscrim; 298 299val empty_ovdiscrim : 'a ovdiscrim = empty_vdiscrim; 300 301val ovdiscrim_size : 'a ovdiscrim -> int = vdiscrim_size; 302 303local 304 fun transfer (a, (n : int, b)) = (n, (a, b)); 305 fun order (m, _) (n, _) = m > n; 306 fun dest dl = (map snd o sort order o map transfer) dl; 307in 308 fun dest_ovdiscrim d = (dest o dest_vdiscrim) d; 309 310 fun ovdiscrim_ho_match ml d = dest (vdiscrim_ho_match ml d) 311 handle (h as HOL_ERR _) => raise err_BUG "ovdiscrim_ho_match" h; 312 313 fun ovdiscrim_fo_match ml d = dest (vdiscrim_fo_match ml d) 314 handle (h as HOL_ERR _) => raise err_BUG "ovdiscrim_fo_match" h; 315end; 316 317fun ovdiscrim_add (tm, a) d = vdiscrim_add (tm, (vdiscrim_size d, a)) d; 318 319fun mk_ovdiscrim l = trans ovdiscrim_add empty_discrim l; 320 321(* non-interactive mode 322*) 323end; 324 325 326 327 328 329 330