1structure tailrecLib :> tailrecLib =
2struct
3
4open HolKernel boolLib bossLib Parse;
5open tailrecTheory helperLib sumSyntax pairSyntax;
6
7structure Parse =
8struct
9   open Parse
10   val (Type, Term) = parse_from_grammars tailrecTheory.tailrec_grammars
11end
12
13val tailrec_definitions = ref ([]:thm list);
14
15(* tactic, move to helperLib? *)
16
17fun dest_tuple tm =
18  let val (x,y) = dest_pair tm in x :: dest_tuple y end handle HOL_ERR e => [tm];
19
20fun EXPAND_BASIC_LET_CONV tm = let
21  val (xs,x) = dest_anylet tm
22  val (lhs,rhs) = hd xs
23  val ys = dest_tuple lhs
24  val zs = dest_tuple rhs
25  val _ = if length zs = length ys then () else fail()
26  fun every p [] = true
27    | every p (x::xs) = if p x then every p xs else fail()
28  val _ = every (fn x => every is_var (list_dest dest_conj x)) zs
29  in (((RATOR_CONV o RATOR_CONV) (REWRITE_CONV [LET_DEF]))
30      THENC DEPTH_CONV PairRules.PBETA_CONV) tm end
31  handle HOL_ERR _ => NO_CONV tm;
32
33fun STRIP_FORALL_TAC (hs,tm) =
34  if is_forall tm then STRIP_TAC (hs,tm) else NO_TAC (hs,tm)
35
36fun SPEC_AND_CASES_TAC x =
37  SPEC_TAC (x,genvar(type_of x)) THEN Cases THEN REWRITE_TAC []
38
39fun GENSPEC_TAC [] = SIMP_TAC pure_ss [pairTheory.FORALL_PROD]
40  | GENSPEC_TAC (x::xs) = SPEC_TAC (x,genvar(type_of x)) THEN GENSPEC_TAC xs;
41
42val EXPAND_BASIC_LET_TAC =
43  CONV_TAC (DEPTH_CONV EXPAND_BASIC_LET_CONV)
44  THEN REPEAT STRIP_FORALL_TAC
45
46fun AUTO_DECONSTRUCT_TAC finder (hs,goal) = let
47  val tm = finder goal
48  in if is_cond tm then let
49       val (b,_,_) = dest_cond tm
50       in SPEC_AND_CASES_TAC b (hs,goal) end
51     else if is_let tm then let
52       val (v,c) = (hd o fst o dest_anylet) tm
53       val c = if not (type_of c = ``:bool``) then c else
54         (find_term (can (match_term ``address$GUARD x b``)) c
55          handle HOL_ERR _ => c)
56       val cs = dest_tuple c
57       in (GENSPEC_TAC cs THEN EXPAND_BASIC_LET_TAC) (hs,goal) end
58     else (REWRITE_TAC [] THEN NO_TAC) (hs,goal) end
59
60(* /move to helper *)
61
62
63fun merge_side t NONE = t
64  | merge_side t (SOME (FUN_VAL tm)) =
65      if tm = mk_var("cond",``:bool``) then t else
66      if tm = T then t else FUN_COND (tm,t)
67  | merge_side t (SOME (FUN_COND (tm,t2))) = FUN_COND (tm,merge_side t (SOME t2))
68  | merge_side (FUN_IF (b,x,y)) (SOME (FUN_IF (b2,x2,y2))) =
69      if (b = b2) then FUN_IF (b, merge_side x (SOME x2), merge_side y (SOME y2)) else fail()
70  | merge_side (FUN_LET (x,y,t)) (SOME (FUN_LET (x2,y2,t2))) =
71      if (x = x2) andalso (y = y2) then FUN_LET (x,y,merge_side t (SOME t2)) else fail()
72  | merge_side _ _ = fail ()
73
74fun leaves (FUN_VAL tm)      f = FUN_VAL (f tm)
75  | leaves (FUN_COND (c,t))  f = FUN_COND (c, leaves t f)
76  | leaves (FUN_IF (a,b,c))  f = FUN_IF (a, leaves b f, leaves c f)
77  | leaves (FUN_LET (v,y,t)) f = FUN_LET (v, y, leaves t f)
78
79fun rm_conds (FUN_VAL tm)      = FUN_VAL tm
80  | rm_conds (FUN_COND (c,t))  = rm_conds t
81  | rm_conds (FUN_IF (a,b,c))  = FUN_IF (a, rm_conds b, rm_conds c)
82  | rm_conds (FUN_LET (v,y,t)) = FUN_LET (v, y, rm_conds t)
83
84fun tailrec_define_from_step func_name step_fun tm_option = let
85  (* definitions *)
86  val thm = ISPEC step_fun SHORT_TAILREC_def
87  val def_rhs = (fst o dest_eq o concl) thm
88  val def_tm = mk_eq (mk_var(func_name,type_of def_rhs),def_rhs)
89  val def_thm = new_definition(func_name,def_tm)
90  val new_def_tm = (fst o dest_eq o concl) def_thm
91  val side = ISPEC step_fun SHORT_TAILREC_PRE_def
92  val side_rhs = (fst o dest_eq o concl) side
93  val side_tm = mk_eq (mk_var(func_name ^ "_pre",type_of side_rhs),side_rhs)
94  val side_thm = new_definition(func_name ^ "_pre",side_tm)
95  val new_side_tm = (fst o dest_eq o concl) side_thm
96  val _ = tailrec_definitions := def_thm::side_thm::(!tailrec_definitions)
97  (* goals *)
98  fun is_inl tm = can (match_term ``(INL x):'a + 'b``) tm
99  fun leaves_inl body f1 f2 = ftree2tm (leaves (tm2ftree body) (fn tm =>
100          if is_inl (fst (dest_pair tm))
101          then f1 (cdr (fst (dest_pair tm)),snd (dest_pair tm))
102          else f2 (cdr (fst (dest_pair tm)),snd (dest_pair tm))))
103  val inst_cond_var = snd o dest_eq o concl o QCONV (REWRITE_CONV [GSYM CONJ_ASSOC]) o
104                      subst [mk_var("cond",``:bool``) |-> T]
105  val (def_goal,side_goal) = case tm_option of
106      SOME (tm,pre_option) => let
107        val (lhs,rhs) = dest_eq tm
108        val func_tm = repeat car lhs
109        val (old_side_tm,tm2) = (case pre_option of
110              NONE => (new_side_tm,T)
111            | SOME x => (repeat car (fst (dest_eq x)),x))
112        in (subst [func_tm |-> new_def_tm] tm,
113            subst [old_side_tm |-> new_side_tm] tm2) end
114    | NONE => let
115        val (args,body) = dest_pabs step_fun
116        val def_body = (ftree2tm o rm_conds o tm2ftree) body
117        val def_body = leaves_inl def_body (fn (tm,_) => mk_comb(new_def_tm,tm)) fst
118        val def_goal = mk_eq(mk_comb(new_def_tm,args),def_body)
119        val side_body = leaves_inl body (fn (tm,c) => mk_conj(mk_comb(new_side_tm,tm),c)) snd
120        val side_goal = mk_eq(mk_comb(new_side_tm,args),side_body)
121        val side_goal = if side_body = ``T`` then side_goal else inst_cond_var (side_goal)
122        in (def_goal,side_goal) end
123  (* prove exported theorems *)
124  fun tac finder =
125    PURE_REWRITE_TAC [def_thm,side_thm]
126    THEN CONV_TAC (RATOR_CONV (PURE_ONCE_REWRITE_CONV [SHORT_TAILREC_THM]))
127    THEN PURE_REWRITE_TAC [GSYM def_thm,GSYM side_thm]
128    THEN CONV_TAC (DEPTH_CONV PairRules.PBETA_CONV)
129    THEN PURE_REWRITE_TAC [AND_CLAUSES]
130    THEN REPEAT (AUTO_DECONSTRUCT_TAC finder)
131    THEN ASM_SIMP_TAC std_ss [sumTheory.ISL,sumTheory.ISR,sumTheory.OUTL,
132           sumTheory.OUTR,LET_DEF, AC CONJ_COMM CONJ_ASSOC]
133    THEN EQ_TAC THEN SIMP_TAC std_ss []
134  val finder = (rand o rand o rand o fst o dest_eq)
135  val def_result = auto_prove "tailrec_define" (def_goal,tac finder)
136  val finder = rand
137  val side_result = RW [] (auto_prove "tailrec_define_with_pre" (side_goal,tac finder))
138  in (def_result,def_thm,side_result,side_thm) end
139
140fun prepare_pre pre_tm = let
141  val (x,y) = dest_eq pre_tm
142  val pre_tm = ftree2tm (leaves (tm2ftree y) (fn tm =>
143     list_mk_conj (
144      (filter (fn c => not (is_comb c andalso
145                            (car c = car x)))
146       (list_dest dest_conj tm)))))
147  val cond_var = mk_var("cond",``:bool``)
148  val pre_tm = subst [cond_var|->T] pre_tm
149  val pre_tm = (snd o dest_eq o concl o SPEC_ALL o QCONV
150                (REWRITE_CONV [])) pre_tm
151  in pre_tm end
152
153fun tailrec_define_full tm pre_option = let
154  val (lhs,rhs) = dest_eq tm
155  val func_tm = repeat car lhs
156  val func_name = fst (dest_var func_tm) handle HOL_ERR e => fst (dest_const func_tm)
157  (* construct step function *)
158  fun option_apply f NONE = NONE | option_apply f (SOME x) = SOME (f x)
159  val t = merge_side (tm2ftree rhs) (option_apply (tm2ftree o prepare_pre) pre_option)
160  val ty = (snd o dest_type o type_of) func_tm
161  val input_type = el 1 ty
162  val output_type = el 2 ty
163  val cond_var = mk_var("cond",``:bool``)
164  fun step (FUN_IF (b,t1,t2)) = FUN_IF (b,step t1,step t2)
165    | step (FUN_LET (x,y,t)) = FUN_LET (x,y,step t)
166    | step (FUN_COND (c,t)) = FUN_COND (c,step t)
167    | step (FUN_VAL tm) =
168        if ((car tm = func_tm) handle HOL_ERR _ => false)
169        then FUN_VAL (mk_pair(mk_inl(cdr tm,output_type),cond_var))
170        else FUN_VAL (mk_pair(mk_inr(tm,input_type),cond_var))
171  val tm2 = subst [cond_var|->T] (ftree2tm (step t))
172  val step_fun = mk_pabs(cdr lhs,tm2)
173  val tm_option = SOME (tm,pre_option)
174  val (def_result,def_thm,side_result,side_thm) =
175        tailrec_define_from_step func_name step_fun tm_option
176  val _ = save_thm(func_name ^ "_def", def_result)
177  val _ = save_thm(func_name ^ "_pre_def", side_result)
178  in (def_result,def_thm,side_result,side_thm) end;
179
180fun tailrec_define tm =
181  let val (th,_,_,_) = tailrec_define_full tm NONE in th end;
182
183fun tailrec_define_with_pre tm pre =
184  let val (th,_,pre,_) = tailrec_define_full tm (SOME pre) in (th,pre) end;
185
186fun TAILREC_TAC (hs,g) =
187  (REWRITE_TAC [] THEN ONCE_REWRITE_TAC (!tailrec_definitions)) (hs,g);
188
189
190end;
191