1structure core_decompilerLib :> core_decompilerLib =
2struct
3
4open HolKernel Parse boolLib bossLib
5open helperLib tripleTheory tripleSyntax
6
7val ERR = Feedback.mk_HOL_ERR "core_decompilerLib"
8
9(* ========================================================================== *
10
11  The decompiler as three phases:
12   1. derive specs for each instruction
13   2. calcuate CFG and split into separate 'decompilation rounds'
14   3. for each round: compose specs together to produce function
15
16 * ========================================================================== *)
17
18(* Hooks for ISA specific tools.  *)
19
20val get_triple =
21   ref (fn _: string =>
22          (raise ERR "get_triple" "uninitialised"): helperLib.instruction)
23val initialise = ref (Lib.I: unit -> unit)
24val pc = ref boolSyntax.F
25val pc_size = ref 0
26val at_pc_conv = ref (Lib.I: conv -> conv)
27val swap_primes = ref (Lib.I: term -> term)
28val code_parser = ref (NONE: (string quotation -> string list) option)
29
30fun add_prime v = Term.mk_var (fst (Term.dest_var v) ^ "'", Term.type_of v)
31
32fun configure
33   {triple_fn = f,
34    init_fn = i,
35    pc_tm = p,
36    pc_conv = c,
37    component_vars = vs
38    } =
39   ( get_triple := f
40     ; initialise := i
41     ; pc := p
42     ; at_pc_conv := c
43     ; pc_size := Arbnum.toInt (wordsSyntax.size_of p)
44     ; swap_primes :=
45         Term.subst
46            (List.concat
47               (List.map
48                  (fn v =>
49                     let val pv = add_prime v in [v |-> pv, pv |-> v] end) vs))
50   )
51
52val code_abbrevs = ref ([]:thm list);
53fun add_code_abbrev th = (code_abbrevs := th::(!code_abbrevs));
54
55val decomp_mem = ref ([]:(string * thm * int) list);
56fun add_decomp name th len = (decomp_mem := ((name,th,len)::(!decomp_mem)));
57
58(* PHASE 1 -- evaluate model *)
59
60fun add_to_pc n =
61   [!pc |-> wordsSyntax.mk_word_add (!pc, wordsSyntax.mk_wordii (n, !pc_size))]
62
63local
64   (* vt100 escape string *)
65   val ESC = String.str (Char.chr 0x1B)
66in
67   val inPlaceEcho =
68      if !Globals.interactive
69         then fn s => helperLib.echo 1 ("\n" ^ s)
70      else fn s => helperLib.echo 1 (ESC ^ "[1K" ^ "\n" ^ ESC ^ "[A" ^ s)
71end
72
73local
74   val POST_ASSERT = RAND_CONV o RAND_CONV
75   val PRE_ASSERT = RATOR_CONV o RATOR_CONV o POST_ASSERT
76   val ARITH_SUB_CONV = wordsLib.WORD_ARITH_CONV THENC wordsLib.WORD_SUB_CONV
77   fun is_reducible tm =
78      case Lib.total wordsSyntax.dest_word_add tm of
79         SOME (v, _) => not (Term.is_var v)
80       | _ => not (boolSyntax.is_cond tm)
81   fun PC_CONV tm =
82      if is_reducible tm then (!at_pc_conv) ARITH_SUB_CONV tm else ALL_CONV tm
83   val PC_RULE = Conv.CONV_RULE (PRE_ASSERT PC_CONV THENC POST_ASSERT PC_CONV)
84   fun set_pc n (th, l, j) =
85      (PC_RULE (Thm.INST (add_to_pc n) th), l, Option.map (fn i => n + i) j)
86   fun derive1 hex n =
87      if String.isPrefix "insert:" hex
88         then let
89                 val name =
90                    helperLib.remove_whitespace
91                       (String.extract (hex, size ("insert:"), NONE))
92                 val (_, th, l) = first (fn (n, _, _) => n = name) (!decomp_mem)
93              in
94                 (n + l, (set_pc n (UNDISCH_ALL th, l, SOME l), NONE))
95              end
96      else let
97              val (x as (_, l, _), y) = (!get_triple) hex
98           in
99              (n + l, (set_pc n x, Option.map (set_pc n) y))
100           end
101   fun derive n [] aux l = rev aux
102     | derive n (x::xs) aux l =
103         let
104            val () = inPlaceEcho (" " ^ Int.toString l)
105            val (n', (x, y)) = derive1 x n
106         in
107            derive n' xs ((n, (x, y)) :: aux) (l - 1)
108         end
109in
110   fun derive_specs name code =
111      let
112         val l = length code
113         val s = if l = 1 then "" else "s"
114      in
115         echo 1 ("\nDeriving instruction spec" ^ s ^ " for " ^
116                 Lib.quote name ^ "...\n\n")
117         ; derive 0 code [] l before
118           inPlaceEcho
119              (" Finished " ^ Int.toString l ^ " instruction" ^ s ^ ".\n\n")
120      end
121end
122
123local
124   val tac =
125      SIMP_TAC (srw_ss()) [pred_setTheory.SUBSET_DEF, pred_setTheory.UNION_DEF]
126   val IN_DEFN = Q.prove(`(c = b) ==> a IN b ==> a IN c`, tac)
127   val SUBSET_DEFN = Q.prove(`(c = b) ==> a SUBSET b ==> a SUBSET c`, tac)
128   val IN_LEFT_DEFN = Q.prove(`(c = b UNION d) ==> a IN b ==> a IN c`, tac)
129   val IN_RIGHT_DEFN =
130      Q.prove(`(c = b UNION d) ==> a SUBSET d ==> a SUBSET c`, tac)
131   val SUBSET_REST = Q.prove(`a SUBSET b ==> a SUBSET (b UNION d)`, tac)
132   val SUBSET_UNION2 = Thm.CONJUNCT2 pred_setTheory.SUBSET_UNION
133   fun subset_conv rwts =
134      Conv.LAND_CONV
135        (PURE_REWRITE_CONV ([boolTheory.AND_CLAUSES,
136                             pred_setTheory.EMPTY_SUBSET,
137                             pred_setTheory.INSERT_SUBSET,
138                             pred_setTheory.UNION_SUBSET] @ rwts))
139      THENC PURE_REWRITE_CONV [boolTheory.IMP_CLAUSES]
140   val list_mk_union =
141      HolKernel.list_mk_lbinop (Lib.curry pred_setSyntax.mk_union)
142   val strip_union = HolKernel.strip_binop (Lib.total pred_setSyntax.dest_union)
143   val get_model_name =
144      helperLib.to_lower o fst o Term.dest_const o tripleSyntax.dest_model o
145      Thm.concl
146   fun extract_code (_, ((th, _, _), _)) = tripleSyntax.dest_code (Thm.concl th)
147   val get_code =
148      (((Lib.mk_set o List.concat o List.map pred_setSyntax.strip_set) ##
149        Lib.mk_set) o
150       List.partition pred_setSyntax.is_insert o
151       List.concat o List.map (strip_union o extract_code)):
152       (int * helperLib.instruction) list -> term list * term list
153   val (CONJ1_CONV, CONJ2_CONV) =
154      case Drule.CONJUNCTS (Drule.SPEC_ALL boolTheory.OR_CLAUSES) of
155         c1 :: c2 :: _ => (Conv.REWR_CONV c1, Conv.REWR_CONV c2)
156       | _ => fail()
157in
158   fun abbreviate_code _ [] = raise ERR "abbreviate_code" "no code"
159     | abbreviate_code name (thms as (_: int, ((th, _, _), _)) :: _) =
160      let
161         val (newcode, acode) = get_code thms
162         val (cs, l, r) =
163            if List.null acode
164               then let
165                       val l = pred_setSyntax.mk_set newcode
166                    in
167                       (l, l, boolSyntax.F)
168                    end
169            else if List.null newcode
170               then let
171                       val r = list_mk_union acode
172                    in
173                       (r, boolSyntax.F, r)
174                    end
175            else let
176                     val l = pred_setSyntax.mk_set newcode
177                     val r = list_mk_union acode
178                 in
179                     (pred_setSyntax.mk_union (l, r), l, r)
180                 end
181         val def_name = name ^ "_" ^ get_model_name th
182         val x = pairSyntax.list_mk_pair (Term.free_vars cs)
183         val v =
184            Term.mk_var (def_name, Term.type_of (pairSyntax.mk_pabs (x, cs)))
185         val code_def =
186            Definition.new_definition
187               (def_name ^ "_def", boolSyntax.mk_eq (Term.mk_comb (v, x), cs))
188         val () = add_code_abbrev code_def
189         val spec_code_def = Drule.SPEC_ALL code_def
190         val tm = boolSyntax.lhs (Thm.concl spec_code_def)
191         val sty = Term.type_of cs
192         val ty = pred_setSyntax.dest_set_type sty
193         val inst_ty = Thm.INST_TYPE [Type.alpha |-> ty]
194         fun f thm = MATCH_MP thm spec_code_def
195         val (in_left, in_right) =
196            if List.null newcode
197               then (TRUTH, f SUBSET_DEFN)
198            else if List.null acode
199               then (f IN_DEFN, TRUTH)
200            else (f IN_LEFT_DEFN, f IN_RIGHT_DEFN)
201         val rwts1 =
202            if List.null newcode
203               then []
204            else let
205                    val a = Term.mk_var ("a", ty)
206                    val refl_cnv =
207                       Conv.REWR_CONV (inst_ty boolTheory.REFL_CLAUSE)
208                    val in_cnv =
209                       Conv.REWR_CONV (inst_ty pred_setTheory.IN_INSERT)
210                    fun expand_in n tm =
211                       if n <= 0
212                          then (in_cnv
213                                THENC Conv.LAND_CONV refl_cnv
214                                THENC CONJ1_CONV) tm
215                       else (in_cnv
216                             THENC Conv.RAND_CONV (expand_in (n - 1))
217                             THENC CONJ2_CONV) tm
218                    fun cnv i = EQT_ELIM o expand_in i
219                 in
220                    Lib.mapi
221                       (fn i => fn c =>
222                           MP (Thm.INST [a |-> c] in_left)
223                              (cnv i (pred_setSyntax.mk_in (c, l)))) newcode
224                 end
225         val rwts2 =
226            if List.null acode
227               then []
228            else let
229                    val a = Term.mk_var ("a", sty)
230                    val n = List.length acode - 1
231                    val tac1 = MATCH_MP_TAC (inst_ty SUBSET_REST)
232                    val tac2 =
233                       REWRITE_TAC
234                         [inst_ty SUBSET_UNION2, pred_setTheory.SUBSET_REFL]
235                 in
236                     Lib.mapi
237                        (fn i => fn c =>
238                            MP (Thm.INST [a |-> c] in_right)
239                            (Tactical.prove
240                               (pred_setSyntax.mk_subset (c, r),
241                                NTAC (n - i) tac1 THEN tac2))) acode
242                 end
243         val rule = CONV_RULE (subset_conv (rwts1 @ rwts2)) o
244                    Thm.SPEC tm o MATCH_MP TRIPLE_EXTEND
245      in
246         List.map (fn (i, x) => (i, helperLib.instruction_apply rule x)) thms
247      end
248end
249
250fun stage_1 name qcode =
251   let
252      val p = Option.getOpt (!code_parser, helperLib.quote_to_strings)
253   in
254      (!initialise) ()
255    ; abbreviate_code name (derive_specs name (p qcode))
256   end
257
258(* Testing
259val name = "test"
260val qcode = `e59f322c  00012f94
261             e59f222c  00012f80
262             edd37a00`
263val (_, ((th, _, _), _)) = hd thms
264*)
265
266
267(* PHASE 2 -- compute CFG *)
268
269val extract_graph =
270   List.concat o
271   List.map (fn (i, ((_, _, j), NONE): helperLib.instruction) => [(i: int, j)]
272              | (i, ((_, _, j), SOME (_, _, k))) => [(i, j), (i, k)])
273
274val jumps2edges =
275    List.concat o
276    List.map (fn (i, NONE) => []: (int * int) list
277               | (i, SOME j) => [(i, j)])
278
279val all_distinct = Lib.mk_set
280
281fun drop_until P [] = []
282  | drop_until P (x :: xs) = if P x then x :: xs else drop_until P xs
283
284fun subset [] ys = true
285  | subset (x :: xs) ys = mem x ys andalso subset xs ys
286
287local
288   fun all_paths_from edges i prefix =
289      let
290         fun f [] = []
291           | f ((k, j) :: xs) = if i = k then j :: f xs else f xs
292        val next = all_distinct (f edges)
293        val prefix = prefix @ [i]
294        val xs = map (fn x => if mem x prefix
295                                 then [prefix @ [x]]
296                              else all_paths_from edges x prefix) next
297        val xs = if xs = [] then [[prefix]] else xs
298      in
299         Lib.flatten xs
300      end
301
302   fun is_loop xs = Lib.mem (List.last xs) (Lib.butlast xs)
303
304   (* clean loop tails *)
305   fun clean_tails (i, xs, tails) =
306      (i, xs,
307       List.mapPartial
308          (fn t => let
309                      val l = drop_until (fn x => not (mem x xs)) t
310                   in
311                      if List.null l then NONE else SOME l
312                   end) tails)
313
314   fun cross [] ys = []
315     | cross (x :: xs) ys = map (fn y => (x, y)) ys @ cross xs ys
316
317   fun sat_goal ((i, j), path) = hd path = i andalso mem j (tl path)
318
319   fun find_and_merge zs =
320      let
321         val ls = Lib.flatten (map (fn (x, y, z) => x) zs)
322         val qs = map (fn (x, y, z) => (x, y, map hd z)) zs
323         fun f ys = filter (fn x => mem x ls andalso (not (mem x ys)))
324         val qs = map (fn (x, y, z) => (x, all_distinct (f x y @ f x z))) qs
325         val edges = Lib.flatten (map (fn (x,y) => cross x y) qs)
326         val paths = map (fn i => all_paths_from edges i []) ls
327         val goals = map (fn (x,y) => (y,x)) edges
328         val (i, j) =
329            fst (hd (filter sat_goal (cross goals (Lib.flatten paths))))
330         val (p1, q1, x1) = hd (filter (fn (x,y,z) => mem i x) zs)
331         val (p2, q2, x2) = hd (filter (fn (x,y,z) => mem j x) zs)
332         val (p, q, x) = (p1 @ p2, all_distinct (q1 @ q2), x1 @ x2)
333         val zs =
334            (p,q,x) ::
335            filter (fn (x, y, z) => not (mem i x) andalso not (mem j x)) zs
336         val zs = map clean_tails zs
337      in
338         zs
339      end
340
341   fun mem_all x = List.all (Lib.mem x)
342
343   fun find_exit_points (x, y, z) =
344      let
345         val q = hd (filter (fn x => mem_all x (tl z)) (hd z))
346      in
347         (x, [q])
348      end
349      handle Empty => (x, all_distinct (map hd z))
350
351   fun list_before x y [] = true
352     | list_before x y (z :: zs) =
353           z <> y andalso (z = x orelse list_before x y zs)
354
355   val int_sort = Lib.sort (Lib.curry (op Int.<=))
356in
357   fun extract_loops jumps =
358      let
359         (* find all possible paths *)
360         val edges = jumps2edges jumps
361         val paths = all_paths_from edges 0 []
362         (* get looping points *)
363         val loops = all_distinct (map last (filter is_loop paths))
364         (* find loop bodies and tails *)
365         fun loop_body_tail i =
366            let
367               val bodies = filter (fn xs => last xs = i) paths
368               val bodies = filter is_loop bodies
369               val bodies = map (drop_until (fn x => x = i) o butlast) bodies
370               val bodies = all_distinct (Lib.flatten bodies)
371               val tails =
372                  filter (fn xs => mem i xs andalso not (last xs = i)) paths
373               val tails = map (drop_until (fn x => x = i)) tails
374            in
375               (fn (x, y, z) => ([x], y, z)) (clean_tails (i, bodies, tails))
376            end
377         val zs = map loop_body_tail loops
378         (* merge combined loops *)
379         val zs = repeat find_and_merge zs
380         (* attempt to find common exit point *)
381         val zs = map find_exit_points zs
382         (* finalise *)
383         val exit = (all_distinct o map last o filter (not o is_loop)) paths
384         val zero = ([0], exit)
385         val zs =
386            if List.null
387                 (filter (fn (x, y) => mem 0 x andalso subset exit y) zs)
388               then zs @ [zero]
389            else zs
390         fun compare (xs, _) (ys, _) =
391            let
392               val x = hd xs
393               val y = hd ys
394               val p = hd (filter (fn xs => mem x xs andalso mem y xs) paths)
395            in
396               not (list_before x y p)
397            end
398            handle Empty => false
399         val loops = sort compare zs
400         (* sort internal  *)
401         val loops = map (int_sort ## int_sort) loops
402      in
403         loops
404      end
405end
406
407local
408   fun forks_acc a =
409     fn [] => List.rev a
410      | (x1, _) :: xs =>
411          if List.exists (fn (x2, _) => x2 = x1) xs
412             then forks_acc (x1 :: a) (filter (fn (x2, _) => x2 <> x1) xs)
413          else forks_acc a xs
414in
415   val forks = forks_acc []
416end
417
418fun stage_12 name qcode =
419   let
420      val thms = stage_1 name qcode
421      val jumps = extract_graph thms
422      val loops = extract_loops jumps
423      val loops =
424         case loops of
425            [([0], [n])] =>
426               let
427                  val fs = Lib.sort (Lib.curry (op >=)) (forks jumps)
428               in
429                  map (fn f => ([f], [n])) fs @ [([0], [n])]
430               end
431          | other => other
432   in
433      (thms, loops)
434   end
435
436
437(* PHASE 3 -- compose and extract *)
438
439datatype compose_tree =
440    End of int
441  | Repeat of int
442  | Cons of thm * compose_tree
443  | Merge of term * compose_tree * compose_tree
444  | ConsMerge of term * thm * thm * compose_tree
445
446fun is_rec (Repeat _) = true
447  | is_rec (End _) = false
448  | is_rec (Cons (_, t)) = is_rec t
449  | is_rec (Merge (_, t1, t2)) = is_rec t1 orelse is_rec t2
450  | is_rec (ConsMerge (_, _, _, t)) = is_rec t
451
452local
453   val (_, _, _, is_abbrev) = HolKernel.syntax_fns1 "marker" "Abbrev"
454   fun get_Abbrev th =
455      case Thm.hyp th of
456         [h] => Lib.with_exn (Term.rand o HolKernel.find_term is_abbrev) h
457                   (ERR "get_Abbrev" "Abbrev not found")
458       | _ => raise ERR "get_Abbrev" "not a single hyp"
459in
460   fun build_compose_tree (b, e) thms =
461      let
462         fun find_next i = first (fn (n, _:helperLib.instruction) => n = i) thms
463         fun sub init NONE =
464               raise ERR "build_compose_tree" "cannot handle bad exists"
465           | sub init (SOME i) =
466             if mem i e
467                then End i
468             else if not init andalso mem i b
469                then Repeat i
470             else case find_next i of
471                     (_, ((th1, l1, x1), NONE)) => Cons (th1, sub false x1)
472                   | (_, ((th1, l1, x1), SOME (th2, l2, x2))) =>
473                     if x1 = x2
474                        then ConsMerge (get_Abbrev th1, th1, th2, sub false x1)
475                     else let
476                             val t1 = Cons (th1, sub false x1)
477                             val t2 = Cons (th2, sub false x2)
478                          in
479                             Merge (get_Abbrev th1, t1, t2)
480                          end
481      in
482         sub true (SOME (hd b))
483      end
484end
485
486val l1 = ref TRUTH
487val l2 = ref TRUTH
488val l3 = ref T;
489
490local
491   fun VALUE_RULE c = CONV_RULE (RAND_CONV (RAND_CONV c))
492   val PAIR_RULE =
493      CONV_RULE ((RATOR_CONV o RAND_CONV) (PURE_REWRITE_CONV [pairTheory.PAIR]))
494   fun ii lhs rhs =
495      let
496         val (x, y) = pairSyntax.dest_pair rhs
497         val x1 = pairSyntax.mk_fst lhs
498         val y1 = pairSyntax.mk_snd lhs
499      in
500         (x |-> x1) :: ii y1 y
501      end
502      handle HOL_ERR _ => [rhs |-> lhs]
503in
504   fun compose th1 th2 =
505      let
506         val th3 = MATCH_MP (MATCH_MP TRIPLE_COMPOSE_ALT th2) th1
507      in
508         case Thm.hyp th3 of
509            [] => th3
510          | tm :: _ =>
511              let
512                 val lemma = SYM (ASSUME tm)
513                 val (lhs, rhs) = dest_eq tm
514                 val th4 = VALUE_RULE
515                              (PairRules.UNPBETA_CONV rhs
516                               THENC REWR_CONV (SYM (SPEC_ALL LET_THM))
517                               THENC RAND_CONV (fn _ => lemma)) th3
518              in
519                 MP (PAIR_RULE (INST (ii lhs rhs) (DISCH_ALL th4))) (REFL lhs)
520              end
521      end
522      handle HOL_ERR e => (l1 := th1; l2 := th2; raise HOL_ERR e)
523end
524
525(*
526val th1 = !l1
527val th2 = !l2
528val tm = !l3
529*)
530
531local
532   val rule1 =
533      REWRITE_RULE [markerTheory.Abbrev_def, addressTheory.CONTAINER_def]
534   val rule2 = CONV_RULE ((RAND_CONV o RAND_CONV) (SIMP_CONV bool_ss []))
535in
536   fun merge tm th1 th2 =
537      let
538         val th1 = DISCH tm (rule1 th1)
539         val th2 = DISCH (mk_neg tm) (rule1 th2)
540      in
541         rule2 (MATCH_MP COND_MERGE (CONJ th1 th2))
542      end
543      handle HOL_ERR e => (l1 := th1; l2 := th2; l3 := tm; raise HOL_ERR e)
544end
545
546(*
547  fun fan (End i) = 1
548    | fan (Repeat i) = 1
549    | fan (Cons (th,t)) = fan t
550    | fan (Merge (tm,t1,t2)) = fan t1 + fan t2
551    | fan (ConsMerge (tm,th1,th2,t)) = fan t + fan t
552
553  fan t
554*)
555
556val case_sum_conv =
557   SIMP_CONV (std_ss++simpLib.rewrites [tripleTheory.case_sum_def]) []
558val case_sum_rule = Conv.CONV_RULE (Conv.RATOR_CONV case_sum_conv)
559val beta_fst_rule =
560   CONV_RULE ((RATOR_CONV o RAND_CONV)
561                (DEPTH_CONV PairedLambda.GEN_BETA_CONV
562                 THENC REWRITE_CONV [pairTheory.FST])
563              THENC REWRITE_CONV [boolTheory.IMP_CLAUSES])
564val forall_prod_ss = pure_ss++simpLib.rewrites [pairTheory.FORALL_PROD]
565
566fun round (input, get_assert, triple_refl) =
567   fn name => fn (b, e) => fn thms =>
568   let
569      val () = inPlaceEcho (name ^ ": ")
570      val () = echo 1 "."
571      val t = build_compose_tree (b, e) thms
572      val loop = is_rec t
573      val pre = get_assert (hd b)
574      val post = get_assert (hd e)
575      val (enter_th, exit_th) =
576         if loop
577            then let
578                    fun refl_sum f =
579                       triple_refl
580                          (tripleSyntax.mk_case_sum
581                             (pre, post, f (input, Term.type_of input)))
582                 in
583                    (refl_sum sumSyntax.mk_inl, refl_sum sumSyntax.mk_inr)
584                 end
585         else (boolTheory.TRUTH, triple_refl (Term.mk_comb (post, input)))
586      (* perform composition *)
587      val () = echo 1 "."
588      fun comp (End i) = exit_th
589        | comp (Repeat i) = enter_th
590        | comp (Cons (th, t)) = compose th (comp t)
591        | comp (Merge (tm, t1, t2)) = merge tm (comp t1) (comp t2)
592        | comp (ConsMerge (tm, th1, th2, t)) =
593            let
594               val res = comp t
595            in
596               merge tm (compose th1 res) (compose th2 res)
597            end
598      val th =
599         CONV_RULE ((RAND_CONV o RAND_CONV) (PairRules.UNPBETA_CONV input))
600            (comp t)
601      val th =
602         if loop
603            then let
604                    val () = echo 1 "."
605                    (* apply loop rule *)
606                    val lemma = case_sum_conv (mk_comb (pre, input))
607                    val th = CONV_RULE ((RATOR_CONV o RATOR_CONV o RAND_CONV)
608                                           (fn _ => GSYM lemma)) th
609                    val x = Term.mk_var ("x", Term.type_of input)
610                    val tm = mk_forall (x, subst [input |-> x] (Thm.concl th))
611                    val lemma =
612                       prove
613                         (tm,
614                          FULL_SIMP_TAC forall_prod_ss []
615                          THEN REPEAT STRIP_TAC
616                          THEN MATCH_MP_TAC (DISCH T th)
617                          THEN FULL_SIMP_TAC std_ss [])
618                    val th =
619                       MATCH_MP SHORT_TERM_TAILREC lemma
620                       |> beta_fst_rule
621                       |> SPEC input
622                       |> CONV_RULE ((RATOR_CONV o RATOR_CONV o RAND_CONV)
623                                         PairRules.PBETA_CONV)
624                 in
625                    th
626                 end
627         else th
628      val () = echo 1 "."
629      val ff = th |> concl |> rand |> rand |> rator
630      val def =
631         new_definition (name ^ "_def", mk_eq (mk_var (name, type_of ff), ff))
632      val th = th |> CONV_RULE
633                       ((RAND_CONV o RAND_CONV o RATOR_CONV) (fn _ => GSYM def))
634      (* clean up result *)
635      val lemma = mk_eq(mk_comb (boolSyntax.lhs (concl def), input),
636                        (!swap_primes) input)
637                  |> ASSUME
638      val result =
639         th |> CONV_RULE ((RAND_CONV o RAND_CONV) (fn _ => lemma)
640                          THENC RAND_CONV PairRules.PBETA_CONV)
641            |> DISCH_ALL
642      val () = echo 1 "."
643   in
644      (def, result)
645   end
646
647fun core_decompile name qcode =
648   let
649      val (thms, loops) = time (stage_12 name) qcode
650      val (_, ((th, _, _), _)) = Lib.first (fn (x, _) => x = 0) thms
651      val (model, pre, code, _) = tripleSyntax.dest_triple (Thm.concl th)
652      val (cnd, asrt) = pairSyntax.dest_pair pre
653      val affected_vars = asrt |> Term.free_vars |> filter (fn v => v <> !pc)
654      val inp = pairSyntax.list_mk_pair (cnd :: affected_vars)
655      fun get_assert i = pairSyntax.mk_pabs (inp, Term.subst (add_to_pc i) pre)
656      val refl = Thm.SPEC code (Drule.ISPEC model tripleTheory.TRIPLE_REFL)
657      fun triple_refl t = case_sum_rule (Thm.SPEC t refl)
658      val round = round (inp, get_assert, triple_refl)
659      fun rounds loops thms defs =
660         let
661            val (b, e) = hd loops
662            val loops = tl loops
663            val n = length loops
664            val part_name = if n = 0 then name
665                            else name ^ "_part" ^ (int_to_string n)
666            val (def, result) = round part_name (b, e) thms
667            val thms =
668               (hd b, ((UNDISCH_ALL result, 0, SOME (hd e)), NONE)) :: thms
669         in
670            if n = 0
671               then (result, rev (def :: defs))
672            else rounds loops thms (def :: defs)
673         end
674      val () = echo 1 "\nProcessing...\n\n"
675      val (res, defs) =
676         time (fn () => rounds loops thms [] before inPlaceEcho "Finished.\n\n")
677              ()
678      val () = add_decomp name res (loops |> last |> snd |> hd)
679   in
680      (res, LIST_CONJ defs)
681   end
682
683end
684