1//== SMTConv.h --------------------------------------------------*- C++ -*--==//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9//  This file defines a set of functions to create SMT expressions
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef LLVM_CLANG_STATICANALYZER_CORE_PATHSENSITIVE_SMTCONV_H
14#define LLVM_CLANG_STATICANALYZER_CORE_PATHSENSITIVE_SMTCONV_H
15
16#include "clang/AST/Expr.h"
17#include "clang/StaticAnalyzer/Core/PathSensitive/APSIntType.h"
18#include "clang/StaticAnalyzer/Core/PathSensitive/SymbolManager.h"
19#include "llvm/Support/SMTAPI.h"
20
21namespace clang {
22namespace ento {
23
24class SMTConv {
25public:
26  // Returns an appropriate sort, given a QualType and it's bit width.
27  static inline llvm::SMTSortRef mkSort(llvm::SMTSolverRef &Solver,
28                                        const QualType &Ty, unsigned BitWidth) {
29    if (Ty->isBooleanType())
30      return Solver->getBoolSort();
31
32    if (Ty->isRealFloatingType())
33      return Solver->getFloatSort(BitWidth);
34
35    return Solver->getBitvectorSort(BitWidth);
36  }
37
38  /// Constructs an SMTSolverRef from an unary operator.
39  static inline llvm::SMTExprRef fromUnOp(llvm::SMTSolverRef &Solver,
40                                          const UnaryOperator::Opcode Op,
41                                          const llvm::SMTExprRef &Exp) {
42    switch (Op) {
43    case UO_Minus:
44      return Solver->mkBVNeg(Exp);
45
46    case UO_Not:
47      return Solver->mkBVNot(Exp);
48
49    case UO_LNot:
50      return Solver->mkNot(Exp);
51
52    default:;
53    }
54    llvm_unreachable("Unimplemented opcode");
55  }
56
57  /// Constructs an SMTSolverRef from a floating-point unary operator.
58  static inline llvm::SMTExprRef fromFloatUnOp(llvm::SMTSolverRef &Solver,
59                                               const UnaryOperator::Opcode Op,
60                                               const llvm::SMTExprRef &Exp) {
61    switch (Op) {
62    case UO_Minus:
63      return Solver->mkFPNeg(Exp);
64
65    case UO_LNot:
66      return fromUnOp(Solver, Op, Exp);
67
68    default:;
69    }
70    llvm_unreachable("Unimplemented opcode");
71  }
72
73  /// Construct an SMTSolverRef from a n-ary binary operator.
74  static inline llvm::SMTExprRef
75  fromNBinOp(llvm::SMTSolverRef &Solver, const BinaryOperator::Opcode Op,
76             const std::vector<llvm::SMTExprRef> &ASTs) {
77    assert(!ASTs.empty());
78
79    if (Op != BO_LAnd && Op != BO_LOr)
80      llvm_unreachable("Unimplemented opcode");
81
82    llvm::SMTExprRef res = ASTs.front();
83    for (std::size_t i = 1; i < ASTs.size(); ++i)
84      res = (Op == BO_LAnd) ? Solver->mkAnd(res, ASTs[i])
85                            : Solver->mkOr(res, ASTs[i]);
86    return res;
87  }
88
89  /// Construct an SMTSolverRef from a binary operator.
90  static inline llvm::SMTExprRef fromBinOp(llvm::SMTSolverRef &Solver,
91                                           const llvm::SMTExprRef &LHS,
92                                           const BinaryOperator::Opcode Op,
93                                           const llvm::SMTExprRef &RHS,
94                                           bool isSigned) {
95    assert(*Solver->getSort(LHS) == *Solver->getSort(RHS) &&
96           "AST's must have the same sort!");
97
98    switch (Op) {
99    // Multiplicative operators
100    case BO_Mul:
101      return Solver->mkBVMul(LHS, RHS);
102
103    case BO_Div:
104      return isSigned ? Solver->mkBVSDiv(LHS, RHS) : Solver->mkBVUDiv(LHS, RHS);
105
106    case BO_Rem:
107      return isSigned ? Solver->mkBVSRem(LHS, RHS) : Solver->mkBVURem(LHS, RHS);
108
109      // Additive operators
110    case BO_Add:
111      return Solver->mkBVAdd(LHS, RHS);
112
113    case BO_Sub:
114      return Solver->mkBVSub(LHS, RHS);
115
116      // Bitwise shift operators
117    case BO_Shl:
118      return Solver->mkBVShl(LHS, RHS);
119
120    case BO_Shr:
121      return isSigned ? Solver->mkBVAshr(LHS, RHS) : Solver->mkBVLshr(LHS, RHS);
122
123      // Relational operators
124    case BO_LT:
125      return isSigned ? Solver->mkBVSlt(LHS, RHS) : Solver->mkBVUlt(LHS, RHS);
126
127    case BO_GT:
128      return isSigned ? Solver->mkBVSgt(LHS, RHS) : Solver->mkBVUgt(LHS, RHS);
129
130    case BO_LE:
131      return isSigned ? Solver->mkBVSle(LHS, RHS) : Solver->mkBVUle(LHS, RHS);
132
133    case BO_GE:
134      return isSigned ? Solver->mkBVSge(LHS, RHS) : Solver->mkBVUge(LHS, RHS);
135
136      // Equality operators
137    case BO_EQ:
138      return Solver->mkEqual(LHS, RHS);
139
140    case BO_NE:
141      return fromUnOp(Solver, UO_LNot,
142                      fromBinOp(Solver, LHS, BO_EQ, RHS, isSigned));
143
144      // Bitwise operators
145    case BO_And:
146      return Solver->mkBVAnd(LHS, RHS);
147
148    case BO_Xor:
149      return Solver->mkBVXor(LHS, RHS);
150
151    case BO_Or:
152      return Solver->mkBVOr(LHS, RHS);
153
154      // Logical operators
155    case BO_LAnd:
156      return Solver->mkAnd(LHS, RHS);
157
158    case BO_LOr:
159      return Solver->mkOr(LHS, RHS);
160
161    default:;
162    }
163    llvm_unreachable("Unimplemented opcode");
164  }
165
166  /// Construct an SMTSolverRef from a special floating-point binary
167  /// operator.
168  static inline llvm::SMTExprRef
169  fromFloatSpecialBinOp(llvm::SMTSolverRef &Solver, const llvm::SMTExprRef &LHS,
170                        const BinaryOperator::Opcode Op,
171                        const llvm::APFloat::fltCategory &RHS) {
172    switch (Op) {
173    // Equality operators
174    case BO_EQ:
175      switch (RHS) {
176      case llvm::APFloat::fcInfinity:
177        return Solver->mkFPIsInfinite(LHS);
178
179      case llvm::APFloat::fcNaN:
180        return Solver->mkFPIsNaN(LHS);
181
182      case llvm::APFloat::fcNormal:
183        return Solver->mkFPIsNormal(LHS);
184
185      case llvm::APFloat::fcZero:
186        return Solver->mkFPIsZero(LHS);
187      }
188      break;
189
190    case BO_NE:
191      return fromFloatUnOp(Solver, UO_LNot,
192                           fromFloatSpecialBinOp(Solver, LHS, BO_EQ, RHS));
193
194    default:;
195    }
196
197    llvm_unreachable("Unimplemented opcode");
198  }
199
200  /// Construct an SMTSolverRef from a floating-point binary operator.
201  static inline llvm::SMTExprRef fromFloatBinOp(llvm::SMTSolverRef &Solver,
202                                                const llvm::SMTExprRef &LHS,
203                                                const BinaryOperator::Opcode Op,
204                                                const llvm::SMTExprRef &RHS) {
205    assert(*Solver->getSort(LHS) == *Solver->getSort(RHS) &&
206           "AST's must have the same sort!");
207
208    switch (Op) {
209    // Multiplicative operators
210    case BO_Mul:
211      return Solver->mkFPMul(LHS, RHS);
212
213    case BO_Div:
214      return Solver->mkFPDiv(LHS, RHS);
215
216    case BO_Rem:
217      return Solver->mkFPRem(LHS, RHS);
218
219      // Additive operators
220    case BO_Add:
221      return Solver->mkFPAdd(LHS, RHS);
222
223    case BO_Sub:
224      return Solver->mkFPSub(LHS, RHS);
225
226      // Relational operators
227    case BO_LT:
228      return Solver->mkFPLt(LHS, RHS);
229
230    case BO_GT:
231      return Solver->mkFPGt(LHS, RHS);
232
233    case BO_LE:
234      return Solver->mkFPLe(LHS, RHS);
235
236    case BO_GE:
237      return Solver->mkFPGe(LHS, RHS);
238
239      // Equality operators
240    case BO_EQ:
241      return Solver->mkFPEqual(LHS, RHS);
242
243    case BO_NE:
244      return fromFloatUnOp(Solver, UO_LNot,
245                           fromFloatBinOp(Solver, LHS, BO_EQ, RHS));
246
247      // Logical operators
248    case BO_LAnd:
249    case BO_LOr:
250      return fromBinOp(Solver, LHS, Op, RHS, /*isSigned=*/false);
251
252    default:;
253    }
254
255    llvm_unreachable("Unimplemented opcode");
256  }
257
258  /// Construct an SMTSolverRef from a QualType FromTy to a QualType ToTy,
259  /// and their bit widths.
260  static inline llvm::SMTExprRef fromCast(llvm::SMTSolverRef &Solver,
261                                          const llvm::SMTExprRef &Exp,
262                                          QualType ToTy, uint64_t ToBitWidth,
263                                          QualType FromTy,
264                                          uint64_t FromBitWidth) {
265    if ((FromTy->isIntegralOrEnumerationType() &&
266         ToTy->isIntegralOrEnumerationType()) ||
267        (FromTy->isAnyPointerType() ^ ToTy->isAnyPointerType()) ||
268        (FromTy->isBlockPointerType() ^ ToTy->isBlockPointerType()) ||
269        (FromTy->isReferenceType() ^ ToTy->isReferenceType())) {
270
271      if (FromTy->isBooleanType()) {
272        assert(ToBitWidth > 0 && "BitWidth must be positive!");
273        return Solver->mkIte(
274            Exp, Solver->mkBitvector(llvm::APSInt("1"), ToBitWidth),
275            Solver->mkBitvector(llvm::APSInt("0"), ToBitWidth));
276      }
277
278      if (ToBitWidth > FromBitWidth)
279        return FromTy->isSignedIntegerOrEnumerationType()
280                   ? Solver->mkBVSignExt(ToBitWidth - FromBitWidth, Exp)
281                   : Solver->mkBVZeroExt(ToBitWidth - FromBitWidth, Exp);
282
283      if (ToBitWidth < FromBitWidth)
284        return Solver->mkBVExtract(ToBitWidth - 1, 0, Exp);
285
286      // Both are bitvectors with the same width, ignore the type cast
287      return Exp;
288    }
289
290    if (FromTy->isRealFloatingType() && ToTy->isRealFloatingType()) {
291      if (ToBitWidth != FromBitWidth)
292        return Solver->mkFPtoFP(Exp, Solver->getFloatSort(ToBitWidth));
293
294      return Exp;
295    }
296
297    if (FromTy->isIntegralOrEnumerationType() && ToTy->isRealFloatingType()) {
298      llvm::SMTSortRef Sort = Solver->getFloatSort(ToBitWidth);
299      return FromTy->isSignedIntegerOrEnumerationType()
300                 ? Solver->mkSBVtoFP(Exp, Sort)
301                 : Solver->mkUBVtoFP(Exp, Sort);
302    }
303
304    if (FromTy->isRealFloatingType() && ToTy->isIntegralOrEnumerationType())
305      return ToTy->isSignedIntegerOrEnumerationType()
306                 ? Solver->mkFPtoSBV(Exp, ToBitWidth)
307                 : Solver->mkFPtoUBV(Exp, ToBitWidth);
308
309    llvm_unreachable("Unsupported explicit type cast!");
310  }
311
312  // Callback function for doCast parameter on APSInt type.
313  static inline llvm::APSInt castAPSInt(llvm::SMTSolverRef &Solver,
314                                        const llvm::APSInt &V, QualType ToTy,
315                                        uint64_t ToWidth, QualType FromTy,
316                                        uint64_t FromWidth) {
317    APSIntType TargetType(ToWidth, !ToTy->isSignedIntegerOrEnumerationType());
318    return TargetType.convert(V);
319  }
320
321  /// Construct an SMTSolverRef from a SymbolData.
322  static inline llvm::SMTExprRef fromData(llvm::SMTSolverRef &Solver,
323                                          const SymbolID ID, const QualType &Ty,
324                                          uint64_t BitWidth) {
325    llvm::Twine Name = "$" + llvm::Twine(ID);
326    return Solver->mkSymbol(Name.str().c_str(), mkSort(Solver, Ty, BitWidth));
327  }
328
329  // Wrapper to generate SMTSolverRef from SymbolCast data.
330  static inline llvm::SMTExprRef getCastExpr(llvm::SMTSolverRef &Solver,
331                                             ASTContext &Ctx,
332                                             const llvm::SMTExprRef &Exp,
333                                             QualType FromTy, QualType ToTy) {
334    return fromCast(Solver, Exp, ToTy, Ctx.getTypeSize(ToTy), FromTy,
335                    Ctx.getTypeSize(FromTy));
336  }
337
338  // Wrapper to generate SMTSolverRef from unpacked binary symbolic
339  // expression. Sets the RetTy parameter. See getSMTSolverRef().
340  static inline llvm::SMTExprRef
341  getBinExpr(llvm::SMTSolverRef &Solver, ASTContext &Ctx,
342             const llvm::SMTExprRef &LHS, QualType LTy,
343             BinaryOperator::Opcode Op, const llvm::SMTExprRef &RHS,
344             QualType RTy, QualType *RetTy) {
345    llvm::SMTExprRef NewLHS = LHS;
346    llvm::SMTExprRef NewRHS = RHS;
347    doTypeConversion(Solver, Ctx, NewLHS, NewRHS, LTy, RTy);
348
349    // Update the return type parameter if the output type has changed.
350    if (RetTy) {
351      // A boolean result can be represented as an integer type in C/C++, but at
352      // this point we only care about the SMT sorts. Set it as a boolean type
353      // to avoid subsequent SMT errors.
354      if (BinaryOperator::isComparisonOp(Op) ||
355          BinaryOperator::isLogicalOp(Op)) {
356        *RetTy = Ctx.BoolTy;
357      } else {
358        *RetTy = LTy;
359      }
360
361      // If the two operands are pointers and the operation is a subtraction,
362      // the result is of type ptrdiff_t, which is signed
363      if (LTy->isAnyPointerType() && RTy->isAnyPointerType() && Op == BO_Sub) {
364        *RetTy = Ctx.getPointerDiffType();
365      }
366    }
367
368    return LTy->isRealFloatingType()
369               ? fromFloatBinOp(Solver, NewLHS, Op, NewRHS)
370               : fromBinOp(Solver, NewLHS, Op, NewRHS,
371                           LTy->isSignedIntegerOrEnumerationType());
372  }
373
374  // Wrapper to generate SMTSolverRef from BinarySymExpr.
375  // Sets the hasComparison and RetTy parameters. See getSMTSolverRef().
376  static inline llvm::SMTExprRef getSymBinExpr(llvm::SMTSolverRef &Solver,
377                                               ASTContext &Ctx,
378                                               const BinarySymExpr *BSE,
379                                               bool *hasComparison,
380                                               QualType *RetTy) {
381    QualType LTy, RTy;
382    BinaryOperator::Opcode Op = BSE->getOpcode();
383
384    if (const SymIntExpr *SIE = dyn_cast<SymIntExpr>(BSE)) {
385      llvm::SMTExprRef LHS =
386          getSymExpr(Solver, Ctx, SIE->getLHS(), &LTy, hasComparison);
387      llvm::APSInt NewRInt;
388      std::tie(NewRInt, RTy) = fixAPSInt(Ctx, SIE->getRHS());
389      llvm::SMTExprRef RHS =
390          Solver->mkBitvector(NewRInt, NewRInt.getBitWidth());
391      return getBinExpr(Solver, Ctx, LHS, LTy, Op, RHS, RTy, RetTy);
392    }
393
394    if (const IntSymExpr *ISE = dyn_cast<IntSymExpr>(BSE)) {
395      llvm::APSInt NewLInt;
396      std::tie(NewLInt, LTy) = fixAPSInt(Ctx, ISE->getLHS());
397      llvm::SMTExprRef LHS =
398          Solver->mkBitvector(NewLInt, NewLInt.getBitWidth());
399      llvm::SMTExprRef RHS =
400          getSymExpr(Solver, Ctx, ISE->getRHS(), &RTy, hasComparison);
401      return getBinExpr(Solver, Ctx, LHS, LTy, Op, RHS, RTy, RetTy);
402    }
403
404    if (const SymSymExpr *SSM = dyn_cast<SymSymExpr>(BSE)) {
405      llvm::SMTExprRef LHS =
406          getSymExpr(Solver, Ctx, SSM->getLHS(), &LTy, hasComparison);
407      llvm::SMTExprRef RHS =
408          getSymExpr(Solver, Ctx, SSM->getRHS(), &RTy, hasComparison);
409      return getBinExpr(Solver, Ctx, LHS, LTy, Op, RHS, RTy, RetTy);
410    }
411
412    llvm_unreachable("Unsupported BinarySymExpr type!");
413  }
414
415  // Recursive implementation to unpack and generate symbolic expression.
416  // Sets the hasComparison and RetTy parameters. See getExpr().
417  static inline llvm::SMTExprRef getSymExpr(llvm::SMTSolverRef &Solver,
418                                            ASTContext &Ctx, SymbolRef Sym,
419                                            QualType *RetTy,
420                                            bool *hasComparison) {
421    if (const SymbolData *SD = dyn_cast<SymbolData>(Sym)) {
422      if (RetTy)
423        *RetTy = Sym->getType();
424
425      return fromData(Solver, SD->getSymbolID(), Sym->getType(),
426                      Ctx.getTypeSize(Sym->getType()));
427    }
428
429    if (const SymbolCast *SC = dyn_cast<SymbolCast>(Sym)) {
430      if (RetTy)
431        *RetTy = Sym->getType();
432
433      QualType FromTy;
434      llvm::SMTExprRef Exp =
435          getSymExpr(Solver, Ctx, SC->getOperand(), &FromTy, hasComparison);
436
437      // Casting an expression with a comparison invalidates it. Note that this
438      // must occur after the recursive call above.
439      // e.g. (signed char) (x > 0)
440      if (hasComparison)
441        *hasComparison = false;
442      return getCastExpr(Solver, Ctx, Exp, FromTy, Sym->getType());
443    }
444
445    if (const BinarySymExpr *BSE = dyn_cast<BinarySymExpr>(Sym)) {
446      llvm::SMTExprRef Exp =
447          getSymBinExpr(Solver, Ctx, BSE, hasComparison, RetTy);
448      // Set the hasComparison parameter, in post-order traversal order.
449      if (hasComparison)
450        *hasComparison = BinaryOperator::isComparisonOp(BSE->getOpcode());
451      return Exp;
452    }
453
454    llvm_unreachable("Unsupported SymbolRef type!");
455  }
456
457  // Generate an SMTSolverRef that represents the given symbolic expression.
458  // Sets the hasComparison parameter if the expression has a comparison
459  // operator. Sets the RetTy parameter to the final return type after
460  // promotions and casts.
461  static inline llvm::SMTExprRef getExpr(llvm::SMTSolverRef &Solver,
462                                         ASTContext &Ctx, SymbolRef Sym,
463                                         QualType *RetTy = nullptr,
464                                         bool *hasComparison = nullptr) {
465    if (hasComparison) {
466      *hasComparison = false;
467    }
468
469    return getSymExpr(Solver, Ctx, Sym, RetTy, hasComparison);
470  }
471
472  // Generate an SMTSolverRef that compares the expression to zero.
473  static inline llvm::SMTExprRef getZeroExpr(llvm::SMTSolverRef &Solver,
474                                             ASTContext &Ctx,
475                                             const llvm::SMTExprRef &Exp,
476                                             QualType Ty, bool Assumption) {
477    if (Ty->isRealFloatingType()) {
478      llvm::APFloat Zero =
479          llvm::APFloat::getZero(Ctx.getFloatTypeSemantics(Ty));
480      return fromFloatBinOp(Solver, Exp, Assumption ? BO_EQ : BO_NE,
481                            Solver->mkFloat(Zero));
482    }
483
484    if (Ty->isIntegralOrEnumerationType() || Ty->isAnyPointerType() ||
485        Ty->isBlockPointerType() || Ty->isReferenceType()) {
486
487      // Skip explicit comparison for boolean types
488      bool isSigned = Ty->isSignedIntegerOrEnumerationType();
489      if (Ty->isBooleanType())
490        return Assumption ? fromUnOp(Solver, UO_LNot, Exp) : Exp;
491
492      return fromBinOp(
493          Solver, Exp, Assumption ? BO_EQ : BO_NE,
494          Solver->mkBitvector(llvm::APSInt("0"), Ctx.getTypeSize(Ty)),
495          isSigned);
496    }
497
498    llvm_unreachable("Unsupported type for zero value!");
499  }
500
501  // Wrapper to generate SMTSolverRef from a range. If From == To, an
502  // equality will be created instead.
503  static inline llvm::SMTExprRef
504  getRangeExpr(llvm::SMTSolverRef &Solver, ASTContext &Ctx, SymbolRef Sym,
505               const llvm::APSInt &From, const llvm::APSInt &To, bool InRange) {
506    // Convert lower bound
507    QualType FromTy;
508    llvm::APSInt NewFromInt;
509    std::tie(NewFromInt, FromTy) = fixAPSInt(Ctx, From);
510    llvm::SMTExprRef FromExp =
511        Solver->mkBitvector(NewFromInt, NewFromInt.getBitWidth());
512
513    // Convert symbol
514    QualType SymTy;
515    llvm::SMTExprRef Exp = getExpr(Solver, Ctx, Sym, &SymTy);
516
517    // Construct single (in)equality
518    if (From == To)
519      return getBinExpr(Solver, Ctx, Exp, SymTy, InRange ? BO_EQ : BO_NE,
520                        FromExp, FromTy, /*RetTy=*/nullptr);
521
522    QualType ToTy;
523    llvm::APSInt NewToInt;
524    std::tie(NewToInt, ToTy) = fixAPSInt(Ctx, To);
525    llvm::SMTExprRef ToExp =
526        Solver->mkBitvector(NewToInt, NewToInt.getBitWidth());
527    assert(FromTy == ToTy && "Range values have different types!");
528
529    // Construct two (in)equalities, and a logical and/or
530    llvm::SMTExprRef LHS =
531        getBinExpr(Solver, Ctx, Exp, SymTy, InRange ? BO_GE : BO_LT, FromExp,
532                   FromTy, /*RetTy=*/nullptr);
533    llvm::SMTExprRef RHS = getBinExpr(Solver, Ctx, Exp, SymTy,
534                                      InRange ? BO_LE : BO_GT, ToExp, ToTy,
535                                      /*RetTy=*/nullptr);
536
537    return fromBinOp(Solver, LHS, InRange ? BO_LAnd : BO_LOr, RHS,
538                     SymTy->isSignedIntegerOrEnumerationType());
539  }
540
541  // Recover the QualType of an APSInt.
542  // TODO: Refactor to put elsewhere
543  static inline QualType getAPSIntType(ASTContext &Ctx,
544                                       const llvm::APSInt &Int) {
545    return Ctx.getIntTypeForBitwidth(Int.getBitWidth(), Int.isSigned());
546  }
547
548  // Get the QualTy for the input APSInt, and fix it if it has a bitwidth of 1.
549  static inline std::pair<llvm::APSInt, QualType>
550  fixAPSInt(ASTContext &Ctx, const llvm::APSInt &Int) {
551    llvm::APSInt NewInt;
552
553    // FIXME: This should be a cast from a 1-bit integer type to a boolean type,
554    // but the former is not available in Clang. Instead, extend the APSInt
555    // directly.
556    if (Int.getBitWidth() == 1 && getAPSIntType(Ctx, Int).isNull()) {
557      NewInt = Int.extend(Ctx.getTypeSize(Ctx.BoolTy));
558    } else
559      NewInt = Int;
560
561    return std::make_pair(NewInt, getAPSIntType(Ctx, NewInt));
562  }
563
564  // Perform implicit type conversion on binary symbolic expressions.
565  // May modify all input parameters.
566  // TODO: Refactor to use built-in conversion functions
567  static inline void doTypeConversion(llvm::SMTSolverRef &Solver,
568                                      ASTContext &Ctx, llvm::SMTExprRef &LHS,
569                                      llvm::SMTExprRef &RHS, QualType &LTy,
570                                      QualType &RTy) {
571    assert(!LTy.isNull() && !RTy.isNull() && "Input type is null!");
572
573    // Perform type conversion
574    if ((LTy->isIntegralOrEnumerationType() &&
575         RTy->isIntegralOrEnumerationType()) &&
576        (LTy->isArithmeticType() && RTy->isArithmeticType())) {
577      SMTConv::doIntTypeConversion<llvm::SMTExprRef, &fromCast>(
578          Solver, Ctx, LHS, LTy, RHS, RTy);
579      return;
580    }
581
582    if (LTy->isRealFloatingType() || RTy->isRealFloatingType()) {
583      SMTConv::doFloatTypeConversion<llvm::SMTExprRef, &fromCast>(
584          Solver, Ctx, LHS, LTy, RHS, RTy);
585      return;
586    }
587
588    if ((LTy->isAnyPointerType() || RTy->isAnyPointerType()) ||
589        (LTy->isBlockPointerType() || RTy->isBlockPointerType()) ||
590        (LTy->isReferenceType() || RTy->isReferenceType())) {
591      // TODO: Refactor to Sema::FindCompositePointerType(), and
592      // Sema::CheckCompareOperands().
593
594      uint64_t LBitWidth = Ctx.getTypeSize(LTy);
595      uint64_t RBitWidth = Ctx.getTypeSize(RTy);
596
597      // Cast the non-pointer type to the pointer type.
598      // TODO: Be more strict about this.
599      if ((LTy->isAnyPointerType() ^ RTy->isAnyPointerType()) ||
600          (LTy->isBlockPointerType() ^ RTy->isBlockPointerType()) ||
601          (LTy->isReferenceType() ^ RTy->isReferenceType())) {
602        if (LTy->isNullPtrType() || LTy->isBlockPointerType() ||
603            LTy->isReferenceType()) {
604          LHS = fromCast(Solver, LHS, RTy, RBitWidth, LTy, LBitWidth);
605          LTy = RTy;
606        } else {
607          RHS = fromCast(Solver, RHS, LTy, LBitWidth, RTy, RBitWidth);
608          RTy = LTy;
609        }
610      }
611
612      // Cast the void pointer type to the non-void pointer type.
613      // For void types, this assumes that the casted value is equal to the
614      // value of the original pointer, and does not account for alignment
615      // requirements.
616      if (LTy->isVoidPointerType() ^ RTy->isVoidPointerType()) {
617        assert((Ctx.getTypeSize(LTy) == Ctx.getTypeSize(RTy)) &&
618               "Pointer types have different bitwidths!");
619        if (RTy->isVoidPointerType())
620          RTy = LTy;
621        else
622          LTy = RTy;
623      }
624
625      if (LTy == RTy)
626        return;
627    }
628
629    // Fallback: for the solver, assume that these types don't really matter
630    if ((LTy.getCanonicalType() == RTy.getCanonicalType()) ||
631        (LTy->isObjCObjectPointerType() && RTy->isObjCObjectPointerType())) {
632      LTy = RTy;
633      return;
634    }
635
636    // TODO: Refine behavior for invalid type casts
637  }
638
639  // Perform implicit integer type conversion.
640  // May modify all input parameters.
641  // TODO: Refactor to use Sema::handleIntegerConversion()
642  template <typename T, T (*doCast)(llvm::SMTSolverRef &Solver, const T &,
643                                    QualType, uint64_t, QualType, uint64_t)>
644  static inline void doIntTypeConversion(llvm::SMTSolverRef &Solver,
645                                         ASTContext &Ctx, T &LHS, QualType &LTy,
646                                         T &RHS, QualType &RTy) {
647    uint64_t LBitWidth = Ctx.getTypeSize(LTy);
648    uint64_t RBitWidth = Ctx.getTypeSize(RTy);
649
650    assert(!LTy.isNull() && !RTy.isNull() && "Input type is null!");
651    // Always perform integer promotion before checking type equality.
652    // Otherwise, e.g. (bool) a + (bool) b could trigger a backend assertion
653    if (LTy->isPromotableIntegerType()) {
654      QualType NewTy = Ctx.getPromotedIntegerType(LTy);
655      uint64_t NewBitWidth = Ctx.getTypeSize(NewTy);
656      LHS = (*doCast)(Solver, LHS, NewTy, NewBitWidth, LTy, LBitWidth);
657      LTy = NewTy;
658      LBitWidth = NewBitWidth;
659    }
660    if (RTy->isPromotableIntegerType()) {
661      QualType NewTy = Ctx.getPromotedIntegerType(RTy);
662      uint64_t NewBitWidth = Ctx.getTypeSize(NewTy);
663      RHS = (*doCast)(Solver, RHS, NewTy, NewBitWidth, RTy, RBitWidth);
664      RTy = NewTy;
665      RBitWidth = NewBitWidth;
666    }
667
668    if (LTy == RTy)
669      return;
670
671    // Perform integer type conversion
672    // Note: Safe to skip updating bitwidth because this must terminate
673    bool isLSignedTy = LTy->isSignedIntegerOrEnumerationType();
674    bool isRSignedTy = RTy->isSignedIntegerOrEnumerationType();
675
676    int order = Ctx.getIntegerTypeOrder(LTy, RTy);
677    if (isLSignedTy == isRSignedTy) {
678      // Same signedness; use the higher-ranked type
679      if (order == 1) {
680        RHS = (*doCast)(Solver, RHS, LTy, LBitWidth, RTy, RBitWidth);
681        RTy = LTy;
682      } else {
683        LHS = (*doCast)(Solver, LHS, RTy, RBitWidth, LTy, LBitWidth);
684        LTy = RTy;
685      }
686    } else if (order != (isLSignedTy ? 1 : -1)) {
687      // The unsigned type has greater than or equal rank to the
688      // signed type, so use the unsigned type
689      if (isRSignedTy) {
690        RHS = (*doCast)(Solver, RHS, LTy, LBitWidth, RTy, RBitWidth);
691        RTy = LTy;
692      } else {
693        LHS = (*doCast)(Solver, LHS, RTy, RBitWidth, LTy, LBitWidth);
694        LTy = RTy;
695      }
696    } else if (LBitWidth != RBitWidth) {
697      // The two types are different widths; if we are here, that
698      // means the signed type is larger than the unsigned type, so
699      // use the signed type.
700      if (isLSignedTy) {
701        RHS = (doCast)(Solver, RHS, LTy, LBitWidth, RTy, RBitWidth);
702        RTy = LTy;
703      } else {
704        LHS = (*doCast)(Solver, LHS, RTy, RBitWidth, LTy, LBitWidth);
705        LTy = RTy;
706      }
707    } else {
708      // The signed type is higher-ranked than the unsigned type,
709      // but isn't actually any bigger (like unsigned int and long
710      // on most 32-bit systems).  Use the unsigned type corresponding
711      // to the signed type.
712      QualType NewTy =
713          Ctx.getCorrespondingUnsignedType(isLSignedTy ? LTy : RTy);
714      RHS = (*doCast)(Solver, RHS, LTy, LBitWidth, RTy, RBitWidth);
715      RTy = NewTy;
716      LHS = (doCast)(Solver, LHS, RTy, RBitWidth, LTy, LBitWidth);
717      LTy = NewTy;
718    }
719  }
720
721  // Perform implicit floating-point type conversion.
722  // May modify all input parameters.
723  // TODO: Refactor to use Sema::handleFloatConversion()
724  template <typename T, T (*doCast)(llvm::SMTSolverRef &Solver, const T &,
725                                    QualType, uint64_t, QualType, uint64_t)>
726  static inline void
727  doFloatTypeConversion(llvm::SMTSolverRef &Solver, ASTContext &Ctx, T &LHS,
728                        QualType &LTy, T &RHS, QualType &RTy) {
729    uint64_t LBitWidth = Ctx.getTypeSize(LTy);
730    uint64_t RBitWidth = Ctx.getTypeSize(RTy);
731
732    // Perform float-point type promotion
733    if (!LTy->isRealFloatingType()) {
734      LHS = (*doCast)(Solver, LHS, RTy, RBitWidth, LTy, LBitWidth);
735      LTy = RTy;
736      LBitWidth = RBitWidth;
737    }
738    if (!RTy->isRealFloatingType()) {
739      RHS = (*doCast)(Solver, RHS, LTy, LBitWidth, RTy, RBitWidth);
740      RTy = LTy;
741      RBitWidth = LBitWidth;
742    }
743
744    if (LTy == RTy)
745      return;
746
747    // If we have two real floating types, convert the smaller operand to the
748    // bigger result
749    // Note: Safe to skip updating bitwidth because this must terminate
750    int order = Ctx.getFloatingTypeOrder(LTy, RTy);
751    if (order > 0) {
752      RHS = (*doCast)(Solver, RHS, LTy, LBitWidth, RTy, RBitWidth);
753      RTy = LTy;
754    } else if (order == 0) {
755      LHS = (*doCast)(Solver, LHS, RTy, RBitWidth, LTy, LBitWidth);
756      LTy = RTy;
757    } else {
758      llvm_unreachable("Unsupported floating-point type cast!");
759    }
760  }
761};
762} // namespace ento
763} // namespace clang
764
765#endif
766