1(*                                   Twofish Block Cipher
2                                        -- implemented in Standard ML
3
4Twofish is a 128-bit block cipher that accepts a variable-length key
5up to 256 bits. The cipher is a 16-round Feistel network with a
6bijective F function made up of four key-dependent 8-by-8-bit S-boxes,
7a fixed 4-by-4 maximum distance separable matrix over GF(28), a
8pseudo-Hadamard transform, bitwise rotations, and a carefully designed
9key schedule. For more information, please refer to
10        web site: http://www.counterpane.com/twofish.html
11*)
12
13(* For interactive work
14  quietdec := true;
15  app load ["arithmeticTheory","wordsLib"];
16  open arithmeticTheory wordsTheory pairTheory bitTheory wordsLib;
17  quietdec := false;
18*)
19
20open HolKernel Parse boolLib bossLib
21     arithmeticTheory wordsTheory pairTheory bitTheory wordsLib;
22
23(*---------------------------------------------------------------------------*)
24(* Create the theory.                                                        *)
25(*---------------------------------------------------------------------------*)
26
27val _ = new_theory "twofish";
28
29(*---------------------------------------------------------------------------*)
30(* Type Definitions                                                          *)
31(*---------------------------------------------------------------------------*)
32
33val _ = type_abbrev("block", ``:word32 # word32 # word32 # word32``);
34val _ = type_abbrev("key",   ``:word32 # word32``);
35
36val _ = type_abbrev("initkeys",
37        ``:word8 # word8 # word8 # word8 # word8 # word8 # word8 # word8 #
38           word8 # word8 # word8 # word8 # word8 # word8 # word8 # word8 #
39           word8 # word8 # word8 # word8 # word8 # word8 # word8 # word8 #
40           word8 # word8 # word8 # word8 # word8 # word8 # word8 # word8``);
41
42val _ = type_abbrev("keysched",
43   ``:word32 # word32 # word32 # word32 # word32 # word32 # word32 # word32 #
44      word32 # word32 # word32 # word32 # word32 # word32 # word32 # word32 #
45      word32 # word32 # word32 # word32 # word32 # word32 # word32 # word32 #
46      word32 # word32 # word32 # word32 # word32 # word32 # word32 # word32 #
47      word32 # word32 # word32 # word32 # word32 # word32 # word32 # word32``);
48
49(*---------------------------------------------------------------------------*)
50(* Case analysis on a block and a pair of keys.                              *)
51(*---------------------------------------------------------------------------*)
52
53val FORALL_BLOCK = Q.store_thm
54  ("FORALL_BLOCK",
55    `(!b:block. P b) = !v0 v1 v2 v3. P (v0,v1,v2,v3)`,
56    SIMP_TAC std_ss [FORALL_PROD]);
57
58val FORALL_KEY = Q.store_thm
59  ("FORALL_KEY",
60    `(!b:key. P b) = !k0 k1. P (k0,k1)`,
61    SIMP_TAC std_ss [FORALL_PROD]);
62
63(*---------------------------------------------------------------------------*)
64(* Operations on word8, word32 and word4.                                    *)
65(*---------------------------------------------------------------------------*)
66
67(* Word4 shifting operators *)
68val ROR4_def = Define`ROR4(x:word8, n) = x >> n && x << (4 - n)`;
69val ROL4_def = Define`ROL4(x:word8, n) = x << n && x >> (4 - n)`;
70
71val _ = wordsLib.guess_lengths();
72
73(* Conversion between word8*word8*word8*word8 and word32 *)
74
75val toLarge_def = Define`toLarge (a3:word8,a2:word8,a1:word8,a0:word8) =
76   a3 @@ a2 @@ a1 @@ a0`;
77
78val fromLarge_def = Define`fromLarge (a:word32) =
79   ((31 >< 24) a, (23 >< 16) a, (15 >< 8) a, (7 >< 0) a)`;
80
81(*---------------------------------------------------------------------------*)
82(* Multiply a byte representing a polynomial by x.                           *)
83(*---------------------------------------------------------------------------*)
84
85(* For MDS multiplication, v(x) = x^8 + x^6 + x^5 + x^3 + 1 , i.e. 0wx165    *)
86
87val xtime1_def = Define
88  `xtime1 (w : word8) =
89     if word_msb w then w << 1 ?? 0x165w else w << 1`;
90
91val _ = set_fixity "**" (Infixl 675);
92
93val Mult1_def = xDefine "Mult1"
94  `b1 ** b2 =
95     if b1 = 0w :word8 then 0w else
96     if word_lsb b1
97        then b2 ?? ((b1 >>> 1) ** xtime1 b2)
98        else        (b1 >>> 1) ** xtime1 b2`;
99
100(* For RS multiplication, v(x) = x^8 + x^6 + x^3 + x^2 + 1 , i.e. 0wx14D*)
101
102val xtime2_def = Define
103  `xtime2 (w : word8) =
104     if word_msb w then w << 1 ?? 0x14Dw else w << 1`;
105
106val _ = set_fixity "***" (Infixl 675);
107
108val Mult2_def = xDefine "Mult2"
109  `b1 *** b2 =
110     if b1 = 0w :word8 then 0w else
111     if word_lsb b1
112        then b2 ?? ((b1 >>> 1) *** xtime2 b2)
113        else        (b1 >>> 1) *** xtime2 b2`;
114
115(*---------------------------------------------------------------------------*)
116(* Matrix Column Multiplication                                              *)
117(*---------------------------------------------------------------------------*)
118
119val InvW_def = Define`
120    InvW (m0,m1,m2,m3): (word8 # word8 # word8 # word8) = (m3,m2,m1,m0)`;
121
122(* Multiply the MDS matrix *)
123
124val MDSMul_def = Define`MDSMul(m0,m1,m2,m3) =
125  ((0x01w ** m0) ?? (0xEFw ** m1) ?? (0x5Bw ** m2) ?? (0x5Bw ** m3),
126   (0x5Bw ** m0) ?? (0xEFw ** m1) ?? (0xEFw ** m2) ?? (0x01w ** m3),
127   (0xEFw ** m0) ?? (0x5Bw ** m1) ?? (0x01w ** m2) ?? (0xEFw ** m3),
128   (0xEFw ** m0) ?? (0x01w ** m1) ?? (0xEFw ** m2) ?? (0x5Bw ** m3))`;
129
130(* Multiply the RS matrix *)
131
132val RSMul_def = Define`RSMul(m0,m1,m2,m3,m4,m5,m6,m7) =
133  ((0x01w *** m0) ?? (0xA4w *** m1) ?? (0x55w *** m2) ?? (0x87w *** m3) ??
134   (0x5Aw *** m4) ?? (0x58w *** m5) ?? (0xDBw *** m6) ?? (0x9Ew *** m7),
135   (0xA4w *** m0) ?? (0x56w *** m1) ?? (0x82w *** m2) ?? (0xF3w *** m3) ??
136   (0x1Ew *** m4) ?? (0xC6w *** m5) ?? (0x68w *** m6) ?? (0xE5w *** m7),
137   (0x02w *** m0) ?? (0xA1w *** m1) ?? (0xFCw *** m2) ?? (0xC1w *** m3) ??
138   (0x47w *** m4) ?? (0xAEw *** m5) ?? (0x3Dw *** m6) ?? (0x19w *** m7),
139   (0xA4w *** m0) ?? (0x55w *** m1) ?? (0x87w *** m2) ?? (0x5Aw *** m3) ??
140   (0x58w *** m4) ?? (0xDBw *** m5) ?? (0x9Ew *** m6) ?? (0x03w *** m7))`;
141
142(*---------------------------------------------------------------------------*)
143(* The permutations q0 and q1 are fixed permutations on 8-bit values.        *)
144(* They are constructed from four different 4-bit permutations each.         *)
145(*---------------------------------------------------------------------------*)
146
147(* The 4-bit S-boxes For the permutation q0 *)
148
149val t00_def = Define`
150  t00 (x:word8) =
151    case x of
152    0w => 0x8w | 1w => 0x1w | 2w => 0x7w | 3w => 0xDw |
153    4w => 0x6w | 5w => 0xFw | 6w => 0x3w | 7w => 0x2w |
154    8w => 0x0w | 9w => 0xBw | 10w => 0x5w | 11w => 0x9w |
155    12w => 0xEw | 13w => 0xCw | 14w => 0xAw | 15w => 0x4w : word8`;
156
157val t01_def = Define`
158  t01 (x:word8) =
159    case x of
160    0w => 0xEw | 1w => 0xCw | 2w => 0xBw | 3w => 0x8w |
161    4w => 0x1w | 5w => 0x2w | 6w => 0x3w | 7w => 0x5w |
162    8w => 0xFw | 9w => 0x4w | 10w => 0xAw | 11w => 0x6w |
163    12w => 0x7w | 13w => 0x0w | 14w => 0x9w | 15w => 0xDw : word8`;
164
165val t02_def = Define`
166  t02 (x:word8) =
167    case x of
168    0w => 0xBw | 1w => 0xAw | 2w => 0x5w | 3w => 0xEw |
169    4w => 0x6w | 5w => 0xDw | 6w => 0x9w | 7w => 0x0w |
170    8w => 0xCw | 9w => 0x8w | 10w => 0xFw | 11w => 0x3w |
171    12w => 0x2w | 13w => 0x4w | 14w => 0x7w | 15w => 0x1w : word8`;
172
173val t03_def = Define`
174  t03 (x:word8) =
175    case x of
176    0w => 0xDw | 1w => 0x7w | 2w => 0xFw | 3w => 0x4w |
177    4w => 0x1w | 5w => 0x2w | 6w => 0x6w | 7w => 0xEw |
178    8w => 0x9w | 9w => 0xBw | 10w => 0x3w | 11w => 0x0w |
179    12w => 0x8w | 13w => 0x5w | 14w => 0xCw | 15w => 0xAw : word8`;
180
181(* The 4-bit S-boxes For the permutation q1 *)
182
183val t10_def = Define`
184  t10 (x:word8) =
185    case x of
186    0w => 0x2w | 1w => 0x8w | 2w => 0xBw | 3w => 0xDw |
187    4w => 0xFw | 5w => 0x7w | 6w => 0x6w | 7w => 0xEw |
188    8w => 0x3w | 9w => 0x1w | 10w => 0x9w | 11w => 0x4w |
189    12w => 0x0w | 13w => 0xAw | 14w => 0xCw | 15w => 0x5w : word8`;
190
191val t11_def = Define`
192  t11 (x:word8) =
193    case x of
194    0w => 0x1w | 1w => 0xEw | 2w => 0x2w | 3w => 0xBw |
195    4w => 0x4w | 5w => 0xCw | 6w => 0x3w | 7w => 0x7w |
196    8w => 0x6w | 9w => 0xDw | 10w => 0xAw | 11w => 0x5w |
197    12w => 0xFw | 13w => 0x9w | 14w => 0x0w | 15w => 0x8w : word8`;
198
199val t12_def = Define`
200  t12 (x:word8) =
201    case x of
202    0w => 0x4w | 1w => 0xCw | 2w => 0x7w | 3w => 0x5w |
203    4w => 0x1w | 5w => 0x6w | 6w => 0x9w | 7w => 0xAw |
204    8w => 0x0w | 9w => 0xEw | 10w => 0xDw | 11w => 0x8w |
205    12w => 0x2w | 13w => 0xBw | 14w => 0x3w | 15w => 0xFw : word8`;
206
207val t13_def = Define`
208  t13 (x:word8) =
209    case x of
210    0w => 0xBw | 1w => 0x9w | 2w => 0x5w | 3w => 0x1w |
211    4w => 0xCw | 5w => 0x3w | 6w => 0x3w | 7w => 0x7w |
212    8w => 0x6w | 9w => 0x4w | 10w => 0x7w | 11w => 0xFw |
213    12w => 0x2w | 13w => 0x0w | 14w => 0x8w | 15w => 0xAw : word8`;
214
215(* First, the byte is split into two nibbles. These are combined in a        *)
216(* bijective mixing step. Each nibble is then passed through its own 4-bit   *)
217(* fixed S-box. This is followed by another mixing step and S-box lookup.    *)
218(* Finally, the two nibbles are recombined into a byte.                      *)
219
220val qq_def = Define`
221  qq t0 t1 t2 t3 (x:word8) =
222    let (a0, b0) = ((x >> 4) && 0xfw, x && 0xfw) in
223    let (a1, b1) = (a0 ?? b0, a0 ?? ROR4(b0,1) ?? (8w*a0 && 0xfw)) in
224    let (a2, b2) = (t0(a1), t1(b1)) in
225    let (a3, b3) = (a2 ?? a2, a0 ?? ROR4(b2,1) ?? (8w*a2 && 0xfw)) in
226    let (a4, b4) = (t2(a3), t3(b3))
227    in 16w * b4 + a4 : word8`;
228
229val q0_def = Define`q0 = qq t00 t01 t02 t03`;
230val q1_def = Define`q1 = qq t10 t11 t12 t13`;
231
232(*---------------------------------------------------------------------------*)
233(* Function h takes two inputs--a 32-bit word X and a list L = (L0,...,Lk )  *)
234(* (here k = 4) of 32-bit words of and produces one word of output. This     *)
235(* function works in k stages. In each stage, the four bytes are each        *)
236(* passed through a fixed S-box, and xored with a byte derived from the list.*)
237(* Finally, the bytes are once again passed through a fixed Sbox, and the    *)
238(* four bytes are multiplied by the MDS matrix.                              *)
239(*---------------------------------------------------------------------------*)
240
241val fun_h_def = Define`
242  fun_h
243   ((x3,x2,x1,x0),(l33,l32,l31,l30),(l23,l22,l21,l20),
244    (l13,l12,l11,l10),l03,l02,l01,l00) =
245 (let (y0,y1,y2,y3) = (x0,x1,x2,x3) in
246  let (y0,y1,y2,y3) =
247        (q1 y0 ?? l30,q0 y1 ?? l31,q0 y2 ?? l32,q1 y3 ?? l33)          (* k=4 *)
248  in
249  let (y0,y1,y2,y3) =
250        (q1 y0 ?? l20,q1 y1 ?? l21,q0 y2 ?? l22,q0 y3 ?? l23)          (* k=3 *)
251  in
252  let (y0,y1,y2,y3) =
253        (q1 (q0 (q0 y0 ?? l10) ?? l00), q0 (q0 (q1 y1 ?? l11) ?? l01),
254         q1 (q1 (q0 y2 ?? l12) ?? l02), q0 (q1 (q1 y3 ?? l13) ?? l03)) (* k=2 *)
255  in
256  let (y0,y1,y2,y3) =
257        (q1 (q0 (q0 y0 ?? l10) ?? l00), q0 (q0 (q1 y1 ?? l11) ?? l01),
258         q1 (q1 (q0 y2 ?? l12) ?? l02), q0 (q1 (q1 y3 ?? l13) ?? l03)) (* k=1 *)
259  in
260    InvW (MDSMul (y0,y1,y2,y3)))`;
261
262(*---------------------------------------------------------------------------*)
263(* Take the key bytes in groups of 8, interpreting them as a vector over     *)
264(* GF(2^8), and multiplying them by a 4bytes 8 matrix derived from an RS code.*)
265(*---------------------------------------------------------------------------*)
266
267val genM_def = Define`
268  genM
269    ((m31,m30,m29,m28,m27,m26,m25,m24,m23,m22,m21,m20,m19,m18,m17,m16,m15,
270      m14,m13,m12,m11,m10,m9,m8,m7,m6,m5,m4,m3,m2,m1,m0):initkeys) =
271  let Me = ((m3,m2,m1,m0),(m11,m10,m9,m8),(m19,m18,m17,m16),(m27,m26,m25,m24))
272  in
273  let Mo = ((m7,m6,m5,m4),(m15,m14,m13,m12),(m23,m22,m21,m20),(m31,m30,m29,m28))
274  in
275    (Me, Mo)`;
276
277val genS_def = Define`
278  genS
279    (m31,m30,m29,m28,m27,m26,m25,m24,m23,m22,m21,m20,m19,m18,m17,m16,
280     m15,m14,m13,m12,m11,m10,m9,m8,m7,m6,m5,m4,m3,m2,m1,m0) =
281  (InvW (RSMul (m24,m25,m26,m27,m28,m29,m30,m31)),
282   InvW (RSMul (m16,m17,m18,m19,m20,m21,m22,m23)),
283   InvW (RSMul (m8,m9,m10,m11,m12,m13,m14,m15)),
284   InvW (RSMul (m0,m1,m2,m3,m4,m5,m6,m7)))`;
285
286(*---------------------------------------------------------------------------*)
287(* The words of the expanded key are defined using the h function. For Ai    *)
288(* the byte values are 2i, and the second argument of h is Me. Bi is         *)
289(* computed similarly using 2i + 1 as the byte value and Mo as the second    *)
290(* argument, with an extra rotate over 8 bits.  The values Ai and Bi are     *)
291(* combined in a PHT. One of the results is further rotated by 9 bits.       *)
292(*---------------------------------------------------------------------------*)
293
294val e_rnd_def = Define`
295  e_rnd(Me,Mo,i) =
296    let i = n2w i in
297    let Ai = toLarge(fun_h((2w*i, 2w*i, 2w*i, 2w*i), Me)) in
298    let Bi = toLarge(fun_h((2w*i+1w, 2w*i+1w, 2w*i+1w, 2w*i+1w), Mo)) #<< 8 in
299    let K2i = (Ai + Bi) && 0xffffffffw in
300    let K2i_1 = ((Ai + 2w * Bi) && 0xffffffffw) #<< 9
301    in (K2i, K2i_1)`;
302
303val expandKeys_def = Define`
304  expandKeys (Me,Mo) =
305     (FST (e_rnd (Me,Mo,0)),SND (e_rnd (Me,Mo,0)),
306      FST (e_rnd (Me,Mo,1)),SND (e_rnd (Me,Mo,1)),
307      FST (e_rnd (Me,Mo,2)),SND (e_rnd (Me,Mo,2)),
308      FST (e_rnd (Me,Mo,3)),SND (e_rnd (Me,Mo,3)),
309      FST (e_rnd (Me,Mo,4)),SND (e_rnd (Me,Mo,4)),
310      FST (e_rnd (Me,Mo,5)),SND (e_rnd (Me,Mo,5)),
311      FST (e_rnd (Me,Mo,6)),SND (e_rnd (Me,Mo,6)),
312      FST (e_rnd (Me,Mo,7)),SND (e_rnd (Me,Mo,7)),
313      FST (e_rnd (Me,Mo,8)),SND (e_rnd (Me,Mo,8)),
314      FST (e_rnd (Me,Mo,9)),SND (e_rnd (Me,Mo,9)),
315      FST (e_rnd (Me,Mo,10)),SND (e_rnd (Me,Mo,10)),
316      FST (e_rnd (Me,Mo,11)),SND (e_rnd (Me,Mo,11)),
317      FST (e_rnd (Me,Mo,12)),SND (e_rnd (Me,Mo,12)),
318      FST (e_rnd (Me,Mo,13)),SND (e_rnd (Me,Mo,13)),
319      FST (e_rnd (Me,Mo,14)),SND (e_rnd (Me,Mo,14)),
320      FST (e_rnd (Me,Mo,15)),SND (e_rnd (Me,Mo,15)),
321      FST (e_rnd (Me,Mo,16)),SND (e_rnd (Me,Mo,16)),
322      FST (e_rnd (Me,Mo,17)),SND (e_rnd (Me,Mo,17)),
323      FST (e_rnd (Me,Mo,18)),SND (e_rnd (Me,Mo,18)),
324      FST (e_rnd (Me,Mo,19)),SND (e_rnd (Me,Mo,19)))`;
325
326(*---------------------------------------------------------------------------*)
327(*---------------------------------------------------------------------------*)
328
329val ROTKEYS_def = Define`
330  ROTKEYS
331   (k0,k1,k2,k3,k4,k5,k6,k7,k8,k9,k10,k11,k12,k13,k14,k15,k16,k17,
332    k18,k19,k20,k21,k22,k23,k24,k25,k26,k27,k28,k29,k30,k31,k32,k33,
333    k34,k35,k36,k37,k38,k39) =
334 (k2,k3,k4,k5,k6,k7,k8,k9,k10,k11,k12,k13,k14,k15,k16,k17,k18,k19,
335  k20,k21,k22,k23,k24,k25,k26,k27,k28,k29,k30,k31,k32,k33,k34,k35,
336  k36,k37,k38,k39,k0,k1)`;
337
338val ROTKEYS8_def = Define`
339  ROTKEYS8
340   (k0,k1,k2,k3,k4,k5,k6,k7,k8,k9,k10,k11,k12,k13,k14,k15,k16,k17,
341    k18,k19,k20,k21,k22,k23,k24,k25,k26,k27,k28,k29,k30,k31,k32,k33,
342    k34,k35,k36,k37,k38,k39) =
343 (k8,k9,k10,k11,k12,k13,k14,k15,k16,k17,k18,k19,k20,k21,k22,k23,k24,
344  k25,k26,k27,k28,k29,k30,k31,k32,k33,k34,k35,k36,k37,k38,k39,k0,k1,
345  k2,k3,k4,k5,k6,k7)`;
346
347val GETKEYS_def = Define`
348  GETKEYS
349   (k0,k1,k2,k3,k4,k5,k6,k7,k8,k9,k10,k11,k12,k13,k14,k15,k16,k17,
350    k18,k19,k20,k21,k22,k23,k24,k25,k26,k27,k28,k29,k30,k31,k32,k33,
351    k34,k35,k36,k37,k38,k39) =
352 (k0,k1)`;
353
354val FORALL_INITKEYS = Q.prove(
355 `(!x:initkeys. P x) =
356   !k0 k1 k2 k3 k4 k5 k6 k7 k8 k9 k10 k11 k12 k13 k14 k15 k16 k17 k18
357      k19 k20 k21 k22 k23 k24 k25 k26 k27 k28 k29 k30 k31.
358    P (k0,k1,k2,k3,k4,k5,k6,k7,k8,k9,k10,k11,k12,k13,k14,k15,k16,k17,
359       k18,k19,k20,k21,k22,k23,k24,k25,k26,k27,k28,k29,k30,k31)`,
360  SIMP_TAC std_ss [FORALL_PROD]);
361
362val FORALL_KEYSCHEDS = Q.prove(
363 `(!x:keysched. P x) =
364  !k0 k1 k2 k3 k4 k5 k6 k7 k8 k9 k10 k11 k12 k13 k14 k15 k16 k17 k18
365     k19 k20 k21 k22 k23 k24 k25 k26 k27 k28 k29 k30 k31 k32 k33 k34
366     k35 k36 k37 k38 k39.
367   P (k0,k1,k2,k3,k4,k5,k6,k7,k8,k9,k10,k11,k12,k13,k14,k15,k16,k17,
368      k18,k19,k20,k21,k22,k23,k24,k25,k26,k27,k28,k29,k30,k31,k32,k33,
369      k34,k35,k36,k37,k38,k39)`,
370  SIMP_TAC std_ss [FORALL_PROD]);
371
372(*---------------------------------------------------------------------------*)
373(* Sanity check                                                              *)
374(*---------------------------------------------------------------------------*)
375
376val toList_def = Define`
377  toList (k:keysched) =
378  (let (k0,k1,k2,k3,k4,k5,k6,k7,k8,k9,k10,k11,k12,k13,k14,k15,k16,
379        k17,k18,k19,k20,k21,k22,k23,k24,k25,k26,k27,k28,k29,k30,k31,
380        k32,k33,k34,k35,k36,k37,k38,k39) = k
381   in
382     [k0; k1; k2; k3; k4; k5; k6; k7; k8; k9; k10; k11; k12; k13;
383      k14; k15; k16; k17; k18; k19; k20; k21; k22; k23; k24; k25;
384      k26; k27; k28; k29; k30; k31; k32; k33; k34; k35; k36; k37;
385      k38; k39])`;
386
387val keysched_length = Count.apply Q.prove(
388  `!Me Mo. LENGTH (toList(expandKeys(Me,Mo))) = 40`,
389  SRW_TAC [] [expandKeys_def, toList_def] THEN EVAL_TAC);
390
391(*---------------------------------------------------------------------------*)
392(* The Key-dependent S-boxes *)
393(*---------------------------------------------------------------------------*)
394
395val fun_g_def = Define`
396  fun_g(x, SS) = toLarge(fun_h(x,SS))`;
397
398(*---------------------------------------------------------------------------*)
399(* The function FF is a key-dependent permutation on 64-bit values *)
400(*---------------------------------------------------------------------------*)
401
402val FF_def = Define`
403  FF((R0,R1),(K0,K1),SS) =
404  let T0 = fun_g(fromLarge(R0),SS) in
405  let T1 = fun_g(fromLarge(R1 #<< 8),SS) in
406  let F0 = (T0 + T1 + K0) && 0xffffffffw in
407  let F1 = (T0 + 2w*T1+ K1) && 0xffffffffw
408  in (F0,F1)`;
409
410(*---------------------------------------------------------------------------*)
411(*-------------Forward round used by the encrypting function-----------------*)
412(*---------------------------------------------------------------------------*)
413
414(* The operation in each of the 16 rounds *)
415
416val Round_Op_def = Define`
417  Round_Op((R0,R1,R2,R3),k,ss) =
418  let (F0, F1) = FF((R0,R1),GETKEYS(k), ss) in
419  let R0' = (R2 ?? F0) #>> 1 in
420  let R1' = (R3 #<< 1) ?? F1
421  in (R0', R1', R0, R1)`;
422
423val (en_rnd_def, en_rnd_ind) = Defn.tprove (
424    Hol_defn "en_rnd"
425    `en_rnd i (b:block) k ss =
426     if i=0 then b
427     else en_rnd (i-1) (Round_Op(b,k,ss)) (ROTKEYS(k)) ss`,
428  WF_REL_TAC `measure FST` THEN REPEAT PairRules.PGEN_TAC THEN DECIDE_TAC);
429
430val _ = save_thm ("en_rnd_def", en_rnd_def);
431val _ = save_thm ("en_rnd_ind", en_rnd_ind);
432
433val fwd_def = Define `fwd(b,k,s) = en_rnd 16 b k s`;
434
435(*---------------------------------------------------------------------------*)
436(*-------------Backward round used by the decrypting function----------------*)
437(*---------------------------------------------------------------------------*)
438
439(* Decryption. Note that (R2,R3) at round r+1 = (R0,R1) at round r *)
440val InvRound_Op_def = Define`
441  InvRound_Op((R0,R1,R2,R3),k,ss) =
442  let (F0, F1) = FF((R2,R3),GETKEYS(k),ss) in
443  let R0' = (R0 #<< 1) ?? F0 in
444  let R1' = (R1 ?? F1) #>> 1
445  in (R2, R3, R0', R1')`;
446
447val (de_rnd_def, de_rnd_ind) = Defn.tprove (
448    Hol_defn "de_rnd"
449    `de_rnd i (b:block) k ss =
450     if i=0 then b
451     else InvRound_Op(de_rnd (i-1) b (ROTKEYS(k)) ss, k, ss)`,
452  WF_REL_TAC `measure FST` THEN REPEAT PairRules.PGEN_TAC THEN DECIDE_TAC);
453
454val _ = save_thm ("en_rnd_def", en_rnd_def);
455val _ = save_thm ("de_rnd__ind", de_rnd_ind);
456
457val bwd_def = Define `bwd(b,k,s) = de_rnd 16 b k s`;
458
459(* --------------------------------------------------------------------------*)
460(*-------------Forward and backward round operation inversion lemmas---------*)
461(*---------------------------------------------------------------------------*)
462
463val PBETA_ss = simpLib.conv_ss
464  {name="PBETA",trace = 3,conv=K (K PairRules.PBETA_CONV),
465   key = SOME([],``(\(x:'a,y:'b). s1) s2:'c``)};
466
467val Round_Inversion = Q.store_thm("Round_Inversion",
468  `!b k s. InvRound_Op(Round_Op(b,k,s),k,s) = b`,
469  SIMP_TAC std_ss [FORALL_BLOCK, FORALL_KEY]
470    THEN SRW_TAC [boolSimps.LET_ss,PBETA_ss] [Round_Op_def,InvRound_Op_def]);
471
472val [Round_Op] = decls "Round_Op";
473val [InvRound_Op] = decls "InvRound_Op";
474
475val Round_Inversion_LEMMA = Q.store_thm("Round_Inversion_LEMMA",
476  `!b k s. bwd(fwd(b,k,s),k,s) = b`,
477  SIMP_TAC std_ss [FORALL_BLOCK]
478    THEN computeLib.RESTR_EVAL_TAC [Round_Op, InvRound_Op]
479    THEN RW_TAC std_ss [Round_Inversion]);
480
481(*---------------------------------------------------------------------------*)
482(* Input whitening and output whitening                                      *)
483(*---------------------------------------------------------------------------*)
484
485val In_Whiten_def = Define`
486  In_Whiten(b:block, k) =
487    let (R0,R1,R2,R3) = b in
488    (R0 ?? FST(GETKEYS(k)), R1 ?? SND(GETKEYS(k)),
489     R2 ?? FST(GETKEYS(ROTKEYS(k))), R3 ?? SND(GETKEYS(ROTKEYS(k))))`;
490
491val Out_Whiten_def = Define`
492  Out_Whiten(b:block, k) =
493    let (R0,R1,R2,R3) = b in
494    (R0 ?? FST(GETKEYS(ROTKEYS(ROTKEYS(k)))),
495     R1 ?? SND(GETKEYS(ROTKEYS(ROTKEYS(k)))),
496     R2 ?? FST(GETKEYS(ROTKEYS(ROTKEYS(ROTKEYS(k))))),
497     R3 ?? SND(GETKEYS(ROTKEYS(ROTKEYS(ROTKEYS(k))))))`;
498
499val WHITENING_LEMMA = Q.store_thm("WHITENING_LEMMA",
500  `!(b:block) (k:keysched).
501    (Out_Whiten(Out_Whiten(b,k),k) = b) /\ (In_Whiten(In_Whiten(b,k),k) = b)`,
502  SRW_TAC [] [Out_Whiten_def, In_Whiten_def]);
503
504(*---------------------------------------------------------------------------*)
505(* Encrypt and Decrypt                                                       *)
506(*---------------------------------------------------------------------------*)
507(*  In the input whitening step, these words are xored
508    with 4 words of the expanded key. Then goes the 16 rounds.
509    Finally the output whitening step undoes the `swap' of the
510    last round, and xors the data words with 4 words of the expanded key.*)
511
512val TwofishEncrypt_def = Define`
513  TwofishEncrypt initM b =
514  let (k, ss) = (expandKeys(genM(initM)), genS(initM))
515  in  Out_Whiten(fwd(In_Whiten(b,k),ROTKEYS8(k),ss), k)`;
516
517val TwofishDecrypt_def = Define`
518  TwofishDecrypt initM b =
519  let (k, ss) = (expandKeys(genM(initM)), genS(initM))
520  in  In_Whiten(bwd(Out_Whiten(b,k),ROTKEYS8(k),ss), k)`;
521
522(*---------------------------------------------------------------------------*)
523(* Main Lemma                                                                *)
524(*---------------------------------------------------------------------------*)
525
526val TWOFISH_LEMMA = Q.store_thm("TWOFISH_LEMMA",
527  `!(plaintext:block) (keys:initkeys).
528     TwofishDecrypt keys (TwofishEncrypt keys plaintext) = plaintext`,
529  RW_TAC std_ss [TwofishEncrypt_def]
530    THEN RW_TAC std_ss [TwofishDecrypt_def]
531    THEN RW_TAC std_ss [WHITENING_LEMMA, Round_Inversion_LEMMA]);
532
533(*---------------------------------------------------------------------------*)
534(* Basic theorem about encryption/decryption                                 *)
535(*---------------------------------------------------------------------------*)
536
537val TWOFISH_def = Define`
538  TWOFISH (keys) =
539  (TwofishEncrypt keys,  TwofishDecrypt keys)`;
540
541val TWOFISH_CORRECT = Q.store_thm("TWOFISH_CORRECT",
542   `!key plaintext.
543       ((encrypt,decrypt) = TWOFISH key)
544       ==>
545       (decrypt (encrypt plaintext) = plaintext)`,
546         RW_TAC std_ss [TWOFISH_def,LET_THM,TWOFISH_LEMMA]);
547
548val _ = export_theory();
549