1/*
2 * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23package org.graalvm.compiler.replacements.test;
24
25import java.util.ArrayList;
26import java.util.Collection;
27import java.util.List;
28
29import org.junit.Assert;
30import org.junit.Test;
31import org.junit.runner.RunWith;
32import org.junit.runners.Parameterized;
33import org.junit.runners.Parameterized.Parameter;
34import org.junit.runners.Parameterized.Parameters;
35
36import org.graalvm.compiler.core.common.type.IntegerStamp;
37import org.graalvm.compiler.core.common.type.StampFactory;
38import org.graalvm.compiler.core.test.GraalCompilerTest;
39import org.graalvm.compiler.nodes.ConstantNode;
40import org.graalvm.compiler.nodes.ParameterNode;
41import org.graalvm.compiler.nodes.PiNode;
42import org.graalvm.compiler.nodes.StructuredGraph;
43import org.graalvm.compiler.nodes.ValueNode;
44import org.graalvm.compiler.nodes.StructuredGraph.AllowAssumptions;
45import org.graalvm.compiler.nodes.calc.MulNode;
46import org.graalvm.compiler.nodes.java.StoreFieldNode;
47import org.graalvm.compiler.phases.common.CanonicalizerPhase;
48import org.graalvm.compiler.phases.tiers.HighTierContext;
49import org.graalvm.compiler.replacements.nodes.arithmetic.IntegerMulExactNode;
50
51@RunWith(Parameterized.class)
52public class IntegerMulExactFoldTest extends GraalCompilerTest {
53
54    public static int SideEffectI;
55    public static long SideEffectL;
56
57    public static void snippetInt(int a, int b) {
58        SideEffectI = Math.multiplyExact(a, b);
59    }
60
61    public static void snippetLong(long a, long b) {
62        SideEffectL = Math.multiplyExact(a, b);
63    }
64
65    private StructuredGraph prepareGraph(String snippet) {
66        StructuredGraph graph = parseEager(snippet, AllowAssumptions.NO);
67        HighTierContext context = getDefaultHighTierContext();
68        new CanonicalizerPhase().apply(graph, context);
69        return graph;
70    }
71
72    @Parameter(0) public long lowerBound1;
73    @Parameter(1) public long upperBound1;
74    @Parameter(2) public long lowerBound2;
75    @Parameter(3) public long upperBound2;
76    @Parameter(4) public int bits;
77
78    @Test
79    public void tryFold() {
80        assert bits == 32 || bits == 64;
81
82        IntegerStamp a = StampFactory.forInteger(bits, lowerBound1, upperBound1);
83        IntegerStamp b = StampFactory.forInteger(bits, lowerBound2, upperBound2);
84
85        // prepare the graph once for the given stamps, if the canonicalize method thinks it does
86        // not overflow it will replace the exact mul with a normal mul
87        StructuredGraph g = prepareGraph(bits == 32 ? "snippetInt" : "snippetLong");
88        List<ParameterNode> params = g.getNodes(ParameterNode.TYPE).snapshot();
89        params.get(0).replaceAtMatchingUsages((g.addOrUnique(new PiNode(params.get(0), a))), x -> x instanceof IntegerMulExactNode);
90        params.get(1).replaceAtMatchingUsages((g.addOrUnique(new PiNode(params.get(1), b))), x -> x instanceof IntegerMulExactNode);
91        new CanonicalizerPhase().apply(g, getDefaultHighTierContext());
92        boolean optimized = g.getNodes().filter(IntegerMulExactNode.class).count() == 0;
93        ValueNode leftOverMull = optimized ? g.getNodes().filter(MulNode.class).first() : g.getNodes().filter(IntegerMulExactNode.class).first();
94        new CanonicalizerPhase().apply(g, getDefaultHighTierContext());
95        if (leftOverMull == null) {
96            // result may be constant if there is no mul exact or mul node left
97            leftOverMull = g.getNodes().filter(StoreFieldNode.class).first().inputs().filter(ConstantNode.class).first();
98        }
99        if (leftOverMull == null) {
100            // even mul got canonicalized so we may end up with one of the original nodes
101            leftOverMull = g.getNodes().filter(PiNode.class).first();
102        }
103        IntegerStamp resultStamp = (IntegerStamp) leftOverMull.stamp();
104
105        // now check for all values in the stamp whether their products overflow overflow
106        for (long l1 = lowerBound1; l1 <= upperBound1; l1++) {
107            for (long l2 = lowerBound2; l2 <= upperBound2; l2++) {
108                try {
109                    long res = mulExact(l1, l2, bits);
110                    Assert.assertTrue(resultStamp.contains(res));
111                } catch (ArithmeticException e) {
112                    Assert.assertFalse(optimized);
113                }
114                if (l2 == Long.MAX_VALUE) {
115                    // do not want to overflow the check loop
116                    break;
117                }
118            }
119            if (l1 == Long.MAX_VALUE) {
120                // do not want to overflow the check loop
121                break;
122            }
123        }
124
125    }
126
127    private static long mulExact(long x, long y, int bits) {
128        long r = x * y;
129        if (bits == 8) {
130            if ((byte) r != r) {
131                throw new ArithmeticException("overflow");
132            }
133        } else if (bits == 16) {
134            if ((short) r != r) {
135                throw new ArithmeticException("overflow");
136            }
137        } else if (bits == 32) {
138            return Math.multiplyExact((int) x, (int) y);
139        } else {
140            return Math.multiplyExact(x, y);
141        }
142        return r;
143    }
144
145    @Parameters(name = "a[{0} - {1}] b[{2} - {3}] bits=32")
146    public static Collection<Object[]> data() {
147        ArrayList<Object[]> tests = new ArrayList<>();
148
149        // zero related
150        addTest(tests, -2, 2, 3, 3, 32);
151        addTest(tests, 0, 0, 1, 1, 32);
152        addTest(tests, 1, 1, 0, 0, 32);
153        addTest(tests, -1, 1, 0, 1, 32);
154        addTest(tests, -1, 1, 1, 1, 32);
155        addTest(tests, -1, 1, -1, 1, 32);
156
157        addTest(tests, -2, 2, 3, 3, 64);
158        addTest(tests, 0, 0, 1, 1, 64);
159        addTest(tests, 1, 1, 0, 0, 64);
160        addTest(tests, -1, 1, 0, 1, 64);
161        addTest(tests, -1, 1, 1, 1, 64);
162        addTest(tests, -1, 1, -1, 1, 64);
163
164        addTest(tests, -2, 2, 3, 3, 32);
165        addTest(tests, 0, 0, 1, 1, 32);
166        addTest(tests, 1, 1, 0, 0, 32);
167        addTest(tests, -1, 1, 0, 1, 32);
168        addTest(tests, -1, 1, 1, 1, 32);
169        addTest(tests, -1, 1, -1, 1, 32);
170
171        addTest(tests, 0, 0, 1, 1, 64);
172        addTest(tests, 1, 1, 0, 0, 64);
173        addTest(tests, -1, 1, 0, 1, 64);
174        addTest(tests, -1, 1, 1, 1, 64);
175        addTest(tests, -1, 1, -1, 1, 64);
176
177        // bounds
178        addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xFF, Integer.MAX_VALUE - 0xFF,
179                        Integer.MAX_VALUE, 32);
180        addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xFFF, -1, -1, 32);
181        addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xFF, Integer.MAX_VALUE - 0xFF,
182                        Integer.MAX_VALUE, 64);
183        addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xFFF, -1, -1, 64);
184        addTest(tests, Long.MIN_VALUE, Long.MIN_VALUE + 0xFFF, -1, -1, 64);
185
186        // constants
187        addTest(tests, 2, 2, 2, 2, 32);
188        addTest(tests, 1, 1, 2, 2, 32);
189        addTest(tests, 2, 2, 4, 4, 32);
190        addTest(tests, 3, 3, 3, 3, 32);
191        addTest(tests, -4, -4, 3, 3, 32);
192        addTest(tests, -4, -4, -3, -3, 32);
193        addTest(tests, 4, 4, -3, -3, 32);
194
195        addTest(tests, 2, 2, 2, 2, 64);
196        addTest(tests, 1, 1, 2, 2, 64);
197        addTest(tests, 3, 3, 3, 3, 64);
198
199        addTest(tests, Long.MAX_VALUE, Long.MAX_VALUE, 1, 1, 64);
200        addTest(tests, Long.MAX_VALUE, Long.MAX_VALUE, -1, -1, 64);
201        addTest(tests, Long.MIN_VALUE, Long.MIN_VALUE, -1, -1, 64);
202        addTest(tests, Long.MIN_VALUE, Long.MIN_VALUE, 1, 1, 64);
203
204        return tests;
205    }
206
207    private static void addTest(ArrayList<Object[]> tests, long lowerBound1, long upperBound1, long lowerBound2, long upperBound2, int bits) {
208        tests.add(new Object[]{lowerBound1, upperBound1, lowerBound2, upperBound2, bits});
209    }
210
211}
212