1/*
2 * Copyright (c) 2011, 2015, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23package org.graalvm.compiler.nodes.calc;
24
25import static org.graalvm.compiler.nodeinfo.NodeCycles.CYCLES_2;
26
27import org.graalvm.compiler.core.common.type.ArithmeticOpTable;
28import org.graalvm.compiler.core.common.type.IntegerStamp;
29import org.graalvm.compiler.core.common.type.ArithmeticOpTable.BinaryOp;
30import org.graalvm.compiler.core.common.type.ArithmeticOpTable.BinaryOp.Mul;
31import org.graalvm.compiler.core.common.type.Stamp;
32import org.graalvm.compiler.graph.NodeClass;
33import org.graalvm.compiler.graph.spi.Canonicalizable.BinaryCommutative;
34import org.graalvm.compiler.graph.spi.CanonicalizerTool;
35import org.graalvm.compiler.lir.gen.ArithmeticLIRGeneratorTool;
36import org.graalvm.compiler.nodeinfo.NodeInfo;
37import org.graalvm.compiler.nodes.ConstantNode;
38import org.graalvm.compiler.nodes.ValueNode;
39import org.graalvm.compiler.nodes.spi.NodeLIRBuilderTool;
40
41import jdk.vm.ci.code.CodeUtil;
42import jdk.vm.ci.meta.Constant;
43import jdk.vm.ci.meta.PrimitiveConstant;
44import jdk.vm.ci.meta.Value;
45
46@NodeInfo(shortName = "*", cycles = CYCLES_2)
47public class MulNode extends BinaryArithmeticNode<Mul> implements NarrowableArithmeticNode, BinaryCommutative<ValueNode> {
48
49    public static final NodeClass<MulNode> TYPE = NodeClass.create(MulNode.class);
50
51    public MulNode(ValueNode x, ValueNode y) {
52        this(TYPE, x, y);
53    }
54
55    protected MulNode(NodeClass<? extends MulNode> c, ValueNode x, ValueNode y) {
56        super(c, ArithmeticOpTable::getMul, x, y);
57    }
58
59    public static ValueNode create(ValueNode x, ValueNode y) {
60        BinaryOp<Mul> op = ArithmeticOpTable.forStamp(x.stamp()).getMul();
61        Stamp stamp = op.foldStamp(x.stamp(), y.stamp());
62        ConstantNode tryConstantFold = tryConstantFold(op, x, y, stamp);
63        if (tryConstantFold != null) {
64            return tryConstantFold;
65        }
66        return canonical(null, op, stamp, x, y);
67    }
68
69    @Override
70    public ValueNode canonical(CanonicalizerTool tool, ValueNode forX, ValueNode forY) {
71        ValueNode ret = super.canonical(tool, forX, forY);
72        if (ret != this) {
73            return ret;
74        }
75
76        if (forX.isConstant() && !forY.isConstant()) {
77            // we try to swap and canonicalize
78            ValueNode improvement = canonical(tool, forY, forX);
79            if (improvement != this) {
80                return improvement;
81            }
82            // if this fails we only swap
83            return new MulNode(forY, forX);
84        }
85        BinaryOp<Mul> op = getOp(forX, forY);
86        return canonical(this, op, stamp(), forX, forY);
87    }
88
89    private static ValueNode canonical(MulNode self, BinaryOp<Mul> op, Stamp stamp, ValueNode forX, ValueNode forY) {
90        if (forY.isConstant()) {
91            Constant c = forY.asConstant();
92            if (op.isNeutral(c)) {
93                return forX;
94            }
95
96            if (c instanceof PrimitiveConstant && ((PrimitiveConstant) c).getJavaKind().isNumericInteger()) {
97                long i = ((PrimitiveConstant) c).asLong();
98
99                if (i == 0) {
100                    return ConstantNode.forIntegerStamp(stamp, 0);
101                } else if (i == 1) {
102                    return forX;
103                } else if (i == -1) {
104                    return NegateNode.create(forX);
105                } else if (i > 0) {
106                    if (CodeUtil.isPowerOf2(i)) {
107                        return new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(i)));
108                    } else if (CodeUtil.isPowerOf2(i - 1)) {
109                        return AddNode.create(new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(i - 1))), forX);
110                    } else if (CodeUtil.isPowerOf2(i + 1)) {
111                        return SubNode.create(new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(i + 1))), forX);
112                    } else {
113                        int bitCount = Long.bitCount(i);
114                        long highestBitValue = Long.highestOneBit(i);
115                        if (bitCount == 2) {
116                            // e.g., 0b1000_0010
117                            long lowerBitValue = i - highestBitValue;
118                            assert highestBitValue > 0 && lowerBitValue > 0;
119                            ValueNode left = new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(highestBitValue)));
120                            ValueNode right = lowerBitValue == 1 ? forX : new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(lowerBitValue)));
121                            return AddNode.create(left, right);
122                        } else {
123                            // e.g., 0b1111_1101
124                            int shiftToRoundUpToPowerOf2 = CodeUtil.log2(highestBitValue) + 1;
125                            long subValue = (1 << shiftToRoundUpToPowerOf2) - i;
126                            if (CodeUtil.isPowerOf2(subValue) && shiftToRoundUpToPowerOf2 < ((IntegerStamp) stamp).getBits()) {
127                                assert CodeUtil.log2(subValue) >= 1;
128                                ValueNode left = new LeftShiftNode(forX, ConstantNode.forInt(shiftToRoundUpToPowerOf2));
129                                ValueNode right = new LeftShiftNode(forX, ConstantNode.forInt(CodeUtil.log2(subValue)));
130                                return SubNode.create(left, right);
131                            }
132                        }
133                    }
134                } else if (i < 0) {
135                    if (CodeUtil.isPowerOf2(-i)) {
136                        return NegateNode.create(LeftShiftNode.create(forX, ConstantNode.forInt(CodeUtil.log2(-i))));
137                    }
138                }
139            }
140
141            if (op.isAssociative()) {
142                // canonicalize expressions like "(a * 1) * 2"
143                return reassociate(self != null ? self : (MulNode) new MulNode(forX, forY).maybeCommuteInputs(), ValueNode.isConstantPredicate(), forX, forY);
144            }
145        }
146        return self != null ? self : new MulNode(forX, forY).maybeCommuteInputs();
147    }
148
149    @Override
150    public void generate(NodeLIRBuilderTool nodeValueMap, ArithmeticLIRGeneratorTool gen) {
151        Value op1 = nodeValueMap.operand(getX());
152        Value op2 = nodeValueMap.operand(getY());
153        if (shouldSwapInputs(nodeValueMap)) {
154            Value tmp = op1;
155            op1 = op2;
156            op2 = tmp;
157        }
158        nodeValueMap.setResult(this, gen.emitMul(op1, op2, false));
159    }
160}
161