1//===- SMTAPI.h -------------------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9//  This file defines a SMT generic Solver API, which will be the base class
10//  for every SMT solver specific class.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef LLVM_SUPPORT_SMTAPI_H
15#define LLVM_SUPPORT_SMTAPI_H
16
17#include "llvm/ADT/APFloat.h"
18#include "llvm/ADT/APSInt.h"
19#include "llvm/ADT/FoldingSet.h"
20#include "llvm/Support/raw_ostream.h"
21#include <memory>
22
23namespace llvm {
24
25/// Generic base class for SMT sorts
26class SMTSort {
27public:
28  SMTSort() = default;
29  virtual ~SMTSort() = default;
30
31  /// Returns true if the sort is a bitvector, calls isBitvectorSortImpl().
32  virtual bool isBitvectorSort() const { return isBitvectorSortImpl(); }
33
34  /// Returns true if the sort is a floating-point, calls isFloatSortImpl().
35  virtual bool isFloatSort() const { return isFloatSortImpl(); }
36
37  /// Returns true if the sort is a boolean, calls isBooleanSortImpl().
38  virtual bool isBooleanSort() const { return isBooleanSortImpl(); }
39
40  /// Returns the bitvector size, fails if the sort is not a bitvector
41  /// Calls getBitvectorSortSizeImpl().
42  virtual unsigned getBitvectorSortSize() const {
43    assert(isBitvectorSort() && "Not a bitvector sort!");
44    unsigned Size = getBitvectorSortSizeImpl();
45    assert(Size && "Size is zero!");
46    return Size;
47  };
48
49  /// Returns the floating-point size, fails if the sort is not a floating-point
50  /// Calls getFloatSortSizeImpl().
51  virtual unsigned getFloatSortSize() const {
52    assert(isFloatSort() && "Not a floating-point sort!");
53    unsigned Size = getFloatSortSizeImpl();
54    assert(Size && "Size is zero!");
55    return Size;
56  };
57
58  virtual void Profile(llvm::FoldingSetNodeID &ID) const = 0;
59
60  bool operator<(const SMTSort &Other) const {
61    llvm::FoldingSetNodeID ID1, ID2;
62    Profile(ID1);
63    Other.Profile(ID2);
64    return ID1 < ID2;
65  }
66
67  friend bool operator==(SMTSort const &LHS, SMTSort const &RHS) {
68    return LHS.equal_to(RHS);
69  }
70
71  virtual void print(raw_ostream &OS) const = 0;
72
73  LLVM_DUMP_METHOD void dump() const;
74
75protected:
76  /// Query the SMT solver and returns true if two sorts are equal (same kind
77  /// and bit width). This does not check if the two sorts are the same objects.
78  virtual bool equal_to(SMTSort const &other) const = 0;
79
80  /// Query the SMT solver and checks if a sort is bitvector.
81  virtual bool isBitvectorSortImpl() const = 0;
82
83  /// Query the SMT solver and checks if a sort is floating-point.
84  virtual bool isFloatSortImpl() const = 0;
85
86  /// Query the SMT solver and checks if a sort is boolean.
87  virtual bool isBooleanSortImpl() const = 0;
88
89  /// Query the SMT solver and returns the sort bit width.
90  virtual unsigned getBitvectorSortSizeImpl() const = 0;
91
92  /// Query the SMT solver and returns the sort bit width.
93  virtual unsigned getFloatSortSizeImpl() const = 0;
94};
95
96/// Shared pointer for SMTSorts, used by SMTSolver API.
97using SMTSortRef = const SMTSort *;
98
99/// Generic base class for SMT exprs
100class SMTExpr {
101public:
102  SMTExpr() = default;
103  virtual ~SMTExpr() = default;
104
105  bool operator<(const SMTExpr &Other) const {
106    llvm::FoldingSetNodeID ID1, ID2;
107    Profile(ID1);
108    Other.Profile(ID2);
109    return ID1 < ID2;
110  }
111
112  virtual void Profile(llvm::FoldingSetNodeID &ID) const = 0;
113
114  friend bool operator==(SMTExpr const &LHS, SMTExpr const &RHS) {
115    return LHS.equal_to(RHS);
116  }
117
118  virtual void print(raw_ostream &OS) const = 0;
119
120  LLVM_DUMP_METHOD void dump() const;
121
122protected:
123  /// Query the SMT solver and returns true if two sorts are equal (same kind
124  /// and bit width). This does not check if the two sorts are the same objects.
125  virtual bool equal_to(SMTExpr const &other) const = 0;
126};
127
128/// Shared pointer for SMTExprs, used by SMTSolver API.
129using SMTExprRef = const SMTExpr *;
130
131/// Generic base class for SMT Solvers
132///
133/// This class is responsible for wrapping all sorts and expression generation,
134/// through the mk* methods. It also provides methods to create SMT expressions
135/// straight from clang's AST, through the from* methods.
136class SMTSolver {
137public:
138  SMTSolver() = default;
139  virtual ~SMTSolver() = default;
140
141  LLVM_DUMP_METHOD void dump() const;
142
143  // Returns an appropriate floating-point sort for the given bitwidth.
144  SMTSortRef getFloatSort(unsigned BitWidth) {
145    switch (BitWidth) {
146    case 16:
147      return getFloat16Sort();
148    case 32:
149      return getFloat32Sort();
150    case 64:
151      return getFloat64Sort();
152    case 128:
153      return getFloat128Sort();
154    default:;
155    }
156    llvm_unreachable("Unsupported floating-point bitwidth!");
157  }
158
159  // Returns a boolean sort.
160  virtual SMTSortRef getBoolSort() = 0;
161
162  // Returns an appropriate bitvector sort for the given bitwidth.
163  virtual SMTSortRef getBitvectorSort(const unsigned BitWidth) = 0;
164
165  // Returns a floating-point sort of width 16
166  virtual SMTSortRef getFloat16Sort() = 0;
167
168  // Returns a floating-point sort of width 32
169  virtual SMTSortRef getFloat32Sort() = 0;
170
171  // Returns a floating-point sort of width 64
172  virtual SMTSortRef getFloat64Sort() = 0;
173
174  // Returns a floating-point sort of width 128
175  virtual SMTSortRef getFloat128Sort() = 0;
176
177  // Returns an appropriate sort for the given AST.
178  virtual SMTSortRef getSort(const SMTExprRef &AST) = 0;
179
180  /// Given a constraint, adds it to the solver
181  virtual void addConstraint(const SMTExprRef &Exp) const = 0;
182
183  /// Creates a bitvector addition operation
184  virtual SMTExprRef mkBVAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
185
186  /// Creates a bitvector subtraction operation
187  virtual SMTExprRef mkBVSub(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
188
189  /// Creates a bitvector multiplication operation
190  virtual SMTExprRef mkBVMul(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
191
192  /// Creates a bitvector signed modulus operation
193  virtual SMTExprRef mkBVSRem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
194
195  /// Creates a bitvector unsigned modulus operation
196  virtual SMTExprRef mkBVURem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
197
198  /// Creates a bitvector signed division operation
199  virtual SMTExprRef mkBVSDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
200
201  /// Creates a bitvector unsigned division operation
202  virtual SMTExprRef mkBVUDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
203
204  /// Creates a bitvector logical shift left operation
205  virtual SMTExprRef mkBVShl(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
206
207  /// Creates a bitvector arithmetic shift right operation
208  virtual SMTExprRef mkBVAshr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
209
210  /// Creates a bitvector logical shift right operation
211  virtual SMTExprRef mkBVLshr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
212
213  /// Creates a bitvector negation operation
214  virtual SMTExprRef mkBVNeg(const SMTExprRef &Exp) = 0;
215
216  /// Creates a bitvector not operation
217  virtual SMTExprRef mkBVNot(const SMTExprRef &Exp) = 0;
218
219  /// Creates a bitvector xor operation
220  virtual SMTExprRef mkBVXor(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
221
222  /// Creates a bitvector or operation
223  virtual SMTExprRef mkBVOr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
224
225  /// Creates a bitvector and operation
226  virtual SMTExprRef mkBVAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
227
228  /// Creates a bitvector unsigned less-than operation
229  virtual SMTExprRef mkBVUlt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
230
231  /// Creates a bitvector signed less-than operation
232  virtual SMTExprRef mkBVSlt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
233
234  /// Creates a bitvector unsigned greater-than operation
235  virtual SMTExprRef mkBVUgt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
236
237  /// Creates a bitvector signed greater-than operation
238  virtual SMTExprRef mkBVSgt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
239
240  /// Creates a bitvector unsigned less-equal-than operation
241  virtual SMTExprRef mkBVUle(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
242
243  /// Creates a bitvector signed less-equal-than operation
244  virtual SMTExprRef mkBVSle(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
245
246  /// Creates a bitvector unsigned greater-equal-than operation
247  virtual SMTExprRef mkBVUge(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
248
249  /// Creates a bitvector signed greater-equal-than operation
250  virtual SMTExprRef mkBVSge(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
251
252  /// Creates a boolean not operation
253  virtual SMTExprRef mkNot(const SMTExprRef &Exp) = 0;
254
255  /// Creates a boolean equality operation
256  virtual SMTExprRef mkEqual(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
257
258  /// Creates a boolean and operation
259  virtual SMTExprRef mkAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
260
261  /// Creates a boolean or operation
262  virtual SMTExprRef mkOr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
263
264  /// Creates a boolean ite operation
265  virtual SMTExprRef mkIte(const SMTExprRef &Cond, const SMTExprRef &T,
266                           const SMTExprRef &F) = 0;
267
268  /// Creates a bitvector sign extension operation
269  virtual SMTExprRef mkBVSignExt(unsigned i, const SMTExprRef &Exp) = 0;
270
271  /// Creates a bitvector zero extension operation
272  virtual SMTExprRef mkBVZeroExt(unsigned i, const SMTExprRef &Exp) = 0;
273
274  /// Creates a bitvector extract operation
275  virtual SMTExprRef mkBVExtract(unsigned High, unsigned Low,
276                                 const SMTExprRef &Exp) = 0;
277
278  /// Creates a bitvector concat operation
279  virtual SMTExprRef mkBVConcat(const SMTExprRef &LHS,
280                                const SMTExprRef &RHS) = 0;
281
282  /// Creates a predicate that checks for overflow in a bitvector addition
283  /// operation
284  virtual SMTExprRef mkBVAddNoOverflow(const SMTExprRef &LHS,
285                                       const SMTExprRef &RHS,
286                                       bool isSigned) = 0;
287
288  /// Creates a predicate that checks for underflow in a signed bitvector
289  /// addition operation
290  virtual SMTExprRef mkBVAddNoUnderflow(const SMTExprRef &LHS,
291                                        const SMTExprRef &RHS) = 0;
292
293  /// Creates a predicate that checks for overflow in a signed bitvector
294  /// subtraction operation
295  virtual SMTExprRef mkBVSubNoOverflow(const SMTExprRef &LHS,
296                                       const SMTExprRef &RHS) = 0;
297
298  /// Creates a predicate that checks for underflow in a bitvector subtraction
299  /// operation
300  virtual SMTExprRef mkBVSubNoUnderflow(const SMTExprRef &LHS,
301                                        const SMTExprRef &RHS,
302                                        bool isSigned) = 0;
303
304  /// Creates a predicate that checks for overflow in a signed bitvector
305  /// division/modulus operation
306  virtual SMTExprRef mkBVSDivNoOverflow(const SMTExprRef &LHS,
307                                        const SMTExprRef &RHS) = 0;
308
309  /// Creates a predicate that checks for overflow in a bitvector negation
310  /// operation
311  virtual SMTExprRef mkBVNegNoOverflow(const SMTExprRef &Exp) = 0;
312
313  /// Creates a predicate that checks for overflow in a bitvector multiplication
314  /// operation
315  virtual SMTExprRef mkBVMulNoOverflow(const SMTExprRef &LHS,
316                                       const SMTExprRef &RHS,
317                                       bool isSigned) = 0;
318
319  /// Creates a predicate that checks for underflow in a signed bitvector
320  /// multiplication operation
321  virtual SMTExprRef mkBVMulNoUnderflow(const SMTExprRef &LHS,
322                                        const SMTExprRef &RHS) = 0;
323
324  /// Creates a floating-point negation operation
325  virtual SMTExprRef mkFPNeg(const SMTExprRef &Exp) = 0;
326
327  /// Creates a floating-point isInfinite operation
328  virtual SMTExprRef mkFPIsInfinite(const SMTExprRef &Exp) = 0;
329
330  /// Creates a floating-point isNaN operation
331  virtual SMTExprRef mkFPIsNaN(const SMTExprRef &Exp) = 0;
332
333  /// Creates a floating-point isNormal operation
334  virtual SMTExprRef mkFPIsNormal(const SMTExprRef &Exp) = 0;
335
336  /// Creates a floating-point isZero operation
337  virtual SMTExprRef mkFPIsZero(const SMTExprRef &Exp) = 0;
338
339  /// Creates a floating-point multiplication operation
340  virtual SMTExprRef mkFPMul(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
341
342  /// Creates a floating-point division operation
343  virtual SMTExprRef mkFPDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
344
345  /// Creates a floating-point remainder operation
346  virtual SMTExprRef mkFPRem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
347
348  /// Creates a floating-point addition operation
349  virtual SMTExprRef mkFPAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
350
351  /// Creates a floating-point subtraction operation
352  virtual SMTExprRef mkFPSub(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
353
354  /// Creates a floating-point less-than operation
355  virtual SMTExprRef mkFPLt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
356
357  /// Creates a floating-point greater-than operation
358  virtual SMTExprRef mkFPGt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
359
360  /// Creates a floating-point less-than-or-equal operation
361  virtual SMTExprRef mkFPLe(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
362
363  /// Creates a floating-point greater-than-or-equal operation
364  virtual SMTExprRef mkFPGe(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
365
366  /// Creates a floating-point equality operation
367  virtual SMTExprRef mkFPEqual(const SMTExprRef &LHS,
368                               const SMTExprRef &RHS) = 0;
369
370  /// Creates a floating-point conversion from floatint-point to floating-point
371  /// operation
372  virtual SMTExprRef mkFPtoFP(const SMTExprRef &From, const SMTSortRef &To) = 0;
373
374  /// Creates a floating-point conversion from signed bitvector to
375  /// floatint-point operation
376  virtual SMTExprRef mkSBVtoFP(const SMTExprRef &From,
377                               const SMTSortRef &To) = 0;
378
379  /// Creates a floating-point conversion from unsigned bitvector to
380  /// floatint-point operation
381  virtual SMTExprRef mkUBVtoFP(const SMTExprRef &From,
382                               const SMTSortRef &To) = 0;
383
384  /// Creates a floating-point conversion from floatint-point to signed
385  /// bitvector operation
386  virtual SMTExprRef mkFPtoSBV(const SMTExprRef &From, unsigned ToWidth) = 0;
387
388  /// Creates a floating-point conversion from floatint-point to unsigned
389  /// bitvector operation
390  virtual SMTExprRef mkFPtoUBV(const SMTExprRef &From, unsigned ToWidth) = 0;
391
392  /// Creates a new symbol, given a name and a sort
393  virtual SMTExprRef mkSymbol(const char *Name, SMTSortRef Sort) = 0;
394
395  // Returns an appropriate floating-point rounding mode.
396  virtual SMTExprRef getFloatRoundingMode() = 0;
397
398  // If the a model is available, returns the value of a given bitvector symbol
399  virtual llvm::APSInt getBitvector(const SMTExprRef &Exp, unsigned BitWidth,
400                                    bool isUnsigned) = 0;
401
402  // If the a model is available, returns the value of a given boolean symbol
403  virtual bool getBoolean(const SMTExprRef &Exp) = 0;
404
405  /// Constructs an SMTExprRef from a boolean.
406  virtual SMTExprRef mkBoolean(const bool b) = 0;
407
408  /// Constructs an SMTExprRef from a finite APFloat.
409  virtual SMTExprRef mkFloat(const llvm::APFloat Float) = 0;
410
411  /// Constructs an SMTExprRef from an APSInt and its bit width
412  virtual SMTExprRef mkBitvector(const llvm::APSInt Int, unsigned BitWidth) = 0;
413
414  /// Given an expression, extract the value of this operand in the model.
415  virtual bool getInterpretation(const SMTExprRef &Exp, llvm::APSInt &Int) = 0;
416
417  /// Given an expression extract the value of this operand in the model.
418  virtual bool getInterpretation(const SMTExprRef &Exp,
419                                 llvm::APFloat &Float) = 0;
420
421  /// Check if the constraints are satisfiable
422  virtual Optional<bool> check() const = 0;
423
424  /// Push the current solver state
425  virtual void push() = 0;
426
427  /// Pop the previous solver state
428  virtual void pop(unsigned NumStates = 1) = 0;
429
430  /// Reset the solver and remove all constraints.
431  virtual void reset() = 0;
432
433  /// Checks if the solver supports floating-points.
434  virtual bool isFPSupported() = 0;
435
436  virtual void print(raw_ostream &OS) const = 0;
437};
438
439/// Shared pointer for SMTSolvers.
440using SMTSolverRef = std::shared_ptr<SMTSolver>;
441
442/// Convenience method to create and Z3Solver object
443SMTSolverRef CreateZ3Solver();
444
445} // namespace llvm
446
447#endif
448