1/*
2 * Copyright (c) 2016, 2016, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23package org.graalvm.compiler.api.directives.test;
24
25import java.io.IOException;
26import java.util.ArrayList;
27import java.util.HashMap;
28import java.util.Iterator;
29import java.util.List;
30import java.util.Map;
31
32import org.graalvm.compiler.test.ExportingClassLoader;
33
34import jdk.internal.org.objectweb.asm.ClassReader;
35import jdk.internal.org.objectweb.asm.ClassWriter;
36import jdk.internal.org.objectweb.asm.Label;
37import jdk.internal.org.objectweb.asm.MethodVisitor;
38import jdk.internal.org.objectweb.asm.Opcodes;
39import jdk.internal.org.objectweb.asm.tree.AbstractInsnNode;
40import jdk.internal.org.objectweb.asm.tree.ClassNode;
41import jdk.internal.org.objectweb.asm.tree.IincInsnNode;
42import jdk.internal.org.objectweb.asm.tree.InsnList;
43import jdk.internal.org.objectweb.asm.tree.JumpInsnNode;
44import jdk.internal.org.objectweb.asm.tree.LabelNode;
45import jdk.internal.org.objectweb.asm.tree.LineNumberNode;
46import jdk.internal.org.objectweb.asm.tree.MethodNode;
47import jdk.internal.org.objectweb.asm.tree.VarInsnNode;
48
49/**
50 * The {@code TinyInstrumentor} is a bytecode instrumentor using ASM bytecode manipulation
51 * framework. It injects given code snippet into a target method and creates a temporary class as
52 * the container. Because the target method is cloned into the temporary class, it is required that
53 * the target method is public static. Any referred method/field in the target method or the
54 * instrumentation snippet should be made public as well.
55 */
56public class TinyInstrumentor implements Opcodes {
57
58    private InsnList instrumentationInstructions;
59    private int instrumentationMaxLocal;
60
61    /**
62     * Create a instrumentor with a instrumentation snippet. The snippet is specified with the given
63     * class {@code instrumentationClass} and the given method name {@code methodName}.
64     */
65    public TinyInstrumentor(Class<?> instrumentationClass, String methodName) throws IOException {
66        MethodNode instrumentationMethod = getMethodNode(instrumentationClass, methodName);
67        assert instrumentationMethod != null;
68        assert (instrumentationMethod.access | ACC_STATIC) != 0;
69        assert "()V".equals(instrumentationMethod.desc);
70        instrumentationInstructions = cloneInstructions(instrumentationMethod.instructions);
71        instrumentationMaxLocal = instrumentationMethod.maxLocals;
72        // replace return instructions with a goto unless there is a single return at the end. In
73        // that case, simply remove the return.
74        List<AbstractInsnNode> returnInstructions = new ArrayList<>();
75        for (AbstractInsnNode instruction : selectAll(instrumentationInstructions)) {
76            if (instruction instanceof LineNumberNode) {
77                instrumentationInstructions.remove(instruction);
78            } else if (instruction.getOpcode() == RETURN) {
79                returnInstructions.add(instruction);
80            }
81        }
82        LabelNode exit = new LabelNode();
83        if (returnInstructions.size() == 1) {
84            AbstractInsnNode returnInstruction = returnInstructions.get(0);
85            if (instrumentationInstructions.getLast() != returnInstruction) {
86                instrumentationInstructions.insertBefore(returnInstruction, new JumpInsnNode(GOTO, exit));
87            }
88            instrumentationInstructions.remove(returnInstruction);
89        } else {
90            for (AbstractInsnNode returnInstruction : returnInstructions) {
91                instrumentationInstructions.insertBefore(returnInstruction, new JumpInsnNode(GOTO, exit));
92                instrumentationInstructions.remove(returnInstruction);
93            }
94        }
95        instrumentationInstructions.add(exit);
96    }
97
98    /**
99     * @return a {@link MethodNode} called {@code methodName} in the given class.
100     */
101    private static MethodNode getMethodNode(Class<?> clazz, String methodName) throws IOException {
102        ClassReader classReader = new ClassReader(clazz.getName());
103        ClassNode classNode = new ClassNode();
104        classReader.accept(classNode, ClassReader.SKIP_FRAMES);
105
106        for (MethodNode methodNode : classNode.methods) {
107            if (methodNode.name.equals(methodName)) {
108                return methodNode;
109            }
110        }
111        return null;
112    }
113
114    /**
115     * Create a {@link ClassNode} with empty constructor.
116     */
117    private static ClassNode emptyClass(String name) {
118        ClassNode classNode = new ClassNode();
119        classNode.visit(52, ACC_SUPER | ACC_PUBLIC, name.replace('.', '/'), null, "java/lang/Object", new String[]{});
120
121        MethodVisitor mv = classNode.visitMethod(ACC_PUBLIC, "<init>", "()V", null, null);
122        mv.visitCode();
123        Label l0 = new Label();
124        mv.visitLabel(l0);
125        mv.visitVarInsn(ALOAD, 0);
126        mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "<init>", "()V", false);
127        mv.visitInsn(RETURN);
128        Label l1 = new Label();
129        mv.visitLabel(l1);
130        mv.visitMaxs(1, 1);
131        mv.visitEnd();
132
133        return classNode;
134    }
135
136    /**
137     * Helper method for iterating the given {@link InsnList}.
138     */
139    private static Iterable<AbstractInsnNode> selectAll(InsnList instructions) {
140        return new Iterable<AbstractInsnNode>() {
141            @Override
142            public Iterator<AbstractInsnNode> iterator() {
143                return instructions.iterator();
144            }
145        };
146    }
147
148    /**
149     * Make a clone of the given {@link InsnList}.
150     */
151    private static InsnList cloneInstructions(InsnList instructions) {
152        Map<LabelNode, LabelNode> labelMap = new HashMap<>();
153        for (AbstractInsnNode instruction : selectAll(instructions)) {
154            if (instruction instanceof LabelNode) {
155                LabelNode clone = new LabelNode(new Label());
156                LabelNode original = (LabelNode) instruction;
157                labelMap.put(original, clone);
158            }
159        }
160        InsnList clone = new InsnList();
161        for (AbstractInsnNode insn : selectAll(instructions)) {
162            clone.add(insn.clone(labelMap));
163        }
164        return clone;
165    }
166
167    /**
168     * Shifts all local variable slot references by a specified amount.
169     */
170    private static void shiftLocalSlots(InsnList instructions, int offset) {
171        for (AbstractInsnNode insn : selectAll(instructions)) {
172            if (insn instanceof VarInsnNode) {
173                VarInsnNode varInsn = (VarInsnNode) insn;
174                varInsn.var += offset;
175
176            } else if (insn instanceof IincInsnNode) {
177                IincInsnNode iincInsn = (IincInsnNode) insn;
178                iincInsn.var += offset;
179            }
180        }
181    }
182
183    /**
184     * Instrument the target method specified by the class {@code targetClass} and the method name
185     * {@code methodName}. For each occurrence of the {@code opcode}, the instrumentor injects a
186     * copy of the instrumentation snippet.
187     */
188    public Class<?> instrument(Class<?> targetClass, String methodName, int opcode) throws IOException, ClassNotFoundException {
189        return instrument(targetClass, methodName, opcode, true);
190    }
191
192    public Class<?> instrument(Class<?> targetClass, String methodName, int opcode, boolean insertAfter) throws IOException, ClassNotFoundException {
193        // create a container class
194        String className = targetClass.getName() + "$$" + methodName;
195        ClassNode classNode = emptyClass(className);
196        // duplicate the target method and add to the container class
197        MethodNode methodNode = getMethodNode(targetClass, methodName);
198        MethodNode newMethodNode = new MethodNode(methodNode.access, methodNode.name, methodNode.desc, methodNode.signature, methodNode.exceptions.toArray(new String[methodNode.exceptions.size()]));
199        methodNode.accept(newMethodNode);
200        classNode.methods.add(newMethodNode);
201        // perform bytecode instrumentation
202        for (AbstractInsnNode instruction : selectAll(newMethodNode.instructions)) {
203            if (instruction.getOpcode() == opcode) {
204                InsnList instrumentation = cloneInstructions(instrumentationInstructions);
205                shiftLocalSlots(instrumentation, newMethodNode.maxLocals);
206                newMethodNode.maxLocals += instrumentationMaxLocal;
207                if (insertAfter) {
208                    newMethodNode.instructions.insert(instruction, instrumentation);
209                } else {
210                    newMethodNode.instructions.insertBefore(instruction, instrumentation);
211                }
212            }
213        }
214        // dump a byte array and load the class with a dedicated loader to separate the namespace
215        ClassWriter classWriter = new ClassWriter(ClassWriter.COMPUTE_FRAMES);
216        classNode.accept(classWriter);
217        byte[] bytes = classWriter.toByteArray();
218        return new Loader(className, bytes).findClass(className);
219    }
220
221    private static class Loader extends ExportingClassLoader {
222
223        private String className;
224        private byte[] bytes;
225
226        Loader(String className, byte[] bytes) {
227            super(TinyInstrumentor.class.getClassLoader());
228            this.className = className;
229            this.bytes = bytes;
230        }
231
232        @Override
233        protected Class<?> findClass(String name) throws ClassNotFoundException {
234            if (name.equals(className)) {
235                return defineClass(name, bytes, 0, bytes.length);
236            } else {
237                return super.findClass(name);
238            }
239        }
240    }
241
242}
243