1(* -------------------------------------------------------------------------
2   Conversions for evaluating some 64-bit (double precision) IEEE-754
3   operations. They can be enabled in EVAL with
4
5    > set_trace "native IEEE" 1;
6
7   Calculations are perfromed in hardware using SML's Real and Math
8   structures and the theorems are produced using Thm.mk_oracle_thm.
9
10   NOTE: Poly/ML 5.5's multiply isn't fully IEEE compliant on 64-bit machines,
11         so the results should not be fully trusted.
12   ------------------------------------------------------------------------- *)
13structure native_ieeeLib :> native_ieeeLib =
14struct
15
16open HolKernel Parse boolLib bossLib
17open wordsLib binary_ieeeLib binary_ieeeSyntax fp64Syntax
18
19structure Parse =
20struct
21  open Parse
22  val (Type, Term) =
23     parse_from_grammars binary_ieeeTheory.binary_ieee_grammars
24end
25open Parse
26
27val ERR = Feedback.mk_HOL_ERR "native_ieeeLib"
28
29(* -------------------------------------------------------------------------
30   numToReal
31   realToNum
32   ------------------------------------------------------------------------- *)
33
34val n256 = Arbnum.fromInt 256
35
36local
37   val byte =  Word8.fromInt o Arbnum.toInt
38   fun loop a i x =
39      if i <= 0
40         then byte (Arbnum.mod (x, n256)) :: a
41      else let
42              val (r, q) = Arbnum.divmod (x, n256)
43           in
44              loop (byte q :: a) (i - 1) r
45           end
46in
47   val numToReal = PackRealBig.fromBytes o Word8Vector.fromList o
48                   loop [] (PackRealBig.bytesPerElem - 1)
49end
50
51local
52   val byte = Arbnum.fromInt o Word8.toInt o Word8Vector.sub
53in
54   fun realToNum r =
55      if Real.isNan r
56         then raise ERR "realToNum" "NaN"
57      else let
58              val v = PackRealBig.toBytes r
59              val l = List.tabulate
60                        (PackRealBig.bytesPerElem, fn i => byte (v, 7 - i))
61           in
62              List.foldl
63                 (fn (b, a) => Arbnum.+ (Arbnum.* (a, n256), b)) Arbnum.zero
64                 (List.rev l)
65           end
66end
67
68(* -------------------------------------------------------------------------
69   wordToReal
70   realToWord
71   ------------------------------------------------------------------------- *)
72
73val irealwidth = 8 * PackRealBig.bytesPerElem
74val realwidth = Arbnum.fromInt irealwidth
75
76fun wordToReal tm =
77   let
78      val (v, n) = wordsSyntax.dest_mod_word_literal tm
79   in
80      n = realwidth orelse raise ERR "wordToReal" "length mismatch"
81    ; numToReal v
82   end
83
84fun realToWord r = wordsSyntax.mk_word (realToNum r, realwidth)
85
86(* -------------------------------------------------------------------------
87   floatToReal
88   realToFloat
89   ------------------------------------------------------------------------- *)
90
91local
92  val native_ty =
93    binary_ieeeSyntax.mk_ifloat_ty
94      (Real.precision - 1, irealwidth - Real.precision)
95  val native_itself =
96    (boolSyntax.mk_itself o pairSyntax.mk_prod o
97     binary_ieeeSyntax.dest_float_ty) native_ty
98  val native_plus_infinity_tm =
99    binary_ieeeSyntax.mk_float_plus_infinity native_itself
100  val native_minus_infinity_tm =
101    binary_ieeeSyntax.mk_float_minus_infinity native_itself
102  val exponent = irealwidth - Real.precision
103  val signval = Arbnum.pow (Arbnum.two, Arbnum.fromInt (irealwidth - 1))
104  val expval = Arbnum.pow (Arbnum.two, Arbnum.fromInt exponent)
105  val manval = Arbnum.pow (Arbnum.two, Arbnum.fromInt (Real.precision - 1))
106  fun odd n = Arbnum.mod (n, Arbnum.two) = Arbnum.one
107in
108  fun floatToReal tm =
109    let
110       val ((t, w), (s, e, f)) = binary_ieeeSyntax.triple_of_float tm
111       val _ = t + 1 = Real.precision andalso w = exponent orelse
112               raise ERR "floatToReal" "size mismatch"
113    in
114       numToReal
115          (Arbnum.+ (if s then signval else Arbnum.zero,
116                     Arbnum.+ (Arbnum.* (e, manval), f)))
117    end
118    handle e as HOL_ERR {origin_function = "dest_floating_point", ...} =>
119       if Term.type_of tm = native_ty
120          then if binary_ieeeSyntax.is_float_plus_infinity tm
121                  then Real.posInf
122               else if binary_ieeeSyntax.is_float_minus_infinity tm
123                  then Real.negInf
124               else raise e
125       else raise ERR "floatToReal" "not native float type"
126  fun realToFloat r =
127    case Real.class r of
128       IEEEReal.INF => if Real.signBit r then native_minus_infinity_tm
129                       else native_plus_infinity_tm
130     | IEEEReal.NAN => raise ERR "realToFloat" "NaN"
131     | _ =>
132         let
133            val n = realToNum r
134            val (e, f) = Arbnum.divmod (n, manval)
135            val (s, e) = Arbnum.divmod (e, expval)
136         in
137            binary_ieeeSyntax.float_of_triple
138              ((Real.precision - 1, exponent), (odd s, e, f))
139         end
140end
141
142(* -------------------------------------------------------------------------
143   Native conversions
144   ------------------------------------------------------------------------- *)
145
146fun withRounding tm f x =
147   let
148      val mode = IEEEReal.getRoundingMode ()
149   in
150      IEEEReal.setRoundingMode
151        (if tm = binary_ieeeSyntax.roundTiesToEven_tm
152            then IEEEReal.TO_NEAREST
153         else if tm = binary_ieeeSyntax.roundTowardZero_tm
154            then IEEEReal.TO_ZERO
155         else if tm = binary_ieeeSyntax.roundTowardNegative_tm
156            then IEEEReal.TO_NEGINF
157         else if tm = binary_ieeeSyntax.roundTowardPositive_tm
158            then IEEEReal.TO_POSINF
159         else raise ERR "setRounding" "not a valid mode")
160     ; f x before IEEEReal.setRoundingMode mode
161   end
162
163fun mk_native_ieee_thm th = Thm.mk_oracle_thm "native_ieee" ([], th)
164
165fun mk_native tm r =
166  if Real.isNan r then raise ERR "mk_native" "result is NaN"
167  else mk_native_ieee_thm (boolSyntax.mk_eq (tm, realToWord r))
168
169val wordToReal' = Lib.total wordToReal
170
171fun lift1 f dst tm =
172  case Lib.total dst tm of
173     SOME (mode, a) =>
174       (case wordToReal' a of
175           SOME ra => withRounding mode (mk_native tm o f) ra
176         | _ => raise ERR "lift1" "failed to convert to native real")
177   | NONE => raise ERR "lift1" ""
178
179fun lift2 f dst tm =
180  case Lib.total dst tm of
181     SOME (mode, a, b) =>
182       (case (wordToReal' a, wordToReal' b) of
183           (SOME ra, SOME rb) =>
184             withRounding mode (mk_native tm o f) (ra, rb)
185         | _ => raise ERR "lift2" "failed to convert to native reals")
186   | NONE => raise ERR "lif2" ""
187
188fun liftOrder f dst tm =
189  case Lib.total dst tm of
190     SOME (a, b) =>
191       (case (wordToReal' a, wordToReal' b) of
192           (SOME ra, SOME rb) =>
193               mk_native_ieee_thm (boolSyntax.mk_eq (tm, f (ra, rb)))
194         | _ => raise ERR "liftNativeOrder"
195                          "failed to convert to native reals")
196   | NONE => raise ERR "liftNativeOrder" ""
197
198val float_compare =
199  (fn IEEEReal.LESS      => binary_ieeeSyntax.LT_tm
200    | IEEEReal.EQUAL     => binary_ieeeSyntax.EQ_tm
201    | IEEEReal.GREATER   => binary_ieeeSyntax.GT_tm
202    | IEEEReal.UNORDERED => binary_ieeeSyntax.UN_tm) o Real.compareReal
203
204val mk_b = fn true => boolSyntax.T | _ => boolSyntax.F
205
206val () =
207  ( machine_ieeeTheory.sqrt_CONV := lift1 Math.sqrt fp64Syntax.dest_fp_sqrt
208  ; machine_ieeeTheory.add_CONV := lift2 Real.+ fp64Syntax.dest_fp_add
209  ; machine_ieeeTheory.sub_CONV := lift2 Real.- fp64Syntax.dest_fp_sub
210  ; machine_ieeeTheory.mul_CONV := lift2 Real.* fp64Syntax.dest_fp_mul
211  ; machine_ieeeTheory.div_CONV := lift2 Real./ fp64Syntax.dest_fp_div
212  ; machine_ieeeTheory.compare_CONV :=
213      liftOrder float_compare fp64Syntax.dest_fp_compare
214  ; machine_ieeeTheory.eq_CONV :=
215      liftOrder (mk_b o Real.==) fp64Syntax.dest_fp_equal
216  ; machine_ieeeTheory.lt_CONV :=
217      liftOrder (mk_b o Real.<) fp64Syntax.dest_fp_lessThan
218  ; machine_ieeeTheory.le_CONV :=
219      liftOrder (mk_b o Real.<=) fp64Syntax.dest_fp_lessEqual
220  ; machine_ieeeTheory.gt_CONV :=
221      liftOrder (mk_b o Real.>)  fp64Syntax.dest_fp_greaterThan
222  ; machine_ieeeTheory.ge_CONV :=
223      liftOrder (mk_b o Real.>=) fp64Syntax.dest_fp_greaterEqual
224  )
225
226(* ------------------------------------------------------------------------ *)
227
228end
229