1//===- llvm/FixedPointBuilder.h - Builder for fixed-point ops ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the FixedPointBuilder class, which is used as a convenient
10// way to lower fixed-point arithmetic operations to LLVM IR.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef LLVM_IR_FIXEDPOINTBUILDER_H
15#define LLVM_IR_FIXEDPOINTBUILDER_H
16
17#include "llvm/ADT/APFixedPoint.h"
18#include "llvm/IR/Constant.h"
19#include "llvm/IR/Constants.h"
20#include "llvm/IR/IRBuilder.h"
21#include "llvm/IR/InstrTypes.h"
22#include "llvm/IR/Instruction.h"
23#include "llvm/IR/IntrinsicInst.h"
24#include "llvm/IR/Intrinsics.h"
25#include "llvm/IR/Type.h"
26#include "llvm/IR/Value.h"
27
28#include <cmath>
29
30namespace llvm {
31
32template <class IRBuilderTy> class FixedPointBuilder {
33  IRBuilderTy &B;
34
35  Value *Convert(Value *Src, const FixedPointSemantics &SrcSema,
36                 const FixedPointSemantics &DstSema, bool DstIsInteger) {
37    unsigned SrcWidth = SrcSema.getWidth();
38    unsigned DstWidth = DstSema.getWidth();
39    unsigned SrcScale = SrcSema.getScale();
40    unsigned DstScale = DstSema.getScale();
41    bool SrcIsSigned = SrcSema.isSigned();
42    bool DstIsSigned = DstSema.isSigned();
43
44    Type *DstIntTy = B.getIntNTy(DstWidth);
45
46    Value *Result = Src;
47    unsigned ResultWidth = SrcWidth;
48
49    // Downscale.
50    if (DstScale < SrcScale) {
51      // When converting to integers, we round towards zero. For negative
52      // numbers, right shifting rounds towards negative infinity. In this case,
53      // we can just round up before shifting.
54      if (DstIsInteger && SrcIsSigned) {
55        Value *Zero = Constant::getNullValue(Result->getType());
56        Value *IsNegative = B.CreateICmpSLT(Result, Zero);
57        Value *LowBits = ConstantInt::get(
58            B.getContext(), APInt::getLowBitsSet(ResultWidth, SrcScale));
59        Value *Rounded = B.CreateAdd(Result, LowBits);
60        Result = B.CreateSelect(IsNegative, Rounded, Result);
61      }
62
63      Result = SrcIsSigned
64                   ? B.CreateAShr(Result, SrcScale - DstScale, "downscale")
65                   : B.CreateLShr(Result, SrcScale - DstScale, "downscale");
66    }
67
68    if (!DstSema.isSaturated()) {
69      // Resize.
70      Result = B.CreateIntCast(Result, DstIntTy, SrcIsSigned, "resize");
71
72      // Upscale.
73      if (DstScale > SrcScale)
74        Result = B.CreateShl(Result, DstScale - SrcScale, "upscale");
75    } else {
76      // Adjust the number of fractional bits.
77      if (DstScale > SrcScale) {
78        // Compare to DstWidth to prevent resizing twice.
79        ResultWidth = std::max(SrcWidth + DstScale - SrcScale, DstWidth);
80        Type *UpscaledTy = B.getIntNTy(ResultWidth);
81        Result = B.CreateIntCast(Result, UpscaledTy, SrcIsSigned, "resize");
82        Result = B.CreateShl(Result, DstScale - SrcScale, "upscale");
83      }
84
85      // Handle saturation.
86      bool LessIntBits = DstSema.getIntegralBits() < SrcSema.getIntegralBits();
87      if (LessIntBits) {
88        Value *Max = ConstantInt::get(
89            B.getContext(),
90            APFixedPoint::getMax(DstSema).getValue().extOrTrunc(ResultWidth));
91        Value *TooHigh = SrcIsSigned ? B.CreateICmpSGT(Result, Max)
92                                     : B.CreateICmpUGT(Result, Max);
93        Result = B.CreateSelect(TooHigh, Max, Result, "satmax");
94      }
95      // Cannot overflow min to dest type if src is unsigned since all fixed
96      // point types can cover the unsigned min of 0.
97      if (SrcIsSigned && (LessIntBits || !DstIsSigned)) {
98        Value *Min = ConstantInt::get(
99            B.getContext(),
100            APFixedPoint::getMin(DstSema).getValue().extOrTrunc(ResultWidth));
101        Value *TooLow = B.CreateICmpSLT(Result, Min);
102        Result = B.CreateSelect(TooLow, Min, Result, "satmin");
103      }
104
105      // Resize the integer part to get the final destination size.
106      if (ResultWidth != DstWidth)
107        Result = B.CreateIntCast(Result, DstIntTy, SrcIsSigned, "resize");
108    }
109    return Result;
110  }
111
112  /// Get the common semantic for two semantics, with the added imposition that
113  /// saturated padded types retain the padding bit.
114  FixedPointSemantics
115  getCommonBinopSemantic(const FixedPointSemantics &LHSSema,
116                         const FixedPointSemantics &RHSSema) {
117    auto C = LHSSema.getCommonSemantics(RHSSema);
118    bool BothPadded =
119        LHSSema.hasUnsignedPadding() && RHSSema.hasUnsignedPadding();
120    return FixedPointSemantics(
121        C.getWidth() + (unsigned)(BothPadded && C.isSaturated()), C.getScale(),
122        C.isSigned(), C.isSaturated(), BothPadded);
123  }
124
125  /// Given a floating point type and a fixed-point semantic, return a floating
126  /// point type which can accommodate the fixed-point semantic. This is either
127  /// \p Ty, or a floating point type with a larger exponent than Ty.
128  Type *getAccommodatingFloatType(Type *Ty, const FixedPointSemantics &Sema) {
129    const fltSemantics *FloatSema = &Ty->getFltSemantics();
130    while (!Sema.fitsInFloatSemantics(*FloatSema))
131      FloatSema = APFixedPoint::promoteFloatSemantics(FloatSema);
132    return Type::getFloatingPointTy(Ty->getContext(), *FloatSema);
133  }
134
135public:
136  FixedPointBuilder(IRBuilderTy &Builder) : B(Builder) {}
137
138  /// Convert an integer value representing a fixed-point number from one
139  /// fixed-point semantic to another fixed-point semantic.
140  /// \p Src     - The source value
141  /// \p SrcSema - The fixed-point semantic of the source value
142  /// \p DstSema - The resulting fixed-point semantic
143  Value *CreateFixedToFixed(Value *Src, const FixedPointSemantics &SrcSema,
144                            const FixedPointSemantics &DstSema) {
145    return Convert(Src, SrcSema, DstSema, false);
146  }
147
148  /// Convert an integer value representing a fixed-point number to an integer
149  /// with the given bit width and signedness.
150  /// \p Src         - The source value
151  /// \p SrcSema     - The fixed-point semantic of the source value
152  /// \p DstWidth    - The bit width of the result value
153  /// \p DstIsSigned - The signedness of the result value
154  Value *CreateFixedToInteger(Value *Src, const FixedPointSemantics &SrcSema,
155                              unsigned DstWidth, bool DstIsSigned) {
156    return Convert(
157        Src, SrcSema,
158        FixedPointSemantics::GetIntegerSemantics(DstWidth, DstIsSigned), true);
159  }
160
161  /// Convert an integer value with the given signedness to an integer value
162  /// representing the given fixed-point semantic.
163  /// \p Src         - The source value
164  /// \p SrcIsSigned - The signedness of the source value
165  /// \p DstSema     - The resulting fixed-point semantic
166  Value *CreateIntegerToFixed(Value *Src, unsigned SrcIsSigned,
167                              const FixedPointSemantics &DstSema) {
168    return Convert(Src,
169                   FixedPointSemantics::GetIntegerSemantics(
170                       Src->getType()->getScalarSizeInBits(), SrcIsSigned),
171                   DstSema, false);
172  }
173
174  Value *CreateFixedToFloating(Value *Src, const FixedPointSemantics &SrcSema,
175                               Type *DstTy) {
176    Value *Result;
177    Type *OpTy = getAccommodatingFloatType(DstTy, SrcSema);
178    // Convert the raw fixed-point value directly to floating point. If the
179    // value is too large to fit, it will be rounded, not truncated.
180    Result = SrcSema.isSigned() ? B.CreateSIToFP(Src, OpTy)
181                                : B.CreateUIToFP(Src, OpTy);
182    // Rescale the integral-in-floating point by the scaling factor. This is
183    // lossless, except for overflow to infinity which is unlikely.
184    Result = B.CreateFMul(Result,
185        ConstantFP::get(OpTy, std::pow(2, -(int)SrcSema.getScale())));
186    if (OpTy != DstTy)
187      Result = B.CreateFPTrunc(Result, DstTy);
188    return Result;
189  }
190
191  Value *CreateFloatingToFixed(Value *Src, const FixedPointSemantics &DstSema) {
192    bool UseSigned = DstSema.isSigned() || DstSema.hasUnsignedPadding();
193    Value *Result = Src;
194    Type *OpTy = getAccommodatingFloatType(Src->getType(), DstSema);
195    if (OpTy != Src->getType())
196      Result = B.CreateFPExt(Result, OpTy);
197    // Rescale the floating point value so that its significant bits (for the
198    // purposes of the conversion) are in the integral range.
199    Result = B.CreateFMul(Result,
200        ConstantFP::get(OpTy, std::pow(2, DstSema.getScale())));
201
202    Type *ResultTy = B.getIntNTy(DstSema.getWidth());
203    if (DstSema.isSaturated()) {
204      Intrinsic::ID IID =
205          UseSigned ? Intrinsic::fptosi_sat : Intrinsic::fptoui_sat;
206      Result = B.CreateIntrinsic(IID, {ResultTy, OpTy}, {Result});
207    } else {
208      Result = UseSigned ? B.CreateFPToSI(Result, ResultTy)
209                         : B.CreateFPToUI(Result, ResultTy);
210    }
211
212    // When saturating unsigned-with-padding using signed operations, we may
213    // get negative values. Emit an extra clamp to zero.
214    if (DstSema.isSaturated() && DstSema.hasUnsignedPadding()) {
215      Constant *Zero = Constant::getNullValue(Result->getType());
216      Result =
217          B.CreateSelect(B.CreateICmpSLT(Result, Zero), Zero, Result, "satmin");
218    }
219
220    return Result;
221  }
222
223  /// Add two fixed-point values and return the result in their common semantic.
224  /// \p LHS     - The left hand side
225  /// \p LHSSema - The semantic of the left hand side
226  /// \p RHS     - The right hand side
227  /// \p RHSSema - The semantic of the right hand side
228  Value *CreateAdd(Value *LHS, const FixedPointSemantics &LHSSema,
229                   Value *RHS, const FixedPointSemantics &RHSSema) {
230    auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
231    bool UseSigned = CommonSema.isSigned() || CommonSema.hasUnsignedPadding();
232
233    Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
234    Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
235
236    Value *Result;
237    if (CommonSema.isSaturated()) {
238      Intrinsic::ID IID = UseSigned ? Intrinsic::sadd_sat : Intrinsic::uadd_sat;
239      Result = B.CreateBinaryIntrinsic(IID, WideLHS, WideRHS);
240    } else {
241      Result = B.CreateAdd(WideLHS, WideRHS);
242    }
243
244    return CreateFixedToFixed(Result, CommonSema,
245                              LHSSema.getCommonSemantics(RHSSema));
246  }
247
248  /// Subtract two fixed-point values and return the result in their common
249  /// semantic.
250  /// \p LHS     - The left hand side
251  /// \p LHSSema - The semantic of the left hand side
252  /// \p RHS     - The right hand side
253  /// \p RHSSema - The semantic of the right hand side
254  Value *CreateSub(Value *LHS, const FixedPointSemantics &LHSSema,
255                   Value *RHS, const FixedPointSemantics &RHSSema) {
256    auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
257    bool UseSigned = CommonSema.isSigned() || CommonSema.hasUnsignedPadding();
258
259    Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
260    Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
261
262    Value *Result;
263    if (CommonSema.isSaturated()) {
264      Intrinsic::ID IID = UseSigned ? Intrinsic::ssub_sat : Intrinsic::usub_sat;
265      Result = B.CreateBinaryIntrinsic(IID, WideLHS, WideRHS);
266    } else {
267      Result = B.CreateSub(WideLHS, WideRHS);
268    }
269
270    // Subtraction can end up below 0 for padded unsigned operations, so emit
271    // an extra clamp in that case.
272    if (CommonSema.isSaturated() && CommonSema.hasUnsignedPadding()) {
273      Constant *Zero = Constant::getNullValue(Result->getType());
274      Result =
275          B.CreateSelect(B.CreateICmpSLT(Result, Zero), Zero, Result, "satmin");
276    }
277
278    return CreateFixedToFixed(Result, CommonSema,
279                              LHSSema.getCommonSemantics(RHSSema));
280  }
281
282  /// Multiply two fixed-point values and return the result in their common
283  /// semantic.
284  /// \p LHS     - The left hand side
285  /// \p LHSSema - The semantic of the left hand side
286  /// \p RHS     - The right hand side
287  /// \p RHSSema - The semantic of the right hand side
288  Value *CreateMul(Value *LHS, const FixedPointSemantics &LHSSema,
289                   Value *RHS, const FixedPointSemantics &RHSSema) {
290    auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
291    bool UseSigned = CommonSema.isSigned() || CommonSema.hasUnsignedPadding();
292
293    Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
294    Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
295
296    Intrinsic::ID IID;
297    if (CommonSema.isSaturated()) {
298      IID = UseSigned ? Intrinsic::smul_fix_sat : Intrinsic::umul_fix_sat;
299    } else {
300      IID = UseSigned ? Intrinsic::smul_fix : Intrinsic::umul_fix;
301    }
302    Value *Result = B.CreateIntrinsic(
303        IID, {WideLHS->getType()},
304        {WideLHS, WideRHS, B.getInt32(CommonSema.getScale())});
305
306    return CreateFixedToFixed(Result, CommonSema,
307                              LHSSema.getCommonSemantics(RHSSema));
308  }
309
310  /// Divide two fixed-point values and return the result in their common
311  /// semantic.
312  /// \p LHS     - The left hand side
313  /// \p LHSSema - The semantic of the left hand side
314  /// \p RHS     - The right hand side
315  /// \p RHSSema - The semantic of the right hand side
316  Value *CreateDiv(Value *LHS, const FixedPointSemantics &LHSSema,
317                   Value *RHS, const FixedPointSemantics &RHSSema) {
318    auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
319    bool UseSigned = CommonSema.isSigned() || CommonSema.hasUnsignedPadding();
320
321    Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
322    Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
323
324    Intrinsic::ID IID;
325    if (CommonSema.isSaturated()) {
326      IID = UseSigned ? Intrinsic::sdiv_fix_sat : Intrinsic::udiv_fix_sat;
327    } else {
328      IID = UseSigned ? Intrinsic::sdiv_fix : Intrinsic::udiv_fix;
329    }
330    Value *Result = B.CreateIntrinsic(
331        IID, {WideLHS->getType()},
332        {WideLHS, WideRHS, B.getInt32(CommonSema.getScale())});
333
334    return CreateFixedToFixed(Result, CommonSema,
335                              LHSSema.getCommonSemantics(RHSSema));
336  }
337
338  /// Left shift a fixed-point value by an unsigned integer value. The integer
339  /// value can be any bit width.
340  /// \p LHS     - The left hand side
341  /// \p LHSSema - The semantic of the left hand side
342  /// \p RHS     - The right hand side
343  Value *CreateShl(Value *LHS, const FixedPointSemantics &LHSSema, Value *RHS) {
344    bool UseSigned = LHSSema.isSigned() || LHSSema.hasUnsignedPadding();
345
346    RHS = B.CreateIntCast(RHS, LHS->getType(), /*IsSigned=*/false);
347
348    Value *Result;
349    if (LHSSema.isSaturated()) {
350      Intrinsic::ID IID = UseSigned ? Intrinsic::sshl_sat : Intrinsic::ushl_sat;
351      Result = B.CreateBinaryIntrinsic(IID, LHS, RHS);
352    } else {
353      Result = B.CreateShl(LHS, RHS);
354    }
355
356    return Result;
357  }
358
359  /// Right shift a fixed-point value by an unsigned integer value. The integer
360  /// value can be any bit width.
361  /// \p LHS     - The left hand side
362  /// \p LHSSema - The semantic of the left hand side
363  /// \p RHS     - The right hand side
364  Value *CreateShr(Value *LHS, const FixedPointSemantics &LHSSema, Value *RHS) {
365    RHS = B.CreateIntCast(RHS, LHS->getType(), false);
366
367    return LHSSema.isSigned() ? B.CreateAShr(LHS, RHS) : B.CreateLShr(LHS, RHS);
368  }
369
370  /// Compare two fixed-point values for equality.
371  /// \p LHS     - The left hand side
372  /// \p LHSSema - The semantic of the left hand side
373  /// \p RHS     - The right hand side
374  /// \p RHSSema - The semantic of the right hand side
375  Value *CreateEQ(Value *LHS, const FixedPointSemantics &LHSSema,
376                  Value *RHS, const FixedPointSemantics &RHSSema) {
377    auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
378
379    Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
380    Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
381
382    return B.CreateICmpEQ(WideLHS, WideRHS);
383  }
384
385  /// Compare two fixed-point values for inequality.
386  /// \p LHS     - The left hand side
387  /// \p LHSSema - The semantic of the left hand side
388  /// \p RHS     - The right hand side
389  /// \p RHSSema - The semantic of the right hand side
390  Value *CreateNE(Value *LHS, const FixedPointSemantics &LHSSema,
391                  Value *RHS, const FixedPointSemantics &RHSSema) {
392    auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
393
394    Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
395    Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
396
397    return B.CreateICmpNE(WideLHS, WideRHS);
398  }
399
400  /// Compare two fixed-point values as LHS < RHS.
401  /// \p LHS     - The left hand side
402  /// \p LHSSema - The semantic of the left hand side
403  /// \p RHS     - The right hand side
404  /// \p RHSSema - The semantic of the right hand side
405  Value *CreateLT(Value *LHS, const FixedPointSemantics &LHSSema,
406                  Value *RHS, const FixedPointSemantics &RHSSema) {
407    auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
408
409    Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
410    Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
411
412    return CommonSema.isSigned() ? B.CreateICmpSLT(WideLHS, WideRHS)
413                                 : B.CreateICmpULT(WideLHS, WideRHS);
414  }
415
416  /// Compare two fixed-point values as LHS <= RHS.
417  /// \p LHS     - The left hand side
418  /// \p LHSSema - The semantic of the left hand side
419  /// \p RHS     - The right hand side
420  /// \p RHSSema - The semantic of the right hand side
421  Value *CreateLE(Value *LHS, const FixedPointSemantics &LHSSema,
422                  Value *RHS, const FixedPointSemantics &RHSSema) {
423    auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
424
425    Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
426    Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
427
428    return CommonSema.isSigned() ? B.CreateICmpSLE(WideLHS, WideRHS)
429                                 : B.CreateICmpULE(WideLHS, WideRHS);
430  }
431
432  /// Compare two fixed-point values as LHS > RHS.
433  /// \p LHS     - The left hand side
434  /// \p LHSSema - The semantic of the left hand side
435  /// \p RHS     - The right hand side
436  /// \p RHSSema - The semantic of the right hand side
437  Value *CreateGT(Value *LHS, const FixedPointSemantics &LHSSema,
438                  Value *RHS, const FixedPointSemantics &RHSSema) {
439    auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
440
441    Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
442    Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
443
444    return CommonSema.isSigned() ? B.CreateICmpSGT(WideLHS, WideRHS)
445                                 : B.CreateICmpUGT(WideLHS, WideRHS);
446  }
447
448  /// Compare two fixed-point values as LHS >= RHS.
449  /// \p LHS     - The left hand side
450  /// \p LHSSema - The semantic of the left hand side
451  /// \p RHS     - The right hand side
452  /// \p RHSSema - The semantic of the right hand side
453  Value *CreateGE(Value *LHS, const FixedPointSemantics &LHSSema,
454                  Value *RHS, const FixedPointSemantics &RHSSema) {
455    auto CommonSema = getCommonBinopSemantic(LHSSema, RHSSema);
456
457    Value *WideLHS = CreateFixedToFixed(LHS, LHSSema, CommonSema);
458    Value *WideRHS = CreateFixedToFixed(RHS, RHSSema, CommonSema);
459
460    return CommonSema.isSigned() ? B.CreateICmpSGE(WideLHS, WideRHS)
461                                 : B.CreateICmpUGE(WideLHS, WideRHS);
462  }
463};
464
465} // end namespace llvm
466
467#endif // LLVM_IR_FIXEDPOINTBUILDER_H
468