1(*---------------------------------------------------------------------------*)
2(* Multiplication by a constant. Recursive, iterative, and table-lookup      *)
3(* versions.                                                                 *)
4(*---------------------------------------------------------------------------*)
5
6(* For interactive work
7  load "wordsLib";
8  quietdec := true;
9  open wordsTheory bitTheory wordsLib arithmeticTheory;
10  quietdec := false;
11*)
12
13open HolKernel Parse boolLib bossLib
14     wordsTheory bitTheory wordsLib arithmeticTheory;
15
16val _ = new_theory "Mult";
17
18(*---------------------------------------------------------------------------
19    Multiply a byte (representing a polynomial) by x.
20 ---------------------------------------------------------------------------*)
21
22val xtime_def = Define
23  `xtime (w : word8) =
24     w << 1 ?? (if word_msb w then 0x1Bw else 0w)`;
25
26val MSB_lem = Q.prove (
27  `!a b. word_msb (a ?? b) = ~(word_msb a = word_msb b)`,
28  SRW_TAC [WORD_BIT_EQ_ss] []);
29
30val xtime_distrib = Q.store_thm
31("xtime_distrib",
32 `!a b. xtime (a ?? b) = xtime a ?? xtime b`,
33  SRW_TAC [] [xtime_def, MSB_lem] THEN FULL_SIMP_TAC std_ss []);
34
35(*---------------------------------------------------------------------------*)
36(* Multiplication by a constant                                              *)
37(*---------------------------------------------------------------------------*)
38
39val _ = set_fixity "**" (Infixl 675);
40
41val ConstMult_def =
42 xDefine
43   "ConstMult"
44   `b1 ** b2 =
45      if b1 = 0w:word8 then 0w else
46      if word_lsb b1
47         then b2 ?? ((b1 >>> 1) ** xtime b2)
48         else       ((b1 >>> 1) ** xtime b2)`;
49
50val _ = computeLib.add_persistent_funs ["ConstMult_def"];
51
52val ConstMultDistrib = Q.store_thm
53("ConstMultDistrib",
54 `!x y z. x ** (y ?? z) = (x ** y) ?? (x ** z)`,
55 recInduct (theorem "ConstMult_ind")
56   THEN REPEAT STRIP_TAC
57   THEN ONCE_REWRITE_TAC [ConstMult_def]
58   THEN SRW_TAC [] [xtime_distrib]);
59
60(*---------------------------------------------------------------------------*)
61(* Iterative version                                                         *)
62(*---------------------------------------------------------------------------*)
63
64val IterConstMult_def =
65 Define
66   `IterConstMult (b1,b2,acc) =
67       if b1 = 0w:word8 then (b1,b2,acc)
68       else IterConstMult (b1 >>> 1, xtime b2,
69                           if word_lsb b1 then (b2 ?? acc) else acc)`;
70
71val _ = computeLib.add_persistent_funs ["IterConstMult_def"];
72
73(*---------------------------------------------------------------------------*)
74(* Equivalence between recursive and iterative forms.                        *)
75(*---------------------------------------------------------------------------*)
76
77val ConstMultEq = Q.store_thm
78("ConstMultEq",
79 `!b1 b2 acc. (b1 ** b2) ?? acc = SND(SND(IterConstMult (b1,b2,acc)))`,
80 recInduct (theorem "IterConstMult_ind") THEN RW_TAC std_ss []
81   THEN ONCE_REWRITE_TAC [ConstMult_def,IterConstMult_def]
82   THEN FULL_SIMP_TAC (srw_ss()) [] THEN SRW_TAC [] []);
83
84(*---------------------------------------------------------------------------*)
85(* Tabled versions                                                           *)
86(*---------------------------------------------------------------------------*)
87
88fun UNROLL_RULE 0 def = def
89  | UNROLL_RULE n def =
90     SIMP_RULE arith_ss [LSR_ADD]
91     (GEN_REWRITE_RULE (RHS_CONV o DEPTH_CONV) empty_rewrites [def]
92                       (UNROLL_RULE (n - 1) def));
93val instantiate =
94 SIMP_RULE (srw_ss()) [GSYM xtime_distrib] o
95 ONCE_REWRITE_CONV [UNROLL_RULE 4 (SPEC_ALL ConstMult_def)];
96
97val IterMult2 = UNROLL_RULE 1 (SPEC_ALL IterConstMult_def);
98
99(*---------------------------------------------------------------------------*)
100(* mult_unroll                                                               *)
101(*    |- (2w ** x = xtime x) /\                                              *)
102(*       (3w ** x = x ?? xtime x) /\                                         *)
103(*       (9w ** x = x ?? xtime (xtime (xtime x)))      /\                    *)
104(*       (11w ** x = x ?? xtime (x ?? xtime (xtime x))) /\                   *)
105(*       (13w ** x = x ?? xtime (xtime (x ?? xtime x))) /\                   *)
106(*       (14w ** x = xtime (x ?? xtime (x ?? xtime x)))                      *)
107(*---------------------------------------------------------------------------*)
108
109val mult_unroll = save_thm("mult_unroll",
110  LIST_CONJ (map instantiate
111    [``0x2w ** x``, ``0x3w ** x``, ``0x9w ** x``,
112     ``0xBw ** x``, ``0xDw ** x``, ``0xEw ** x``]));
113
114val eval_mult = WORD_EVAL_RULE o PURE_REWRITE_CONV [mult_unroll,
115  CONV_RULE (STRIP_QUANT_CONV (RHS_CONV (SIMP_CONV (srw_ss())
116    [WORD_MUL_LSL, COND_RAND]))) xtime_def];
117
118fun mk_word8 i = wordsSyntax.mk_n2w(numSyntax.term_of_int i, ``:8``);
119
120(*---------------------------------------------------------------------------*)
121(* Construct specialized multiplication tables.                              *)
122(*---------------------------------------------------------------------------*)
123
124val mult_tables =
125  LIST_CONJ (List.concat (map (fn x => List.tabulate(256,
126       fn i => let val y = mk_word8 i in eval_mult ``^x ** ^y`` end))
127  [``0x2w:word8``, ``0x3w:word8``, ``0x9w:word8``,
128   ``0xBw:word8``, ``0xDw:word8``, ``0xEw:word8``]));
129
130(*---------------------------------------------------------------------------*)
131(* Multiplication by constant implemented by one-step rewrites.              *)
132(*---------------------------------------------------------------------------*)
133
134val _ = save_thm ("mult_tables", mult_tables)
135
136(*---------------------------------------------------------------------------*)
137(* Multiplication by constant implemented by lookup into balanced binary     *)
138(* tree. Lookup is done bit-by-bit.                                          *)
139(*---------------------------------------------------------------------------*)
140
141(*
142val _ = save_thm ("mult_ifs", mult_ifs)
143*)
144
145(*---------------------------------------------------------------------------*)
146(* Exponentiation                                                            *)
147(*---------------------------------------------------------------------------*)
148
149val PolyExp_def =
150 Define
151   `PolyExp x n = if n=0 then 1w else x ** PolyExp x (n-1)`;
152
153
154val _ = export_theory();
155