1/*
2 * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24package jdk.incubator.http.internal.websocket;
25
26import org.testng.annotations.Test;
27
28import java.nio.ByteBuffer;
29import java.security.SecureRandom;
30import java.util.stream.IntStream;
31
32import static org.testng.Assert.assertEquals;
33import static jdk.incubator.http.internal.websocket.Frame.Masker.transferMasking;
34import static jdk.incubator.http.internal.websocket.TestSupport.forEachBufferPartition;
35import static jdk.incubator.http.internal.websocket.TestSupport.fullCopy;
36
37public class MaskerTest {
38
39    private static final SecureRandom random = new SecureRandom();
40
41    @Test
42    public void stateless() {
43        IntStream.iterate(0, i -> i + 1).limit(125).boxed()
44                .forEach(r -> {
45                    int m = random.nextInt();
46                    ByteBuffer src = createSourceBuffer(r);
47                    ByteBuffer dst = createDestinationBuffer(r);
48                    verify(src, dst, maskArray(m), 0,
49                            () -> transferMasking(src, dst, m));
50                });
51    }
52
53    /*
54     * Stateful masker to make sure setting a mask resets the state as if a new
55     * Masker instance is created each time
56     */
57    private final Frame.Masker masker = new Frame.Masker();
58
59    @Test
60    public void stateful0() {
61        // This size (17 = 8 + 8 + 1) should test all the stages
62        // (galloping/slow) of masking good enough
63        int N = 17;
64        ByteBuffer src = createSourceBuffer(N);
65        ByteBuffer dst = createDestinationBuffer(N);
66        int mask = random.nextInt();
67        forEachBufferPartition(src,
68                buffers -> {
69                    int offset = 0;
70                    masker.mask(mask);
71                    int[] maskBytes = maskArray(mask);
72                    for (ByteBuffer s : buffers) {
73                        offset = verify(s, dst, maskBytes, offset,
74                                () -> masker.transferMasking(s, dst));
75                    }
76                });
77    }
78
79    @Test
80    public void stateful1() {
81        int m = random.nextInt();
82        masker.mask(m);
83        ByteBuffer src = ByteBuffer.allocate(0);
84        ByteBuffer dst = ByteBuffer.allocate(16);
85        verify(src, dst, maskArray(m), 0,
86                () -> masker.transferMasking(src, dst));
87    }
88
89    private static int verify(ByteBuffer src,
90                              ByteBuffer dst,
91                              int[] maskBytes,
92                              int offset,
93                              Runnable masking) {
94        ByteBuffer srcCopy = fullCopy(src);
95        ByteBuffer dstCopy = fullCopy(dst);
96        masking.run();
97        int srcRemaining = srcCopy.remaining();
98        int dstRemaining = dstCopy.remaining();
99        int masked = Math.min(srcRemaining, dstRemaining);
100        // 1. position check
101        assertEquals(src.position(), srcCopy.position() + masked);
102        assertEquals(dst.position(), dstCopy.position() + masked);
103        // 2. masking check
104        src.position(srcCopy.position());
105        dst.position(dstCopy.position());
106        for (; src.hasRemaining() && dst.hasRemaining();
107             offset = (offset + 1) & 3) {
108            assertEquals(dst.get(), src.get() ^ maskBytes[offset]);
109        }
110        // 3. corruption check
111        // 3.1 src contents haven't changed
112        int srcPosition = src.position();
113        int srcLimit = src.limit();
114        src.clear();
115        srcCopy.clear();
116        assertEquals(src, srcCopy);
117        src.limit(srcLimit).position(srcPosition); // restore src
118        // 3.2 dst leading and trailing regions' contents haven't changed
119        int dstPosition = dst.position();
120        int dstInitialPosition = dstCopy.position();
121        int dstLimit = dst.limit();
122        // leading
123        dst.position(0).limit(dstInitialPosition);
124        dstCopy.position(0).limit(dstInitialPosition);
125        assertEquals(dst, dstCopy);
126        // trailing
127        dst.limit(dst.capacity()).position(dstLimit);
128        dstCopy.limit(dst.capacity()).position(dstLimit);
129        assertEquals(dst, dstCopy);
130        // restore dst
131        dst.position(dstPosition).limit(dstLimit);
132        return offset;
133    }
134
135    private static ByteBuffer createSourceBuffer(int remaining) {
136        int leading = random.nextInt(4);
137        int trailing = random.nextInt(4);
138        byte[] bytes = new byte[leading + remaining + trailing];
139        random.nextBytes(bytes);
140        return ByteBuffer.wrap(bytes).position(leading).limit(leading + remaining);
141    }
142
143    private static ByteBuffer createDestinationBuffer(int remaining) {
144        int leading = random.nextInt(4);
145        int trailing = random.nextInt(4);
146        return ByteBuffer.allocate(leading + remaining + trailing)
147                .position(leading).limit(leading + remaining);
148    }
149
150    private static int[] maskArray(int mask) {
151        return new int[]{
152                (byte) (mask >>> 24),
153                (byte) (mask >>> 16),
154                (byte) (mask >>>  8),
155                (byte) (mask >>>  0)
156        };
157    }
158}
159