1//== SMTConstraintManager.h -------------------------------------*- C++ -*--==//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9//  This file defines a SMT generic API, which will be the base class for
10//  every SMT solver specific class.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef LLVM_CLANG_STATICANALYZER_CORE_PATHSENSITIVE_SMTCONSTRAINTMANAGER_H
15#define LLVM_CLANG_STATICANALYZER_CORE_PATHSENSITIVE_SMTCONSTRAINTMANAGER_H
16
17#include "clang/Basic/JsonSupport.h"
18#include "clang/StaticAnalyzer/Core/PathSensitive/RangedConstraintManager.h"
19#include "clang/StaticAnalyzer/Core/PathSensitive/SMTConv.h"
20
21typedef llvm::ImmutableSet<
22    std::pair<clang::ento::SymbolRef, const llvm::SMTExpr *>>
23    ConstraintSMTType;
24REGISTER_TRAIT_WITH_PROGRAMSTATE(ConstraintSMT, ConstraintSMTType)
25
26namespace clang {
27namespace ento {
28
29class SMTConstraintManager : public clang::ento::SimpleConstraintManager {
30  mutable llvm::SMTSolverRef Solver = llvm::CreateZ3Solver();
31
32public:
33  SMTConstraintManager(clang::ento::SubEngine *SE, clang::ento::SValBuilder &SB)
34      : SimpleConstraintManager(SE, SB) {}
35  virtual ~SMTConstraintManager() = default;
36
37  //===------------------------------------------------------------------===//
38  // Implementation for interface from SimpleConstraintManager.
39  //===------------------------------------------------------------------===//
40
41  ProgramStateRef assumeSym(ProgramStateRef State, SymbolRef Sym,
42                            bool Assumption) override {
43    ASTContext &Ctx = getBasicVals().getContext();
44
45    QualType RetTy;
46    bool hasComparison;
47
48    llvm::SMTExprRef Exp =
49        SMTConv::getExpr(Solver, Ctx, Sym, &RetTy, &hasComparison);
50
51    // Create zero comparison for implicit boolean cast, with reversed
52    // assumption
53    if (!hasComparison && !RetTy->isBooleanType())
54      return assumeExpr(
55          State, Sym,
56          SMTConv::getZeroExpr(Solver, Ctx, Exp, RetTy, !Assumption));
57
58    return assumeExpr(State, Sym, Assumption ? Exp : Solver->mkNot(Exp));
59  }
60
61  ProgramStateRef assumeSymInclusiveRange(ProgramStateRef State, SymbolRef Sym,
62                                          const llvm::APSInt &From,
63                                          const llvm::APSInt &To,
64                                          bool InRange) override {
65    ASTContext &Ctx = getBasicVals().getContext();
66    return assumeExpr(
67        State, Sym, SMTConv::getRangeExpr(Solver, Ctx, Sym, From, To, InRange));
68  }
69
70  ProgramStateRef assumeSymUnsupported(ProgramStateRef State, SymbolRef Sym,
71                                       bool Assumption) override {
72    // Skip anything that is unsupported
73    return State;
74  }
75
76  //===------------------------------------------------------------------===//
77  // Implementation for interface from ConstraintManager.
78  //===------------------------------------------------------------------===//
79
80  ConditionTruthVal checkNull(ProgramStateRef State, SymbolRef Sym) override {
81    ASTContext &Ctx = getBasicVals().getContext();
82
83    QualType RetTy;
84    // The expression may be casted, so we cannot call getZ3DataExpr() directly
85    llvm::SMTExprRef VarExp = SMTConv::getExpr(Solver, Ctx, Sym, &RetTy);
86    llvm::SMTExprRef Exp =
87        SMTConv::getZeroExpr(Solver, Ctx, VarExp, RetTy, /*Assumption=*/true);
88
89    // Negate the constraint
90    llvm::SMTExprRef NotExp =
91        SMTConv::getZeroExpr(Solver, Ctx, VarExp, RetTy, /*Assumption=*/false);
92
93    ConditionTruthVal isSat = checkModel(State, Sym, Exp);
94    ConditionTruthVal isNotSat = checkModel(State, Sym, NotExp);
95
96    // Zero is the only possible solution
97    if (isSat.isConstrainedTrue() && isNotSat.isConstrainedFalse())
98      return true;
99
100    // Zero is not a solution
101    if (isSat.isConstrainedFalse() && isNotSat.isConstrainedTrue())
102      return false;
103
104    // Zero may be a solution
105    return ConditionTruthVal();
106  }
107
108  const llvm::APSInt *getSymVal(ProgramStateRef State,
109                                SymbolRef Sym) const override {
110    BasicValueFactory &BVF = getBasicVals();
111    ASTContext &Ctx = BVF.getContext();
112
113    if (const SymbolData *SD = dyn_cast<SymbolData>(Sym)) {
114      QualType Ty = Sym->getType();
115      assert(!Ty->isRealFloatingType());
116      llvm::APSInt Value(Ctx.getTypeSize(Ty),
117                         !Ty->isSignedIntegerOrEnumerationType());
118
119      // TODO: this should call checkModel so we can use the cache, however,
120      // this method tries to get the interpretation (the actual value) from
121      // the solver, which is currently not cached.
122
123      llvm::SMTExprRef Exp =
124          SMTConv::fromData(Solver, SD->getSymbolID(), Ty, Ctx.getTypeSize(Ty));
125
126      Solver->reset();
127      addStateConstraints(State);
128
129      // Constraints are unsatisfiable
130      Optional<bool> isSat = Solver->check();
131      if (!isSat.hasValue() || !isSat.getValue())
132        return nullptr;
133
134      // Model does not assign interpretation
135      if (!Solver->getInterpretation(Exp, Value))
136        return nullptr;
137
138      // A value has been obtained, check if it is the only value
139      llvm::SMTExprRef NotExp = SMTConv::fromBinOp(
140          Solver, Exp, BO_NE,
141          Ty->isBooleanType() ? Solver->mkBoolean(Value.getBoolValue())
142                              : Solver->mkBitvector(Value, Value.getBitWidth()),
143          /*isSigned=*/false);
144
145      Solver->addConstraint(NotExp);
146
147      Optional<bool> isNotSat = Solver->check();
148      if (!isSat.hasValue() || isNotSat.getValue())
149        return nullptr;
150
151      // This is the only solution, store it
152      return &BVF.getValue(Value);
153    }
154
155    if (const SymbolCast *SC = dyn_cast<SymbolCast>(Sym)) {
156      SymbolRef CastSym = SC->getOperand();
157      QualType CastTy = SC->getType();
158      // Skip the void type
159      if (CastTy->isVoidType())
160        return nullptr;
161
162      const llvm::APSInt *Value;
163      if (!(Value = getSymVal(State, CastSym)))
164        return nullptr;
165      return &BVF.Convert(SC->getType(), *Value);
166    }
167
168    if (const BinarySymExpr *BSE = dyn_cast<BinarySymExpr>(Sym)) {
169      const llvm::APSInt *LHS, *RHS;
170      if (const SymIntExpr *SIE = dyn_cast<SymIntExpr>(BSE)) {
171        LHS = getSymVal(State, SIE->getLHS());
172        RHS = &SIE->getRHS();
173      } else if (const IntSymExpr *ISE = dyn_cast<IntSymExpr>(BSE)) {
174        LHS = &ISE->getLHS();
175        RHS = getSymVal(State, ISE->getRHS());
176      } else if (const SymSymExpr *SSM = dyn_cast<SymSymExpr>(BSE)) {
177        // Early termination to avoid expensive call
178        LHS = getSymVal(State, SSM->getLHS());
179        RHS = LHS ? getSymVal(State, SSM->getRHS()) : nullptr;
180      } else {
181        llvm_unreachable("Unsupported binary expression to get symbol value!");
182      }
183
184      if (!LHS || !RHS)
185        return nullptr;
186
187      llvm::APSInt ConvertedLHS, ConvertedRHS;
188      QualType LTy, RTy;
189      std::tie(ConvertedLHS, LTy) = SMTConv::fixAPSInt(Ctx, *LHS);
190      std::tie(ConvertedRHS, RTy) = SMTConv::fixAPSInt(Ctx, *RHS);
191      SMTConv::doIntTypeConversion<llvm::APSInt, &SMTConv::castAPSInt>(
192          Solver, Ctx, ConvertedLHS, LTy, ConvertedRHS, RTy);
193      return BVF.evalAPSInt(BSE->getOpcode(), ConvertedLHS, ConvertedRHS);
194    }
195
196    llvm_unreachable("Unsupported expression to get symbol value!");
197  }
198
199  ProgramStateRef removeDeadBindings(ProgramStateRef State,
200                                     SymbolReaper &SymReaper) override {
201    auto CZ = State->get<ConstraintSMT>();
202    auto &CZFactory = State->get_context<ConstraintSMT>();
203
204    for (auto I = CZ.begin(), E = CZ.end(); I != E; ++I) {
205      if (SymReaper.isDead(I->first))
206        CZ = CZFactory.remove(CZ, *I);
207    }
208
209    return State->set<ConstraintSMT>(CZ);
210  }
211
212  void printJson(raw_ostream &Out, ProgramStateRef State, const char *NL = "\n",
213                 unsigned int Space = 0, bool IsDot = false) const override {
214    ConstraintSMTType Constraints = State->get<ConstraintSMT>();
215
216    Indent(Out, Space, IsDot) << "\"constraints\": ";
217    if (Constraints.isEmpty()) {
218      Out << "null," << NL;
219      return;
220    }
221
222    ++Space;
223    Out << '[' << NL;
224    for (ConstraintSMTType::iterator I = Constraints.begin();
225         I != Constraints.end(); ++I) {
226      Indent(Out, Space, IsDot)
227          << "{ \"symbol\": \"" << I->first << "\", \"range\": \"";
228      I->second->print(Out);
229      Out << "\" }";
230
231      if (std::next(I) != Constraints.end())
232        Out << ',';
233      Out << NL;
234    }
235
236    --Space;
237    Indent(Out, Space, IsDot) << "],";
238  }
239
240  bool haveEqualConstraints(ProgramStateRef S1,
241                            ProgramStateRef S2) const override {
242    return S1->get<ConstraintSMT>() == S2->get<ConstraintSMT>();
243  }
244
245  bool canReasonAbout(SVal X) const override {
246    const TargetInfo &TI = getBasicVals().getContext().getTargetInfo();
247
248    Optional<nonloc::SymbolVal> SymVal = X.getAs<nonloc::SymbolVal>();
249    if (!SymVal)
250      return true;
251
252    const SymExpr *Sym = SymVal->getSymbol();
253    QualType Ty = Sym->getType();
254
255    // Complex types are not modeled
256    if (Ty->isComplexType() || Ty->isComplexIntegerType())
257      return false;
258
259    // Non-IEEE 754 floating-point types are not modeled
260    if ((Ty->isSpecificBuiltinType(BuiltinType::LongDouble) &&
261         (&TI.getLongDoubleFormat() == &llvm::APFloat::x87DoubleExtended() ||
262          &TI.getLongDoubleFormat() == &llvm::APFloat::PPCDoubleDouble())))
263      return false;
264
265    if (Ty->isRealFloatingType())
266      return Solver->isFPSupported();
267
268    if (isa<SymbolData>(Sym))
269      return true;
270
271    SValBuilder &SVB = getSValBuilder();
272
273    if (const SymbolCast *SC = dyn_cast<SymbolCast>(Sym))
274      return canReasonAbout(SVB.makeSymbolVal(SC->getOperand()));
275
276    if (const BinarySymExpr *BSE = dyn_cast<BinarySymExpr>(Sym)) {
277      if (const SymIntExpr *SIE = dyn_cast<SymIntExpr>(BSE))
278        return canReasonAbout(SVB.makeSymbolVal(SIE->getLHS()));
279
280      if (const IntSymExpr *ISE = dyn_cast<IntSymExpr>(BSE))
281        return canReasonAbout(SVB.makeSymbolVal(ISE->getRHS()));
282
283      if (const SymSymExpr *SSE = dyn_cast<SymSymExpr>(BSE))
284        return canReasonAbout(SVB.makeSymbolVal(SSE->getLHS())) &&
285               canReasonAbout(SVB.makeSymbolVal(SSE->getRHS()));
286    }
287
288    llvm_unreachable("Unsupported expression to reason about!");
289  }
290
291  /// Dumps SMT formula
292  LLVM_DUMP_METHOD void dump() const { Solver->dump(); }
293
294protected:
295  // Check whether a new model is satisfiable, and update the program state.
296  virtual ProgramStateRef assumeExpr(ProgramStateRef State, SymbolRef Sym,
297                                     const llvm::SMTExprRef &Exp) {
298    // Check the model, avoid simplifying AST to save time
299    if (checkModel(State, Sym, Exp).isConstrainedTrue())
300      return State->add<ConstraintSMT>(std::make_pair(Sym, Exp));
301
302    return nullptr;
303  }
304
305  /// Given a program state, construct the logical conjunction and add it to
306  /// the solver
307  virtual void addStateConstraints(ProgramStateRef State) const {
308    // TODO: Don't add all the constraints, only the relevant ones
309    auto CZ = State->get<ConstraintSMT>();
310    auto I = CZ.begin(), IE = CZ.end();
311
312    // Construct the logical AND of all the constraints
313    if (I != IE) {
314      std::vector<llvm::SMTExprRef> ASTs;
315
316      llvm::SMTExprRef Constraint = I++->second;
317      while (I != IE) {
318        Constraint = Solver->mkAnd(Constraint, I++->second);
319      }
320
321      Solver->addConstraint(Constraint);
322    }
323  }
324
325  // Generate and check a Z3 model, using the given constraint.
326  ConditionTruthVal checkModel(ProgramStateRef State, SymbolRef Sym,
327                               const llvm::SMTExprRef &Exp) const {
328    ProgramStateRef NewState =
329        State->add<ConstraintSMT>(std::make_pair(Sym, Exp));
330
331    llvm::FoldingSetNodeID ID;
332    NewState->get<ConstraintSMT>().Profile(ID);
333
334    unsigned hash = ID.ComputeHash();
335    auto I = Cached.find(hash);
336    if (I != Cached.end())
337      return I->second;
338
339    Solver->reset();
340    addStateConstraints(NewState);
341
342    Optional<bool> res = Solver->check();
343    if (!res.hasValue())
344      Cached[hash] = ConditionTruthVal();
345    else
346      Cached[hash] = ConditionTruthVal(res.getValue());
347
348    return Cached[hash];
349  }
350
351  // Cache the result of an SMT query (true, false, unknown). The key is the
352  // hash of the constraints in a state
353  mutable llvm::DenseMap<unsigned, ConditionTruthVal> Cached;
354}; // end class SMTConstraintManager
355
356} // namespace ento
357} // namespace clang
358
359#endif
360