1(*===========================================================================*)
2(* Definition of the encryption and decryption algorithms plus               *)
3(* proof of correctness.                                                     *)
4(*===========================================================================*)
5
6(*
7  app load ["RoundOpTheory", "metisLib"];
8*)
9open HolKernel Parse boolLib bossLib RoundOpTheory pairTheory metisLib;
10
11(*---------------------------------------------------------------------------*)
12(* Make bindings to pre-existing stuff                                       *)
13(*---------------------------------------------------------------------------*)
14
15val RESTR_EVAL_TAC = computeLib.RESTR_EVAL_TAC;
16
17val _ = new_theory "aes";
18
19(*---------------------------------------------------------------------------*)
20(* The keyschedule can be represented as a circular buffer of fixed size.    *)
21(* It has 11 keys (blocks) in it, and the buffer gets rotated each time      *)
22(* a key is taken from it.                                                   *)
23(*---------------------------------------------------------------------------*)
24
25val _ =
26  type_abbrev ("keysched", Type`:key#key#key#key#key#key#key#key#key#key#key`);
27
28val FORALL_KEYSCHED = Q.prove
29(`(!x:keysched. P x) = !k1 k2 k3 k4 k5 k6 k7 k8 k9 k10 k11.
30                        P(k1,k2,k3,k4,k5,k6,k7,k8,k9,k10,k11)`,
31 EQ_TAC THEN RW_TAC std_ss [] THEN
32 `?a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 a11.
33     x = (a1,a2,a3,a4,a5,a6,a7,a8,a9,a10,a11)`
34   by METIS_TAC [ABS_PAIR_THM]
35 THEN ASM_REWRITE_TAC[]);
36
37
38val ROTKEYS_def =
39 Define
40   `ROTKEYS (k0,k1,k2,k3,k4,k5,k6,k7,k8,k9,k10) =
41            (k1,k2,k3,k4,k5,k6,k7,k8,k9,k10,k0) : keysched`;
42
43val REVKEYS_def =
44 Define
45   `REVKEYS (k0,k1,k2,k3,k4,k5,k6,k7,k8,k9,k10) =
46            (k10,k9,k8,k7,k6,k5,k4,k3,k2,k1,k0) : keysched`;
47
48val LIST_TO_KEYS_def =
49 Define
50  `(LIST_TO_KEYS [] acc = acc) /\
51   (LIST_TO_KEYS (h::t) (k1,k2,k3,k4,k5,k6,k7,k8,k9,k10,k11) =
52         LIST_TO_KEYS t (h,k1,k2,k3,k4,k5,k6,k7,k8,k9,k10))`;
53
54val DUMMY_KEYS_def =
55 Define
56  `DUMMY_KEYS = (ZERO_BLOCK,ZERO_BLOCK,ZERO_BLOCK,ZERO_BLOCK,
57                 ZERO_BLOCK, ZERO_BLOCK,ZERO_BLOCK,ZERO_BLOCK,
58                 ZERO_BLOCK,ZERO_BLOCK,ZERO_BLOCK)`;
59
60(*---------------------------------------------------------------------------*)
61(* Orchestrate the round computations.                                       *)
62(*---------------------------------------------------------------------------*)
63
64
65val (Round_def, Round_ind) = Defn.tprove
66 (Hol_defn
67   "Round"
68   `Round n (keys:keysched) (state:state) =
69     if n=0
70      then AddRoundKey (FST keys) (ShiftRows (SubBytes state))
71      else Round (n-1) (ROTKEYS keys)
72            (AddRoundKey (FST keys)
73              (MixColumns (ShiftRows (SubBytes state))))`,
74  WF_REL_TAC `measure FST` THEN REPEAT PairRules.PGEN_TAC THEN DECIDE_TAC);
75
76
77val (InvRound_def,InvRound_ind) = Defn.tprove
78 (Hol_defn
79   "InvRound"
80   `InvRound n keys state =
81      if n=0
82       then AddRoundKey (FST keys) (InvSubBytes (InvShiftRows state))
83       else InvRound (n-1) (ROTKEYS keys)
84             (InvMixColumns
85               (AddRoundKey (FST keys) (InvSubBytes (InvShiftRows state))))`,
86  WF_REL_TAC `measure FST` THEN REPEAT PairRules.PGEN_TAC THEN DECIDE_TAC);
87
88val _ = save_thm ("Round_def", Round_def);
89val _ = save_thm ("Round_ind", Round_ind);
90val _ = save_thm ("InvRound_def", InvRound_def);
91val _ = save_thm ("InvRound_ind", InvRound_ind);
92
93(*---------------------------------------------------------------------------*)
94(* Encrypt and Decrypt                                                       *)
95(*---------------------------------------------------------------------------*)
96
97val AES_FWD_def =
98 Define
99  `AES_FWD keys =
100    from_state o Round 9 (ROTKEYS keys) o AddRoundKey (FST keys) o to_state`;
101
102
103val AES_BWD_def =
104 Define
105  `AES_BWD keys =
106    from_state o InvRound 9 (ROTKEYS keys) o AddRoundKey (FST keys) o to_state`;
107
108(*---------------------------------------------------------------------------*)
109(* Main lemma                                                                *)
110(*---------------------------------------------------------------------------*)
111
112val [MultCol] = decls "MultCol";
113val [InvMultCol] = decls "InvMultCol";
114val [genMixColumns] = decls "genMixColumns";
115
116val AES_LEMMA = Q.store_thm
117("AES_LEMMA",
118 `!(plaintext:state) (keys:keysched).
119     AES_BWD (REVKEYS keys) (AES_FWD keys plaintext) = plaintext`,
120 SIMP_TAC std_ss [FORALL_BLOCK] THEN
121 SIMP_TAC std_ss [FORALL_KEYSCHED]
122   THEN RESTR_EVAL_TAC [MultCol,InvMultCol,genMixColumns]
123   THEN RW_TAC std_ss [ShiftRows_Inversion,SubBytes_Inversion,
124                       XOR_BLOCK_IDEM,MixColumns_Inversion,
125                       from_state_Inversion,from_state_def]);
126
127
128(*---------------------------------------------------------------------------
129     Generate the key schedule from key. We work using 4-tuples of
130     bytes. Unpacking moves from four contiguous 4-tuples to a 16-tuple,
131     and also lays the bytes out in the top-to-bottom, left-to-right
132     order that the state also has.
133 ---------------------------------------------------------------------------*)
134
135val XOR8x4_def =
136 Define
137   `(a,b,c,d) XOR8x4 (a1,b1,c1,d1) = (a # a1, b # b1, c # c1, d # d1)`;
138
139val SubWord_def = Define
140   `SubWord(b0,b1,b2,b3) = (Sbox b0, Sbox b1, Sbox b2, Sbox b3)`;
141
142val RotWord_def = Define
143   `RotWord(b0,b1,b2,b3) = (b1,b2,b3,b0)`;
144
145val Rcon_def = Define
146   `Rcon i = (PolyExp TWO (i-1), ZERO,ZERO,ZERO)`;
147
148val unpack_def = Define
149  `(unpack [] A = A) /\
150   (unpack ((a,b,c,d)::(e,f,g,h)::(i,j,k,l)::(m,n,o1,p)::rst) A
151        = unpack rst ((m,i,e,a,n,j,f,b,o1,k,g,c,p,l,h,d)::A))`;
152
153(*---------------------------------------------------------------------------*)
154(* Build the keyschedule from a key. This definition is too specific, but    *)
155(* works fine for 128 bit blocks.                                            *)
156(*---------------------------------------------------------------------------*)
157
158val (expand_def,expand_ind) =
159Defn.tprove
160 (Hol_defn
161   "expand"
162   `expand n sched =
163      if 43 < n then unpack sched []
164      else let h = HD sched in
165           let h' = if ~(n MOD 4 = 0) then h
166                       else SubWord(RotWord h) XOR8x4 Rcon(n DIV 4)
167           in expand (n+1) ((h' XOR8x4 (HD(TL(TL(TL sched)))))::sched)`,
168  WF_REL_TAC `measure ($- 44 o FST)`);
169
170
171val _ = save_thm ("expand_def", expand_def);
172val _ = save_thm ("expand_ind", expand_ind);
173val _ = computeLib.add_persistent_funs ["expand_def"];
174
175val mk_keysched_def = Define
176 `mk_keysched ((b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,b15):key)
177      =
178  expand 4 [(b12,b13,b14,b15) ; (b8,b9,b10,b11) ;
179            (b4,b5,b6,b7)     ; (b0,b1,b2,b3)]`;
180
181
182(*---------------------------------------------------------------------------*)
183(* Sanity check                                                              *)
184(*---------------------------------------------------------------------------*)
185
186val keysched_length = Count.apply Q.prove
187(`!key. LENGTH (mk_keysched key) = 11`,
188 SIMP_TAC std_ss [FORALL_BLOCK,mk_keysched_def]
189  THEN REPEAT GEN_TAC
190  THEN EVAL_TAC);
191
192(*---------------------------------------------------------------------------*)
193(* Generate key schedule, and its inverse, then build the encryption and     *)
194(* decryption functions. Called AES, since it wraps everything up into a     *)
195(* single package.                                                           *)
196(*---------------------------------------------------------------------------*)
197
198val AES_def = Define
199 `AES key =
200   let keys = LIST_TO_KEYS (mk_keysched key) DUMMY_KEYS
201   in (AES_FWD keys, AES_BWD (REVKEYS keys))`;
202
203
204(*---------------------------------------------------------------------------*)
205(* Basic theorem about encryption/decryption                                 *)
206(*---------------------------------------------------------------------------*)
207
208val AES_CORRECT = Q.store_thm
209  ("AES_CORRECT",
210   `!key plaintext.
211      ((encrypt,decrypt) = AES key)
212      ==>
213       (decrypt (encrypt plaintext) = plaintext)`,
214 RW_TAC std_ss [AES_def,LET_THM,AES_LEMMA]);
215
216
217val _ = export_theory();
218