1structure tailrecLib :> tailrecLib =
2struct
3
4open HolKernel boolLib bossLib Parse;
5open tailrecTheory helperLib sumSyntax pairSyntax;
6
7structure Parse =
8struct
9   open Parse
10   val (Type, Term) = parse_from_grammars tailrecTheory.tailrec_grammars
11end
12
13val tailrec_definitions = ref ([]:thm list);
14
15(* tactic, move to helperLib? *)
16
17fun dest_tuple tm =
18  let val (x,y) = dest_pair tm in x :: dest_tuple y end handle HOL_ERR e => [tm];
19
20fun EXPAND_BASIC_LET_CONV tm = let
21  val (xs,x) = dest_anylet tm
22  val (lhs,rhs) = hd xs
23  val ys = dest_tuple lhs
24  val zs = dest_tuple rhs
25  val _ = if length zs = length ys then () else fail()
26  fun every p [] = true
27    | every p (x::xs) = if p x then every p xs else fail()
28  val _ = every (fn x => every is_var (list_dest dest_conj x)) zs
29  in (((RATOR_CONV o RATOR_CONV) (REWRITE_CONV [LET_DEF]))
30      THENC DEPTH_CONV PairRules.PBETA_CONV) tm end
31  handle HOL_ERR _ => NO_CONV tm;
32
33fun STRIP_FORALL_TAC (hs,tm) =
34  if is_forall tm then STRIP_TAC (hs,tm) else NO_TAC (hs,tm)
35
36fun SPEC_AND_CASES_TAC x =
37  SPEC_TAC (x,genvar(type_of x)) THEN Cases THEN REWRITE_TAC []
38
39fun GENSPEC_TAC [] = SIMP_TAC pure_ss [pairTheory.FORALL_PROD]
40  | GENSPEC_TAC (x::xs) = SPEC_TAC (x,genvar(type_of x)) THEN GENSPEC_TAC xs;
41
42val EXPAND_BASIC_LET_TAC =
43  CONV_TAC (DEPTH_CONV EXPAND_BASIC_LET_CONV)
44  THEN REPEAT STRIP_FORALL_TAC
45
46fun AUTO_DECONSTRUCT_TAC finder (hs,goal) = let
47  val tm = finder goal
48  in if is_cond tm then let
49       val (b,_,_) = dest_cond tm
50       in SPEC_AND_CASES_TAC b (hs,goal) end
51     else if is_let tm then let
52       val (v,c) = (hd o fst o dest_anylet) tm
53       val c = if not (type_of c = ``:bool``) then c else
54         (find_term (can (match_term ``address$GUARD x b``)) c
55          handle HOL_ERR _ => c)
56       val cs = dest_tuple c
57       in (GENSPEC_TAC cs THEN EXPAND_BASIC_LET_TAC) (hs,goal) end
58     else (REWRITE_TAC [] THEN NO_TAC) (hs,goal) end
59
60(* /move to helper *)
61
62
63fun merge_side t NONE = t
64  | merge_side t (SOME (FUN_VAL tm)) =
65      if tm ~~ mk_var("cond",``:bool``) then t else
66      if Teq tm then t else FUN_COND (tm,t)
67  | merge_side t (SOME (FUN_COND (tm,t2))) = FUN_COND (tm,merge_side t (SOME t2))
68  | merge_side (FUN_IF (b,x,y)) (SOME (FUN_IF (b2,x2,y2))) =
69      if b ~~ b2 then FUN_IF (b, merge_side x (SOME x2), merge_side y (SOME y2))
70      else fail()
71  | merge_side (FUN_LET (x,y,t)) (SOME (FUN_LET (x2,y2,t2))) =
72      if x ~~ x2 andalso y ~~ y2 then FUN_LET (x,y,merge_side t (SOME t2))
73      else fail()
74  | merge_side _ _ = fail ()
75
76fun leaves (FUN_VAL tm)      f = FUN_VAL (f tm)
77  | leaves (FUN_COND (c,t))  f = FUN_COND (c, leaves t f)
78  | leaves (FUN_IF (a,b,c))  f = FUN_IF (a, leaves b f, leaves c f)
79  | leaves (FUN_LET (v,y,t)) f = FUN_LET (v, y, leaves t f)
80
81fun rm_conds (FUN_VAL tm)      = FUN_VAL tm
82  | rm_conds (FUN_COND (c,t))  = rm_conds t
83  | rm_conds (FUN_IF (a,b,c))  = FUN_IF (a, rm_conds b, rm_conds c)
84  | rm_conds (FUN_LET (v,y,t)) = FUN_LET (v, y, rm_conds t)
85
86fun tailrec_define_from_step func_name step_fun tm_option = let
87  (* definitions *)
88  val thm = ISPEC step_fun SHORT_TAILREC_def
89  val def_rhs = (fst o dest_eq o concl) thm
90  val def_tm = mk_eq (mk_var(func_name,type_of def_rhs),def_rhs)
91  val def_thm = new_definition(func_name,def_tm)
92  val new_def_tm = (fst o dest_eq o concl) def_thm
93  val side = ISPEC step_fun SHORT_TAILREC_PRE_def
94  val side_rhs = (fst o dest_eq o concl) side
95  val side_tm = mk_eq (mk_var(func_name ^ "_pre",type_of side_rhs),side_rhs)
96  val side_thm = new_definition(func_name ^ "_pre",side_tm)
97  val new_side_tm = (fst o dest_eq o concl) side_thm
98  val _ = tailrec_definitions := def_thm::side_thm::(!tailrec_definitions)
99  (* goals *)
100  fun is_inl tm = can (match_term ``(INL x):'a + 'b``) tm
101  fun leaves_inl body f1 f2 = ftree2tm (leaves (tm2ftree body) (fn tm =>
102          if is_inl (fst (dest_pair tm))
103          then f1 (cdr (fst (dest_pair tm)),snd (dest_pair tm))
104          else f2 (cdr (fst (dest_pair tm)),snd (dest_pair tm))))
105  val inst_cond_var = snd o dest_eq o concl o QCONV (REWRITE_CONV [GSYM CONJ_ASSOC]) o
106                      subst [mk_var("cond",``:bool``) |-> T]
107  val (def_goal,side_goal) = case tm_option of
108      SOME (tm,pre_option) => let
109        val (lhs,rhs) = dest_eq tm
110        val func_tm = repeat car lhs
111        val (old_side_tm,tm2) = (case pre_option of
112              NONE => (new_side_tm,T)
113            | SOME x => (repeat car (fst (dest_eq x)),x))
114        in (subst [func_tm |-> new_def_tm] tm,
115            subst [old_side_tm |-> new_side_tm] tm2) end
116    | NONE => let
117        val (args,body) = dest_pabs step_fun
118        val def_body = (ftree2tm o rm_conds o tm2ftree) body
119        val def_body = leaves_inl def_body (fn (tm,_) => mk_comb(new_def_tm,tm)) fst
120        val def_goal = mk_eq(mk_comb(new_def_tm,args),def_body)
121        val side_body = leaves_inl body (fn (tm,c) => mk_conj(mk_comb(new_side_tm,tm),c)) snd
122        val side_goal = mk_eq(mk_comb(new_side_tm,args),side_body)
123        val side_goal = if Teq side_body then side_goal
124                        else inst_cond_var (side_goal)
125        in (def_goal,side_goal) end
126  (* prove exported theorems *)
127  fun tac finder =
128    PURE_REWRITE_TAC [def_thm,side_thm]
129    THEN CONV_TAC (RATOR_CONV (PURE_ONCE_REWRITE_CONV [SHORT_TAILREC_THM]))
130    THEN PURE_REWRITE_TAC [GSYM def_thm,GSYM side_thm]
131    THEN CONV_TAC (DEPTH_CONV PairRules.PBETA_CONV)
132    THEN PURE_REWRITE_TAC [AND_CLAUSES]
133    THEN REPEAT (AUTO_DECONSTRUCT_TAC finder)
134    THEN ASM_SIMP_TAC std_ss [sumTheory.ISL,sumTheory.ISR,sumTheory.OUTL,
135           sumTheory.OUTR,LET_DEF, AC CONJ_COMM CONJ_ASSOC]
136    THEN EQ_TAC THEN SIMP_TAC std_ss []
137  val finder = (rand o rand o rand o fst o dest_eq)
138  val def_result = auto_prove "tailrec_define" (def_goal,tac finder)
139  val finder = rand
140  val side_result = RW [] (auto_prove "tailrec_define_with_pre" (side_goal,tac finder))
141  in (def_result,def_thm,side_result,side_thm) end
142
143fun prepare_pre pre_tm = let
144  val (x,y) = dest_eq pre_tm
145  val pre_tm = ftree2tm (leaves (tm2ftree y) (fn tm =>
146     list_mk_conj (
147       (filter (fn c => not (is_comb c andalso (car c ~~ car x)))
148              (list_dest dest_conj tm)))))
149  val cond_var = mk_var("cond",``:bool``)
150  val pre_tm = subst [cond_var|->T] pre_tm
151  val pre_tm = (snd o dest_eq o concl o SPEC_ALL o QCONV
152                (REWRITE_CONV [])) pre_tm
153  in pre_tm end
154
155fun tailrec_define_full tm pre_option = let
156  val (lhs,rhs) = dest_eq tm
157  val func_tm = repeat car lhs
158  val func_name = fst (dest_var func_tm) handle HOL_ERR e => fst (dest_const func_tm)
159  (* construct step function *)
160  fun option_apply f NONE = NONE | option_apply f (SOME x) = SOME (f x)
161  val t = merge_side (tm2ftree rhs) (option_apply (tm2ftree o prepare_pre) pre_option)
162  val ty = (snd o dest_type o type_of) func_tm
163  val input_type = el 1 ty
164  val output_type = el 2 ty
165  val cond_var = mk_var("cond",``:bool``)
166  fun step (FUN_IF (b,t1,t2)) = FUN_IF (b,step t1,step t2)
167    | step (FUN_LET (x,y,t)) = FUN_LET (x,y,step t)
168    | step (FUN_COND (c,t)) = FUN_COND (c,step t)
169    | step (FUN_VAL tm) =
170        if ((car tm ~~ func_tm) handle HOL_ERR _ => false)
171        then FUN_VAL (mk_pair(mk_inl(cdr tm,output_type),cond_var))
172        else FUN_VAL (mk_pair(mk_inr(tm,input_type),cond_var))
173  val tm2 = subst [cond_var|->T] (ftree2tm (step t))
174  val step_fun = mk_pabs(cdr lhs,tm2)
175  val tm_option = SOME (tm,pre_option)
176  val (def_result,def_thm,side_result,side_thm) =
177        tailrec_define_from_step func_name step_fun tm_option
178  val _ = save_thm(func_name ^ "_def", def_result)
179  val _ = save_thm(func_name ^ "_pre_def", side_result)
180  in (def_result,def_thm,side_result,side_thm) end;
181
182fun tailrec_define tm =
183  let val (th,_,_,_) = tailrec_define_full tm NONE in th end;
184
185fun tailrec_define_with_pre tm pre =
186  let val (th,_,pre,_) = tailrec_define_full tm (SOME pre) in (th,pre) end;
187
188fun TAILREC_TAC (hs,g) =
189  (REWRITE_TAC [] THEN ONCE_REWRITE_TAC (!tailrec_definitions)) (hs,g);
190
191
192end;
193