1//===- llvm/ADT/DirectedGraph.h - Directed Graph ----------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the interface and a base class implementation for a
10// directed graph.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef LLVM_ADT_DIRECTEDGRAPH_H
15#define LLVM_ADT_DIRECTEDGRAPH_H
16
17#include "llvm/ADT/GraphTraits.h"
18#include "llvm/ADT/SetVector.h"
19#include "llvm/ADT/SmallVector.h"
20#include "llvm/Support/Debug.h"
21#include "llvm/Support/raw_ostream.h"
22
23namespace llvm {
24
25/// Represent an edge in the directed graph.
26/// The edge contains the target node it connects to.
27template <class NodeType, class EdgeType> class DGEdge {
28public:
29  DGEdge() = delete;
30  /// Create an edge pointing to the given node \p N.
31  explicit DGEdge(NodeType &N) : TargetNode(N) {}
32  explicit DGEdge(const DGEdge<NodeType, EdgeType> &E)
33      : TargetNode(E.TargetNode) {}
34  DGEdge<NodeType, EdgeType> &operator=(const DGEdge<NodeType, EdgeType> &E) {
35    TargetNode = E.TargetNode;
36    return *this;
37  }
38
39  /// Static polymorphism: delegate implementation (via isEqualTo) to the
40  /// derived class.
41  bool operator==(const EdgeType &E) const { return getDerived().isEqualTo(E); }
42  bool operator!=(const EdgeType &E) const { return !operator==(E); }
43
44  /// Retrieve the target node this edge connects to.
45  const NodeType &getTargetNode() const { return TargetNode; }
46  NodeType &getTargetNode() {
47    return const_cast<NodeType &>(
48        static_cast<const DGEdge<NodeType, EdgeType> &>(*this).getTargetNode());
49  }
50
51  /// Set the target node this edge connects to.
52  void setTargetNode(const NodeType &N) { TargetNode = N; }
53
54protected:
55  // As the default implementation use address comparison for equality.
56  bool isEqualTo(const EdgeType &E) const { return this == &E; }
57
58  // Cast the 'this' pointer to the derived type and return a reference.
59  EdgeType &getDerived() { return *static_cast<EdgeType *>(this); }
60  const EdgeType &getDerived() const {
61    return *static_cast<const EdgeType *>(this);
62  }
63
64  // The target node this edge connects to.
65  NodeType &TargetNode;
66};
67
68/// Represent a node in the directed graph.
69/// The node has a (possibly empty) list of outgoing edges.
70template <class NodeType, class EdgeType> class DGNode {
71public:
72  using EdgeListTy = SetVector<EdgeType *>;
73  using iterator = typename EdgeListTy::iterator;
74  using const_iterator = typename EdgeListTy::const_iterator;
75
76  /// Create a node with a single outgoing edge \p E.
77  explicit DGNode(EdgeType &E) : Edges() { Edges.insert(&E); }
78  DGNode() = default;
79
80  explicit DGNode(const DGNode<NodeType, EdgeType> &N) : Edges(N.Edges) {}
81  DGNode(DGNode<NodeType, EdgeType> &&N) : Edges(std::move(N.Edges)) {}
82
83  DGNode<NodeType, EdgeType> &operator=(const DGNode<NodeType, EdgeType> &N) {
84    Edges = N.Edges;
85    return *this;
86  }
87  DGNode<NodeType, EdgeType> &operator=(const DGNode<NodeType, EdgeType> &&N) {
88    Edges = std::move(N.Edges);
89    return *this;
90  }
91
92  /// Static polymorphism: delegate implementation (via isEqualTo) to the
93  /// derived class.
94  bool operator==(const NodeType &N) const { return getDerived().isEqualTo(N); }
95  bool operator!=(const NodeType &N) const { return !operator==(N); }
96
97  const_iterator begin() const { return Edges.begin(); }
98  const_iterator end() const { return Edges.end(); }
99  iterator begin() { return Edges.begin(); }
100  iterator end() { return Edges.end(); }
101  const EdgeType &front() const { return *Edges.front(); }
102  EdgeType &front() { return *Edges.front(); }
103  const EdgeType &back() const { return *Edges.back(); }
104  EdgeType &back() { return *Edges.back(); }
105
106  /// Collect in \p EL, all the edges from this node to \p N.
107  /// Return true if at least one edge was found, and false otherwise.
108  /// Note that this implementation allows more than one edge to connect
109  /// a given pair of nodes.
110  bool findEdgesTo(const NodeType &N, SmallVectorImpl<EdgeType *> &EL) const {
111    assert(EL.empty() && "Expected the list of edges to be empty.");
112    for (auto *E : Edges)
113      if (E->getTargetNode() == N)
114        EL.push_back(E);
115    return !EL.empty();
116  }
117
118  /// Add the given edge \p E to this node, if it doesn't exist already. Returns
119  /// true if the edge is added and false otherwise.
120  bool addEdge(EdgeType &E) { return Edges.insert(&E); }
121
122  /// Remove the given edge \p E from this node, if it exists.
123  void removeEdge(EdgeType &E) { Edges.remove(&E); }
124
125  /// Test whether there is an edge that goes from this node to \p N.
126  bool hasEdgeTo(const NodeType &N) const {
127    return (findEdgeTo(N) != Edges.end());
128  }
129
130  /// Retrieve the outgoing edges for the node.
131  const EdgeListTy &getEdges() const { return Edges; }
132  EdgeListTy &getEdges() {
133    return const_cast<EdgeListTy &>(
134        static_cast<const DGNode<NodeType, EdgeType> &>(*this).Edges);
135  }
136
137  /// Clear the outgoing edges.
138  void clear() { Edges.clear(); }
139
140protected:
141  // As the default implementation use address comparison for equality.
142  bool isEqualTo(const NodeType &N) const { return this == &N; }
143
144  // Cast the 'this' pointer to the derived type and return a reference.
145  NodeType &getDerived() { return *static_cast<NodeType *>(this); }
146  const NodeType &getDerived() const {
147    return *static_cast<const NodeType *>(this);
148  }
149
150  /// Find an edge to \p N. If more than one edge exists, this will return
151  /// the first one in the list of edges.
152  const_iterator findEdgeTo(const NodeType &N) const {
153    return llvm::find_if(
154        Edges, [&N](const EdgeType *E) { return E->getTargetNode() == N; });
155  }
156
157  // The list of outgoing edges.
158  EdgeListTy Edges;
159};
160
161/// Directed graph
162///
163/// The graph is represented by a table of nodes.
164/// Each node contains a (possibly empty) list of outgoing edges.
165/// Each edge contains the target node it connects to.
166template <class NodeType, class EdgeType> class DirectedGraph {
167protected:
168  using NodeListTy = SmallVector<NodeType *, 10>;
169  using EdgeListTy = SmallVector<EdgeType *, 10>;
170public:
171  using iterator = typename NodeListTy::iterator;
172  using const_iterator = typename NodeListTy::const_iterator;
173  using DGraphType = DirectedGraph<NodeType, EdgeType>;
174
175  DirectedGraph() = default;
176  explicit DirectedGraph(NodeType &N) : Nodes() { addNode(N); }
177  DirectedGraph(const DGraphType &G) : Nodes(G.Nodes) {}
178  DirectedGraph(DGraphType &&RHS) : Nodes(std::move(RHS.Nodes)) {}
179  DGraphType &operator=(const DGraphType &G) {
180    Nodes = G.Nodes;
181    return *this;
182  }
183  DGraphType &operator=(const DGraphType &&G) {
184    Nodes = std::move(G.Nodes);
185    return *this;
186  }
187
188  const_iterator begin() const { return Nodes.begin(); }
189  const_iterator end() const { return Nodes.end(); }
190  iterator begin() { return Nodes.begin(); }
191  iterator end() { return Nodes.end(); }
192  const NodeType &front() const { return *Nodes.front(); }
193  NodeType &front() { return *Nodes.front(); }
194  const NodeType &back() const { return *Nodes.back(); }
195  NodeType &back() { return *Nodes.back(); }
196
197  size_t size() const { return Nodes.size(); }
198
199  /// Find the given node \p N in the table.
200  const_iterator findNode(const NodeType &N) const {
201    return llvm::find_if(Nodes,
202                         [&N](const NodeType *Node) { return *Node == N; });
203  }
204  iterator findNode(const NodeType &N) {
205    return const_cast<iterator>(
206        static_cast<const DGraphType &>(*this).findNode(N));
207  }
208
209  /// Add the given node \p N to the graph if it is not already present.
210  bool addNode(NodeType &N) {
211    if (findNode(N) != Nodes.end())
212      return false;
213    Nodes.push_back(&N);
214    return true;
215  }
216
217  /// Collect in \p EL all edges that are coming into node \p N. Return true
218  /// if at least one edge was found, and false otherwise.
219  bool findIncomingEdgesToNode(const NodeType &N, SmallVectorImpl<EdgeType*> &EL) const {
220    assert(EL.empty() && "Expected the list of edges to be empty.");
221    EdgeListTy TempList;
222    for (auto *Node : Nodes) {
223      if (*Node == N)
224        continue;
225      Node->findEdgesTo(N, TempList);
226      EL.insert(EL.end(), TempList.begin(), TempList.end());
227      TempList.clear();
228    }
229    return !EL.empty();
230  }
231
232  /// Remove the given node \p N from the graph. If the node has incoming or
233  /// outgoing edges, they are also removed. Return true if the node was found
234  /// and then removed, and false if the node was not found in the graph to
235  /// begin with.
236  bool removeNode(NodeType &N) {
237    iterator IT = findNode(N);
238    if (IT == Nodes.end())
239      return false;
240    // Remove incoming edges.
241    EdgeListTy EL;
242    for (auto *Node : Nodes) {
243      if (*Node == N)
244        continue;
245      Node->findEdgesTo(N, EL);
246      for (auto *E : EL)
247        Node->removeEdge(*E);
248      EL.clear();
249    }
250    N.clear();
251    Nodes.erase(IT);
252    return true;
253  }
254
255  /// Assuming nodes \p Src and \p Dst are already in the graph, connect node \p
256  /// Src to node \p Dst using the provided edge \p E. Return true if \p Src is
257  /// not already connected to \p Dst via \p E, and false otherwise.
258  bool connect(NodeType &Src, NodeType &Dst, EdgeType &E) {
259    assert(findNode(Src) != Nodes.end() && "Src node should be present.");
260    assert(findNode(Dst) != Nodes.end() && "Dst node should be present.");
261    assert((E.getTargetNode() == Dst) &&
262           "Target of the given edge does not match Dst.");
263    return Src.addEdge(E);
264  }
265
266protected:
267  // The list of nodes in the graph.
268  NodeListTy Nodes;
269};
270
271} // namespace llvm
272
273#endif // LLVM_ADT_DIRECTEDGRAPH_H
274