1/*
2 * Copyright (c) 2014, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24package asmlib;
25
26import java.io.PrintStream;
27import java.util.HashSet;
28import java.util.Set;
29import java.util.concurrent.atomic.AtomicInteger;
30import jdk.internal.org.objectweb.asm.ClassReader;
31import jdk.internal.org.objectweb.asm.ClassVisitor;
32import jdk.internal.org.objectweb.asm.ClassWriter;
33import jdk.internal.org.objectweb.asm.MethodVisitor;
34import jdk.internal.org.objectweb.asm.Opcodes;
35
36import java.util.function.Consumer;
37import jdk.internal.org.objectweb.asm.Type;
38
39public class Instrumentor {
40    public static class InstrHelper {
41        private final MethodVisitor mv;
42        private final String name;
43
44        InstrHelper(MethodVisitor mv, String name) {
45            this.mv = mv;
46            this.name = name;
47        }
48
49        public String getName() {
50            return this.name;
51        }
52
53        public void invokeStatic(String owner, String name, String desc, boolean itf) {
54            mv.visitMethodInsn(Opcodes.INVOKESTATIC, owner, name, desc, itf);
55        }
56
57        public void invokeSpecial(String owner, String name, String desc) {
58            mv.visitMethodInsn(Opcodes.INVOKESPECIAL, owner, name, desc, false);
59        }
60
61        public void invokeVirtual(String owner, String name, String desc) {
62            mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, owner, name, desc, false);
63        }
64
65        public void push(int val) {
66            if (val >= -1 && val <= 5) {
67                mv.visitInsn(Opcodes.ICONST_0 + val);
68            } else if (val >= Byte.MIN_VALUE && val <= Byte.MAX_VALUE) {
69                mv.visitIntInsn(Opcodes.BIPUSH, val);
70            } else if (val >= Short.MIN_VALUE && val <= Short.MAX_VALUE) {
71                mv.visitIntInsn(Opcodes.SIPUSH, val);
72            } else {
73                mv.visitLdcInsn(val);
74            }
75        }
76
77        public void push(Object val) {
78            mv.visitLdcInsn(val);
79        }
80
81        public void println(String s) {
82            mv.visitFieldInsn(Opcodes.GETSTATIC, Type.getInternalName(System.class), "out", Type.getDescriptor(PrintStream.class));
83            mv.visitLdcInsn(s);
84            mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(PrintStream.class), "println", Type.getMethodDescriptor(Type.VOID_TYPE, Type.getType(String.class)), false);
85        }
86    }
87
88    public static Instrumentor instrFor(byte[] classData) {
89        return new Instrumentor(classData);
90    }
91
92
93    private final ClassReader cr;
94    private final ClassWriter output;
95    private ClassVisitor instrumentingVisitor = null;
96    private final AtomicInteger matches = new AtomicInteger(0);
97
98    private Instrumentor(byte[] classData) {
99        cr = new ClassReader(classData);
100        output = new ClassWriter(ClassWriter.COMPUTE_MAXS);
101        instrumentingVisitor = output;
102    }
103
104    public synchronized Instrumentor addMethodEntryInjection(String methodName, Consumer<InstrHelper> injector) {
105        instrumentingVisitor = new ClassVisitor(Opcodes.ASM5, instrumentingVisitor) {
106            @Override
107            public MethodVisitor visitMethod(int access, String name, String desc, String signature, String[] exceptions) {
108                MethodVisitor mv = super.visitMethod(access, name, desc, signature, exceptions);
109
110                if (name.equals(methodName)) {
111                    matches.getAndIncrement();
112
113                    mv = new MethodVisitor(Opcodes.ASM5, mv) {
114                        @Override
115                        public void visitCode() {
116                            injector.accept(new InstrHelper(mv, name));
117                        }
118                    };
119                }
120                return mv;
121            }
122        };
123        return this;
124    }
125
126    public synchronized Instrumentor addNativeMethodTrackingInjection(String prefix, Consumer<InstrHelper> injector) {
127        instrumentingVisitor = new ClassVisitor(Opcodes.ASM5, instrumentingVisitor) {
128            private final Set<Consumer<ClassVisitor>> wmGenerators = new HashSet<>();
129            private String className;
130
131            @Override
132            public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
133                this.className = name;
134                super.visit(version, access, name, signature, superName, interfaces);
135            }
136
137
138            @Override
139            public MethodVisitor visitMethod(int access, String name, String desc, String signature, String[] exceptions) {
140                if ((access & Opcodes.ACC_NATIVE) != 0) {
141                    matches.getAndIncrement();
142
143                    String newName = prefix + name;
144                    wmGenerators.add((v)->{
145                        MethodVisitor mv = v.visitMethod(access & ~Opcodes.ACC_NATIVE, name, desc, signature, exceptions);
146                        mv.visitCode();
147                        injector.accept(new InstrHelper(mv, name));
148                        Type[] argTypes = Type.getArgumentTypes(desc);
149                        Type retType = Type.getReturnType(desc);
150
151                        boolean isStatic = (access & Opcodes.ACC_STATIC) != 0;
152                        if (!isStatic) {
153                            mv.visitIntInsn(Opcodes.ALOAD, 0); // load "this"
154                        }
155
156                        // load the method parameters
157                        if (argTypes.length > 0) {
158                            int ptr = isStatic ? 0 : 1;
159                            for(Type argType : argTypes) {
160                                mv.visitIntInsn(argType.getOpcode(Opcodes.ILOAD), ptr);
161                                ptr += argType.getSize();
162                            }
163                        }
164
165                        mv.visitMethodInsn(isStatic ? Opcodes.INVOKESTATIC : Opcodes.INVOKESPECIAL, className, newName, desc, false);
166                        mv.visitInsn(retType.getOpcode(Opcodes.IRETURN));
167
168                        mv.visitMaxs(1, 1); // dummy call; let ClassWriter to deal with this
169                        mv.visitEnd();
170                    });
171                    return super.visitMethod(access, newName, desc, signature, exceptions);
172                }
173                return super.visitMethod(access, name, desc, signature, exceptions);
174            }
175
176            @Override
177            public void visitEnd() {
178                wmGenerators.stream().forEach((e) -> {
179                    e.accept(cv);
180                });
181                super.visitEnd();
182            }
183        };
184
185        return this;
186    }
187
188    public synchronized byte[] apply() {
189        cr.accept(instrumentingVisitor, ClassReader.SKIP_DEBUG + ClassReader.EXPAND_FRAMES);
190
191        return matches.get() == 0 ? null : output.toByteArray();
192    }
193}
194