HeuristicSolver.h revision 263508
1//===-- HeuristicSolver.h - Heuristic PBQP Solver --------------*- C++ -*-===// 2// 3// The LLVM Compiler Infrastructure 4// 5// This file is distributed under the University of Illinois Open Source 6// License. See LICENSE.TXT for details. 7// 8//===----------------------------------------------------------------------===// 9// 10// Heuristic PBQP solver. This solver is able to perform optimal reductions for 11// nodes of degree 0, 1 or 2. For nodes of degree >2 a plugable heuristic is 12// used to select a node for reduction. 13// 14//===----------------------------------------------------------------------===// 15 16#ifndef LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H 17#define LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H 18 19#include "Graph.h" 20#include "Solution.h" 21#include <limits> 22#include <vector> 23 24namespace PBQP { 25 26 /// \brief Heuristic PBQP solver implementation. 27 /// 28 /// This class should usually be created (and destroyed) indirectly via a call 29 /// to HeuristicSolver<HImpl>::solve(Graph&). 30 /// See the comments for HeuristicSolver. 31 /// 32 /// HeuristicSolverImpl provides the R0, R1 and R2 reduction rules, 33 /// backpropagation phase, and maintains the internal copy of the graph on 34 /// which the reduction is carried out (the original being kept to facilitate 35 /// backpropagation). 36 template <typename HImpl> 37 class HeuristicSolverImpl { 38 private: 39 40 typedef typename HImpl::NodeData HeuristicNodeData; 41 typedef typename HImpl::EdgeData HeuristicEdgeData; 42 43 typedef std::list<Graph::EdgeId> SolverEdges; 44 45 public: 46 47 /// \brief Iterator type for edges in the solver graph. 48 typedef SolverEdges::iterator SolverEdgeItr; 49 50 private: 51 52 class NodeData { 53 public: 54 NodeData() : solverDegree(0) {} 55 56 HeuristicNodeData& getHeuristicData() { return hData; } 57 58 SolverEdgeItr addSolverEdge(Graph::EdgeId eId) { 59 ++solverDegree; 60 return solverEdges.insert(solverEdges.end(), eId); 61 } 62 63 void removeSolverEdge(SolverEdgeItr seItr) { 64 --solverDegree; 65 solverEdges.erase(seItr); 66 } 67 68 SolverEdgeItr solverEdgesBegin() { return solverEdges.begin(); } 69 SolverEdgeItr solverEdgesEnd() { return solverEdges.end(); } 70 unsigned getSolverDegree() const { return solverDegree; } 71 void clearSolverEdges() { 72 solverDegree = 0; 73 solverEdges.clear(); 74 } 75 76 private: 77 HeuristicNodeData hData; 78 unsigned solverDegree; 79 SolverEdges solverEdges; 80 }; 81 82 class EdgeData { 83 public: 84 HeuristicEdgeData& getHeuristicData() { return hData; } 85 86 void setN1SolverEdgeItr(SolverEdgeItr n1SolverEdgeItr) { 87 this->n1SolverEdgeItr = n1SolverEdgeItr; 88 } 89 90 SolverEdgeItr getN1SolverEdgeItr() { return n1SolverEdgeItr; } 91 92 void setN2SolverEdgeItr(SolverEdgeItr n2SolverEdgeItr){ 93 this->n2SolverEdgeItr = n2SolverEdgeItr; 94 } 95 96 SolverEdgeItr getN2SolverEdgeItr() { return n2SolverEdgeItr; } 97 98 private: 99 100 HeuristicEdgeData hData; 101 SolverEdgeItr n1SolverEdgeItr, n2SolverEdgeItr; 102 }; 103 104 Graph &g; 105 HImpl h; 106 Solution s; 107 std::vector<Graph::NodeId> stack; 108 109 typedef std::list<NodeData> NodeDataList; 110 NodeDataList nodeDataList; 111 112 typedef std::list<EdgeData> EdgeDataList; 113 EdgeDataList edgeDataList; 114 115 public: 116 117 /// \brief Construct a heuristic solver implementation to solve the given 118 /// graph. 119 /// @param g The graph representing the problem instance to be solved. 120 HeuristicSolverImpl(Graph &g) : g(g), h(*this) {} 121 122 /// \brief Get the graph being solved by this solver. 123 /// @return The graph representing the problem instance being solved by this 124 /// solver. 125 Graph& getGraph() { return g; } 126 127 /// \brief Get the heuristic data attached to the given node. 128 /// @param nId Node id. 129 /// @return The heuristic data attached to the given node. 130 HeuristicNodeData& getHeuristicNodeData(Graph::NodeId nId) { 131 return getSolverNodeData(nId).getHeuristicData(); 132 } 133 134 /// \brief Get the heuristic data attached to the given edge. 135 /// @param eId Edge id. 136 /// @return The heuristic data attached to the given node. 137 HeuristicEdgeData& getHeuristicEdgeData(Graph::EdgeId eId) { 138 return getSolverEdgeData(eId).getHeuristicData(); 139 } 140 141 /// \brief Begin iterator for the set of edges adjacent to the given node in 142 /// the solver graph. 143 /// @param nId Node id. 144 /// @return Begin iterator for the set of edges adjacent to the given node 145 /// in the solver graph. 146 SolverEdgeItr solverEdgesBegin(Graph::NodeId nId) { 147 return getSolverNodeData(nId).solverEdgesBegin(); 148 } 149 150 /// \brief End iterator for the set of edges adjacent to the given node in 151 /// the solver graph. 152 /// @param nId Node id. 153 /// @return End iterator for the set of edges adjacent to the given node in 154 /// the solver graph. 155 SolverEdgeItr solverEdgesEnd(Graph::NodeId nId) { 156 return getSolverNodeData(nId).solverEdgesEnd(); 157 } 158 159 /// \brief Remove a node from the solver graph. 160 /// @param eId Edge id for edge to be removed. 161 /// 162 /// Does <i>not</i> notify the heuristic of the removal. That should be 163 /// done manually if necessary. 164 void removeSolverEdge(Graph::EdgeId eId) { 165 EdgeData &eData = getSolverEdgeData(eId); 166 NodeData &n1Data = getSolverNodeData(g.getEdgeNode1(eId)), 167 &n2Data = getSolverNodeData(g.getEdgeNode2(eId)); 168 169 n1Data.removeSolverEdge(eData.getN1SolverEdgeItr()); 170 n2Data.removeSolverEdge(eData.getN2SolverEdgeItr()); 171 } 172 173 /// \brief Compute a solution to the PBQP problem instance with which this 174 /// heuristic solver was constructed. 175 /// @return A solution to the PBQP problem. 176 /// 177 /// Performs the full PBQP heuristic solver algorithm, including setup, 178 /// calls to the heuristic (which will call back to the reduction rules in 179 /// this class), and cleanup. 180 Solution computeSolution() { 181 setup(); 182 h.setup(); 183 h.reduce(); 184 backpropagate(); 185 h.cleanup(); 186 cleanup(); 187 return s; 188 } 189 190 /// \brief Add to the end of the stack. 191 /// @param nId Node id to add to the reduction stack. 192 void pushToStack(Graph::NodeId nId) { 193 getSolverNodeData(nId).clearSolverEdges(); 194 stack.push_back(nId); 195 } 196 197 /// \brief Returns the solver degree of the given node. 198 /// @param nId Node id for which degree is requested. 199 /// @return Node degree in the <i>solver</i> graph (not the original graph). 200 unsigned getSolverDegree(Graph::NodeId nId) { 201 return getSolverNodeData(nId).getSolverDegree(); 202 } 203 204 /// \brief Set the solution of the given node. 205 /// @param nId Node id to set solution for. 206 /// @param selection Selection for node. 207 void setSolution(const Graph::NodeId &nId, unsigned selection) { 208 s.setSelection(nId, selection); 209 210 for (Graph::AdjEdgeItr aeItr = g.adjEdgesBegin(nId), 211 aeEnd = g.adjEdgesEnd(nId); 212 aeItr != aeEnd; ++aeItr) { 213 Graph::EdgeId eId(*aeItr); 214 Graph::NodeId anId(g.getEdgeOtherNode(eId, nId)); 215 getSolverNodeData(anId).addSolverEdge(eId); 216 } 217 } 218 219 /// \brief Apply rule R0. 220 /// @param nId Node id for node to apply R0 to. 221 /// 222 /// Node will be automatically pushed to the solver stack. 223 void applyR0(Graph::NodeId nId) { 224 assert(getSolverNodeData(nId).getSolverDegree() == 0 && 225 "R0 applied to node with degree != 0."); 226 227 // Nothing to do. Just push the node onto the reduction stack. 228 pushToStack(nId); 229 230 s.recordR0(); 231 } 232 233 /// \brief Apply rule R1. 234 /// @param xnId Node id for node to apply R1 to. 235 /// 236 /// Node will be automatically pushed to the solver stack. 237 void applyR1(Graph::NodeId xnId) { 238 NodeData &nd = getSolverNodeData(xnId); 239 assert(nd.getSolverDegree() == 1 && 240 "R1 applied to node with degree != 1."); 241 242 Graph::EdgeId eId = *nd.solverEdgesBegin(); 243 244 const Matrix &eCosts = g.getEdgeCosts(eId); 245 const Vector &xCosts = g.getNodeCosts(xnId); 246 247 // Duplicate a little to avoid transposing matrices. 248 if (xnId == g.getEdgeNode1(eId)) { 249 Graph::NodeId ynId = g.getEdgeNode2(eId); 250 Vector &yCosts = g.getNodeCosts(ynId); 251 for (unsigned j = 0; j < yCosts.getLength(); ++j) { 252 PBQPNum min = eCosts[0][j] + xCosts[0]; 253 for (unsigned i = 1; i < xCosts.getLength(); ++i) { 254 PBQPNum c = eCosts[i][j] + xCosts[i]; 255 if (c < min) 256 min = c; 257 } 258 yCosts[j] += min; 259 } 260 h.handleRemoveEdge(eId, ynId); 261 } else { 262 Graph::NodeId ynId = g.getEdgeNode1(eId); 263 Vector &yCosts = g.getNodeCosts(ynId); 264 for (unsigned i = 0; i < yCosts.getLength(); ++i) { 265 PBQPNum min = eCosts[i][0] + xCosts[0]; 266 for (unsigned j = 1; j < xCosts.getLength(); ++j) { 267 PBQPNum c = eCosts[i][j] + xCosts[j]; 268 if (c < min) 269 min = c; 270 } 271 yCosts[i] += min; 272 } 273 h.handleRemoveEdge(eId, ynId); 274 } 275 removeSolverEdge(eId); 276 assert(nd.getSolverDegree() == 0 && 277 "Degree 1 with edge removed should be 0."); 278 pushToStack(xnId); 279 s.recordR1(); 280 } 281 282 /// \brief Apply rule R2. 283 /// @param xnId Node id for node to apply R2 to. 284 /// 285 /// Node will be automatically pushed to the solver stack. 286 void applyR2(Graph::NodeId xnId) { 287 assert(getSolverNodeData(xnId).getSolverDegree() == 2 && 288 "R2 applied to node with degree != 2."); 289 290 NodeData &nd = getSolverNodeData(xnId); 291 const Vector &xCosts = g.getNodeCosts(xnId); 292 293 SolverEdgeItr aeItr = nd.solverEdgesBegin(); 294 Graph::EdgeId yxeId = *aeItr, 295 zxeId = *(++aeItr); 296 297 Graph::NodeId ynId = g.getEdgeOtherNode(yxeId, xnId), 298 znId = g.getEdgeOtherNode(zxeId, xnId); 299 300 bool flipEdge1 = (g.getEdgeNode1(yxeId) == xnId), 301 flipEdge2 = (g.getEdgeNode1(zxeId) == xnId); 302 303 const Matrix *yxeCosts = flipEdge1 ? 304 new Matrix(g.getEdgeCosts(yxeId).transpose()) : 305 &g.getEdgeCosts(yxeId); 306 307 const Matrix *zxeCosts = flipEdge2 ? 308 new Matrix(g.getEdgeCosts(zxeId).transpose()) : 309 &g.getEdgeCosts(zxeId); 310 311 unsigned xLen = xCosts.getLength(), 312 yLen = yxeCosts->getRows(), 313 zLen = zxeCosts->getRows(); 314 315 Matrix delta(yLen, zLen); 316 317 for (unsigned i = 0; i < yLen; ++i) { 318 for (unsigned j = 0; j < zLen; ++j) { 319 PBQPNum min = (*yxeCosts)[i][0] + (*zxeCosts)[j][0] + xCosts[0]; 320 for (unsigned k = 1; k < xLen; ++k) { 321 PBQPNum c = (*yxeCosts)[i][k] + (*zxeCosts)[j][k] + xCosts[k]; 322 if (c < min) { 323 min = c; 324 } 325 } 326 delta[i][j] = min; 327 } 328 } 329 330 if (flipEdge1) 331 delete yxeCosts; 332 333 if (flipEdge2) 334 delete zxeCosts; 335 336 Graph::EdgeId yzeId = g.findEdge(ynId, znId); 337 bool addedEdge = false; 338 339 if (yzeId == g.invalidEdgeId()) { 340 yzeId = g.addEdge(ynId, znId, delta); 341 addedEdge = true; 342 } else { 343 Matrix &yzeCosts = g.getEdgeCosts(yzeId); 344 h.preUpdateEdgeCosts(yzeId); 345 if (ynId == g.getEdgeNode1(yzeId)) { 346 yzeCosts += delta; 347 } else { 348 yzeCosts += delta.transpose(); 349 } 350 } 351 352 bool nullCostEdge = tryNormaliseEdgeMatrix(yzeId); 353 354 if (!addedEdge) { 355 // If we modified the edge costs let the heuristic know. 356 h.postUpdateEdgeCosts(yzeId); 357 } 358 359 if (nullCostEdge) { 360 // If this edge ended up null remove it. 361 if (!addedEdge) { 362 // We didn't just add it, so we need to notify the heuristic 363 // and remove it from the solver. 364 h.handleRemoveEdge(yzeId, ynId); 365 h.handleRemoveEdge(yzeId, znId); 366 removeSolverEdge(yzeId); 367 } 368 g.removeEdge(yzeId); 369 } else if (addedEdge) { 370 // If the edge was added, and non-null, finish setting it up, add it to 371 // the solver & notify heuristic. 372 edgeDataList.push_back(EdgeData()); 373 g.setEdgeData(yzeId, &edgeDataList.back()); 374 addSolverEdge(yzeId); 375 h.handleAddEdge(yzeId); 376 } 377 378 h.handleRemoveEdge(yxeId, ynId); 379 removeSolverEdge(yxeId); 380 h.handleRemoveEdge(zxeId, znId); 381 removeSolverEdge(zxeId); 382 383 pushToStack(xnId); 384 s.recordR2(); 385 } 386 387 /// \brief Record an application of the RN rule. 388 /// 389 /// For use by the HeuristicBase. 390 void recordRN() { s.recordRN(); } 391 392 private: 393 394 NodeData& getSolverNodeData(Graph::NodeId nId) { 395 return *static_cast<NodeData*>(g.getNodeData(nId)); 396 } 397 398 EdgeData& getSolverEdgeData(Graph::EdgeId eId) { 399 return *static_cast<EdgeData*>(g.getEdgeData(eId)); 400 } 401 402 void addSolverEdge(Graph::EdgeId eId) { 403 EdgeData &eData = getSolverEdgeData(eId); 404 NodeData &n1Data = getSolverNodeData(g.getEdgeNode1(eId)), 405 &n2Data = getSolverNodeData(g.getEdgeNode2(eId)); 406 407 eData.setN1SolverEdgeItr(n1Data.addSolverEdge(eId)); 408 eData.setN2SolverEdgeItr(n2Data.addSolverEdge(eId)); 409 } 410 411 void setup() { 412 if (h.solverRunSimplify()) { 413 simplify(); 414 } 415 416 // Create node data objects. 417 for (Graph::NodeItr nItr = g.nodesBegin(), nEnd = g.nodesEnd(); 418 nItr != nEnd; ++nItr) { 419 nodeDataList.push_back(NodeData()); 420 g.setNodeData(*nItr, &nodeDataList.back()); 421 } 422 423 // Create edge data objects. 424 for (Graph::EdgeItr eItr = g.edgesBegin(), eEnd = g.edgesEnd(); 425 eItr != eEnd; ++eItr) { 426 edgeDataList.push_back(EdgeData()); 427 g.setEdgeData(*eItr, &edgeDataList.back()); 428 addSolverEdge(*eItr); 429 } 430 } 431 432 void simplify() { 433 disconnectTrivialNodes(); 434 eliminateIndependentEdges(); 435 } 436 437 // Eliminate trivial nodes. 438 void disconnectTrivialNodes() { 439 unsigned numDisconnected = 0; 440 441 for (Graph::NodeItr nItr = g.nodesBegin(), nEnd = g.nodesEnd(); 442 nItr != nEnd; ++nItr) { 443 444 Graph::NodeId nId = *nItr; 445 446 if (g.getNodeCosts(nId).getLength() == 1) { 447 448 std::vector<Graph::EdgeId> edgesToRemove; 449 450 for (Graph::AdjEdgeItr aeItr = g.adjEdgesBegin(nId), 451 aeEnd = g.adjEdgesEnd(nId); 452 aeItr != aeEnd; ++aeItr) { 453 454 Graph::EdgeId eId = *aeItr; 455 456 if (g.getEdgeNode1(eId) == nId) { 457 Graph::NodeId otherNodeId = g.getEdgeNode2(eId); 458 g.getNodeCosts(otherNodeId) += 459 g.getEdgeCosts(eId).getRowAsVector(0); 460 } 461 else { 462 Graph::NodeId otherNodeId = g.getEdgeNode1(eId); 463 g.getNodeCosts(otherNodeId) += 464 g.getEdgeCosts(eId).getColAsVector(0); 465 } 466 467 edgesToRemove.push_back(eId); 468 } 469 470 if (!edgesToRemove.empty()) 471 ++numDisconnected; 472 473 while (!edgesToRemove.empty()) { 474 g.removeEdge(edgesToRemove.back()); 475 edgesToRemove.pop_back(); 476 } 477 } 478 } 479 } 480 481 void eliminateIndependentEdges() { 482 std::vector<Graph::EdgeId> edgesToProcess; 483 unsigned numEliminated = 0; 484 485 for (Graph::EdgeItr eItr = g.edgesBegin(), eEnd = g.edgesEnd(); 486 eItr != eEnd; ++eItr) { 487 edgesToProcess.push_back(*eItr); 488 } 489 490 while (!edgesToProcess.empty()) { 491 if (tryToEliminateEdge(edgesToProcess.back())) 492 ++numEliminated; 493 edgesToProcess.pop_back(); 494 } 495 } 496 497 bool tryToEliminateEdge(Graph::EdgeId eId) { 498 if (tryNormaliseEdgeMatrix(eId)) { 499 g.removeEdge(eId); 500 return true; 501 } 502 return false; 503 } 504 505 bool tryNormaliseEdgeMatrix(Graph::EdgeId &eId) { 506 507 const PBQPNum infinity = std::numeric_limits<PBQPNum>::infinity(); 508 509 Matrix &edgeCosts = g.getEdgeCosts(eId); 510 Vector &uCosts = g.getNodeCosts(g.getEdgeNode1(eId)), 511 &vCosts = g.getNodeCosts(g.getEdgeNode2(eId)); 512 513 for (unsigned r = 0; r < edgeCosts.getRows(); ++r) { 514 PBQPNum rowMin = infinity; 515 516 for (unsigned c = 0; c < edgeCosts.getCols(); ++c) { 517 if (vCosts[c] != infinity && edgeCosts[r][c] < rowMin) 518 rowMin = edgeCosts[r][c]; 519 } 520 521 uCosts[r] += rowMin; 522 523 if (rowMin != infinity) { 524 edgeCosts.subFromRow(r, rowMin); 525 } 526 else { 527 edgeCosts.setRow(r, 0); 528 } 529 } 530 531 for (unsigned c = 0; c < edgeCosts.getCols(); ++c) { 532 PBQPNum colMin = infinity; 533 534 for (unsigned r = 0; r < edgeCosts.getRows(); ++r) { 535 if (uCosts[r] != infinity && edgeCosts[r][c] < colMin) 536 colMin = edgeCosts[r][c]; 537 } 538 539 vCosts[c] += colMin; 540 541 if (colMin != infinity) { 542 edgeCosts.subFromCol(c, colMin); 543 } 544 else { 545 edgeCosts.setCol(c, 0); 546 } 547 } 548 549 return edgeCosts.isZero(); 550 } 551 552 void backpropagate() { 553 while (!stack.empty()) { 554 computeSolution(stack.back()); 555 stack.pop_back(); 556 } 557 } 558 559 void computeSolution(Graph::NodeId nId) { 560 561 NodeData &nodeData = getSolverNodeData(nId); 562 563 Vector v(g.getNodeCosts(nId)); 564 565 // Solve based on existing solved edges. 566 for (SolverEdgeItr solvedEdgeItr = nodeData.solverEdgesBegin(), 567 solvedEdgeEnd = nodeData.solverEdgesEnd(); 568 solvedEdgeItr != solvedEdgeEnd; ++solvedEdgeItr) { 569 570 Graph::EdgeId eId(*solvedEdgeItr); 571 Matrix &edgeCosts = g.getEdgeCosts(eId); 572 573 if (nId == g.getEdgeNode1(eId)) { 574 Graph::NodeId adjNode(g.getEdgeNode2(eId)); 575 unsigned adjSolution = s.getSelection(adjNode); 576 v += edgeCosts.getColAsVector(adjSolution); 577 } 578 else { 579 Graph::NodeId adjNode(g.getEdgeNode1(eId)); 580 unsigned adjSolution = s.getSelection(adjNode); 581 v += edgeCosts.getRowAsVector(adjSolution); 582 } 583 584 } 585 586 setSolution(nId, v.minIndex()); 587 } 588 589 void cleanup() { 590 h.cleanup(); 591 nodeDataList.clear(); 592 edgeDataList.clear(); 593 } 594 }; 595 596 /// \brief PBQP heuristic solver class. 597 /// 598 /// Given a PBQP Graph g representing a PBQP problem, you can find a solution 599 /// by calling 600 /// <tt>Solution s = HeuristicSolver<H>::solve(g);</tt> 601 /// 602 /// The choice of heuristic for the H parameter will affect both the solver 603 /// speed and solution quality. The heuristic should be chosen based on the 604 /// nature of the problem being solved. 605 /// Currently the only solver included with LLVM is the Briggs heuristic for 606 /// register allocation. 607 template <typename HImpl> 608 class HeuristicSolver { 609 public: 610 static Solution solve(Graph &g) { 611 HeuristicSolverImpl<HImpl> hs(g); 612 return hs.computeSolution(); 613 } 614 }; 615 616} 617 618#endif // LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H 619