1//===- llvm/ADT/SmallBitVector.h - 'Normally small' bit vectors -*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the SmallBitVector class.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef LLVM_ADT_SMALLBITVECTOR_H
14#define LLVM_ADT_SMALLBITVECTOR_H
15
16#include "llvm/ADT/BitVector.h"
17#include "llvm/ADT/iterator_range.h"
18#include "llvm/Support/MathExtras.h"
19#include <algorithm>
20#include <cassert>
21#include <climits>
22#include <cstddef>
23#include <cstdint>
24#include <limits>
25#include <utility>
26
27namespace llvm {
28
29/// This is a 'bitvector' (really, a variable-sized bit array), optimized for
30/// the case when the array is small. It contains one pointer-sized field, which
31/// is directly used as a plain collection of bits when possible, or as a
32/// pointer to a larger heap-allocated array when necessary. This allows normal
33/// "small" cases to be fast without losing generality for large inputs.
34class SmallBitVector {
35  // TODO: In "large" mode, a pointer to a BitVector is used, leading to an
36  // unnecessary level of indirection. It would be more efficient to use a
37  // pointer to memory containing size, allocation size, and the array of bits.
38  uintptr_t X = 1;
39
40  enum {
41    // The number of bits in this class.
42    NumBaseBits = sizeof(uintptr_t) * CHAR_BIT,
43
44    // One bit is used to discriminate between small and large mode. The
45    // remaining bits are used for the small-mode representation.
46    SmallNumRawBits = NumBaseBits - 1,
47
48    // A few more bits are used to store the size of the bit set in small mode.
49    // Theoretically this is a ceil-log2. These bits are encoded in the most
50    // significant bits of the raw bits.
51    SmallNumSizeBits = (NumBaseBits == 32 ? 5 :
52                        NumBaseBits == 64 ? 6 :
53                        SmallNumRawBits),
54
55    // The remaining bits are used to store the actual set in small mode.
56    SmallNumDataBits = SmallNumRawBits - SmallNumSizeBits
57  };
58
59  static_assert(NumBaseBits == 64 || NumBaseBits == 32,
60                "Unsupported word size");
61
62public:
63  using size_type = unsigned;
64
65  // Encapsulation of a single bit.
66  class reference {
67    SmallBitVector &TheVector;
68    unsigned BitPos;
69
70  public:
71    reference(SmallBitVector &b, unsigned Idx) : TheVector(b), BitPos(Idx) {}
72
73    reference(const reference&) = default;
74
75    reference& operator=(reference t) {
76      *this = bool(t);
77      return *this;
78    }
79
80    reference& operator=(bool t) {
81      if (t)
82        TheVector.set(BitPos);
83      else
84        TheVector.reset(BitPos);
85      return *this;
86    }
87
88    operator bool() const {
89      return const_cast<const SmallBitVector &>(TheVector).operator[](BitPos);
90    }
91  };
92
93private:
94  BitVector *getPointer() const {
95    assert(!isSmall());
96    return reinterpret_cast<BitVector *>(X);
97  }
98
99  void switchToSmall(uintptr_t NewSmallBits, size_t NewSize) {
100    X = 1;
101    setSmallSize(NewSize);
102    setSmallBits(NewSmallBits);
103  }
104
105  void switchToLarge(BitVector *BV) {
106    X = reinterpret_cast<uintptr_t>(BV);
107    assert(!isSmall() && "Tried to use an unaligned pointer");
108  }
109
110  // Return all the bits used for the "small" representation; this includes
111  // bits for the size as well as the element bits.
112  uintptr_t getSmallRawBits() const {
113    assert(isSmall());
114    return X >> 1;
115  }
116
117  void setSmallRawBits(uintptr_t NewRawBits) {
118    assert(isSmall());
119    X = (NewRawBits << 1) | uintptr_t(1);
120  }
121
122  // Return the size.
123  size_t getSmallSize() const { return getSmallRawBits() >> SmallNumDataBits; }
124
125  void setSmallSize(size_t Size) {
126    setSmallRawBits(getSmallBits() | (Size << SmallNumDataBits));
127  }
128
129  // Return the element bits.
130  uintptr_t getSmallBits() const {
131    return getSmallRawBits() & ~(~uintptr_t(0) << getSmallSize());
132  }
133
134  void setSmallBits(uintptr_t NewBits) {
135    setSmallRawBits((NewBits & ~(~uintptr_t(0) << getSmallSize())) |
136                    (getSmallSize() << SmallNumDataBits));
137  }
138
139public:
140  /// Creates an empty bitvector.
141  SmallBitVector() = default;
142
143  /// Creates a bitvector of specified number of bits. All bits are initialized
144  /// to the specified value.
145  explicit SmallBitVector(unsigned s, bool t = false) {
146    if (s <= SmallNumDataBits)
147      switchToSmall(t ? ~uintptr_t(0) : 0, s);
148    else
149      switchToLarge(new BitVector(s, t));
150  }
151
152  /// SmallBitVector copy ctor.
153  SmallBitVector(const SmallBitVector &RHS) {
154    if (RHS.isSmall())
155      X = RHS.X;
156    else
157      switchToLarge(new BitVector(*RHS.getPointer()));
158  }
159
160  SmallBitVector(SmallBitVector &&RHS) : X(RHS.X) {
161    RHS.X = 1;
162  }
163
164  ~SmallBitVector() {
165    if (!isSmall())
166      delete getPointer();
167  }
168
169  using const_set_bits_iterator = const_set_bits_iterator_impl<SmallBitVector>;
170  using set_iterator = const_set_bits_iterator;
171
172  const_set_bits_iterator set_bits_begin() const {
173    return const_set_bits_iterator(*this);
174  }
175
176  const_set_bits_iterator set_bits_end() const {
177    return const_set_bits_iterator(*this, -1);
178  }
179
180  iterator_range<const_set_bits_iterator> set_bits() const {
181    return make_range(set_bits_begin(), set_bits_end());
182  }
183
184  bool isSmall() const { return X & uintptr_t(1); }
185
186  /// Tests whether there are no bits in this bitvector.
187  bool empty() const {
188    return isSmall() ? getSmallSize() == 0 : getPointer()->empty();
189  }
190
191  /// Returns the number of bits in this bitvector.
192  size_t size() const {
193    return isSmall() ? getSmallSize() : getPointer()->size();
194  }
195
196  /// Returns the number of bits which are set.
197  size_type count() const {
198    if (isSmall()) {
199      uintptr_t Bits = getSmallBits();
200      return countPopulation(Bits);
201    }
202    return getPointer()->count();
203  }
204
205  /// Returns true if any bit is set.
206  bool any() const {
207    if (isSmall())
208      return getSmallBits() != 0;
209    return getPointer()->any();
210  }
211
212  /// Returns true if all bits are set.
213  bool all() const {
214    if (isSmall())
215      return getSmallBits() == (uintptr_t(1) << getSmallSize()) - 1;
216    return getPointer()->all();
217  }
218
219  /// Returns true if none of the bits are set.
220  bool none() const {
221    if (isSmall())
222      return getSmallBits() == 0;
223    return getPointer()->none();
224  }
225
226  /// Returns the index of the first set bit, -1 if none of the bits are set.
227  int find_first() const {
228    if (isSmall()) {
229      uintptr_t Bits = getSmallBits();
230      if (Bits == 0)
231        return -1;
232      return countTrailingZeros(Bits);
233    }
234    return getPointer()->find_first();
235  }
236
237  int find_last() const {
238    if (isSmall()) {
239      uintptr_t Bits = getSmallBits();
240      if (Bits == 0)
241        return -1;
242      return NumBaseBits - countLeadingZeros(Bits) - 1;
243    }
244    return getPointer()->find_last();
245  }
246
247  /// Returns the index of the first unset bit, -1 if all of the bits are set.
248  int find_first_unset() const {
249    if (isSmall()) {
250      if (count() == getSmallSize())
251        return -1;
252
253      uintptr_t Bits = getSmallBits();
254      return countTrailingOnes(Bits);
255    }
256    return getPointer()->find_first_unset();
257  }
258
259  int find_last_unset() const {
260    if (isSmall()) {
261      if (count() == getSmallSize())
262        return -1;
263
264      uintptr_t Bits = getSmallBits();
265      // Set unused bits.
266      Bits |= ~uintptr_t(0) << getSmallSize();
267      return NumBaseBits - countLeadingOnes(Bits) - 1;
268    }
269    return getPointer()->find_last_unset();
270  }
271
272  /// Returns the index of the next set bit following the "Prev" bit.
273  /// Returns -1 if the next set bit is not found.
274  int find_next(unsigned Prev) const {
275    if (isSmall()) {
276      uintptr_t Bits = getSmallBits();
277      // Mask off previous bits.
278      Bits &= ~uintptr_t(0) << (Prev + 1);
279      if (Bits == 0 || Prev + 1 >= getSmallSize())
280        return -1;
281      return countTrailingZeros(Bits);
282    }
283    return getPointer()->find_next(Prev);
284  }
285
286  /// Returns the index of the next unset bit following the "Prev" bit.
287  /// Returns -1 if the next unset bit is not found.
288  int find_next_unset(unsigned Prev) const {
289    if (isSmall()) {
290      ++Prev;
291      uintptr_t Bits = getSmallBits();
292      // Mask in previous bits.
293      uintptr_t Mask = (uintptr_t(1) << Prev) - 1;
294      Bits |= Mask;
295
296      if (Bits == ~uintptr_t(0) || Prev + 1 >= getSmallSize())
297        return -1;
298      return countTrailingOnes(Bits);
299    }
300    return getPointer()->find_next_unset(Prev);
301  }
302
303  /// find_prev - Returns the index of the first set bit that precedes the
304  /// the bit at \p PriorTo.  Returns -1 if all previous bits are unset.
305  int find_prev(unsigned PriorTo) const {
306    if (isSmall()) {
307      if (PriorTo == 0)
308        return -1;
309
310      --PriorTo;
311      uintptr_t Bits = getSmallBits();
312      Bits &= maskTrailingOnes<uintptr_t>(PriorTo + 1);
313      if (Bits == 0)
314        return -1;
315
316      return NumBaseBits - countLeadingZeros(Bits) - 1;
317    }
318    return getPointer()->find_prev(PriorTo);
319  }
320
321  /// Clear all bits.
322  void clear() {
323    if (!isSmall())
324      delete getPointer();
325    switchToSmall(0, 0);
326  }
327
328  /// Grow or shrink the bitvector.
329  void resize(unsigned N, bool t = false) {
330    if (!isSmall()) {
331      getPointer()->resize(N, t);
332    } else if (SmallNumDataBits >= N) {
333      uintptr_t NewBits = t ? ~uintptr_t(0) << getSmallSize() : 0;
334      setSmallSize(N);
335      setSmallBits(NewBits | getSmallBits());
336    } else {
337      BitVector *BV = new BitVector(N, t);
338      uintptr_t OldBits = getSmallBits();
339      for (size_t i = 0, e = getSmallSize(); i != e; ++i)
340        (*BV)[i] = (OldBits >> i) & 1;
341      switchToLarge(BV);
342    }
343  }
344
345  void reserve(unsigned N) {
346    if (isSmall()) {
347      if (N > SmallNumDataBits) {
348        uintptr_t OldBits = getSmallRawBits();
349        size_t SmallSize = getSmallSize();
350        BitVector *BV = new BitVector(SmallSize);
351        for (size_t i = 0; i < SmallSize; ++i)
352          if ((OldBits >> i) & 1)
353            BV->set(i);
354        BV->reserve(N);
355        switchToLarge(BV);
356      }
357    } else {
358      getPointer()->reserve(N);
359    }
360  }
361
362  // Set, reset, flip
363  SmallBitVector &set() {
364    if (isSmall())
365      setSmallBits(~uintptr_t(0));
366    else
367      getPointer()->set();
368    return *this;
369  }
370
371  SmallBitVector &set(unsigned Idx) {
372    if (isSmall()) {
373      assert(Idx <= static_cast<unsigned>(
374                        std::numeric_limits<uintptr_t>::digits) &&
375             "undefined behavior");
376      setSmallBits(getSmallBits() | (uintptr_t(1) << Idx));
377    }
378    else
379      getPointer()->set(Idx);
380    return *this;
381  }
382
383  /// Efficiently set a range of bits in [I, E)
384  SmallBitVector &set(unsigned I, unsigned E) {
385    assert(I <= E && "Attempted to set backwards range!");
386    assert(E <= size() && "Attempted to set out-of-bounds range!");
387    if (I == E) return *this;
388    if (isSmall()) {
389      uintptr_t EMask = ((uintptr_t)1) << E;
390      uintptr_t IMask = ((uintptr_t)1) << I;
391      uintptr_t Mask = EMask - IMask;
392      setSmallBits(getSmallBits() | Mask);
393    } else
394      getPointer()->set(I, E);
395    return *this;
396  }
397
398  SmallBitVector &reset() {
399    if (isSmall())
400      setSmallBits(0);
401    else
402      getPointer()->reset();
403    return *this;
404  }
405
406  SmallBitVector &reset(unsigned Idx) {
407    if (isSmall())
408      setSmallBits(getSmallBits() & ~(uintptr_t(1) << Idx));
409    else
410      getPointer()->reset(Idx);
411    return *this;
412  }
413
414  /// Efficiently reset a range of bits in [I, E)
415  SmallBitVector &reset(unsigned I, unsigned E) {
416    assert(I <= E && "Attempted to reset backwards range!");
417    assert(E <= size() && "Attempted to reset out-of-bounds range!");
418    if (I == E) return *this;
419    if (isSmall()) {
420      uintptr_t EMask = ((uintptr_t)1) << E;
421      uintptr_t IMask = ((uintptr_t)1) << I;
422      uintptr_t Mask = EMask - IMask;
423      setSmallBits(getSmallBits() & ~Mask);
424    } else
425      getPointer()->reset(I, E);
426    return *this;
427  }
428
429  SmallBitVector &flip() {
430    if (isSmall())
431      setSmallBits(~getSmallBits());
432    else
433      getPointer()->flip();
434    return *this;
435  }
436
437  SmallBitVector &flip(unsigned Idx) {
438    if (isSmall())
439      setSmallBits(getSmallBits() ^ (uintptr_t(1) << Idx));
440    else
441      getPointer()->flip(Idx);
442    return *this;
443  }
444
445  // No argument flip.
446  SmallBitVector operator~() const {
447    return SmallBitVector(*this).flip();
448  }
449
450  // Indexing.
451  reference operator[](unsigned Idx) {
452    assert(Idx < size() && "Out-of-bounds Bit access.");
453    return reference(*this, Idx);
454  }
455
456  bool operator[](unsigned Idx) const {
457    assert(Idx < size() && "Out-of-bounds Bit access.");
458    if (isSmall())
459      return ((getSmallBits() >> Idx) & 1) != 0;
460    return getPointer()->operator[](Idx);
461  }
462
463  bool test(unsigned Idx) const {
464    return (*this)[Idx];
465  }
466
467  // Push single bit to end of vector.
468  void push_back(bool Val) {
469    resize(size() + 1, Val);
470  }
471
472  /// Test if any common bits are set.
473  bool anyCommon(const SmallBitVector &RHS) const {
474    if (isSmall() && RHS.isSmall())
475      return (getSmallBits() & RHS.getSmallBits()) != 0;
476    if (!isSmall() && !RHS.isSmall())
477      return getPointer()->anyCommon(*RHS.getPointer());
478
479    for (unsigned i = 0, e = std::min(size(), RHS.size()); i != e; ++i)
480      if (test(i) && RHS.test(i))
481        return true;
482    return false;
483  }
484
485  // Comparison operators.
486  bool operator==(const SmallBitVector &RHS) const {
487    if (size() != RHS.size())
488      return false;
489    if (isSmall() && RHS.isSmall())
490      return getSmallBits() == RHS.getSmallBits();
491    else if (!isSmall() && !RHS.isSmall())
492      return *getPointer() == *RHS.getPointer();
493    else {
494      for (size_t i = 0, e = size(); i != e; ++i) {
495        if ((*this)[i] != RHS[i])
496          return false;
497      }
498      return true;
499    }
500  }
501
502  bool operator!=(const SmallBitVector &RHS) const {
503    return !(*this == RHS);
504  }
505
506  // Intersection, union, disjoint union.
507  // FIXME BitVector::operator&= does not resize the LHS but this does
508  SmallBitVector &operator&=(const SmallBitVector &RHS) {
509    resize(std::max(size(), RHS.size()));
510    if (isSmall() && RHS.isSmall())
511      setSmallBits(getSmallBits() & RHS.getSmallBits());
512    else if (!isSmall() && !RHS.isSmall())
513      getPointer()->operator&=(*RHS.getPointer());
514    else {
515      size_t i, e;
516      for (i = 0, e = std::min(size(), RHS.size()); i != e; ++i)
517        (*this)[i] = test(i) && RHS.test(i);
518      for (e = size(); i != e; ++i)
519        reset(i);
520    }
521    return *this;
522  }
523
524  /// Reset bits that are set in RHS. Same as *this &= ~RHS.
525  SmallBitVector &reset(const SmallBitVector &RHS) {
526    if (isSmall() && RHS.isSmall())
527      setSmallBits(getSmallBits() & ~RHS.getSmallBits());
528    else if (!isSmall() && !RHS.isSmall())
529      getPointer()->reset(*RHS.getPointer());
530    else
531      for (unsigned i = 0, e = std::min(size(), RHS.size()); i != e; ++i)
532        if (RHS.test(i))
533          reset(i);
534
535    return *this;
536  }
537
538  /// Check if (This - RHS) is zero. This is the same as reset(RHS) and any().
539  bool test(const SmallBitVector &RHS) const {
540    if (isSmall() && RHS.isSmall())
541      return (getSmallBits() & ~RHS.getSmallBits()) != 0;
542    if (!isSmall() && !RHS.isSmall())
543      return getPointer()->test(*RHS.getPointer());
544
545    unsigned i, e;
546    for (i = 0, e = std::min(size(), RHS.size()); i != e; ++i)
547      if (test(i) && !RHS.test(i))
548        return true;
549
550    for (e = size(); i != e; ++i)
551      if (test(i))
552        return true;
553
554    return false;
555  }
556
557  SmallBitVector &operator|=(const SmallBitVector &RHS) {
558    resize(std::max(size(), RHS.size()));
559    if (isSmall() && RHS.isSmall())
560      setSmallBits(getSmallBits() | RHS.getSmallBits());
561    else if (!isSmall() && !RHS.isSmall())
562      getPointer()->operator|=(*RHS.getPointer());
563    else {
564      for (size_t i = 0, e = RHS.size(); i != e; ++i)
565        (*this)[i] = test(i) || RHS.test(i);
566    }
567    return *this;
568  }
569
570  SmallBitVector &operator^=(const SmallBitVector &RHS) {
571    resize(std::max(size(), RHS.size()));
572    if (isSmall() && RHS.isSmall())
573      setSmallBits(getSmallBits() ^ RHS.getSmallBits());
574    else if (!isSmall() && !RHS.isSmall())
575      getPointer()->operator^=(*RHS.getPointer());
576    else {
577      for (size_t i = 0, e = RHS.size(); i != e; ++i)
578        (*this)[i] = test(i) != RHS.test(i);
579    }
580    return *this;
581  }
582
583  SmallBitVector &operator<<=(unsigned N) {
584    if (isSmall())
585      setSmallBits(getSmallBits() << N);
586    else
587      getPointer()->operator<<=(N);
588    return *this;
589  }
590
591  SmallBitVector &operator>>=(unsigned N) {
592    if (isSmall())
593      setSmallBits(getSmallBits() >> N);
594    else
595      getPointer()->operator>>=(N);
596    return *this;
597  }
598
599  // Assignment operator.
600  const SmallBitVector &operator=(const SmallBitVector &RHS) {
601    if (isSmall()) {
602      if (RHS.isSmall())
603        X = RHS.X;
604      else
605        switchToLarge(new BitVector(*RHS.getPointer()));
606    } else {
607      if (!RHS.isSmall())
608        *getPointer() = *RHS.getPointer();
609      else {
610        delete getPointer();
611        X = RHS.X;
612      }
613    }
614    return *this;
615  }
616
617  const SmallBitVector &operator=(SmallBitVector &&RHS) {
618    if (this != &RHS) {
619      clear();
620      swap(RHS);
621    }
622    return *this;
623  }
624
625  void swap(SmallBitVector &RHS) {
626    std::swap(X, RHS.X);
627  }
628
629  /// Add '1' bits from Mask to this vector. Don't resize.
630  /// This computes "*this |= Mask".
631  void setBitsInMask(const uint32_t *Mask, unsigned MaskWords = ~0u) {
632    if (isSmall())
633      applyMask<true, false>(Mask, MaskWords);
634    else
635      getPointer()->setBitsInMask(Mask, MaskWords);
636  }
637
638  /// Clear any bits in this vector that are set in Mask. Don't resize.
639  /// This computes "*this &= ~Mask".
640  void clearBitsInMask(const uint32_t *Mask, unsigned MaskWords = ~0u) {
641    if (isSmall())
642      applyMask<false, false>(Mask, MaskWords);
643    else
644      getPointer()->clearBitsInMask(Mask, MaskWords);
645  }
646
647  /// Add a bit to this vector for every '0' bit in Mask. Don't resize.
648  /// This computes "*this |= ~Mask".
649  void setBitsNotInMask(const uint32_t *Mask, unsigned MaskWords = ~0u) {
650    if (isSmall())
651      applyMask<true, true>(Mask, MaskWords);
652    else
653      getPointer()->setBitsNotInMask(Mask, MaskWords);
654  }
655
656  /// Clear a bit in this vector for every '0' bit in Mask. Don't resize.
657  /// This computes "*this &= Mask".
658  void clearBitsNotInMask(const uint32_t *Mask, unsigned MaskWords = ~0u) {
659    if (isSmall())
660      applyMask<false, true>(Mask, MaskWords);
661    else
662      getPointer()->clearBitsNotInMask(Mask, MaskWords);
663  }
664
665private:
666  template <bool AddBits, bool InvertMask>
667  void applyMask(const uint32_t *Mask, unsigned MaskWords) {
668    assert(MaskWords <= sizeof(uintptr_t) && "Mask is larger than base!");
669    uintptr_t M = Mask[0];
670    if (NumBaseBits == 64)
671      M |= uint64_t(Mask[1]) << 32;
672    if (InvertMask)
673      M = ~M;
674    if (AddBits)
675      setSmallBits(getSmallBits() | M);
676    else
677      setSmallBits(getSmallBits() & ~M);
678  }
679};
680
681inline SmallBitVector
682operator&(const SmallBitVector &LHS, const SmallBitVector &RHS) {
683  SmallBitVector Result(LHS);
684  Result &= RHS;
685  return Result;
686}
687
688inline SmallBitVector
689operator|(const SmallBitVector &LHS, const SmallBitVector &RHS) {
690  SmallBitVector Result(LHS);
691  Result |= RHS;
692  return Result;
693}
694
695inline SmallBitVector
696operator^(const SmallBitVector &LHS, const SmallBitVector &RHS) {
697  SmallBitVector Result(LHS);
698  Result ^= RHS;
699  return Result;
700}
701
702} // end namespace llvm
703
704namespace std {
705
706/// Implement std::swap in terms of BitVector swap.
707inline void
708swap(llvm::SmallBitVector &LHS, llvm::SmallBitVector &RHS) {
709  LHS.swap(RHS);
710}
711
712} // end namespace std
713
714#endif // LLVM_ADT_SMALLBITVECTOR_H
715