1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <algorithm>
10#include <cstdint>
11#include <memory>
12#include <random>
13#include <set>
14#include <string>
15#include <vector>
16
17#include "CartesianBenchmarks.h"
18#include "benchmark/benchmark.h"
19#include "test_macros.h"
20
21namespace {
22
23enum class HitType { Hit, Miss };
24
25struct AllHitTypes : EnumValuesAsTuple<AllHitTypes, HitType, 2> {
26  static constexpr const char* Names[] = {"Hit", "Miss"};
27};
28
29enum class AccessPattern { Ordered, Random };
30
31struct AllAccessPattern
32    : EnumValuesAsTuple<AllAccessPattern, AccessPattern, 2> {
33  static constexpr const char* Names[] = {"Ordered", "Random"};
34};
35
36void sortKeysBy(std::vector<uint64_t>& Keys, AccessPattern AP) {
37  if (AP == AccessPattern::Random) {
38    std::random_device R;
39    std::mt19937 M(R());
40    std::shuffle(std::begin(Keys), std::end(Keys), M);
41  }
42}
43
44struct TestSets {
45  std::vector<std::set<uint64_t> > Sets;
46  std::vector<uint64_t> Keys;
47};
48
49TestSets makeTestingSets(size_t TableSize, size_t NumTables, HitType Hit,
50                         AccessPattern Access) {
51  TestSets R;
52  R.Sets.resize(1);
53
54  for (uint64_t I = 0; I < TableSize; ++I) {
55    R.Sets[0].insert(2 * I);
56    R.Keys.push_back(Hit == HitType::Hit ? 2 * I : 2 * I + 1);
57  }
58  R.Sets.resize(NumTables, R.Sets[0]);
59  sortKeysBy(R.Keys, Access);
60
61  return R;
62}
63
64struct Base {
65  size_t TableSize;
66  size_t NumTables;
67  Base(size_t T, size_t N) : TableSize(T), NumTables(N) {}
68
69  bool skip() const {
70    size_t Total = TableSize * NumTables;
71    return Total < 100 || Total > 1000000;
72  }
73
74  std::string baseName() const {
75    return "_TableSize" + std::to_string(TableSize) + "_NumTables" +
76           std::to_string(NumTables);
77  }
78};
79
80template <class Access>
81struct Create : Base {
82  using Base::Base;
83
84  void run(benchmark::State& State) const {
85    std::vector<uint64_t> Keys(TableSize);
86    std::iota(Keys.begin(), Keys.end(), uint64_t{0});
87    sortKeysBy(Keys, Access());
88
89    while (State.KeepRunningBatch(TableSize * NumTables)) {
90      std::vector<std::set<uint64_t>> Sets(NumTables);
91      for (auto K : Keys) {
92        for (auto& Set : Sets) {
93          benchmark::DoNotOptimize(Set.insert(K));
94        }
95      }
96    }
97  }
98
99  std::string name() const {
100    return "BM_Create" + Access::name() + baseName();
101  }
102};
103
104template <class Hit, class Access>
105struct Find : Base {
106  using Base::Base;
107
108  void run(benchmark::State& State) const {
109    auto Data = makeTestingSets(TableSize, NumTables, Hit(), Access());
110
111    while (State.KeepRunningBatch(TableSize * NumTables)) {
112      for (auto K : Data.Keys) {
113        for (auto& Set : Data.Sets) {
114          benchmark::DoNotOptimize(Set.find(K));
115        }
116      }
117    }
118  }
119
120  std::string name() const {
121    return "BM_Find" + Hit::name() + Access::name() + baseName();
122  }
123};
124
125template <class Hit, class Access>
126struct FindNeEnd : Base {
127  using Base::Base;
128
129  void run(benchmark::State& State) const {
130    auto Data = makeTestingSets(TableSize, NumTables, Hit(), Access());
131
132    while (State.KeepRunningBatch(TableSize * NumTables)) {
133      for (auto K : Data.Keys) {
134        for (auto& Set : Data.Sets) {
135          benchmark::DoNotOptimize(Set.find(K) != Set.end());
136        }
137      }
138    }
139  }
140
141  std::string name() const {
142    return "BM_FindNeEnd" + Hit::name() + Access::name() + baseName();
143  }
144};
145
146template <class Access>
147struct InsertHit : Base {
148  using Base::Base;
149
150  void run(benchmark::State& State) const {
151    auto Data = makeTestingSets(TableSize, NumTables, HitType::Hit, Access());
152
153    while (State.KeepRunningBatch(TableSize * NumTables)) {
154      for (auto K : Data.Keys) {
155        for (auto& Set : Data.Sets) {
156          benchmark::DoNotOptimize(Set.insert(K));
157        }
158      }
159    }
160  }
161
162  std::string name() const {
163    return "BM_InsertHit" + Access::name() + baseName();
164  }
165};
166
167template <class Access>
168struct InsertMissAndErase : Base {
169  using Base::Base;
170
171  void run(benchmark::State& State) const {
172    auto Data = makeTestingSets(TableSize, NumTables, HitType::Miss, Access());
173
174    while (State.KeepRunningBatch(TableSize * NumTables)) {
175      for (auto K : Data.Keys) {
176        for (auto& Set : Data.Sets) {
177          benchmark::DoNotOptimize(Set.erase(Set.insert(K).first));
178        }
179      }
180    }
181  }
182
183  std::string name() const {
184    return "BM_InsertMissAndErase" + Access::name() + baseName();
185  }
186};
187
188struct IterateRangeFor : Base {
189  using Base::Base;
190
191  void run(benchmark::State& State) const {
192    auto Data = makeTestingSets(TableSize, NumTables, HitType::Miss,
193                                AccessPattern::Ordered);
194
195    while (State.KeepRunningBatch(TableSize * NumTables)) {
196      for (auto& Set : Data.Sets) {
197        for (auto& V : Set) {
198          benchmark::DoNotOptimize(V);
199        }
200      }
201    }
202  }
203
204  std::string name() const { return "BM_IterateRangeFor" + baseName(); }
205};
206
207struct IterateBeginEnd : Base {
208  using Base::Base;
209
210  void run(benchmark::State& State) const {
211    auto Data = makeTestingSets(TableSize, NumTables, HitType::Miss,
212                                AccessPattern::Ordered);
213
214    while (State.KeepRunningBatch(TableSize * NumTables)) {
215      for (auto& Set : Data.Sets) {
216        for (auto it = Set.begin(); it != Set.end(); ++it) {
217          benchmark::DoNotOptimize(*it);
218        }
219      }
220    }
221  }
222
223  std::string name() const { return "BM_IterateBeginEnd" + baseName(); }
224};
225
226}  // namespace
227
228int main(int argc, char** argv) {
229  benchmark::Initialize(&argc, argv);
230  if (benchmark::ReportUnrecognizedArguments(argc, argv))
231    return 1;
232
233  const std::vector<size_t> TableSize{1, 10, 100, 1000, 10000, 100000, 1000000};
234  const std::vector<size_t> NumTables{1, 10, 100, 1000, 10000, 100000, 1000000};
235
236  makeCartesianProductBenchmark<Create, AllAccessPattern>(TableSize, NumTables);
237  makeCartesianProductBenchmark<Find, AllHitTypes, AllAccessPattern>(
238      TableSize, NumTables);
239  makeCartesianProductBenchmark<FindNeEnd, AllHitTypes, AllAccessPattern>(
240      TableSize, NumTables);
241  makeCartesianProductBenchmark<InsertHit, AllAccessPattern>(
242      TableSize, NumTables);
243  makeCartesianProductBenchmark<InsertMissAndErase, AllAccessPattern>(
244      TableSize, NumTables);
245  makeCartesianProductBenchmark<IterateRangeFor>(TableSize, NumTables);
246  makeCartesianProductBenchmark<IterateBeginEnd>(TableSize, NumTables);
247  benchmark::RunSpecifiedBenchmarks();
248}
249