1/*
2 * Copyright (c) 2014, 2014, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23package org.graalvm.compiler.jtt.except;
24
25import org.graalvm.compiler.jtt.JTTTest;
26import org.graalvm.compiler.test.ExportingClassLoader;
27import org.junit.BeforeClass;
28import org.junit.Test;
29import org.objectweb.asm.ClassWriter;
30import org.objectweb.asm.MethodVisitor;
31import org.objectweb.asm.Opcodes;
32import org.objectweb.asm.Type;
33
34public class UntrustedInterfaces extends JTTTest {
35
36    public interface CallBack {
37        int callBack(TestInterface ti);
38    }
39
40    private interface TestInterface {
41        int method();
42    }
43
44    /**
45     * What a GoodPill would look like.
46     *
47     * <pre>
48     * private static final class GoodPill extends Pill {
49     *     public void setField() {
50     *         field = new TestConstant();
51     *     }
52     *
53     *     public void setStaticField() {
54     *         staticField = new TestConstant();
55     *     }
56     *
57     *     public int callMe(CallBack callback) {
58     *         return callback.callBack(new TestConstant());
59     *     }
60     *
61     *     public TestInterface get() {
62     *         return new TestConstant();
63     *     }
64     * }
65     *
66     * private static final class TestConstant implements TestInterface {
67     *     public int method() {
68     *         return 42;
69     *     }
70     * }
71     * </pre>
72     */
73    public abstract static class Pill {
74        public static TestInterface staticField;
75        public TestInterface field;
76
77        public abstract void setField();
78
79        public abstract void setStaticField();
80
81        public abstract int callMe(CallBack callback);
82
83        public abstract TestInterface get();
84    }
85
86    public int callBack(TestInterface list) {
87        return list.method();
88    }
89
90    public int staticFieldInvoke(Pill pill) {
91        pill.setStaticField();
92        return Pill.staticField.method();
93    }
94
95    public int fieldInvoke(Pill pill) {
96        pill.setField();
97        return pill.field.method();
98    }
99
100    public int argumentInvoke(Pill pill) {
101        return pill.callMe(ti -> ti.method());
102    }
103
104    public int returnInvoke(Pill pill) {
105        return pill.get().method();
106    }
107
108    @SuppressWarnings("cast")
109    public boolean staticFieldInstanceof(Pill pill) {
110        pill.setStaticField();
111        return Pill.staticField instanceof TestInterface;
112    }
113
114    @SuppressWarnings("cast")
115    public boolean fieldInstanceof(Pill pill) {
116        pill.setField();
117        return pill.field instanceof TestInterface;
118    }
119
120    @SuppressWarnings("cast")
121    public int argumentInstanceof(Pill pill) {
122        return pill.callMe(ti -> ti instanceof TestInterface ? 42 : 24);
123    }
124
125    @SuppressWarnings("cast")
126    public boolean returnInstanceof(Pill pill) {
127        return pill.get() instanceof TestInterface;
128    }
129
130    public TestInterface staticFieldCheckcast(Pill pill) {
131        pill.setStaticField();
132        return TestInterface.class.cast(Pill.staticField);
133    }
134
135    public TestInterface fieldCheckcast(Pill pill) {
136        pill.setField();
137        return TestInterface.class.cast(pill.field);
138    }
139
140    public int argumentCheckcast(Pill pill) {
141        return pill.callMe(ti -> TestInterface.class.cast(ti).method());
142    }
143
144    public TestInterface returnCheckcast(Pill pill) {
145        return TestInterface.class.cast(pill.get());
146    }
147
148    private static Pill poisonPill;
149
150    // Checkstyle: stop
151    @BeforeClass
152    public static void setUp() throws InstantiationException, IllegalAccessException, ClassNotFoundException {
153        poisonPill = (Pill) new PoisonLoader().findClass(PoisonLoader.POISON_IMPL_NAME).newInstance();
154    }
155
156    // Checkstyle: resume
157
158    @Test
159    public void testStaticField0() {
160        runTest("staticFieldInvoke", poisonPill);
161    }
162
163    @Test
164    public void testStaticField1() {
165        runTest("staticFieldInstanceof", poisonPill);
166    }
167
168    @Test
169    public void testStaticField2() {
170        runTest("staticFieldCheckcast", poisonPill);
171    }
172
173    @Test
174    public void testField0() {
175        runTest("fieldInvoke", poisonPill);
176    }
177
178    @Test
179    public void testField1() {
180        runTest("fieldInstanceof", poisonPill);
181    }
182
183    @Test
184    public void testField2() {
185        runTest("fieldCheckcast", poisonPill);
186    }
187
188    @Test
189    public void testArgument0() {
190        runTest("argumentInvoke", poisonPill);
191    }
192
193    @Test
194    public void testArgument1() {
195        runTest("argumentInstanceof", poisonPill);
196    }
197
198    @Test
199    public void testArgument2() {
200        runTest("argumentCheckcast", poisonPill);
201    }
202
203    @Test
204    public void testReturn0() {
205        runTest("returnInvoke", poisonPill);
206    }
207
208    @Test
209    public void testReturn1() {
210        runTest("returnInstanceof", poisonPill);
211    }
212
213    @Test
214    public void testReturn2() {
215        runTest("returnCheckcast", poisonPill);
216    }
217
218    private static class PoisonLoader extends ExportingClassLoader {
219        public static final String POISON_IMPL_NAME = "org.graalvm.compiler.jtt.except.PoisonPill";
220
221        @Override
222        protected Class<?> findClass(String name) throws ClassNotFoundException {
223            if (name.equals(POISON_IMPL_NAME)) {
224                ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);
225
226                cw.visit(Opcodes.V1_8, Opcodes.ACC_PUBLIC, POISON_IMPL_NAME.replace('.', '/'), null, Type.getInternalName(Pill.class), null);
227                // constructor
228                MethodVisitor constructor = cw.visitMethod(Opcodes.ACC_PUBLIC, "<init>", "()V", null, null);
229                constructor.visitCode();
230                constructor.visitVarInsn(Opcodes.ALOAD, 0);
231                constructor.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(Pill.class), "<init>", "()V", false);
232                constructor.visitInsn(Opcodes.RETURN);
233                constructor.visitMaxs(0, 0);
234                constructor.visitEnd();
235
236                MethodVisitor setList = cw.visitMethod(Opcodes.ACC_PUBLIC, "setField", "()V", null, null);
237                setList.visitCode();
238                setList.visitVarInsn(Opcodes.ALOAD, 0);
239                setList.visitTypeInsn(Opcodes.NEW, Type.getInternalName(Object.class));
240                setList.visitInsn(Opcodes.DUP);
241                setList.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(Object.class), "<init>", "()V", false);
242                setList.visitFieldInsn(Opcodes.PUTFIELD, Type.getInternalName(Pill.class), "field", Type.getDescriptor(TestInterface.class));
243                setList.visitInsn(Opcodes.RETURN);
244                setList.visitMaxs(0, 0);
245                setList.visitEnd();
246
247                MethodVisitor setStaticList = cw.visitMethod(Opcodes.ACC_PUBLIC, "setStaticField", "()V", null, null);
248                setStaticList.visitCode();
249                setStaticList.visitTypeInsn(Opcodes.NEW, Type.getInternalName(Object.class));
250                setStaticList.visitInsn(Opcodes.DUP);
251                setStaticList.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(Object.class), "<init>", "()V", false);
252                setStaticList.visitFieldInsn(Opcodes.PUTSTATIC, Type.getInternalName(Pill.class), "staticField", Type.getDescriptor(TestInterface.class));
253                setStaticList.visitInsn(Opcodes.RETURN);
254                setStaticList.visitMaxs(0, 0);
255                setStaticList.visitEnd();
256
257                MethodVisitor callMe = cw.visitMethod(Opcodes.ACC_PUBLIC, "callMe", Type.getMethodDescriptor(Type.INT_TYPE, Type.getType(CallBack.class)), null, null);
258                callMe.visitCode();
259                callMe.visitVarInsn(Opcodes.ALOAD, 1);
260                callMe.visitTypeInsn(Opcodes.NEW, Type.getInternalName(Object.class));
261                callMe.visitInsn(Opcodes.DUP);
262                callMe.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(Object.class), "<init>", "()V", false);
263                callMe.visitMethodInsn(Opcodes.INVOKEINTERFACE, Type.getInternalName(CallBack.class), "callBack", Type.getMethodDescriptor(Type.INT_TYPE, Type.getType(TestInterface.class)), true);
264                callMe.visitInsn(Opcodes.IRETURN);
265                callMe.visitMaxs(0, 0);
266                callMe.visitEnd();
267
268                MethodVisitor getList = cw.visitMethod(Opcodes.ACC_PUBLIC, "get", Type.getMethodDescriptor(Type.getType(TestInterface.class)), null, null);
269                getList.visitCode();
270                getList.visitTypeInsn(Opcodes.NEW, Type.getInternalName(Object.class));
271                getList.visitInsn(Opcodes.DUP);
272                getList.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(Object.class), "<init>", "()V", false);
273                getList.visitInsn(Opcodes.ARETURN);
274                getList.visitMaxs(0, 0);
275                getList.visitEnd();
276
277                cw.visitEnd();
278
279                byte[] bytes = cw.toByteArray();
280                return defineClass(name, bytes, 0, bytes.length);
281            }
282            return super.findClass(name);
283        }
284    }
285}
286