1//===- Float2Int.cpp - Demote floating point ops to work on integers ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the Float2Int pass, which aims to demote floating
10// point operations to work on integers, where that is losslessly possible.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/InitializePasses.h"
15#include "llvm/Support/CommandLine.h"
16#define DEBUG_TYPE "float2int"
17
18#include "llvm/Transforms/Scalar/Float2Int.h"
19#include "llvm/ADT/APInt.h"
20#include "llvm/ADT/APSInt.h"
21#include "llvm/ADT/SmallVector.h"
22#include "llvm/Analysis/AliasAnalysis.h"
23#include "llvm/Analysis/GlobalsModRef.h"
24#include "llvm/IR/Constants.h"
25#include "llvm/IR/IRBuilder.h"
26#include "llvm/IR/InstIterator.h"
27#include "llvm/IR/Instructions.h"
28#include "llvm/IR/Module.h"
29#include "llvm/Pass.h"
30#include "llvm/Support/Debug.h"
31#include "llvm/Support/raw_ostream.h"
32#include "llvm/Transforms/Scalar.h"
33#include <deque>
34#include <functional> // For std::function
35using namespace llvm;
36
37// The algorithm is simple. Start at instructions that convert from the
38// float to the int domain: fptoui, fptosi and fcmp. Walk up the def-use
39// graph, using an equivalence datastructure to unify graphs that interfere.
40//
41// Mappable instructions are those with an integer corrollary that, given
42// integer domain inputs, produce an integer output; fadd, for example.
43//
44// If a non-mappable instruction is seen, this entire def-use graph is marked
45// as non-transformable. If we see an instruction that converts from the
46// integer domain to FP domain (uitofp,sitofp), we terminate our walk.
47
48/// The largest integer type worth dealing with.
49static cl::opt<unsigned>
50MaxIntegerBW("float2int-max-integer-bw", cl::init(64), cl::Hidden,
51             cl::desc("Max integer bitwidth to consider in float2int"
52                      "(default=64)"));
53
54namespace {
55  struct Float2IntLegacyPass : public FunctionPass {
56    static char ID; // Pass identification, replacement for typeid
57    Float2IntLegacyPass() : FunctionPass(ID) {
58      initializeFloat2IntLegacyPassPass(*PassRegistry::getPassRegistry());
59    }
60
61    bool runOnFunction(Function &F) override {
62      if (skipFunction(F))
63        return false;
64
65      const DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
66      return Impl.runImpl(F, DT);
67    }
68
69    void getAnalysisUsage(AnalysisUsage &AU) const override {
70      AU.setPreservesCFG();
71      AU.addRequired<DominatorTreeWrapperPass>();
72      AU.addPreserved<GlobalsAAWrapperPass>();
73    }
74
75  private:
76    Float2IntPass Impl;
77  };
78}
79
80char Float2IntLegacyPass::ID = 0;
81INITIALIZE_PASS(Float2IntLegacyPass, "float2int", "Float to int", false, false)
82
83// Given a FCmp predicate, return a matching ICmp predicate if one
84// exists, otherwise return BAD_ICMP_PREDICATE.
85static CmpInst::Predicate mapFCmpPred(CmpInst::Predicate P) {
86  switch (P) {
87  case CmpInst::FCMP_OEQ:
88  case CmpInst::FCMP_UEQ:
89    return CmpInst::ICMP_EQ;
90  case CmpInst::FCMP_OGT:
91  case CmpInst::FCMP_UGT:
92    return CmpInst::ICMP_SGT;
93  case CmpInst::FCMP_OGE:
94  case CmpInst::FCMP_UGE:
95    return CmpInst::ICMP_SGE;
96  case CmpInst::FCMP_OLT:
97  case CmpInst::FCMP_ULT:
98    return CmpInst::ICMP_SLT;
99  case CmpInst::FCMP_OLE:
100  case CmpInst::FCMP_ULE:
101    return CmpInst::ICMP_SLE;
102  case CmpInst::FCMP_ONE:
103  case CmpInst::FCMP_UNE:
104    return CmpInst::ICMP_NE;
105  default:
106    return CmpInst::BAD_ICMP_PREDICATE;
107  }
108}
109
110// Given a floating point binary operator, return the matching
111// integer version.
112static Instruction::BinaryOps mapBinOpcode(unsigned Opcode) {
113  switch (Opcode) {
114  default: llvm_unreachable("Unhandled opcode!");
115  case Instruction::FAdd: return Instruction::Add;
116  case Instruction::FSub: return Instruction::Sub;
117  case Instruction::FMul: return Instruction::Mul;
118  }
119}
120
121// Find the roots - instructions that convert from the FP domain to
122// integer domain.
123void Float2IntPass::findRoots(Function &F, const DominatorTree &DT) {
124  for (BasicBlock &BB : F) {
125    // Unreachable code can take on strange forms that we are not prepared to
126    // handle. For example, an instruction may have itself as an operand.
127    if (!DT.isReachableFromEntry(&BB))
128      continue;
129
130    for (Instruction &I : BB) {
131      if (isa<VectorType>(I.getType()))
132        continue;
133      switch (I.getOpcode()) {
134      default: break;
135      case Instruction::FPToUI:
136      case Instruction::FPToSI:
137        Roots.insert(&I);
138        break;
139      case Instruction::FCmp:
140        if (mapFCmpPred(cast<CmpInst>(&I)->getPredicate()) !=
141            CmpInst::BAD_ICMP_PREDICATE)
142          Roots.insert(&I);
143        break;
144      }
145    }
146  }
147}
148
149// Helper - mark I as having been traversed, having range R.
150void Float2IntPass::seen(Instruction *I, ConstantRange R) {
151  LLVM_DEBUG(dbgs() << "F2I: " << *I << ":" << R << "\n");
152  auto IT = SeenInsts.find(I);
153  if (IT != SeenInsts.end())
154    IT->second = std::move(R);
155  else
156    SeenInsts.insert(std::make_pair(I, std::move(R)));
157}
158
159// Helper - get a range representing a poison value.
160ConstantRange Float2IntPass::badRange() {
161  return ConstantRange::getFull(MaxIntegerBW + 1);
162}
163ConstantRange Float2IntPass::unknownRange() {
164  return ConstantRange::getEmpty(MaxIntegerBW + 1);
165}
166ConstantRange Float2IntPass::validateRange(ConstantRange R) {
167  if (R.getBitWidth() > MaxIntegerBW + 1)
168    return badRange();
169  return R;
170}
171
172// The most obvious way to structure the search is a depth-first, eager
173// search from each root. However, that require direct recursion and so
174// can only handle small instruction sequences. Instead, we split the search
175// up into two phases:
176//   - walkBackwards:  A breadth-first walk of the use-def graph starting from
177//                     the roots. Populate "SeenInsts" with interesting
178//                     instructions and poison values if they're obvious and
179//                     cheap to compute. Calculate the equivalance set structure
180//                     while we're here too.
181//   - walkForwards:  Iterate over SeenInsts in reverse order, so we visit
182//                     defs before their uses. Calculate the real range info.
183
184// Breadth-first walk of the use-def graph; determine the set of nodes
185// we care about and eagerly determine if some of them are poisonous.
186void Float2IntPass::walkBackwards() {
187  std::deque<Instruction*> Worklist(Roots.begin(), Roots.end());
188  while (!Worklist.empty()) {
189    Instruction *I = Worklist.back();
190    Worklist.pop_back();
191
192    if (SeenInsts.find(I) != SeenInsts.end())
193      // Seen already.
194      continue;
195
196    switch (I->getOpcode()) {
197      // FIXME: Handle select and phi nodes.
198    default:
199      // Path terminated uncleanly.
200      seen(I, badRange());
201      break;
202
203    case Instruction::UIToFP:
204    case Instruction::SIToFP: {
205      // Path terminated cleanly - use the type of the integer input to seed
206      // the analysis.
207      unsigned BW = I->getOperand(0)->getType()->getPrimitiveSizeInBits();
208      auto Input = ConstantRange::getFull(BW);
209      auto CastOp = (Instruction::CastOps)I->getOpcode();
210      seen(I, validateRange(Input.castOp(CastOp, MaxIntegerBW+1)));
211      continue;
212    }
213
214    case Instruction::FNeg:
215    case Instruction::FAdd:
216    case Instruction::FSub:
217    case Instruction::FMul:
218    case Instruction::FPToUI:
219    case Instruction::FPToSI:
220    case Instruction::FCmp:
221      seen(I, unknownRange());
222      break;
223    }
224
225    for (Value *O : I->operands()) {
226      if (Instruction *OI = dyn_cast<Instruction>(O)) {
227        // Unify def-use chains if they interfere.
228        ECs.unionSets(I, OI);
229        if (SeenInsts.find(I)->second != badRange())
230          Worklist.push_back(OI);
231      } else if (!isa<ConstantFP>(O)) {
232        // Not an instruction or ConstantFP? we can't do anything.
233        seen(I, badRange());
234      }
235    }
236  }
237}
238
239// Walk forwards down the list of seen instructions, so we visit defs before
240// uses.
241void Float2IntPass::walkForwards() {
242  for (auto &It : reverse(SeenInsts)) {
243    if (It.second != unknownRange())
244      continue;
245
246    Instruction *I = It.first;
247    std::function<ConstantRange(ArrayRef<ConstantRange>)> Op;
248    switch (I->getOpcode()) {
249      // FIXME: Handle select and phi nodes.
250    default:
251    case Instruction::UIToFP:
252    case Instruction::SIToFP:
253      llvm_unreachable("Should have been handled in walkForwards!");
254
255    case Instruction::FNeg:
256      Op = [](ArrayRef<ConstantRange> Ops) {
257        assert(Ops.size() == 1 && "FNeg is a unary operator!");
258        unsigned Size = Ops[0].getBitWidth();
259        auto Zero = ConstantRange(APInt::getNullValue(Size));
260        return Zero.sub(Ops[0]);
261      };
262      break;
263
264    case Instruction::FAdd:
265    case Instruction::FSub:
266    case Instruction::FMul:
267      Op = [I](ArrayRef<ConstantRange> Ops) {
268        assert(Ops.size() == 2 && "its a binary operator!");
269        auto BinOp = (Instruction::BinaryOps) I->getOpcode();
270        return Ops[0].binaryOp(BinOp, Ops[1]);
271      };
272      break;
273
274    //
275    // Root-only instructions - we'll only see these if they're the
276    //                          first node in a walk.
277    //
278    case Instruction::FPToUI:
279    case Instruction::FPToSI:
280      Op = [I](ArrayRef<ConstantRange> Ops) {
281        assert(Ops.size() == 1 && "FPTo[US]I is a unary operator!");
282        // Note: We're ignoring the casts output size here as that's what the
283        // caller expects.
284        auto CastOp = (Instruction::CastOps)I->getOpcode();
285        return Ops[0].castOp(CastOp, MaxIntegerBW+1);
286      };
287      break;
288
289    case Instruction::FCmp:
290      Op = [](ArrayRef<ConstantRange> Ops) {
291        assert(Ops.size() == 2 && "FCmp is a binary operator!");
292        return Ops[0].unionWith(Ops[1]);
293      };
294      break;
295    }
296
297    bool Abort = false;
298    SmallVector<ConstantRange,4> OpRanges;
299    for (Value *O : I->operands()) {
300      if (Instruction *OI = dyn_cast<Instruction>(O)) {
301        assert(SeenInsts.find(OI) != SeenInsts.end() &&
302               "def not seen before use!");
303        OpRanges.push_back(SeenInsts.find(OI)->second);
304      } else if (ConstantFP *CF = dyn_cast<ConstantFP>(O)) {
305        // Work out if the floating point number can be losslessly represented
306        // as an integer.
307        // APFloat::convertToInteger(&Exact) purports to do what we want, but
308        // the exactness can be too precise. For example, negative zero can
309        // never be exactly converted to an integer.
310        //
311        // Instead, we ask APFloat to round itself to an integral value - this
312        // preserves sign-of-zero - then compare the result with the original.
313        //
314        const APFloat &F = CF->getValueAPF();
315
316        // First, weed out obviously incorrect values. Non-finite numbers
317        // can't be represented and neither can negative zero, unless
318        // we're in fast math mode.
319        if (!F.isFinite() ||
320            (F.isZero() && F.isNegative() && isa<FPMathOperator>(I) &&
321             !I->hasNoSignedZeros())) {
322          seen(I, badRange());
323          Abort = true;
324          break;
325        }
326
327        APFloat NewF = F;
328        auto Res = NewF.roundToIntegral(APFloat::rmNearestTiesToEven);
329        if (Res != APFloat::opOK || NewF != F) {
330          seen(I, badRange());
331          Abort = true;
332          break;
333        }
334        // OK, it's representable. Now get it.
335        APSInt Int(MaxIntegerBW+1, false);
336        bool Exact;
337        CF->getValueAPF().convertToInteger(Int,
338                                           APFloat::rmNearestTiesToEven,
339                                           &Exact);
340        OpRanges.push_back(ConstantRange(Int));
341      } else {
342        llvm_unreachable("Should have already marked this as badRange!");
343      }
344    }
345
346    // Reduce the operands' ranges to a single range and return.
347    if (!Abort)
348      seen(I, Op(OpRanges));
349  }
350}
351
352// If there is a valid transform to be done, do it.
353bool Float2IntPass::validateAndTransform() {
354  bool MadeChange = false;
355
356  // Iterate over every disjoint partition of the def-use graph.
357  for (auto It = ECs.begin(), E = ECs.end(); It != E; ++It) {
358    ConstantRange R(MaxIntegerBW + 1, false);
359    bool Fail = false;
360    Type *ConvertedToTy = nullptr;
361
362    // For every member of the partition, union all the ranges together.
363    for (auto MI = ECs.member_begin(It), ME = ECs.member_end();
364         MI != ME; ++MI) {
365      Instruction *I = *MI;
366      auto SeenI = SeenInsts.find(I);
367      if (SeenI == SeenInsts.end())
368        continue;
369
370      R = R.unionWith(SeenI->second);
371      // We need to ensure I has no users that have not been seen.
372      // If it does, transformation would be illegal.
373      //
374      // Don't count the roots, as they terminate the graphs.
375      if (Roots.count(I) == 0) {
376        // Set the type of the conversion while we're here.
377        if (!ConvertedToTy)
378          ConvertedToTy = I->getType();
379        for (User *U : I->users()) {
380          Instruction *UI = dyn_cast<Instruction>(U);
381          if (!UI || SeenInsts.find(UI) == SeenInsts.end()) {
382            LLVM_DEBUG(dbgs() << "F2I: Failing because of " << *U << "\n");
383            Fail = true;
384            break;
385          }
386        }
387      }
388      if (Fail)
389        break;
390    }
391
392    // If the set was empty, or we failed, or the range is poisonous,
393    // bail out.
394    if (ECs.member_begin(It) == ECs.member_end() || Fail ||
395        R.isFullSet() || R.isSignWrappedSet())
396      continue;
397    assert(ConvertedToTy && "Must have set the convertedtoty by this point!");
398
399    // The number of bits required is the maximum of the upper and
400    // lower limits, plus one so it can be signed.
401    unsigned MinBW = std::max(R.getLower().getMinSignedBits(),
402                              R.getUpper().getMinSignedBits()) + 1;
403    LLVM_DEBUG(dbgs() << "F2I: MinBitwidth=" << MinBW << ", R: " << R << "\n");
404
405    // If we've run off the realms of the exactly representable integers,
406    // the floating point result will differ from an integer approximation.
407
408    // Do we need more bits than are in the mantissa of the type we converted
409    // to? semanticsPrecision returns the number of mantissa bits plus one
410    // for the sign bit.
411    unsigned MaxRepresentableBits
412      = APFloat::semanticsPrecision(ConvertedToTy->getFltSemantics()) - 1;
413    if (MinBW > MaxRepresentableBits) {
414      LLVM_DEBUG(dbgs() << "F2I: Value not guaranteed to be representable!\n");
415      continue;
416    }
417    if (MinBW > 64) {
418      LLVM_DEBUG(
419          dbgs() << "F2I: Value requires more than 64 bits to represent!\n");
420      continue;
421    }
422
423    // OK, R is known to be representable. Now pick a type for it.
424    // FIXME: Pick the smallest legal type that will fit.
425    Type *Ty = (MinBW > 32) ? Type::getInt64Ty(*Ctx) : Type::getInt32Ty(*Ctx);
426
427    for (auto MI = ECs.member_begin(It), ME = ECs.member_end();
428         MI != ME; ++MI)
429      convert(*MI, Ty);
430    MadeChange = true;
431  }
432
433  return MadeChange;
434}
435
436Value *Float2IntPass::convert(Instruction *I, Type *ToTy) {
437  if (ConvertedInsts.find(I) != ConvertedInsts.end())
438    // Already converted this instruction.
439    return ConvertedInsts[I];
440
441  SmallVector<Value*,4> NewOperands;
442  for (Value *V : I->operands()) {
443    // Don't recurse if we're an instruction that terminates the path.
444    if (I->getOpcode() == Instruction::UIToFP ||
445        I->getOpcode() == Instruction::SIToFP) {
446      NewOperands.push_back(V);
447    } else if (Instruction *VI = dyn_cast<Instruction>(V)) {
448      NewOperands.push_back(convert(VI, ToTy));
449    } else if (ConstantFP *CF = dyn_cast<ConstantFP>(V)) {
450      APSInt Val(ToTy->getPrimitiveSizeInBits(), /*isUnsigned=*/false);
451      bool Exact;
452      CF->getValueAPF().convertToInteger(Val,
453                                         APFloat::rmNearestTiesToEven,
454                                         &Exact);
455      NewOperands.push_back(ConstantInt::get(ToTy, Val));
456    } else {
457      llvm_unreachable("Unhandled operand type?");
458    }
459  }
460
461  // Now create a new instruction.
462  IRBuilder<> IRB(I);
463  Value *NewV = nullptr;
464  switch (I->getOpcode()) {
465  default: llvm_unreachable("Unhandled instruction!");
466
467  case Instruction::FPToUI:
468    NewV = IRB.CreateZExtOrTrunc(NewOperands[0], I->getType());
469    break;
470
471  case Instruction::FPToSI:
472    NewV = IRB.CreateSExtOrTrunc(NewOperands[0], I->getType());
473    break;
474
475  case Instruction::FCmp: {
476    CmpInst::Predicate P = mapFCmpPred(cast<CmpInst>(I)->getPredicate());
477    assert(P != CmpInst::BAD_ICMP_PREDICATE && "Unhandled predicate!");
478    NewV = IRB.CreateICmp(P, NewOperands[0], NewOperands[1], I->getName());
479    break;
480  }
481
482  case Instruction::UIToFP:
483    NewV = IRB.CreateZExtOrTrunc(NewOperands[0], ToTy);
484    break;
485
486  case Instruction::SIToFP:
487    NewV = IRB.CreateSExtOrTrunc(NewOperands[0], ToTy);
488    break;
489
490  case Instruction::FNeg:
491    NewV = IRB.CreateNeg(NewOperands[0], I->getName());
492    break;
493
494  case Instruction::FAdd:
495  case Instruction::FSub:
496  case Instruction::FMul:
497    NewV = IRB.CreateBinOp(mapBinOpcode(I->getOpcode()),
498                           NewOperands[0], NewOperands[1],
499                           I->getName());
500    break;
501  }
502
503  // If we're a root instruction, RAUW.
504  if (Roots.count(I))
505    I->replaceAllUsesWith(NewV);
506
507  ConvertedInsts[I] = NewV;
508  return NewV;
509}
510
511// Perform dead code elimination on the instructions we just modified.
512void Float2IntPass::cleanup() {
513  for (auto &I : reverse(ConvertedInsts))
514    I.first->eraseFromParent();
515}
516
517bool Float2IntPass::runImpl(Function &F, const DominatorTree &DT) {
518  LLVM_DEBUG(dbgs() << "F2I: Looking at function " << F.getName() << "\n");
519  // Clear out all state.
520  ECs = EquivalenceClasses<Instruction*>();
521  SeenInsts.clear();
522  ConvertedInsts.clear();
523  Roots.clear();
524
525  Ctx = &F.getParent()->getContext();
526
527  findRoots(F, DT);
528
529  walkBackwards();
530  walkForwards();
531
532  bool Modified = validateAndTransform();
533  if (Modified)
534    cleanup();
535  return Modified;
536}
537
538namespace llvm {
539FunctionPass *createFloat2IntPass() { return new Float2IntLegacyPass(); }
540
541PreservedAnalyses Float2IntPass::run(Function &F, FunctionAnalysisManager &AM) {
542  const DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
543  if (!runImpl(F, DT))
544    return PreservedAnalyses::all();
545
546  PreservedAnalyses PA;
547  PA.preserveSet<CFGAnalyses>();
548  PA.preserve<GlobalsAA>();
549  return PA;
550}
551} // End namespace llvm
552