1/*
2 * Copyright (c) 2014, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24/*
25 * @test
26 * @bug 8035776
27 * @summary Ensure that invocation parameters are always cast to the instantiatedMethodType
28 */
29import java.lang.invoke.*;
30import java.util.Arrays;
31import static java.lang.invoke.MethodType.methodType;
32
33public class MetafactoryParameterCastTest {
34
35    static final MethodHandles.Lookup lookup = MethodHandles.lookup();
36
37    public static class A {
38    }
39
40    public static class B extends A {
41        void instance0() {}
42        void instance1(B arg) {}
43        static void static1(B arg) {}
44        static void static2(B arg1, B arg2) {}
45    }
46
47    public static class C extends B {}
48    public static class NotC extends B {}
49
50    public interface ASink { void take(A arg); }
51    public interface BSink { void take(B arg); }
52
53    public static void main(String... args) throws Throwable {
54        new MetafactoryParameterCastTest().test();
55    }
56
57    void test() throws Throwable {
58        MethodType takeA = methodType(void.class, A.class);
59        MethodType takeB = methodType(void.class, B.class);
60        MethodType takeC = methodType(void.class, C.class);
61
62        Class<?>[] noCapture = {};
63        Class<?>[] captureB = { B.class };
64
65        MethodHandle[] oneBParam = { lookup.findVirtual(B.class, "instance0", methodType(void.class)),
66                                     lookup.findStatic(B.class, "static1", methodType(void.class, B.class)) };
67        MethodHandle[] twoBParams = { lookup.findVirtual(B.class, "instance1", methodType(void.class, B.class)),
68                                      lookup.findStatic(B.class, "static2", methodType(void.class, B.class, B.class)) };
69
70        for (MethodHandle mh : oneBParam) {
71            // sam
72            tryASink(invokeMetafactory(mh, ASink.class, "take", noCapture, takeC, takeA));
73            tryBSink(invokeMetafactory(mh, BSink.class, "take", noCapture, takeC, takeB));
74            tryASink(invokeAltMetafactory(mh, ASink.class, "take", noCapture, takeC, takeA));
75            tryBSink(invokeAltMetafactory(mh, BSink.class, "take", noCapture, takeC, takeB));
76
77            // bridge
78            tryASink(invokeAltMetafactory(mh, ASink.class, "take", noCapture, takeC, takeC, takeA));
79            tryBSink(invokeAltMetafactory(mh, BSink.class, "take", noCapture, takeC, takeC, takeB));
80        }
81
82        for (MethodHandle mh : twoBParams) {
83            // sam
84            tryCapASink(invokeMetafactory(mh, ASink.class, "take", captureB, takeC, takeA));
85            tryCapBSink(invokeMetafactory(mh, BSink.class, "take", captureB, takeC, takeB));
86            tryCapASink(invokeAltMetafactory(mh, ASink.class, "take", captureB, takeC, takeA));
87            tryCapBSink(invokeAltMetafactory(mh, BSink.class, "take", captureB, takeC, takeB));
88
89            // bridge
90            tryCapASink(invokeAltMetafactory(mh, ASink.class, "take", captureB, takeC, takeC, takeA));
91            tryCapBSink(invokeAltMetafactory(mh, BSink.class, "take", captureB, takeC, takeC, takeB));
92        }
93    }
94
95    void tryASink(CallSite cs) throws Throwable {
96        ASink sink = (ASink) cs.dynamicInvoker().invoke();
97        tryASink(sink);
98    }
99
100    void tryCapASink(CallSite cs) throws Throwable {
101        ASink sink = (ASink) cs.dynamicInvoker().invoke(new B());
102        tryASink(sink);
103    }
104
105    void tryBSink(CallSite cs) throws Throwable {
106        BSink sink = (BSink) cs.dynamicInvoker().invoke();
107        tryBSink(sink);
108    }
109
110    void tryCapBSink(CallSite cs) throws Throwable {
111        BSink sink = (BSink) cs.dynamicInvoker().invoke(new B());
112        tryBSink(sink);
113    }
114
115    void tryASink(ASink sink) {
116        try { sink.take(new C()); }
117        catch (ClassCastException e) {
118            throw new AssertionError("Unexpected cast failure: " + e + " " + lastMFParams());
119        }
120
121        try {
122            sink.take(new B());
123            throw new AssertionError("Missing cast from A to C: " + lastMFParams());
124        }
125        catch (ClassCastException e) { /* expected */ }
126
127        try {
128            sink.take(new NotC());
129            throw new AssertionError("Missing cast from A to C: " + lastMFParams());
130        }
131        catch (ClassCastException e) { /* expected */ }
132    }
133
134    void tryBSink(BSink sink) {
135        try { sink.take(new C()); }
136        catch (ClassCastException e) {
137            throw new AssertionError("Unexpected cast failure: " + e + " " + lastMFParams());
138        }
139
140        try {
141            sink.take(new B());
142            throw new AssertionError("Missing cast from B to C: " + lastMFParams());
143        }
144        catch (ClassCastException e) { /* expected */ }
145
146        try {
147            sink.take(new NotC());
148            throw new AssertionError("Missing cast from B to C: " + lastMFParams());
149        }
150        catch (ClassCastException e) { /* expected */ }
151    }
152
153    MethodHandle lastMH;
154    Class<?>[] lastCaptured;
155    MethodType lastInstMT;
156    MethodType lastSamMT;
157    MethodType[] lastBridgeMTs;
158
159    String lastMFParams() {
160        return "mh=" + lastMH +
161               ", captured=" + Arrays.toString(lastCaptured) +
162               ", instMT=" + lastInstMT +
163               ", samMT=" + lastSamMT +
164               ", bridgeMTs=" + Arrays.toString(lastBridgeMTs);
165    }
166
167    CallSite invokeMetafactory(MethodHandle mh, Class<?> sam, String methodName,
168                               Class<?>[] captured, MethodType instMT, MethodType samMT) {
169        lastMH = mh;
170        lastCaptured = captured;
171        lastInstMT = instMT;
172        lastSamMT = samMT;
173        lastBridgeMTs = new MethodType[]{};
174        try {
175            return LambdaMetafactory.metafactory(lookup, methodName, methodType(sam, captured),
176                                                 samMT, mh, instMT);
177        }
178        catch (LambdaConversionException e) {
179            // unexpected linkage error
180            throw new RuntimeException(e);
181        }
182    }
183
184    CallSite invokeAltMetafactory(MethodHandle mh, Class<?> sam, String methodName,
185                                  Class<?>[] captured, MethodType instMT,
186                                  MethodType samMT, MethodType... bridgeMTs) {
187        lastMH = mh;
188        lastCaptured = captured;
189        lastInstMT = instMT;
190        lastSamMT = samMT;
191        lastBridgeMTs = bridgeMTs;
192        try {
193            boolean bridge = bridgeMTs.length > 0;
194            Object[] args = new Object[bridge ? 5+bridgeMTs.length : 4];
195            args[0] = samMT;
196            args[1] = mh;
197            args[2] = instMT;
198            args[3] = bridge ? LambdaMetafactory.FLAG_BRIDGES : 0;
199            if (bridge) {
200                args[4] = bridgeMTs.length;
201                for (int i = 0; i < bridgeMTs.length; i++) args[5+i] = bridgeMTs[i];
202            }
203            return LambdaMetafactory.altMetafactory(lookup, methodName, methodType(sam, captured), args);
204        }
205        catch (LambdaConversionException e) {
206            // unexpected linkage error
207            throw new RuntimeException(e);
208        }
209    }
210
211}
212