1class Matrix
2  # Adapted from JAMA: http://math.nist.gov/javanumerics/jama/
3
4  #
5  # For an m-by-n matrix A with m >= n, the LU decomposition is an m-by-n
6  # unit lower triangular matrix L, an n-by-n upper triangular matrix U,
7  # and a m-by-m permutation matrix P so that L*U = P*A.
8  # If m < n, then L is m-by-m and U is m-by-n.
9  #
10  # The LUP decomposition with pivoting always exists, even if the matrix is
11  # singular, so the constructor will never fail.  The primary use of the
12  # LU decomposition is in the solution of square systems of simultaneous
13  # linear equations.  This will fail if singular? returns true.
14  #
15
16  class LUPDecomposition
17    # Returns the lower triangular factor +L+
18
19    include Matrix::ConversionHelper
20
21    def l
22      Matrix.build(@row_count, [@column_count, @row_count].min) do |i, j|
23        if (i > j)
24          @lu[i][j]
25        elsif (i == j)
26          1
27        else
28          0
29        end
30      end
31    end
32
33    # Returns the upper triangular factor +U+
34
35    def u
36      Matrix.build([@column_count, @row_count].min, @column_count) do |i, j|
37        if (i <= j)
38          @lu[i][j]
39        else
40          0
41        end
42      end
43    end
44
45    # Returns the permutation matrix +P+
46
47    def p
48      rows = Array.new(@row_count){Array.new(@row_count, 0)}
49      @pivots.each_with_index{|p, i| rows[i][p] = 1}
50      Matrix.send :new, rows, @row_count
51    end
52
53    # Returns +L+, +U+, +P+ in an array
54
55    def to_ary
56      [l, u, p]
57    end
58    alias_method :to_a, :to_ary
59
60    # Returns the pivoting indices
61
62    attr_reader :pivots
63
64    # Returns +true+ if +U+, and hence +A+, is singular.
65
66    def singular? ()
67      @column_count.times do |j|
68        if (@lu[j][j] == 0)
69          return true
70        end
71      end
72      false
73    end
74
75    # Returns the determinant of +A+, calculated efficiently
76    # from the factorization.
77
78    def det
79      if (@row_count != @column_count)
80        Matrix.Raise Matrix::ErrDimensionMismatch
81      end
82      d = @pivot_sign
83      @column_count.times do |j|
84        d *= @lu[j][j]
85      end
86      d
87    end
88    alias_method :determinant, :det
89
90    # Returns +m+ so that <tt>A*m = b</tt>,
91    # or equivalently so that <tt>L*U*m = P*b</tt>
92    # +b+ can be a Matrix or a Vector
93
94    def solve b
95      if (singular?)
96        Matrix.Raise Matrix::ErrNotRegular, "Matrix is singular."
97      end
98      if b.is_a? Matrix
99        if (b.row_count != @row_count)
100          Matrix.Raise Matrix::ErrDimensionMismatch
101        end
102
103        # Copy right hand side with pivoting
104        nx = b.column_count
105        m = @pivots.map{|row| b.row(row).to_a}
106
107        # Solve L*Y = P*b
108        @column_count.times do |k|
109          (k+1).upto(@column_count-1) do |i|
110            nx.times do |j|
111              m[i][j] -= m[k][j]*@lu[i][k]
112            end
113          end
114        end
115        # Solve U*m = Y
116        (@column_count-1).downto(0) do |k|
117          nx.times do |j|
118            m[k][j] = m[k][j].quo(@lu[k][k])
119          end
120          k.times do |i|
121            nx.times do |j|
122              m[i][j] -= m[k][j]*@lu[i][k]
123            end
124          end
125        end
126        Matrix.send :new, m, nx
127      else # same algorithm, specialized for simpler case of a vector
128        b = convert_to_array(b)
129        if (b.size != @row_count)
130          Matrix.Raise Matrix::ErrDimensionMismatch
131        end
132
133        # Copy right hand side with pivoting
134        m = b.values_at(*@pivots)
135
136        # Solve L*Y = P*b
137        @column_count.times do |k|
138          (k+1).upto(@column_count-1) do |i|
139            m[i] -= m[k]*@lu[i][k]
140          end
141        end
142        # Solve U*m = Y
143        (@column_count-1).downto(0) do |k|
144          m[k] = m[k].quo(@lu[k][k])
145          k.times do |i|
146            m[i] -= m[k]*@lu[i][k]
147          end
148        end
149        Vector.elements(m, false)
150      end
151    end
152
153    def initialize a
154      raise TypeError, "Expected Matrix but got #{a.class}" unless a.is_a?(Matrix)
155      # Use a "left-looking", dot-product, Crout/Doolittle algorithm.
156      @lu = a.to_a
157      @row_count = a.row_count
158      @column_count = a.column_count
159      @pivots = Array.new(@row_count)
160      @row_count.times do |i|
161         @pivots[i] = i
162      end
163      @pivot_sign = 1
164      lu_col_j = Array.new(@row_count)
165
166      # Outer loop.
167
168      @column_count.times do |j|
169
170        # Make a copy of the j-th column to localize references.
171
172        @row_count.times do |i|
173          lu_col_j[i] = @lu[i][j]
174        end
175
176        # Apply previous transformations.
177
178        @row_count.times do |i|
179          lu_row_i = @lu[i]
180
181          # Most of the time is spent in the following dot product.
182
183          kmax = [i, j].min
184          s = 0
185          kmax.times do |k|
186            s += lu_row_i[k]*lu_col_j[k]
187          end
188
189          lu_row_i[j] = lu_col_j[i] -= s
190        end
191
192        # Find pivot and exchange if necessary.
193
194        p = j
195        (j+1).upto(@row_count-1) do |i|
196          if (lu_col_j[i].abs > lu_col_j[p].abs)
197            p = i
198          end
199        end
200        if (p != j)
201          @column_count.times do |k|
202            t = @lu[p][k]; @lu[p][k] = @lu[j][k]; @lu[j][k] = t
203          end
204          k = @pivots[p]; @pivots[p] = @pivots[j]; @pivots[j] = k
205          @pivot_sign = -@pivot_sign
206        end
207
208        # Compute multipliers.
209
210        if (j < @row_count && @lu[j][j] != 0)
211          (j+1).upto(@row_count-1) do |i|
212            @lu[i][j] = @lu[i][j].quo(@lu[j][j])
213          end
214        end
215      end
216    end
217  end
218end
219