1/*
2 * Copyright (C) 2011 Google Inc.  All rights reserved.
3 * Copyright (C) Research In Motion Limited 2011. All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are
7 * met:
8 *
9 *     * Redistributions of source code must retain the above copyright
10 * notice, this list of conditions and the following disclaimer.
11 *     * Redistributions in binary form must reproduce the above
12 * copyright notice, this list of conditions and the following disclaimer
13 * in the documentation and/or other materials provided with the
14 * distribution.
15 *     * Neither the name of Google Inc. nor the names of its
16 * contributors may be used to endorse or promote products derived from
17 * this software without specific prior written permission.
18 *
19 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32#include "config.h"
33
34#if ENABLE(WEB_SOCKETS)
35
36#include "WebSocketHandshake.h"
37#include "WebSocket.h"
38
39#include "Cookie.h"
40#include "CookieJar.h"
41#include "Document.h"
42#include "HTTPHeaderMap.h"
43#include "HTTPHeaderNames.h"
44#include "HTTPParsers.h"
45#include "URL.h"
46#include "Logging.h"
47#include "ResourceRequest.h"
48#include "ScriptExecutionContext.h"
49#include "SecurityOrigin.h"
50#include <wtf/CryptographicallyRandomNumber.h>
51#include <wtf/MD5.h>
52#include <wtf/SHA1.h>
53#include <wtf/StdLibExtras.h>
54#include <wtf/StringExtras.h>
55#include <wtf/Vector.h>
56#include <wtf/text/Base64.h>
57#include <wtf/text/CString.h>
58#include <wtf/text/StringBuilder.h>
59#include <wtf/text/WTFString.h>
60#include <wtf/unicode/CharacterNames.h>
61
62namespace WebCore {
63
64static String resourceName(const URL& url)
65{
66    StringBuilder name;
67    name.append(url.path());
68    if (name.isEmpty())
69        name.append('/');
70    if (!url.query().isNull()) {
71        name.append('?');
72        name.append(url.query());
73    }
74    String result = name.toString();
75    ASSERT(!result.isEmpty());
76    ASSERT(!result.contains(' '));
77    return result;
78}
79
80static String hostName(const URL& url, bool secure)
81{
82    ASSERT(url.protocolIs("wss") == secure);
83    StringBuilder builder;
84    builder.append(url.host().lower());
85    if (url.port() && ((!secure && url.port() != 80) || (secure && url.port() != 443))) {
86        builder.append(':');
87        builder.appendNumber(url.port());
88    }
89    return builder.toString();
90}
91
92static const size_t maxInputSampleSize = 128;
93static String trimInputSample(const char* p, size_t len)
94{
95    String s = String(p, std::min<size_t>(len, maxInputSampleSize));
96    if (len > maxInputSampleSize)
97        s.append(horizontalEllipsis);
98    return s;
99}
100
101static String generateSecWebSocketKey()
102{
103    static const size_t nonceSize = 16;
104    unsigned char key[nonceSize];
105    cryptographicallyRandomValues(key, nonceSize);
106    return base64Encode(key, nonceSize);
107}
108
109String WebSocketHandshake::getExpectedWebSocketAccept(const String& secWebSocketKey)
110{
111    static const char* const webSocketKeyGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
112    SHA1 sha1;
113    CString keyData = secWebSocketKey.ascii();
114    sha1.addBytes(reinterpret_cast<const uint8_t*>(keyData.data()), keyData.length());
115    sha1.addBytes(reinterpret_cast<const uint8_t*>(webSocketKeyGUID), strlen(webSocketKeyGUID));
116    SHA1::Digest hash;
117    sha1.computeHash(hash);
118    return base64Encode(hash.data(), SHA1::hashSize);
119}
120
121WebSocketHandshake::WebSocketHandshake(const URL& url, const String& protocol, ScriptExecutionContext* context)
122    : m_url(url)
123    , m_clientProtocol(protocol)
124    , m_secure(m_url.protocolIs("wss"))
125    , m_context(context)
126    , m_mode(Incomplete)
127{
128    m_secWebSocketKey = generateSecWebSocketKey();
129    m_expectedAccept = getExpectedWebSocketAccept(m_secWebSocketKey);
130}
131
132WebSocketHandshake::~WebSocketHandshake()
133{
134}
135
136const URL& WebSocketHandshake::url() const
137{
138    return m_url;
139}
140
141void WebSocketHandshake::setURL(const URL& url)
142{
143    m_url = url.copy();
144}
145
146const String WebSocketHandshake::host() const
147{
148    return m_url.host().lower();
149}
150
151const String& WebSocketHandshake::clientProtocol() const
152{
153    return m_clientProtocol;
154}
155
156void WebSocketHandshake::setClientProtocol(const String& protocol)
157{
158    m_clientProtocol = protocol;
159}
160
161bool WebSocketHandshake::secure() const
162{
163    return m_secure;
164}
165
166String WebSocketHandshake::clientOrigin() const
167{
168    return m_context->securityOrigin()->toString();
169}
170
171String WebSocketHandshake::clientLocation() const
172{
173    StringBuilder builder;
174    builder.append(m_secure ? "wss" : "ws");
175    builder.append("://");
176    builder.append(hostName(m_url, m_secure));
177    builder.append(resourceName(m_url));
178    return builder.toString();
179}
180
181CString WebSocketHandshake::clientHandshakeMessage() const
182{
183    // Keep the following consistent with clientHandshakeRequest().
184    StringBuilder builder;
185
186    builder.append("GET ");
187    builder.append(resourceName(m_url));
188    builder.append(" HTTP/1.1\r\n");
189
190    Vector<String> fields;
191    fields.append("Upgrade: websocket");
192    fields.append("Connection: Upgrade");
193    fields.append("Host: " + hostName(m_url, m_secure));
194    fields.append("Origin: " + clientOrigin());
195    if (!m_clientProtocol.isEmpty())
196        fields.append("Sec-WebSocket-Protocol: " + m_clientProtocol);
197
198    URL url = httpURLForAuthenticationAndCookies();
199    if (m_context->isDocument()) {
200        Document* document = toDocument(m_context);
201        String cookie = cookieRequestHeaderFieldValue(document, url);
202        if (!cookie.isEmpty())
203            fields.append("Cookie: " + cookie);
204        // Set "Cookie2: <cookie>" if cookies 2 exists for url?
205    }
206
207    // Add no-cache headers to avoid compatibility issue.
208    // There are some proxies that rewrite "Connection: upgrade"
209    // to "Connection: close" in the response if a request doesn't contain
210    // these headers.
211    fields.append("Pragma: no-cache");
212    fields.append("Cache-Control: no-cache");
213
214    fields.append("Sec-WebSocket-Key: " + m_secWebSocketKey);
215    fields.append("Sec-WebSocket-Version: 13");
216    const String extensionValue = m_extensionDispatcher.createHeaderValue();
217    if (extensionValue.length())
218        fields.append("Sec-WebSocket-Extensions: " + extensionValue);
219
220    // Add a User-Agent header.
221    fields.append("User-Agent: " + m_context->userAgent(m_context->url()));
222
223    // Fields in the handshake are sent by the client in a random order; the
224    // order is not meaningful.  Thus, it's ok to send the order we constructed
225    // the fields.
226
227    for (size_t i = 0; i < fields.size(); i++) {
228        builder.append(fields[i]);
229        builder.append("\r\n");
230    }
231
232    builder.append("\r\n");
233
234    return builder.toString().utf8();
235}
236
237ResourceRequest WebSocketHandshake::clientHandshakeRequest() const
238{
239    // Keep the following consistent with clientHandshakeMessage().
240    ResourceRequest request(m_url);
241    request.setHTTPMethod("GET");
242
243    request.setHTTPHeaderField(HTTPHeaderName::Connection, "Upgrade");
244    request.setHTTPHeaderField(HTTPHeaderName::Host, hostName(m_url, m_secure));
245    request.setHTTPHeaderField(HTTPHeaderName::Origin, clientOrigin());
246    if (!m_clientProtocol.isEmpty())
247        request.setHTTPHeaderField(HTTPHeaderName::SecWebSocketProtocol, m_clientProtocol);
248
249    URL url = httpURLForAuthenticationAndCookies();
250    if (m_context->isDocument()) {
251        Document* document = toDocument(m_context);
252        String cookie = cookieRequestHeaderFieldValue(document, url);
253        if (!cookie.isEmpty())
254            request.setHTTPHeaderField(HTTPHeaderName::Cookie, cookie);
255        // Set "Cookie2: <cookie>" if cookies 2 exists for url?
256    }
257
258    request.setHTTPHeaderField(HTTPHeaderName::Pragma, "no-cache");
259    request.setHTTPHeaderField(HTTPHeaderName::CacheControl, "no-cache");
260
261    request.setHTTPHeaderField(HTTPHeaderName::SecWebSocketKey, m_secWebSocketKey);
262    request.setHTTPHeaderField(HTTPHeaderName::SecWebSocketVersion, "13");
263    const String extensionValue = m_extensionDispatcher.createHeaderValue();
264    if (extensionValue.length())
265        request.setHTTPHeaderField(HTTPHeaderName::SecWebSocketExtensions, extensionValue);
266
267    // Add a User-Agent header.
268    request.setHTTPHeaderField(HTTPHeaderName::UserAgent, m_context->userAgent(m_context->url()));
269
270    return request;
271}
272
273void WebSocketHandshake::reset()
274{
275    m_mode = Incomplete;
276    m_extensionDispatcher.reset();
277}
278
279void WebSocketHandshake::clearScriptExecutionContext()
280{
281    m_context = 0;
282}
283
284int WebSocketHandshake::readServerHandshake(const char* header, size_t len)
285{
286    m_mode = Incomplete;
287    int statusCode;
288    String statusText;
289    int lineLength = readStatusLine(header, len, statusCode, statusText);
290    if (lineLength == -1)
291        return -1;
292    if (statusCode == -1) {
293        m_mode = Failed; // m_failureReason is set inside readStatusLine().
294        return len;
295    }
296    LOG(Network, "WebSocketHandshake %p readServerHandshake() Status code is %d", this, statusCode);
297
298    m_serverHandshakeResponse = ResourceResponse();
299    m_serverHandshakeResponse.setHTTPStatusCode(statusCode);
300    m_serverHandshakeResponse.setHTTPStatusText(statusText);
301
302    if (statusCode != 101) {
303        m_mode = Failed;
304        m_failureReason = "Unexpected response code: " + String::number(statusCode);
305        return len;
306    }
307    m_mode = Normal;
308    if (!strnstr(header, "\r\n\r\n", len)) {
309        // Just hasn't been received fully yet.
310        m_mode = Incomplete;
311        return -1;
312    }
313    const char* p = readHTTPHeaders(header + lineLength, header + len);
314    if (!p) {
315        LOG(Network, "WebSocketHandshake %p readServerHandshake() readHTTPHeaders() failed", this);
316        m_mode = Failed; // m_failureReason is set inside readHTTPHeaders().
317        return len;
318    }
319    if (!checkResponseHeaders()) {
320        LOG(Network, "WebSocketHandshake %p readServerHandshake() checkResponseHeaders() failed", this);
321        m_mode = Failed;
322        return p - header;
323    }
324
325    m_mode = Connected;
326    return p - header;
327}
328
329WebSocketHandshake::Mode WebSocketHandshake::mode() const
330{
331    return m_mode;
332}
333
334String WebSocketHandshake::failureReason() const
335{
336    return m_failureReason;
337}
338
339String WebSocketHandshake::serverWebSocketProtocol() const
340{
341    return m_serverHandshakeResponse.httpHeaderFields().get(HTTPHeaderName::SecWebSocketProtocol);
342}
343
344String WebSocketHandshake::serverSetCookie() const
345{
346    return m_serverHandshakeResponse.httpHeaderFields().get(HTTPHeaderName::SetCookie);
347}
348
349String WebSocketHandshake::serverSetCookie2() const
350{
351    return m_serverHandshakeResponse.httpHeaderFields().get(HTTPHeaderName::SetCookie2);
352}
353
354String WebSocketHandshake::serverUpgrade() const
355{
356    return m_serverHandshakeResponse.httpHeaderFields().get(HTTPHeaderName::Upgrade);
357}
358
359String WebSocketHandshake::serverConnection() const
360{
361    return m_serverHandshakeResponse.httpHeaderFields().get(HTTPHeaderName::Connection);
362}
363
364String WebSocketHandshake::serverWebSocketAccept() const
365{
366    return m_serverHandshakeResponse.httpHeaderFields().get(HTTPHeaderName::SecWebSocketAccept);
367}
368
369String WebSocketHandshake::acceptedExtensions() const
370{
371    return m_extensionDispatcher.acceptedExtensions();
372}
373
374const ResourceResponse& WebSocketHandshake::serverHandshakeResponse() const
375{
376    return m_serverHandshakeResponse;
377}
378
379void WebSocketHandshake::addExtensionProcessor(PassOwnPtr<WebSocketExtensionProcessor> processor)
380{
381    m_extensionDispatcher.addProcessor(processor);
382}
383
384URL WebSocketHandshake::httpURLForAuthenticationAndCookies() const
385{
386    URL url = m_url.copy();
387    bool couldSetProtocol = url.setProtocol(m_secure ? "https" : "http");
388    ASSERT_UNUSED(couldSetProtocol, couldSetProtocol);
389    return url;
390}
391
392// Returns the header length (including "\r\n"), or -1 if we have not received enough data yet.
393// If the line is malformed or the status code is not a 3-digit number,
394// statusCode and statusText will be set to -1 and a null string, respectively.
395int WebSocketHandshake::readStatusLine(const char* header, size_t headerLength, int& statusCode, String& statusText)
396{
397    // Arbitrary size limit to prevent the server from sending an unbounded
398    // amount of data with no newlines and forcing us to buffer it all.
399    static const int maximumLength = 1024;
400
401    statusCode = -1;
402    statusText = String();
403
404    const char* space1 = 0;
405    const char* space2 = 0;
406    const char* p;
407    size_t consumedLength;
408
409    for (p = header, consumedLength = 0; consumedLength < headerLength; p++, consumedLength++) {
410        if (*p == ' ') {
411            if (!space1)
412                space1 = p;
413            else if (!space2)
414                space2 = p;
415        } else if (*p == '\0') {
416            // The caller isn't prepared to deal with null bytes in status
417            // line. WebSockets specification doesn't prohibit this, but HTTP
418            // does, so we'll just treat this as an error.
419            m_failureReason = "Status line contains embedded null";
420            return p + 1 - header;
421        } else if (*p == '\n')
422            break;
423    }
424    if (consumedLength == headerLength)
425        return -1; // We have not received '\n' yet.
426
427    const char* end = p + 1;
428    int lineLength = end - header;
429    if (lineLength > maximumLength) {
430        m_failureReason = "Status line is too long";
431        return maximumLength;
432    }
433
434    // The line must end with "\r\n".
435    if (lineLength < 2 || *(end - 2) != '\r') {
436        m_failureReason = "Status line does not end with CRLF";
437        return lineLength;
438    }
439
440    if (!space1 || !space2) {
441        m_failureReason = "No response code found: " + trimInputSample(header, lineLength - 2);
442        return lineLength;
443    }
444
445    String statusCodeString(space1 + 1, space2 - space1 - 1);
446    if (statusCodeString.length() != 3) // Status code must consist of three digits.
447        return lineLength;
448    for (int i = 0; i < 3; ++i)
449        if (statusCodeString[i] < '0' || statusCodeString[i] > '9') {
450            m_failureReason = "Invalid status code: " + statusCodeString;
451            return lineLength;
452        }
453
454    bool ok = false;
455    statusCode = statusCodeString.toInt(&ok);
456    ASSERT(ok);
457
458    statusText = String(space2 + 1, end - space2 - 3); // Exclude "\r\n".
459    return lineLength;
460}
461
462const char* WebSocketHandshake::readHTTPHeaders(const char* start, const char* end)
463{
464    String name;
465    String value;
466    bool sawSecWebSocketExtensionsHeaderField = false;
467    bool sawSecWebSocketAcceptHeaderField = false;
468    bool sawSecWebSocketProtocolHeaderField = false;
469    const char* p = start;
470    for (; p < end; p++) {
471        size_t consumedLength = parseHTTPHeader(p, end - p, m_failureReason, name, value);
472        if (!consumedLength)
473            return 0;
474        p += consumedLength;
475
476        // Stop once we consumed an empty line.
477        if (name.isEmpty())
478            break;
479
480        if (equalIgnoringCase("sec-websocket-extensions", name)) {
481            if (sawSecWebSocketExtensionsHeaderField) {
482                m_failureReason = "The Sec-WebSocket-Extensions header MUST NOT appear more than once in an HTTP response";
483                return 0;
484            }
485            if (!m_extensionDispatcher.processHeaderValue(value)) {
486                m_failureReason = m_extensionDispatcher.failureReason();
487                return 0;
488            }
489            sawSecWebSocketExtensionsHeaderField = true;
490        } else if (equalIgnoringCase("Sec-WebSocket-Accept", name)) {
491            if (sawSecWebSocketAcceptHeaderField) {
492                m_failureReason = "The Sec-WebSocket-Accept header MUST NOT appear more than once in an HTTP response";
493                return 0;
494            }
495            m_serverHandshakeResponse.addHTTPHeaderField(name, value);
496            sawSecWebSocketAcceptHeaderField = true;
497        } else if (equalIgnoringCase("Sec-WebSocket-Protocol", name)) {
498            if (sawSecWebSocketProtocolHeaderField) {
499                m_failureReason = "The Sec-WebSocket-Protocol header MUST NOT appear more than once in an HTTP response";
500                return 0;
501            }
502            m_serverHandshakeResponse.addHTTPHeaderField(name, value);
503            sawSecWebSocketProtocolHeaderField = true;
504        } else
505            m_serverHandshakeResponse.addHTTPHeaderField(name, value);
506    }
507    return p;
508}
509
510bool WebSocketHandshake::checkResponseHeaders()
511{
512    const String& serverWebSocketProtocol = this->serverWebSocketProtocol();
513    const String& serverUpgrade = this->serverUpgrade();
514    const String& serverConnection = this->serverConnection();
515    const String& serverWebSocketAccept = this->serverWebSocketAccept();
516
517    if (serverUpgrade.isNull()) {
518        m_failureReason = "Error during WebSocket handshake: 'Upgrade' header is missing";
519        return false;
520    }
521    if (serverConnection.isNull()) {
522        m_failureReason = "Error during WebSocket handshake: 'Connection' header is missing";
523        return false;
524    }
525    if (serverWebSocketAccept.isNull()) {
526        m_failureReason = "Error during WebSocket handshake: 'Sec-WebSocket-Accept' header is missing";
527        return false;
528    }
529
530    if (!equalIgnoringCase(serverUpgrade, "websocket")) {
531        m_failureReason = "Error during WebSocket handshake: 'Upgrade' header value is not 'WebSocket'";
532        return false;
533    }
534    if (!equalIgnoringCase(serverConnection, "upgrade")) {
535        m_failureReason = "Error during WebSocket handshake: 'Connection' header value is not 'Upgrade'";
536        return false;
537    }
538
539    if (serverWebSocketAccept != m_expectedAccept) {
540        m_failureReason = "Error during WebSocket handshake: Sec-WebSocket-Accept mismatch";
541        return false;
542    }
543    if (!serverWebSocketProtocol.isNull()) {
544        if (m_clientProtocol.isEmpty()) {
545            m_failureReason = "Error during WebSocket handshake: Sec-WebSocket-Protocol mismatch";
546            return false;
547        }
548        Vector<String> result;
549        m_clientProtocol.split(String(WebSocket::subProtocolSeperator()), result);
550        if (!result.contains(serverWebSocketProtocol)) {
551            m_failureReason = "Error during WebSocket handshake: Sec-WebSocket-Protocol mismatch";
552            return false;
553        }
554    }
555    return true;
556}
557
558} // namespace WebCore
559
560#endif // ENABLE(WEB_SOCKETS)
561