1require 'bigdecimal'
2
3#
4# Solves a*x = b for x, using LU decomposition.
5#
6module LUSolve
7  module_function
8
9  # Performs LU decomposition of the n by n matrix a.
10  def ludecomp(a,n,zero=0,one=1)
11    prec = BigDecimal.limit(nil)
12    ps     = []
13    scales = []
14    for i in 0...n do  # pick up largest(abs. val.) element in each row.
15      ps <<= i
16      nrmrow  = zero
17      ixn = i*n
18      for j in 0...n do
19        biggst = a[ixn+j].abs
20        nrmrow = biggst if biggst>nrmrow
21      end
22      if nrmrow>zero then
23        scales <<= one.div(nrmrow,prec)
24      else
25        raise "Singular matrix"
26      end
27    end
28    n1          = n - 1
29    for k in 0...n1 do # Gaussian elimination with partial pivoting.
30      biggst  = zero;
31      for i in k...n do
32        size = a[ps[i]*n+k].abs*scales[ps[i]]
33        if size>biggst then
34          biggst = size
35          pividx  = i
36        end
37      end
38      raise "Singular matrix" if biggst<=zero
39      if pividx!=k then
40        j = ps[k]
41        ps[k] = ps[pividx]
42        ps[pividx] = j
43      end
44      pivot   = a[ps[k]*n+k]
45      for i in (k+1)...n do
46        psin = ps[i]*n
47        a[psin+k] = mult = a[psin+k].div(pivot,prec)
48        if mult!=zero then
49          pskn = ps[k]*n
50          for j in (k+1)...n do
51            a[psin+j] -= mult.mult(a[pskn+j],prec)
52          end
53        end
54      end
55    end
56    raise "Singular matrix" if a[ps[n1]*n+n1] == zero
57    ps
58  end
59
60  # Solves a*x = b for x, using LU decomposition.
61  #
62  # a is a matrix, b is a constant vector, x is the solution vector.
63  #
64  # ps is the pivot, a vector which indicates the permutation of rows performed
65  # during LU decomposition.
66  def lusolve(a,b,ps,zero=0.0)
67    prec = BigDecimal.limit(nil)
68    n = ps.size
69    x = []
70    for i in 0...n do
71      dot = zero
72      psin = ps[i]*n
73      for j in 0...i do
74        dot = a[psin+j].mult(x[j],prec) + dot
75      end
76      x <<= b[ps[i]] - dot
77    end
78    (n-1).downto(0) do |i|
79      dot = zero
80      psin = ps[i]*n
81      for j in (i+1)...n do
82        dot = a[psin+j].mult(x[j],prec) + dot
83      end
84      x[i]  = (x[i]-dot).div(a[psin+i],prec)
85    end
86    x
87  end
88end
89