1//== SMTConstraintManager.h -------------------------------------*- C++ -*--==//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9//  This file defines a SMT generic API, which will be the base class for
10//  every SMT solver specific class.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef LLVM_CLANG_STATICANALYZER_CORE_PATHSENSITIVE_SMTCONSTRAINTMANAGER_H
15#define LLVM_CLANG_STATICANALYZER_CORE_PATHSENSITIVE_SMTCONSTRAINTMANAGER_H
16
17#include "clang/Basic/JsonSupport.h"
18#include "clang/Basic/TargetInfo.h"
19#include "clang/StaticAnalyzer/Core/PathSensitive/RangedConstraintManager.h"
20#include "clang/StaticAnalyzer/Core/PathSensitive/SMTConv.h"
21#include <optional>
22
23typedef llvm::ImmutableSet<
24    std::pair<clang::ento::SymbolRef, const llvm::SMTExpr *>>
25    ConstraintSMTType;
26REGISTER_TRAIT_WITH_PROGRAMSTATE(ConstraintSMT, ConstraintSMTType)
27
28namespace clang {
29namespace ento {
30
31class SMTConstraintManager : public clang::ento::SimpleConstraintManager {
32  mutable llvm::SMTSolverRef Solver = llvm::CreateZ3Solver();
33
34public:
35  SMTConstraintManager(clang::ento::ExprEngine *EE,
36                       clang::ento::SValBuilder &SB)
37      : SimpleConstraintManager(EE, SB) {}
38  virtual ~SMTConstraintManager() = default;
39
40  //===------------------------------------------------------------------===//
41  // Implementation for interface from SimpleConstraintManager.
42  //===------------------------------------------------------------------===//
43
44  ProgramStateRef assumeSym(ProgramStateRef State, SymbolRef Sym,
45                            bool Assumption) override {
46    ASTContext &Ctx = getBasicVals().getContext();
47
48    QualType RetTy;
49    bool hasComparison;
50
51    llvm::SMTExprRef Exp =
52        SMTConv::getExpr(Solver, Ctx, Sym, &RetTy, &hasComparison);
53
54    // Create zero comparison for implicit boolean cast, with reversed
55    // assumption
56    if (!hasComparison && !RetTy->isBooleanType())
57      return assumeExpr(
58          State, Sym,
59          SMTConv::getZeroExpr(Solver, Ctx, Exp, RetTy, !Assumption));
60
61    return assumeExpr(State, Sym, Assumption ? Exp : Solver->mkNot(Exp));
62  }
63
64  ProgramStateRef assumeSymInclusiveRange(ProgramStateRef State, SymbolRef Sym,
65                                          const llvm::APSInt &From,
66                                          const llvm::APSInt &To,
67                                          bool InRange) override {
68    ASTContext &Ctx = getBasicVals().getContext();
69    return assumeExpr(
70        State, Sym, SMTConv::getRangeExpr(Solver, Ctx, Sym, From, To, InRange));
71  }
72
73  ProgramStateRef assumeSymUnsupported(ProgramStateRef State, SymbolRef Sym,
74                                       bool Assumption) override {
75    // Skip anything that is unsupported
76    return State;
77  }
78
79  //===------------------------------------------------------------------===//
80  // Implementation for interface from ConstraintManager.
81  //===------------------------------------------------------------------===//
82
83  ConditionTruthVal checkNull(ProgramStateRef State, SymbolRef Sym) override {
84    ASTContext &Ctx = getBasicVals().getContext();
85
86    QualType RetTy;
87    // The expression may be casted, so we cannot call getZ3DataExpr() directly
88    llvm::SMTExprRef VarExp = SMTConv::getExpr(Solver, Ctx, Sym, &RetTy);
89    llvm::SMTExprRef Exp =
90        SMTConv::getZeroExpr(Solver, Ctx, VarExp, RetTy, /*Assumption=*/true);
91
92    // Negate the constraint
93    llvm::SMTExprRef NotExp =
94        SMTConv::getZeroExpr(Solver, Ctx, VarExp, RetTy, /*Assumption=*/false);
95
96    ConditionTruthVal isSat = checkModel(State, Sym, Exp);
97    ConditionTruthVal isNotSat = checkModel(State, Sym, NotExp);
98
99    // Zero is the only possible solution
100    if (isSat.isConstrainedTrue() && isNotSat.isConstrainedFalse())
101      return true;
102
103    // Zero is not a solution
104    if (isSat.isConstrainedFalse() && isNotSat.isConstrainedTrue())
105      return false;
106
107    // Zero may be a solution
108    return ConditionTruthVal();
109  }
110
111  const llvm::APSInt *getSymVal(ProgramStateRef State,
112                                SymbolRef Sym) const override {
113    BasicValueFactory &BVF = getBasicVals();
114    ASTContext &Ctx = BVF.getContext();
115
116    if (const SymbolData *SD = dyn_cast<SymbolData>(Sym)) {
117      QualType Ty = Sym->getType();
118      assert(!Ty->isRealFloatingType());
119      llvm::APSInt Value(Ctx.getTypeSize(Ty),
120                         !Ty->isSignedIntegerOrEnumerationType());
121
122      // TODO: this should call checkModel so we can use the cache, however,
123      // this method tries to get the interpretation (the actual value) from
124      // the solver, which is currently not cached.
125
126      llvm::SMTExprRef Exp = SMTConv::fromData(Solver, Ctx, SD);
127
128      Solver->reset();
129      addStateConstraints(State);
130
131      // Constraints are unsatisfiable
132      std::optional<bool> isSat = Solver->check();
133      if (!isSat || !*isSat)
134        return nullptr;
135
136      // Model does not assign interpretation
137      if (!Solver->getInterpretation(Exp, Value))
138        return nullptr;
139
140      // A value has been obtained, check if it is the only value
141      llvm::SMTExprRef NotExp = SMTConv::fromBinOp(
142          Solver, Exp, BO_NE,
143          Ty->isBooleanType() ? Solver->mkBoolean(Value.getBoolValue())
144                              : Solver->mkBitvector(Value, Value.getBitWidth()),
145          /*isSigned=*/false);
146
147      Solver->addConstraint(NotExp);
148
149      std::optional<bool> isNotSat = Solver->check();
150      if (!isNotSat || *isNotSat)
151        return nullptr;
152
153      // This is the only solution, store it
154      return &BVF.getValue(Value);
155    }
156
157    if (const SymbolCast *SC = dyn_cast<SymbolCast>(Sym)) {
158      SymbolRef CastSym = SC->getOperand();
159      QualType CastTy = SC->getType();
160      // Skip the void type
161      if (CastTy->isVoidType())
162        return nullptr;
163
164      const llvm::APSInt *Value;
165      if (!(Value = getSymVal(State, CastSym)))
166        return nullptr;
167      return &BVF.Convert(SC->getType(), *Value);
168    }
169
170    if (const BinarySymExpr *BSE = dyn_cast<BinarySymExpr>(Sym)) {
171      const llvm::APSInt *LHS, *RHS;
172      if (const SymIntExpr *SIE = dyn_cast<SymIntExpr>(BSE)) {
173        LHS = getSymVal(State, SIE->getLHS());
174        RHS = &SIE->getRHS();
175      } else if (const IntSymExpr *ISE = dyn_cast<IntSymExpr>(BSE)) {
176        LHS = &ISE->getLHS();
177        RHS = getSymVal(State, ISE->getRHS());
178      } else if (const SymSymExpr *SSM = dyn_cast<SymSymExpr>(BSE)) {
179        // Early termination to avoid expensive call
180        LHS = getSymVal(State, SSM->getLHS());
181        RHS = LHS ? getSymVal(State, SSM->getRHS()) : nullptr;
182      } else {
183        llvm_unreachable("Unsupported binary expression to get symbol value!");
184      }
185
186      if (!LHS || !RHS)
187        return nullptr;
188
189      llvm::APSInt ConvertedLHS, ConvertedRHS;
190      QualType LTy, RTy;
191      std::tie(ConvertedLHS, LTy) = SMTConv::fixAPSInt(Ctx, *LHS);
192      std::tie(ConvertedRHS, RTy) = SMTConv::fixAPSInt(Ctx, *RHS);
193      SMTConv::doIntTypeConversion<llvm::APSInt, &SMTConv::castAPSInt>(
194          Solver, Ctx, ConvertedLHS, LTy, ConvertedRHS, RTy);
195      return BVF.evalAPSInt(BSE->getOpcode(), ConvertedLHS, ConvertedRHS);
196    }
197
198    llvm_unreachable("Unsupported expression to get symbol value!");
199  }
200
201  ProgramStateRef removeDeadBindings(ProgramStateRef State,
202                                     SymbolReaper &SymReaper) override {
203    auto CZ = State->get<ConstraintSMT>();
204    auto &CZFactory = State->get_context<ConstraintSMT>();
205
206    for (auto I = CZ.begin(), E = CZ.end(); I != E; ++I) {
207      if (SymReaper.isDead(I->first))
208        CZ = CZFactory.remove(CZ, *I);
209    }
210
211    return State->set<ConstraintSMT>(CZ);
212  }
213
214  void printJson(raw_ostream &Out, ProgramStateRef State, const char *NL = "\n",
215                 unsigned int Space = 0, bool IsDot = false) const override {
216    ConstraintSMTType Constraints = State->get<ConstraintSMT>();
217
218    Indent(Out, Space, IsDot) << "\"constraints\": ";
219    if (Constraints.isEmpty()) {
220      Out << "null," << NL;
221      return;
222    }
223
224    ++Space;
225    Out << '[' << NL;
226    for (ConstraintSMTType::iterator I = Constraints.begin();
227         I != Constraints.end(); ++I) {
228      Indent(Out, Space, IsDot)
229          << "{ \"symbol\": \"" << I->first << "\", \"range\": \"";
230      I->second->print(Out);
231      Out << "\" }";
232
233      if (std::next(I) != Constraints.end())
234        Out << ',';
235      Out << NL;
236    }
237
238    --Space;
239    Indent(Out, Space, IsDot) << "],";
240  }
241
242  bool haveEqualConstraints(ProgramStateRef S1,
243                            ProgramStateRef S2) const override {
244    return S1->get<ConstraintSMT>() == S2->get<ConstraintSMT>();
245  }
246
247  bool canReasonAbout(SVal X) const override {
248    const TargetInfo &TI = getBasicVals().getContext().getTargetInfo();
249
250    std::optional<nonloc::SymbolVal> SymVal = X.getAs<nonloc::SymbolVal>();
251    if (!SymVal)
252      return true;
253
254    const SymExpr *Sym = SymVal->getSymbol();
255    QualType Ty = Sym->getType();
256
257    // Complex types are not modeled
258    if (Ty->isComplexType() || Ty->isComplexIntegerType())
259      return false;
260
261    // Non-IEEE 754 floating-point types are not modeled
262    if ((Ty->isSpecificBuiltinType(BuiltinType::LongDouble) &&
263         (&TI.getLongDoubleFormat() == &llvm::APFloat::x87DoubleExtended() ||
264          &TI.getLongDoubleFormat() == &llvm::APFloat::PPCDoubleDouble())))
265      return false;
266
267    if (Ty->isRealFloatingType())
268      return Solver->isFPSupported();
269
270    if (isa<SymbolData>(Sym))
271      return true;
272
273    SValBuilder &SVB = getSValBuilder();
274
275    if (const SymbolCast *SC = dyn_cast<SymbolCast>(Sym))
276      return canReasonAbout(SVB.makeSymbolVal(SC->getOperand()));
277
278    if (const BinarySymExpr *BSE = dyn_cast<BinarySymExpr>(Sym)) {
279      if (const SymIntExpr *SIE = dyn_cast<SymIntExpr>(BSE))
280        return canReasonAbout(SVB.makeSymbolVal(SIE->getLHS()));
281
282      if (const IntSymExpr *ISE = dyn_cast<IntSymExpr>(BSE))
283        return canReasonAbout(SVB.makeSymbolVal(ISE->getRHS()));
284
285      if (const SymSymExpr *SSE = dyn_cast<SymSymExpr>(BSE))
286        return canReasonAbout(SVB.makeSymbolVal(SSE->getLHS())) &&
287               canReasonAbout(SVB.makeSymbolVal(SSE->getRHS()));
288    }
289
290    llvm_unreachable("Unsupported expression to reason about!");
291  }
292
293  /// Dumps SMT formula
294  LLVM_DUMP_METHOD void dump() const { Solver->dump(); }
295
296protected:
297  // Check whether a new model is satisfiable, and update the program state.
298  virtual ProgramStateRef assumeExpr(ProgramStateRef State, SymbolRef Sym,
299                                     const llvm::SMTExprRef &Exp) {
300    // Check the model, avoid simplifying AST to save time
301    if (checkModel(State, Sym, Exp).isConstrainedTrue())
302      return State->add<ConstraintSMT>(std::make_pair(Sym, Exp));
303
304    return nullptr;
305  }
306
307  /// Given a program state, construct the logical conjunction and add it to
308  /// the solver
309  virtual void addStateConstraints(ProgramStateRef State) const {
310    // TODO: Don't add all the constraints, only the relevant ones
311    auto CZ = State->get<ConstraintSMT>();
312    auto I = CZ.begin(), IE = CZ.end();
313
314    // Construct the logical AND of all the constraints
315    if (I != IE) {
316      std::vector<llvm::SMTExprRef> ASTs;
317
318      llvm::SMTExprRef Constraint = I++->second;
319      while (I != IE) {
320        Constraint = Solver->mkAnd(Constraint, I++->second);
321      }
322
323      Solver->addConstraint(Constraint);
324    }
325  }
326
327  // Generate and check a Z3 model, using the given constraint.
328  ConditionTruthVal checkModel(ProgramStateRef State, SymbolRef Sym,
329                               const llvm::SMTExprRef &Exp) const {
330    ProgramStateRef NewState =
331        State->add<ConstraintSMT>(std::make_pair(Sym, Exp));
332
333    llvm::FoldingSetNodeID ID;
334    NewState->get<ConstraintSMT>().Profile(ID);
335
336    unsigned hash = ID.ComputeHash();
337    auto I = Cached.find(hash);
338    if (I != Cached.end())
339      return I->second;
340
341    Solver->reset();
342    addStateConstraints(NewState);
343
344    std::optional<bool> res = Solver->check();
345    if (!res)
346      Cached[hash] = ConditionTruthVal();
347    else
348      Cached[hash] = ConditionTruthVal(*res);
349
350    return Cached[hash];
351  }
352
353  // Cache the result of an SMT query (true, false, unknown). The key is the
354  // hash of the constraints in a state
355  mutable llvm::DenseMap<unsigned, ConditionTruthVal> Cached;
356}; // end class SMTConstraintManager
357
358} // namespace ento
359} // namespace clang
360
361#endif
362