1(*---------------------------------------------------------------------------*)
2(* Operations performed in a round:                                          *)
3(*                                                                           *)
4(*    - applying Sboxes                                                      *)
5(*    - shifting rows                                                        *)
6(*    - mixing columns                                                       *)
7(*    - adding round keys                                                    *)
8(*                                                                           *)
9(* We prove "inversion" theorems for each of these                           *)
10(*                                                                           *)
11(*---------------------------------------------------------------------------*)
12
13(* For interactive work
14  quietdec := true;
15  app load ["sboxTheory","metisLib","MultTheory"];
16  open word8Theory pairTheory metisLib;
17  quietdec := false;
18*)
19
20open HolKernel Parse boolLib bossLib
21     pairTools numLib metisLib pairTheory word8Theory tablesTheory MultTheory;
22
23(*---------------------------------------------------------------------------*)
24(* Make bindings to pre-existing stuff                                       *)
25(*---------------------------------------------------------------------------*)
26
27val RESTR_EVAL_TAC = computeLib.RESTR_EVAL_TAC;
28
29val Sbox_Inversion = tablesTheory.Sbox_Inversion;
30
31(*---------------------------------------------------------------------------*)
32(* Create the theory.                                                        *)
33(*---------------------------------------------------------------------------*)
34
35val _ = new_theory "RoundOp";
36
37
38(*---------------------------------------------------------------------------*)
39(* A block is 16 bytes. A state also has that type, although states have     *)
40(* a special format.                                                         *)
41(*---------------------------------------------------------------------------*)
42
43val _ = type_abbrev("block",
44                    Type`:word8 # word8 # word8 # word8 #
45                          word8 # word8 # word8 # word8 #
46                          word8 # word8 # word8 # word8 #
47                          word8 # word8 # word8 # word8`);
48
49val _ = type_abbrev("state", Type`:block`);
50val _ = type_abbrev("key",   Type`:state`);
51val _ = type_abbrev("w8x4",  Type`:word8 # word8 # word8 # word8`);
52
53
54val ZERO_BLOCK_def = Define
55 `ZERO_BLOCK = (ZERO,ZERO,ZERO,ZERO,ZERO,ZERO,ZERO,ZERO,
56                ZERO,ZERO,ZERO,ZERO,ZERO,ZERO,ZERO,ZERO) : block`;
57
58(*---------------------------------------------------------------------------*)
59(* Case analysis on a block.                                                 *)
60(*---------------------------------------------------------------------------*)
61
62val FORALL_BLOCK = Q.store_thm
63("FORALL_BLOCK",
64 `(!b:block. P b) =
65   !w1 w2 w3 w4 w5 w6 w7 w8 w9 w10 w11 w12 w13 w14 w15 w16.
66    P (w1,w2,w3,w4,w5,w6,w7,w8,w9,w10,w11,w12,w13,w14,w15,w16)`,
67 SIMP_TAC std_ss [FORALL_PROD]);
68
69(*---------------------------------------------------------------------------*)
70(* XOR on blocks. Definition and algebraic properties.                       *)
71(*---------------------------------------------------------------------------*)
72
73val XOR_BLOCK_def = Define
74 `XOR_BLOCK ((a0,a1,a2,a3,a4,a5,a6,a7,a8,a9,a10,a11,a12,a13,a14,a15):block)
75            ((b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,b15):block)
76       =
77      (a0 # b0,   a1 # b1,   a2 # b2,   a3 # b3,
78       a4 # b4,   a5 # b5,   a6 # b6,   a7 # b7,
79       a8 # b8,   a9 # b9,   a10 # b10, a11 # b11,
80       a12 # b12, a13 # b13, a14 # b14, a15 # b15)`;
81
82val XOR_BLOCK_ZERO = Q.store_thm
83("XOR_BLOCK_ZERO",
84 `!x:block. XOR_BLOCK x ZERO_BLOCK = x`,
85 SIMP_TAC std_ss [FORALL_BLOCK,XOR_BLOCK_def, ZERO_BLOCK_def, XOR8_ZERO]);
86
87val XOR_BLOCK_INV = Q.store_thm
88("XOR_BLOCK_INV",
89 `!x:block. XOR_BLOCK x x = ZERO_BLOCK`,
90 SIMP_TAC std_ss [FORALL_BLOCK,XOR_BLOCK_def, ZERO_BLOCK_def, XOR8_INV]);
91
92val XOR_BLOCK_AC = Q.store_thm
93("XOR_BLOCK_AC",
94 `(!x y z:block. XOR_BLOCK (XOR_BLOCK x y) z = XOR_BLOCK x (XOR_BLOCK y z)) /\
95  (!x y:block. XOR_BLOCK x y = XOR_BLOCK y x)`,
96 SIMP_TAC std_ss [FORALL_BLOCK,XOR_BLOCK_def, XOR8_AC]);
97
98val [a,c] = CONJUNCTS XOR8_AC;
99
100val XOR_BLOCK_IDEM = Q.store_thm
101("XOR_BLOCK_IDEM",
102 `(!v u. XOR_BLOCK (XOR_BLOCK v u) u = v) /\
103  (!v u. XOR_BLOCK v (XOR_BLOCK v u) = u)`,
104 METIS_TAC [XOR_BLOCK_INV,XOR_BLOCK_AC,XOR_BLOCK_ZERO]);
105
106
107(*---------------------------------------------------------------------------*)
108(*    Moving data into and out of a state                                    *)
109(*---------------------------------------------------------------------------*)
110
111val to_state_def = Define
112 `to_state ((b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,b15) :block)
113                =
114            (b0,b4,b8,b12,
115             b1,b5,b9,b13,
116             b2,b6,b10,b14,
117             b3,b7,b11,b15) : state`;
118
119val from_state_def = Define
120 `from_state((b0,b4,b8,b12,
121              b1,b5,b9,b13,
122              b2,b6,b10,b14,
123              b3,b7,b11,b15) :state)
124 = (b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,b15) : block`;
125
126
127val to_state_Inversion = Q.store_thm
128  ("to_state_Inversion",
129   `!s:state. from_state(to_state s) = s`,
130   SIMP_TAC std_ss [FORALL_BLOCK, from_state_def, to_state_def]);
131
132
133val from_state_Inversion = Q.store_thm
134  ("from_state_Inversion",
135   `!s:state. to_state(from_state s) = s`,
136   SIMP_TAC std_ss [FORALL_BLOCK, from_state_def, to_state_def]);
137
138
139(*---------------------------------------------------------------------------*)
140(*    Apply an Sbox to the state                                             *)
141(*---------------------------------------------------------------------------*)
142
143val _ = Parse.hide "S";   (* to make parameter S a variable *)
144
145val genSubBytes_def = try Define
146  `genSubBytes S ((b00,b01,b02,b03,
147                   b10,b11,b12,b13,
148                   b20,b21,b22,b23,
149                   b30,b31,b32,b33) :state)
150                          =
151             (S b00, S b01, S b02, S b03,
152              S b10, S b11, S b12, S b13,
153              S b20, S b21, S b22, S b23,
154              S b30, S b31, S b32, S b33) :state`;
155
156val _ = Parse.reveal "S";
157
158val SubBytes_def    = Define `SubBytes = genSubBytes Sbox`;
159val InvSubBytes_def = Define `InvSubBytes = genSubBytes InvSbox`;
160
161val SubBytes_Inversion = Q.store_thm
162("SubBytes_Inversion",
163 `!s:state. genSubBytes InvSbox (genSubBytes Sbox s) = s`,
164 SIMP_TAC std_ss [FORALL_BLOCK,genSubBytes_def,Sbox_Inversion]);
165
166
167(*---------------------------------------------------------------------------
168    Left-shift the first row not at all, the second row by 1, the
169    third row by 2, and the last row by 3. And the inverse operation.
170 ---------------------------------------------------------------------------*)
171
172val ShiftRows_def = Define
173  `ShiftRows ((b00,b01,b02,b03,
174               b10,b11,b12,b13,
175               b20,b21,b22,b23,
176               b30,b31,b32,b33) :state)
177                     =
178             (b00,b01,b02,b03,
179              b11,b12,b13,b10,
180              b22,b23,b20,b21,
181              b33,b30,b31,b32) :state`;
182
183val InvShiftRows_def = Define
184  `InvShiftRows ((b00,b01,b02,b03,
185                  b11,b12,b13,b10,
186                  b22,b23,b20,b21,
187                  b33,b30,b31,b32) :state)
188                     =
189                (b00,b01,b02,b03,
190                 b10,b11,b12,b13,
191                 b20,b21,b22,b23,
192                 b30,b31,b32,b33) :state`;
193
194(*---------------------------------------------------------------------------
195        InvShiftRows inverts ShiftRows
196 ---------------------------------------------------------------------------*)
197
198val ShiftRows_Inversion = Q.store_thm
199("ShiftRows_Inversion",
200 `!s:state. InvShiftRows (ShiftRows s) = s`,
201 SIMP_TAC std_ss [FORALL_BLOCK] THEN REPEAT STRIP_TAC THEN EVAL_TAC);
202
203
204(*---------------------------------------------------------------------------*)
205(* For alternative decryption scheme                                         *)
206(*---------------------------------------------------------------------------*)
207
208val ShiftRows_SubBytes_Commute = Q.store_thm
209 ("ShiftRows_SubBytes_Commute",
210  `!s. ShiftRows (SubBytes s) = SubBytes (ShiftRows s)`,
211 SIMP_TAC std_ss [FORALL_BLOCK] THEN REPEAT STRIP_TAC THEN EVAL_TAC);
212
213
214val InvShiftRows_InvSubBytes_Commute = Q.store_thm
215 ("InvShiftRows_InvSubBytes_Commute",
216  `!s. InvShiftRows (InvSubBytes s) = InvSubBytes (InvShiftRows s)`,
217 SIMP_TAC std_ss [FORALL_BLOCK] THEN REPEAT STRIP_TAC THEN EVAL_TAC);
218
219
220(*---------------------------------------------------------------------------
221        Column multiplication and its inverse
222 ---------------------------------------------------------------------------*)
223
224val MultCol_def = Define
225 `MultCol (a,b,c,d) =
226   ((TWO ** a)   # (THREE ** b) #  c           # d,
227     a           # (TWO ** b)   # (THREE ** c) # d,
228     a           #  b           # (TWO ** c)   # (THREE ** d),
229    (THREE ** a) #  b           #  c           # (TWO ** d))`;
230
231val InvMultCol_def = Define
232 `InvMultCol (a,b,c,d) =
233   ((E_HEX ** a) # (B_HEX ** b) # (D_HEX ** c) # (NINE  ** d),
234    (NINE  ** a) # (E_HEX ** b) # (B_HEX ** c) # (D_HEX ** d),
235    (D_HEX ** a) # (NINE  ** b) # (E_HEX ** c) # (B_HEX ** d),
236    (B_HEX ** a) # (D_HEX ** b) # (NINE  ** c) # (E_HEX ** d))`;
237
238(*---------------------------------------------------------------------------*)
239(* Inversion lemmas for column multiplication. Proved with an ad-hoc tactic  *)
240(*                                                                           *)
241(* Note: could just use case analysis with Sbox_ind, then EVAL_TAC, but      *)
242(* that's far slower.                                                        *)
243(*---------------------------------------------------------------------------*)
244
245val BYTE_CASES_TAC =
246 Ho_Rewrite.ONCE_REWRITE_TAC [FORALL_BYTE_VARS] THEN EVAL_TAC
247 THEN REWRITE_TAC [REWRITE_RULE [ZERO_def] XOR8_ZERO]
248 THEN Cases THEN PURE_REWRITE_TAC [COND_CLAUSES]
249 THEN REPEAT Cases
250 THEN EVAL_TAC;
251
252val lemma_a1 = Q.prove
253(`!a. E_HEX ** (TWO ** a) # B_HEX ** a # D_HEX ** a # NINE ** (THREE ** a) = a`,
254 BYTE_CASES_TAC);
255
256val lemma_a2 = Q.prove
257(`!b. E_HEX ** (THREE ** b) # B_HEX ** (TWO ** b) # D_HEX ** b # NINE  ** b = ZERO`,
258 BYTE_CASES_TAC);
259
260val lemma_a3 = Q.prove
261(`!c. E_HEX ** c # B_HEX ** (THREE ** c) # D_HEX ** (TWO ** c) # NINE ** c = ZERO`,
262 BYTE_CASES_TAC);
263
264val lemma_a4 = Count.apply Q.prove
265(`!d. E_HEX ** d # B_HEX ** d # D_HEX ** (THREE ** d) # NINE ** (TWO ** d) = ZERO`,
266 BYTE_CASES_TAC);
267
268val lemma_b1 = Q.prove
269(`!a. NINE ** (TWO ** a) # E_HEX ** a # B_HEX ** a # D_HEX  ** (THREE ** a) = ZERO`,
270 BYTE_CASES_TAC);
271
272val lemma_b2 = Q.prove
273(`!b. NINE ** (THREE ** b) # E_HEX ** (TWO ** b) # B_HEX ** b # D_HEX ** b = b`,
274 BYTE_CASES_TAC);
275
276val lemma_b3 = Q.prove
277(`!c. NINE ** c # E_HEX ** (THREE ** c) # B_HEX ** (TWO ** c) # D_HEX ** c = ZERO`,
278 BYTE_CASES_TAC);
279
280val lemma_b4 = Count.apply Q.prove
281(`!d. NINE ** d # E_HEX ** d # B_HEX ** (THREE ** d) # D_HEX ** (TWO ** d) = ZERO`,
282 BYTE_CASES_TAC);
283
284val lemma_c1 = Q.prove
285(`!a. D_HEX ** (TWO ** a) # NINE ** a # E_HEX ** a # B_HEX  ** (THREE ** a) = ZERO`,
286 BYTE_CASES_TAC THEN EVAL_TAC);
287
288val lemma_c2 = Q.prove
289(`!b. D_HEX ** (THREE ** b) # NINE ** (TWO ** b) # E_HEX ** b # B_HEX ** b = ZERO`,
290 BYTE_CASES_TAC);
291
292val lemma_c3 = Q.prove
293(`!c. D_HEX ** c # NINE ** (THREE ** c) # E_HEX ** (TWO ** c) # B_HEX ** c = c`,
294 BYTE_CASES_TAC);
295
296val lemma_c4 = Count.apply Q.prove
297(`!d. D_HEX ** d # NINE ** d # E_HEX ** (THREE ** d) # B_HEX ** (TWO ** d) = ZERO`,
298 BYTE_CASES_TAC);
299
300val lemma_d1 = Q.prove
301(`!a. B_HEX ** (TWO ** a) # D_HEX ** a # NINE ** a # E_HEX  ** (THREE ** a) = ZERO`,
302 BYTE_CASES_TAC);
303
304val lemma_d2 = Q.prove
305(`!b. B_HEX ** (THREE ** b) # D_HEX ** (TWO ** b) # NINE ** b # E_HEX ** b = ZERO`,
306 BYTE_CASES_TAC);
307
308val lemma_d3 = Q.prove
309(`!c. B_HEX ** c # D_HEX ** (THREE ** c) # NINE ** (TWO ** c) # E_HEX ** c = ZERO`,
310 BYTE_CASES_TAC THEN EVAL_TAC);
311
312val lemma_d4 = Count.apply Q.prove
313(`!d. B_HEX ** d # D_HEX ** d # NINE ** (THREE ** d) # E_HEX ** (TWO ** d) = d`,
314 BYTE_CASES_TAC);
315
316(*---------------------------------------------------------------------------*)
317(* The following lemma is hideous to prove without permutative rewriting     *)
318(*---------------------------------------------------------------------------*)
319
320val rearrange_xors = Q.prove
321(`(a1 # b1 # c1 # d1) #
322  (a2 # b2 # c2 # d2) #
323  (a3 # b3 # c3 # d3) #
324  (a4 # b4 # c4 # d4)
325     =
326  (a1 # a2 # a3 # a4) #
327  (b1 # b2 # b3 # b4) #
328  (c1 # c2 # c3 # c4) #
329  (d1 # d2 # d3 # d4)`,
330 RW_TAC std_ss [AC a c]);
331
332val mix_lemma1 = Q.prove
333(`!a b c d.
334   (E_HEX ** ((TWO ** a) # (THREE ** b) # c # d)) #
335   (B_HEX ** (a # (TWO ** b) # (THREE ** c) # d)) #
336   (D_HEX ** (a # b # (TWO ** c) # (THREE ** d))) #
337   (NINE  ** ((THREE ** a) # b # c # (TWO ** d)))
338      = a`,
339 RW_TAC std_ss [ConstMultDistrib]
340   THEN ONCE_REWRITE_TAC [rearrange_xors]
341   THEN RW_TAC std_ss [lemma_a1,lemma_a2,lemma_a3,lemma_a4,XOR8_ZERO]);
342
343val mix_lemma2 = Q.prove
344(`!a b c d.
345   (NINE  ** ((TWO ** a) # (THREE ** b) # c # d)) #
346   (E_HEX ** (a # (TWO ** b) # (THREE ** c) # d)) #
347   (B_HEX ** (a # b # (TWO ** c) # (THREE ** d))) #
348   (D_HEX ** ((THREE ** a) # b # c # (TWO ** d)))
349     = b`,
350 RW_TAC std_ss [ConstMultDistrib]
351   THEN ONCE_REWRITE_TAC [rearrange_xors]
352   THEN RW_TAC std_ss [lemma_b1,lemma_b2,lemma_b3,lemma_b4,
353                       XOR8_ZERO, ONCE_REWRITE_RULE [XOR8_AC] XOR8_ZERO]);
354
355val mix_lemma3 = Q.prove
356(`!a b c d.
357   (D_HEX ** ((TWO ** a) # (THREE ** b) # c # d)) #
358   (NINE  ** (a # (TWO ** b) # (THREE ** c) # d)) #
359   (E_HEX ** (a # b # (TWO ** c) # (THREE ** d))) #
360   (B_HEX ** ((THREE ** a) # b # c # (TWO ** d)))
361     = c`,
362 RW_TAC std_ss [ConstMultDistrib]
363   THEN ONCE_REWRITE_TAC [rearrange_xors]
364   THEN RW_TAC std_ss [lemma_c1,lemma_c2,lemma_c3,lemma_c4,
365                       XOR8_ZERO, ONCE_REWRITE_RULE [XOR8_AC] XOR8_ZERO]);
366
367val mix_lemma4 = Q.prove
368(`!a b c d.
369   (B_HEX ** ((TWO ** a) # (THREE ** b) # c # d)) #
370   (D_HEX ** (a # (TWO ** b) # (THREE ** c) # d)) #
371   (NINE  ** (a # b # (TWO ** c) # (THREE ** d))) #
372   (E_HEX ** ((THREE ** a) # b # c # (TWO ** d)))
373     = d`,
374 RW_TAC std_ss [ConstMultDistrib]
375   THEN ONCE_REWRITE_TAC [rearrange_xors]
376   THEN RW_TAC std_ss [lemma_d1,lemma_d2,lemma_d3,lemma_d4,
377                       XOR8_ZERO, ONCE_REWRITE_RULE [XOR8_AC] XOR8_ZERO]);
378
379(*---------------------------------------------------------------------------*)
380(* Get the constants of various definitions                                  *)
381(*---------------------------------------------------------------------------*)
382
383val [mult]     = decls "**";
384val [TWO]      = decls "TWO";
385val [THREE]    = decls "THREE";
386val [NINE]     = decls "NINE";
387val [B_HEX]    = decls "B_HEX";
388val [D_HEX]    = decls "D_HEX";
389val [E_HEX]    = decls "E_HEX";
390
391(*---------------------------------------------------------------------------*)
392(* Mixing columns                                                            *)
393(*---------------------------------------------------------------------------*)
394
395val genMixColumns_def = Define
396 `genMixColumns MC ((b00,b01,b02,b03,
397                     b10,b11,b12,b13,
398                     b20,b21,b22,b23,
399                     b30,b31,b32,b33) :state)
400 = let (b00', b10', b20', b30') = MC (b00,b10,b20,b30) in
401   let (b01', b11', b21', b31') = MC (b01,b11,b21,b31) in
402   let (b02', b12', b22', b32') = MC (b02,b12,b22,b32) in
403   let (b03', b13', b23', b33') = MC (b03,b13,b23,b33)
404   in
405    (b00', b01', b02', b03',
406     b10', b11', b12', b13',
407     b20', b21', b22', b23',
408     b30', b31', b32', b33') : state`;
409
410
411val MixColumns_def    = Define `MixColumns    = genMixColumns MultCol`;
412val InvMixColumns_def = Define `InvMixColumns = genMixColumns InvMultCol`;
413
414val MixColumns_Inversion = Q.store_thm
415("MixColumns_Inversion",
416 `!s. genMixColumns InvMultCol (genMixColumns MultCol s) = s`,
417 SIMP_TAC std_ss [FORALL_BLOCK]
418  THEN RESTR_EVAL_TAC [mult,B_HEX,D_HEX,E_HEX,TWO,THREE,NINE]
419  THEN RW_TAC std_ss [mix_lemma1,mix_lemma2,mix_lemma3,mix_lemma4]);
420
421
422(*---------------------------------------------------------------------------
423    Pairwise XOR the state with the round key
424 ---------------------------------------------------------------------------*)
425
426val AddRoundKey_def = Define `AddRoundKey = XOR_BLOCK`;
427
428(*---------------------------------------------------------------------------*)
429(* For alternative decryption scheme                                         *)
430(*---------------------------------------------------------------------------*)
431
432val InvMixColumns_Distrib = Q.store_thm
433("InvMixColumns_Distrib",
434 `!s k. InvMixColumns (AddRoundKey s k)
435            =
436        AddRoundKey (InvMixColumns s) (InvMixColumns k)`,
437 SIMP_TAC std_ss [FORALL_BLOCK] THEN
438 RW_TAC std_ss [XOR_BLOCK_def, AddRoundKey_def, InvMixColumns_def, LET_THM,
439                genMixColumns_def, InvMultCol_def, ConstMultDistrib, AC a c]);
440
441val _ = export_theory();
442