1(*---------------------------------------------------------------------------*)
2(* Operations performed in a round:                                          *)
3(*                                                                           *)
4(*    - applying Sboxes                                                      *)
5(*    - shifting rows                                                        *)
6(*    - mixing columns                                                       *)
7(*    - adding round keys                                                    *)
8(*                                                                           *)
9(* We prove "inversion" theorems for each of these                           *)
10(*                                                                           *)
11(*---------------------------------------------------------------------------*)
12
13(* For interactive work
14  quietdec := true;
15  app load ["metisLib","MultTheory"];
16  open word8Theory pairTheory metisLib;
17  quietdec := false;
18*)
19
20open HolKernel Parse boolLib bossLib
21     metisLib pairTheory word8Theory tablesTheory MultTheory;
22
23(*---------------------------------------------------------------------------*)
24(* Make bindings to pre-existing stuff                                       *)
25(*---------------------------------------------------------------------------*)
26
27val RESTR_EVAL_TAC = computeLib.RESTR_EVAL_TAC;
28
29val Sbox_Inversion = tablesTheory.Sbox_Inversion;
30
31(*---------------------------------------------------------------------------*)
32(* Create the theory.                                                        *)
33(*---------------------------------------------------------------------------*)
34
35val _ = new_theory "RoundOp";
36
37
38(*---------------------------------------------------------------------------*)
39(* A block is 16 bytes. A state also has that type, although states have     *)
40(* a special format.                                                         *)
41(*---------------------------------------------------------------------------*)
42
43val _ = type_abbrev("block",
44                    Type`:word8 # word8 # word8 # word8 #
45                          word8 # word8 # word8 # word8 #
46                          word8 # word8 # word8 # word8 #
47                          word8 # word8 # word8 # word8`);
48
49val _ = type_abbrev("state", Type`:block`);
50val _ = type_abbrev("key",   Type`:state`);
51val _ = type_abbrev("w8x4",  Type`:word8 # word8 # word8 # word8`);
52
53
54val ZERO_BLOCK_def = Define
55 `ZERO_BLOCK = (ZERO,ZERO,ZERO,ZERO,ZERO,ZERO,ZERO,ZERO,
56                ZERO,ZERO,ZERO,ZERO,ZERO,ZERO,ZERO,ZERO) : block`;
57
58(*---------------------------------------------------------------------------*)
59(* Case analysis on a block.                                                 *)
60(*---------------------------------------------------------------------------*)
61
62val FORALL_BLOCK = Q.store_thm
63("FORALL_BLOCK",
64 `(!b:block. P b) =
65   !w1 w2 w3 w4 w5 w6 w7 w8 w9 w10 w11 w12 w13 w14 w15 w16.
66    P (w1,w2,w3,w4,w5,w6,w7,w8,w9,w10,w11,w12,w13,w14,w15,w16)`,
67 SIMP_TAC std_ss [FORALL_PROD]);
68
69(*---------------------------------------------------------------------------*)
70(* XOR on blocks. Definition and algebraic properties.                       *)
71(*---------------------------------------------------------------------------*)
72
73val XOR_BLOCK_def = Define
74 `XOR_BLOCK ((a0,a1,a2,a3,a4,a5,a6,a7,a8,a9,a10,a11,a12,a13,a14,a15):block)
75            ((b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,b15):block)
76       =
77      (a0 # b0,   a1 # b1,   a2 # b2,   a3 # b3,
78       a4 # b4,   a5 # b5,   a6 # b6,   a7 # b7,
79       a8 # b8,   a9 # b9,   a10 # b10, a11 # b11,
80       a12 # b12, a13 # b13, a14 # b14, a15 # b15)`;
81
82val XOR_BLOCK_ZERO = Q.store_thm
83("XOR_BLOCK_ZERO",
84 `!x:block. XOR_BLOCK x ZERO_BLOCK = x`,
85 SIMP_TAC std_ss [FORALL_BLOCK,XOR_BLOCK_def, ZERO_BLOCK_def, XOR8_ZERO]);
86
87val XOR_BLOCK_INV = Q.store_thm
88("XOR_BLOCK_INV",
89 `!x:block. XOR_BLOCK x x = ZERO_BLOCK`,
90 SIMP_TAC std_ss [FORALL_BLOCK,XOR_BLOCK_def, ZERO_BLOCK_def, XOR8_INV]);
91
92val XOR_BLOCK_AC = Q.store_thm
93("XOR_BLOCK_AC",
94 `(!x y z:block. XOR_BLOCK (XOR_BLOCK x y) z = XOR_BLOCK x (XOR_BLOCK y z)) /\
95  (!x y:block. XOR_BLOCK x y = XOR_BLOCK y x)`,
96 SIMP_TAC std_ss [FORALL_BLOCK,XOR_BLOCK_def, XOR8_AC]);
97
98val [a,c] = CONJUNCTS XOR8_AC;
99
100val XOR_BLOCK_IDEM = Q.store_thm
101("XOR_BLOCK_IDEM",
102 `(!v u. XOR_BLOCK (XOR_BLOCK v u) u = v) /\
103  (!v u. XOR_BLOCK v (XOR_BLOCK v u) = u)`,
104 METIS_TAC [XOR_BLOCK_INV,XOR_BLOCK_AC,XOR_BLOCK_ZERO]);
105
106
107(*---------------------------------------------------------------------------*)
108(*    Moving data into and out of a state                                    *)
109(*---------------------------------------------------------------------------*)
110
111val to_state_def = Define
112 `to_state ((b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,b15) :block)
113                =
114            (b0,b4,b8,b12,
115             b1,b5,b9,b13,
116             b2,b6,b10,b14,
117             b3,b7,b11,b15) : state`;
118
119val from_state_def = Define
120 `from_state((b0,b4,b8,b12,
121              b1,b5,b9,b13,
122              b2,b6,b10,b14,
123              b3,b7,b11,b15) :state)
124 = (b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,b15) : block`;
125
126
127val to_state_Inversion = Q.store_thm
128  ("to_state_Inversion",
129   `!s:state. from_state(to_state s) = s`,
130   SIMP_TAC std_ss [FORALL_BLOCK, from_state_def, to_state_def]);
131
132
133val from_state_Inversion = Q.store_thm
134  ("from_state_Inversion",
135   `!s:state. to_state(from_state s) = s`,
136   SIMP_TAC std_ss [FORALL_BLOCK, from_state_def, to_state_def]);
137
138
139(*---------------------------------------------------------------------------*)
140(*    Apply an Sbox to the state                                             *)
141(*---------------------------------------------------------------------------*)
142
143val _ = Parse.hide "S";   (* to make parameter S a variable *)
144
145val genSubBytes_def = try Define
146  `genSubBytes S ((b00,b01,b02,b03,
147                   b10,b11,b12,b13,
148                   b20,b21,b22,b23,
149                   b30,b31,b32,b33) :state)
150                          =
151             (S b00, S b01, S b02, S b03,
152              S b10, S b11, S b12, S b13,
153              S b20, S b21, S b22, S b23,
154              S b30, S b31, S b32, S b33) :state`;
155
156val _ = Parse.reveal "S";
157
158val SubBytes_def    = Define `SubBytes = genSubBytes Sbox`;
159val InvSubBytes_def = Define `InvSubBytes = genSubBytes InvSbox`;
160
161val SubBytes_Inversion = Q.store_thm
162("SubBytes_Inversion",
163 `!s:state. genSubBytes InvSbox (genSubBytes Sbox s) = s`,
164 SIMP_TAC std_ss [FORALL_BLOCK,genSubBytes_def,Sbox_Inversion]);
165
166
167(*---------------------------------------------------------------------------
168    Left-shift the first row not at all, the second row by 1, the
169    third row by 2, and the last row by 3. And the inverse operation.
170 ---------------------------------------------------------------------------*)
171
172val ShiftRows_def = Define
173  `ShiftRows ((b00,b01,b02,b03,
174               b10,b11,b12,b13,
175               b20,b21,b22,b23,
176               b30,b31,b32,b33) :state)
177                     =
178             (b00,b01,b02,b03,
179              b11,b12,b13,b10,
180              b22,b23,b20,b21,
181              b33,b30,b31,b32) :state`;
182
183val InvShiftRows_def = Define
184  `InvShiftRows ((b00,b01,b02,b03,
185                  b11,b12,b13,b10,
186                  b22,b23,b20,b21,
187                  b33,b30,b31,b32) :state)
188                     =
189                (b00,b01,b02,b03,
190                 b10,b11,b12,b13,
191                 b20,b21,b22,b23,
192                 b30,b31,b32,b33) :state`;
193
194(*---------------------------------------------------------------------------
195        InvShiftRows inverts ShiftRows
196 ---------------------------------------------------------------------------*)
197
198val ShiftRows_Inversion = Q.store_thm
199("ShiftRows_Inversion",
200 `!s:state. InvShiftRows (ShiftRows s) = s`,
201 SIMP_TAC std_ss [FORALL_BLOCK] THEN REPEAT STRIP_TAC THEN EVAL_TAC);
202
203
204(*---------------------------------------------------------------------------*)
205(* For alternative decryption scheme                                         *)
206(*---------------------------------------------------------------------------*)
207
208val ShiftRows_SubBytes_Commute = Q.store_thm
209 ("ShiftRows_SubBytes_Commute",
210  `!s. ShiftRows (SubBytes s) = SubBytes (ShiftRows s)`,
211 SIMP_TAC std_ss [FORALL_BLOCK] THEN REPEAT STRIP_TAC THEN EVAL_TAC);
212
213
214val InvShiftRows_InvSubBytes_Commute = Q.store_thm
215 ("InvShiftRows_InvSubBytes_Commute",
216  `!s. InvShiftRows (InvSubBytes s) = InvSubBytes (InvShiftRows s)`,
217 SIMP_TAC std_ss [FORALL_BLOCK] THEN REPEAT STRIP_TAC THEN EVAL_TAC);
218
219
220(*---------------------------------------------------------------------------
221        Column multiplication and its inverse
222 ---------------------------------------------------------------------------*)
223
224val MultCol_def = Define
225 `MultCol (a,b,c,d) =
226   ((TWO ** a)   # (THREE ** b) #  c           # d,
227     a           # (TWO ** b)   # (THREE ** c) # d,
228     a           #  b           # (TWO ** c)   # (THREE ** d),
229    (THREE ** a) #  b           #  c           # (TWO ** d))`;
230
231val InvMultCol_def = Define
232 `InvMultCol (a,b,c,d) =
233   ((E_HEX ** a) # (B_HEX ** b) # (D_HEX ** c) # (NINE  ** d),
234    (NINE  ** a) # (E_HEX ** b) # (B_HEX ** c) # (D_HEX ** d),
235    (D_HEX ** a) # (NINE  ** b) # (E_HEX ** c) # (B_HEX ** d),
236    (B_HEX ** a) # (D_HEX ** b) # (NINE  ** c) # (E_HEX ** d))`;
237
238(*---------------------------------------------------------------------------*)
239(* Table-lookup versions of MultCol and InvMultCol.Faster to use, but        *)
240(* require tables (consume space).                                           *)
241(*---------------------------------------------------------------------------*)
242
243val TabledMultCol = Q.store_thm
244("TabledMultCol",
245 `MultCol(a,b,c,d) =
246    (GF256_by_2 a # GF256_by_3 b # c # d,
247     a # GF256_by_2 b # GF256_by_3 c # d,
248     a # b # GF256_by_2 c # GF256_by_3 d,
249     GF256_by_3 a # b # c # GF256_by_2 d)`,
250 SIMP_TAC std_ss [MultCol_def] THEN
251 SIMP_TAC std_ss (tcm_thm::map SYM (CONJUNCTS (SPEC_ALL MultEquiv))));
252
253val TabledInvMultCol =
254 Q.store_thm
255 ("TabledInvMultCol",
256  `InvMultCol (a,b,c,d) =
257    (GF256_by_14 a # GF256_by_11 b # GF256_by_13 c # GF256_by_9 d,
258     GF256_by_9 a # GF256_by_14 b # GF256_by_11 c # GF256_by_13 d,
259     GF256_by_13 a # GF256_by_9 b # GF256_by_14 c # GF256_by_11 d,
260     GF256_by_11 a # GF256_by_13 b # GF256_by_9 c # GF256_by_14 d)`,
261 SIMP_TAC std_ss [InvMultCol_def] THEN
262 SIMP_TAC std_ss (tcm_thm::map SYM (CONJUNCTS (SPEC_ALL MultEquiv))));
263
264
265(*---------------------------------------------------------------------------*)
266(* Inversion lemmas for column multiplication. Proved with an ad-hoc tactic  *)
267(*                                                                           *)
268(* Note: could just use case analysis with Sbox_ind, then EVAL_TAC, but      *)
269(* that's far slower.                                                        *)
270(*---------------------------------------------------------------------------*)
271
272val BYTE_CASES_TAC =
273  SIMP_TAC std_ss (tcm_thm::map SYM (CONJUNCTS (SPEC_ALL MultEquiv)))
274    THEN Ho_Rewrite.ONCE_REWRITE_TAC [FORALL_BYTE_BITS]
275    THEN EVAL_TAC;
276
277val lemma_a1 = Count.apply Q.prove
278(`!a. E_HEX ** (TWO ** a) # B_HEX ** a # D_HEX ** a # NINE ** (THREE ** a) = a`,
279 BYTE_CASES_TAC);
280
281val lemma_a2 = Count.apply Q.prove
282(`!b. E_HEX ** (THREE ** b) # B_HEX ** (TWO ** b) # D_HEX ** b # NINE  ** b = ZERO`,
283 BYTE_CASES_TAC);
284
285val lemma_a3 = Count.apply Q.prove
286(`!c. E_HEX ** c # B_HEX ** (THREE ** c) # D_HEX ** (TWO ** c) # NINE ** c = ZERO`,
287 BYTE_CASES_TAC);
288
289val lemma_a4 = Count.apply Q.prove
290(`!d. E_HEX ** d # B_HEX ** d # D_HEX ** (THREE ** d) # NINE ** (TWO ** d) = ZERO`,
291 BYTE_CASES_TAC);
292
293val lemma_b1 = Count.apply Q.prove
294(`!a. NINE ** (TWO ** a) # E_HEX ** a # B_HEX ** a # D_HEX  ** (THREE ** a) = ZERO`,
295 BYTE_CASES_TAC);
296
297val lemma_b2 = Count.apply Q.prove
298(`!b. NINE ** (THREE ** b) # E_HEX ** (TWO ** b) # B_HEX ** b # D_HEX ** b = b`,
299 BYTE_CASES_TAC);
300
301val lemma_b3 = Count.apply Q.prove
302(`!c. NINE ** c # E_HEX ** (THREE ** c) # B_HEX ** (TWO ** c) # D_HEX ** c = ZERO`,
303 BYTE_CASES_TAC);
304
305val lemma_b4 = Count.apply Q.prove
306(`!d. NINE ** d # E_HEX ** d # B_HEX ** (THREE ** d) # D_HEX ** (TWO ** d) = ZERO`,
307 BYTE_CASES_TAC);
308
309val lemma_c1 = Count.apply Q.prove
310(`!a. D_HEX ** (TWO ** a) # NINE ** a # E_HEX ** a # B_HEX  ** (THREE ** a) = ZERO`,
311 BYTE_CASES_TAC THEN EVAL_TAC);
312
313val lemma_c2 = Count.apply Q.prove
314(`!b. D_HEX ** (THREE ** b) # NINE ** (TWO ** b) # E_HEX ** b # B_HEX ** b = ZERO`,
315 BYTE_CASES_TAC);
316
317val lemma_c3 = Count.apply Q.prove
318(`!c. D_HEX ** c # NINE ** (THREE ** c) # E_HEX ** (TWO ** c) # B_HEX ** c = c`,
319 BYTE_CASES_TAC);
320
321val lemma_c4 = Count.apply Q.prove
322(`!d. D_HEX ** d # NINE ** d # E_HEX ** (THREE ** d) # B_HEX ** (TWO ** d) = ZERO`,
323 BYTE_CASES_TAC);
324
325val lemma_d1 = Count.apply Q.prove
326(`!a. B_HEX ** (TWO ** a) # D_HEX ** a # NINE ** a # E_HEX  ** (THREE ** a) = ZERO`,
327 BYTE_CASES_TAC);
328
329val lemma_d2 = Count.apply Q.prove
330(`!b. B_HEX ** (THREE ** b) # D_HEX ** (TWO ** b) # NINE ** b # E_HEX ** b = ZERO`,
331 BYTE_CASES_TAC);
332
333val lemma_d3 = Count.apply Q.prove
334(`!c. B_HEX ** c # D_HEX ** (THREE ** c) # NINE ** (TWO ** c) # E_HEX ** c = ZERO`,
335 BYTE_CASES_TAC THEN EVAL_TAC);
336
337val lemma_d4 = Count.apply Q.prove
338(`!d. B_HEX ** d # D_HEX ** d # NINE ** (THREE ** d) # E_HEX ** (TWO ** d) = d`,
339 BYTE_CASES_TAC);
340
341(*---------------------------------------------------------------------------*)
342(* The following lemma is hideous to prove without permutative rewriting     *)
343(*---------------------------------------------------------------------------*)
344
345val rearrange_xors = Q.prove
346(`(a1 # b1 # c1 # d1) #
347  (a2 # b2 # c2 # d2) #
348  (a3 # b3 # c3 # d3) #
349  (a4 # b4 # c4 # d4)
350     =
351  (a1 # a2 # a3 # a4) #
352  (b1 # b2 # b3 # b4) #
353  (c1 # c2 # c3 # c4) #
354  (d1 # d2 # d3 # d4)`,
355 RW_TAC std_ss [AC a c]);
356
357val mix_lemma1 = Q.prove
358(`!a b c d.
359   (E_HEX ** ((TWO ** a) # (THREE ** b) # c # d)) #
360   (B_HEX ** (a # (TWO ** b) # (THREE ** c) # d)) #
361   (D_HEX ** (a # b # (TWO ** c) # (THREE ** d))) #
362   (NINE  ** ((THREE ** a) # b # c # (TWO ** d)))
363      = a`,
364 RW_TAC std_ss [ConstMultDistrib]
365   THEN ONCE_REWRITE_TAC [rearrange_xors]
366   THEN RW_TAC std_ss [lemma_a1,lemma_a2,lemma_a3,lemma_a4,XOR8_ZERO]);
367
368val mix_lemma2 = Q.prove
369(`!a b c d.
370   (NINE  ** ((TWO ** a) # (THREE ** b) # c # d)) #
371   (E_HEX ** (a # (TWO ** b) # (THREE ** c) # d)) #
372   (B_HEX ** (a # b # (TWO ** c) # (THREE ** d))) #
373   (D_HEX ** ((THREE ** a) # b # c # (TWO ** d)))
374     = b`,
375 RW_TAC std_ss [ConstMultDistrib]
376   THEN ONCE_REWRITE_TAC [rearrange_xors]
377   THEN RW_TAC std_ss [lemma_b1,lemma_b2,lemma_b3,lemma_b4,
378                       XOR8_ZERO, ONCE_REWRITE_RULE [XOR8_AC] XOR8_ZERO]);
379
380val mix_lemma3 = Q.prove
381(`!a b c d.
382   (D_HEX ** ((TWO ** a) # (THREE ** b) # c # d)) #
383   (NINE  ** (a # (TWO ** b) # (THREE ** c) # d)) #
384   (E_HEX ** (a # b # (TWO ** c) # (THREE ** d))) #
385   (B_HEX ** ((THREE ** a) # b # c # (TWO ** d)))
386     = c`,
387 RW_TAC std_ss [ConstMultDistrib]
388   THEN ONCE_REWRITE_TAC [rearrange_xors]
389   THEN RW_TAC std_ss [lemma_c1,lemma_c2,lemma_c3,lemma_c4,
390                       XOR8_ZERO, ONCE_REWRITE_RULE [XOR8_AC] XOR8_ZERO]);
391
392val mix_lemma4 = Q.prove
393(`!a b c d.
394   (B_HEX ** ((TWO ** a) # (THREE ** b) # c # d)) #
395   (D_HEX ** (a # (TWO ** b) # (THREE ** c) # d)) #
396   (NINE  ** (a # b # (TWO ** c) # (THREE ** d))) #
397   (E_HEX ** ((THREE ** a) # b # c # (TWO ** d)))
398     = d`,
399 RW_TAC std_ss [ConstMultDistrib]
400   THEN ONCE_REWRITE_TAC [rearrange_xors]
401   THEN RW_TAC std_ss [lemma_d1,lemma_d2,lemma_d3,lemma_d4,
402                       XOR8_ZERO, ONCE_REWRITE_RULE [XOR8_AC] XOR8_ZERO]);
403
404(*---------------------------------------------------------------------------*)
405(* Get the constants of various definitions                                  *)
406(*---------------------------------------------------------------------------*)
407
408val [mult]     = decls "**";
409val [TWO]      = decls "TWO";
410val [THREE]    = decls "THREE";
411val [NINE]     = decls "NINE";
412val [B_HEX]    = decls "B_HEX";
413val [D_HEX]    = decls "D_HEX";
414val [E_HEX]    = decls "E_HEX";
415
416(*---------------------------------------------------------------------------*)
417(* Mixing columns                                                            *)
418(*---------------------------------------------------------------------------*)
419
420val genMixColumns_def = Define
421 `genMixColumns MC ((b00,b01,b02,b03,
422                     b10,b11,b12,b13,
423                     b20,b21,b22,b23,
424                     b30,b31,b32,b33) :state)
425 = let (b00', b10', b20', b30') = MC (b00,b10,b20,b30) in
426   let (b01', b11', b21', b31') = MC (b01,b11,b21,b31) in
427   let (b02', b12', b22', b32') = MC (b02,b12,b22,b32) in
428   let (b03', b13', b23', b33') = MC (b03,b13,b23,b33)
429   in
430    (b00', b01', b02', b03',
431     b10', b11', b12', b13',
432     b20', b21', b22', b23',
433     b30', b31', b32', b33') : state`;
434
435
436val MixColumns_def    = Define `MixColumns    = genMixColumns MultCol`;
437val InvMixColumns_def = Define `InvMixColumns = genMixColumns InvMultCol`;
438
439val MixColumns_Inversion = Q.store_thm
440("MixColumns_Inversion",
441 `!s. genMixColumns InvMultCol (genMixColumns MultCol s) = s`,
442 SIMP_TAC std_ss [FORALL_BLOCK]
443  THEN RESTR_EVAL_TAC [mult,B_HEX,D_HEX,E_HEX,TWO,THREE,NINE]
444  THEN RW_TAC std_ss [mix_lemma1,mix_lemma2,mix_lemma3,mix_lemma4]);
445
446
447(*---------------------------------------------------------------------------
448    Pairwise XOR the state with the round key
449 ---------------------------------------------------------------------------*)
450
451val AddRoundKey_def = Define `AddRoundKey = XOR_BLOCK`;
452
453(*---------------------------------------------------------------------------*)
454(* For alternative decryption scheme                                         *)
455(*---------------------------------------------------------------------------*)
456
457val InvMixColumns_Distrib = Q.store_thm
458("InvMixColumns_Distrib",
459 `!s k. InvMixColumns (AddRoundKey s k)
460            =
461        AddRoundKey (InvMixColumns s) (InvMixColumns k)`,
462 SIMP_TAC std_ss [FORALL_BLOCK] THEN
463 RW_TAC std_ss [XOR_BLOCK_def, AddRoundKey_def, InvMixColumns_def, LET_THM,
464                genMixColumns_def, InvMultCol_def, ConstMultDistrib, AC a c]);
465
466
467val _ = export_theory();
468