1326938Sdim//===--- Random.h - Utilities for random sampling -------------------------===//
2326938Sdim//
3353358Sdim// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4353358Sdim// See https://llvm.org/LICENSE.txt for license information.
5353358Sdim// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6326938Sdim//
7326938Sdim//===----------------------------------------------------------------------===//
8326938Sdim//
9326938Sdim// Utilities for random sampling.
10326938Sdim//
11326938Sdim//===----------------------------------------------------------------------===//
12326938Sdim
13326938Sdim#ifndef LLVM_FUZZMUTATE_RANDOM_H
14326938Sdim#define LLVM_FUZZMUTATE_RANDOM_H
15326938Sdim
16326938Sdim#include <random>
17326938Sdim#include "llvm/Support/raw_ostream.h"
18326938Sdimnamespace llvm {
19326938Sdim
20326938Sdim/// Return a uniformly distributed random value between \c Min and \c Max
21326938Sdimtemplate <typename T, typename GenT> T uniform(GenT &Gen, T Min, T Max) {
22326938Sdim  return std::uniform_int_distribution<T>(Min, Max)(Gen);
23326938Sdim}
24326938Sdim
25326938Sdim/// Return a uniformly distributed random value of type \c T
26326938Sdimtemplate <typename T, typename GenT> T uniform(GenT &Gen) {
27326938Sdim  return uniform<T>(Gen, std::numeric_limits<T>::min(),
28326938Sdim                    std::numeric_limits<T>::max());
29326938Sdim}
30326938Sdim
31326938Sdim/// Randomly selects an item by sampling into a set with an unknown number of
32326938Sdim/// elements, which may each be weighted to be more likely choices.
33326938Sdimtemplate <typename T, typename GenT> class ReservoirSampler {
34326938Sdim  GenT &RandGen;
35326938Sdim  typename std::remove_const<T>::type Selection = {};
36326938Sdim  uint64_t TotalWeight = 0;
37326938Sdim
38326938Sdimpublic:
39326938Sdim  ReservoirSampler(GenT &RandGen) : RandGen(RandGen) {}
40326938Sdim
41326938Sdim  uint64_t totalWeight() const { return TotalWeight; }
42326938Sdim  bool isEmpty() const { return TotalWeight == 0; }
43326938Sdim
44326938Sdim  const T &getSelection() const {
45326938Sdim    assert(!isEmpty() && "Nothing selected");
46326938Sdim    return Selection;
47326938Sdim  }
48326938Sdim
49326938Sdim  explicit operator bool() const { return !isEmpty();}
50326938Sdim  const T &operator*() const { return getSelection(); }
51326938Sdim
52326938Sdim  /// Sample each item in \c Items with unit weight
53326938Sdim  template <typename RangeT> ReservoirSampler &sample(RangeT &&Items) {
54326938Sdim    for (auto &I : Items)
55326938Sdim      sample(I, 1);
56326938Sdim    return *this;
57326938Sdim  }
58326938Sdim
59326938Sdim  /// Sample a single item with the given weight.
60326938Sdim  ReservoirSampler &sample(const T &Item, uint64_t Weight) {
61326938Sdim    if (!Weight)
62326938Sdim      // If the weight is zero, do nothing.
63326938Sdim      return *this;
64326938Sdim    TotalWeight += Weight;
65326938Sdim    // Consider switching from the current element to this one.
66326938Sdim    if (uniform<uint64_t>(RandGen, 1, TotalWeight) <= Weight)
67326938Sdim      Selection = Item;
68326938Sdim    return *this;
69326938Sdim  }
70326938Sdim};
71326938Sdim
72326938Sdimtemplate <typename GenT, typename RangeT,
73326938Sdim          typename ElT = typename std::remove_reference<
74326938Sdim              decltype(*std::begin(std::declval<RangeT>()))>::type>
75326938SdimReservoirSampler<ElT, GenT> makeSampler(GenT &RandGen, RangeT &&Items) {
76326938Sdim  ReservoirSampler<ElT, GenT> RS(RandGen);
77326938Sdim  RS.sample(Items);
78326938Sdim  return RS;
79326938Sdim}
80326938Sdim
81326938Sdimtemplate <typename GenT, typename T>
82326938SdimReservoirSampler<T, GenT> makeSampler(GenT &RandGen, const T &Item,
83326938Sdim                                      uint64_t Weight) {
84326938Sdim  ReservoirSampler<T, GenT> RS(RandGen);
85326938Sdim  RS.sample(Item, Weight);
86326938Sdim  return RS;
87326938Sdim}
88326938Sdim
89326938Sdimtemplate <typename T, typename GenT>
90326938SdimReservoirSampler<T, GenT> makeSampler(GenT &RandGen) {
91326938Sdim  return ReservoirSampler<T, GenT>(RandGen);
92326938Sdim}
93326938Sdim
94326938Sdim} // End llvm namespace
95326938Sdim
96326938Sdim#endif // LLVM_FUZZMUTATE_RANDOM_H
97