1//===-- HeuristicSolver.h - Heuristic PBQP Solver --------------*- C++ -*-===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// Heuristic PBQP solver. This solver is able to perform optimal reductions for
11// nodes of degree 0, 1 or 2. For nodes of degree >2 a plugable heuristic is
12// used to select a node for reduction.
13//
14//===----------------------------------------------------------------------===//
15
16#ifndef LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H
17#define LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H
18
19#include "Graph.h"
20#include "Solution.h"
21#include <vector>
22#include <limits>
23
24namespace PBQP {
25
26  /// \brief Heuristic PBQP solver implementation.
27  ///
28  /// This class should usually be created (and destroyed) indirectly via a call
29  /// to HeuristicSolver<HImpl>::solve(Graph&).
30  /// See the comments for HeuristicSolver.
31  ///
32  /// HeuristicSolverImpl provides the R0, R1 and R2 reduction rules,
33  /// backpropagation phase, and maintains the internal copy of the graph on
34  /// which the reduction is carried out (the original being kept to facilitate
35  /// backpropagation).
36  template <typename HImpl>
37  class HeuristicSolverImpl {
38  private:
39
40    typedef typename HImpl::NodeData HeuristicNodeData;
41    typedef typename HImpl::EdgeData HeuristicEdgeData;
42
43    typedef std::list<Graph::EdgeItr> SolverEdges;
44
45  public:
46
47    /// \brief Iterator type for edges in the solver graph.
48    typedef SolverEdges::iterator SolverEdgeItr;
49
50  private:
51
52    class NodeData {
53    public:
54      NodeData() : solverDegree(0) {}
55
56      HeuristicNodeData& getHeuristicData() { return hData; }
57
58      SolverEdgeItr addSolverEdge(Graph::EdgeItr eItr) {
59        ++solverDegree;
60        return solverEdges.insert(solverEdges.end(), eItr);
61      }
62
63      void removeSolverEdge(SolverEdgeItr seItr) {
64        --solverDegree;
65        solverEdges.erase(seItr);
66      }
67
68      SolverEdgeItr solverEdgesBegin() { return solverEdges.begin(); }
69      SolverEdgeItr solverEdgesEnd() { return solverEdges.end(); }
70      unsigned getSolverDegree() const { return solverDegree; }
71      void clearSolverEdges() {
72        solverDegree = 0;
73        solverEdges.clear();
74      }
75
76    private:
77      HeuristicNodeData hData;
78      unsigned solverDegree;
79      SolverEdges solverEdges;
80    };
81
82    class EdgeData {
83    public:
84      HeuristicEdgeData& getHeuristicData() { return hData; }
85
86      void setN1SolverEdgeItr(SolverEdgeItr n1SolverEdgeItr) {
87        this->n1SolverEdgeItr = n1SolverEdgeItr;
88      }
89
90      SolverEdgeItr getN1SolverEdgeItr() { return n1SolverEdgeItr; }
91
92      void setN2SolverEdgeItr(SolverEdgeItr n2SolverEdgeItr){
93        this->n2SolverEdgeItr = n2SolverEdgeItr;
94      }
95
96      SolverEdgeItr getN2SolverEdgeItr() { return n2SolverEdgeItr; }
97
98    private:
99
100      HeuristicEdgeData hData;
101      SolverEdgeItr n1SolverEdgeItr, n2SolverEdgeItr;
102    };
103
104    Graph &g;
105    HImpl h;
106    Solution s;
107    std::vector<Graph::NodeItr> stack;
108
109    typedef std::list<NodeData> NodeDataList;
110    NodeDataList nodeDataList;
111
112    typedef std::list<EdgeData> EdgeDataList;
113    EdgeDataList edgeDataList;
114
115  public:
116
117    /// \brief Construct a heuristic solver implementation to solve the given
118    ///        graph.
119    /// @param g The graph representing the problem instance to be solved.
120    HeuristicSolverImpl(Graph &g) : g(g), h(*this) {}
121
122    /// \brief Get the graph being solved by this solver.
123    /// @return The graph representing the problem instance being solved by this
124    ///         solver.
125    Graph& getGraph() { return g; }
126
127    /// \brief Get the heuristic data attached to the given node.
128    /// @param nItr Node iterator.
129    /// @return The heuristic data attached to the given node.
130    HeuristicNodeData& getHeuristicNodeData(Graph::NodeItr nItr) {
131      return getSolverNodeData(nItr).getHeuristicData();
132    }
133
134    /// \brief Get the heuristic data attached to the given edge.
135    /// @param eItr Edge iterator.
136    /// @return The heuristic data attached to the given node.
137    HeuristicEdgeData& getHeuristicEdgeData(Graph::EdgeItr eItr) {
138      return getSolverEdgeData(eItr).getHeuristicData();
139    }
140
141    /// \brief Begin iterator for the set of edges adjacent to the given node in
142    ///        the solver graph.
143    /// @param nItr Node iterator.
144    /// @return Begin iterator for the set of edges adjacent to the given node
145    ///         in the solver graph.
146    SolverEdgeItr solverEdgesBegin(Graph::NodeItr nItr) {
147      return getSolverNodeData(nItr).solverEdgesBegin();
148    }
149
150    /// \brief End iterator for the set of edges adjacent to the given node in
151    ///        the solver graph.
152    /// @param nItr Node iterator.
153    /// @return End iterator for the set of edges adjacent to the given node in
154    ///         the solver graph.
155    SolverEdgeItr solverEdgesEnd(Graph::NodeItr nItr) {
156      return getSolverNodeData(nItr).solverEdgesEnd();
157    }
158
159    /// \brief Remove a node from the solver graph.
160    /// @param eItr Edge iterator for edge to be removed.
161    ///
162    /// Does <i>not</i> notify the heuristic of the removal. That should be
163    /// done manually if necessary.
164    void removeSolverEdge(Graph::EdgeItr eItr) {
165      EdgeData &eData = getSolverEdgeData(eItr);
166      NodeData &n1Data = getSolverNodeData(g.getEdgeNode1(eItr)),
167               &n2Data = getSolverNodeData(g.getEdgeNode2(eItr));
168
169      n1Data.removeSolverEdge(eData.getN1SolverEdgeItr());
170      n2Data.removeSolverEdge(eData.getN2SolverEdgeItr());
171    }
172
173    /// \brief Compute a solution to the PBQP problem instance with which this
174    ///        heuristic solver was constructed.
175    /// @return A solution to the PBQP problem.
176    ///
177    /// Performs the full PBQP heuristic solver algorithm, including setup,
178    /// calls to the heuristic (which will call back to the reduction rules in
179    /// this class), and cleanup.
180    Solution computeSolution() {
181      setup();
182      h.setup();
183      h.reduce();
184      backpropagate();
185      h.cleanup();
186      cleanup();
187      return s;
188    }
189
190    /// \brief Add to the end of the stack.
191    /// @param nItr Node iterator to add to the reduction stack.
192    void pushToStack(Graph::NodeItr nItr) {
193      getSolverNodeData(nItr).clearSolverEdges();
194      stack.push_back(nItr);
195    }
196
197    /// \brief Returns the solver degree of the given node.
198    /// @param nItr Node iterator for which degree is requested.
199    /// @return Node degree in the <i>solver</i> graph (not the original graph).
200    unsigned getSolverDegree(Graph::NodeItr nItr) {
201      return  getSolverNodeData(nItr).getSolverDegree();
202    }
203
204    /// \brief Set the solution of the given node.
205    /// @param nItr Node iterator to set solution for.
206    /// @param selection Selection for node.
207    void setSolution(const Graph::NodeItr &nItr, unsigned selection) {
208      s.setSelection(nItr, selection);
209
210      for (Graph::AdjEdgeItr aeItr = g.adjEdgesBegin(nItr),
211                             aeEnd = g.adjEdgesEnd(nItr);
212           aeItr != aeEnd; ++aeItr) {
213        Graph::EdgeItr eItr(*aeItr);
214        Graph::NodeItr anItr(g.getEdgeOtherNode(eItr, nItr));
215        getSolverNodeData(anItr).addSolverEdge(eItr);
216      }
217    }
218
219    /// \brief Apply rule R0.
220    /// @param nItr Node iterator for node to apply R0 to.
221    ///
222    /// Node will be automatically pushed to the solver stack.
223    void applyR0(Graph::NodeItr nItr) {
224      assert(getSolverNodeData(nItr).getSolverDegree() == 0 &&
225             "R0 applied to node with degree != 0.");
226
227      // Nothing to do. Just push the node onto the reduction stack.
228      pushToStack(nItr);
229
230      s.recordR0();
231    }
232
233    /// \brief Apply rule R1.
234    /// @param xnItr Node iterator for node to apply R1 to.
235    ///
236    /// Node will be automatically pushed to the solver stack.
237    void applyR1(Graph::NodeItr xnItr) {
238      NodeData &nd = getSolverNodeData(xnItr);
239      assert(nd.getSolverDegree() == 1 &&
240             "R1 applied to node with degree != 1.");
241
242      Graph::EdgeItr eItr = *nd.solverEdgesBegin();
243
244      const Matrix &eCosts = g.getEdgeCosts(eItr);
245      const Vector &xCosts = g.getNodeCosts(xnItr);
246
247      // Duplicate a little to avoid transposing matrices.
248      if (xnItr == g.getEdgeNode1(eItr)) {
249        Graph::NodeItr ynItr = g.getEdgeNode2(eItr);
250        Vector &yCosts = g.getNodeCosts(ynItr);
251        for (unsigned j = 0; j < yCosts.getLength(); ++j) {
252          PBQPNum min = eCosts[0][j] + xCosts[0];
253          for (unsigned i = 1; i < xCosts.getLength(); ++i) {
254            PBQPNum c = eCosts[i][j] + xCosts[i];
255            if (c < min)
256              min = c;
257          }
258          yCosts[j] += min;
259        }
260        h.handleRemoveEdge(eItr, ynItr);
261     } else {
262        Graph::NodeItr ynItr = g.getEdgeNode1(eItr);
263        Vector &yCosts = g.getNodeCosts(ynItr);
264        for (unsigned i = 0; i < yCosts.getLength(); ++i) {
265          PBQPNum min = eCosts[i][0] + xCosts[0];
266          for (unsigned j = 1; j < xCosts.getLength(); ++j) {
267            PBQPNum c = eCosts[i][j] + xCosts[j];
268            if (c < min)
269              min = c;
270          }
271          yCosts[i] += min;
272        }
273        h.handleRemoveEdge(eItr, ynItr);
274      }
275      removeSolverEdge(eItr);
276      assert(nd.getSolverDegree() == 0 &&
277             "Degree 1 with edge removed should be 0.");
278      pushToStack(xnItr);
279      s.recordR1();
280    }
281
282    /// \brief Apply rule R2.
283    /// @param xnItr Node iterator for node to apply R2 to.
284    ///
285    /// Node will be automatically pushed to the solver stack.
286    void applyR2(Graph::NodeItr xnItr) {
287      assert(getSolverNodeData(xnItr).getSolverDegree() == 2 &&
288             "R2 applied to node with degree != 2.");
289
290      NodeData &nd = getSolverNodeData(xnItr);
291      const Vector &xCosts = g.getNodeCosts(xnItr);
292
293      SolverEdgeItr aeItr = nd.solverEdgesBegin();
294      Graph::EdgeItr yxeItr = *aeItr,
295                     zxeItr = *(++aeItr);
296
297      Graph::NodeItr ynItr = g.getEdgeOtherNode(yxeItr, xnItr),
298                     znItr = g.getEdgeOtherNode(zxeItr, xnItr);
299
300      bool flipEdge1 = (g.getEdgeNode1(yxeItr) == xnItr),
301           flipEdge2 = (g.getEdgeNode1(zxeItr) == xnItr);
302
303      const Matrix *yxeCosts = flipEdge1 ?
304        new Matrix(g.getEdgeCosts(yxeItr).transpose()) :
305        &g.getEdgeCosts(yxeItr);
306
307      const Matrix *zxeCosts = flipEdge2 ?
308        new Matrix(g.getEdgeCosts(zxeItr).transpose()) :
309        &g.getEdgeCosts(zxeItr);
310
311      unsigned xLen = xCosts.getLength(),
312               yLen = yxeCosts->getRows(),
313               zLen = zxeCosts->getRows();
314
315      Matrix delta(yLen, zLen);
316
317      for (unsigned i = 0; i < yLen; ++i) {
318        for (unsigned j = 0; j < zLen; ++j) {
319          PBQPNum min = (*yxeCosts)[i][0] + (*zxeCosts)[j][0] + xCosts[0];
320          for (unsigned k = 1; k < xLen; ++k) {
321            PBQPNum c = (*yxeCosts)[i][k] + (*zxeCosts)[j][k] + xCosts[k];
322            if (c < min) {
323              min = c;
324            }
325          }
326          delta[i][j] = min;
327        }
328      }
329
330      if (flipEdge1)
331        delete yxeCosts;
332
333      if (flipEdge2)
334        delete zxeCosts;
335
336      Graph::EdgeItr yzeItr = g.findEdge(ynItr, znItr);
337      bool addedEdge = false;
338
339      if (yzeItr == g.edgesEnd()) {
340        yzeItr = g.addEdge(ynItr, znItr, delta);
341        addedEdge = true;
342      } else {
343        Matrix &yzeCosts = g.getEdgeCosts(yzeItr);
344        h.preUpdateEdgeCosts(yzeItr);
345        if (ynItr == g.getEdgeNode1(yzeItr)) {
346          yzeCosts += delta;
347        } else {
348          yzeCosts += delta.transpose();
349        }
350      }
351
352      bool nullCostEdge = tryNormaliseEdgeMatrix(yzeItr);
353
354      if (!addedEdge) {
355        // If we modified the edge costs let the heuristic know.
356        h.postUpdateEdgeCosts(yzeItr);
357      }
358
359      if (nullCostEdge) {
360        // If this edge ended up null remove it.
361        if (!addedEdge) {
362          // We didn't just add it, so we need to notify the heuristic
363          // and remove it from the solver.
364          h.handleRemoveEdge(yzeItr, ynItr);
365          h.handleRemoveEdge(yzeItr, znItr);
366          removeSolverEdge(yzeItr);
367        }
368        g.removeEdge(yzeItr);
369      } else if (addedEdge) {
370        // If the edge was added, and non-null, finish setting it up, add it to
371        // the solver & notify heuristic.
372        edgeDataList.push_back(EdgeData());
373        g.setEdgeData(yzeItr, &edgeDataList.back());
374        addSolverEdge(yzeItr);
375        h.handleAddEdge(yzeItr);
376      }
377
378      h.handleRemoveEdge(yxeItr, ynItr);
379      removeSolverEdge(yxeItr);
380      h.handleRemoveEdge(zxeItr, znItr);
381      removeSolverEdge(zxeItr);
382
383      pushToStack(xnItr);
384      s.recordR2();
385    }
386
387    /// \brief Record an application of the RN rule.
388    ///
389    /// For use by the HeuristicBase.
390    void recordRN() { s.recordRN(); }
391
392  private:
393
394    NodeData& getSolverNodeData(Graph::NodeItr nItr) {
395      return *static_cast<NodeData*>(g.getNodeData(nItr));
396    }
397
398    EdgeData& getSolverEdgeData(Graph::EdgeItr eItr) {
399      return *static_cast<EdgeData*>(g.getEdgeData(eItr));
400    }
401
402    void addSolverEdge(Graph::EdgeItr eItr) {
403      EdgeData &eData = getSolverEdgeData(eItr);
404      NodeData &n1Data = getSolverNodeData(g.getEdgeNode1(eItr)),
405               &n2Data = getSolverNodeData(g.getEdgeNode2(eItr));
406
407      eData.setN1SolverEdgeItr(n1Data.addSolverEdge(eItr));
408      eData.setN2SolverEdgeItr(n2Data.addSolverEdge(eItr));
409    }
410
411    void setup() {
412      if (h.solverRunSimplify()) {
413        simplify();
414      }
415
416      // Create node data objects.
417      for (Graph::NodeItr nItr = g.nodesBegin(), nEnd = g.nodesEnd();
418           nItr != nEnd; ++nItr) {
419        nodeDataList.push_back(NodeData());
420        g.setNodeData(nItr, &nodeDataList.back());
421      }
422
423      // Create edge data objects.
424      for (Graph::EdgeItr eItr = g.edgesBegin(), eEnd = g.edgesEnd();
425           eItr != eEnd; ++eItr) {
426        edgeDataList.push_back(EdgeData());
427        g.setEdgeData(eItr, &edgeDataList.back());
428        addSolverEdge(eItr);
429      }
430    }
431
432    void simplify() {
433      disconnectTrivialNodes();
434      eliminateIndependentEdges();
435    }
436
437    // Eliminate trivial nodes.
438    void disconnectTrivialNodes() {
439      unsigned numDisconnected = 0;
440
441      for (Graph::NodeItr nItr = g.nodesBegin(), nEnd = g.nodesEnd();
442           nItr != nEnd; ++nItr) {
443
444        if (g.getNodeCosts(nItr).getLength() == 1) {
445
446          std::vector<Graph::EdgeItr> edgesToRemove;
447
448          for (Graph::AdjEdgeItr aeItr = g.adjEdgesBegin(nItr),
449                                 aeEnd = g.adjEdgesEnd(nItr);
450               aeItr != aeEnd; ++aeItr) {
451
452            Graph::EdgeItr eItr = *aeItr;
453
454            if (g.getEdgeNode1(eItr) == nItr) {
455              Graph::NodeItr otherNodeItr = g.getEdgeNode2(eItr);
456              g.getNodeCosts(otherNodeItr) +=
457                g.getEdgeCosts(eItr).getRowAsVector(0);
458            }
459            else {
460              Graph::NodeItr otherNodeItr = g.getEdgeNode1(eItr);
461              g.getNodeCosts(otherNodeItr) +=
462                g.getEdgeCosts(eItr).getColAsVector(0);
463            }
464
465            edgesToRemove.push_back(eItr);
466          }
467
468          if (!edgesToRemove.empty())
469            ++numDisconnected;
470
471          while (!edgesToRemove.empty()) {
472            g.removeEdge(edgesToRemove.back());
473            edgesToRemove.pop_back();
474          }
475        }
476      }
477    }
478
479    void eliminateIndependentEdges() {
480      std::vector<Graph::EdgeItr> edgesToProcess;
481      unsigned numEliminated = 0;
482
483      for (Graph::EdgeItr eItr = g.edgesBegin(), eEnd = g.edgesEnd();
484           eItr != eEnd; ++eItr) {
485        edgesToProcess.push_back(eItr);
486      }
487
488      while (!edgesToProcess.empty()) {
489        if (tryToEliminateEdge(edgesToProcess.back()))
490          ++numEliminated;
491        edgesToProcess.pop_back();
492      }
493    }
494
495    bool tryToEliminateEdge(Graph::EdgeItr eItr) {
496      if (tryNormaliseEdgeMatrix(eItr)) {
497        g.removeEdge(eItr);
498        return true;
499      }
500      return false;
501    }
502
503    bool tryNormaliseEdgeMatrix(Graph::EdgeItr &eItr) {
504
505      const PBQPNum infinity = std::numeric_limits<PBQPNum>::infinity();
506
507      Matrix &edgeCosts = g.getEdgeCosts(eItr);
508      Vector &uCosts = g.getNodeCosts(g.getEdgeNode1(eItr)),
509             &vCosts = g.getNodeCosts(g.getEdgeNode2(eItr));
510
511      for (unsigned r = 0; r < edgeCosts.getRows(); ++r) {
512        PBQPNum rowMin = infinity;
513
514        for (unsigned c = 0; c < edgeCosts.getCols(); ++c) {
515          if (vCosts[c] != infinity && edgeCosts[r][c] < rowMin)
516            rowMin = edgeCosts[r][c];
517        }
518
519        uCosts[r] += rowMin;
520
521        if (rowMin != infinity) {
522          edgeCosts.subFromRow(r, rowMin);
523        }
524        else {
525          edgeCosts.setRow(r, 0);
526        }
527      }
528
529      for (unsigned c = 0; c < edgeCosts.getCols(); ++c) {
530        PBQPNum colMin = infinity;
531
532        for (unsigned r = 0; r < edgeCosts.getRows(); ++r) {
533          if (uCosts[r] != infinity && edgeCosts[r][c] < colMin)
534            colMin = edgeCosts[r][c];
535        }
536
537        vCosts[c] += colMin;
538
539        if (colMin != infinity) {
540          edgeCosts.subFromCol(c, colMin);
541        }
542        else {
543          edgeCosts.setCol(c, 0);
544        }
545      }
546
547      return edgeCosts.isZero();
548    }
549
550    void backpropagate() {
551      while (!stack.empty()) {
552        computeSolution(stack.back());
553        stack.pop_back();
554      }
555    }
556
557    void computeSolution(Graph::NodeItr nItr) {
558
559      NodeData &nodeData = getSolverNodeData(nItr);
560
561      Vector v(g.getNodeCosts(nItr));
562
563      // Solve based on existing solved edges.
564      for (SolverEdgeItr solvedEdgeItr = nodeData.solverEdgesBegin(),
565                         solvedEdgeEnd = nodeData.solverEdgesEnd();
566           solvedEdgeItr != solvedEdgeEnd; ++solvedEdgeItr) {
567
568        Graph::EdgeItr eItr(*solvedEdgeItr);
569        Matrix &edgeCosts = g.getEdgeCosts(eItr);
570
571        if (nItr == g.getEdgeNode1(eItr)) {
572          Graph::NodeItr adjNode(g.getEdgeNode2(eItr));
573          unsigned adjSolution = s.getSelection(adjNode);
574          v += edgeCosts.getColAsVector(adjSolution);
575        }
576        else {
577          Graph::NodeItr adjNode(g.getEdgeNode1(eItr));
578          unsigned adjSolution = s.getSelection(adjNode);
579          v += edgeCosts.getRowAsVector(adjSolution);
580        }
581
582      }
583
584      setSolution(nItr, v.minIndex());
585    }
586
587    void cleanup() {
588      h.cleanup();
589      nodeDataList.clear();
590      edgeDataList.clear();
591    }
592  };
593
594  /// \brief PBQP heuristic solver class.
595  ///
596  /// Given a PBQP Graph g representing a PBQP problem, you can find a solution
597  /// by calling
598  /// <tt>Solution s = HeuristicSolver<H>::solve(g);</tt>
599  ///
600  /// The choice of heuristic for the H parameter will affect both the solver
601  /// speed and solution quality. The heuristic should be chosen based on the
602  /// nature of the problem being solved.
603  /// Currently the only solver included with LLVM is the Briggs heuristic for
604  /// register allocation.
605  template <typename HImpl>
606  class HeuristicSolver {
607  public:
608    static Solution solve(Graph &g) {
609      HeuristicSolverImpl<HImpl> hs(g);
610      return hs.computeSolution();
611    }
612  };
613
614}
615
616#endif // LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H
617