1(*---------------------------------------------------------------------------*)
2(* Operations performed in a round:                                          *)
3(*                                                                           *)
4(*    - applying Sboxes                                                      *)
5(*    - shifting rows                                                        *)
6(*    - mixing columns                                                       *)
7(*    - adding round keys                                                    *)
8(*                                                                           *)
9(* We prove "inversion" theorems for each of these                           *)
10(*                                                                           *)
11(*---------------------------------------------------------------------------*)
12
13(* For interactive work
14  quietdec := true;
15  app load ["MultTheory", "tablesTheory", "wordsLib"];
16  quietdec := false;
17*)
18
19open HolKernel Parse boolLib bossLib;
20open pairTheory wordsTheory MultTheory wordsLib;
21
22(*---------------------------------------------------------------------------*)
23(* Make bindings to pre-existing stuff                                       *)
24(*---------------------------------------------------------------------------*)
25
26val RESTR_EVAL_TAC = computeLib.RESTR_EVAL_TAC;
27
28val Sbox_Inversion = tablesTheory.Sbox_Inversion;
29
30(*---------------------------------------------------------------------------*)
31(* Create the theory.                                                        *)
32(*---------------------------------------------------------------------------*)
33
34val _ = new_theory "RoundOp";
35
36(*---------------------------------------------------------------------------*)
37(* A block is 16 bytes. A state also has that type, although states have     *)
38(* a special format.                                                         *)
39(*---------------------------------------------------------------------------*)
40
41val _ = type_abbrev("block",
42                    Type`:word8 # word8 # word8 # word8 #
43                          word8 # word8 # word8 # word8 #
44                          word8 # word8 # word8 # word8 #
45                          word8 # word8 # word8 # word8`);
46
47val _ = type_abbrev("state", Type`:block`);
48val _ = type_abbrev("key",   Type`:state`);
49val _ = type_abbrev("w8x4",  Type`:word8 # word8 # word8 # word8`);
50
51
52val ZERO_BLOCK_def =
53 Define
54   `ZERO_BLOCK = (0w,0w,0w,0w,
55                  0w,0w,0w,0w,
56                  0w,0w,0w,0w,
57                  0w,0w,0w,0w) : block`;
58
59(*---------------------------------------------------------------------------*)
60(* Case analysis on a block.                                                 *)
61(*---------------------------------------------------------------------------*)
62
63val FORALL_BLOCK = Q.store_thm
64("FORALL_BLOCK",
65 `(!b:block. P b) =
66   !w1 w2 w3 w4 w5 w6 w7 w8 w9 w10 w11 w12 w13 w14 w15 w16.
67    P (w1,w2,w3,w4,w5,w6,w7,w8,w9,w10,w11,w12,w13,w14,w15,w16)`,
68 SIMP_TAC std_ss [FORALL_PROD]);
69
70(*---------------------------------------------------------------------------*)
71(* XOR on blocks. Definition and algebraic properties.                       *)
72(*---------------------------------------------------------------------------*)
73
74val XOR_BLOCK_def = Define
75 `XOR_BLOCK ((a0,a1,a2,a3,a4,a5,a6,a7,a8,a9,a10,a11,a12,a13,a14,a15):block)
76            ((b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,b15):block)
77       =
78      (a0 ?? b0,   a1 ?? b1,   a2 ?? b2,   a3 ?? b3,
79       a4 ?? b4,   a5 ?? b5,   a6 ?? b6,   a7 ?? b7,
80       a8 ?? b8,   a9 ?? b9,   a10 ?? b10, a11 ?? b11,
81       a12 ?? b12, a13 ?? b13, a14 ?? b14, a15 ?? b15)`;
82
83val XOR_BLOCK_ZERO = Q.store_thm
84("XOR_BLOCK_ZERO",
85 `!x:block. XOR_BLOCK x ZERO_BLOCK = x`,
86 SIMP_TAC std_ss
87   [FORALL_BLOCK,XOR_BLOCK_def, ZERO_BLOCK_def, WORD_XOR_CLAUSES]);
88
89val XOR_BLOCK_INV = Q.store_thm
90("XOR_BLOCK_INV",
91 `!x:block. XOR_BLOCK x x = ZERO_BLOCK`,
92 SIMP_TAC std_ss
93   [FORALL_BLOCK,XOR_BLOCK_def, ZERO_BLOCK_def, WORD_XOR_CLAUSES]);
94
95val XOR_BLOCK_AC = Q.store_thm
96("XOR_BLOCK_AC",
97 `(!x y z:block. XOR_BLOCK (XOR_BLOCK x y) z = XOR_BLOCK x (XOR_BLOCK y z)) /\
98  (!x y:block. XOR_BLOCK x y = XOR_BLOCK y x)`,
99 SIMP_TAC (srw_ss()) [FORALL_BLOCK,XOR_BLOCK_def]);
100
101val XOR_BLOCK_IDEM = Q.store_thm
102("XOR_BLOCK_IDEM",
103 `(!v u. XOR_BLOCK (XOR_BLOCK v u) u = v) /\
104  (!v u. XOR_BLOCK v (XOR_BLOCK v u) = u)`,
105 METIS_TAC [XOR_BLOCK_INV,XOR_BLOCK_AC,XOR_BLOCK_ZERO]);
106
107(*---------------------------------------------------------------------------*)
108(*    Moving data into and out of a state                                    *)
109(*---------------------------------------------------------------------------*)
110
111val to_state_def = Define
112 `to_state ((b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,b15) :block)
113                =
114            (b0,b4,b8,b12,
115             b1,b5,b9,b13,
116             b2,b6,b10,b14,
117             b3,b7,b11,b15) : state`;
118
119val from_state_def = Define
120 `from_state((b0,b4,b8,b12,
121              b1,b5,b9,b13,
122              b2,b6,b10,b14,
123              b3,b7,b11,b15) :state)
124 = (b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,b15) : block`;
125
126
127val to_state_Inversion = Q.store_thm
128  ("to_state_Inversion",
129   `!s:state. from_state(to_state s) = s`,
130   SIMP_TAC std_ss [FORALL_BLOCK, from_state_def, to_state_def]);
131
132
133val from_state_Inversion = Q.store_thm
134  ("from_state_Inversion",
135   `!s:state. to_state(from_state s) = s`,
136   SIMP_TAC std_ss [FORALL_BLOCK, from_state_def, to_state_def]);
137
138
139(*---------------------------------------------------------------------------*)
140(*    Apply an Sbox to the state                                             *)
141(*---------------------------------------------------------------------------*)
142
143val _ = Parse.hide "S";   (* to make parameter S a variable *)
144
145val genSubBytes_def = try Define
146  `genSubBytes S ((b00,b01,b02,b03,
147                   b10,b11,b12,b13,
148                   b20,b21,b22,b23,
149                   b30,b31,b32,b33) :state)
150                          =
151             (S b00, S b01, S b02, S b03,
152              S b10, S b11, S b12, S b13,
153              S b20, S b21, S b22, S b23,
154              S b30, S b31, S b32, S b33) :state`;
155
156val _ = Parse.reveal "S";
157
158val SubBytes_def    = Define `SubBytes = genSubBytes Sbox`;
159val InvSubBytes_def = Define `InvSubBytes = genSubBytes InvSbox`;
160
161val SubBytes_Inversion = Q.store_thm
162("SubBytes_Inversion",
163 `!s:state. genSubBytes InvSbox (genSubBytes Sbox s) = s`,
164 SIMP_TAC std_ss [FORALL_BLOCK,genSubBytes_def,Sbox_Inversion]);
165
166
167(*---------------------------------------------------------------------------
168    Left-shift the first row not at all, the second row by 1, the
169    third row by 2, and the last row by 3. And the inverse operation.
170 ---------------------------------------------------------------------------*)
171
172val ShiftRows_def = Define
173  `ShiftRows ((b00,b01,b02,b03,
174               b10,b11,b12,b13,
175               b20,b21,b22,b23,
176               b30,b31,b32,b33) :state)
177                     =
178             (b00,b01,b02,b03,
179              b11,b12,b13,b10,
180              b22,b23,b20,b21,
181              b33,b30,b31,b32) :state`;
182
183val InvShiftRows_def = Define
184  `InvShiftRows ((b00,b01,b02,b03,
185                  b11,b12,b13,b10,
186                  b22,b23,b20,b21,
187                  b33,b30,b31,b32) :state)
188                     =
189                (b00,b01,b02,b03,
190                 b10,b11,b12,b13,
191                 b20,b21,b22,b23,
192                 b30,b31,b32,b33) :state`;
193
194(*---------------------------------------------------------------------------
195        InvShiftRows inverts ShiftRows
196 ---------------------------------------------------------------------------*)
197
198val ShiftRows_Inversion = Q.store_thm
199("ShiftRows_Inversion",
200 `!s:state. InvShiftRows (ShiftRows s) = s`,
201 SIMP_TAC std_ss [FORALL_BLOCK] THEN REPEAT STRIP_TAC THEN EVAL_TAC);
202
203
204(*---------------------------------------------------------------------------*)
205(* For alternative decryption scheme                                         *)
206(*---------------------------------------------------------------------------*)
207
208val ShiftRows_SubBytes_Commute = Q.store_thm
209 ("ShiftRows_SubBytes_Commute",
210  `!s. ShiftRows (SubBytes s) = SubBytes (ShiftRows s)`,
211 SIMP_TAC std_ss [FORALL_BLOCK] THEN REPEAT STRIP_TAC THEN EVAL_TAC);
212
213
214val InvShiftRows_InvSubBytes_Commute = Q.store_thm
215 ("InvShiftRows_InvSubBytes_Commute",
216  `!s. InvShiftRows (InvSubBytes s) = InvSubBytes (InvShiftRows s)`,
217 SIMP_TAC std_ss [FORALL_BLOCK] THEN REPEAT STRIP_TAC THEN EVAL_TAC);
218
219
220(*---------------------------------------------------------------------------
221        Column multiplication and its inverse
222 ---------------------------------------------------------------------------*)
223
224val MultCol_def = Define
225 `MultCol (a,b,c,d) =
226   ((2w ** a) ?? (3w ** b) ??  c        ?? d,
227     a        ?? (2w ** b) ?? (3w ** c) ?? d,
228     a        ??  b        ?? (2w ** c) ?? (3w ** d),
229    (3w ** a) ??  b        ??  c        ?? (2w ** d))`;
230
231val InvMultCol_def = Define
232 `InvMultCol (a,b,c,d) =
233   ((0xEw ** a) ?? (0xBw ** b) ?? (0xDw ** c) ?? (9w   ** d),
234    (9w   ** a) ?? (0xEw ** b) ?? (0xBw ** c) ?? (0xDw ** d),
235    (0xDw ** a) ?? (9w   ** b) ?? (0xEw ** c) ?? (0xBw ** d),
236    (0xBw ** a) ?? (0xDw ** b) ?? (9w   ** c) ?? (0xEw ** d))`;
237
238(*---------------------------------------------------------------------------*)
239(* Inversion lemmas for column multiplication. Proved with an ad-hoc tactic  *)
240(*---------------------------------------------------------------------------*)
241
242val BYTE_CASES_TAC =
243  Cases
244    THEN FULL_SIMP_TAC std_ss [wordsTheory.dimword_8,
245           CONV_RULE numLib.SUC_TO_NUMERAL_DEFN_CONV prim_recTheory.LESS_THM]
246    THEN RW_TAC std_ss [fetch "Mult" "mult_tables"]
247    THEN REWRITE_TAC [fetch "Mult" "mult_tables"]
248    THEN WORD_EVAL_TAC;
249
250
251val lemma_a1 = Count.apply Q.prove
252(`!a. 0xEw ** (2w ** a) ?? 0xBw ** a ?? 0xDw ** a ?? 9w ** (3w ** a) = a`,
253 BYTE_CASES_TAC);
254
255val lemma_a2 = Count.apply Q.prove
256(`!b. 0xEw ** (3w ** b) ?? 0xBw ** (2w ** b) ?? 0xDw ** b ?? 9w  ** b = 0w`,
257 BYTE_CASES_TAC);
258
259val lemma_a3 = Count.apply Q.prove
260(`!c. 0xEw ** c ?? 0xBw ** (3w ** c) ?? 0xDw ** (2w ** c) ?? 9w ** c = 0w`,
261 BYTE_CASES_TAC);
262
263val lemma_a4 = Count.apply Q.prove
264(`!d. 0xEw ** d ?? 0xBw ** d ?? 0xDw ** (3w ** d) ?? 9w ** (2w ** d) = 0w`,
265 BYTE_CASES_TAC);
266
267val lemma_b1 = Count.apply Q.prove
268(`!a. 9w ** (2w ** a) ?? 0xEw ** a ?? 0xBw ** a ?? 0xDw  ** (3w ** a) = 0w`,
269 BYTE_CASES_TAC);
270
271val lemma_b2 = Count.apply Q.prove
272(`!b. 9w ** (3w ** b) ?? 0xEw ** (2w ** b) ?? 0xBw ** b ?? 0xDw ** b = b`,
273 BYTE_CASES_TAC);
274
275val lemma_b3 = Count.apply Q.prove
276(`!c. 9w ** c ?? 0xEw ** (3w ** c) ?? 0xBw ** (2w ** c) ?? 0xDw ** c = 0w`,
277 BYTE_CASES_TAC);
278
279val lemma_b4 = Count.apply Q.prove
280(`!d. 9w ** d ?? 0xEw ** d ?? 0xBw ** (3w ** d) ?? 0xDw ** (2w ** d) = 0w`,
281 BYTE_CASES_TAC);
282
283val lemma_c1 = Count.apply Q.prove
284(`!a. 0xDw ** (2w ** a) ?? 9w ** a ?? 0xEw ** a ?? 0xBw  ** (3w ** a) = 0w`,
285 BYTE_CASES_TAC THEN EVAL_TAC);
286
287val lemma_c2 = Count.apply Q.prove
288(`!b. 0xDw ** (3w ** b) ?? 9w ** (2w ** b) ?? 0xEw ** b ?? 0xBw ** b = 0w`,
289 BYTE_CASES_TAC);
290
291val lemma_c3 = Count.apply Q.prove
292(`!c. 0xDw ** c ?? 9w ** (3w ** c) ?? 0xEw ** (2w ** c) ?? 0xBw ** c = c`,
293 BYTE_CASES_TAC);
294
295val lemma_c4 = Count.apply Q.prove
296(`!d. 0xDw ** d ?? 9w ** d ?? 0xEw ** (3w ** d) ?? 0xBw ** (2w ** d) = 0w`,
297 BYTE_CASES_TAC);
298
299val lemma_d1 = Count.apply Q.prove
300(`!a. 0xBw ** (2w ** a) ?? 0xDw ** a ?? 9w ** a ?? 0xEw  ** (3w ** a) = 0w`,
301 BYTE_CASES_TAC);
302
303val lemma_d2 = Count.apply Q.prove
304(`!b. 0xBw ** (3w ** b) ?? 0xDw ** (2w ** b) ?? 9w ** b ?? 0xEw ** b = 0w`,
305 BYTE_CASES_TAC);
306
307val lemma_d3 = Count.apply Q.prove
308(`!c. 0xBw ** c ?? 0xDw ** (3w ** c) ?? 9w ** (2w ** c) ?? 0xEw ** c = 0w`,
309 BYTE_CASES_TAC THEN EVAL_TAC);
310
311val lemma_d4 = Count.apply Q.prove
312(`!d. 0xBw ** d ?? 0xDw ** d ?? 9w ** (3w ** d) ?? 0xEw ** (2w ** d) = d`,
313 BYTE_CASES_TAC);
314
315(*---------------------------------------------------------------------------*)
316(* The following lemma is hideous to prove without permutative rewriting     *)
317(*---------------------------------------------------------------------------*)
318
319val rearrange_xors = Q.prove
320(`(a1 ?? b1 ?? c1 ?? d1) ??
321  (a2 ?? b2 ?? c2 ?? d2) ??
322  (a3 ?? b3 ?? c3 ?? d3) ??
323  (a4 ?? b4 ?? c4 ?? d4)
324     =
325  (a1 ?? a2 ?? a3 ?? a4) ??
326  (b1 ?? b2 ?? b3 ?? b4) ??
327  (c1 ?? c2 ?? c3 ?? c4) ??
328  (d1 ?? d2 ?? d3 ?? d4)`,
329 SRW_TAC [] []);
330
331val mix_lemma1 = Q.prove
332(`!a b c d.
333   (0xEw ** ((2w ** a) ?? (3w ** b) ?? c ?? d)) ??
334   (0xBw ** (a ?? (2w ** b) ?? (3w ** c) ?? d)) ??
335   (0xDw ** (a ?? b ?? (2w ** c) ?? (3w ** d))) ??
336   (9w  ** ((3w ** a) ?? b ?? c ?? (2w ** d)))
337      = a`,
338 RW_TAC std_ss [ConstMultDistrib]
339   THEN ONCE_REWRITE_TAC [rearrange_xors]
340   THEN RW_TAC std_ss [lemma_a1,lemma_a2,lemma_a3,lemma_a4,WORD_XOR_CLAUSES]);
341
342val mix_lemma2 = Q.prove
343(`!a b c d.
344   (9w  ** ((2w ** a) ?? (3w ** b) ?? c ?? d)) ??
345   (0xEw ** (a ?? (2w ** b) ?? (3w ** c) ?? d)) ??
346   (0xBw ** (a ?? b ?? (2w ** c) ?? (3w ** d))) ??
347   (0xDw ** ((3w ** a) ?? b ?? c ?? (2w ** d)))
348     = b`,
349 RW_TAC std_ss [ConstMultDistrib]
350   THEN ONCE_REWRITE_TAC [rearrange_xors]
351   THEN RW_TAC std_ss [lemma_b1,lemma_b2,lemma_b3,lemma_b4,WORD_XOR_CLAUSES]);
352
353val mix_lemma3 = Q.prove
354(`!a b c d.
355   (0xDw ** ((2w ** a) ?? (3w ** b) ?? c ?? d)) ??
356   (9w  ** (a ?? (2w ** b) ?? (3w ** c) ?? d)) ??
357   (0xEw ** (a ?? b ?? (2w ** c) ?? (3w ** d))) ??
358   (0xBw ** ((3w ** a) ?? b ?? c ?? (2w ** d)))
359     = c`,
360 RW_TAC std_ss [ConstMultDistrib]
361   THEN ONCE_REWRITE_TAC [rearrange_xors]
362   THEN RW_TAC std_ss [lemma_c1,lemma_c2,lemma_c3,lemma_c4,WORD_XOR_CLAUSES]);
363
364val mix_lemma4 = Q.prove
365(`!a b c d.
366   (0xBw ** ((2w ** a) ?? (3w ** b) ?? c ?? d)) ??
367   (0xDw ** (a ?? (2w ** b) ?? (3w ** c) ?? d)) ??
368   (9w  ** (a ?? b ?? (2w ** c) ?? (3w ** d))) ??
369   (0xEw ** ((3w ** a) ?? b ?? c ?? (2w ** d)))
370     = d`,
371 RW_TAC std_ss [ConstMultDistrib]
372   THEN ONCE_REWRITE_TAC [rearrange_xors]
373   THEN RW_TAC std_ss [lemma_d1,lemma_d2,lemma_d3,lemma_d4,WORD_XOR_CLAUSES]);
374
375(*---------------------------------------------------------------------------*)
376(* Get the constants of various definitions                                  *)
377(*---------------------------------------------------------------------------*)
378
379val mult = Term `Mult$**`;
380val n2w = Term `n2w`;
381
382(*---------------------------------------------------------------------------*)
383(* Mixing columns                                                            *)
384(*---------------------------------------------------------------------------*)
385
386val genMixColumns_def = Define
387 `genMixColumns MC ((b00,b01,b02,b03,
388                     b10,b11,b12,b13,
389                     b20,b21,b22,b23,
390                     b30,b31,b32,b33) :state)
391 = let (b00', b10', b20', b30') = MC (b00,b10,b20,b30) in
392   let (b01', b11', b21', b31') = MC (b01,b11,b21,b31) in
393   let (b02', b12', b22', b32') = MC (b02,b12,b22,b32) in
394   let (b03', b13', b23', b33') = MC (b03,b13,b23,b33)
395   in
396    (b00', b01', b02', b03',
397     b10', b11', b12', b13',
398     b20', b21', b22', b23',
399     b30', b31', b32', b33') : state`;
400
401
402val MixColumns_def    = Define `MixColumns    = genMixColumns MultCol`;
403val InvMixColumns_def = Define `InvMixColumns = genMixColumns InvMultCol`;
404
405val MixColumns_Inversion = Q.store_thm
406("MixColumns_Inversion",
407 `!s. genMixColumns InvMultCol (genMixColumns MultCol s) = s`,
408 SIMP_TAC std_ss [FORALL_BLOCK]
409  THEN RESTR_EVAL_TAC [mult,n2w]
410  THEN RW_TAC std_ss [mix_lemma1,mix_lemma2,mix_lemma3,mix_lemma4]);
411
412
413(*---------------------------------------------------------------------------
414    Pairwise XOR the state with the round key
415 ---------------------------------------------------------------------------*)
416
417val AddRoundKey_def = Define `AddRoundKey = XOR_BLOCK`;
418
419(*---------------------------------------------------------------------------*)
420(* For alternative decryption scheme                                         *)
421(*---------------------------------------------------------------------------*)
422
423val InvMixColumns_Distrib = Q.store_thm
424("InvMixColumns_Distrib",
425 `!s k. InvMixColumns (AddRoundKey s k)
426            =
427        AddRoundKey (InvMixColumns s) (InvMixColumns k)`,
428 SIMP_TAC std_ss [FORALL_BLOCK] THEN
429 SRW_TAC [] [XOR_BLOCK_def, AddRoundKey_def, InvMixColumns_def, LET_THM,
430             genMixColumns_def, InvMultCol_def, ConstMultDistrib]);
431
432
433val _ = export_theory();
434