1(*===========================================================================*)
2(* Definition of the encryption and decryption algorithms plus               *)
3(* proof of correctness.                                                     *)
4(*===========================================================================*)
5
6(*
7  app load ["RoundOpTheory"];
8*)
9open HolKernel Parse boolLib bossLib RoundOpTheory pairTheory;
10
11(*---------------------------------------------------------------------------*)
12(* Make bindings to pre-existing stuff                                       *)
13(*---------------------------------------------------------------------------*)
14
15val RESTR_EVAL_TAC = computeLib.RESTR_EVAL_TAC;
16
17val _ = new_theory "aes";
18
19(*---------------------------------------------------------------------------*)
20(* The keyschedule can be represented as a circular buffer of fixed size.    *)
21(* It has 11 keys (blocks) in it, and the buffer gets rotated each time      *)
22(* a key is taken from it.                                                   *)
23(*---------------------------------------------------------------------------*)
24
25val _ =
26  type_abbrev ("keysched", ``:key#key#key#key#key#key#key#key#key#key#key``);
27
28val FORALL_KEYSCHED = Q.store_thm
29("FORALL_KEYSCHED",
30 `(!x:keysched. P x) = !k1 k2 k3 k4 k5 k6 k7 k8 k9 k10 k11.
31                        P(k1,k2,k3,k4,k5,k6,k7,k8,k9,k10,k11)`,
32 EQ_TAC THEN RW_TAC std_ss [] THEN
33 `?a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 a11.
34     x = (a1,a2,a3,a4,a5,a6,a7,a8,a9,a10,a11)`
35   by METIS_TAC [ABS_PAIR_THM]
36 THEN ASM_REWRITE_TAC[]);
37
38
39val ROTKEYS_def =
40 Define
41   `ROTKEYS (k0,k1,k2,k3,k4,k5,k6,k7,k8,k9,k10) =
42            (k1,k2,k3,k4,k5,k6,k7,k8,k9,k10,k0) : keysched`;
43
44val REVKEYS_def =
45 Define
46   `REVKEYS (k0,k1,k2,k3,k4,k5,k6,k7,k8,k9,k10) =
47            (k10,k9,k8,k7,k6,k5,k4,k3,k2,k1,k0) : keysched`;
48
49val LIST_TO_KEYS_def =
50 Define
51  `(LIST_TO_KEYS [] acc = acc) /\
52   (LIST_TO_KEYS (h::t) (k1,k2,k3,k4,k5,k6,k7,k8,k9,k10,k11) =
53         LIST_TO_KEYS t (h,k1,k2,k3,k4,k5,k6,k7,k8,k9,k10))`;
54
55val DUMMY_KEYS_def =
56 Define
57  `DUMMY_KEYS = (ZERO_BLOCK,ZERO_BLOCK,ZERO_BLOCK,ZERO_BLOCK,
58                 ZERO_BLOCK,ZERO_BLOCK,ZERO_BLOCK,ZERO_BLOCK,
59                 ZERO_BLOCK,ZERO_BLOCK,ZERO_BLOCK)`;
60
61(*---------------------------------------------------------------------------*)
62(* Orchestrate the round computations.                                       *)
63(*---------------------------------------------------------------------------*)
64
65val (RoundTuple_def, RoundTuple_ind) = Defn.tprove
66 (Hol_defn
67   "RoundTuple"
68   `RoundTuple (n, keys:keysched, state:state) =
69     if n=0
70      then (0,ROTKEYS keys,
71            AddRoundKey (FST keys)
72              (ShiftRows (SubBytes state)))
73      else RoundTuple (n-1, ROTKEYS keys,
74            (AddRoundKey (FST keys)
75              (MixColumns (ShiftRows (SubBytes state)))))`,
76  WF_REL_TAC `measure FST` THEN REPEAT PairRules.PGEN_TAC THEN DECIDE_TAC);
77
78val (InvRoundTuple_def,InvRoundTuple_ind) = Defn.tprove
79 (Hol_defn
80   "InvRoundTuple"
81   `InvRoundTuple (n, keys:keysched, state:state) =
82      if n=0
83       then (0,ROTKEYS keys,
84             AddRoundKey (FST keys)
85               (InvSubBytes (InvShiftRows state)))
86       else InvRoundTuple (n-1,ROTKEYS keys,
87             (InvMixColumns
88               (AddRoundKey (FST keys)
89                 (InvSubBytes (InvShiftRows state)))))`,
90  WF_REL_TAC `measure FST` THEN REPEAT PairRules.PGEN_TAC THEN DECIDE_TAC);
91
92val _ = save_thm ("RoundTuple_def", RoundTuple_def);
93val _ = save_thm ("RoundTuple_ind", RoundTuple_ind);
94val _ = save_thm ("InvRoundTuple_def", InvRoundTuple_def);
95val _ = save_thm ("InvRoundTuple_ind", InvRoundTuple_ind);
96
97val Round_def = Define `Round n k s = SND(SND(RoundTuple(n,k,s)))`;
98val InvRound_def = Define `InvRound n k s = SND(SND(InvRoundTuple(n,k,s)))`;
99
100(*---------------------------------------------------------------------------*)
101(* Encrypt and Decrypt                                                       *)
102(*---------------------------------------------------------------------------*)
103
104val AES_FWD_def =
105 Define
106  `AES_FWD keys =
107    from_state o Round 9 (ROTKEYS keys)
108               o AddRoundKey (FST keys) o to_state`;
109
110val AES_BWD_def =
111 Define
112  `AES_BWD keys =
113    from_state o InvRound 9 (ROTKEYS keys)
114               o AddRoundKey (FST keys) o to_state`;
115
116(*---------------------------------------------------------------------------*)
117(* Main lemma                                                                *)
118(*---------------------------------------------------------------------------*)
119
120val [MultCol] = decls "MultCol";
121val [InvMultCol] = decls "InvMultCol";
122val [genMixColumns] = decls "genMixColumns";
123
124val AES_LEMMA = Q.store_thm
125("AES_LEMMA",
126 `!(plaintext:state) (keys:keysched).
127     AES_BWD (REVKEYS keys) (AES_FWD keys plaintext) = plaintext`,
128 SIMP_TAC std_ss [FORALL_BLOCK] THEN
129 SIMP_TAC std_ss [FORALL_KEYSCHED]
130   THEN RESTR_EVAL_TAC [MultCol,InvMultCol,genMixColumns]
131   THEN RW_TAC std_ss [ShiftRows_Inversion,SubBytes_Inversion,
132                       XOR_BLOCK_IDEM,MixColumns_Inversion,
133                       from_state_Inversion,from_state_def]);
134
135(*---------------------------------------------------------------------------
136     Generate the key schedule from key. We work using 4-tuples of
137     bytes. Unpacking moves from four contiguous 4-tuples to a 16-tuple,
138     and also lays the bytes out in the top-to-bottom, left-to-right
139     order that the state also has.
140 ---------------------------------------------------------------------------*)
141
142val _ = set_fixity "XOR8x4"  (Infixr 350);
143
144val XOR8x4_def =
145 Define
146   `(a,b,c,d) XOR8x4 (a1,b1,c1,d1) = (a ?? a1, b ?? b1, c ?? c1, d ?? d1)`;
147
148val SubWord_def = Define
149   `SubWord(b0,b1,b2,b3) = (Sbox b0, Sbox b1, Sbox b2, Sbox b3)`;
150
151val RotWord_def = Define
152   `RotWord(b0,b1,b2,b3) = (b1,b2,b3,b0)`;
153
154val Rcon_def = Define
155   `Rcon i = (PolyExp 2w (i-1), 0w,0w,0w)`;
156
157val unpack_def = Define
158  `(unpack [] A = A) /\
159   (unpack ((a,b,c,d)::(e,f,g,h)::(i,j,k,l)::(m,n,o1,p)::rst) A
160        = unpack rst ((m,i,e,a,n,j,f,b,o1,k,g,c,p,l,h,d)::A))`;
161
162(*---------------------------------------------------------------------------*)
163(* Build the keyschedule from a key. This definition is too specific, but    *)
164(* works fine for 128 bit blocks.                                            *)
165(*---------------------------------------------------------------------------*)
166
167val (expand_def,expand_ind) =
168Defn.tprove
169 (Hol_defn
170   "expand"
171   `expand n sched =
172      if 43 < n then unpack sched []
173      else let h = HD sched in
174           let h' = if ~(n MOD 4 = 0) then h
175                       else SubWord(RotWord h) XOR8x4 Rcon(n DIV 4)
176           in expand (n+1) ((h' XOR8x4 (HD(TL(TL(TL sched)))))::sched)`,
177  WF_REL_TAC `measure ($- 44 o FST)`);
178
179
180val _ = save_thm ("expand_def", expand_def);
181val _ = save_thm ("expand_ind", expand_ind);
182val _ = computeLib.add_persistent_funs ["expand_def"];
183
184val mk_keysched_def = Define
185 `mk_keysched ((b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,b15):key)
186      =
187  expand 4 [(b12,b13,b14,b15) ; (b8,b9,b10,b11) ;
188            (b4,b5,b6,b7)     ; (b0,b1,b2,b3)]`;
189
190
191(*---------------------------------------------------------------------------*)
192(* Sanity check                                                              *)
193(*---------------------------------------------------------------------------*)
194(*
195val PolyExp = Q.prove
196(`(PolyExp x 0 = 1w) /\
197  (PolyExp x (SUC n) = x ** PolyExp x n)`,
198*)
199
200(*
201val keysched_length = Count.apply Q.prove
202(`!key. LENGTH (mk_keysched key) = 11`,
203 SIMP_TAC std_ss [FORALL_BLOCK,mk_keysched_def]
204  THEN REPEAT GEN_TAC
205  THEN NTAC 42
206     (fn x => (RW_TAC list_ss [Once expand_def,LET_THM]
207       THEN FULL_SIMP_TAC list_ss [markerTheory.Abbrev_def]
208       THEN RW_TAC list_ss [XOR8x4_def, SubWord_def, RotWord_def, Rcon_def,
209                            tablesTheory.Sbox_def,MultTheory.PolyExp_def]) x)
210  THEN RULE_ASSUM_TAC (REWRITE_RULE [markerTheory.Abbrev_def])
211  THEN NTAC 20 (POP_ASSUM SUBST_ALL_TAC)
212  THEN NTAC 5 (POP_ASSUM SUBST_ALL_TAC)
213  THEN NTAC 5 (POP_ASSUM SUBST_ALL_TAC)
214RW_TAC std_ss [unpack_def]
215RW_TAC list_ss [unpack_def]
216*)
217
218
219(*---------------------------------------------------------------------------*)
220(* Generate key schedule, and its inverse, then build the encryption and     *)
221(* decryption functions. Called AES, since it wraps everything up into a     *)
222(* single package.                                                           *)
223(*---------------------------------------------------------------------------*)
224
225val AES_def = Define
226 `AES key =
227   let keys = LIST_TO_KEYS (mk_keysched key) DUMMY_KEYS
228   in (AES_FWD keys, AES_BWD (REVKEYS keys))`;
229
230(*---------------------------------------------------------------------------*)
231(* Basic theorem about encryption/decryption                                 *)
232(*---------------------------------------------------------------------------*)
233
234val AES_CORRECT = Q.store_thm
235  ("AES_CORRECT",
236   `!key plaintext.
237      ((encrypt,decrypt) = AES key)
238      ==>
239       (decrypt (encrypt plaintext) = plaintext)`,
240 RW_TAC std_ss [AES_def,LET_THM,AES_LEMMA]);
241
242
243val _ = export_theory();
244