1structure stack_analysisLib :> stack_analysisLib =
2struct
3
4open HolKernel boolLib bossLib Parse;
5open wordsTheory set_sepTheory progTheory helperLib addressTheory combinTheory;
6open backgroundLib file_readerLib writerLib;
7open GraphLangTheory
8
9fun stack_offset_in_r0 sec_name = let
10  val (_,ret_word_count,_) = section_io sec_name
11  in 1 < ret_word_count end
12
13local
14  val bool_arb = mk_arb(``:bool``)
15  val mem_type = ``:word32->word8``
16  val write32_pat = ``WRITE32 a w m``
17  fun dest_write32 tm =
18    if can (match_term write32_pat) tm then
19      let val (xyz,q) = dest_comb tm
20          val (xy,z) = dest_comb xyz
21          val (x,y) = dest_comb xy
22      in (y,z,q) end
23    else failwith "dest_write32"
24  fun list_dest_write32 tm =
25    let val (a,w,m) = dest_write32 tm
26        val ls = list_dest_write32 m
27    in (``READ32 ^a ^m``,w)::ls end handle HOL_ERR _ => []
28in
29  fun get_updates pre post = let
30    val ps = list_dest dest_star pre
31    val qs = list_dest dest_star post
32    fun find_same q [] = failwith "not found"
33      | find_same q (y::ys) =
34          (if rator y = q then rand y else find_same q ys)
35          handle HOL_ERR _ => find_same q ys
36    fun dest_write (x,y) =
37      if type_of x <> mem_type then [(x,y)] else list_dest_write32 y
38    in map (fn tm => (rand tm,find_same (rator tm) qs)) ps
39       |> filter (fn (x,y) => is_var x andalso x <> y)
40       |> map (fn (x,y) => if type_of x = ``:bool`` then (x,bool_arb) else (x,y))
41       |> map dest_write |> Lib.flatten
42    end
43end
44
45fun get_assum_pre_post th = let
46  val th = th |> SIMP_RULE (std_ss++sep_cond_ss) [word_arith_lemma1]
47              |> SIMP_RULE std_ss [word_arith_lemma3,word_arith_lemma4]
48              |> PURE_REWRITE_RULE [SPEC_MOVE_COND]
49              |> UNDISCH_ALL
50  val (_,pre,_,post) = dest_spec (concl th)
51  in (hyp th, pre, post) end
52
53fun get_pc_pat () = let
54  val (_,_,_,pc) = get_tools ()
55  in ``^pc w`` end
56
57local
58  fun get_pc_val tm = let
59    val pc_pat = get_pc_pat ()
60    in find_term (can (match_term pc_pat)) tm |> rand end
61  val dmem_type = ``:word32 set``
62  fun dest_dmem_test tm = let
63    val (x,y) = pred_setSyntax.dest_in tm
64    val (v,ty) = dest_var y
65    val _ = (ty = dmem_type) orelse fail()
66    in x end
67  fun get_mem_access [] = NONE
68    | get_mem_access (tm::tms) =
69        SOME (dest_dmem_test (find_term (can dest_dmem_test) tm))
70        handle HOL_ERR _ => get_mem_access tms
71  fun summary (i:int,(th,l:int,j:int option),NONE) = let
72      val (assum,pre,post) = get_assum_pre_post th
73      val pc = get_pc_val pre
74      fun all_posts post =
75        if is_cond post then let
76          val (_,p1,p2) = dest_cond post
77          in all_posts p1 @ all_posts p2 end
78        else [post]
79      in map (fn post => (pc,T,get_updates pre post,
80                          get_mem_access assum,get_pc_val post))
81           (all_posts post) end
82    | summary (i,(th1,l1,j1),SOME (th2,l:int,j:int option)) = let
83      val (assum1,pre1,post1) = get_assum_pre_post th1
84      val (assum2,pre2,post2) = get_assum_pre_post th2
85      val assum2 = list_mk_conj assum2
86      val u1 = get_updates pre1 post1
87      val u2 = get_updates pre2 post2
88      val assum1' = mk_neg assum2 |> QCONV (REWRITE_CONV []) |> concl |> rand
89      val pc = get_pc_val pre1
90      val res1 = (pc,assum1',u1,get_mem_access assum1,get_pc_val post1)
91      val res2 = (pc,assum2,u2,NONE,get_pc_val post2)
92      in [res1,res2] end handle HOL_ERR _ =>
93        summary (i,(th1,l1,j1),NONE) @ summary (i,(th2,l,j),NONE)
94  fun dest_call_tag tm = let
95    val (xy,z) = dest_comb tm
96    val (x,y) = dest_comb xy
97    val _ = (x = ``CALL_TAG``) orelse fail()
98    in (stringSyntax.fromHOLstring y,
99        if z = T then true else
100        if z = F then false else fail()) end
101  fun has_call_tag th =
102    can (find_term (can dest_call_tag)) (concl (DISCH_ALL th))
103  val call_update = let
104    fun mk_arb_pair tm = (tm,mk_arb(type_of tm))
105    in map mk_arb_pair [``r0:word32``,
106             ``r1:word32``, ``r2:word32``, ``r3:word32``, ``r14:word32``,
107             ``n:bool``, ``z:bool``, ``c:bool``, ``v:bool``] end
108in
109  fun approx_summary (i,(th1,i1,i2),thi2) =
110    if not (has_call_tag th1) then summary (i,(th1,i1,i2),thi2) else let
111      val res = summary (i,(th1,i1,i2),thi2)
112      val (p1,assum1,u1,addr,q1) = hd res
113      val r14 = mk_var("r14",``:word32``)
114      val dest = first (fn (x,_) => x = r14) u1 |> snd handle HOL_ERR _ => T
115      in (p1,assum1,call_update,addr,dest) :: tl res end
116end
117
118fun find_stack_accesses_for all_summaries sec_name = let
119  val r13 = ``r13:word32``
120  val (init_pc,_,_,_,_) = hd all_summaries
121  val state = (init_pc,(r13,r13)::(if stack_offset_in_r0 sec_name then
122                                    [(``r0:word32``,``^r13 + offset``)]
123                                   else []),T)
124  val (pc,s,t) = state
125  val us = filter (fn (p,_,_,_,_) => p = pc) all_summaries
126  val (pc1,assum,u,addr,pc2) = hd us
127  val stack_accesses = ref ([]:int list);
128  fun add_stack_access pc = let
129    val n = pc |> wordsSyntax.dest_n2w |> fst |> numSyntax.int_of_term
130    val a = !stack_accesses
131    in if mem n a then () else (stack_accesses := n::a) end
132  fun can_exec_step t (pc1,assum,u,addr,pc2) =
133    if assum = T then true else let
134      val test = mk_imp(t,mk_neg assum)
135      val vs = free_vars test
136      val test = list_mk_forall(vs,test)
137      fun can_prove_by_cases goal =
138        ([],goal) |> (REPEAT Cases THEN REWRITE_TAC [])
139                  |> (fn (x,_) => length x = 0)
140      in not (can_prove_by_cases test) end
141  fun is_sp_add_or_sub a =
142    if a = r13 then true else let
143      val (w1,w2) = wordsSyntax.dest_word_add a
144      in (w1 = r13 andalso wordsSyntax.is_n2w w2) orelse
145         (w2 = r13 andalso wordsSyntax.is_n2w w1) end
146    handle HOL_ERR _ => let
147      val (w1,w2) = wordsSyntax.dest_word_sub a
148      in (w1 = r13 andalso wordsSyntax.is_n2w w2) orelse
149         (w2 = r13 andalso wordsSyntax.is_n2w w1) end
150    handle HOL_ERR _ => false
151  val stack_read32_pat = ``READ32 (a:word32) m``
152  fun is_simple_or_stack_read32 (x,y) =
153    if is_var x then true else
154    if can (match_term stack_read32_pat) x then
155      (x |> rator |> rand |> is_sp_add_or_sub)
156    else false
157  val word_simp_tm = rand o concl o QCONV (SIMP_CONV std_ss [word_arith_lemma1] THENC
158        SIMP_CONV std_ss [word_arith_lemma3,word_arith_lemma4,WORD_ADD_0])
159  fun exec_step s t (pc1,assum,u,addr,pc2) = let
160    val s_simple = filter (fn (x,_) => is_var x) s
161    val s_read32 = filter (fn (x,_) => not (is_var x)) s
162    val i_simple = word_simp_tm o subst (map (fn (x,y) => x |-> y) s_simple)
163    val i = word_simp_tm o subst (map (fn (x,y) => x |-> y) s_read32) o i_simple
164    fun i_read32_only tm = if is_comb tm then i_simple tm else tm
165    val new_u_part = map (fn (x,y) => (i_read32_only x,i y)) u
166    val u_domain = map fst new_u_part
167    val new_u = new_u_part @ filter (fn (x,y) => not (mem x u_domain)) s
168    val new_t = if assum = T then t else assum
169    val new_t = if intersect u_domain (free_vars new_t) <> [] then T else new_t
170    in (pc2,filter is_simple_or_stack_read32 new_u,new_t) end
171  fun register_state (pc,s,t) = let
172    val _ = print ("\nRegister state:\n  pc = " ^ term_to_string pc ^ "\n")
173    val _ = print ("  assumes " ^ term_to_string t ^ "\n")
174    val s = filter (fn (x,y) => not (is_arb y)) s
175    val _ = map (fn (x,y) => print ("  " ^ term_to_string x ^ " is " ^
176                                           term_to_string y ^ "\n")) s
177    in () end
178  val read32_pat = ``READ32 a (m:word32->word8)``
179  fun remove_read32 tm = let
180    val xs = find_terms (can (match_term read32_pat)) tm
181    val ss = map (fn x => x |-> (mk_arb(type_of x))) xs
182    in subst ss tm end
183  fun found_stack_access pc s = let
184    val _ = add_stack_access pc
185(*    val _ = print ("Stack found access at pc = " ^ term_to_string pc ^ "\n") *)
186(*    val _ =  print "\n" *)
187    in () end
188  fun check_for_stack_accesses (pc,s,t) NONE = ()
189    | check_for_stack_accesses (pc,s,t) (SOME a) = let
190    val i = remove_read32 o word_simp_tm o subst (map (fn (x,y) => x |-> y) s)
191    fun contains_sp tm = mem r13 (free_vars tm)
192    in if contains_sp (i a)
193       then found_stack_access pc (filter (fn (x,y) => contains_sp y)
194                                     (map (fn (x,y) => (x,i y)) s))
195       else () end
196  fun get_pc (pc,s,t) = pc
197  val seen_nodes = ref ([]:(term * term) list)
198  fun has_visited (pc,s,t) = let
199    val seen = !seen_nodes
200    in if can (first (fn x => x = (t,pc))) seen then true else
201         (seen_nodes := ((t,pc)::seen); false) end
202  fun print_state (pc,s,t) = let
203    val _ = print ("Looking at pc = " ^ term_to_string pc ^ "\n")
204    val _ = map (fn (x,y) => print ("  " ^ term_to_string x ^ " is " ^
205                                             term_to_string y ^ "\n")) s
206    val _ = print ("  assuming: " ^ term_to_string t ^ "\n\n")
207    in () end
208  fun exec_steps state =
209    if has_visited state then () else let
210     (* val _ = register_state state *)
211      val (pc,s,t) = state
212(*      val _ = print_state state *)
213      val us = filter (fn (p,_,_,_,_) => p = pc) all_summaries
214      val us = filter (can_exec_step t) us
215      val addresses = map (fn (_,_,_,a,_) => a) us
216      val _ = map (check_for_stack_accesses state) addresses
217(*      val (pc1,assum,u,addr,pc2) = hd us *)
218      val states = map (exec_step s t) us
219(*      val state = hd states handle Empty => state *)
220      val _ = map exec_steps states
221      in () end
222  val _ = exec_steps state
223  val xs = !stack_accesses
224  in xs end;
225
226fun find_stack_accesses sec_name thms = let
227  val _ = write_subsection "Stack analysis"
228  val all_summaries = map approx_summary thms |> flatten
229  val all_simple_summaries = all_summaries
230    |> map (fn (x,a,u,z,y) => (x,a,filter (is_var o fst) u,z,y))
231  val all_stack_accesses = find_stack_accesses_for all_summaries sec_name
232  val simple_stack_accesses = find_stack_accesses_for all_simple_summaries sec_name
233  fun annotation loc =
234    if mem loc simple_stack_accesses then "stack access" else
235    if mem loc all_stack_accesses then "indirect stack access" else fail()
236  val _ = if stack_offset_in_r0 sec_name then
237            write_line ("Section `" ^ sec_name ^ "` expects pointer to stack in r0.")
238          else ()
239  val l = all_stack_accesses |> all_distinct |> length |> int_to_string
240  val _ = (if l = "0" then
241             write_line ("No stack accesses found. Code for `" ^ sec_name ^ "`:")
242           else
243             write_line (l ^ " stack accesses found. " ^
244                         "Annotated code for `" ^ sec_name ^ "`:"))
245  val _ = show_annotated_code annotation sec_name
246  in all_stack_accesses end
247
248end
249