1//===- llvm/ADT/CoalescingBitVector.h - A coalescing bitvector --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// A bitvector that uses an IntervalMap to coalesce adjacent elements
11/// into intervals.
12///
13//===----------------------------------------------------------------------===//
14
15#ifndef LLVM_ADT_COALESCINGBITVECTOR_H
16#define LLVM_ADT_COALESCINGBITVECTOR_H
17
18#include "llvm/ADT/IntervalMap.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/ADT/SmallVector.h"
21#include "llvm/ADT/iterator_range.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/raw_ostream.h"
24
25#include <initializer_list>
26
27namespace llvm {
28
29/// A bitvector that, under the hood, relies on an IntervalMap to coalesce
30/// elements into intervals. Good for representing sets which predominantly
31/// contain contiguous ranges. Bad for representing sets with lots of gaps
32/// between elements.
33///
34/// Compared to SparseBitVector, CoalescingBitVector offers more predictable
35/// performance for non-sequential find() operations.
36///
37/// \tparam IndexT - The type of the index into the bitvector.
38template <typename IndexT> class CoalescingBitVector {
39  static_assert(std::is_unsigned<IndexT>::value,
40                "Index must be an unsigned integer.");
41
42  using ThisT = CoalescingBitVector<IndexT>;
43
44  /// An interval map for closed integer ranges. The mapped values are unused.
45  using MapT = IntervalMap<IndexT, char>;
46
47  using UnderlyingIterator = typename MapT::const_iterator;
48
49  using IntervalT = std::pair<IndexT, IndexT>;
50
51public:
52  using Allocator = typename MapT::Allocator;
53
54  /// Construct by passing in a CoalescingBitVector<IndexT>::Allocator
55  /// reference.
56  CoalescingBitVector(Allocator &Alloc)
57      : Alloc(&Alloc), Intervals(Alloc) {}
58
59  /// \name Copy/move constructors and assignment operators.
60  /// @{
61
62  CoalescingBitVector(const ThisT &Other)
63      : Alloc(Other.Alloc), Intervals(*Other.Alloc) {
64    set(Other);
65  }
66
67  ThisT &operator=(const ThisT &Other) {
68    clear();
69    set(Other);
70    return *this;
71  }
72
73  CoalescingBitVector(ThisT &&Other) = delete;
74  ThisT &operator=(ThisT &&Other) = delete;
75
76  /// @}
77
78  /// Clear all the bits.
79  void clear() { Intervals.clear(); }
80
81  /// Check whether no bits are set.
82  bool empty() const { return Intervals.empty(); }
83
84  /// Count the number of set bits.
85  unsigned count() const {
86    unsigned Bits = 0;
87    for (auto It = Intervals.begin(), End = Intervals.end(); It != End; ++It)
88      Bits += 1 + It.stop() - It.start();
89    return Bits;
90  }
91
92  /// Set the bit at \p Index.
93  ///
94  /// This method does /not/ support setting a bit that has already been set,
95  /// for efficiency reasons. If possible, restructure your code to not set the
96  /// same bit multiple times, or use \ref test_and_set.
97  void set(IndexT Index) {
98    assert(!test(Index) && "Setting already-set bits not supported/efficient, "
99                           "IntervalMap will assert");
100    insert(Index, Index);
101  }
102
103  /// Set the bits set in \p Other.
104  ///
105  /// This method does /not/ support setting already-set bits, see \ref set
106  /// for the rationale. For a safe set union operation, use \ref operator|=.
107  void set(const ThisT &Other) {
108    for (auto It = Other.Intervals.begin(), End = Other.Intervals.end();
109         It != End; ++It)
110      insert(It.start(), It.stop());
111  }
112
113  /// Set the bits at \p Indices. Used for testing, primarily.
114  void set(std::initializer_list<IndexT> Indices) {
115    for (IndexT Index : Indices)
116      set(Index);
117  }
118
119  /// Check whether the bit at \p Index is set.
120  bool test(IndexT Index) const {
121    const auto It = Intervals.find(Index);
122    if (It == Intervals.end())
123      return false;
124    assert(It.stop() >= Index && "Interval must end after Index");
125    return It.start() <= Index;
126  }
127
128  /// Set the bit at \p Index. Supports setting an already-set bit.
129  void test_and_set(IndexT Index) {
130    if (!test(Index))
131      set(Index);
132  }
133
134  /// Reset the bit at \p Index. Supports resetting an already-unset bit.
135  void reset(IndexT Index) {
136    auto It = Intervals.find(Index);
137    if (It == Intervals.end())
138      return;
139
140    // Split the interval containing Index into up to two parts: one from
141    // [Start, Index-1] and another from [Index+1, Stop]. If Index is equal to
142    // either Start or Stop, we create one new interval. If Index is equal to
143    // both Start and Stop, we simply erase the existing interval.
144    IndexT Start = It.start();
145    if (Index < Start)
146      // The index was not set.
147      return;
148    IndexT Stop = It.stop();
149    assert(Index <= Stop && "Wrong interval for index");
150    It.erase();
151    if (Start < Index)
152      insert(Start, Index - 1);
153    if (Index < Stop)
154      insert(Index + 1, Stop);
155  }
156
157  /// Set union. If \p RHS is guaranteed to not overlap with this, \ref set may
158  /// be a faster alternative.
159  void operator|=(const ThisT &RHS) {
160    // Get the overlaps between the two interval maps.
161    SmallVector<IntervalT, 8> Overlaps;
162    getOverlaps(RHS, Overlaps);
163
164    // Insert the non-overlapping parts of all the intervals from RHS.
165    for (auto It = RHS.Intervals.begin(), End = RHS.Intervals.end();
166         It != End; ++It) {
167      IndexT Start = It.start();
168      IndexT Stop = It.stop();
169      SmallVector<IntervalT, 8> NonOverlappingParts;
170      getNonOverlappingParts(Start, Stop, Overlaps, NonOverlappingParts);
171      for (IntervalT AdditivePortion : NonOverlappingParts)
172        insert(AdditivePortion.first, AdditivePortion.second);
173    }
174  }
175
176  /// Set intersection.
177  void operator&=(const ThisT &RHS) {
178    // Get the overlaps between the two interval maps (i.e. the intersection).
179    SmallVector<IntervalT, 8> Overlaps;
180    getOverlaps(RHS, Overlaps);
181    // Rebuild the interval map, including only the overlaps.
182    clear();
183    for (IntervalT Overlap : Overlaps)
184      insert(Overlap.first, Overlap.second);
185  }
186
187  /// Reset all bits present in \p Other.
188  void intersectWithComplement(const ThisT &Other) {
189    SmallVector<IntervalT, 8> Overlaps;
190    if (!getOverlaps(Other, Overlaps)) {
191      // If there is no overlap with Other, the intersection is empty.
192      return;
193    }
194
195    // Delete the overlapping intervals. Split up intervals that only partially
196    // intersect an overlap.
197    for (IntervalT Overlap : Overlaps) {
198      IndexT OlapStart, OlapStop;
199      std::tie(OlapStart, OlapStop) = Overlap;
200
201      auto It = Intervals.find(OlapStart);
202      IndexT CurrStart = It.start();
203      IndexT CurrStop = It.stop();
204      assert(CurrStart <= OlapStart && OlapStop <= CurrStop &&
205             "Expected some intersection!");
206
207      // Split the overlap interval into up to two parts: one from [CurrStart,
208      // OlapStart-1] and another from [OlapStop+1, CurrStop]. If OlapStart is
209      // equal to CurrStart, the first split interval is unnecessary. Ditto for
210      // when OlapStop is equal to CurrStop, we omit the second split interval.
211      It.erase();
212      if (CurrStart < OlapStart)
213        insert(CurrStart, OlapStart - 1);
214      if (OlapStop < CurrStop)
215        insert(OlapStop + 1, CurrStop);
216    }
217  }
218
219  bool operator==(const ThisT &RHS) const {
220    // We cannot just use std::equal because it checks the dereferenced values
221    // of an iterator pair for equality, not the iterators themselves. In our
222    // case that results in comparison of the (unused) IntervalMap values.
223    auto ItL = Intervals.begin();
224    auto ItR = RHS.Intervals.begin();
225    while (ItL != Intervals.end() && ItR != RHS.Intervals.end() &&
226           ItL.start() == ItR.start() && ItL.stop() == ItR.stop()) {
227      ++ItL;
228      ++ItR;
229    }
230    return ItL == Intervals.end() && ItR == RHS.Intervals.end();
231  }
232
233  bool operator!=(const ThisT &RHS) const { return !operator==(RHS); }
234
235  class const_iterator {
236    friend class CoalescingBitVector;
237
238  public:
239    using iterator_category = std::forward_iterator_tag;
240    using value_type = IndexT;
241    using difference_type = std::ptrdiff_t;
242    using pointer = value_type *;
243    using reference = value_type &;
244
245  private:
246    // For performance reasons, make the offset at the end different than the
247    // one used in \ref begin, to optimize the common `It == end()` pattern.
248    static constexpr unsigned kIteratorAtTheEndOffset = ~0u;
249
250    UnderlyingIterator MapIterator;
251    unsigned OffsetIntoMapIterator = 0;
252
253    // Querying the start/stop of an IntervalMap iterator can be very expensive.
254    // Cache these values for performance reasons.
255    IndexT CachedStart = IndexT();
256    IndexT CachedStop = IndexT();
257
258    void setToEnd() {
259      OffsetIntoMapIterator = kIteratorAtTheEndOffset;
260      CachedStart = IndexT();
261      CachedStop = IndexT();
262    }
263
264    /// MapIterator has just changed, reset the cached state to point to the
265    /// start of the new underlying iterator.
266    void resetCache() {
267      if (MapIterator.valid()) {
268        OffsetIntoMapIterator = 0;
269        CachedStart = MapIterator.start();
270        CachedStop = MapIterator.stop();
271      } else {
272        setToEnd();
273      }
274    }
275
276    /// Advance the iterator to \p Index, if it is contained within the current
277    /// interval. The public-facing method which supports advancing past the
278    /// current interval is \ref advanceToLowerBound.
279    void advanceTo(IndexT Index) {
280      assert(Index <= CachedStop && "Cannot advance to OOB index");
281      if (Index < CachedStart)
282        // We're already past this index.
283        return;
284      OffsetIntoMapIterator = Index - CachedStart;
285    }
286
287    const_iterator(UnderlyingIterator MapIt) : MapIterator(MapIt) {
288      resetCache();
289    }
290
291  public:
292    const_iterator() { setToEnd(); }
293
294    bool operator==(const const_iterator &RHS) const {
295      // Do /not/ compare MapIterator for equality, as this is very expensive.
296      // The cached start/stop values make that check unnecessary.
297      return std::tie(OffsetIntoMapIterator, CachedStart, CachedStop) ==
298             std::tie(RHS.OffsetIntoMapIterator, RHS.CachedStart,
299                      RHS.CachedStop);
300    }
301
302    bool operator!=(const const_iterator &RHS) const {
303      return !operator==(RHS);
304    }
305
306    IndexT operator*() const { return CachedStart + OffsetIntoMapIterator; }
307
308    const_iterator &operator++() { // Pre-increment (++It).
309      if (CachedStart + OffsetIntoMapIterator < CachedStop) {
310        // Keep going within the current interval.
311        ++OffsetIntoMapIterator;
312      } else {
313        // We reached the end of the current interval: advance.
314        ++MapIterator;
315        resetCache();
316      }
317      return *this;
318    }
319
320    const_iterator operator++(int) { // Post-increment (It++).
321      const_iterator tmp = *this;
322      operator++();
323      return tmp;
324    }
325
326    /// Advance the iterator to the first set bit AT, OR AFTER, \p Index. If
327    /// no such set bit exists, advance to end(). This is like std::lower_bound.
328    /// This is useful if \p Index is close to the current iterator position.
329    /// However, unlike \ref find(), this has worst-case O(n) performance.
330    void advanceToLowerBound(IndexT Index) {
331      if (OffsetIntoMapIterator == kIteratorAtTheEndOffset)
332        return;
333
334      // Advance to the first interval containing (or past) Index, or to end().
335      while (Index > CachedStop) {
336        ++MapIterator;
337        resetCache();
338        if (OffsetIntoMapIterator == kIteratorAtTheEndOffset)
339          return;
340      }
341
342      advanceTo(Index);
343    }
344  };
345
346  const_iterator begin() const { return const_iterator(Intervals.begin()); }
347
348  const_iterator end() const { return const_iterator(); }
349
350  /// Return an iterator pointing to the first set bit AT, OR AFTER, \p Index.
351  /// If no such set bit exists, return end(). This is like std::lower_bound.
352  /// This has worst-case logarithmic performance (roughly O(log(gaps between
353  /// contiguous ranges))).
354  const_iterator find(IndexT Index) const {
355    auto UnderlyingIt = Intervals.find(Index);
356    if (UnderlyingIt == Intervals.end())
357      return end();
358    auto It = const_iterator(UnderlyingIt);
359    It.advanceTo(Index);
360    return It;
361  }
362
363  /// Return a range iterator which iterates over all of the set bits in the
364  /// half-open range [Start, End).
365  iterator_range<const_iterator> half_open_range(IndexT Start,
366                                                 IndexT End) const {
367    assert(Start < End && "Not a valid range");
368    auto StartIt = find(Start);
369    if (StartIt == end() || *StartIt >= End)
370      return {end(), end()};
371    auto EndIt = StartIt;
372    EndIt.advanceToLowerBound(End);
373    return {StartIt, EndIt};
374  }
375
376  void print(raw_ostream &OS) const {
377    OS << "{";
378    for (auto It = Intervals.begin(), End = Intervals.end(); It != End;
379         ++It) {
380      OS << "[" << It.start();
381      if (It.start() != It.stop())
382        OS << ", " << It.stop();
383      OS << "]";
384    }
385    OS << "}";
386  }
387
388#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
389  LLVM_DUMP_METHOD void dump() const {
390    // LLDB swallows the first line of output after callling dump(). Add
391    // newlines before/after the braces to work around this.
392    dbgs() << "\n";
393    print(dbgs());
394    dbgs() << "\n";
395  }
396#endif
397
398private:
399  void insert(IndexT Start, IndexT End) { Intervals.insert(Start, End, 0); }
400
401  /// Record the overlaps between \p this and \p Other in \p Overlaps. Return
402  /// true if there is any overlap.
403  bool getOverlaps(const ThisT &Other,
404                   SmallVectorImpl<IntervalT> &Overlaps) const {
405    for (IntervalMapOverlaps<MapT, MapT> I(Intervals, Other.Intervals);
406         I.valid(); ++I)
407      Overlaps.emplace_back(I.start(), I.stop());
408    assert(llvm::is_sorted(Overlaps,
409                           [](IntervalT LHS, IntervalT RHS) {
410                             return LHS.second < RHS.first;
411                           }) &&
412           "Overlaps must be sorted");
413    return !Overlaps.empty();
414  }
415
416  /// Given the set of overlaps between this and some other bitvector, and an
417  /// interval [Start, Stop] from that bitvector, determine the portions of the
418  /// interval which do not overlap with this.
419  void getNonOverlappingParts(IndexT Start, IndexT Stop,
420                              const SmallVectorImpl<IntervalT> &Overlaps,
421                              SmallVectorImpl<IntervalT> &NonOverlappingParts) {
422    IndexT NextUncoveredBit = Start;
423    for (IntervalT Overlap : Overlaps) {
424      IndexT OlapStart, OlapStop;
425      std::tie(OlapStart, OlapStop) = Overlap;
426
427      // [Start;Stop] and [OlapStart;OlapStop] overlap iff OlapStart <= Stop
428      // and Start <= OlapStop.
429      bool DoesOverlap = OlapStart <= Stop && Start <= OlapStop;
430      if (!DoesOverlap)
431        continue;
432
433      // Cover the range [NextUncoveredBit, OlapStart). This puts the start of
434      // the next uncovered range at OlapStop+1.
435      if (NextUncoveredBit < OlapStart)
436        NonOverlappingParts.emplace_back(NextUncoveredBit, OlapStart - 1);
437      NextUncoveredBit = OlapStop + 1;
438      if (NextUncoveredBit > Stop)
439        break;
440    }
441    if (NextUncoveredBit <= Stop)
442      NonOverlappingParts.emplace_back(NextUncoveredBit, Stop);
443  }
444
445  Allocator *Alloc;
446  MapT Intervals;
447};
448
449} // namespace llvm
450
451#endif // LLVM_ADT_COALESCINGBITVECTOR_H
452