1//===-- TimeProfiler.cpp - Hierarchical Time Profiler ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements hierarchical time profiler.
10//
11//===----------------------------------------------------------------------===//
12
13#include "llvm/Support/TimeProfiler.h"
14#include "llvm/ADT/StringMap.h"
15#include "llvm/Support/CommandLine.h"
16#include "llvm/Support/JSON.h"
17#include "llvm/Support/Path.h"
18#include <cassert>
19#include <chrono>
20#include <string>
21#include <vector>
22
23using namespace std::chrono;
24
25namespace llvm {
26
27TimeTraceProfiler *TimeTraceProfilerInstance = nullptr;
28
29typedef duration<steady_clock::rep, steady_clock::period> DurationType;
30typedef time_point<steady_clock> TimePointType;
31typedef std::pair<size_t, DurationType> CountAndDurationType;
32typedef std::pair<std::string, CountAndDurationType>
33    NameAndCountAndDurationType;
34
35struct Entry {
36  const TimePointType Start;
37  TimePointType End;
38  const std::string Name;
39  const std::string Detail;
40
41  Entry(TimePointType &&S, TimePointType &&E, std::string &&N, std::string &&Dt)
42      : Start(std::move(S)), End(std::move(E)), Name(std::move(N)),
43        Detail(std::move(Dt)) {}
44
45  // Calculate timings for FlameGraph. Cast time points to microsecond precision
46  // rather than casting duration. This avoid truncation issues causing inner
47  // scopes overruning outer scopes.
48  steady_clock::rep getFlameGraphStartUs(TimePointType StartTime) const {
49    return (time_point_cast<microseconds>(Start) -
50            time_point_cast<microseconds>(StartTime))
51        .count();
52  }
53
54  steady_clock::rep getFlameGraphDurUs() const {
55    return (time_point_cast<microseconds>(End) -
56            time_point_cast<microseconds>(Start))
57        .count();
58  }
59};
60
61struct TimeTraceProfiler {
62  TimeTraceProfiler(unsigned TimeTraceGranularity = 0, StringRef ProcName = "")
63      : StartTime(steady_clock::now()), ProcName(ProcName),
64        TimeTraceGranularity(TimeTraceGranularity) {}
65
66  void begin(std::string Name, llvm::function_ref<std::string()> Detail) {
67    Stack.emplace_back(steady_clock::now(), TimePointType(), std::move(Name),
68                       Detail());
69  }
70
71  void end() {
72    assert(!Stack.empty() && "Must call begin() first");
73    auto &E = Stack.back();
74    E.End = steady_clock::now();
75
76    // Check that end times monotonically increase.
77    assert((Entries.empty() ||
78            (E.getFlameGraphStartUs(StartTime) + E.getFlameGraphDurUs() >=
79             Entries.back().getFlameGraphStartUs(StartTime) +
80                 Entries.back().getFlameGraphDurUs())) &&
81           "TimeProfiler scope ended earlier than previous scope");
82
83    // Calculate duration at full precision for overall counts.
84    DurationType Duration = E.End - E.Start;
85
86    // Only include sections longer or equal to TimeTraceGranularity msec.
87    if (duration_cast<microseconds>(Duration).count() >= TimeTraceGranularity)
88      Entries.emplace_back(E);
89
90    // Track total time taken by each "name", but only the topmost levels of
91    // them; e.g. if there's a template instantiation that instantiates other
92    // templates from within, we only want to add the topmost one. "topmost"
93    // happens to be the ones that don't have any currently open entries above
94    // itself.
95    if (std::find_if(++Stack.rbegin(), Stack.rend(), [&](const Entry &Val) {
96          return Val.Name == E.Name;
97        }) == Stack.rend()) {
98      auto &CountAndTotal = CountAndTotalPerName[E.Name];
99      CountAndTotal.first++;
100      CountAndTotal.second += Duration;
101    }
102
103    Stack.pop_back();
104  }
105
106  void Write(raw_pwrite_stream &OS) {
107    assert(Stack.empty() &&
108           "All profiler sections should be ended when calling Write");
109    json::OStream J(OS);
110    J.objectBegin();
111    J.attributeBegin("traceEvents");
112    J.arrayBegin();
113
114    // Emit all events for the main flame graph.
115    for (const auto &E : Entries) {
116      auto StartUs = E.getFlameGraphStartUs(StartTime);
117      auto DurUs = E.getFlameGraphDurUs();
118
119      J.object([&]{
120        J.attribute("pid", 1);
121        J.attribute("tid", 0);
122        J.attribute("ph", "X");
123        J.attribute("ts", StartUs);
124        J.attribute("dur", DurUs);
125        J.attribute("name", E.Name);
126        if (!E.Detail.empty()) {
127          J.attributeObject("args", [&] { J.attribute("detail", E.Detail); });
128        }
129      });
130    }
131
132    // Emit totals by section name as additional "thread" events, sorted from
133    // longest one.
134    int Tid = 1;
135    std::vector<NameAndCountAndDurationType> SortedTotals;
136    SortedTotals.reserve(CountAndTotalPerName.size());
137    for (const auto &E : CountAndTotalPerName)
138      SortedTotals.emplace_back(E.getKey(), E.getValue());
139
140    llvm::sort(SortedTotals.begin(), SortedTotals.end(),
141               [](const NameAndCountAndDurationType &A,
142                  const NameAndCountAndDurationType &B) {
143                 return A.second.second > B.second.second;
144               });
145    for (const auto &E : SortedTotals) {
146      auto DurUs = duration_cast<microseconds>(E.second.second).count();
147      auto Count = CountAndTotalPerName[E.first].first;
148
149      J.object([&]{
150        J.attribute("pid", 1);
151        J.attribute("tid", Tid);
152        J.attribute("ph", "X");
153        J.attribute("ts", 0);
154        J.attribute("dur", DurUs);
155        J.attribute("name", "Total " + E.first);
156        J.attributeObject("args", [&] {
157          J.attribute("count", int64_t(Count));
158          J.attribute("avg ms", int64_t(DurUs / Count / 1000));
159        });
160      });
161
162      ++Tid;
163    }
164
165    // Emit metadata event with process name.
166    J.object([&] {
167      J.attribute("cat", "");
168      J.attribute("pid", 1);
169      J.attribute("tid", 0);
170      J.attribute("ts", 0);
171      J.attribute("ph", "M");
172      J.attribute("name", "process_name");
173      J.attributeObject("args", [&] { J.attribute("name", ProcName); });
174    });
175
176    J.arrayEnd();
177    J.attributeEnd();
178    J.objectEnd();
179  }
180
181  SmallVector<Entry, 16> Stack;
182  SmallVector<Entry, 128> Entries;
183  StringMap<CountAndDurationType> CountAndTotalPerName;
184  const TimePointType StartTime;
185  const std::string ProcName;
186
187  // Minimum time granularity (in microseconds)
188  const unsigned TimeTraceGranularity;
189};
190
191void timeTraceProfilerInitialize(unsigned TimeTraceGranularity,
192                                 StringRef ProcName) {
193  assert(TimeTraceProfilerInstance == nullptr &&
194         "Profiler should not be initialized");
195  TimeTraceProfilerInstance = new TimeTraceProfiler(
196      TimeTraceGranularity, llvm::sys::path::filename(ProcName));
197}
198
199void timeTraceProfilerCleanup() {
200  delete TimeTraceProfilerInstance;
201  TimeTraceProfilerInstance = nullptr;
202}
203
204void timeTraceProfilerWrite(raw_pwrite_stream &OS) {
205  assert(TimeTraceProfilerInstance != nullptr &&
206         "Profiler object can't be null");
207  TimeTraceProfilerInstance->Write(OS);
208}
209
210void timeTraceProfilerBegin(StringRef Name, StringRef Detail) {
211  if (TimeTraceProfilerInstance != nullptr)
212    TimeTraceProfilerInstance->begin(Name, [&]() { return Detail; });
213}
214
215void timeTraceProfilerBegin(StringRef Name,
216                            llvm::function_ref<std::string()> Detail) {
217  if (TimeTraceProfilerInstance != nullptr)
218    TimeTraceProfilerInstance->begin(Name, Detail);
219}
220
221void timeTraceProfilerEnd() {
222  if (TimeTraceProfilerInstance != nullptr)
223    TimeTraceProfilerInstance->end();
224}
225
226} // namespace llvm
227