1(*
2  Code to recall that some partial functions (of type 'a -> 'b option)
3  can be represented as sorted alists, and derive a fast conversion on
4  applications of those functions.
5*)
6structure alist_treeLib :> alist_treeLib =
7struct
8
9open HolKernel Parse boolLib simpLib bossLib
10
11open alist_treeTheory comparisonTheory
12
13(* Syntax *)
14
15val alookup_tm = prim_mk_const {Name = "ALOOKUP", Thy = "alist"}
16
17fun mkc nm = prim_mk_const {Name = nm, Thy = "alist_tree"}
18
19val count_append_tm = mkc "count_append"
20val is_insert_tm = mkc "is_insert"
21val option_choice_tm = mkc "option_choice_f"
22val is_lookup_tm = mkc "is_lookup"
23val repr_tm = mkc "sorted_alist_repr"
24
25(* trivia *)
26val err = mk_HOL_ERR "alist_treeLib"
27
28(* the repr set object *)
29datatype 'a alist_reprs = AList_Reprs of {R_thm: thm, conv: conv,
30    dest: term -> 'a, cmp: ('a * 'a) -> order,
31    dict: (term, thm) Redblackmap.dict ref}
32
33fun mk_alist_reprs R_thm conv dest cmp
34    = AList_Reprs {R_thm = R_thm, conv = conv, cmp = cmp,
35        dest = dest, dict = ref (Redblackmap.mkDict Term.compare)}
36
37fun peek_functions_in_rs (AList_Reprs inn_rs)
38    = Redblackmap.listItems (! (#dict inn_rs)) |> map fst
39
40fun peek_repr (AList_Reprs inn_rs) tm = Redblackmap.peek (! (#dict inn_rs), tm)
41
42(* constructing is_insert thms *)
43
44fun find_key_rec is_last [] = raise Empty
45  | find_key_rec is_last (t :: ts) = if listSyntax.is_nil t
46  then find_key_rec is_last ts
47  else let
48    val (f, xs) = strip_comb t
49    val do_rev = if is_last then rev else I
50  in if same_const f count_append_tm
51    then find_key_rec is_last (do_rev (tl xs) @ ts)
52    else hd (do_rev (fst (listSyntax.dest_list t))) end
53
54fun hd_key t = find_key_rec false [t] |> pairSyntax.dest_pair |> fst
55fun last_key t = find_key_rec true [t] |> pairSyntax.dest_pair |> fst
56
57fun mk_singleton x = listSyntax.mk_list ([x], type_of x)
58
59val simp_count_append = SIMP_CONV bool_ss [count_append_HD_LAST, pairTheory.FST]
60
61fun assume_prems thm = if not (is_imp (concl thm)) then thm
62  else let
63    val thm = CONV_RULE (RATOR_CONV simp_count_append) thm
64    val l_asm = fst (dest_imp (concl thm))
65    val prem = if l_asm ~~ T then TRUTH else ASSUME l_asm
66  in
67    assume_prems (MP thm prem)
68  end
69
70fun do_inst_mp insts mp_thm arg_thm = let
71    val (prem, _) = dest_imp (concl mp_thm)
72    fun rerr e s = let
73        val m = s ^ ": " ^ #message e
74      in print ("error in do_inst_mp: " ^ m ^ "\n");
75        print_thm mp_thm; print "\n"; print_thm arg_thm; print "\n";
76        raise (err "do_inst_mp" m) end
77    val (more_insts, ty_insts) = match_term prem (concl arg_thm)
78      handle HOL_ERR e => rerr e "match_term"
79    val ty_i_thm = INST_TYPE ty_insts mp_thm
80    val ithm = INST (more_insts @ insts) ty_i_thm
81      handle HOL_ERR e => rerr e "INST"
82  in MP ithm arg_thm handle HOL_ERR e => rerr e "MP" end
83
84fun build_insert (dest : term -> 'a) cmp R k x =
85  let
86    val dest_k = dest k
87    fun chk thm = if same_const is_insert_tm (fst (strip_comb (concl thm)))
88      then thm
89      else (print "Not an insert_tm:\n"; print_thm thm; print "\n";
90        raise (err "build_insert" "check"))
91    val pp = chk o assume_prems
92    fun build t = if listSyntax.is_nil t
93      then pp (ISPECL [R, k, x] is_insert_to_empty)
94      else if listSyntax.is_cons t then let
95        val (xs, _) = listSyntax.dest_list t
96        val _ = length xs = 1 orelse raise (err "build_insert" "malformed")
97        val v = hd xs
98      in case cmp (dest_k, dest (fst (pairSyntax.dest_pair v))) of
99          EQUAL => pp (ISPECL [R, k, x, v] is_insert_overwrite)
100        | GREATER => pp (ISPECL [R, k, x, t] is_insert_far_right)
101        | LESS => pp (ISPECL [R, k, x, t] is_insert_far_left)
102      end else let
103        val (f, xs) = strip_comb t
104        val _ = same_const count_append_tm f
105           orelse raise (err "build_insert" "unknown")
106        val (n, l, r) = case xs of [n, l, r] => (n, l, r)
107           | _ => raise (err "build_insert" "num args")
108        fun vsub nm v = [mk_var (nm, type_of v) |-> v]
109      in if not (cmp (dest_k, dest (hd_key r)) = LESS)
110        then do_inst_mp (vsub "l" l) (SPEC n is_insert_r) (build r)
111        else if cmp (dest_k, dest (last_key l)) = GREATER
112        then ISPECL [R, n, k, x] is_insert_centre
113            |> INST (vsub "l" l @ vsub "r" r) |> pp
114        else do_inst_mp (vsub "r" r) (SPEC n is_insert_l) (build l)
115      end
116  in build end
117
118fun prove_assum_by_conv conv thm = let
119    val (x,y) = dest_imp (concl thm)
120    val thm = CONV_RULE ((RATOR_CONV o RAND_CONV) conv) thm
121  in MP thm TRUTH handle HOL_ERR e =>
122    (print "Failed to prove assum by conv:\n";
123      print_term x;
124      print "\n  -- reduced to:\n";
125      print_term (fst (dest_imp (concl thm)));
126      raise HOL_ERR e)
127  end
128
129(* balancing count_append trees *)
130fun get_depth tm = let
131    val (f, xs) = strip_comb tm
132  in if same_const f count_append_tm
133    then numSyntax.int_of_term (hd xs)
134    else if listSyntax.is_cons tm then 1
135    else raise (err "get_depth" "unknown")
136  end
137
138fun balance iter bias tm = if iter > 1000 then
139    (print "error: looping balance\n"; print_term tm;
140        raise (err "balance" "looping"))
141  else let
142    val (f, xs) = strip_comb tm
143    val _ = same_const f count_append_tm orelse raise UNCHANGED
144    val _ = is_arb (hd xs) orelse bias <> "N" orelse raise UNCHANGED
145    val step_conv = (RAND_CONV (balance 0 "N"))
146        THENC (RATOR_CONV (RAND_CONV (balance 0 "N")))
147    val thm = QCONV step_conv tm
148    val tm = rhs (concl thm)
149    val l_sz = get_depth (rand (rator tm))
150    val r_sz = get_depth (rand tm)
151    val reb = if l_sz > (r_sz + 1) orelse (bias = "R" andalso l_sz > r_sz)
152        then "R"
153        else if r_sz > (l_sz + 1) orelse (bias = "L" andalso r_sz > l_sz)
154        then "L" else "N"
155    val conv1 = if reb = "R" then RATOR_CONV (RAND_CONV (balance 0 "L"))
156        else if reb = "L" then RAND_CONV (balance 0 "R") else ALL_CONV
157    val conv2 = if reb = "R" then REWR_CONV balance_r
158        else if reb = "L" then REWR_CONV balance_l else ALL_CONV
159    val thm = CONV_RULE (QCONV (RHS_CONV (conv1 THENC conv2 THENC step_conv))) thm
160    val tm = rhs (concl thm)
161    val l_sz = get_depth (rand (rator tm))
162    val r_sz = get_depth (rand tm)
163    val sz = numSyntax.term_of_int (1 + Int.max (l_sz, r_sz))
164    val set = REWR_CONV (set_count |> SPEC sz)
165    val final = if Int.abs (l_sz - r_sz) < 2 then set else balance (iter + 1) "N"
166  in CONV_RULE (RHS_CONV final) thm end
167
168fun prove_insert R conv dest cmp k x al = let
169    val thm = build_insert dest cmp R k x al
170    fun prove thm = if not (is_imp (concl thm)) then thm
171      else prove (prove_assum_by_conv conv thm)
172  in thm |> DISCH_ALL |> prove |> CONV_RULE (RAND_CONV (balance 0 "N")) end
173
174(* making repr theorems *)
175
176fun mk_insert_repr (AList_Reprs rs) prev_repr k_x = let
177    val (k, x) = pairSyntax.dest_pair k_x
178    val (R, al) = case strip_comb (concl prev_repr)
179      of (_, [R, al, _]) => (R, al)
180        | _ => raise (err "mk_insert_repr" "unexpected")
181    val insert = prove_insert R (#conv rs) (#dest rs) (#cmp rs) k x al
182  in MATCH_MP repr_insert (CONJ prev_repr insert) end
183
184fun dest_alookup_single tm = let
185    val (f, xs) = strip_comb tm
186  in if not (length xs = 1 andalso same_const f alookup_tm)
187    then NONE
188    else if listSyntax.is_nil (hd xs) then SOME NONE
189    else case total listSyntax.dest_cons (hd xs) of
190        SOME (y, ys) => if listSyntax.is_nil ys then SOME (SOME y) else NONE
191      | NONE => NONE
192  end
193
194fun mk_repr_step rs tm = let
195    val (AList_Reprs inn_rs) = rs
196    val (f, xs) = strip_comb tm
197    val is_short =
198        not (option_eq (option_eq aconv) (dest_alookup_single tm) NONE)
199    val is_merge = same_const option_choice_tm f
200    val is_repr_merge = is_merge andalso (case peek_repr rs (hd xs) of
201        SOME _ => true | NONE => false)
202    val (is_insert, insert_tm) = if not is_merge then (false, T)
203        else case dest_alookup_single (hd xs) of
204            SOME (SOME t) => (true, t) | _ => (false, T)
205  in if is_short
206    then MATCH_MP (ISPEC (hd xs) alist_repr_refl) (#R_thm inn_rs)
207        |> prove_assum_by_conv (SIMP_CONV list_ss [sortingTheory.SORTED_DEF])
208    else if is_insert
209    then mk_insert_repr rs (mk_repr rs (rand tm)) insert_tm
210    else if is_repr_merge
211    then let
212        val l_repr_thm = mk_repr rs (hd xs)
213        val l_repr_al = rand (rator (concl l_repr_thm))
214        val look = mk_icomb (alookup_tm, l_repr_al)
215        val half_repr = list_mk_icomb (option_choice_tm, [look, List.last xs])
216        val next_repr = mk_repr rs half_repr
217      in MATCH_MP alist_repr_choice_trans_left (CONJ l_repr_thm next_repr) end
218    else CHANGED_CONV (SIMP_CONV bool_ss [alookup_to_option_choice,
219                option_choice_f_assoc, alookup_empty_option_choice_f,
220                count_append_def, alookup_append_option_choice_f,
221                empty_is_ALOOKUP]) tm
222        handle HOL_ERR _ => raise err "mk_repr_step"
223            ("no progress from SIMP_CONV: " ^ Parse.term_to_string tm)
224  end
225and mk_repr_known_step rs tm =
226  case peek_repr rs tm of
227    SOME thm => thm
228  | NONE => mk_repr_step rs tm
229and mk_repr rs tm = let
230    val thm = mk_repr_known_step rs tm
231  in if is_eq (concl thm)
232    then mk_repr rs (rhs (concl thm))
233      |> CONV_RULE (RAND_CONV (REWR_CONV (SYM thm)))
234    else thm
235  end
236
237fun add_alist_repr rs thm = let
238    val AList_Reprs inn_rs = rs
239    val (f, rhs) = dest_eq (concl thm)
240    val repr_thm = case peek_repr rs rhs of
241        SOME rhs_thm => if is_eq (concl rhs_thm)
242          then TRANS thm rhs_thm
243          else thm
244      | NONE => (mk_repr rs rhs
245        |> CONV_RULE (RAND_CONV (REWR_CONV (SYM thm))))
246  in
247    #dict inn_rs := Redblackmap.insert (! (#dict inn_rs), f, repr_thm)
248  end
249
250fun timeit msg f v = let
251    val start = Portable.timestamp ()
252    val r = f v
253    val time = Time.-(Portable.timestamp (), start)
254  in print ("Time to " ^ msg ^ ": " ^ Portable.time_to_string time ^ "\n");
255    r end
256
257(* testing *)
258
259fun test_rs () = let
260    val thm1 = DB.fetch "comparison" "good_cmp_Less_irrefl_trans"
261    val thm2 = DB.fetch "comparison" "num_cmp_good"
262    val R_thm = MATCH_MP thm1 thm2
263  in mk_alist_reprs R_thm EVAL numSyntax.int_of_term Int.compare end
264
265fun test_mk_alookup ns = let
266    open numSyntax
267    val _ = I
268    fun f i = ((i * 157) mod 1000)
269    fun el i = pairSyntax.mk_pair (term_of_int (f i), term_of_int i)
270  in mk_icomb (alookup_tm, listSyntax.mk_list (map el ns, type_of (el 0))) end
271
272fun test_200 rs = let
273    val al1 = test_mk_alookup (upto 1 200)
274    val al2 = test_mk_alookup [1, 4, 3]
275    val merge = list_mk_icomb (option_choice_tm, [al1, al2])
276  in mk_repr rs merge end
277
278(*
279val rs = test_rs ()
280val thm_200 = timeit "build repr" test_200 rs
281*)
282
283(* proving and using is_lookup thms *)
284
285fun build_lookup (dest : term -> 'a) cmp R k =
286  let
287    val dest_k = dest k
288    fun chk thm = if same_const is_lookup_tm (fst (strip_comb (concl thm)))
289      then thm
290      else (print "Not a lookup_tm:\n"; print_thm thm; print "\n";
291        raise (err "build_lookup" "check"))
292    val pp = chk o assume_prems
293    fun build t = if listSyntax.is_nil t
294      then pp (ISPECL [R, k, t] is_lookup_empty)
295      else if listSyntax.is_cons t then let
296        val (xs, _) = listSyntax.dest_list t
297        val _ = length xs = 1 orelse raise (err "build_insert" "malformed")
298        val (k', v) = pairSyntax.dest_pair (hd xs)
299      in case cmp (dest_k, dest k') of
300          EQUAL => pp (ISPECL [R, k, k', v] is_lookup_hit)
301        | GREATER => pp (ISPECL [R, k, k', v] is_lookup_far_right)
302        | LESS => pp (ISPECL [R, k, k', v] is_lookup_far_left)
303      end else let
304        val (f, xs) = strip_comb t
305        val _ = same_const count_append_tm f
306           orelse raise (err "build_lookup" "unknown")
307        val (n, l, r) = case xs of [n, l, r] => (n, l, r)
308           | _ => raise (err "build_lookup" "num args")
309        fun vsub nm v = [mk_var (nm, type_of v) |-> v]
310      in if not (cmp (dest_k, dest (hd_key r)) = LESS)
311        then do_inst_mp (vsub "l" l) (SPEC n is_lookup_r) (build r)
312        else if cmp (dest_k, dest (last_key l)) = GREATER
313        then pp (ISPECL [R, n, l, r, k] is_lookup_centre)
314        else do_inst_mp (vsub "r" r) (SPEC n is_lookup_l) (build l)
315      end
316  in build end
317
318fun prove_lookup R conv dest cmp k al = let
319    val thm = build_lookup dest cmp R k al
320    fun prove thm = if not (is_imp (concl thm)) then thm
321      else prove (prove_assum_by_conv conv thm)
322  in thm |> DISCH_ALL |> prove end
323
324fun repr_prove_lookup conv dest cmp repr_thm k = let
325    val (f, xs) = strip_comb (concl repr_thm)
326    val f = same_const f repr_tm orelse
327        raise (err "repr_prove_lookup" "unexpected")
328    val (R, al, f) = case xs of [R, al, f] => (R, al, f)
329        | _ => raise (err "repr_prove_lookup" "num args")
330    val lookup = prove_lookup R conv dest cmp k al
331  in MATCH_MP lookup_repr (CONJ repr_thm lookup) end
332
333fun reprs_conv rs tm = let
334    val AList_Reprs inn_rs = rs
335    val (f, x) = dest_comb tm handle HOL_ERR _ => raise UNCHANGED
336    val repr_thm = case peek_repr rs f of
337      NONE => raise UNCHANGED | SOME thm => thm
338  in if is_eq (concl repr_thm)
339    then (RATOR_CONV (REWR_CONV repr_thm) THENC reprs_conv rs) tm
340    else repr_prove_lookup (#conv inn_rs) (#dest inn_rs) (#cmp inn_rs)
341        repr_thm x
342  end
343
344fun extract_test f rs i = mk_comb (f, numSyntax.term_of_int i) |> reprs_conv rs
345
346fun extract_test_1000 rs = let
347    val alookup = test_mk_alookup (upto 1 300)
348    val f = mk_var ("f", type_of alookup)
349    val f_def = new_definition ("f", mk_eq (f, alookup))
350    val res1 = timeit "add def" (add_alist_repr rs) f_def
351    val res2 = timeit "map extract" (map (extract_test f rs)) (upto 1 1000)
352  in res2 end
353
354end
355