1/*
2 * Copyright (C) 2011 Google Inc.  All rights reserved.
3 * Copyright (C) Research In Motion Limited 2011. All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are
7 * met:
8 *
9 *     * Redistributions of source code must retain the above copyright
10 * notice, this list of conditions and the following disclaimer.
11 *     * Redistributions in binary form must reproduce the above
12 * copyright notice, this list of conditions and the following disclaimer
13 * in the documentation and/or other materials provided with the
14 * distribution.
15 *     * Neither the name of Google Inc. nor the names of its
16 * contributors may be used to endorse or promote products derived from
17 * this software without specific prior written permission.
18 *
19 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32#include "config.h"
33
34#if ENABLE(WEB_SOCKETS)
35
36#include "WebSocketHandshake.h"
37#include "WebSocket.h"
38
39#include "Cookie.h"
40#include "CookieJar.h"
41#include "Document.h"
42#include "HTTPHeaderMap.h"
43#include "HTTPParsers.h"
44#include "KURL.h"
45#include "Logging.h"
46#include "ResourceRequest.h"
47#include "ScriptCallStack.h"
48#include "ScriptExecutionContext.h"
49#include "SecurityOrigin.h"
50#include <wtf/CryptographicallyRandomNumber.h>
51#include <wtf/MD5.h>
52#include <wtf/SHA1.h>
53#include <wtf/StdLibExtras.h>
54#include <wtf/StringExtras.h>
55#include <wtf/Vector.h>
56#include <wtf/text/Base64.h>
57#include <wtf/text/CString.h>
58#include <wtf/text/StringBuilder.h>
59#include <wtf/text/WTFString.h>
60#include <wtf/unicode/CharacterNames.h>
61
62namespace WebCore {
63
64static const char randomCharacterInSecWebSocketKey[] = "!\"#$%&'()*+,-./:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~";
65
66static String resourceName(const KURL& url)
67{
68    StringBuilder name;
69    name.append(url.path());
70    if (name.isEmpty())
71        name.append('/');
72    if (!url.query().isNull()) {
73        name.append('?');
74        name.append(url.query());
75    }
76    String result = name.toString();
77    ASSERT(!result.isEmpty());
78    ASSERT(!result.contains(' '));
79    return result;
80}
81
82static String hostName(const KURL& url, bool secure)
83{
84    ASSERT(url.protocolIs("wss") == secure);
85    StringBuilder builder;
86    builder.append(url.host().lower());
87    if (url.port() && ((!secure && url.port() != 80) || (secure && url.port() != 443))) {
88        builder.append(':');
89        builder.appendNumber(url.port());
90    }
91    return builder.toString();
92}
93
94static const size_t maxInputSampleSize = 128;
95static String trimInputSample(const char* p, size_t len)
96{
97    String s = String(p, std::min<size_t>(len, maxInputSampleSize));
98    if (len > maxInputSampleSize)
99        s.append(horizontalEllipsis);
100    return s;
101}
102
103static String generateSecWebSocketKey()
104{
105    static const size_t nonceSize = 16;
106    unsigned char key[nonceSize];
107    cryptographicallyRandomValues(key, nonceSize);
108    return base64Encode(reinterpret_cast<char*>(key), nonceSize);
109}
110
111String WebSocketHandshake::getExpectedWebSocketAccept(const String& secWebSocketKey)
112{
113    static const char* const webSocketKeyGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
114    static const size_t sha1HashSize = 20; // FIXME: This should be defined in SHA1.h.
115    SHA1 sha1;
116    CString keyData = secWebSocketKey.ascii();
117    sha1.addBytes(reinterpret_cast<const uint8_t*>(keyData.data()), keyData.length());
118    sha1.addBytes(reinterpret_cast<const uint8_t*>(webSocketKeyGUID), strlen(webSocketKeyGUID));
119    Vector<uint8_t, sha1HashSize> hash;
120    sha1.computeHash(hash);
121    return base64Encode(reinterpret_cast<const char*>(hash.data()), sha1HashSize);
122}
123
124WebSocketHandshake::WebSocketHandshake(const KURL& url, const String& protocol, ScriptExecutionContext* context)
125    : m_url(url)
126    , m_clientProtocol(protocol)
127    , m_secure(m_url.protocolIs("wss"))
128    , m_context(context)
129    , m_mode(Incomplete)
130{
131    m_secWebSocketKey = generateSecWebSocketKey();
132    m_expectedAccept = getExpectedWebSocketAccept(m_secWebSocketKey);
133}
134
135WebSocketHandshake::~WebSocketHandshake()
136{
137}
138
139const KURL& WebSocketHandshake::url() const
140{
141    return m_url;
142}
143
144void WebSocketHandshake::setURL(const KURL& url)
145{
146    m_url = url.copy();
147}
148
149const String WebSocketHandshake::host() const
150{
151    return m_url.host().lower();
152}
153
154const String& WebSocketHandshake::clientProtocol() const
155{
156    return m_clientProtocol;
157}
158
159void WebSocketHandshake::setClientProtocol(const String& protocol)
160{
161    m_clientProtocol = protocol;
162}
163
164bool WebSocketHandshake::secure() const
165{
166    return m_secure;
167}
168
169String WebSocketHandshake::clientOrigin() const
170{
171    return m_context->securityOrigin()->toString();
172}
173
174String WebSocketHandshake::clientLocation() const
175{
176    StringBuilder builder;
177    builder.append(m_secure ? "wss" : "ws");
178    builder.append("://");
179    builder.append(hostName(m_url, m_secure));
180    builder.append(resourceName(m_url));
181    return builder.toString();
182}
183
184CString WebSocketHandshake::clientHandshakeMessage() const
185{
186    // Keep the following consistent with clientHandshakeRequest().
187    StringBuilder builder;
188
189    builder.append("GET ");
190    builder.append(resourceName(m_url));
191    builder.append(" HTTP/1.1\r\n");
192
193    Vector<String> fields;
194    fields.append("Upgrade: websocket");
195    fields.append("Connection: Upgrade");
196    fields.append("Host: " + hostName(m_url, m_secure));
197    fields.append("Origin: " + clientOrigin());
198    if (!m_clientProtocol.isEmpty())
199        fields.append("Sec-WebSocket-Protocol: " + m_clientProtocol);
200
201    KURL url = httpURLForAuthenticationAndCookies();
202    if (m_context->isDocument()) {
203        Document* document = toDocument(m_context);
204        String cookie = cookieRequestHeaderFieldValue(document, url);
205        if (!cookie.isEmpty())
206            fields.append("Cookie: " + cookie);
207        // Set "Cookie2: <cookie>" if cookies 2 exists for url?
208    }
209
210    // Add no-cache headers to avoid compatibility issue.
211    // There are some proxies that rewrite "Connection: upgrade"
212    // to "Connection: close" in the response if a request doesn't contain
213    // these headers.
214    fields.append("Pragma: no-cache");
215    fields.append("Cache-Control: no-cache");
216
217    fields.append("Sec-WebSocket-Key: " + m_secWebSocketKey);
218    fields.append("Sec-WebSocket-Version: 13");
219    const String extensionValue = m_extensionDispatcher.createHeaderValue();
220    if (extensionValue.length())
221        fields.append("Sec-WebSocket-Extensions: " + extensionValue);
222
223    // Add a User-Agent header.
224    fields.append("User-Agent: " + m_context->userAgent(m_context->url()));
225
226    // Fields in the handshake are sent by the client in a random order; the
227    // order is not meaningful.  Thus, it's ok to send the order we constructed
228    // the fields.
229
230    for (size_t i = 0; i < fields.size(); i++) {
231        builder.append(fields[i]);
232        builder.append("\r\n");
233    }
234
235    builder.append("\r\n");
236
237    return builder.toString().utf8();
238}
239
240ResourceRequest WebSocketHandshake::clientHandshakeRequest() const
241{
242    // Keep the following consistent with clientHandshakeMessage().
243    // FIXME: do we need to store m_secWebSocketKey1, m_secWebSocketKey2 and
244    // m_key3 in the request?
245    ResourceRequest request(m_url);
246    request.setHTTPMethod("GET");
247
248    request.addHTTPHeaderField("Connection", "Upgrade");
249    request.addHTTPHeaderField("Host", hostName(m_url, m_secure));
250    request.addHTTPHeaderField("Origin", clientOrigin());
251    if (!m_clientProtocol.isEmpty())
252        request.addHTTPHeaderField("Sec-WebSocket-Protocol", m_clientProtocol);
253
254    KURL url = httpURLForAuthenticationAndCookies();
255    if (m_context->isDocument()) {
256        Document* document = toDocument(m_context);
257        String cookie = cookieRequestHeaderFieldValue(document, url);
258        if (!cookie.isEmpty())
259            request.addHTTPHeaderField("Cookie", cookie);
260        // Set "Cookie2: <cookie>" if cookies 2 exists for url?
261    }
262
263    request.addHTTPHeaderField("Pragma", "no-cache");
264    request.addHTTPHeaderField("Cache-Control", "no-cache");
265
266    request.addHTTPHeaderField("Sec-WebSocket-Key", m_secWebSocketKey);
267    request.addHTTPHeaderField("Sec-WebSocket-Version", "13");
268    const String extensionValue = m_extensionDispatcher.createHeaderValue();
269    if (extensionValue.length())
270        request.addHTTPHeaderField("Sec-WebSocket-Extensions", extensionValue);
271
272    // Add a User-Agent header.
273    request.addHTTPHeaderField("User-Agent", m_context->userAgent(m_context->url()));
274
275    return request;
276}
277
278void WebSocketHandshake::reset()
279{
280    m_mode = Incomplete;
281    m_extensionDispatcher.reset();
282}
283
284void WebSocketHandshake::clearScriptExecutionContext()
285{
286    m_context = 0;
287}
288
289int WebSocketHandshake::readServerHandshake(const char* header, size_t len)
290{
291    m_mode = Incomplete;
292    int statusCode;
293    String statusText;
294    int lineLength = readStatusLine(header, len, statusCode, statusText);
295    if (lineLength == -1)
296        return -1;
297    if (statusCode == -1) {
298        m_mode = Failed; // m_failureReason is set inside readStatusLine().
299        return len;
300    }
301    LOG(Network, "WebSocketHandshake %p readServerHandshake() Status code is %d", this, statusCode);
302
303    m_serverHandshakeResponse = ResourceResponse();
304    m_serverHandshakeResponse.setHTTPStatusCode(statusCode);
305    m_serverHandshakeResponse.setHTTPStatusText(statusText);
306
307    if (statusCode != 101) {
308        m_mode = Failed;
309        m_failureReason = "Unexpected response code: " + String::number(statusCode);
310        return len;
311    }
312    m_mode = Normal;
313    if (!strnstr(header, "\r\n\r\n", len)) {
314        // Just hasn't been received fully yet.
315        m_mode = Incomplete;
316        return -1;
317    }
318    const char* p = readHTTPHeaders(header + lineLength, header + len);
319    if (!p) {
320        LOG(Network, "WebSocketHandshake %p readServerHandshake() readHTTPHeaders() failed", this);
321        m_mode = Failed; // m_failureReason is set inside readHTTPHeaders().
322        return len;
323    }
324    if (!checkResponseHeaders()) {
325        LOG(Network, "WebSocketHandshake %p readServerHandshake() checkResponseHeaders() failed", this);
326        m_mode = Failed;
327        return p - header;
328    }
329
330    m_mode = Connected;
331    return p - header;
332}
333
334WebSocketHandshake::Mode WebSocketHandshake::mode() const
335{
336    return m_mode;
337}
338
339String WebSocketHandshake::failureReason() const
340{
341    return m_failureReason;
342}
343
344String WebSocketHandshake::serverWebSocketProtocol() const
345{
346    return m_serverHandshakeResponse.httpHeaderFields().get("sec-websocket-protocol");
347}
348
349String WebSocketHandshake::serverSetCookie() const
350{
351    return m_serverHandshakeResponse.httpHeaderFields().get("set-cookie");
352}
353
354String WebSocketHandshake::serverSetCookie2() const
355{
356    return m_serverHandshakeResponse.httpHeaderFields().get("set-cookie2");
357}
358
359String WebSocketHandshake::serverUpgrade() const
360{
361    return m_serverHandshakeResponse.httpHeaderFields().get("upgrade");
362}
363
364String WebSocketHandshake::serverConnection() const
365{
366    return m_serverHandshakeResponse.httpHeaderFields().get("connection");
367}
368
369String WebSocketHandshake::serverWebSocketAccept() const
370{
371    return m_serverHandshakeResponse.httpHeaderFields().get("sec-websocket-accept");
372}
373
374String WebSocketHandshake::acceptedExtensions() const
375{
376    return m_extensionDispatcher.acceptedExtensions();
377}
378
379const ResourceResponse& WebSocketHandshake::serverHandshakeResponse() const
380{
381    return m_serverHandshakeResponse;
382}
383
384void WebSocketHandshake::addExtensionProcessor(PassOwnPtr<WebSocketExtensionProcessor> processor)
385{
386    m_extensionDispatcher.addProcessor(processor);
387}
388
389KURL WebSocketHandshake::httpURLForAuthenticationAndCookies() const
390{
391    KURL url = m_url.copy();
392    bool couldSetProtocol = url.setProtocol(m_secure ? "https" : "http");
393    ASSERT_UNUSED(couldSetProtocol, couldSetProtocol);
394    return url;
395}
396
397// Returns the header length (including "\r\n"), or -1 if we have not received enough data yet.
398// If the line is malformed or the status code is not a 3-digit number,
399// statusCode and statusText will be set to -1 and a null string, respectively.
400int WebSocketHandshake::readStatusLine(const char* header, size_t headerLength, int& statusCode, String& statusText)
401{
402    // Arbitrary size limit to prevent the server from sending an unbounded
403    // amount of data with no newlines and forcing us to buffer it all.
404    static const int maximumLength = 1024;
405
406    statusCode = -1;
407    statusText = String();
408
409    const char* space1 = 0;
410    const char* space2 = 0;
411    const char* p;
412    size_t consumedLength;
413
414    for (p = header, consumedLength = 0; consumedLength < headerLength; p++, consumedLength++) {
415        if (*p == ' ') {
416            if (!space1)
417                space1 = p;
418            else if (!space2)
419                space2 = p;
420        } else if (*p == '\0') {
421            // The caller isn't prepared to deal with null bytes in status
422            // line. WebSockets specification doesn't prohibit this, but HTTP
423            // does, so we'll just treat this as an error.
424            m_failureReason = "Status line contains embedded null";
425            return p + 1 - header;
426        } else if (*p == '\n')
427            break;
428    }
429    if (consumedLength == headerLength)
430        return -1; // We have not received '\n' yet.
431
432    const char* end = p + 1;
433    int lineLength = end - header;
434    if (lineLength > maximumLength) {
435        m_failureReason = "Status line is too long";
436        return maximumLength;
437    }
438
439    // The line must end with "\r\n".
440    if (lineLength < 2 || *(end - 2) != '\r') {
441        m_failureReason = "Status line does not end with CRLF";
442        return lineLength;
443    }
444
445    if (!space1 || !space2) {
446        m_failureReason = "No response code found: " + trimInputSample(header, lineLength - 2);
447        return lineLength;
448    }
449
450    String statusCodeString(space1 + 1, space2 - space1 - 1);
451    if (statusCodeString.length() != 3) // Status code must consist of three digits.
452        return lineLength;
453    for (int i = 0; i < 3; ++i)
454        if (statusCodeString[i] < '0' || statusCodeString[i] > '9') {
455            m_failureReason = "Invalid status code: " + statusCodeString;
456            return lineLength;
457        }
458
459    bool ok = false;
460    statusCode = statusCodeString.toInt(&ok);
461    ASSERT(ok);
462
463    statusText = String(space2 + 1, end - space2 - 3); // Exclude "\r\n".
464    return lineLength;
465}
466
467const char* WebSocketHandshake::readHTTPHeaders(const char* start, const char* end)
468{
469    AtomicString name;
470    String value;
471    bool sawSecWebSocketExtensionsHeaderField = false;
472    bool sawSecWebSocketAcceptHeaderField = false;
473    bool sawSecWebSocketProtocolHeaderField = false;
474    const char* p = start;
475    for (; p < end; p++) {
476        size_t consumedLength = parseHTTPHeader(p, end - p, m_failureReason, name, value);
477        if (!consumedLength)
478            return 0;
479        p += consumedLength;
480
481        // Stop once we consumed an empty line.
482        if (name.isEmpty())
483            break;
484
485        if (equalIgnoringCase("sec-websocket-extensions", name)) {
486            if (sawSecWebSocketExtensionsHeaderField) {
487                m_failureReason = "The Sec-WebSocket-Extensions header MUST NOT appear more than once in an HTTP response";
488                return 0;
489            }
490            if (!m_extensionDispatcher.processHeaderValue(value)) {
491                m_failureReason = m_extensionDispatcher.failureReason();
492                return 0;
493            }
494            sawSecWebSocketExtensionsHeaderField = true;
495        } else if (equalIgnoringCase("Sec-WebSocket-Accept", name)) {
496            if (sawSecWebSocketAcceptHeaderField) {
497                m_failureReason = "The Sec-WebSocket-Accept header MUST NOT appear more than once in an HTTP response";
498                return 0;
499            }
500            m_serverHandshakeResponse.addHTTPHeaderField(name, value);
501            sawSecWebSocketAcceptHeaderField = true;
502        } else if (equalIgnoringCase("Sec-WebSocket-Protocol", name)) {
503            if (sawSecWebSocketProtocolHeaderField) {
504                m_failureReason = "The Sec-WebSocket-Protocol header MUST NOT appear more than once in an HTTP response";
505                return 0;
506            }
507            m_serverHandshakeResponse.addHTTPHeaderField(name, value);
508            sawSecWebSocketProtocolHeaderField = true;
509        } else
510            m_serverHandshakeResponse.addHTTPHeaderField(name, value);
511    }
512    return p;
513}
514
515bool WebSocketHandshake::checkResponseHeaders()
516{
517    const String& serverWebSocketProtocol = this->serverWebSocketProtocol();
518    const String& serverUpgrade = this->serverUpgrade();
519    const String& serverConnection = this->serverConnection();
520    const String& serverWebSocketAccept = this->serverWebSocketAccept();
521
522    if (serverUpgrade.isNull()) {
523        m_failureReason = "Error during WebSocket handshake: 'Upgrade' header is missing";
524        return false;
525    }
526    if (serverConnection.isNull()) {
527        m_failureReason = "Error during WebSocket handshake: 'Connection' header is missing";
528        return false;
529    }
530    if (serverWebSocketAccept.isNull()) {
531        m_failureReason = "Error during WebSocket handshake: 'Sec-WebSocket-Accept' header is missing";
532        return false;
533    }
534
535    if (!equalIgnoringCase(serverUpgrade, "websocket")) {
536        m_failureReason = "Error during WebSocket handshake: 'Upgrade' header value is not 'WebSocket'";
537        return false;
538    }
539    if (!equalIgnoringCase(serverConnection, "upgrade")) {
540        m_failureReason = "Error during WebSocket handshake: 'Connection' header value is not 'Upgrade'";
541        return false;
542    }
543
544    if (serverWebSocketAccept != m_expectedAccept) {
545        m_failureReason = "Error during WebSocket handshake: Sec-WebSocket-Accept mismatch";
546        return false;
547    }
548    if (!serverWebSocketProtocol.isNull()) {
549        if (m_clientProtocol.isEmpty()) {
550            m_failureReason = "Error during WebSocket handshake: Sec-WebSocket-Protocol mismatch";
551            return false;
552        }
553        Vector<String> result;
554        m_clientProtocol.split(String(WebSocket::subProtocolSeperator()), result);
555        if (!result.contains(serverWebSocketProtocol)) {
556            m_failureReason = "Error during WebSocket handshake: Sec-WebSocket-Protocol mismatch";
557            return false;
558        }
559    }
560    return true;
561}
562
563} // namespace WebCore
564
565#endif // ENABLE(WEB_SOCKETS)
566