1structure stack_analysisLib :> stack_analysisLib = 2struct 3 4open HolKernel boolLib bossLib Parse; 5open wordsTheory set_sepTheory progTheory helperLib addressTheory combinTheory; 6open backgroundLib file_readerLib writerLib; 7open GraphLangTheory 8 9fun stack_offset_in_r0 sec_name = let 10 val (_,ret_word_count,_) = section_io sec_name 11 in 1 < ret_word_count end 12 13local 14 val bool_arb = mk_arb(``:bool``) 15 val mem_type = ``:word32->word8`` 16 val write32_pat = ``WRITE32 a w m`` 17 fun dest_write32 tm = 18 if can (match_term write32_pat) tm then 19 let val (xyz,q) = dest_comb tm 20 val (xy,z) = dest_comb xyz 21 val (x,y) = dest_comb xy 22 in (y,z,q) end 23 else failwith "dest_write32" 24 fun list_dest_write32 tm = 25 let val (a,w,m) = dest_write32 tm 26 val ls = list_dest_write32 m 27 in (``READ32 ^a ^m``,w)::ls end handle HOL_ERR _ => [] 28in 29 fun get_updates pre post = let 30 val ps = list_dest dest_star pre 31 val qs = list_dest dest_star post 32 fun find_same q [] = failwith "not found" 33 | find_same q (y::ys) = 34 (if rator y = q then rand y else find_same q ys) 35 handle HOL_ERR _ => find_same q ys 36 fun dest_write (x,y) = 37 if type_of x <> mem_type then [(x,y)] else list_dest_write32 y 38 in map (fn tm => (rand tm,find_same (rator tm) qs)) ps 39 |> filter (fn (x,y) => is_var x andalso x <> y) 40 |> map (fn (x,y) => if type_of x = ``:bool`` then (x,bool_arb) else (x,y)) 41 |> map dest_write |> Lib.flatten 42 end 43end 44 45fun get_assum_pre_post th = let 46 val th = th |> SIMP_RULE (std_ss++sep_cond_ss) [word_arith_lemma1] 47 |> SIMP_RULE std_ss [word_arith_lemma3,word_arith_lemma4] 48 |> PURE_REWRITE_RULE [SPEC_MOVE_COND] 49 |> UNDISCH_ALL 50 val (_,pre,_,post) = dest_spec (concl th) 51 in (hyp th, pre, post) end 52 53fun get_pc_pat () = let 54 val (_,_,_,pc) = get_tools () 55 in ``^pc w`` end 56 57local 58 fun get_pc_val tm = let 59 val pc_pat = get_pc_pat () 60 in find_term (can (match_term pc_pat)) tm |> rand end 61 val dmem_type = ``:word32 set`` 62 fun dest_dmem_test tm = let 63 val (x,y) = pred_setSyntax.dest_in tm 64 val (v,ty) = dest_var y 65 val _ = (ty = dmem_type) orelse fail() 66 in x end 67 fun get_mem_access [] = NONE 68 | get_mem_access (tm::tms) = 69 SOME (dest_dmem_test (find_term (can dest_dmem_test) tm)) 70 handle HOL_ERR _ => get_mem_access tms 71 fun summary (i:int,(th,l:int,j:int option),NONE) = let 72 val (assum,pre,post) = get_assum_pre_post th 73 val pc = get_pc_val pre 74 fun all_posts post = 75 if is_cond post then let 76 val (_,p1,p2) = dest_cond post 77 in all_posts p1 @ all_posts p2 end 78 else [post] 79 in map (fn post => (pc,T,get_updates pre post, 80 get_mem_access assum,get_pc_val post)) 81 (all_posts post) end 82 | summary (i,(th1,l1,j1),SOME (th2,l:int,j:int option)) = let 83 val (assum1,pre1,post1) = get_assum_pre_post th1 84 val (assum2,pre2,post2) = get_assum_pre_post th2 85 val assum2 = list_mk_conj assum2 86 val u1 = get_updates pre1 post1 87 val u2 = get_updates pre2 post2 88 val assum1' = mk_neg assum2 |> QCONV (REWRITE_CONV []) |> concl |> rand 89 val pc = get_pc_val pre1 90 val res1 = (pc,assum1',u1,get_mem_access assum1,get_pc_val post1) 91 val res2 = (pc,assum2,u2,NONE,get_pc_val post2) 92 in [res1,res2] end handle HOL_ERR _ => 93 summary (i,(th1,l1,j1),NONE) @ summary (i,(th2,l,j),NONE) 94 fun dest_call_tag tm = let 95 val (xy,z) = dest_comb tm 96 val (x,y) = dest_comb xy 97 val _ = (x = ``CALL_TAG``) orelse fail() 98 in (stringSyntax.fromHOLstring y, 99 if z = T then true else 100 if z = F then false else fail()) end 101 fun has_call_tag th = 102 can (find_term (can dest_call_tag)) (concl (DISCH_ALL th)) 103 val call_update = let 104 fun mk_arb_pair tm = (tm,mk_arb(type_of tm)) 105 in map mk_arb_pair [``r0:word32``, 106 ``r1:word32``, ``r2:word32``, ``r3:word32``, ``r14:word32``, 107 ``n:bool``, ``z:bool``, ``c:bool``, ``v:bool``] end 108in 109 fun approx_summary (i,(th1,i1,i2),thi2) = 110 if not (has_call_tag th1) then summary (i,(th1,i1,i2),thi2) else let 111 val res = summary (i,(th1,i1,i2),thi2) 112 val (p1,assum1,u1,addr,q1) = hd res 113 val r14 = mk_var("r14",``:word32``) 114 val dest = first (fn (x,_) => x = r14) u1 |> snd handle HOL_ERR _ => T 115 in (p1,assum1,call_update,addr,dest) :: tl res end 116end 117 118fun find_stack_accesses_for all_summaries sec_name = let 119 val r13 = ``r13:word32`` 120 val (init_pc,_,_,_,_) = hd all_summaries 121 val state = (init_pc,(r13,r13)::(if stack_offset_in_r0 sec_name then 122 [(``r0:word32``,``^r13 + offset``)] 123 else []),T) 124 val (pc,s,t) = state 125 val us = filter (fn (p,_,_,_,_) => p = pc) all_summaries 126 val (pc1,assum,u,addr,pc2) = hd us 127 val stack_accesses = ref ([]:int list); 128 fun add_stack_access pc = let 129 val n = pc |> wordsSyntax.dest_n2w |> fst |> numSyntax.int_of_term 130 val a = !stack_accesses 131 in if mem n a then () else (stack_accesses := n::a) end 132 fun can_exec_step t (pc1,assum,u,addr,pc2) = 133 if assum = T then true else let 134 val test = mk_imp(t,mk_neg assum) 135 val vs = free_vars test 136 val test = list_mk_forall(vs,test) 137 fun can_prove_by_cases goal = 138 ([],goal) |> (REPEAT Cases THEN REWRITE_TAC []) 139 |> (fn (x,_) => length x = 0) 140 in not (can_prove_by_cases test) end 141 fun is_sp_add_or_sub a = 142 if a = r13 then true else let 143 val (w1,w2) = wordsSyntax.dest_word_add a 144 in (w1 = r13 andalso wordsSyntax.is_n2w w2) orelse 145 (w2 = r13 andalso wordsSyntax.is_n2w w1) end 146 handle HOL_ERR _ => let 147 val (w1,w2) = wordsSyntax.dest_word_sub a 148 in (w1 = r13 andalso wordsSyntax.is_n2w w2) orelse 149 (w2 = r13 andalso wordsSyntax.is_n2w w1) end 150 handle HOL_ERR _ => false 151 val stack_read32_pat = ``READ32 (a:word32) m`` 152 fun is_simple_or_stack_read32 (x,y) = 153 if is_var x then true else 154 if can (match_term stack_read32_pat) x then 155 (x |> rator |> rand |> is_sp_add_or_sub) 156 else false 157 val word_simp_tm = rand o concl o QCONV (SIMP_CONV std_ss [word_arith_lemma1] THENC 158 SIMP_CONV std_ss [word_arith_lemma3,word_arith_lemma4,WORD_ADD_0]) 159 fun exec_step s t (pc1,assum,u,addr,pc2) = let 160 val s_simple = filter (fn (x,_) => is_var x) s 161 val s_read32 = filter (fn (x,_) => not (is_var x)) s 162 val i_simple = word_simp_tm o subst (map (fn (x,y) => x |-> y) s_simple) 163 val i = word_simp_tm o subst (map (fn (x,y) => x |-> y) s_read32) o i_simple 164 fun i_read32_only tm = if is_comb tm then i_simple tm else tm 165 val new_u_part = map (fn (x,y) => (i_read32_only x,i y)) u 166 val u_domain = map fst new_u_part 167 val new_u = new_u_part @ filter (fn (x,y) => not (mem x u_domain)) s 168 val new_t = if assum = T then t else assum 169 val new_t = if intersect u_domain (free_vars new_t) <> [] then T else new_t 170 in (pc2,filter is_simple_or_stack_read32 new_u,new_t) end 171 fun register_state (pc,s,t) = let 172 val _ = print ("\nRegister state:\n pc = " ^ term_to_string pc ^ "\n") 173 val _ = print (" assumes " ^ term_to_string t ^ "\n") 174 val s = filter (fn (x,y) => not (is_arb y)) s 175 val _ = map (fn (x,y) => print (" " ^ term_to_string x ^ " is " ^ 176 term_to_string y ^ "\n")) s 177 in () end 178 val read32_pat = ``READ32 a (m:word32->word8)`` 179 fun remove_read32 tm = let 180 val xs = find_terms (can (match_term read32_pat)) tm 181 val ss = map (fn x => x |-> (mk_arb(type_of x))) xs 182 in subst ss tm end 183 fun found_stack_access pc s = let 184 val _ = add_stack_access pc 185(* val _ = print ("Stack found access at pc = " ^ term_to_string pc ^ "\n") *) 186(* val _ = print "\n" *) 187 in () end 188 fun check_for_stack_accesses (pc,s,t) NONE = () 189 | check_for_stack_accesses (pc,s,t) (SOME a) = let 190 val i = remove_read32 o word_simp_tm o subst (map (fn (x,y) => x |-> y) s) 191 fun contains_sp tm = mem r13 (free_vars tm) 192 in if contains_sp (i a) 193 then found_stack_access pc (filter (fn (x,y) => contains_sp y) 194 (map (fn (x,y) => (x,i y)) s)) 195 else () end 196 fun get_pc (pc,s,t) = pc 197 val seen_nodes = ref ([]:(term * term) list) 198 fun has_visited (pc,s,t) = let 199 val seen = !seen_nodes 200 in if can (first (fn x => x = (t,pc))) seen then true else 201 (seen_nodes := ((t,pc)::seen); false) end 202 fun print_state (pc,s,t) = let 203 val _ = print ("Looking at pc = " ^ term_to_string pc ^ "\n") 204 val _ = map (fn (x,y) => print (" " ^ term_to_string x ^ " is " ^ 205 term_to_string y ^ "\n")) s 206 val _ = print (" assuming: " ^ term_to_string t ^ "\n\n") 207 in () end 208 fun exec_steps state = 209 if has_visited state then () else let 210 (* val _ = register_state state *) 211 val (pc,s,t) = state 212(* val _ = print_state state *) 213 val us = filter (fn (p,_,_,_,_) => p = pc) all_summaries 214 val us = filter (can_exec_step t) us 215 val addresses = map (fn (_,_,_,a,_) => a) us 216 val _ = map (check_for_stack_accesses state) addresses 217(* val (pc1,assum,u,addr,pc2) = hd us *) 218 val states = map (exec_step s t) us 219(* val state = hd states handle Empty => state *) 220 val _ = map exec_steps states 221 in () end 222 val _ = exec_steps state 223 val xs = !stack_accesses 224 in xs end; 225 226fun find_stack_accesses sec_name thms = let 227 val _ = write_subsection "Stack analysis" 228 val all_summaries = map approx_summary thms |> flatten 229 val all_simple_summaries = all_summaries 230 |> map (fn (x,a,u,z,y) => (x,a,filter (is_var o fst) u,z,y)) 231 val all_stack_accesses = find_stack_accesses_for all_summaries sec_name 232 val simple_stack_accesses = find_stack_accesses_for all_simple_summaries sec_name 233 fun annotation loc = 234 if mem loc simple_stack_accesses then "stack access" else 235 if mem loc all_stack_accesses then "indirect stack access" else fail() 236 val _ = if stack_offset_in_r0 sec_name then 237 write_line ("Section `" ^ sec_name ^ "` expects pointer to stack in r0.") 238 else () 239 val l = all_stack_accesses |> all_distinct |> length |> int_to_string 240 val _ = (if l = "0" then 241 write_line ("No stack accesses found. Code for `" ^ sec_name ^ "`:") 242 else 243 write_line (l ^ " stack accesses found. " ^ 244 "Annotated code for `" ^ sec_name ^ "`:")) 245 val _ = show_annotated_code annotation sec_name 246 in all_stack_accesses end 247 248end 249