1(* ========================================================================= *)
2(* FILE          : mlMatrix.sml                                              *)
3(* DESCRIPTION   : Matrix operations.                                        *)
4(*                 Matrix are represented as lists of lines                  *)
5(* AUTHOR        : (c) Thibault Gauthier, Czech Technical University         *)
6(* DATE          : 2018                                                      *)
7(* ========================================================================= *)
8
9structure mlMatrix :> mlMatrix =
10struct
11
12open HolKernel Abbrev boolLib aiLib
13
14val ERR = mk_HOL_ERR "mlMatrix"
15
16type vect = real vector
17type mat = real vector vector
18
19(* -------------------------------------------------------------------------
20   Vectors
21   ------------------------------------------------------------------------- *)
22
23fun sum_rvect v = Vector.foldl (op +) 0.0 v
24
25fun average_rvect v = sum_rvect v / Real.fromInt (Vector.length v)
26
27fun diff_rvect v1 v2 =
28  let fun f i = Vector.sub (v1,i) - Vector.sub (v2,i) in
29    Vector.tabulate (Vector.length v1, f)
30  end
31
32fun mult_rvect v1 v2 =
33  let fun f i = Vector.sub (v1,i) * Vector.sub (v2,i) in
34    Vector.tabulate (Vector.length v1, f)
35  end
36
37fun scalar_product v1 v2 = sum_rvect (mult_rvect v1 v2)
38
39fun scalar_mult k v = Vector.map (fn x => k * x) v
40
41fun add_vectl vl =
42  let fun f i = sum_real (map (fn x => Vector.sub (x,i)) vl) in
43    Vector.tabulate (Vector.length (hd vl), f)
44  end
45
46(* -------------------------------------------------------------------------
47   Matrix
48   ------------------------------------------------------------------------- *)
49
50fun mat_mult m inv =
51  let fun f line = scalar_product line inv in Vector.map f m end
52
53fun mat_map f m = Vector.map (Vector.map f) m
54
55fun mat_tabulate f (linen,coln) =
56  let fun mk_line i = Vector.tabulate (coln, f i) in
57    Vector.tabulate (linen, mk_line)
58  end
59
60fun mat_smult (k:real) m = mat_map (fn x => k * x) m
61
62fun mat_dim m = (Vector.length m, Vector.length (Vector.sub (m,0)))
63
64fun mat_sub m i j = Vector.sub (Vector.sub (m,i), j)
65
66fun mat_update m ((i,j),k) =
67  let val newv = Vector.update (Vector.sub(m,i),j,k) in
68    Vector.update (m,i,newv)
69  end
70
71fun mat_add m1 m2 =
72  let fun f i j = mat_sub m1 i j + mat_sub m2 i j in
73    mat_tabulate f (mat_dim m1)
74  end
75
76fun matl_add ml = case ml of
77    [] => raise ERR "mat_addl" ""
78  | [m] => m
79  | m :: contl => mat_add m (matl_add contl)
80
81fun inv_dim (a,b) = (b,a)
82
83fun mat_transpose m1 =
84  let fun f i j = mat_sub m1 j i in
85    mat_tabulate f (inv_dim (mat_dim m1))
86  end
87
88fun mat_random (dim as (a,b)) =
89  let
90    val r = Math.sqrt (6.0 / (Real.fromInt (a + b)))
91    fun f i j = r * (2.0 * random_real () - 1.0)
92  in
93    mat_tabulate f dim
94  end
95
96(* -------------------------------------------------------------------------
97   Input/output
98   ------------------------------------------------------------------------- *)
99
100fun string_of_vect v =
101  String.concatWith " " (map Real.toString (vector_to_list v))
102fun string_of_mat m =
103  String.concatWith "\n" (map string_of_vect (vector_to_list m))
104
105local open HOLsexp in
106fun enc_vect v = list_encode enc_real (vector_to_list v)
107fun dec_vect t = Option.map Vector.fromList (list_decode dec_real t)
108fun enc_mat m = list_encode enc_vect (vector_to_list m)
109fun dec_mat t = Option.map Vector.fromList (list_decode dec_vect t)
110end
111
112
113end (* struct *)
114