1(*===========================================================================*) 2(* Definition of the encryption and decryption algorithms plus *) 3(* proof of correctness. *) 4(*===========================================================================*) 5 6(* 7 app load ["RoundOpTheory", "metisLib"]; 8*) 9open HolKernel Parse boolLib bossLib RoundOpTheory pairTheory metisLib; 10 11(*---------------------------------------------------------------------------*) 12(* Make bindings to pre-existing stuff *) 13(*---------------------------------------------------------------------------*) 14 15val RESTR_EVAL_TAC = computeLib.RESTR_EVAL_TAC; 16 17val _ = new_theory "aes"; 18 19(*---------------------------------------------------------------------------*) 20(* The keyschedule can be represented as a circular buffer of fixed size. *) 21(* It has 11 keys (blocks) in it, and the buffer gets rotated each time *) 22(* a key is taken from it. *) 23(*---------------------------------------------------------------------------*) 24 25val _ = 26 type_abbrev ("keysched", Type`:key#key#key#key#key#key#key#key#key#key#key`); 27 28val FORALL_KEYSCHED = Q.prove 29(`(!x:keysched. P x) = !k1 k2 k3 k4 k5 k6 k7 k8 k9 k10 k11. 30 P(k1,k2,k3,k4,k5,k6,k7,k8,k9,k10,k11)`, 31 EQ_TAC THEN RW_TAC std_ss [] THEN 32 `?a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 a11. 33 x = (a1,a2,a3,a4,a5,a6,a7,a8,a9,a10,a11)` 34 by METIS_TAC [ABS_PAIR_THM] 35 THEN ASM_REWRITE_TAC[]); 36 37 38val ROTKEYS_def = 39 Define 40 `ROTKEYS (k0,k1,k2,k3,k4,k5,k6,k7,k8,k9,k10) = 41 (k1,k2,k3,k4,k5,k6,k7,k8,k9,k10,k0) : keysched`; 42 43val REVKEYS_def = 44 Define 45 `REVKEYS (k0,k1,k2,k3,k4,k5,k6,k7,k8,k9,k10) = 46 (k10,k9,k8,k7,k6,k5,k4,k3,k2,k1,k0) : keysched`; 47 48val LIST_TO_KEYS_def = 49 Define 50 `(LIST_TO_KEYS [] acc = acc) /\ 51 (LIST_TO_KEYS (h::t) (k1,k2,k3,k4,k5,k6,k7,k8,k9,k10,k11) = 52 LIST_TO_KEYS t (h,k1,k2,k3,k4,k5,k6,k7,k8,k9,k10))`; 53 54val DUMMY_KEYS_def = 55 Define 56 `DUMMY_KEYS = (ZERO_BLOCK,ZERO_BLOCK,ZERO_BLOCK,ZERO_BLOCK, 57 ZERO_BLOCK, ZERO_BLOCK,ZERO_BLOCK,ZERO_BLOCK, 58 ZERO_BLOCK,ZERO_BLOCK,ZERO_BLOCK)`; 59 60(*---------------------------------------------------------------------------*) 61(* Orchestrate the round computations. *) 62(*---------------------------------------------------------------------------*) 63 64 65val (Round_def, Round_ind) = Defn.tprove 66 (Hol_defn 67 "Round" 68 `Round n (keys:keysched) (state:state) = 69 if n=0 70 then AddRoundKey (FST keys) (ShiftRows (SubBytes state)) 71 else Round (n-1) (ROTKEYS keys) 72 (AddRoundKey (FST keys) 73 (MixColumns (ShiftRows (SubBytes state))))`, 74 WF_REL_TAC `measure FST` THEN REPEAT PairRules.PGEN_TAC THEN DECIDE_TAC); 75 76 77val (InvRound_def,InvRound_ind) = Defn.tprove 78 (Hol_defn 79 "InvRound" 80 `InvRound n keys state = 81 if n=0 82 then AddRoundKey (FST keys) (InvSubBytes (InvShiftRows state)) 83 else InvRound (n-1) (ROTKEYS keys) 84 (InvMixColumns 85 (AddRoundKey (FST keys) (InvSubBytes (InvShiftRows state))))`, 86 WF_REL_TAC `measure FST` THEN REPEAT PairRules.PGEN_TAC THEN DECIDE_TAC); 87 88val _ = save_thm ("Round_def", Round_def); 89val _ = save_thm ("Round_ind", Round_ind); 90val _ = save_thm ("InvRound_def", InvRound_def); 91val _ = save_thm ("InvRound_ind", InvRound_ind); 92 93(*---------------------------------------------------------------------------*) 94(* Encrypt and Decrypt *) 95(*---------------------------------------------------------------------------*) 96 97val AES_FWD_def = 98 Define 99 `AES_FWD keys = 100 from_state o Round 9 (ROTKEYS keys) o AddRoundKey (FST keys) o to_state`; 101 102 103val AES_BWD_def = 104 Define 105 `AES_BWD keys = 106 from_state o InvRound 9 (ROTKEYS keys) o AddRoundKey (FST keys) o to_state`; 107 108(*---------------------------------------------------------------------------*) 109(* Main lemma *) 110(*---------------------------------------------------------------------------*) 111 112val [MultCol] = decls "MultCol"; 113val [InvMultCol] = decls "InvMultCol"; 114val [genMixColumns] = decls "genMixColumns"; 115 116val AES_LEMMA = Q.store_thm 117("AES_LEMMA", 118 `!(plaintext:state) (keys:keysched). 119 AES_BWD (REVKEYS keys) (AES_FWD keys plaintext) = plaintext`, 120 SIMP_TAC std_ss [FORALL_BLOCK] THEN 121 SIMP_TAC std_ss [FORALL_KEYSCHED] 122 THEN RESTR_EVAL_TAC [MultCol,InvMultCol,genMixColumns] 123 THEN RW_TAC std_ss [ShiftRows_Inversion,SubBytes_Inversion, 124 XOR_BLOCK_IDEM,MixColumns_Inversion, 125 from_state_Inversion,from_state_def]); 126 127 128(*--------------------------------------------------------------------------- 129 Generate the key schedule from key. We work using 4-tuples of 130 bytes. Unpacking moves from four contiguous 4-tuples to a 16-tuple, 131 and also lays the bytes out in the top-to-bottom, left-to-right 132 order that the state also has. 133 ---------------------------------------------------------------------------*) 134 135val XOR8x4_def = 136 Define 137 `(a,b,c,d) XOR8x4 (a1,b1,c1,d1) = (a # a1, b # b1, c # c1, d # d1)`; 138 139val SubWord_def = Define 140 `SubWord(b0,b1,b2,b3) = (Sbox b0, Sbox b1, Sbox b2, Sbox b3)`; 141 142val RotWord_def = Define 143 `RotWord(b0,b1,b2,b3) = (b1,b2,b3,b0)`; 144 145val Rcon_def = Define 146 `Rcon i = (PolyExp TWO (i-1), ZERO,ZERO,ZERO)`; 147 148val unpack_def = Define 149 `(unpack [] A = A) /\ 150 (unpack ((a,b,c,d)::(e,f,g,h)::(i,j,k,l)::(m,n,o1,p)::rst) A 151 = unpack rst ((m,i,e,a,n,j,f,b,o1,k,g,c,p,l,h,d)::A))`; 152 153(*---------------------------------------------------------------------------*) 154(* Build the keyschedule from a key. This definition is too specific, but *) 155(* works fine for 128 bit blocks. *) 156(*---------------------------------------------------------------------------*) 157 158val (expand_def,expand_ind) = 159Defn.tprove 160 (Hol_defn 161 "expand" 162 `expand n sched = 163 if 43 < n then unpack sched [] 164 else let h = HD sched in 165 let h' = if ~(n MOD 4 = 0) then h 166 else SubWord(RotWord h) XOR8x4 Rcon(n DIV 4) 167 in expand (n+1) ((h' XOR8x4 (HD(TL(TL(TL sched)))))::sched)`, 168 WF_REL_TAC `measure ($- 44 o FST)`); 169 170 171val _ = save_thm ("expand_def", expand_def); 172val _ = save_thm ("expand_ind", expand_ind); 173val _ = computeLib.add_persistent_funs ["expand_def"]; 174 175val mk_keysched_def = Define 176 `mk_keysched ((b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,b15):key) 177 = 178 expand 4 [(b12,b13,b14,b15) ; (b8,b9,b10,b11) ; 179 (b4,b5,b6,b7) ; (b0,b1,b2,b3)]`; 180 181 182(*---------------------------------------------------------------------------*) 183(* Sanity check *) 184(*---------------------------------------------------------------------------*) 185 186val keysched_length = Count.apply Q.prove 187(`!key. LENGTH (mk_keysched key) = 11`, 188 SIMP_TAC std_ss [FORALL_BLOCK,mk_keysched_def] 189 THEN REPEAT GEN_TAC 190 THEN EVAL_TAC); 191 192(*---------------------------------------------------------------------------*) 193(* Generate key schedule, and its inverse, then build the encryption and *) 194(* decryption functions. Called AES, since it wraps everything up into a *) 195(* single package. *) 196(*---------------------------------------------------------------------------*) 197 198val AES_def = Define 199 `AES key = 200 let keys = LIST_TO_KEYS (mk_keysched key) DUMMY_KEYS 201 in (AES_FWD keys, AES_BWD (REVKEYS keys))`; 202 203 204(*---------------------------------------------------------------------------*) 205(* Basic theorem about encryption/decryption *) 206(*---------------------------------------------------------------------------*) 207 208val AES_CORRECT = Q.store_thm 209 ("AES_CORRECT", 210 `!key plaintext. 211 ((encrypt,decrypt) = AES key) 212 ==> 213 (decrypt (encrypt plaintext) = plaintext)`, 214 RW_TAC std_ss [AES_def,LET_THM,AES_LEMMA]); 215 216 217val _ = export_theory(); 218