1(*
2set_trace "Unicode" 0;
3set_trace "pp_annotations" 0;
4*)
5
6use (HOLDIR^"/src/pfl/defchoose");
7
8(* quietdec := true; *)
9open numSyntax optionSyntax pairSyntax optionTheory;
10(* quietdec := false; *)
11
12
13val suc_zero = ``SUC 0``;
14
15(*---------------------------------------------------------------------------*)
16(* Examples                                                                  *)
17(*---------------------------------------------------------------------------*)
18
19val ack_tm =
20 ``ack m n =
21    if m=0 then n + 1 else
22    if n=0 then ack (m-1) 1 else
23    ack (m-1) (ack m (n-1))``;
24
25val ack1_tm =
26 ``(ack 0 n = n+1) /\
27   (ack m 0 = ack (m-1) 1) /\
28   (ack (SUC m) (SUC n) = ack m (ack (SUC m) n))``;
29
30val ack2_tm =
31 ``(!n. ack 0 n = n+1) /\
32   (!m. ack m 0 = ack (m-1) 1) /\
33   (!m n. ack (SUC m) (SUC n) = ack m (ack (SUC m) n))``;
34
35val fact_tm = ``fact n = if n=0 then 1 else n * fact (n-1)``;
36
37val fact1_tm = ``fact n = if n=0 then 1 else (n + n) * (13 * fact (n-1) + n)``;
38
39val fact2_tm = ``fact n = if n=0 then 1 else
40                 (n + fact (n-1)) * (13 * fact (n-1) + n)``;
41
42val fact3_tm = ``fact n = if n=0 then 1 else
43                  let x = n + n
44                  in x * fact (n-1)``;
45
46val fact4_tm = ``fact n = if n=0 then 1 else
47                  let x = fact (n-1)
48                  in x + n * x``;
49
50val fact5_tm = ``fact n = if n=0 then 1 else
51                  let x = fact (n-1)
52                  and y = n + n
53                  in y + n * x``;
54
55val map_tm = ``(map f [] = []) /\ (map f (h::t) = f h :: map f t)``;
56val f91_tm = ``f91 n = if n>100 then n-10 else f91(f91(n+11))``;
57val f91a_tm = ``f91 n = if n>100 then n-10 else n + f91(f91(n+11))``;
58
59val Z_tm = ``Z n = if n=0 then 0 else Z(Z(Z(n-1)))``;
60val Z1_tm = ``Z1 n = if n=0 then 0 else SUC(SUC(Z(n-1))) + Z1 (n-1)``;
61val Z2_tm = ``Z1 n = if n=0 then 0 else SUC(SUC(Z1(n-1))) + Z1 (n-1)``;
62
63val partle_tm =
64  ``(part x [] = ([],[])) /\
65    (part x (h::t) =
66      let (l1,l2) = part x t
67      in if h <= x then (h::l1, l2) else (l1,h::l2))``;
68
69val part_tm =
70  ``(part P [] = ([],[])) /\
71    (part P (h::t) =
72      let (l1,l2) = part P t
73      in if P h then (h::l1, l2) else (l1,h::l2))``;
74
75val qsort0_tm =
76  ``(qsort [] = []) /\
77    (qsort (h::t) =
78      let l1 = FILTER (\x. x <= h) t in
79      let l2 = FILTER (\x. x > h) t
80      in qsort l1 ++ [h] ++ qsort l2)``;
81
82val qsort1_tm =
83  ``(qsort [] = []) /\
84    (qsort (h::t) =
85      let l1 = FILTER (\x. x <= h) t in
86      let l2 = FILTER (\x. x > h) t
87      in qsort l1 ++ [h] ++ qsort l2)``;
88
89val qsort2_tm =
90  ``(qsort [] = []) /\
91    (qsort (h::t) =
92      let l1 = FILTER (\x. x <= h) t
93      and l2 = FILTER (\x. x > h) t
94      in qsort l1 ++ [h] ++ qsort l2)``;
95
96val qsort3_tm =
97  ``(qsort [] = []) /\
98    (qsort (h::t) =
99      let (l1,l2) = part (\x. x <= h) t
100      in qsort l1 ++ [h] ++ qsort l2)``;
101
102val qsort4_tm =
103  ``(qsort R [] = []) /\
104    (qsort R (h::t) =
105      let (l1,l2) = part (\x. R x h) t
106      in qsort R l1 ++ [h] ++ qsort R l2)``;
107
108(*---------------------------------------------------------------------------*)
109(* Code                                                                      *)
110(*---------------------------------------------------------------------------*)
111(*
112val pair_case_tm = prim_mk_const{Name="pair_case",Thy="pair"};
113
114fun mk_pair_case (f,a) =
115  let val theta = match_type (type_of pair_case_tm) (type_of f)
116  in list_mk_comb(inst theta pair_case_tm,[f,a])
117  end;
118*)
119
120fun list_mk_option_cases u d ty =
121 itlist (fn (a,v) => fn c =>
122            mk_option_case (mk_none ty, mk_abs(v, c), a))
123        u d;
124
125(*---------------------------------------------------------------------------*)
126(* Splits tmlist into a triple (u,d,vs) in which u is a list of pairs of     *)
127(* form (term,var), d is a list of terms, and vs is a list of used variables *)
128(* The intent is that tmlist represents a list of terms, some of which are   *)
129(* statically known to be defined, i.e. are of the form "SOME t". The other  *)
130(* terms are not statically known to be defined. It is possible that they    *)
131(* might statically be known to be undefined, i.e. NONE, but that is not     *)
132(* catered for at the moment. If a term is not of the form SOME t, then, we  *)
133(* have to put the definedness test into the term structure. This is done    *)
134(* by making case analysis on an option. So ... d is the transformed tmlist  *)
135(* and u contains the elements of tmlist that can't be trivially shown to    *)
136(* be defined.                                                               *)
137(*---------------------------------------------------------------------------*)
138
139fun split_args tmlist fvs =
140  itlist (fn t => fn (u,d,vs) =>
141     if is_some t then (u,dest_some t::d,vs)
142      else let val ty = dest_option(type_of t)
143               val v = variant vs (mk_var("v",ty))
144           in ((t,v)::u, v::d, (v::vs))
145           end) tmlist ([],[],fvs);
146
147(*---------------------------------------------------------------------------*)
148(* Partiality transformation. Note that it is, at present, really only       *)
149(* applicable to first-order terms. Explicit lambdas are handled, but not    *)
150(* other terms of functional type.                                           *)
151(*---------------------------------------------------------------------------*)
152
153fun partialize env tm =
154 if is_var tm orelse is_const tm then mk_some tm else
155 if is_abs tm then
156    let val (v,M) = dest_abs tm
157    in mk_abs(v,partialize env M)
158    end else
159 if is_cond tm then
160    let val (b,t1,t2) = dest_cond tm
161        val b' = partialize env b
162        val t1' = partialize env t1
163        val t2' = partialize env t2
164    in
165      if is_some b' andalso is_some t1' andalso is_some t2'
166        then mk_some (mk_cond(dest_some b', dest_some t1', dest_some t2'))
167        else
168      if is_some b' then mk_cond (dest_some(b'), t1', t2')
169        else
170      let val rty = type_of tm
171          val v = variant (free_vars tm) (mk_var("v",bool))
172      in mk_option_case (mk_none rty, mk_abs(v, mk_cond(v,t1',t2')), b')
173      end
174    end else
175 if TypeBase.is_case tm then
176    let val (case_tm,ob,clauses) = TypeBase.dest_case tm
177        val ob' = partialize env ob
178        val clauses' = map (I ## partialize env) clauses
179    in
180      if is_some ob' andalso Lib.all (is_some o snd) clauses'
181        then mk_some (TypeBase.mk_case
182                        (dest_some ob',map (I##dest_some) clauses'))
183        else
184      if is_some ob' then TypeBase.mk_case (dest_some(ob'), clauses')
185      else let val rty = type_of tm
186               val v = variant (free_vars tm) (mk_var("v",type_of ob))
187           in mk_option_case
188                (mk_none rty, mk_abs(v, TypeBase.mk_case(v,clauses')), ob')
189           end
190    end else
191 if is_let tm then
192    let val (bindings,M) = dest_anylet tm
193        val bindings' = zip (map fst bindings)
194                            (map (partialize env o snd) bindings)
195        val M' = partialize env M
196    in itlist (fn (v,t) => fn body =>
197         if is_some t
198           then mk_anylet ([(v,dest_some t)],body)
199           else mk_option_case
200                  (mk_none (dest_option(type_of t)),mk_pabs(v,body), t))
201        bindings' M'
202    end
203 else (* is_comb tm *)
204 let val (f,args) = strip_comb tm
205     val args' = map (partialize env) args
206     val (u,d,vs) = split_args args' (free_vars tm)
207     val fapp = case subst_assoc (equal f) env
208                 of NONE => mk_some(list_mk_comb(f,d))
209                  | SOME g => list_mk_comb(g,d)
210 in
211   list_mk_option_cases u fapp (type_of tm)
212 end;
213
214fun optional fvar =
215 let val (fname,ty) = dest_var fvar
216     val pfname = "p"^fname
217     val (src,target) = strip_fun ty
218     val ty' = list_mk_fun (src, mk_option target)
219 in
220   fvar |-> mk_var(pfname, ty')
221 end;
222
223fun indexed d fvar =
224 let val (fname,ty) = dest_var fvar
225     val ifname = "i"^fname
226     val (src,target) = strip_fun ty
227     val ty' = list_mk_fun (num::src, mk_option target)
228 in
229   fvar |-> mk_comb(mk_var(ifname, ty'),d)
230 end;
231
232fun mysubst theta v = Option.valOf(subst_assoc (equal v) theta);
233
234fun mk_typed_vars name vlist tylist =
235 let fun vary (away,[],vars) = rev vars
236       | vary (away,ty::tyl,vars) =
237          let val v = numvariant away (mk_var(name,ty))
238          in vary (v::away,tyl,v::vars)
239          end
240 in vary (vlist,tylist,[])
241 end;
242
243fun single x = [x];
244
245fun new_base_cases eqns vars =
246 let fun munge fns [] vars bcases = ([],bcases)
247       | munge fns (h::t) vars bcases =
248         let val (f,args) = strip_comb(lhs (snd(strip_forall h))) in
249         if mem f fns
250         then (append(single h) ## I) (munge fns t vars bcases)
251         else let val h0_vars = mk_typed_vars "v" vars (tl (map type_of args))
252                  val h0 = mk_eq(list_mk_comb(f,zero_tm::h0_vars),
253                                 mk_none(dest_option(type_of(rhs h))))
254               in (append [h0,h]##I)
255                  (munge (f::fns) t (h0_vars@vars) (h0::bcases))
256               end
257         end
258 in munge [] eqns vars []
259 end;
260
261val option_case_rewrite = Q.prove
262(`option_case a g (option_case NONE f ob) =
263  option_case a (\v. option_case a g (f v)) ob`,
264 Cases_on `ob` THEN RW_TAC std_ss []);
265
266val linearize_case = QCONV (SIMP_CONV std_ss [option_case_rewrite]);
267
268fun mk_peqns ufns L R =
269 let val thetaP = map optional ufns
270     val L' = map (partialize thetaP) L
271     val R' = map (partialize thetaP) R
272     val eqns = map mk_eq (zip L' R')
273 in
274   map (rhs o concl o linearize_case) eqns
275 end;
276
277(*---------------------------------------------------------------------------*)
278(* Run term through the partiality transformation and build the partial and  *)
279(* indexed versions of the equations.                                        *)
280(*---------------------------------------------------------------------------*)
281
282fun alt_eqns tm =
283 let val eqns = map (snd o strip_forall) (strip_conj tm)
284     val (L,R) = unzip(map dest_eq eqns)
285     val (fns,args) = unzip(map strip_comb L)
286     val ufns = mk_set fns  (* unique fns, should all be variables *)
287     val peqns = mk_peqns ufns L R
288     (* now make ieqns *)
289     val vars = all_vars tm
290     val d = variant vars (mk_var("d",num))
291     val thetaL = map (indexed (mk_suc d)) ufns
292     val thetaR = map (indexed d) ufns
293     val fns' = map (mysubst thetaL) fns
294     val L' = map2 (curry list_mk_comb) fns' args
295     val R' = map (partialize thetaR) R
296     val ieqns1 = map mk_eq (zip L' R')
297     val (ieqns2,base_cases) = new_base_cases ieqns1 (d::vars)
298     val ieqns = map (rhs o concl o linearize_case) ieqns2
299 in
300   (peqns, ieqns, ufns, base_cases)
301 end;
302
303(*---------------------------------------------------------------------------*)
304(* Input: ifn 0 v0 ... vn = NONE                                             *)
305(*---------------------------------------------------------------------------*)
306
307fun limit_spec_def base_case =
308 let val (fvar,args) = strip_comb(lhs base_case)
309     val args' = tl args
310     val (fname,ty) = dest_var fvar
311     val fconst = mk_const (fname,ty)
312     val lim_name = fname^"Lim"
313     val limvar = numvariant (all_vars base_case) (mk_var(lim_name,num))
314     val fapp = mk_is_some(list_mk_comb(fconst, limvar::args'))
315     val tm' = list_mk_forall(args',mk_exists(limvar,fapp))
316 in
317   DEFCHOOSE (lim_name^"_spec", lim_name, tm')
318 end
319 handle e => raise wrap_exn "Index" "limit_spec_def" e;
320
321fun pfn_def fname limspec_thm =
322 let val tm = snd(dest_imp(snd(strip_forall(concl limspec_thm))))
323     val ifn_app = dest_is_some tm
324     val (const,args) = strip_comb ifn_app
325     val args' = tl args
326     val ty = list_mk_fun(map type_of args',type_of ifn_app)
327     val pfn_name = "p"^fname
328     val pfn_var = mk_var(pfn_name,ty)
329     val deftm = mk_eq(list_mk_comb(pfn_var,args'),ifn_app)
330 in
331  new_definition (pfn_name^"_def",deftm)
332 end
333 handle e => raise wrap_exn "Index" "pfn_def" e;
334
335fun fn_def name pfn_def =
336 let val pfn_app = lhs(snd(strip_forall(concl pfn_def)))
337     val (pfn,args) = strip_comb pfn_app
338     val (dtys,rty) = strip_fun (type_of pfn)
339     val fn_var = mk_var(name,list_mk_fun(dtys,dest_option rty))
340     val lapp = list_mk_comb(fn_var,args)
341     val deftm = mk_eq(lapp,mk_the pfn_app)
342 in
343    new_definition (name^"_def",deftm)
344 end
345 handle e => raise wrap_exn "Index" "fn_def" e;
346
347fun in_dom_def name fn_def =
348 let val (lapp,rapp) = dest_eq(snd(strip_forall(concl fn_def)))
349     val (lfn,args) = strip_comb lapp
350     val right = mk_is_some(dest_the rapp)
351     val dname = "in_dom_"^name
352     val dom_var = mk_var(dname,list_mk_fun(map type_of args,bool))
353     val deftm = mk_eq(list_mk_comb(dom_var, args),right)
354 in
355   new_definition (dname^"_def",deftm)
356 end
357 handle e => raise wrap_exn "Index" "in_dom_def" e;
358
359fun mk_defs tm =
360 let open TotalDefn
361     val (peqns, ieqns, vfns, base_cases) = alt_eqns tm
362     val ivar = fst(strip_comb(lhs(snd(strip_forall(hd ieqns)))))
363     val iname = fst(dest_var ivar)
364     val idef = tDefine iname `^(list_mk_conj ieqns)` (WF_REL_TAC`measure FST`)
365     val limspec_thms = map limit_spec_def base_cases
366     val names = map (fst o dest_var) vfns
367     val pfn_defs = map2 pfn_def names limspec_thms
368     val fn_defs =  map2 fn_def names pfn_defs
369     val dom_defs = map2 in_dom_def names fn_defs
370 in
371  (peqns,idef,limspec_thms,pfn_defs,fn_defs,dom_defs)
372 end;
373
374fun cross_prod [] l2 = []
375  | cross_prod (h::t) l2 = map (pair h) l2 @ cross_prod t l2;
376
377fun merge pn ((alist,b),(blist,e)) =
378    (alist@[if pn then b else mk_neg b]@blist,e);
379
380fun merge_paths bpaths pospaths negpaths =
381  map (merge true) (cross_prod bpaths pospaths) @
382  map (merge false) (cross_prod bpaths negpaths);
383
384(*---------------------------------------------------------------------------*)
385(* Collect maximal non-if terms on rhs, plus accumulate recursive calls.     *)
386(*---------------------------------------------------------------------------*)
387
388fun paths vfns tm =
389 if is_cond tm
390   then let val (b,t1,t2) = dest_cond tm
391        in merge_paths (paths vfns b) (paths vfns t1) (paths vfns t2)
392        end else
393 if TypeBase.is_case tm then
394   let
395     val (cconst,ob,clauses) = TypeBase.dest_case tm
396     val (pats,rhsl) = unzip clauses
397     val plists = map (paths vfns) rhsl
398     fun patch pat plist = map (fn (ctxt,e) => (mk_eq(ob,pat)::ctxt,e)) plist
399     val patched = map2 patch pats plists
400   in flatten patched
401   end
402 else if is_let tm
403   then let val (binds, body) = dest_anylet tm
404            val plist = paths vfns body
405            fun patch (x,M) (ctxt,e) = (mk_eq(x,M)::ctxt, body)
406        in map (fn path => itlist patch binds path) plist
407       end else
408 if is_comb tm
409   then let val (f,_) = strip_comb tm
410        in if mem f vfns
411            then [([mk_is_some tm],tm)]
412            else [([],tm)]
413       end
414 else [([],tm)];
415
416fun list_mk_conj_imp ([],b) = b
417  | list_mk_conj_imp (blist,b) = mk_imp(list_mk_conj blist,b);
418
419(*---------------------------------------------------------------------------*)
420(* Adding index pattern (0,SUC) may duplicate clauses, which is unfortunate, *)
421(* because I can't then just use peqns to generate my proof obligations.     *)
422(* Instead, I have to use the post-pattern match translation arising from    *)
423(* idef and then change those into corresponding pfn equations. That will    *)
424(* ensure that the idef will always be able to be applied as a rewrite in    *)
425(* proofs. Example : ack 0 n = SOME(n+1), but pattern-match translation with *)
426(* depth added results in this turning into two equations                    *)
427(*                                                                           *)
428(* iack (SUC d) 0 0 = SOME(0+1) and iack (SUC d) 0 (SUC n) = SOME(SUC n + 1) *)
429(*                                                                           *)
430(* so if we generated a pack 0 n = SOME (n+1) goal that wouldn't work because*)
431(* the rewrite rules for iack are too specific to rewrite that.              *)
432(*---------------------------------------------------------------------------*)
433
434fun pgoals peqns idef =
435 let fun fn_of clause = fst(strip_comb(lhs(snd(strip_forall clause))))
436     val vpfns = mk_set(map fn_of peqns)
437     val cpfns = map (mk_const o dest_var) vpfns
438     val idefs = strip_conj(snd(strip_forall(concl idef)))
439     val idefs' = map (snd o strip_forall) (tl idefs)
440     val ifns = mk_set(map fn_of idefs')
441     val _ = assert (fn () => length ifns = length cpfns) ()
442     fun mk_rule ifn pfn = mk_thm([],``^ifn ^(mk_var("n",num)) = ^pfn``)
443     val rules = map2 mk_rule ifns cpfns
444     fun transform ieqn = rhs(concl(QCONV(REWRITE_CONV rules) ieqn))
445     val peqns' = map transform idefs'
446     val plists = map (fn eqn => (lhs eqn, paths vpfns (rhs eqn))) peqns'
447     fun mk_goal left (ctxt,e) = list_mk_conj_imp(ctxt,mk_eq(left,e))
448     fun mk_goals (left,list) = map (mk_goal left) list
449     val raw_goals = flatten (map mk_goals plists)
450 in
451   raw_goals
452 end;
453
454fun test tm =
455  let val (peqns,idef,limspec_thms,pfn_defs,fn_defs,dom_defs) = mk_defs tm
456      val goals = pgoals peqns idef
457  in
458   app (fn tm => ignore(set_goal([],tm))) (rev goals);
459   (peqns,idef,limspec_thms,pfn_defs,fn_defs,dom_defs)
460  end;
461
462STOP;
463
464
465
466fun BASE_DEFINED_TAC idef limspec ifn_def_pos (asl,goal) =
467 let val (ifn,args) = strip_comb(lhs goal)
468     val args' = tl args
469     val subgoal = mk_is_some(list_mk_comb(ifn,suc_zero::args'))
470     val fact = EQT_ELIM
471                  (REWRITE_CONV ([idef,IS_SOME_DEF]@map ASSUME asl) subgoal)
472 in
473  CHOOSE_THEN SUBST_ALL_TAC
474        (MATCH_MP ifn_def_pos (MATCH_MP limspec fact))
475 end (asl,goal)
476 handle e => raise wrap_exn "Index" "BASE_DEFINED_TAC" e;
477
478fun pbase_clause pfn_def idef limspec ifn_def_pos =
479 REPEAT STRIP_TAC THEN REPEAT BasicProvers.VAR_EQ_TAC THEN
480 PURE_REWRITE_TAC [pfn_def] THEN
481 BASE_DEFINED_TAC idef limspec ifn_def_pos THEN
482 ASM_REWRITE_TAC [idef]
483
484fun precursive_clause pfn_def idef limspec = ...
485
486val (peqns,idef,limspec_thms,pfn_defs,fn_defs,dom_defs) = test ack_tm;
487
488(* Base case *)
489val [pfn_def] = pfn_defs;
490val [limspec] = limspec_thms;
491val ifn_def_pos = Q.prove
492(`!d m n. IS_SOME(iack d m n) ==> ?e. d = SUC e`,
493 Cases THEN RW_TAC std_ss [idef]);
494
495e (pbase_clause pfn_def idef limspec ifn_def_pos);
496
497(* Recursive case *)
498(* Recursive case *)
499
500dropn 12;
501
502(* factorial *)
503val (peqns,idef,limspec_thms,pfn_defs,fn_defs,dom_defs) = test fact_tm;
504
505val [pfn_def] = pfn_defs;
506val [limspec] = limspec_thms;
507val ifn_def_pos = Q.prove
508(`!d n. IS_SOME(ifact d n) ==> ?e. d = SUC e`,
509 Cases THEN RW_TAC std_ss [idef]);
510
511(* Base case *)
512e (pbase_clause pfn_def idef limspec ifn_def_pos);
513
514(*Recursive case *)
515
516(* 91 *)
517val (peqns,idef,limspec_thms,pfn_defs,fn_defs,dom_defs) = test f91_tm;
518
519val [pfn_def] = pfn_defs;
520val [limspec] = limspec_thms;
521val ifn_def_pos = Q.prove
522(`!d n. IS_SOME(if91 d n) ==> ?e. d = SUC e`,
523 Cases THEN RW_TAC std_ss [idef]);
524
525(* Base case *)
526e (pbase_clause pfn_def idef limspec ifn_def_pos);
527
528(* Ack by complex patterns ... fails because of pattern expansion. *)
529val (peqns,idef,limspec_thms,pfn_defs,fn_defs,dom_defs) = test ack1_tm;
530
531(* Base case *)
532val [pfn_def] = pfn_defs;
533val [limspec] = limspec_thms;
534val ifn_def_pos = Q.prove
535(`!d m n. IS_SOME(iack d m n) ==> ?e. d = SUC e`,
536 Cases THEN RW_TAC std_ss [idef]);
537
538e (pbase_clause pfn_def idef limspec ifn_def_pos);
539drop();
540e (pbase_clause pfn_def idef limspec ifn_def_pos);
541
542(* Recursive case *)
543(* Recursive case *)
544