1(* 2 Code to recall that some partial functions (of type 'a -> 'b option) 3 can be represented as sorted alists, and derive a fast conversion on 4 applications of those functions. 5*) 6structure alist_treeLib :> alist_treeLib = 7struct 8 9open HolKernel Parse boolLib simpLib bossLib 10 11open alist_treeTheory comparisonTheory 12 13(* Syntax *) 14 15val alookup_tm = prim_mk_const {Name = "ALOOKUP", Thy = "alist"} 16 17fun mkc nm = prim_mk_const {Name = nm, Thy = "alist_tree"} 18 19val count_append_tm = mkc "count_append" 20val is_insert_tm = mkc "is_insert" 21val option_choice_tm = mkc "option_choice_f" 22val is_lookup_tm = mkc "is_lookup" 23val repr_tm = mkc "sorted_alist_repr" 24 25(* trivia *) 26val err = mk_HOL_ERR "alist_treeLib" 27 28(* the repr set object *) 29datatype 'a alist_reprs = AList_Reprs of {R_thm: thm, conv: conv, 30 dest: term -> 'a, cmp: ('a * 'a) -> order, 31 dict: (term, thm) Redblackmap.dict ref} 32 33fun mk_alist_reprs R_thm conv dest cmp 34 = AList_Reprs {R_thm = R_thm, conv = conv, cmp = cmp, 35 dest = dest, dict = ref (Redblackmap.mkDict Term.compare)} 36 37fun peek_functions_in_rs (AList_Reprs inn_rs) 38 = Redblackmap.listItems (! (#dict inn_rs)) |> map fst 39 40fun peek_repr (AList_Reprs inn_rs) tm = Redblackmap.peek (! (#dict inn_rs), tm) 41 42(* constructing is_insert thms *) 43 44fun find_key_rec is_last [] = raise Empty 45 | find_key_rec is_last (t :: ts) = if listSyntax.is_nil t 46 then find_key_rec is_last ts 47 else let 48 val (f, xs) = strip_comb t 49 val do_rev = if is_last then rev else I 50 in if same_const f count_append_tm 51 then find_key_rec is_last (do_rev (tl xs) @ ts) 52 else hd (do_rev (fst (listSyntax.dest_list t))) end 53 54fun hd_key t = find_key_rec false [t] |> pairSyntax.dest_pair |> fst 55fun last_key t = find_key_rec true [t] |> pairSyntax.dest_pair |> fst 56 57fun mk_singleton x = listSyntax.mk_list ([x], type_of x) 58 59val simp_count_append = SIMP_CONV bool_ss [count_append_HD_LAST, pairTheory.FST] 60 61fun assume_prems thm = if not (is_imp (concl thm)) then thm 62 else let 63 val thm = CONV_RULE (RATOR_CONV simp_count_append) thm 64 val l_asm = fst (dest_imp (concl thm)) 65 val prem = if l_asm ~~ T then TRUTH else ASSUME l_asm 66 in 67 assume_prems (MP thm prem) 68 end 69 70fun do_inst_mp insts mp_thm arg_thm = let 71 val (prem, _) = dest_imp (concl mp_thm) 72 fun rerr e s = let 73 val m = s ^ ": " ^ #message e 74 in print ("error in do_inst_mp: " ^ m ^ "\n"); 75 print_thm mp_thm; print "\n"; print_thm arg_thm; print "\n"; 76 raise (err "do_inst_mp" m) end 77 val (more_insts, ty_insts) = match_term prem (concl arg_thm) 78 handle HOL_ERR e => rerr e "match_term" 79 val ty_i_thm = INST_TYPE ty_insts mp_thm 80 val ithm = INST (more_insts @ insts) ty_i_thm 81 handle HOL_ERR e => rerr e "INST" 82 in MP ithm arg_thm handle HOL_ERR e => rerr e "MP" end 83 84fun build_insert (dest : term -> 'a) cmp R k x = 85 let 86 val dest_k = dest k 87 fun chk thm = if same_const is_insert_tm (fst (strip_comb (concl thm))) 88 then thm 89 else (print "Not an insert_tm:\n"; print_thm thm; print "\n"; 90 raise (err "build_insert" "check")) 91 val pp = chk o assume_prems 92 fun build t = if listSyntax.is_nil t 93 then pp (ISPECL [R, k, x] is_insert_to_empty) 94 else if listSyntax.is_cons t then let 95 val (xs, _) = listSyntax.dest_list t 96 val _ = length xs = 1 orelse raise (err "build_insert" "malformed") 97 val v = hd xs 98 in case cmp (dest_k, dest (fst (pairSyntax.dest_pair v))) of 99 EQUAL => pp (ISPECL [R, k, x, v] is_insert_overwrite) 100 | GREATER => pp (ISPECL [R, k, x, t] is_insert_far_right) 101 | LESS => pp (ISPECL [R, k, x, t] is_insert_far_left) 102 end else let 103 val (f, xs) = strip_comb t 104 val _ = same_const count_append_tm f 105 orelse raise (err "build_insert" "unknown") 106 val (n, l, r) = case xs of [n, l, r] => (n, l, r) 107 | _ => raise (err "build_insert" "num args") 108 fun vsub nm v = [mk_var (nm, type_of v) |-> v] 109 in if not (cmp (dest_k, dest (hd_key r)) = LESS) 110 then do_inst_mp (vsub "l" l) (SPEC n is_insert_r) (build r) 111 else if cmp (dest_k, dest (last_key l)) = GREATER 112 then ISPECL [R, n, k, x] is_insert_centre 113 |> INST (vsub "l" l @ vsub "r" r) |> pp 114 else do_inst_mp (vsub "r" r) (SPEC n is_insert_l) (build l) 115 end 116 in build end 117 118fun prove_assum_by_conv conv thm = let 119 val (x,y) = dest_imp (concl thm) 120 val thm = CONV_RULE ((RATOR_CONV o RAND_CONV) conv) thm 121 in MP thm TRUTH handle HOL_ERR e => 122 (print "Failed to prove assum by conv:\n"; 123 print_term x; 124 print "\n -- reduced to:\n"; 125 print_term (fst (dest_imp (concl thm))); 126 raise HOL_ERR e) 127 end 128 129(* balancing count_append trees *) 130fun get_depth tm = let 131 val (f, xs) = strip_comb tm 132 in if same_const f count_append_tm 133 then numSyntax.int_of_term (hd xs) 134 else if listSyntax.is_cons tm then 1 135 else raise (err "get_depth" "unknown") 136 end 137 138fun balance iter bias tm = if iter > 1000 then 139 (print "error: looping balance\n"; print_term tm; 140 raise (err "balance" "looping")) 141 else let 142 val (f, xs) = strip_comb tm 143 val _ = same_const f count_append_tm orelse raise UNCHANGED 144 val _ = is_arb (hd xs) orelse bias <> "N" orelse raise UNCHANGED 145 val step_conv = (RAND_CONV (balance 0 "N")) 146 THENC (RATOR_CONV (RAND_CONV (balance 0 "N"))) 147 val thm = QCONV step_conv tm 148 val tm = rhs (concl thm) 149 val l_sz = get_depth (rand (rator tm)) 150 val r_sz = get_depth (rand tm) 151 val reb = if l_sz > (r_sz + 1) orelse (bias = "R" andalso l_sz > r_sz) 152 then "R" 153 else if r_sz > (l_sz + 1) orelse (bias = "L" andalso r_sz > l_sz) 154 then "L" else "N" 155 val conv1 = if reb = "R" then RATOR_CONV (RAND_CONV (balance 0 "L")) 156 else if reb = "L" then RAND_CONV (balance 0 "R") else ALL_CONV 157 val conv2 = if reb = "R" then REWR_CONV balance_r 158 else if reb = "L" then REWR_CONV balance_l else ALL_CONV 159 val thm = CONV_RULE (QCONV (RHS_CONV (conv1 THENC conv2 THENC step_conv))) thm 160 val tm = rhs (concl thm) 161 val l_sz = get_depth (rand (rator tm)) 162 val r_sz = get_depth (rand tm) 163 val sz = numSyntax.term_of_int (1 + Int.max (l_sz, r_sz)) 164 val set = REWR_CONV (set_count |> SPEC sz) 165 val final = if Int.abs (l_sz - r_sz) < 2 then set else balance (iter + 1) "N" 166 in CONV_RULE (RHS_CONV final) thm end 167 168fun prove_insert R conv dest cmp k x al = let 169 val thm = build_insert dest cmp R k x al 170 fun prove thm = if not (is_imp (concl thm)) then thm 171 else prove (prove_assum_by_conv conv thm) 172 in thm |> DISCH_ALL |> prove |> CONV_RULE (RAND_CONV (balance 0 "N")) end 173 174(* making repr theorems *) 175 176fun mk_insert_repr (AList_Reprs rs) prev_repr k_x = let 177 val (k, x) = pairSyntax.dest_pair k_x 178 val (R, al) = case strip_comb (concl prev_repr) 179 of (_, [R, al, _]) => (R, al) 180 | _ => raise (err "mk_insert_repr" "unexpected") 181 val insert = prove_insert R (#conv rs) (#dest rs) (#cmp rs) k x al 182 in MATCH_MP repr_insert (CONJ prev_repr insert) end 183 184fun dest_alookup_single tm = let 185 val (f, xs) = strip_comb tm 186 in if not (length xs = 1 andalso same_const f alookup_tm) 187 then NONE 188 else if listSyntax.is_nil (hd xs) then SOME NONE 189 else case total listSyntax.dest_cons (hd xs) of 190 SOME (y, ys) => if listSyntax.is_nil ys then SOME (SOME y) else NONE 191 | NONE => NONE 192 end 193 194fun mk_repr_step rs tm = let 195 val (AList_Reprs inn_rs) = rs 196 val (f, xs) = strip_comb tm 197 val is_short = 198 not (option_eq (option_eq aconv) (dest_alookup_single tm) NONE) 199 val is_merge = same_const option_choice_tm f 200 val is_repr_merge = is_merge andalso (case peek_repr rs (hd xs) of 201 SOME _ => true | NONE => false) 202 val (is_insert, insert_tm) = if not is_merge then (false, T) 203 else case dest_alookup_single (hd xs) of 204 SOME (SOME t) => (true, t) | _ => (false, T) 205 in if is_short 206 then MATCH_MP (ISPEC (hd xs) alist_repr_refl) (#R_thm inn_rs) 207 |> prove_assum_by_conv (SIMP_CONV list_ss [sortingTheory.SORTED_DEF]) 208 else if is_insert 209 then mk_insert_repr rs (mk_repr rs (rand tm)) insert_tm 210 else if is_repr_merge 211 then let 212 val l_repr_thm = mk_repr rs (hd xs) 213 val l_repr_al = rand (rator (concl l_repr_thm)) 214 val look = mk_icomb (alookup_tm, l_repr_al) 215 val half_repr = list_mk_icomb (option_choice_tm, [look, List.last xs]) 216 val next_repr = mk_repr rs half_repr 217 in MATCH_MP alist_repr_choice_trans_left (CONJ l_repr_thm next_repr) end 218 else CHANGED_CONV (SIMP_CONV bool_ss [alookup_to_option_choice, 219 option_choice_f_assoc, alookup_empty_option_choice_f, 220 count_append_def, alookup_append_option_choice_f, 221 empty_is_ALOOKUP]) tm 222 handle HOL_ERR _ => raise err "mk_repr_step" 223 ("no progress from SIMP_CONV: " ^ Parse.term_to_string tm) 224 end 225and mk_repr_known_step rs tm = 226 case peek_repr rs tm of 227 SOME thm => thm 228 | NONE => mk_repr_step rs tm 229and mk_repr rs tm = let 230 val thm = mk_repr_known_step rs tm 231 in if is_eq (concl thm) 232 then mk_repr rs (rhs (concl thm)) 233 |> CONV_RULE (RAND_CONV (REWR_CONV (SYM thm))) 234 else thm 235 end 236 237fun add_alist_repr rs thm = let 238 val AList_Reprs inn_rs = rs 239 val (f, rhs) = dest_eq (concl thm) 240 val repr_thm = case peek_repr rs rhs of 241 SOME rhs_thm => if is_eq (concl rhs_thm) 242 then TRANS thm rhs_thm 243 else thm 244 | NONE => (mk_repr rs rhs 245 |> CONV_RULE (RAND_CONV (REWR_CONV (SYM thm)))) 246 in 247 #dict inn_rs := Redblackmap.insert (! (#dict inn_rs), f, repr_thm) 248 end 249 250fun timeit msg f v = let 251 val start = Portable.timestamp () 252 val r = f v 253 val time = Time.-(Portable.timestamp (), start) 254 in print ("Time to " ^ msg ^ ": " ^ Portable.time_to_string time ^ "\n"); 255 r end 256 257(* testing *) 258 259fun test_rs () = let 260 val thm1 = DB.fetch "comparison" "good_cmp_Less_irrefl_trans" 261 val thm2 = DB.fetch "comparison" "num_cmp_good" 262 val R_thm = MATCH_MP thm1 thm2 263 in mk_alist_reprs R_thm EVAL numSyntax.int_of_term Int.compare end 264 265fun test_mk_alookup ns = let 266 open numSyntax 267 val _ = I 268 fun f i = ((i * 157) mod 1000) 269 fun el i = pairSyntax.mk_pair (term_of_int (f i), term_of_int i) 270 in mk_icomb (alookup_tm, listSyntax.mk_list (map el ns, type_of (el 0))) end 271 272fun test_200 rs = let 273 val al1 = test_mk_alookup (upto 1 200) 274 val al2 = test_mk_alookup [1, 4, 3] 275 val merge = list_mk_icomb (option_choice_tm, [al1, al2]) 276 in mk_repr rs merge end 277 278(* 279val rs = test_rs () 280val thm_200 = timeit "build repr" test_200 rs 281*) 282 283(* proving and using is_lookup thms *) 284 285fun build_lookup (dest : term -> 'a) cmp R k = 286 let 287 val dest_k = dest k 288 fun chk thm = if same_const is_lookup_tm (fst (strip_comb (concl thm))) 289 then thm 290 else (print "Not a lookup_tm:\n"; print_thm thm; print "\n"; 291 raise (err "build_lookup" "check")) 292 val pp = chk o assume_prems 293 fun build t = if listSyntax.is_nil t 294 then pp (ISPECL [R, k, t] is_lookup_empty) 295 else if listSyntax.is_cons t then let 296 val (xs, _) = listSyntax.dest_list t 297 val _ = length xs = 1 orelse raise (err "build_insert" "malformed") 298 val (k', v) = pairSyntax.dest_pair (hd xs) 299 in case cmp (dest_k, dest k') of 300 EQUAL => pp (ISPECL [R, k, k', v] is_lookup_hit) 301 | GREATER => pp (ISPECL [R, k, k', v] is_lookup_far_right) 302 | LESS => pp (ISPECL [R, k, k', v] is_lookup_far_left) 303 end else let 304 val (f, xs) = strip_comb t 305 val _ = same_const count_append_tm f 306 orelse raise (err "build_lookup" "unknown") 307 val (n, l, r) = case xs of [n, l, r] => (n, l, r) 308 | _ => raise (err "build_lookup" "num args") 309 fun vsub nm v = [mk_var (nm, type_of v) |-> v] 310 in if not (cmp (dest_k, dest (hd_key r)) = LESS) 311 then do_inst_mp (vsub "l" l) (SPEC n is_lookup_r) (build r) 312 else if cmp (dest_k, dest (last_key l)) = GREATER 313 then pp (ISPECL [R, n, l, r, k] is_lookup_centre) 314 else do_inst_mp (vsub "r" r) (SPEC n is_lookup_l) (build l) 315 end 316 in build end 317 318fun prove_lookup R conv dest cmp k al = let 319 val thm = build_lookup dest cmp R k al 320 fun prove thm = if not (is_imp (concl thm)) then thm 321 else prove (prove_assum_by_conv conv thm) 322 in thm |> DISCH_ALL |> prove end 323 324fun repr_prove_lookup conv dest cmp repr_thm k = let 325 val (f, xs) = strip_comb (concl repr_thm) 326 val f = same_const f repr_tm orelse 327 raise (err "repr_prove_lookup" "unexpected") 328 val (R, al, f) = case xs of [R, al, f] => (R, al, f) 329 | _ => raise (err "repr_prove_lookup" "num args") 330 val lookup = prove_lookup R conv dest cmp k al 331 in MATCH_MP lookup_repr (CONJ repr_thm lookup) end 332 333fun reprs_conv rs tm = let 334 val AList_Reprs inn_rs = rs 335 val (f, x) = dest_comb tm handle HOL_ERR _ => raise UNCHANGED 336 val repr_thm = case peek_repr rs f of 337 NONE => raise UNCHANGED | SOME thm => thm 338 in if is_eq (concl repr_thm) 339 then (RATOR_CONV (REWR_CONV repr_thm) THENC reprs_conv rs) tm 340 else repr_prove_lookup (#conv inn_rs) (#dest inn_rs) (#cmp inn_rs) 341 repr_thm x 342 end 343 344fun extract_test f rs i = mk_comb (f, numSyntax.term_of_int i) |> reprs_conv rs 345 346fun extract_test_1000 rs = let 347 val alookup = test_mk_alookup (upto 1 300) 348 val f = mk_var ("f", type_of alookup) 349 val f_def = new_definition ("f", mk_eq (f, alookup)) 350 val res1 = timeit "add def" (add_alist_repr rs) f_def 351 val res2 = timeit "map extract" (map (extract_test f rs)) (upto 1 1000) 352 in res2 end 353 354end 355