1/*
2 * Copyright (c) 2013, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23package org.graalvm.compiler.test;
24
25import java.io.PrintStream;
26import java.io.PrintWriter;
27import java.lang.reflect.Field;
28import java.lang.reflect.Method;
29import java.util.Arrays;
30
31import org.junit.Assert;
32import org.junit.internal.ComparisonCriteria;
33import org.junit.internal.ExactComparisonCriteria;
34
35import sun.misc.Unsafe;
36
37/**
38 * Base class that contains common utility methods and classes useful in unit tests.
39 */
40public class GraalTest {
41
42    public static final Unsafe UNSAFE;
43    static {
44        try {
45            Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
46            theUnsafe.setAccessible(true);
47            UNSAFE = (Unsafe) theUnsafe.get(Unsafe.class);
48        } catch (Exception e) {
49            throw new RuntimeException("exception while trying to get Unsafe", e);
50        }
51    }
52
53    public static final boolean Java8OrEarlier = System.getProperty("java.specification.version").compareTo("1.9") < 0;
54
55    protected Method getMethod(String methodName) {
56        return getMethod(getClass(), methodName);
57    }
58
59    protected Method getMethod(Class<?> clazz, String methodName) {
60        Method found = null;
61        for (Method m : clazz.getMethods()) {
62            if (m.getName().equals(methodName)) {
63                Assert.assertNull(found);
64                found = m;
65            }
66        }
67        if (found == null) {
68            /* Now look for non-public methods (but this does not look in superclasses). */
69            for (Method m : clazz.getDeclaredMethods()) {
70                if (m.getName().equals(methodName)) {
71                    Assert.assertNull(found);
72                    found = m;
73                }
74            }
75        }
76        if (found != null) {
77            return found;
78        } else {
79            throw new RuntimeException("method not found: " + methodName);
80        }
81    }
82
83    protected Method getMethod(Class<?> clazz, String methodName, Class<?>[] parameterTypes) {
84        try {
85            return clazz.getMethod(methodName, parameterTypes);
86        } catch (NoSuchMethodException | SecurityException e) {
87            throw new RuntimeException("method not found: " + methodName + "" + Arrays.toString(parameterTypes));
88        }
89    }
90
91    /**
92     * Compares two given objects for {@linkplain Assert#assertEquals(Object, Object) equality}.
93     * Does a deep copy equality comparison if {@code expected} is an array.
94     */
95    protected void assertDeepEquals(Object expected, Object actual) {
96        assertDeepEquals(null, expected, actual);
97    }
98
99    /**
100     * Compares two given objects for {@linkplain Assert#assertEquals(Object, Object) equality}.
101     * Does a deep copy equality comparison if {@code expected} is an array.
102     *
103     * @param message the identifying message for the {@link AssertionError}
104     */
105    protected void assertDeepEquals(String message, Object expected, Object actual) {
106        if (ulpsDelta() > 0) {
107            assertDeepEquals(message, expected, actual, ulpsDelta());
108        } else {
109            assertDeepEquals(message, expected, actual, equalFloatsOrDoublesDelta());
110        }
111    }
112
113    /**
114     * Compares two given values for equality, doing a recursive test if both values are arrays of
115     * the same type.
116     *
117     * @param message the identifying message for the {@link AssertionError}
118     * @param delta the maximum delta between two doubles or floats for which both numbers are still
119     *            considered equal.
120     */
121    protected void assertDeepEquals(String message, Object expected, Object actual, double delta) {
122        if (expected != null && actual != null) {
123            Class<?> expectedClass = expected.getClass();
124            Class<?> actualClass = actual.getClass();
125            if (expectedClass.isArray()) {
126                Assert.assertTrue(message, expected != null);
127                Assert.assertTrue(message, actual != null);
128                Assert.assertEquals(message, expectedClass, actual.getClass());
129                if (expected instanceof int[]) {
130                    Assert.assertArrayEquals(message, (int[]) expected, (int[]) actual);
131                } else if (expected instanceof byte[]) {
132                    Assert.assertArrayEquals(message, (byte[]) expected, (byte[]) actual);
133                } else if (expected instanceof char[]) {
134                    Assert.assertArrayEquals(message, (char[]) expected, (char[]) actual);
135                } else if (expected instanceof short[]) {
136                    Assert.assertArrayEquals(message, (short[]) expected, (short[]) actual);
137                } else if (expected instanceof float[]) {
138                    Assert.assertArrayEquals(message, (float[]) expected, (float[]) actual, (float) delta);
139                } else if (expected instanceof long[]) {
140                    Assert.assertArrayEquals(message, (long[]) expected, (long[]) actual);
141                } else if (expected instanceof double[]) {
142                    Assert.assertArrayEquals(message, (double[]) expected, (double[]) actual, delta);
143                } else if (expected instanceof boolean[]) {
144                    new ExactComparisonCriteria().arrayEquals(message, expected, actual);
145                } else if (expected instanceof Object[]) {
146                    new ComparisonCriteria() {
147                        @Override
148                        protected void assertElementsEqual(Object e, Object a) {
149                            assertDeepEquals(message, e, a, delta);
150                        }
151                    }.arrayEquals(message, expected, actual);
152                } else {
153                    Assert.fail((message == null ? "" : message) + "non-array value encountered: " + expected);
154                }
155            } else if (expectedClass.equals(double.class) && actualClass.equals(double.class)) {
156                Assert.assertEquals((double) expected, (double) actual, delta);
157            } else if (expectedClass.equals(float.class) && actualClass.equals(float.class)) {
158                Assert.assertEquals((float) expected, (float) actual, delta);
159            } else {
160                Assert.assertEquals(message, expected, actual);
161            }
162        } else {
163            Assert.assertEquals(message, expected, actual);
164        }
165    }
166
167    /**
168     * Compares two given values for equality, doing a recursive test if both values are arrays of
169     * the same type. Uses {@linkplain StrictMath#ulp(float) ULP}s for comparison of floats.
170     *
171     * @param message the identifying message for the {@link AssertionError}
172     * @param ulpsDelta the maximum allowed ulps difference between two doubles or floats for which
173     *            both numbers are still considered equal.
174     */
175    protected void assertDeepEquals(String message, Object expected, Object actual, int ulpsDelta) {
176        ComparisonCriteria doubleUlpsDeltaCriteria = new ComparisonCriteria() {
177            @Override
178            protected void assertElementsEqual(Object e, Object a) {
179                assertTrue(message, e instanceof Double && a instanceof Double);
180                // determine acceptable error based on whether it is a normal number or a NaN/Inf
181                double de = (Double) e;
182                double epsilon = (!Double.isNaN(de) && Double.isFinite(de) ? ulpsDelta * Math.ulp(de) : 0);
183                Assert.assertEquals(message, (Double) e, (Double) a, epsilon);
184            }
185        };
186
187        ComparisonCriteria floatUlpsDeltaCriteria = new ComparisonCriteria() {
188            @Override
189            protected void assertElementsEqual(Object e, Object a) {
190                assertTrue(message, e instanceof Float && a instanceof Float);
191                // determine acceptable error based on whether it is a normal number or a NaN/Inf
192                float fe = (Float) e;
193                float epsilon = (!Float.isNaN(fe) && Float.isFinite(fe) ? ulpsDelta * Math.ulp(fe) : 0);
194                Assert.assertEquals(message, (Float) e, (Float) a, epsilon);
195            }
196        };
197
198        if (expected != null && actual != null) {
199            Class<?> expectedClass = expected.getClass();
200            Class<?> actualClass = actual.getClass();
201            if (expectedClass.isArray()) {
202                Assert.assertEquals(message, expectedClass, actualClass);
203                if (expected instanceof double[] || expected instanceof Object[]) {
204                    doubleUlpsDeltaCriteria.arrayEquals(message, expected, actual);
205                    return;
206                } else if (expected instanceof float[] || expected instanceof Object[]) {
207                    floatUlpsDeltaCriteria.arrayEquals(message, expected, actual);
208                    return;
209                }
210            } else if (expectedClass.equals(double.class) && actualClass.equals(double.class)) {
211                doubleUlpsDeltaCriteria.arrayEquals(message, expected, actual);
212                return;
213            } else if (expectedClass.equals(float.class) && actualClass.equals(float.class)) {
214                floatUlpsDeltaCriteria.arrayEquals(message, expected, actual);
215                return;
216            }
217        }
218        // anything else just use the non-ulps version
219        assertDeepEquals(message, expected, actual, equalFloatsOrDoublesDelta());
220    }
221
222    /**
223     * Gets the value used by {@link #assertDeepEquals(Object, Object)} and
224     * {@link #assertDeepEquals(String, Object, Object)} for the maximum delta between two doubles
225     * or floats for which both numbers are still considered equal.
226     */
227    protected double equalFloatsOrDoublesDelta() {
228        return 0.0D;
229    }
230
231    // unless overridden ulpsDelta is not used
232    protected int ulpsDelta() {
233        return 0;
234    }
235
236    @SuppressWarnings("serial")
237    public static class MultiCauseAssertionError extends AssertionError {
238
239        private Throwable[] causes;
240
241        public MultiCauseAssertionError(String message, Throwable... causes) {
242            super(message);
243            this.causes = causes;
244        }
245
246        @Override
247        public void printStackTrace(PrintStream out) {
248            super.printStackTrace(out);
249            int num = 0;
250            for (Throwable cause : causes) {
251                if (cause != null) {
252                    out.print("cause " + (num++));
253                    cause.printStackTrace(out);
254                }
255            }
256        }
257
258        @Override
259        public void printStackTrace(PrintWriter out) {
260            super.printStackTrace(out);
261            int num = 0;
262            for (Throwable cause : causes) {
263                if (cause != null) {
264                    out.print("cause " + (num++) + ": ");
265                    cause.printStackTrace(out);
266                }
267            }
268        }
269    }
270
271    /*
272     * Overrides to the normal JUnit {@link Assert} routines that provide varargs style formatting
273     * and produce an exception stack trace with the assertion frames trimmed out.
274     */
275
276    /**
277     * Fails a test with the given message.
278     *
279     * @param message the identifying message for the {@link AssertionError} (<code>null</code>
280     *            okay)
281     * @see AssertionError
282     */
283    public static void fail(String message, Object... objects) {
284        AssertionError e;
285        if (message == null) {
286            e = new AssertionError();
287        } else {
288            e = new AssertionError(String.format(message, objects));
289        }
290        // Trim the assert frames from the stack trace
291        StackTraceElement[] trace = e.getStackTrace();
292        int start = 1; // Skip this frame
293        String thisClassName = GraalTest.class.getName();
294        while (start < trace.length && trace[start].getClassName().equals(thisClassName) && (trace[start].getMethodName().equals("assertTrue") || trace[start].getMethodName().equals("assertFalse"))) {
295            start++;
296        }
297        e.setStackTrace(Arrays.copyOfRange(trace, start, trace.length));
298        throw e;
299    }
300
301    /**
302     * Asserts that a condition is true. If it isn't it throws an {@link AssertionError} with the
303     * given message.
304     *
305     * @param message the identifying message for the {@link AssertionError} (<code>null</code>
306     *            okay)
307     * @param condition condition to be checked
308     */
309    public static void assertTrue(String message, boolean condition) {
310        assertTrue(condition, message);
311    }
312
313    /**
314     * Asserts that a condition is true. If it isn't it throws an {@link AssertionError} without a
315     * message.
316     *
317     * @param condition condition to be checked
318     */
319    public static void assertTrue(boolean condition) {
320        assertTrue(condition, null);
321    }
322
323    /**
324     * Asserts that a condition is false. If it isn't it throws an {@link AssertionError} with the
325     * given message.
326     *
327     * @param message the identifying message for the {@link AssertionError} (<code>null</code>
328     *            okay)
329     * @param condition condition to be checked
330     */
331    public static void assertFalse(String message, boolean condition) {
332        assertTrue(!condition, message);
333    }
334
335    /**
336     * Asserts that a condition is false. If it isn't it throws an {@link AssertionError} without a
337     * message.
338     *
339     * @param condition condition to be checked
340     */
341    public static void assertFalse(boolean condition) {
342        assertTrue(!condition, null);
343    }
344
345    /**
346     * Asserts that a condition is true. If it isn't it throws an {@link AssertionError} with the
347     * given message.
348     *
349     * @param condition condition to be checked
350     * @param message the identifying message for the {@link AssertionError}
351     * @param objects arguments to the format string
352     */
353    public static void assertTrue(boolean condition, String message, Object... objects) {
354        if (!condition) {
355            fail(message, objects);
356        }
357    }
358
359    /**
360     * Asserts that a condition is false. If it isn't it throws an {@link AssertionError} with the
361     * given message produced by {@link String#format}.
362     *
363     * @param condition condition to be checked
364     * @param message the identifying message for the {@link AssertionError}
365     * @param objects arguments to the format string
366     */
367    public static void assertFalse(boolean condition, String message, Object... objects) {
368        assertTrue(!condition, message, objects);
369    }
370}
371