1/*
2 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
3 *
4 * This code is free software; you can redistribute it and/or modify it
5 * under the terms of the GNU General Public License version 2 only, as
6 * published by the Free Software Foundation.
7 *
8 * This code is distributed in the hope that it will be useful, but WITHOUT
9 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
10 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
11 * version 2 for more details (a copy is included in the LICENSE file that
12 * accompanied this code).
13 *
14 * You should have received a copy of the GNU General Public License version
15 * 2 along with this work; if not, write to the Free Software Foundation,
16 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
17 *
18 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
19 * or visit www.oracle.com if you need additional information or have any
20 * questions.
21 */
22
23/*
24 * This file is available under and governed by the GNU General Public
25 * License version 2 only, as published by the Free Software Foundation.
26 * However, the following notice accompanied the original version of this
27 * file:
28 *
29 * Written by Martin Buchholz with assistance from members of JCP
30 * JSR-166 Expert Group and released to the public domain, as
31 * explained at http://creativecommons.org/publicdomain/zero/1.0/
32 */
33
34/*
35 * @test
36 * @modules java.base/java.util.concurrent:open
37 * @run testng WhiteBox
38 * @summary White box tests of implementation details
39 */
40
41import static org.testng.Assert.*;
42import org.testng.annotations.DataProvider;
43import org.testng.annotations.Test;
44
45import java.io.ByteArrayInputStream;
46import java.io.ByteArrayOutputStream;
47import java.io.ObjectInputStream;
48import java.io.ObjectOutputStream;
49import java.lang.invoke.MethodHandles;
50import java.lang.invoke.VarHandle;
51import java.util.ArrayList;
52import java.util.Iterator;
53import java.util.List;
54import java.util.concurrent.LinkedTransferQueue;
55import java.util.concurrent.ThreadLocalRandom;
56import java.util.concurrent.TimeUnit;
57import static java.util.stream.Collectors.toList;
58import java.util.function.Consumer;
59import java.util.function.Function;
60
61@Test
62public class WhiteBox {
63    final ThreadLocalRandom rnd = ThreadLocalRandom.current();
64    final VarHandle HEAD, TAIL, ITEM, NEXT;
65    final int SWEEP_THRESHOLD;
66
67    public WhiteBox() throws ReflectiveOperationException {
68        Class<?> qClass = LinkedTransferQueue.class;
69        Class<?> nodeClass = Class.forName(qClass.getName() + "$Node");
70        MethodHandles.Lookup lookup
71            = MethodHandles.privateLookupIn(qClass, MethodHandles.lookup());
72        HEAD = lookup.findVarHandle(qClass, "head", nodeClass);
73        TAIL = lookup.findVarHandle(qClass, "tail", nodeClass);
74        NEXT = lookup.findVarHandle(nodeClass, "next", nodeClass);
75        ITEM = lookup.findVarHandle(nodeClass, "item", Object.class);
76        SWEEP_THRESHOLD = (int)
77            lookup.findStaticVarHandle(qClass, "SWEEP_THRESHOLD", int.class)
78            .get();
79    }
80
81    Object head(LinkedTransferQueue q) { return HEAD.getVolatile(q); }
82    Object tail(LinkedTransferQueue q) { return TAIL.getVolatile(q); }
83    Object item(Object node)           { return ITEM.getVolatile(node); }
84    Object next(Object node)           { return NEXT.getVolatile(node); }
85
86    int nodeCount(LinkedTransferQueue q) {
87        int i = 0;
88        for (Object p = head(q); p != null; ) {
89            i++;
90            if (p == (p = next(p))) p = head(q);
91        }
92        return i;
93    }
94
95    int tailCount(LinkedTransferQueue q) {
96        int i = 0;
97        for (Object p = tail(q); p != null; ) {
98            i++;
99            if (p == (p = next(p))) p = head(q);
100        }
101        return i;
102    }
103
104    Object findNode(LinkedTransferQueue q, Object e) {
105        for (Object p = head(q); p != null; ) {
106            if (item(p) != null && e.equals(item(p)))
107                return p;
108            if (p == (p = next(p))) p = head(q);
109        }
110        throw new AssertionError("not found");
111    }
112
113    Iterator iteratorAt(LinkedTransferQueue q, Object e) {
114        for (Iterator it = q.iterator(); it.hasNext(); )
115            if (it.next().equals(e))
116                return it;
117        throw new AssertionError("not found");
118    }
119
120    void assertIsSelfLinked(Object node) {
121        assertSame(next(node), node);
122        assertNull(item(node));
123    }
124    void assertIsNotSelfLinked(Object node) {
125        assertNotSame(node, next(node));
126    }
127
128    @Test
129    public void addRemove() {
130        LinkedTransferQueue q = new LinkedTransferQueue();
131        assertInvariants(q);
132        assertNull(next(head(q)));
133        assertNull(item(head(q)));
134        q.add(1);
135        assertEquals(nodeCount(q), 2);
136        assertInvariants(q);
137        q.remove(1);
138        assertEquals(nodeCount(q), 1);
139        assertInvariants(q);
140    }
141
142    /**
143     * Traversal actions that visit every node and do nothing, but
144     * have side effect of squeezing out dead nodes.
145     */
146    @DataProvider
147    public Object[][] traversalActions() {
148        return List.<Consumer<LinkedTransferQueue>>of(
149            q -> q.forEach(e -> {}),
150            q -> assertFalse(q.contains(new Object())),
151            q -> assertFalse(q.remove(new Object())),
152            q -> q.spliterator().forEachRemaining(e -> {}),
153            q -> q.stream().collect(toList()),
154            q -> assertFalse(q.removeIf(e -> false)),
155            q -> assertFalse(q.removeAll(List.of())))
156            .stream().map(x -> new Object[]{ x }).toArray(Object[][]::new);
157    }
158
159    @Test(dataProvider = "traversalActions")
160    public void traversalOperationsCollapseLeadingNodes(
161        Consumer<LinkedTransferQueue> traversalAction) {
162        LinkedTransferQueue q = new LinkedTransferQueue();
163        Object oldHead;
164        int n = 1 + rnd.nextInt(5);
165        for (int i = 0; i < n; i++) q.add(i);
166        assertEquals(nodeCount(q), n + 1);
167        oldHead = head(q);
168        traversalAction.accept(q);
169        assertInvariants(q);
170        assertEquals(nodeCount(q), n);
171        assertIsSelfLinked(oldHead);
172    }
173
174    @Test(dataProvider = "traversalActions")
175    public void traversalOperationsCollapseInteriorNodes(
176        Consumer<LinkedTransferQueue> traversalAction) {
177        LinkedTransferQueue q = new LinkedTransferQueue();
178        int n = 6;
179        for (int i = 0; i < n; i++) q.add(i);
180
181        // We must be quite devious to reliably create an interior dead node
182        Object p0 = findNode(q, 0);
183        Object p1 = findNode(q, 1);
184        Object p2 = findNode(q, 2);
185        Object p3 = findNode(q, 3);
186        Object p4 = findNode(q, 4);
187        Object p5 = findNode(q, 5);
188
189        Iterator it1 = iteratorAt(q, 1);
190        Iterator it2 = iteratorAt(q, 2);
191
192        it2.remove(); // causes it2's ancestor to advance to 1
193        assertSame(next(p1), p3);
194        assertSame(next(p2), p3);
195        assertNull(item(p2));
196        it1.remove(); // removes it2's ancestor
197        assertSame(next(p0), p3);
198        assertSame(next(p1), p3);
199        assertSame(next(p2), p3);
200        assertNull(item(p1));
201        assertEquals(it2.next(), 3);
202        it2.remove(); // it2's ancestor can't unlink
203
204        assertSame(next(p0), p3); // p3 is now interior dead node
205        assertSame(next(p1), p4); // it2 uselessly CASed p1.next
206        assertSame(next(p2), p3);
207        assertSame(next(p3), p4);
208        assertInvariants(q);
209
210        int c = nodeCount(q);
211        traversalAction.accept(q);
212        assertEquals(nodeCount(q), c - 1);
213
214        assertSame(next(p0), p4);
215        assertSame(next(p1), p4);
216        assertSame(next(p2), p3);
217        assertSame(next(p3), p4);
218        assertInvariants(q);
219
220        // trailing nodes are not unlinked
221        Iterator it5 = iteratorAt(q, 5); it5.remove();
222        traversalAction.accept(q);
223        assertSame(next(p4), p5);
224        assertNull(next(p5));
225        assertEquals(nodeCount(q), c - 1);
226    }
227
228    /**
229     * Checks that traversal operations collapse a random pattern of
230     * dead nodes as could normally only occur with a race.
231     */
232    @Test(dataProvider = "traversalActions")
233    public void traversalOperationsCollapseRandomNodes(
234        Consumer<LinkedTransferQueue> traversalAction) {
235        LinkedTransferQueue q = new LinkedTransferQueue();
236        int n = rnd.nextInt(6);
237        for (int i = 0; i < n; i++) q.add(i);
238        ArrayList nulledOut = new ArrayList();
239        for (Object p = head(q); p != null; p = next(p))
240            if (rnd.nextBoolean()) {
241                nulledOut.add(item(p));
242                ITEM.setVolatile(p, null);
243            }
244        traversalAction.accept(q);
245        int c = nodeCount(q);
246        assertEquals(q.size(), c - (q.contains(n - 1) ? 0 : 1));
247        for (int i = 0; i < n; i++)
248            assertTrue(nulledOut.contains(i) ^ q.contains(i));
249    }
250
251    /**
252     * Traversal actions that remove every element, and are also
253     * expected to squeeze out dead nodes.
254     */
255    @DataProvider
256    public Object[][] bulkRemovalActions() {
257        return List.<Consumer<LinkedTransferQueue>>of(
258            q -> q.clear(),
259            q -> assertTrue(q.removeIf(e -> true)),
260            q -> assertTrue(q.retainAll(List.of())))
261            .stream().map(x -> new Object[]{ x }).toArray(Object[][]::new);
262    }
263
264    @Test(dataProvider = "bulkRemovalActions")
265    public void bulkRemovalOperationsCollapseNodes(
266        Consumer<LinkedTransferQueue> bulkRemovalAction) {
267        LinkedTransferQueue q = new LinkedTransferQueue();
268        int n = 1 + rnd.nextInt(5);
269        for (int i = 0; i < n; i++) q.add(i);
270        bulkRemovalAction.accept(q);
271        assertEquals(nodeCount(q), 1);
272        assertInvariants(q);
273    }
274
275    /**
276     * Actions that remove the first element, and are expected to
277     * leave at most one slack dead node at head.
278     */
279    @DataProvider
280    public Object[][] pollActions() {
281        return List.<Consumer<LinkedTransferQueue>>of(
282            q -> assertNotNull(q.poll()),
283            q -> { try { assertNotNull(q.poll(1L, TimeUnit.DAYS)); }
284                catch (Throwable x) { throw new AssertionError(x); }},
285            q -> { try { assertNotNull(q.take()); }
286                catch (Throwable x) { throw new AssertionError(x); }},
287            q -> assertNotNull(q.remove()))
288            .stream().map(x -> new Object[]{ x }).toArray(Object[][]::new);
289    }
290
291    @Test(dataProvider = "pollActions")
292    public void pollActionsOneNodeSlack(
293        Consumer<LinkedTransferQueue> pollAction) {
294        LinkedTransferQueue q = new LinkedTransferQueue();
295        int n = 1 + rnd.nextInt(5);
296        for (int i = 0; i < n; i++) q.add(i);
297        assertEquals(nodeCount(q), n + 1);
298        for (int i = 0; i < n; i++) {
299            int c = nodeCount(q);
300            boolean slack = item(head(q)) == null;
301            if (slack) assertNotNull(item(next(head(q))));
302            pollAction.accept(q);
303            assertEquals(nodeCount(q), q.isEmpty() ? 1 : c - (slack ? 2 : 0));
304        }
305        assertInvariants(q);
306    }
307
308    /**
309     * Actions that append an element, and are expected to
310     * leave at most one slack node at tail.
311     */
312    @DataProvider
313    public Object[][] addActions() {
314        return List.<Consumer<LinkedTransferQueue>>of(
315            q -> q.add(1),
316            q -> q.offer(1))
317            .stream().map(x -> new Object[]{ x }).toArray(Object[][]::new);
318    }
319
320    @Test(dataProvider = "addActions")
321    public void addActionsOneNodeSlack(
322        Consumer<LinkedTransferQueue> addAction) {
323        LinkedTransferQueue q = new LinkedTransferQueue();
324        int n = 1 + rnd.nextInt(9);
325        for (int i = 0; i < n; i++) {
326            boolean slack = next(tail(q)) != null;
327            addAction.accept(q);
328            if (slack)
329                assertNull(next(tail(q)));
330            else {
331                assertNotNull(next(tail(q)));
332                assertNull(next(next(tail(q))));
333            }
334            assertInvariants(q);
335        }
336    }
337
338    byte[] serialBytes(Object o) {
339        try {
340            ByteArrayOutputStream bos = new ByteArrayOutputStream();
341            ObjectOutputStream oos = new ObjectOutputStream(bos);
342            oos.writeObject(o);
343            oos.flush();
344            oos.close();
345            return bos.toByteArray();
346        } catch (Exception fail) {
347            throw new AssertionError(fail);
348        }
349    }
350
351    @SuppressWarnings("unchecked")
352    <T> T serialClone(T o) {
353        try {
354            ObjectInputStream ois = new ObjectInputStream
355                (new ByteArrayInputStream(serialBytes(o)));
356            T clone = (T) ois.readObject();
357            assertNotSame(o, clone);
358            assertSame(o.getClass(), clone.getClass());
359            return clone;
360        } catch (Exception fail) {
361            throw new AssertionError(fail);
362        }
363    }
364
365    public void testSerialization() {
366        LinkedTransferQueue q = serialClone(new LinkedTransferQueue());
367        assertInvariants(q);
368    }
369
370    public void cancelledNodeSweeping() throws Throwable {
371        assertEquals(SWEEP_THRESHOLD & (SWEEP_THRESHOLD - 1), 0);
372        LinkedTransferQueue q = new LinkedTransferQueue();
373        Thread blockHead = null;
374        if (rnd.nextBoolean()) {
375            blockHead = new Thread(
376                () -> { try { q.take(); } catch (InterruptedException ok) {}});
377            blockHead.start();
378            while (nodeCount(q) != 2) { Thread.yield(); }
379            assertTrue(q.hasWaitingConsumer());
380            assertEquals(q.getWaitingConsumerCount(), 1);
381        }
382        int initialNodeCount = nodeCount(q);
383
384        // Some dead nodes do in fact accumulate ...
385        if (blockHead != null)
386            while (nodeCount(q) < initialNodeCount + SWEEP_THRESHOLD / 2)
387                q.poll(1L, TimeUnit.MICROSECONDS);
388
389        // ... but no more than SWEEP_THRESHOLD nodes accumulate
390        for (int i = rnd.nextInt(SWEEP_THRESHOLD * 10); i-->0; )
391            q.poll(1L, TimeUnit.MICROSECONDS);
392        assertTrue(nodeCount(q) <= initialNodeCount + SWEEP_THRESHOLD);
393
394        if (blockHead != null) {
395            blockHead.interrupt();
396            blockHead.join();
397        }
398    }
399
400    /** Checks conditions which should always be true. */
401    void assertInvariants(LinkedTransferQueue q) {
402        assertNotNull(head(q));
403        assertNotNull(tail(q));
404        // head is never self-linked (but tail may!)
405        for (Object h; next(h = head(q)) == h; )
406            assertNotSame(h, head(q)); // must be update race
407    }
408}
409