1(* ------------------------------------------------------------------------
2   CHERI step evaluator
3   ------------------------------------------------------------------------ *)
4
5structure cheri_stepLib :> cheri_stepLib =
6struct
7
8open HolKernel boolLib bossLib
9open blastLib cheriTheory cheri_stepTheory
10
11val ERR = Feedback.mk_HOL_ERR "cheri_stepLib"
12
13val () = show_assums := true
14
15(* ------------------------------------------------------------------------- *)
16
17(* Fetch *)
18
19local
20   val ty32 = fcpSyntax.mk_int_numeric_type 32
21   val w = Term.mk_var ("w", wordsSyntax.mk_int_word_type 32)
22   fun mk_opc l =
23     let
24       val (b1, l) = Lib.split_after 8 l
25       val (b2, l) = Lib.split_after 8 l
26       val (b3, b4) = Lib.split_after 8 l
27     in
28       bitstringSyntax.mk_v2w
29         (listSyntax.mk_list (b1 @ b2 @ b3 @ b4, Type.bool), ty32)
30     end
31in
32   fun pad_opcode v =
33      let
34         val (l, ty) = listSyntax.dest_list v
35      in
36         General.ignore (ty = Type.bool andalso List.length l <= 32 orelse
37         raise ERR "pad_opcode" "bad opcode")
38       ; utilsLib.padLeft boolSyntax.F 32 l
39      end
40   fun fetch v =
41     Thm.INST [w |-> mk_opc (pad_opcode v)] cheri_stepTheory.Fetch_default
42   val fetch_hex = fetch o bitstringSyntax.bitstring_of_hexstring
43end
44
45(* ------------------------------------------------------------------------- *)
46
47(* Decoder *)
48
49local
50   val v = bitstringSyntax.mk_vec 32 0
51   val Decode =
52      Decode_def
53      |> Thm.SPEC v
54      |> Conv.RIGHT_CONV_RULE
55             (
56              REWRITE_CONV [cheriTheory.boolify32_v2w]
57              THENC Conv.DEPTH_CONV PairedLambda.let_CONV
58              THENC Conv.DEPTH_CONV bitstringLib.extract_v2w_CONV
59             )
60   val v = fst (bitstringSyntax.dest_v2w v)
61   val unpredictable_tm = ``cheri$Unpredictable``
62   fun fix_unpredictable thm =
63      let
64         val thm = REWRITE_RULE [not31] thm
65      in
66         case Lib.total (boolSyntax.dest_cond o utilsLib.rhsc) thm of
67            SOME (b, t, _) =>
68               if t = unpredictable_tm
69                  then REWRITE_RULE [ASSUME (boolSyntax.mk_neg b)] thm
70               else thm
71          | _ => thm
72      end
73in
74   fun DecodeCHERI pat =
75      let
76         val s = fst (Term.match_term v pat)
77      in
78         Decode |> Thm.INST s
79                |> REWRITE_RULE []
80                |> Conv.RIGHT_CONV_RULE (Conv.REPEATC PairedLambda.let_CONV)
81                |> fix_unpredictable
82      end
83end
84
85val cheri_ipatterns = List.map (I ## utilsLib.pattern)
86   [
87    ("ADDI",   "FFTFFF__________________________"),
88    ("ADDIU",  "FFTFFT__________________________"),
89    ("SLTI",   "FFTFTF__________________________"),
90    ("SLTIU",  "FFTFTT__________________________"),
91    ("ANDI",   "FFTTFF__________________________"),
92    ("ORI",    "FFTTFT__________________________"),
93    ("XORI",   "FFTTTF__________________________"),
94    ("DADDI",  "FTTFFF__________________________"),
95    ("DADDIU", "FTTFFT__________________________"),
96    ("MULT",   "FFFFFF__________FFFFFFFFFFFTTFFF"),
97    ("MULTU",  "FFFFFF__________FFFFFFFFFFFTTFFT"),
98    ("DMULT",  "FFFFFF__________FFFFFFFFFFFTTTFF"),
99    ("DMULTU", "FFFFFF__________FFFFFFFFFFFTTTFT"),
100    ("MADD",   "FTTTFF__________FFFFFFFFFFFFFFFF"),
101    ("MADDU",  "FTTTFF__________FFFFFFFFFFFFFFFT"),
102    ("MSUB",   "FTTTFF__________FFFFFFFFFFFFFTFF"),
103    ("MSUBU",  "FTTTFF__________FFFFFFFFFFFFFTFT"),
104    ("MUL",    "FTTTFF_______________FFFFFFFFFTF"),
105    ("BEQ",    "FFFTFF__________________________"),
106    ("BNE",    "FFFTFT__________________________"),
107    ("BEQL",   "FTFTFF__________________________"),
108    ("BNEL",   "FTFTFT__________________________")
109   ]
110
111val cheri_dpatterns = List.map (I ## utilsLib.pattern)
112   [
113    ("JALR",   "FFFFFF_____FFFFF__________FFTFFT")
114   ]
115
116val cheri_rpatterns = List.map (I ## utilsLib.pattern)
117   [
118    ("SLLV",   "FFFFFF_______________FFFFFFFFTFF"),
119    ("SRLV",   "FFFFFF_______________FFFFFFFFTTF"),
120    ("SRAV",   "FFFFFF_______________FFFFFFFFTTT"),
121    ("MOVZ",   "FFFFFF_______________FFFFFFFTFTF"),
122    ("MOVN",   "FFFFFF_______________FFFFFFFTFTT"),
123    ("DSLLV",  "FFFFFF_______________FFFFFFTFTFF"),
124    ("DSRLV",  "FFFFFF_______________FFFFFFTFTTF"),
125    ("DSRAV",  "FFFFFF_______________FFFFFFTFTTT"),
126    ("ADD",    "FFFFFF_______________FFFFFTFFFFF"),
127    ("ADDU",   "FFFFFF_______________FFFFFTFFFFT"),
128    ("SUB",    "FFFFFF_______________FFFFFTFFFTF"),
129    ("SUBU",   "FFFFFF_______________FFFFFTFFFTT"),
130    ("AND",    "FFFFFF_______________FFFFFTFFTFF"),
131    ("OR",     "FFFFFF_______________FFFFFTFFTFT"),
132    ("XOR",    "FFFFFF_______________FFFFFTFFTTF"),
133    ("NOR",    "FFFFFF_______________FFFFFTFFTTT"),
134    ("SLT",    "FFFFFF_______________FFFFFTFTFTF"),
135    ("SLTU",   "FFFFFF_______________FFFFFTFTFTT"),
136    ("DADD",   "FFFFFF_______________FFFFFTFTTFF"),
137    ("DADDU",  "FFFFFF_______________FFFFFTFTTFT"),
138    ("DSUB",   "FFFFFF_______________FFFFFTFTTTF"),
139    ("DSUBU",  "FFFFFF_______________FFFFFTFTTTT")
140   ]
141
142val cheri_jpatterns = List.map (I ## utilsLib.pattern)
143   [
144    ("SLL",    "FFFFFFFFFFF_______________FFFFFF"),
145    ("SRL",    "FFFFFFFFFFF_______________FFFFTF"),
146    ("SRA",    "FFFFFFFFFFF_______________FFFFTT"),
147    ("DSLL",   "FFFFFFFFFFF_______________TTTFFF"),
148    ("DSRL",   "FFFFFFFFFFF_______________TTTFTF"),
149    ("DSRA",   "FFFFFFFFFFF_______________TTTFTT"),
150    ("DSLL32", "FFFFFFFFFFF_______________TTTTFF"),
151    ("DSRL32", "FFFFFFFFFFF_______________TTTTTF"),
152    ("DSRA32", "FFFFFFFFFFF_______________TTTTTT")
153   ]
154
155val cheri_patterns0 = List.map (I ## utilsLib.pattern)
156   [
157    ("LUI",     "FFTTTTFFFFF_____________________"),
158    ("DIV",     "FFFFFF__________FFFFFFFFFFFTTFTF"),
159    ("DIVU",    "FFFFFF__________FFFFFFFFFFFTTFTT"),
160    ("DDIV",    "FFFFFF__________FFFFFFFFFFFTTTTF"),
161    ("DDIVU",   "FFFFFF__________FFFFFFFFFFFTTTTT"),
162    ("MTHI",    "FFFFFF_____FFFFFFFFFFFFFFFFTFFFT"),
163    ("MTLO",    "FFFFFF_____FFFFFFFFFFFFFFFFTFFTT"),
164    ("MFHI",    "FFFFFFFFFFFFFFFF_____FFFFFFTFFFF"),
165    ("MFLO",    "FFFFFFFFFFFFFFFF_____FFFFFFTFFTF"),
166    ("BLTZ",    "FFFFFT_____FFFFF________________"),
167    ("BGEZ",    "FFFFFT_____FFFFT________________"),
168    ("BLTZL",   "FFFFFT_____FFFTF________________"),
169    ("BGEZL",   "FFFFFT_____FFFTT________________"),
170    ("BLTZAL",  "FFFFFT_____TFFFF________________"),
171    ("BGEZAL",  "FFFFFT_____TFFFT________________"),
172    ("BLTZALL", "FFFFFT_____TFFTF________________"),
173    ("BGEZALL", "FFFFFT_____TFFTT________________"),
174    ("BLEZ",    "FFFTTF_____FFFFF________________"),
175    ("BGTZ",    "FFFTTT_____FFFFF________________"),
176    ("BLEZL",   "FTFTTF_____FFFFF________________"),
177    ("BGTZL",   "FTFTTT_____FFFFF________________"),
178    ("JR",      "FFFFFF_____FFFFFFFFFF_____FFTFFF")
179   ]
180
181(*
182val cheri_cpatterns = List.map (I ## utilsLib.pattern)
183   [
184    ("MFC0",    "FTFFFFFFFFF__________FFFFFFFF___"),
185    ("MTC0",    "FTFFFFFFTFF__________FFFFFFFF___")
186   ]
187*)
188
189val cheri_patterns = List.map (I ## utilsLib.pattern)
190   [
191    ("J",       "FFFFTF__________________________"),
192    ("JAL",     "FFFFTT__________________________"),
193    ("LDL",     "FTTFTF__________________________"),
194    ("LDR",     "FTTFTT__________________________"),
195    ("LB",      "TFFFFF__________________________"),
196(*  ("LH",      "TFFFFT__________________________"), *)
197    ("LWL",     "TFFFTF__________________________"),
198    ("LW",      "TFFFTT__________________________"),
199    ("LBU",     "TFFTFF__________________________"),
200(*  ("LHU",     "TFFTFT__________________________"), *)
201    ("LWR",     "TFFTTF__________________________"),
202    ("LWU",     "TFFTTT__________________________"),
203    ("SB",      "TFTFFF__________________________"),
204(*  ("SH",      "TFTFFT__________________________"), *)
205    ("SW",      "TFTFTT__________________________"),
206    ("LL",      "TTFFFF__________________________"),
207    ("LLD",     "TTFTFF__________________________"),
208    ("LD",      "TTFTTT__________________________"),
209    ("SC",      "TTTFFF__________________________"),
210    ("SCD",     "TTTTFF__________________________"),
211    ("SD",      "TTTTTT__________________________")
212(*  ("ERET",    "FTFFFFTFFFFFFFFFFFFFFFFFFFFTTFFF")  *)
213   ]
214
215local
216   val patterns =
217      List.concat [cheri_ipatterns, cheri_jpatterns, cheri_dpatterns,
218                   cheri_rpatterns, cheri_patterns0, cheri_patterns]
219   fun padded_opcode v = listSyntax.mk_list (pad_opcode v, Type.bool)
220   val get_opc = boolSyntax.rand o boolSyntax.rand o utilsLib.lhsc
221   fun mk_net l =
222      List.foldl
223         (fn ((s:string, p), nt) =>
224            let
225               val thm = DecodeCHERI p
226            in
227               LVTermNet.insert (nt, ([], get_opc thm), (s, thm))
228            end)
229         LVTermNet.empty l
230   fun find_opcode net =
231      let
232         fun find_opc tm =
233            case LVTermNet.match (net, tm) of
234               [(([], opc), (name, thm))] => SOME (name:string, opc, thm:thm)
235             | _ => NONE
236      in
237         fn v =>
238            let
239               val pv = padded_opcode v
240            in
241               Option.map
242                  (fn (name, opc, thm) =>
243                     (name, opc, thm, fst (Term.match_term opc pv)))
244                  (find_opc pv)
245            end
246      end
247   fun x i = Term.mk_var ("x" ^ Int.toString i, Type.bool)
248   fun assign_bits (p, i, n) =
249      let
250         val l = (i, n) |> (Arbnum.fromInt ## Lib.I)
251                        |> bitstringSyntax.padded_fixedwidth_of_num
252                        |> bitstringSyntax.dest_v2w |> fst
253                        |> listSyntax.dest_list |> fst
254      in
255         Term.subst (Lib.mapi (fn i => fn b => x (i + p) |-> b) l)
256      end
257   val r0  = assign_bits (0, 0, 5)
258   val r5  = assign_bits (5, 0, 5)
259   val r10 = assign_bits (10, 0, 5)
260   val sel = assign_bits (10, 0, 3)
261   val dbg = assign_bits (5, 23, 5) o sel
262   val err = assign_bits (5, 26, 5) o sel
263   fun fnd l = find_opcode (mk_net l)
264   fun fnd2 l tm = Option.map (fn (s, t, _, _) => (s, t)) (fnd l tm)
265   fun comb (0, _    ) = [[]]
266     | comb (_, []   ) = []
267     | comb (m, x::xs) = map (fn y => x :: y) (comb (m-1, xs)) @ comb (m, xs)
268   fun all_comb l =
269     List.concat (List.tabulate (List.length l + 1, fn i => comb (i, l)))
270   fun sb l =
271      all_comb
272         (List.map
273            (fn (x, f:term -> term) => (fn (s, t) => (s ^ "_" ^ x, f t))) l)
274   val fnd_sb = fnd2 ## sb
275   val fp = fnd_sb (cheri_patterns, [])
276   val f0 = fnd_sb (cheri_patterns0, [("0", r0)])
277   val fd = fnd_sb (cheri_dpatterns, [("d0", r10)])
278   val fi = fnd_sb (cheri_ipatterns, [("s0", r0), ("t0", r5)])
279   val fj = fnd_sb (cheri_jpatterns, [("t0", r0), ("d0", r5)])
280   val fr = fnd_sb (cheri_rpatterns, [("s0", r0), ("t0", r5), ("d0", r10)])
281   (*
282   val fc = (fnd2 cheri_cpatterns,
283               [[fn (s, t) => (s ^ "_debug", dbg t)],
284                [fn (s, t) => (s ^ "_errctl", err t)]])
285   *)
286   fun try_patterns [] tm = []
287     | try_patterns ((f, l) :: r) tm =
288         (case f tm of
289             SOME x => List.map (List.foldl (fn (f, a) => f a) x) l
290           | NONE => try_patterns r tm)
291   val find_opc = try_patterns [fi, fr, fp, fj, fd, f0]
292   val cheri_find_opc_ = fnd patterns
293in
294   val hex_to_padded_opcode =
295      padded_opcode o bitstringSyntax.bitstring_of_hexstring
296   fun cheri_decode v =
297      case cheri_find_opc_ v of
298         SOME (_, _, thm, s) => if List.null s then thm else Thm.INST s thm
299       | NONE => raise ERR "decode" (utilsLib.long_term_to_string v)
300   val cheri_decode_hex = cheri_decode o hex_to_padded_opcode
301   fun cheri_find_opc opc =
302      let
303         val l = find_opc opc
304      in
305         List.filter (fn (_, p) => Lib.can (Term.match_term p) opc) l
306      end
307   val cheri_dict = Redblackmap.fromList String.compare patterns
308   (* fun mk_cheri_pattern s = Redblackmap.peek (dict, utilsLib.uppercase s) *)
309end
310
311(*
312  List.map (cheri_decode o snd) (Redblackmap.listItems cheri_dict)
313*)
314
315(* ------------------------------------------------------------------------- *)
316
317(* Evaluator *)
318
319local
320   val eval_simp_rule =
321      utilsLib.ALL_HYP_CONV_RULE
322         (Conv.DEPTH_CONV wordsLib.word_EQ_CONV
323          THENC REWRITE_CONV [v2w_0_rwts])
324   fun eval0 tm rwt =
325      let
326         val thm = eval_simp_rule (utilsLib.INST_REWRITE_CONV1 rwt tm)
327      in
328         if utilsLib.vacuous thm then NONE else SOME thm
329      end
330  val thms = List.map (DB.fetch "cheri_step") cheri_stepTheory.rwts
331  val find_thm = utilsLib.find_rw (utilsLib.mk_rw_net utilsLib.lhsc thms)
332in
333   fun eval tm =
334      let
335         fun err s = (Parse.print_term tm; print "\n"; raise ERR "eval" s)
336      in
337        (case List.mapPartial (eval0 tm) (find_thm tm) of
338            [] => err "no valid step theorem"
339          | [x] => x
340          | [x, _] => x (* ignore exception case *)
341          | l => (List.app (fn x => (Parse.print_thm x; print "\n")) l
342                  ; err "more than one valid step theorem"))
343        handle HOL_ERR {message = "not found",
344                        origin_function = "find_rw", ...} =>
345           err "instruction instance not supported"
346      end
347end
348
349local
350  val monop = #2 o HolKernel.syntax_fns1 "cheri"
351  val binop = #2 o HolKernel.syntax_fns2 "cheri"
352  val mk_exception = monop "cheri_state_exception"
353  val mk_exceptionSignalled = monop "exceptionSignalled"
354  val mk_BranchDelay = monop "BranchDelay"
355  val mk_BranchDelayPCC = monop "cheri_state_BranchDelayPCC"
356  val mk_BranchTo = monop "BranchTo"
357  val mk_BranchToPCC = monop "cheri_state_BranchToPCC"
358  val mk_CCallBranch = monop "cheri_state_CCallBranch"
359  val mk_CCallBranchDelay = monop "cheri_state_CCallBranchDelay"
360  val mk_currentInst_fupd = binop "cheri_state_currentInst_fupd"
361  fun currentInst w st' =
362    mk_currentInst_fupd (combinSyntax.mk_K_1 (w, Term.type_of w), st')
363  val st = ``s:cheri_state``
364  val ths = [exceptionSignalled_def, BranchDelay_def, BranchTo_def]
365  val datatype_conv =
366    REWRITE_CONV
367      (utilsLib.datatype_rewrites true "cheri"
368         ["cheri_state", "cheri_state_brss__0", "cheri_state_brss__1",
369          "procState", "DataType", "CP0", "CapCause", "StatusRegister",
370          "ExceptionType"] @ ths)
371  val dt_assume = ASSUME o utilsLib.rhsc o datatype_conv
372  val procID_th = dt_assume ``^st.procID = 0w``
373  val exceptionSignalled_th = dt_assume ``~exceptionSignalled ^st``
374  val BranchDelayPCC_th = dt_assume ``^st.BranchDelayPCC = NONE``
375  val BranchTo_th = dt_assume ``BranchTo ^st = NONE``
376  val BranchToPCC_th = dt_assume ``^st.BranchToPCC = NONE``
377  val CCallBranch_th = dt_assume ``~^st.CCallBranch``
378  val CCallBranchDelay_th = dt_assume ``~^st.CCallBranchDelay``
379  fun eqf_elim th = Drule.EQF_ELIM th handle HOL_ERR _ => th
380  val STATE_CONV =
381     eqf_elim o
382     Conv.QCONV
383       (datatype_conv
384        THENC REWRITE_CONV
385                [boolTheory.COND_ID, procID_th, exceptionSignalled_th,
386                 BranchDelayPCC_th, BranchTo_th, BranchToPCC_th,
387                 CCallBranch_th, CCallBranchDelay_th,
388                 GSYM cheriTheory.cheri_state_exception])
389  val hyp_rule = utilsLib.ALL_HYP_CONV_RULE datatype_conv
390  val full_rule = hyp_rule o Conv.RIGHT_CONV_RULE datatype_conv
391  val state_rule = Conv.RIGHT_CONV_RULE (Conv.RAND_CONV (utilsLib.SRW_CONV []))
392  val next_rule =
393    hyp_rule o
394    Conv.CONV_RULE
395      (Conv.RAND_CONV (utilsLib.SRW_CONV [PC_def, CP0_def])
396       THENC Conv.PATH_CONV "lrlr" (utilsLib.SRW_CONV [])
397       THENC Conv.PATH_CONV "lrrrrrlrr" (utilsLib.SRW_CONV []))
398  val NextStateCHERI_nodelay = next_rule cheri_stepTheory.NextStateCHERI_nodelay
399  val NextStateCHERI_delay = next_rule cheri_stepTheory.NextStateCHERI_delay
400  val MP_Next  = state_rule o Drule.MATCH_MP NextStateCHERI_nodelay
401  val MP_NextB = state_rule o Drule.MATCH_MP NextStateCHERI_delay
402  val Run_CONV = utilsLib.Run_CONV ("cheri", st) o utilsLib.rhsc
403  val get = pairSyntax.dest_pair o utilsLib.rhsc
404in
405  fun cheri_step v =
406    let
407      val thm1 = fetch v
408      val (w, st') = get thm1
409      val thm2 = cheri_decode v
410      val thm3 = Drule.SPEC_ALL (Run_CONV thm2)
411      val ethm = eval (utilsLib.rhsc thm3)
412      val thm3 = Conv.RIGHT_CONV_RULE (Conv.REWR_CONV ethm) thm3
413      val thm3 = full_rule (Thm.INST [st |-> currentInst w st'] thm3)
414      val tm = utilsLib.rhsc thm3
415      val thms = List.map (fn f => STATE_CONV (f tm))
416                    [mk_exception,
417                     mk_BranchDelay,
418                     mk_BranchDelayPCC,
419                     mk_BranchToPCC,
420                     mk_CCallBranch,
421                     mk_CCallBranchDelay,
422                     mk_BranchTo,
423                     mk_exceptionSignalled]
424      val thm = hyp_rule (Drule.LIST_CONJ ([thm1, thm2, thm3] @ thms))
425    in
426      [MP_Next thm]
427      @
428      ([MP_NextB thm] handle HOL_ERR _ => [])
429    end
430end
431
432(*
433
434 val tms = strip_conj (fst (dest_imp (concl NextStateCHERI_nodelay)))
435 val tms = strip_conj (fst (dest_imp (concl NextStateCHERI_delay)))
436
437 match_term (List.nth (tms, 0)) (concl thm1);
438 match_term (List.nth (tms, 1)) (concl thm2);
439 match_term (List.nth (tms, 2)) (concl thm3);
440 match_term (List.nth (tms, 3)) (concl (List.nth (thms, 0)));
441 match_term (List.nth (tms, 4)) (concl (List.nth (thms, 1)));
442 match_term (List.nth (tms, 5)) (concl (List.nth (thms, 2)));
443 match_term (List.nth (tms, 6)) (concl (List.nth (thms, 3)));
444 match_term (List.nth (tms, 7)) (concl (List.nth (thms, 4)));
445 match_term (List.nth (tms, 8)) (concl (List.nth (thms, 5)));
446
447*)
448
449val cheri_step_hex = cheri_step o bitstringSyntax.bitstring_of_hexstring
450
451(* ========================================================================= *)
452
453(* Testing
454
455open cheri_stepLib
456
457val step = cheri_step
458fun test s = step (Redblackmap.find (cheri_dict, s))
459fun test s = (Redblackmap.find (cheri_dict, s))
460
461val v = test "ADDI";
462val v = test "ADDU";
463val v = test "J";
464val v = test "BEQ";
465val v = test "BEQL";
466val v = test "BLTZAL";
467val v = test "BLTZALL";
468val v = test "LB";
469
470val l = List.map (Lib.total step o snd) (Redblackmap.listItems cheri_dict)
471
472*)
473
474end
475