1(*===========================================================================*) 2(* Definition of the encryption and decryption algorithms plus *) 3(* proof of correctness. *) 4(*===========================================================================*) 5 6(* 7 app load ["RoundOpTheory"]; 8*) 9open HolKernel Parse boolLib bossLib RoundOpTheory pairTheory; 10 11(*---------------------------------------------------------------------------*) 12(* Make bindings to pre-existing stuff *) 13(*---------------------------------------------------------------------------*) 14 15val RESTR_EVAL_TAC = computeLib.RESTR_EVAL_TAC; 16 17val _ = new_theory "aes"; 18 19(*---------------------------------------------------------------------------*) 20(* The keyschedule can be represented as a circular buffer of fixed size. *) 21(* It has 11 keys (blocks) in it, and the buffer gets rotated each time *) 22(* a key is taken from it. *) 23(*---------------------------------------------------------------------------*) 24 25val _ = 26 type_abbrev ("keysched", ``:key#key#key#key#key#key#key#key#key#key#key``); 27 28val FORALL_KEYSCHED = Q.store_thm 29("FORALL_KEYSCHED", 30 `(!x:keysched. P x) = !k1 k2 k3 k4 k5 k6 k7 k8 k9 k10 k11. 31 P(k1,k2,k3,k4,k5,k6,k7,k8,k9,k10,k11)`, 32 EQ_TAC THEN RW_TAC std_ss [] THEN 33 `?a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 a11. 34 x = (a1,a2,a3,a4,a5,a6,a7,a8,a9,a10,a11)` 35 by METIS_TAC [ABS_PAIR_THM] 36 THEN ASM_REWRITE_TAC[]); 37 38 39val ROTKEYS_def = 40 Define 41 `ROTKEYS (k0,k1,k2,k3,k4,k5,k6,k7,k8,k9,k10) = 42 (k1,k2,k3,k4,k5,k6,k7,k8,k9,k10,k0) : keysched`; 43 44val REVKEYS_def = 45 Define 46 `REVKEYS (k0,k1,k2,k3,k4,k5,k6,k7,k8,k9,k10) = 47 (k10,k9,k8,k7,k6,k5,k4,k3,k2,k1,k0) : keysched`; 48 49val LIST_TO_KEYS_def = 50 Define 51 `(LIST_TO_KEYS [] acc = acc) /\ 52 (LIST_TO_KEYS (h::t) (k1,k2,k3,k4,k5,k6,k7,k8,k9,k10,k11) = 53 LIST_TO_KEYS t (h,k1,k2,k3,k4,k5,k6,k7,k8,k9,k10))`; 54 55val DUMMY_KEYS_def = 56 Define 57 `DUMMY_KEYS = (ZERO_BLOCK,ZERO_BLOCK,ZERO_BLOCK,ZERO_BLOCK, 58 ZERO_BLOCK,ZERO_BLOCK,ZERO_BLOCK,ZERO_BLOCK, 59 ZERO_BLOCK,ZERO_BLOCK,ZERO_BLOCK)`; 60 61(*---------------------------------------------------------------------------*) 62(* Orchestrate the round computations. *) 63(*---------------------------------------------------------------------------*) 64 65val (RoundTuple_def, RoundTuple_ind) = Defn.tprove 66 (Hol_defn 67 "RoundTuple" 68 `RoundTuple (n, keys:keysched, state:state) = 69 if n=0 70 then (0,ROTKEYS keys, 71 AddRoundKey (FST keys) 72 (ShiftRows (SubBytes state))) 73 else RoundTuple (n-1, ROTKEYS keys, 74 (AddRoundKey (FST keys) 75 (MixColumns (ShiftRows (SubBytes state)))))`, 76 WF_REL_TAC `measure FST` THEN REPEAT PairRules.PGEN_TAC THEN DECIDE_TAC); 77 78val (InvRoundTuple_def,InvRoundTuple_ind) = Defn.tprove 79 (Hol_defn 80 "InvRoundTuple" 81 `InvRoundTuple (n, keys:keysched, state:state) = 82 if n=0 83 then (0,ROTKEYS keys, 84 AddRoundKey (FST keys) 85 (InvSubBytes (InvShiftRows state))) 86 else InvRoundTuple (n-1,ROTKEYS keys, 87 (InvMixColumns 88 (AddRoundKey (FST keys) 89 (InvSubBytes (InvShiftRows state)))))`, 90 WF_REL_TAC `measure FST` THEN REPEAT PairRules.PGEN_TAC THEN DECIDE_TAC); 91 92val _ = save_thm ("RoundTuple_def", RoundTuple_def); 93val _ = save_thm ("RoundTuple_ind", RoundTuple_ind); 94val _ = save_thm ("InvRoundTuple_def", InvRoundTuple_def); 95val _ = save_thm ("InvRoundTuple_ind", InvRoundTuple_ind); 96 97val Round_def = Define `Round n k s = SND(SND(RoundTuple(n,k,s)))`; 98val InvRound_def = Define `InvRound n k s = SND(SND(InvRoundTuple(n,k,s)))`; 99 100(*---------------------------------------------------------------------------*) 101(* Encrypt and Decrypt *) 102(*---------------------------------------------------------------------------*) 103 104val AES_FWD_def = 105 Define 106 `AES_FWD keys = 107 from_state o Round 9 (ROTKEYS keys) 108 o AddRoundKey (FST keys) o to_state`; 109 110val AES_BWD_def = 111 Define 112 `AES_BWD keys = 113 from_state o InvRound 9 (ROTKEYS keys) 114 o AddRoundKey (FST keys) o to_state`; 115 116(*---------------------------------------------------------------------------*) 117(* Main lemma *) 118(*---------------------------------------------------------------------------*) 119 120val [MultCol] = decls "MultCol"; 121val [InvMultCol] = decls "InvMultCol"; 122val [genMixColumns] = decls "genMixColumns"; 123 124val AES_LEMMA = Q.store_thm 125("AES_LEMMA", 126 `!(plaintext:state) (keys:keysched). 127 AES_BWD (REVKEYS keys) (AES_FWD keys plaintext) = plaintext`, 128 SIMP_TAC std_ss [FORALL_BLOCK] THEN 129 SIMP_TAC std_ss [FORALL_KEYSCHED] 130 THEN RESTR_EVAL_TAC [MultCol,InvMultCol,genMixColumns] 131 THEN RW_TAC std_ss [ShiftRows_Inversion,SubBytes_Inversion, 132 XOR_BLOCK_IDEM,MixColumns_Inversion, 133 from_state_Inversion,from_state_def]); 134 135(*--------------------------------------------------------------------------- 136 Generate the key schedule from key. We work using 4-tuples of 137 bytes. Unpacking moves from four contiguous 4-tuples to a 16-tuple, 138 and also lays the bytes out in the top-to-bottom, left-to-right 139 order that the state also has. 140 ---------------------------------------------------------------------------*) 141 142val _ = set_fixity "XOR8x4" (Infixr 350); 143 144val XOR8x4_def = 145 Define 146 `(a,b,c,d) XOR8x4 (a1,b1,c1,d1) = (a ?? a1, b ?? b1, c ?? c1, d ?? d1)`; 147 148val SubWord_def = Define 149 `SubWord(b0,b1,b2,b3) = (Sbox b0, Sbox b1, Sbox b2, Sbox b3)`; 150 151val RotWord_def = Define 152 `RotWord(b0,b1,b2,b3) = (b1,b2,b3,b0)`; 153 154val Rcon_def = Define 155 `Rcon i = (PolyExp 2w (i-1), 0w,0w,0w)`; 156 157val unpack_def = Define 158 `(unpack [] A = A) /\ 159 (unpack ((a,b,c,d)::(e,f,g,h)::(i,j,k,l)::(m,n,o1,p)::rst) A 160 = unpack rst ((m,i,e,a,n,j,f,b,o1,k,g,c,p,l,h,d)::A))`; 161 162(*---------------------------------------------------------------------------*) 163(* Build the keyschedule from a key. This definition is too specific, but *) 164(* works fine for 128 bit blocks. *) 165(*---------------------------------------------------------------------------*) 166 167val (expand_def,expand_ind) = 168Defn.tprove 169 (Hol_defn 170 "expand" 171 `expand n sched = 172 if 43 < n then unpack sched [] 173 else let h = HD sched in 174 let h' = if ~(n MOD 4 = 0) then h 175 else SubWord(RotWord h) XOR8x4 Rcon(n DIV 4) 176 in expand (n+1) ((h' XOR8x4 (HD(TL(TL(TL sched)))))::sched)`, 177 WF_REL_TAC `measure ($- 44 o FST)`); 178 179 180val _ = save_thm ("expand_def", expand_def); 181val _ = save_thm ("expand_ind", expand_ind); 182val _ = computeLib.add_persistent_funs ["expand_def"]; 183 184val mk_keysched_def = Define 185 `mk_keysched ((b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,b15):key) 186 = 187 expand 4 [(b12,b13,b14,b15) ; (b8,b9,b10,b11) ; 188 (b4,b5,b6,b7) ; (b0,b1,b2,b3)]`; 189 190 191(*---------------------------------------------------------------------------*) 192(* Sanity check *) 193(*---------------------------------------------------------------------------*) 194(* 195val PolyExp = Q.prove 196(`(PolyExp x 0 = 1w) /\ 197 (PolyExp x (SUC n) = x ** PolyExp x n)`, 198*) 199 200(* 201val keysched_length = Count.apply Q.prove 202(`!key. LENGTH (mk_keysched key) = 11`, 203 SIMP_TAC std_ss [FORALL_BLOCK,mk_keysched_def] 204 THEN REPEAT GEN_TAC 205 THEN NTAC 42 206 (fn x => (RW_TAC list_ss [Once expand_def,LET_THM] 207 THEN FULL_SIMP_TAC list_ss [markerTheory.Abbrev_def] 208 THEN RW_TAC list_ss [XOR8x4_def, SubWord_def, RotWord_def, Rcon_def, 209 tablesTheory.Sbox_def,MultTheory.PolyExp_def]) x) 210 THEN RULE_ASSUM_TAC (REWRITE_RULE [markerTheory.Abbrev_def]) 211 THEN NTAC 20 (POP_ASSUM SUBST_ALL_TAC) 212 THEN NTAC 5 (POP_ASSUM SUBST_ALL_TAC) 213 THEN NTAC 5 (POP_ASSUM SUBST_ALL_TAC) 214RW_TAC std_ss [unpack_def] 215RW_TAC list_ss [unpack_def] 216*) 217 218 219(*---------------------------------------------------------------------------*) 220(* Generate key schedule, and its inverse, then build the encryption and *) 221(* decryption functions. Called AES, since it wraps everything up into a *) 222(* single package. *) 223(*---------------------------------------------------------------------------*) 224 225val AES_def = Define 226 `AES key = 227 let keys = LIST_TO_KEYS (mk_keysched key) DUMMY_KEYS 228 in (AES_FWD keys, AES_BWD (REVKEYS keys))`; 229 230(*---------------------------------------------------------------------------*) 231(* Basic theorem about encryption/decryption *) 232(*---------------------------------------------------------------------------*) 233 234val AES_CORRECT = Q.store_thm 235 ("AES_CORRECT", 236 `!key plaintext. 237 ((encrypt,decrypt) = AES key) 238 ==> 239 (decrypt (encrypt plaintext) = plaintext)`, 240 RW_TAC std_ss [AES_def,LET_THM,AES_LEMMA]); 241 242 243val _ = export_theory(); 244