1/*
2 * Copyright (c) 2017, 2017, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23package org.graalvm.compiler.nodes.calc;
24
25import jdk.vm.ci.meta.ConstantReflectionProvider;
26import jdk.vm.ci.meta.MetaAccessProvider;
27import org.graalvm.compiler.core.common.calc.Condition;
28import org.graalvm.compiler.core.common.type.IntegerStamp;
29import org.graalvm.compiler.core.common.type.Stamp;
30import org.graalvm.compiler.graph.NodeClass;
31import org.graalvm.compiler.nodeinfo.NodeInfo;
32import org.graalvm.compiler.nodes.ConstantNode;
33import org.graalvm.compiler.nodes.LogicConstantNode;
34import org.graalvm.compiler.nodes.LogicNegationNode;
35import org.graalvm.compiler.nodes.LogicNode;
36import org.graalvm.compiler.nodes.ValueNode;
37import org.graalvm.compiler.nodes.util.GraphUtil;
38
39import jdk.vm.ci.meta.TriState;
40import org.graalvm.compiler.options.OptionValues;
41
42/**
43 * Common super-class for "a < b" comparisons both {@linkplain IntegerLowerThanNode signed} and
44 * {@linkplain IntegerBelowNode unsigned}.
45 */
46@NodeInfo()
47public abstract class IntegerLowerThanNode extends CompareNode {
48    public static final NodeClass<IntegerLowerThanNode> TYPE = NodeClass.create(IntegerLowerThanNode.class);
49    private final LowerOp op;
50
51    protected IntegerLowerThanNode(NodeClass<? extends CompareNode> c, ValueNode x, ValueNode y, LowerOp op) {
52        super(c, op.getCondition(), false, x, y);
53        this.op = op;
54    }
55
56    protected LowerOp getOp() {
57        return op;
58    }
59
60    @Override
61    public Stamp getSucceedingStampForX(boolean negated, Stamp xStampGeneric, Stamp yStampGeneric) {
62        return getSucceedingStampForX(negated, !negated, xStampGeneric, yStampGeneric, getX(), getY());
63    }
64
65    @Override
66    public Stamp getSucceedingStampForY(boolean negated, Stamp xStampGeneric, Stamp yStampGeneric) {
67        return getSucceedingStampForX(!negated, !negated, yStampGeneric, xStampGeneric, getY(), getX());
68    }
69
70    private Stamp getSucceedingStampForX(boolean mirror, boolean strict, Stamp xStampGeneric, Stamp yStampGeneric, ValueNode forX, ValueNode forY) {
71        Stamp s = getSucceedingStampForX(mirror, strict, xStampGeneric, yStampGeneric);
72        if (s != null) {
73            return s;
74        }
75        if (forY instanceof AddNode && xStampGeneric instanceof IntegerStamp) {
76            IntegerStamp xStamp = (IntegerStamp) xStampGeneric;
77            AddNode addNode = (AddNode) forY;
78            IntegerStamp aStamp = null;
79            if (addNode.getX() == forX && addNode.getY().stamp() instanceof IntegerStamp) {
80                // x < x + a
81                aStamp = (IntegerStamp) addNode.getY().stamp();
82            } else if (addNode.getY() == forX && addNode.getX().stamp() instanceof IntegerStamp) {
83                // x < a + x
84                aStamp = (IntegerStamp) addNode.getX().stamp();
85            }
86            if (aStamp != null) {
87                IntegerStamp result = getOp().getSucceedingStampForXLowerXPlusA(mirror, strict, aStamp);
88                result = (IntegerStamp) xStamp.tryImproveWith(result);
89                if (result != null) {
90                    return result;
91                }
92            }
93        }
94        return null;
95    }
96
97    private Stamp getSucceedingStampForX(boolean mirror, boolean strict, Stamp xStampGeneric, Stamp yStampGeneric) {
98        if (xStampGeneric instanceof IntegerStamp) {
99            IntegerStamp xStamp = (IntegerStamp) xStampGeneric;
100            if (yStampGeneric instanceof IntegerStamp) {
101                IntegerStamp yStamp = (IntegerStamp) yStampGeneric;
102                assert yStamp.getBits() == xStamp.getBits();
103                Stamp s = getOp().getSucceedingStampForX(xStamp, yStamp, mirror, strict);
104                if (s != null) {
105                    return s;
106                }
107            }
108        }
109        return null;
110    }
111
112    @Override
113    public TriState tryFold(Stamp xStampGeneric, Stamp yStampGeneric) {
114        return getOp().tryFold(xStampGeneric, yStampGeneric);
115    }
116
117    public abstract static class LowerOp extends CompareOp {
118        @Override
119        public LogicNode canonical(ConstantReflectionProvider constantReflection, MetaAccessProvider metaAccess, OptionValues options, Integer smallestCompareWidth, Condition condition,
120                        boolean unorderedIsTrue, ValueNode forX, ValueNode forY) {
121            LogicNode result = super.canonical(constantReflection, metaAccess, options, smallestCompareWidth, condition, unorderedIsTrue, forX, forY);
122            if (result != null) {
123                return result;
124            }
125            LogicNode synonym = findSynonym(forX, forY);
126            if (synonym != null) {
127                return synonym;
128            }
129            return null;
130        }
131
132        protected abstract long upperBound(IntegerStamp stamp);
133
134        protected abstract long lowerBound(IntegerStamp stamp);
135
136        protected abstract int compare(long a, long b);
137
138        protected abstract long min(long a, long b);
139
140        protected abstract long max(long a, long b);
141
142        protected long min(long a, long b, int bits) {
143            return min(cast(a, bits), cast(b, bits));
144        }
145
146        protected long max(long a, long b, int bits) {
147            return max(cast(a, bits), cast(b, bits));
148        }
149
150        protected abstract long cast(long a, int bits);
151
152        protected abstract long minValue(int bits);
153
154        protected abstract long maxValue(int bits);
155
156        protected abstract IntegerStamp forInteger(int bits, long min, long max);
157
158        protected abstract Condition getCondition();
159
160        protected abstract IntegerLowerThanNode createNode(ValueNode x, ValueNode y);
161
162        public LogicNode create(ValueNode x, ValueNode y) {
163            LogicNode result = CompareNode.tryConstantFoldPrimitive(getCondition(), x, y, false);
164            if (result != null) {
165                return result;
166            } else {
167                result = findSynonym(x, y);
168                if (result != null) {
169                    return result;
170                }
171                return createNode(x, y);
172            }
173        }
174
175        protected LogicNode findSynonym(ValueNode forX, ValueNode forY) {
176            if (GraphUtil.unproxify(forX) == GraphUtil.unproxify(forY)) {
177                return LogicConstantNode.contradiction();
178            }
179            TriState fold = tryFold(forX.stamp(), forY.stamp());
180            if (fold.isTrue()) {
181                return LogicConstantNode.tautology();
182            } else if (fold.isFalse()) {
183                return LogicConstantNode.contradiction();
184            }
185            if (forY.stamp() instanceof IntegerStamp) {
186                IntegerStamp yStamp = (IntegerStamp) forY.stamp();
187                int bits = yStamp.getBits();
188                if (forX.isJavaConstant() && !forY.isConstant()) {
189                    // bring the constant on the right
190                    long xValue = forX.asJavaConstant().asLong();
191                    if (xValue != maxValue(bits)) {
192                        // c < x <=> !(c >= x) <=> !(x <= c) <=> !(x < c + 1)
193                        return LogicNegationNode.create(create(forY, ConstantNode.forIntegerStamp(yStamp, xValue + 1)));
194                    }
195                }
196                if (forY.isJavaConstant()) {
197                    long yValue = forY.asJavaConstant().asLong();
198                    if (yValue == maxValue(bits)) {
199                        // x < MAX <=> x != MAX
200                        return LogicNegationNode.create(IntegerEqualsNode.create(forX, forY));
201                    }
202                    if (yValue == minValue(bits) + 1) {
203                        // x < MIN + 1 <=> x <= MIN <=> x == MIN
204                        return IntegerEqualsNode.create(forX, ConstantNode.forIntegerStamp(yStamp, minValue(bits)));
205                    }
206                } else if (forY instanceof AddNode) {
207                    AddNode addNode = (AddNode) forY;
208                    LogicNode canonical = canonicalizeXLowerXPlusA(forX, addNode, false, true);
209                    if (canonical != null) {
210                        return canonical;
211                    }
212                }
213                if (forX instanceof AddNode) {
214                    AddNode addNode = (AddNode) forX;
215                    LogicNode canonical = canonicalizeXLowerXPlusA(forY, addNode, true, false);
216                    if (canonical != null) {
217                        return canonical;
218                    }
219                }
220            }
221            return null;
222        }
223
224        private LogicNode canonicalizeXLowerXPlusA(ValueNode forX, AddNode addNode, boolean mirrored, boolean strict) {
225            // x < x + a
226            IntegerStamp succeedingXStamp;
227            boolean exact;
228            if (addNode.getX() == forX && addNode.getY().stamp() instanceof IntegerStamp) {
229                IntegerStamp aStamp = (IntegerStamp) addNode.getY().stamp();
230                succeedingXStamp = getSucceedingStampForXLowerXPlusA(mirrored, strict, aStamp);
231                exact = aStamp.lowerBound() == aStamp.upperBound();
232            } else if (addNode.getY() == forX && addNode.getX().stamp() instanceof IntegerStamp) {
233                IntegerStamp aStamp = (IntegerStamp) addNode.getX().stamp();
234                succeedingXStamp = getSucceedingStampForXLowerXPlusA(mirrored, strict, aStamp);
235                exact = aStamp.lowerBound() == aStamp.upperBound();
236            } else {
237                return null;
238            }
239            if (succeedingXStamp.join(forX.stamp()).isEmpty()) {
240                return LogicConstantNode.contradiction();
241            } else if (exact && !succeedingXStamp.isEmpty()) {
242                int bits = succeedingXStamp.getBits();
243                if (compare(lowerBound(succeedingXStamp), minValue(bits)) > 0) {
244                    assert upperBound(succeedingXStamp) == maxValue(bits);
245                    // x must be in [L..MAX] <=> x >= L <=> !(x < L)
246                    return LogicNegationNode.create(create(forX, ConstantNode.forIntegerStamp(succeedingXStamp, lowerBound(succeedingXStamp))));
247                } else if (compare(upperBound(succeedingXStamp), maxValue(bits)) < 0) {
248                    // x must be in [MIN..H] <=> x <= H <=> !(H < x)
249                    return LogicNegationNode.create(create(ConstantNode.forIntegerStamp(succeedingXStamp, upperBound(succeedingXStamp)), forX));
250                }
251            }
252            return null;
253        }
254
255        protected TriState tryFold(Stamp xStampGeneric, Stamp yStampGeneric) {
256            if (xStampGeneric instanceof IntegerStamp && yStampGeneric instanceof IntegerStamp) {
257                IntegerStamp xStamp = (IntegerStamp) xStampGeneric;
258                IntegerStamp yStamp = (IntegerStamp) yStampGeneric;
259                if (compare(upperBound(xStamp), lowerBound(yStamp)) < 0) {
260                    return TriState.TRUE;
261                }
262                if (compare(lowerBound(xStamp), upperBound(yStamp)) >= 0) {
263                    return TriState.FALSE;
264                }
265            }
266            return TriState.UNKNOWN;
267        }
268
269        protected IntegerStamp getSucceedingStampForX(IntegerStamp xStamp, IntegerStamp yStamp, boolean mirror, boolean strict) {
270            int bits = xStamp.getBits();
271            assert yStamp.getBits() == bits;
272            if (mirror) {
273                long low = lowerBound(yStamp);
274                if (strict) {
275                    if (low == maxValue(bits)) {
276                        return null;
277                    }
278                    low += 1;
279                }
280                if (compare(low, lowerBound(xStamp)) > 0) {
281                    return forInteger(bits, low, upperBound(xStamp));
282                }
283            } else {
284                // x < y, i.e., x < y <= Y_UPPER_BOUND so x <= Y_UPPER_BOUND - 1
285                long low = upperBound(yStamp);
286                if (strict) {
287                    if (low == minValue(bits)) {
288                        return null;
289                    }
290                    low -= 1;
291                }
292                if (compare(low, upperBound(xStamp)) < 0) {
293                    return forInteger(bits, lowerBound(xStamp), low);
294                }
295            }
296            return null;
297        }
298
299        protected IntegerStamp getSucceedingStampForXLowerXPlusA(boolean mirrored, boolean strict, IntegerStamp a) {
300            int bits = a.getBits();
301            long min = minValue(bits);
302            long max = maxValue(bits);
303            /*
304             * if x < x + a <=> x + a didn't overflow:
305             *
306             * x is outside ]MAX - a, MAX], i.e., inside [MIN, MAX - a]
307             *
308             * if a is negative those bounds wrap around correctly.
309             *
310             * If a is exactly zero this gives an unbounded stamp (any integer) in the positive case
311             * and an empty stamp in the negative case: if x |<| x is true, then either x has no
312             * value or any value...
313             *
314             * This does not use upper/lowerBound from LowerOp because it's about the (signed)
315             * addition not the comparison.
316             */
317            if (mirrored) {
318                if (a.contains(0)) {
319                    // a may be zero
320                    return a.unrestricted();
321                }
322                return forInteger(bits, min(max - a.lowerBound() + 1, max - a.upperBound() + 1, bits), max);
323            } else {
324                long aLower = a.lowerBound();
325                long aUpper = a.upperBound();
326                if (strict) {
327                    if (aLower == 0) {
328                        aLower = 1;
329                    }
330                    if (aUpper == 0) {
331                        aUpper = -1;
332                    }
333                    if (aLower > aUpper) {
334                        // impossible
335                        return a.empty();
336                    }
337                }
338                if (aLower < 0 && aUpper > 0) {
339                    // a may be zero
340                    return a.unrestricted();
341                }
342                return forInteger(bits, min, max(max - aLower, max - aUpper, bits));
343            }
344        }
345    }
346}
347