1/*
2 * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24import java.io.InvalidClassException;
25import java.io.ObjectInputFilter;
26import java.io.Serializable;
27
28import java.rmi.Remote;
29import java.rmi.RemoteException;
30import java.rmi.UnmarshalException;
31import java.rmi.server.UnicastRemoteObject;
32
33import java.util.Objects;
34
35import org.testng.Assert;
36import org.testng.annotations.DataProvider;
37import org.testng.annotations.Test;
38
39/*
40 * @test
41 * @run testng/othervm FilterUROTest
42 * @summary Check that objects are exported with ObjectInputFilters via UnicastRemoteObject
43 */
44public class FilterUROTest {
45
46    /**
47     * Data to test serialFilter call counts.
48     * - name
49     * - Object
50     * - expected count of calls to checkInput.
51     *
52     * @return array of test data
53     */
54    @DataProvider(name = "bindData")
55    static Object[][] bindObjects() {
56        Object[][] data = {
57                {"SimpleString", "SimpleString", 0},
58                {"String", new XX("now is the time"), 1},
59                {"String[]", new XX(new String[3]), 3},
60                {"Long[4]", new XX(new Long[4]), 3},
61                {"RejectME", new XX(new RejectME()), -1},
62        };
63        return data;
64    }
65
66    /*
67     * Test exporting an object with a serialFilter using exportObject().
68     * Send some objects and check the number of calls to the serialFilter.
69     */
70    @Test(dataProvider = "bindData")
71    public void useExportObject(String name, Object obj, int expectedFilterCount) throws RemoteException {
72        try {
73            RemoteImpl impl = RemoteImpl.create();
74            Echo client = (Echo) UnicastRemoteObject
75                    .exportObject(impl, 0, impl.checker);
76            int count = client.filterCount(obj);
77            System.out.printf("count: %d, obj: %s%n", count, obj);
78            Assert.assertEquals(count, expectedFilterCount, "wrong number of filter calls");
79        } catch (RemoteException rex) {
80            if (expectedFilterCount == -1 &&
81                    UnmarshalException.class.equals(rex.getCause().getClass()) &&
82                    InvalidClassException.class.equals(rex.getCause().getCause().getClass())) {
83                return; // normal expected exception
84            }
85            rex.printStackTrace();
86            Assert.fail("unexpected remote exception", rex);
87        } catch (Exception rex) {
88            Assert.fail("unexpected exception", rex);
89        }
90    }
91
92    /*
93     * Test exporting an object with a serialFilter using exportObject()
94     * with explicit (but null) SocketFactories.
95     * Send some objects and check the number of calls to the serialFilter.
96     */
97    @Test(dataProvider = "bindData")
98    public void useExportObject2(String name, Object obj, int expectedFilterCount) throws RemoteException {
99        try {
100            RemoteImpl impl = RemoteImpl.create();
101            Echo client = (Echo) UnicastRemoteObject
102                    .exportObject(impl, 0, null, null, impl.checker);
103            int count = client.filterCount(obj);
104            System.out.printf("count: %d, obj: %s%n", count, obj);
105            Assert.assertEquals(count, expectedFilterCount, "wrong number of filter calls");
106        } catch (RemoteException rex) {
107            if (expectedFilterCount == -1 &&
108                    UnmarshalException.class.equals(rex.getCause().getClass()) &&
109                    InvalidClassException.class.equals(rex.getCause().getCause().getClass())) {
110                return; // normal expected exception
111            }
112            rex.printStackTrace();
113            Assert.fail("unexpected remote exception", rex);
114        } catch (Exception rex) {
115            Assert.fail("unexpected exception", rex);
116        }
117    }
118
119    /**
120     * A simple Serializable holding an object that is passed by value.
121     * It and its contents are checked by the filter.
122     */
123    static class XX implements Serializable {
124        private static final long serialVersionUID = 362498820763181265L;
125
126        final Object obj;
127
128        XX(Object obj) {
129            this.obj = obj;
130        }
131
132        public String toString() {
133            return super.toString() + "//" + Objects.toString(obj);
134        }
135    }
136
137    interface Echo extends Remote {
138        int filterCount(Object obj) throws RemoteException;
139    }
140
141    /**
142     * This remote object just counts the calls to the serialFilter
143     * and returns it.  The caller can check the number against
144     * what was expected for the object passed as an argument.
145     * A new RemoteImpl is used for each test so the count starts at zero again.
146     */
147    static class RemoteImpl implements Echo {
148
149        private static final long serialVersionUID = -6999613679881262446L;
150
151        transient Checker checker;
152
153        static RemoteImpl create() throws RemoteException {
154            RemoteImpl impl = new RemoteImpl(new Checker());
155            return impl;
156        }
157
158        private RemoteImpl(Checker checker) throws RemoteException {
159            this.checker = checker;
160        }
161
162        public int filterCount(Object obj) throws RemoteException {
163            return checker.count();
164        }
165
166    }
167
168    /**
169     * A ObjectInputFilter that just counts when it is called.
170     */
171    static class Checker implements ObjectInputFilter {
172        int count;
173
174        @Override
175        public Status checkInput(FilterInfo filterInfo) {
176            if (filterInfo.serialClass() == RejectME.class) {
177                return Status.REJECTED;
178            }
179            count++;
180            return Status.UNDECIDED;
181        }
182
183        public int count() {
184            return count;
185        }
186    }
187
188    /**
189     * A class to be rejected by the filter.
190     */
191    static class RejectME implements Serializable {
192        private static final long serialVersionUID = 2L;
193    }
194}
195