1structure tailrecLib :> tailrecLib = 2struct 3 4open HolKernel boolLib bossLib Parse; 5open tailrecTheory helperLib sumSyntax pairSyntax; 6 7structure Parse = 8struct 9 open Parse 10 val (Type, Term) = parse_from_grammars tailrecTheory.tailrec_grammars 11end 12 13val tailrec_definitions = ref ([]:thm list); 14 15(* tactic, move to helperLib? *) 16 17fun dest_tuple tm = 18 let val (x,y) = dest_pair tm in x :: dest_tuple y end handle HOL_ERR e => [tm]; 19 20fun EXPAND_BASIC_LET_CONV tm = let 21 val (xs,x) = dest_anylet tm 22 val (lhs,rhs) = hd xs 23 val ys = dest_tuple lhs 24 val zs = dest_tuple rhs 25 val _ = if length zs = length ys then () else fail() 26 fun every p [] = true 27 | every p (x::xs) = if p x then every p xs else fail() 28 val _ = every (fn x => every is_var (list_dest dest_conj x)) zs 29 in (((RATOR_CONV o RATOR_CONV) (REWRITE_CONV [LET_DEF])) 30 THENC DEPTH_CONV PairRules.PBETA_CONV) tm end 31 handle HOL_ERR _ => NO_CONV tm; 32 33fun STRIP_FORALL_TAC (hs,tm) = 34 if is_forall tm then STRIP_TAC (hs,tm) else NO_TAC (hs,tm) 35 36fun SPEC_AND_CASES_TAC x = 37 SPEC_TAC (x,genvar(type_of x)) THEN Cases THEN REWRITE_TAC [] 38 39fun GENSPEC_TAC [] = SIMP_TAC pure_ss [pairTheory.FORALL_PROD] 40 | GENSPEC_TAC (x::xs) = SPEC_TAC (x,genvar(type_of x)) THEN GENSPEC_TAC xs; 41 42val EXPAND_BASIC_LET_TAC = 43 CONV_TAC (DEPTH_CONV EXPAND_BASIC_LET_CONV) 44 THEN REPEAT STRIP_FORALL_TAC 45 46fun AUTO_DECONSTRUCT_TAC finder (hs,goal) = let 47 val tm = finder goal 48 in if is_cond tm then let 49 val (b,_,_) = dest_cond tm 50 in SPEC_AND_CASES_TAC b (hs,goal) end 51 else if is_let tm then let 52 val (v,c) = (hd o fst o dest_anylet) tm 53 val c = if not (type_of c = ``:bool``) then c else 54 (find_term (can (match_term ``address$GUARD x b``)) c 55 handle HOL_ERR _ => c) 56 val cs = dest_tuple c 57 in (GENSPEC_TAC cs THEN EXPAND_BASIC_LET_TAC) (hs,goal) end 58 else (REWRITE_TAC [] THEN NO_TAC) (hs,goal) end 59 60(* /move to helper *) 61 62 63fun merge_side t NONE = t 64 | merge_side t (SOME (FUN_VAL tm)) = 65 if tm = mk_var("cond",``:bool``) then t else 66 if tm = T then t else FUN_COND (tm,t) 67 | merge_side t (SOME (FUN_COND (tm,t2))) = FUN_COND (tm,merge_side t (SOME t2)) 68 | merge_side (FUN_IF (b,x,y)) (SOME (FUN_IF (b2,x2,y2))) = 69 if (b = b2) then FUN_IF (b, merge_side x (SOME x2), merge_side y (SOME y2)) else fail() 70 | merge_side (FUN_LET (x,y,t)) (SOME (FUN_LET (x2,y2,t2))) = 71 if (x = x2) andalso (y = y2) then FUN_LET (x,y,merge_side t (SOME t2)) else fail() 72 | merge_side _ _ = fail () 73 74fun leaves (FUN_VAL tm) f = FUN_VAL (f tm) 75 | leaves (FUN_COND (c,t)) f = FUN_COND (c, leaves t f) 76 | leaves (FUN_IF (a,b,c)) f = FUN_IF (a, leaves b f, leaves c f) 77 | leaves (FUN_LET (v,y,t)) f = FUN_LET (v, y, leaves t f) 78 79fun rm_conds (FUN_VAL tm) = FUN_VAL tm 80 | rm_conds (FUN_COND (c,t)) = rm_conds t 81 | rm_conds (FUN_IF (a,b,c)) = FUN_IF (a, rm_conds b, rm_conds c) 82 | rm_conds (FUN_LET (v,y,t)) = FUN_LET (v, y, rm_conds t) 83 84fun tailrec_define_from_step func_name step_fun tm_option = let 85 (* definitions *) 86 val thm = ISPEC step_fun SHORT_TAILREC_def 87 val def_rhs = (fst o dest_eq o concl) thm 88 val def_tm = mk_eq (mk_var(func_name,type_of def_rhs),def_rhs) 89 val def_thm = new_definition(func_name,def_tm) 90 val new_def_tm = (fst o dest_eq o concl) def_thm 91 val side = ISPEC step_fun SHORT_TAILREC_PRE_def 92 val side_rhs = (fst o dest_eq o concl) side 93 val side_tm = mk_eq (mk_var(func_name ^ "_pre",type_of side_rhs),side_rhs) 94 val side_thm = new_definition(func_name ^ "_pre",side_tm) 95 val new_side_tm = (fst o dest_eq o concl) side_thm 96 val _ = tailrec_definitions := def_thm::side_thm::(!tailrec_definitions) 97 (* goals *) 98 fun is_inl tm = can (match_term ``(INL x):'a + 'b``) tm 99 fun leaves_inl body f1 f2 = ftree2tm (leaves (tm2ftree body) (fn tm => 100 if is_inl (fst (dest_pair tm)) 101 then f1 (cdr (fst (dest_pair tm)),snd (dest_pair tm)) 102 else f2 (cdr (fst (dest_pair tm)),snd (dest_pair tm)))) 103 val inst_cond_var = snd o dest_eq o concl o QCONV (REWRITE_CONV [GSYM CONJ_ASSOC]) o 104 subst [mk_var("cond",``:bool``) |-> T] 105 val (def_goal,side_goal) = case tm_option of 106 SOME (tm,pre_option) => let 107 val (lhs,rhs) = dest_eq tm 108 val func_tm = repeat car lhs 109 val (old_side_tm,tm2) = (case pre_option of 110 NONE => (new_side_tm,T) 111 | SOME x => (repeat car (fst (dest_eq x)),x)) 112 in (subst [func_tm |-> new_def_tm] tm, 113 subst [old_side_tm |-> new_side_tm] tm2) end 114 | NONE => let 115 val (args,body) = dest_pabs step_fun 116 val def_body = (ftree2tm o rm_conds o tm2ftree) body 117 val def_body = leaves_inl def_body (fn (tm,_) => mk_comb(new_def_tm,tm)) fst 118 val def_goal = mk_eq(mk_comb(new_def_tm,args),def_body) 119 val side_body = leaves_inl body (fn (tm,c) => mk_conj(mk_comb(new_side_tm,tm),c)) snd 120 val side_goal = mk_eq(mk_comb(new_side_tm,args),side_body) 121 val side_goal = if side_body = ``T`` then side_goal else inst_cond_var (side_goal) 122 in (def_goal,side_goal) end 123 (* prove exported theorems *) 124 fun tac finder = 125 PURE_REWRITE_TAC [def_thm,side_thm] 126 THEN CONV_TAC (RATOR_CONV (PURE_ONCE_REWRITE_CONV [SHORT_TAILREC_THM])) 127 THEN PURE_REWRITE_TAC [GSYM def_thm,GSYM side_thm] 128 THEN CONV_TAC (DEPTH_CONV PairRules.PBETA_CONV) 129 THEN PURE_REWRITE_TAC [AND_CLAUSES] 130 THEN REPEAT (AUTO_DECONSTRUCT_TAC finder) 131 THEN ASM_SIMP_TAC std_ss [sumTheory.ISL,sumTheory.ISR,sumTheory.OUTL, 132 sumTheory.OUTR,LET_DEF, AC CONJ_COMM CONJ_ASSOC] 133 THEN EQ_TAC THEN SIMP_TAC std_ss [] 134 val finder = (rand o rand o rand o fst o dest_eq) 135 val def_result = auto_prove "tailrec_define" (def_goal,tac finder) 136 val finder = rand 137 val side_result = RW [] (auto_prove "tailrec_define_with_pre" (side_goal,tac finder)) 138 in (def_result,def_thm,side_result,side_thm) end 139 140fun prepare_pre pre_tm = let 141 val (x,y) = dest_eq pre_tm 142 val pre_tm = ftree2tm (leaves (tm2ftree y) (fn tm => 143 list_mk_conj ( 144 (filter (fn c => not (is_comb c andalso 145 (car c = car x))) 146 (list_dest dest_conj tm))))) 147 val cond_var = mk_var("cond",``:bool``) 148 val pre_tm = subst [cond_var|->T] pre_tm 149 val pre_tm = (snd o dest_eq o concl o SPEC_ALL o QCONV 150 (REWRITE_CONV [])) pre_tm 151 in pre_tm end 152 153fun tailrec_define_full tm pre_option = let 154 val (lhs,rhs) = dest_eq tm 155 val func_tm = repeat car lhs 156 val func_name = fst (dest_var func_tm) handle HOL_ERR e => fst (dest_const func_tm) 157 (* construct step function *) 158 fun option_apply f NONE = NONE | option_apply f (SOME x) = SOME (f x) 159 val t = merge_side (tm2ftree rhs) (option_apply (tm2ftree o prepare_pre) pre_option) 160 val ty = (snd o dest_type o type_of) func_tm 161 val input_type = el 1 ty 162 val output_type = el 2 ty 163 val cond_var = mk_var("cond",``:bool``) 164 fun step (FUN_IF (b,t1,t2)) = FUN_IF (b,step t1,step t2) 165 | step (FUN_LET (x,y,t)) = FUN_LET (x,y,step t) 166 | step (FUN_COND (c,t)) = FUN_COND (c,step t) 167 | step (FUN_VAL tm) = 168 if ((car tm = func_tm) handle HOL_ERR _ => false) 169 then FUN_VAL (mk_pair(mk_inl(cdr tm,output_type),cond_var)) 170 else FUN_VAL (mk_pair(mk_inr(tm,input_type),cond_var)) 171 val tm2 = subst [cond_var|->T] (ftree2tm (step t)) 172 val step_fun = mk_pabs(cdr lhs,tm2) 173 val tm_option = SOME (tm,pre_option) 174 val (def_result,def_thm,side_result,side_thm) = 175 tailrec_define_from_step func_name step_fun tm_option 176 val _ = save_thm(func_name ^ "_def", def_result) 177 val _ = save_thm(func_name ^ "_pre_def", side_result) 178 in (def_result,def_thm,side_result,side_thm) end; 179 180fun tailrec_define tm = 181 let val (th,_,_,_) = tailrec_define_full tm NONE in th end; 182 183fun tailrec_define_with_pre tm pre = 184 let val (th,_,pre,_) = tailrec_define_full tm (SOME pre) in (th,pre) end; 185 186fun TAILREC_TAC (hs,g) = 187 (REWRITE_TAC [] THEN ONCE_REWRITE_TAC (!tailrec_definitions)) (hs,g); 188 189 190end; 191