1(* ========================================================================= *) 2(* FILE : mlMatrix.sml *) 3(* DESCRIPTION : Matrix operations. *) 4(* Matrix are represented as lists of lines *) 5(* AUTHOR : (c) Thibault Gauthier, Czech Technical University *) 6(* DATE : 2018 *) 7(* ========================================================================= *) 8 9structure mlMatrix :> mlMatrix = 10struct 11 12open HolKernel Abbrev boolLib aiLib 13 14val ERR = mk_HOL_ERR "mlMatrix" 15 16type vect = real vector 17type mat = real vector vector 18 19(* ------------------------------------------------------------------------- 20 Vectors 21 ------------------------------------------------------------------------- *) 22 23fun sum_rvect v = Vector.foldl (op +) 0.0 v 24 25fun average_rvect v = sum_rvect v / Real.fromInt (Vector.length v) 26 27fun diff_rvect v1 v2 = 28 let fun f i = Vector.sub (v1,i) - Vector.sub (v2,i) in 29 Vector.tabulate (Vector.length v1, f) 30 end 31 32fun mult_rvect v1 v2 = 33 let fun f i = Vector.sub (v1,i) * Vector.sub (v2,i) in 34 Vector.tabulate (Vector.length v1, f) 35 end 36 37fun scalar_product v1 v2 = sum_rvect (mult_rvect v1 v2) 38 39fun scalar_mult k v = Vector.map (fn x => k * x) v 40 41fun add_vectl vl = 42 let fun f i = sum_real (map (fn x => Vector.sub (x,i)) vl) in 43 Vector.tabulate (Vector.length (hd vl), f) 44 end 45 46(* ------------------------------------------------------------------------- 47 Matrix 48 ------------------------------------------------------------------------- *) 49 50fun mat_mult m inv = 51 let fun f line = scalar_product line inv in Vector.map f m end 52 53fun mat_map f m = Vector.map (Vector.map f) m 54 55fun mat_tabulate f (linen,coln) = 56 let fun mk_line i = Vector.tabulate (coln, f i) in 57 Vector.tabulate (linen, mk_line) 58 end 59 60fun mat_smult (k:real) m = mat_map (fn x => k * x) m 61 62fun mat_dim m = (Vector.length m, Vector.length (Vector.sub (m,0))) 63 64fun mat_sub m i j = Vector.sub (Vector.sub (m,i), j) 65 66fun mat_update m ((i,j),k) = 67 let val newv = Vector.update (Vector.sub(m,i),j,k) in 68 Vector.update (m,i,newv) 69 end 70 71fun mat_add m1 m2 = 72 let fun f i j = mat_sub m1 i j + mat_sub m2 i j in 73 mat_tabulate f (mat_dim m1) 74 end 75 76fun matl_add ml = case ml of 77 [] => raise ERR "mat_addl" "" 78 | [m] => m 79 | m :: contl => mat_add m (matl_add contl) 80 81fun inv_dim (a,b) = (b,a) 82 83fun mat_transpose m1 = 84 let fun f i j = mat_sub m1 j i in 85 mat_tabulate f (inv_dim (mat_dim m1)) 86 end 87 88fun mat_random (dim as (a,b)) = 89 let 90 val r = Math.sqrt (6.0 / (Real.fromInt (a + b))) 91 fun f i j = r * (2.0 * random_real () - 1.0) 92 in 93 mat_tabulate f dim 94 end 95 96(* ------------------------------------------------------------------------- 97 Input/output 98 ------------------------------------------------------------------------- *) 99 100fun string_of_vect v = 101 String.concatWith " " (map Real.toString (vector_to_list v)) 102fun string_of_mat m = 103 String.concatWith "\n" (map string_of_vect (vector_to_list m)) 104 105local open HOLsexp in 106fun enc_vect v = list_encode enc_real (vector_to_list v) 107fun dec_vect t = Option.map Vector.fromList (list_decode dec_real t) 108fun enc_mat m = list_encode enc_vect (vector_to_list m) 109fun dec_mat t = Option.map Vector.fromList (list_decode dec_vect t) 110end 111 112 113end (* struct *) 114