1(* ========================================================================= *) 2(* PREDICATE SUBTYPE PROVER *) 3(* ========================================================================= *) 4 5structure subtypeTools :> subtypeTools = 6struct 7 8open HolKernel Parse boolLib bossLib res_quanTools; 9 10val ERR = mk_HOL_ERR "subtypeTools"; 11val Bug = mlibUseful.Bug; 12val Error = ERR ""; 13 14(* ------------------------------------------------------------------------- *) 15(* Helper proof tools. *) 16(* ------------------------------------------------------------------------- *) 17 18fun bool_compare (true,false) = LESS 19 | bool_compare (false,true) = GREATER 20 | bool_compare _ = EQUAL; 21 22val dest_in = dest_binop pred_setSyntax.in_tm (ERR "dest_in" ""); 23 24val is_in = can dest_in; 25 26val abbrev_tm = ``Abbrev``; 27 28fun dest_abbrev tm = 29 let 30 val (c,t) = dest_comb tm 31 val () = if same_const c abbrev_tm then () else raise ERR "dest_abbrev" "" 32 in 33 dest_eq t 34 end; 35 36val is_abbrev = can dest_abbrev; 37 38val norm_rule = 39 SIMP_RULE (simpLib.++ (pureSimps.pure_ss, resq_SS)) 40 [GSYM LEFT_FORALL_IMP_THM, GSYM RIGHT_FORALL_IMP_THM, 41 AND_IMP_INTRO, GSYM CONJ_ASSOC]; 42 43fun match_tac th = 44 let 45 val th = norm_rule th 46 val (_,tm) = strip_forall (concl th) 47 in 48 (if is_imp tm then MATCH_MP_TAC else MATCH_ACCEPT_TAC) th 49 end; 50 51fun flexible_solver solver cond = 52 let 53 val cond_th = solver cond 54 val cond_tm = concl cond_th 55 in 56 if cond_tm = cond then cond_th 57 else if cond_tm = mk_eq (cond,T) then EQT_ELIM cond_th 58 else raise Bug "flexible_solver: solver didn't prove condition" 59 end; 60 61fun cond_rewr_conv rewr = 62 let 63 val rewr = SPEC_ALL (norm_rule rewr) 64 val rewr_tm = concl rewr 65 val (no_cond,eq) = 66 case total dest_imp_only rewr_tm of 67 NONE => (true,rewr_tm) 68 | SOME (_,eq) => (false,eq) 69 val pat = lhs eq 70 in 71 fn solver => fn tm => 72 let 73 val sub = match_term pat tm 74 val th = INST_TY_TERM sub rewr 75 in 76 if no_cond then th 77 else MP th (flexible_solver solver (rand (rator (concl th)))) 78 end 79 end; 80 81fun cond_rewrs_conv ths = 82 let 83 val solver_convs = map cond_rewr_conv ths 84 fun mk_conv solver solver_conv = solver_conv solver 85 in 86 fn solver => FIRST_CONV (map (mk_conv solver) solver_convs) 87 end; 88 89local 90 type cache = (term,thm) Binarymap.dict ref; 91 92 fun in_cache cache (asl,g) = 93 case Binarymap.peek (cache,g) of 94 NONE => NONE 95 | SOME th => 96 if List.all (fn h => mem h asl) (hyp th) then SOME th else NONE; 97in 98 fun cache_new () = ref (Binarymap.mkDict compare); 99 100 fun cache_tac (cache : cache) (goal as (_,g)) = 101 case in_cache (!cache) goal of 102 SOME th => ([], fn [] => th | _ => raise Fail "cache_tac: hit") 103 | NONE => 104 ([goal], 105 fn [th] => (cache := Binarymap.insert (!cache, g, th); th) 106 | _ => raise Fail "cache_tac: miss"); 107end; 108 109fun print_tac s goal = (print s; ALL_TAC goal); 110 111(* ------------------------------------------------------------------------- *) 112(* Solver conversions. *) 113(* ------------------------------------------------------------------------- *) 114 115type solver_conv = Conv.conv -> Conv.conv; 116 117fun binop_ac_conv info = 118 let 119 val {term_compare, 120 dest_binop, 121 dest_inv, 122 dest_exp, 123 assoc_th, 124 comm_th, 125 comm_th', 126 id_ths, 127 simplify_ths, 128 combine_ths, 129 combine_ths'} = info 130 131 val is_binop = can dest_binop 132 and is_inv = can dest_inv 133 and is_exp = can dest_exp 134 135 fun dest tm = 136 let 137 val (pos,tm) = 138 case total dest_inv tm of 139 NONE => (true,tm) 140 | SOME (_ : term, tm) => (false,tm) 141 val (sing,tm) = 142 case total dest_exp tm of 143 NONE => (true,tm) 144 | SOME (_ : term, tm, _ : term) => (false,tm) 145 in 146 (tm,pos,sing) 147 end 148 149 fun cmp (x,y) = 150 let 151 val (xt,xp,xs) = dest x 152 and (yt,yp,ys) = dest y 153 in 154 case term_compare (xt,yt) of 155 LESS => (true,false) 156 | EQUAL => 157 (case bool_compare (xp,yp) of 158 LESS => (true,true) 159 | EQUAL => 160 (case bool_compare (xs,ys) of 161 LESS => (true,true) 162 | EQUAL => (true,true) 163 | GREATER => (false,true)) 164 | GREATER => (false,true)) 165 | GREATER => (false,false) 166 end 167 168 val assoc_conv = cond_rewr_conv assoc_th 169 170 val comm_conv = cond_rewr_conv comm_th 171 172 val comm_conv' = cond_rewr_conv comm_th' 173 174 val id_conv = cond_rewrs_conv id_ths 175 176 val term_simplify_conv = cond_rewrs_conv simplify_ths 177 178 val term_combine_conv = 179 let 180 val conv = cond_rewrs_conv combine_ths 181 in 182 fn solver => 183 conv solver THENC 184 reduceLib.REDUCE_CONV THENC 185 TRY_CONV (term_simplify_conv solver) 186 end 187 188 val term_combine_conv' = 189 let 190 val conv = cond_rewrs_conv combine_ths' 191 in 192 fn solver => 193 conv solver THENC 194 LAND_CONV 195 (reduceLib.REDUCE_CONV THENC 196 TRY_CONV (term_simplify_conv solver)) THENC 197 TRY_CONV (id_conv solver) 198 end 199 200 fun push_conv solver tm = 201 TRY_CONV 202 let 203 val (_,a,b) = dest_binop tm 204 in 205 case total dest_binop b of 206 NONE => 207 let 208 val (ok,eq) = cmp (a,b) 209 in 210 (if ok then ALL_CONV else comm_conv solver) THENC 211 (if eq then TRY_CONV (term_combine_conv solver) else ALL_CONV) 212 end 213 | SOME (_,b,_) => 214 let 215 val (ok,eq) = cmp (a,b) 216 in 217 (if ok then ALL_CONV else comm_conv' solver) THENC 218 ((if eq then term_combine_conv' solver else NO_CONV) ORELSEC 219 (if ok then ALL_CONV else push_conv' solver)) 220 end 221 end tm 222 and push_conv' solver = 223 RAND_CONV (push_conv solver) THENC TRY_CONV (id_conv solver) 224 225 (* Does not raise an exception *) 226 fun ac_conv solver tm = 227 (case total dest_binop tm of 228 NONE => TRY_CONV (term_simplify_conv solver THENC ac_conv solver) 229 | SOME (_,a,b) => 230 if is_binop a then 231 TRY_CONV (assoc_conv solver THENC ac_conv solver) 232 else 233 ((id_conv solver ORELSEC 234 LAND_CONV (term_simplify_conv solver)) THENC 235 ac_conv solver) ORELSEC 236 (if is_binop b then 237 RAND_CONV (ac_conv solver) THENC push_conv solver 238 else 239 (RAND_CONV (term_simplify_conv solver) THENC 240 ac_conv solver) ORELSEC 241 push_conv solver)) tm 242 in 243 (***trace_conv "alg_binop_ac_conv" o***) CHANGED_CONV o ac_conv 244 end; 245 246(* ------------------------------------------------------------------------- *) 247(* Named conversions. *) 248(* ------------------------------------------------------------------------- *) 249 250type named_conv = {name : string, key : Term.term, conv : solver_conv}; 251 252fun named_conv_to_simpset_conv solver_conv = 253 let 254 val {name : string, key : term, conv : conv -> conv} = solver_conv 255 val key = SOME ([] : term list, key) 256 and conv = fn c => fn tms : term list => conv (c tms) 257 and trace = 2 258 in 259 {name = name, key = key, conv = conv, trace = trace} 260 end; 261 262(* ------------------------------------------------------------------------- *) 263(* Subtype contexts. *) 264(* ------------------------------------------------------------------------- *) 265 266val ORACLE = ref false; 267 268fun ORACLE_solver goal = 269 EQT_INTRO (mk_oracle_thm "algebra_dproc" ([],goal)); 270 271type named_conv = {name : string, key : term, conv : conv -> conv}; 272 273datatype context = 274 Context of {rewrites : thm list, 275 conversions : named_conv list, 276 reductions : thm list, 277 judgements : thm list, 278 dproc_cache : (term,thm) Binarymap.dict ref}; 279 280fun pp p context = 281 let 282 val Context {rewrites,conversions,reductions,judgements,...} = context 283 val rewrites = length rewrites 284 and conversions = length conversions 285 and reductions = length reductions 286 and judgements = length judgements 287 in 288 PP.begin_block p PP.INCONSISTENT 1; 289 PP.add_string p ("<" ^ int_to_string rewrites ^ "r" ^ ","); 290 PP.add_break p (1,0); 291 PP.add_string p (int_to_string conversions ^ "c" ^ ","); 292 PP.add_break p (1,0); 293 PP.add_string p (int_to_string reductions ^ "r" ^ ","); 294 PP.add_break p (1,0); 295 PP.add_string p (int_to_string judgements ^ "j>"); 296 PP.end_block p 297 end; 298 299fun to_string context = PP.pp_to_string (!Globals.linewidth) pp context; 300 301val empty = 302 Context {rewrites = [], conversions = [], 303 reductions = [], judgements = [], 304 dproc_cache = cache_new ()}; 305 306fun add_rewrite x context = 307 let 308 val Context {rewrites = r, conversions = c, reductions = d, 309 judgements = j, dproc_cache = m} = context 310 in 311 Context {rewrites = r @ [x], conversions = c, reductions = d, 312 judgements = j, dproc_cache = ref (!m)} 313 end; 314 315fun add_conversion x context = 316 let 317 val Context {rewrites = r, conversions = c, reductions = d, 318 judgements = j, dproc_cache = m} = context 319 in 320 Context {rewrites = r, conversions = c @ [x], reductions = d, 321 judgements = j, dproc_cache = ref (!m)} 322 end; 323 324fun add_reduction x context = 325 let 326 val Context {rewrites = r, conversions = c, reductions = d, 327 judgements = j, dproc_cache = m} = context 328 in 329 Context {rewrites = r, conversions = c, reductions = d @ [x], 330 judgements = j, dproc_cache = ref (!m)} 331 end; 332 333fun add_judgement x context = 334 let 335 val Context {rewrites = r, conversions = c,reductions = d, 336 judgements = j, dproc_cache = m} = context 337 in 338 Context {rewrites = r, conversions = c, reductions = d, 339 judgements = j @ [x], dproc_cache = ref (!m)} 340 end; 341 342local 343 exception State of 344 {assumptions : term list, 345 reductions : tactic list, 346 judgements : tactic list}; 347 348 local 349 val abbrev_rule = prove 350 (``!v t. Abbrev (v = t) ==> (!s. t IN s ==> v IN s)``, 351 RW_TAC std_ss [markerTheory.Abbrev_def]); 352 353 fun reduce_tac th = match_tac th THEN REPEAT CONJ_TAC; 354 355 fun assume_reduction th (State {assumptions,reductions,judgements}) = 356 let 357(*** 358 val () = (print "assume_reduction: "; print_thm th; print "\n") 359***) 360 in 361 State {assumptions = concl th :: assumptions, 362 reductions = reduce_tac th :: reductions, 363 judgements = judgements} 364 end 365 | assume_reduction _ _ = raise Fail "assume_reduction"; 366 367 fun assume_judgement th (State {assumptions,reductions,judgements}) = 368 let 369(*** 370 val () = (print "assume_judgement: "; print_thm th; print "\n") 371***) 372 in 373 State {assumptions = concl th :: assumptions, 374 reductions = reductions, 375 judgements = reduce_tac th :: judgements} 376 end 377 | assume_judgement _ _ = raise Fail "assume_judgement"; 378 in 379 fun initial_state reductions judgements = 380 State {assumptions = [], 381 reductions = map reduce_tac reductions, 382 judgements = map reduce_tac judgements}; 383 384 fun state_add (s,[]) = s 385 | state_add (s, th :: ths) = 386 let 387 val tm = concl th 388 in 389 if is_in tm then state_add (assume_reduction th s, ths) 390 else if is_abbrev tm then 391 state_add (assume_judgement (MATCH_MP abbrev_rule th) s, ths) 392 else if is_conj tm then state_add (s, CONJUNCTS th @ ths) 393 else state_add (s,ths) 394 end; 395 end; 396 397 fun state_apply_dproc dproc_cache dproc_context goal = 398 if not (is_in goal) then 399 raise ERR "algebra_dproc" "not of form X IN Y" 400 else if !ORACLE then ORACLE_solver goal 401 else 402 let 403 val {context, solver = _, conv = _, relation = _, stack = _} = dproc_context 404 val {assumptions,reductions,judgements} = 405 case context of 406 State state => state 407 | _ => raise Bug "state_apply_dproc: wrong exception type" 408 409 fun dproc_tac goal = 410 (REPEAT (cache_tac dproc_cache 411 THEN print_tac "-" 412 THEN FIRST reductions) 413 THEN (FIRST (map (fn tac => tac THEN dproc_tac) judgements) 414 ORELSE reduceLib.REDUCE_TAC) 415 THEN NO_TAC) goal 416 417(*** 418 val _ = (print "algebra_dproc: "; print_term goal; print "\n") 419***) 420 val th = TAC_PROOF ((assumptions,goal), dproc_tac) 421 in 422 EQT_INTRO th 423 end; 424 425 fun algebra_dproc reductions judgements dproc_cache = 426 Traverse.REDUCER {name = NONE, 427 initial = initial_state reductions judgements, 428 addcontext = state_add, 429 apply = state_apply_dproc dproc_cache}; 430in 431 fun simpset_frag context = 432 let 433 val Context {rewrites, conversions, reductions, 434 judgements, dproc_cache} = context 435 val convs = map named_conv_to_simpset_conv conversions 436 val dproc = algebra_dproc reductions judgements dproc_cache 437 in 438 simpLib.SSFRAG 439 {name = NONE, ac = [], congs = [], convs = convs, rewrs = rewrites, 440 dprocs = [dproc], filter = NONE} 441 end; 442 443 fun simpset context = simpLib.++ (std_ss, simpset_frag context); 444end; 445 446(* ------------------------------------------------------------------------- *) 447(* Subtype context pairs: one for simplification, the other for *) 448(* normalization. *) 449(* *) 450(* By convention add_X2 adds to both contexts, add_X2' adds to just the *) 451(* simplify context, and add_X2'' adds to just the normalize context. *) 452(* ------------------------------------------------------------------------- *) 453 454datatype context2 = Context2 of {simplify : context, normalize : context}; 455 456fun pp2 pp alg = 457 let 458 val Context2 {simplify,normalize} = alg 459 in 460 PP.begin_block pp PP.INCONSISTENT 1; 461 PP.add_string pp ("{simplify = " ^ to_string simplify ^ ","); 462 PP.add_break pp (1,0); 463 PP.add_string pp ("normalize = " ^ to_string normalize ^ "}"); 464 PP.end_block pp 465 end; 466 467fun to_string2 context2 = PP.pp_to_string (!Globals.linewidth) pp2 context2; 468 469fun dest2 (Context2 info) = info; 470 471val empty2 = 472 Context2 {simplify = empty, normalize = empty}; 473 474fun add_rewrite2' r (Context2 {simplify,normalize}) = 475 Context2 {simplify = add_rewrite r simplify, normalize = normalize}; 476 477fun add_rewrite2'' r (Context2 {simplify,normalize}) = 478 Context2 {simplify = simplify, normalize = add_rewrite r normalize}; 479 480fun add_rewrite2 r = add_rewrite2' r o add_rewrite2'' r; 481 482fun add_conversion2' r (Context2 {simplify,normalize}) = 483 Context2 {simplify = add_conversion r simplify, normalize = normalize}; 484 485fun add_conversion2'' r (Context2 {simplify,normalize}) = 486 Context2 {simplify = simplify, normalize = add_conversion r normalize}; 487 488fun add_conversion2 c = add_conversion2' c o add_conversion2'' c; 489 490fun add_reduction2' d (Context2 {simplify,normalize}) = 491 Context2 {simplify = add_reduction d simplify, normalize = normalize}; 492 493fun add_reduction2'' d (Context2 {simplify,normalize}) = 494 Context2 {simplify = simplify, normalize = add_reduction d normalize}; 495 496fun add_reduction2 d = add_reduction2' d o add_reduction2'' d; 497 498fun add_judgement2' r (Context2 {simplify,normalize}) = 499 Context2 {simplify = add_judgement r simplify, normalize = normalize}; 500 501fun add_judgement2'' r (Context2 {simplify,normalize}) = 502 Context2 {simplify = simplify, normalize = add_judgement r normalize}; 503 504fun add_judgement2 j = add_judgement2' j o add_judgement2'' j; 505 506fun simpset_frag2 (Context2 {simplify,normalize}) = 507 {simplify = simpset_frag simplify, 508 normalize = simpset_frag normalize}; 509 510fun simpset2 (Context2 {simplify,normalize}) = 511 {simplify = simpset simplify, normalize = simpset normalize}; 512 513end 514