1/*
2 * Copyright (c) 2015, 2017, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.  Oracle designates this
8 * particular file as subject to the "Classpath" exception as provided
9 * by Oracle in the LICENSE file that accompanied this code.
10 *
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have any
23 * questions.
24 */
25
26package jdk.incubator.http.internal.websocket;
27
28import jdk.incubator.http.WebSocket.MessagePart;
29import jdk.incubator.http.internal.common.Log;
30import jdk.incubator.http.internal.websocket.Frame.Opcode;
31
32import java.nio.ByteBuffer;
33import java.nio.CharBuffer;
34import java.nio.charset.CharacterCodingException;
35
36import static java.lang.String.format;
37import static java.nio.charset.StandardCharsets.UTF_8;
38import static java.util.Objects.requireNonNull;
39import static jdk.incubator.http.internal.common.Utils.dump;
40import static jdk.incubator.http.internal.websocket.StatusCodes.NO_STATUS_CODE;
41import static jdk.incubator.http.internal.websocket.StatusCodes.isLegalToReceiveFromServer;
42
43/*
44 * Consumes frame parts and notifies a message consumer, when there is
45 * sufficient data to produce a message, or part thereof.
46 *
47 * Data consumed but not yet translated is accumulated until it's sufficient to
48 * form a message.
49 */
50/* Non-final for testing purposes only */
51class FrameConsumer implements Frame.Consumer {
52
53    private final MessageStreamConsumer output;
54    private final UTF8AccumulatingDecoder decoder = new UTF8AccumulatingDecoder();
55    private boolean fin;
56    private Opcode opcode, originatingOpcode;
57    private MessagePart part = MessagePart.WHOLE;
58    private long payloadLen;
59    private long unconsumedPayloadLen;
60    private ByteBuffer binaryData;
61
62    FrameConsumer(MessageStreamConsumer output) {
63        this.output = requireNonNull(output);
64    }
65
66    /* Exposed for testing purposes only */
67    MessageStreamConsumer getOutput() {
68        return output;
69    }
70
71    @Override
72    public void fin(boolean value) {
73        Log.logTrace("Reading fin: {0}", value);
74        fin = value;
75    }
76
77    @Override
78    public void rsv1(boolean value) {
79        Log.logTrace("Reading rsv1: {0}", value);
80        if (value) {
81            throw new FailWebSocketException("Unexpected rsv1 bit");
82        }
83    }
84
85    @Override
86    public void rsv2(boolean value) {
87        Log.logTrace("Reading rsv2: {0}", value);
88        if (value) {
89            throw new FailWebSocketException("Unexpected rsv2 bit");
90        }
91    }
92
93    @Override
94    public void rsv3(boolean value) {
95        Log.logTrace("Reading rsv3: {0}", value);
96        if (value) {
97            throw new FailWebSocketException("Unexpected rsv3 bit");
98        }
99    }
100
101    @Override
102    public void opcode(Opcode v) {
103        Log.logTrace("Reading opcode: {0}", v);
104        if (v == Opcode.PING || v == Opcode.PONG || v == Opcode.CLOSE) {
105            if (!fin) {
106                throw new FailWebSocketException("Fragmented control frame  " + v);
107            }
108            opcode = v;
109        } else if (v == Opcode.TEXT || v == Opcode.BINARY) {
110            if (originatingOpcode != null) {
111                throw new FailWebSocketException(
112                        format("Unexpected frame %s (fin=%s)", v, fin));
113            }
114            opcode = v;
115            if (!fin) {
116                originatingOpcode = v;
117            }
118        } else if (v == Opcode.CONTINUATION) {
119            if (originatingOpcode == null) {
120                throw new FailWebSocketException(
121                        format("Unexpected frame %s (fin=%s)", v, fin));
122            }
123            opcode = v;
124        } else {
125            throw new FailWebSocketException("Unknown opcode " + v);
126        }
127    }
128
129    @Override
130    public void mask(boolean value) {
131        Log.logTrace("Reading mask: {0}", value);
132        if (value) {
133            throw new FailWebSocketException("Masked frame received");
134        }
135    }
136
137    @Override
138    public void payloadLen(long value) {
139        Log.logTrace("Reading payloadLen: {0}", value);
140        if (opcode.isControl()) {
141            if (value > 125) {
142                throw new FailWebSocketException(
143                        format("%s's payload length %s", opcode, value));
144            }
145            assert Opcode.CLOSE.isControl();
146            if (opcode == Opcode.CLOSE && value == 1) {
147                throw new FailWebSocketException("Incomplete status code");
148            }
149        }
150        payloadLen = value;
151        unconsumedPayloadLen = value;
152    }
153
154    @Override
155    public void maskingKey(int value) {
156        // `FrameConsumer.mask(boolean)` is where a masked frame is detected and
157        // reported on; `FrameConsumer.mask(boolean)` MUST be invoked before
158        // this method;
159        // So this method (`maskingKey`) is not supposed to be invoked while
160        // reading a frame that has came from the server. If this method is
161        // invoked, then it's an error in implementation, thus InternalError
162        throw new InternalError();
163    }
164
165    @Override
166    public void payloadData(ByteBuffer data) {
167        Log.logTrace("Reading payloadData: data={0}", data);
168        unconsumedPayloadLen -= data.remaining();
169        boolean isLast = unconsumedPayloadLen == 0;
170        if (opcode.isControl()) {
171            if (binaryData != null) { // An intermediate or the last chunk
172                binaryData.put(data);
173            } else if (!isLast) { // The first chunk
174                int remaining = data.remaining();
175                // It shouldn't be 125, otherwise the next chunk will be of size
176                // 0, which is not what Reader promises to deliver (eager
177                // reading)
178                assert remaining < 125 : dump(remaining);
179                binaryData = ByteBuffer.allocate(125).put(data);
180            } else { // The only chunk
181                binaryData = ByteBuffer.allocate(data.remaining()).put(data);
182            }
183        } else {
184            part = determinePart(isLast);
185            boolean text = opcode == Opcode.TEXT || originatingOpcode == Opcode.TEXT;
186            if (!text) {
187                output.onBinary(part, data.slice());
188                data.position(data.limit()); // Consume
189            } else {
190                boolean binaryNonEmpty = data.hasRemaining();
191                CharBuffer textData;
192                try {
193                    textData = decoder.decode(data, part == MessagePart.WHOLE || part == MessagePart.LAST);
194                } catch (CharacterCodingException e) {
195                    throw new FailWebSocketException(
196                            "Invalid UTF-8 in frame " + opcode, StatusCodes.NOT_CONSISTENT)
197                            .initCause(e);
198                }
199                if (!(binaryNonEmpty && !textData.hasRemaining())) {
200                    // If there's a binary data, that result in no text, then we
201                    // don't deliver anything
202                    output.onText(part, textData);
203                }
204            }
205        }
206    }
207
208    @Override
209    public void endFrame() {
210        if (opcode.isControl()) {
211            binaryData.flip();
212        }
213        switch (opcode) {
214            case CLOSE:
215                char statusCode = NO_STATUS_CODE;
216                String reason = "";
217                if (payloadLen != 0) {
218                    int len = binaryData.remaining();
219                    assert 2 <= len && len <= 125 : dump(len, payloadLen);
220                    statusCode = binaryData.getChar();
221                    if (!isLegalToReceiveFromServer(statusCode)) {
222                        throw new FailWebSocketException(
223                                "Illegal status code: " + statusCode);
224                    }
225                    try {
226                        reason = UTF_8.newDecoder().decode(binaryData).toString();
227                    } catch (CharacterCodingException e) {
228                        throw new FailWebSocketException("Illegal close reason")
229                                .initCause(e);
230                    }
231                }
232                output.onClose(statusCode, reason);
233                break;
234            case PING:
235                output.onPing(binaryData);
236                binaryData = null;
237                break;
238            case PONG:
239                output.onPong(binaryData);
240                binaryData = null;
241                break;
242            default:
243                assert opcode == Opcode.TEXT || opcode == Opcode.BINARY
244                        || opcode == Opcode.CONTINUATION : dump(opcode);
245                if (fin) {
246                    // It is always the last chunk:
247                    // either TEXT(FIN=TRUE)/BINARY(FIN=TRUE) or CONT(FIN=TRUE)
248                    originatingOpcode = null;
249                }
250                break;
251        }
252        payloadLen = 0;
253        opcode = null;
254    }
255
256    private MessagePart determinePart(boolean isLast) {
257        boolean lastChunk = fin && isLast;
258        switch (part) {
259            case LAST:
260            case WHOLE:
261                return lastChunk ? MessagePart.WHOLE : MessagePart.FIRST;
262            case FIRST:
263            case PART:
264                return lastChunk ? MessagePart.LAST : MessagePart.PART;
265            default:
266                throw new InternalError(String.valueOf(part));
267        }
268    }
269}
270