1require_relative "utils"
2
3if defined?(OpenSSL)
4
5class OpenSSL::TestSSL < OpenSSL::SSLTestCase
6
7  TLS_DEFAULT_OPS = defined?(OpenSSL::SSL::OP_DONT_INSERT_EMPTY_FRAGMENTS) ?
8                    OpenSSL::SSL::OP_ALL & ~OpenSSL::SSL::OP_DONT_INSERT_EMPTY_FRAGMENTS :
9                    OpenSSL::SSL::OP_ALL
10
11  def test_ctx_setup
12    ctx = OpenSSL::SSL::SSLContext.new
13    assert_equal(ctx.setup, true)
14    assert_equal(ctx.setup, nil)
15  end
16
17  def test_ctx_setup_no_compression
18    ctx = OpenSSL::SSL::SSLContext.new
19    ctx.options = OpenSSL::SSL::OP_ALL | OpenSSL::SSL::OP_NO_COMPRESSION
20    assert_equal(ctx.setup, true)
21    assert_equal(ctx.setup, nil)
22    assert_equal(OpenSSL::SSL::OP_NO_COMPRESSION,
23                 ctx.options & OpenSSL::SSL::OP_NO_COMPRESSION)
24  end if defined?(OpenSSL::SSL::OP_NO_COMPRESSION)
25
26  def test_not_started_session
27    skip "non socket argument of SSLSocket.new is not supported on this platform" if /mswin|mingw/ =~ RUBY_PLATFORM
28    open(__FILE__) do |f|
29      assert_nil OpenSSL::SSL::SSLSocket.new(f).cert
30    end
31  end
32
33  def test_ssl_read_nonblock
34    start_server(PORT, OpenSSL::SSL::VERIFY_NONE, true) { |server, port|
35      server_connect(port) { |ssl|
36        assert_raise(IO::WaitReadable) { ssl.read_nonblock(100) }
37        ssl.write("abc\n")
38        IO.select [ssl]
39        assert_equal('a', ssl.read_nonblock(1))
40        assert_equal("bc\n", ssl.read_nonblock(100))
41        assert_raise(IO::WaitReadable) { ssl.read_nonblock(100) }
42      }
43    }
44  end
45
46  def test_connect_and_close
47    start_server(PORT, OpenSSL::SSL::VERIFY_NONE, true){|server, port|
48      sock = TCPSocket.new("127.0.0.1", port)
49      ssl = OpenSSL::SSL::SSLSocket.new(sock)
50      assert(ssl.connect)
51      ssl.close
52      assert(!sock.closed?)
53      sock.close
54
55      sock = TCPSocket.new("127.0.0.1", port)
56      ssl = OpenSSL::SSL::SSLSocket.new(sock)
57      ssl.sync_close = true  # !!
58      assert(ssl.connect)
59      ssl.close
60      assert(sock.closed?)
61    }
62  end
63
64  def test_read_and_write
65    start_server(PORT, OpenSSL::SSL::VERIFY_NONE, true){|server, port|
66      server_connect(port) { |ssl|
67        # syswrite and sysread
68        ITERATIONS.times{|i|
69          str = "x" * 100 + "\n"
70          ssl.syswrite(str)
71          assert_equal(str, ssl.sysread(str.size))
72
73          str = "x" * i * 100 + "\n"
74          buf = ""
75          ssl.syswrite(str)
76          assert_equal(buf.object_id, ssl.sysread(str.size, buf).object_id)
77          assert_equal(str, buf)
78        }
79
80        # puts and gets
81        ITERATIONS.times{
82          str = "x" * 100 + "\n"
83          ssl.puts(str)
84          assert_equal(str, ssl.gets)
85
86          str = "x" * 100
87          ssl.puts(str)
88          assert_equal(str, ssl.gets("\n", 100))
89          assert_equal("\n", ssl.gets)
90        }
91
92        # read and write
93        ITERATIONS.times{|i|
94          str = "x" * 100 + "\n"
95          ssl.write(str)
96          assert_equal(str, ssl.read(str.size))
97
98          str = "x" * i * 100 + "\n"
99          buf = ""
100          ssl.write(str)
101          assert_equal(buf.object_id, ssl.read(str.size, buf).object_id)
102          assert_equal(str, buf)
103        }
104      }
105    }
106  end
107
108  def test_client_auth
109    vflag = OpenSSL::SSL::VERIFY_PEER|OpenSSL::SSL::VERIFY_FAIL_IF_NO_PEER_CERT
110    start_server(PORT, vflag, true){|server, port|
111      assert_raise(OpenSSL::SSL::SSLError, Errno::ECONNRESET){
112        sock = TCPSocket.new("127.0.0.1", port)
113        ssl = OpenSSL::SSL::SSLSocket.new(sock)
114        ssl.connect
115      }
116
117      ctx = OpenSSL::SSL::SSLContext.new
118      ctx.key = @cli_key
119      ctx.cert = @cli_cert
120
121      server_connect(port, ctx) { |ssl|
122        ssl.puts("foo")
123        assert_equal("foo\n", ssl.gets)
124      }
125
126      called = nil
127      ctx = OpenSSL::SSL::SSLContext.new
128      ctx.client_cert_cb = Proc.new{ |sslconn|
129        called = true
130        [@cli_cert, @cli_key]
131      }
132
133      server_connect(port, ctx) { |ssl|
134        assert(called)
135        ssl.puts("foo")
136        assert_equal("foo\n", ssl.gets)
137      }
138    }
139  end
140
141  def test_client_ca
142    ctx_proc = Proc.new do |ctx|
143      ctx.client_ca = [@ca_cert]
144    end
145
146    vflag = OpenSSL::SSL::VERIFY_PEER|OpenSSL::SSL::VERIFY_FAIL_IF_NO_PEER_CERT
147    start_server(PORT, vflag, true, :ctx_proc => ctx_proc){|server, port|
148      ctx = OpenSSL::SSL::SSLContext.new
149      client_ca_from_server = nil
150      ctx.client_cert_cb = Proc.new do |sslconn|
151        client_ca_from_server = sslconn.client_ca
152        [@cli_cert, @cli_key]
153      end
154      server_connect(port, ctx) { |ssl| assert_equal([@ca], client_ca_from_server) }
155    }
156  end
157
158  def test_starttls
159    start_server(PORT, OpenSSL::SSL::VERIFY_NONE, false){|server, port|
160      sock = TCPSocket.new("127.0.0.1", port)
161      ssl = OpenSSL::SSL::SSLSocket.new(sock)
162      ssl.sync_close = true
163      str = "x" * 1000 + "\n"
164
165      OpenSSL::TestUtils.silent do
166        ITERATIONS.times{
167          ssl.puts(str)
168          assert_equal(str, ssl.gets)
169        }
170        starttls(ssl)
171      end
172
173      ITERATIONS.times{
174        ssl.puts(str)
175        assert_equal(str, ssl.gets)
176      }
177
178      ssl.close
179    }
180  end
181
182  def test_parallel
183    GC.start
184    start_server(PORT, OpenSSL::SSL::VERIFY_NONE, true){|server, port|
185      ssls = []
186      10.times{
187        sock = TCPSocket.new("127.0.0.1", port)
188        ssl = OpenSSL::SSL::SSLSocket.new(sock)
189        ssl.connect
190        ssl.sync_close = true
191        ssls << ssl
192      }
193      str = "x" * 1000 + "\n"
194      ITERATIONS.times{
195        ssls.each{|ssl|
196          ssl.puts(str)
197          assert_equal(str, ssl.gets)
198        }
199      }
200      ssls.each{|ssl| ssl.close }
201    }
202  end
203
204  def test_verify_result
205    start_server(PORT, OpenSSL::SSL::VERIFY_NONE, true){|server, port|
206      sock = TCPSocket.new("127.0.0.1", port)
207      ctx = OpenSSL::SSL::SSLContext.new
208      ctx.set_params
209      ssl = OpenSSL::SSL::SSLSocket.new(sock, ctx)
210      assert_raise(OpenSSL::SSL::SSLError){ ssl.connect }
211      assert_equal(OpenSSL::X509::V_ERR_SELF_SIGNED_CERT_IN_CHAIN, ssl.verify_result)
212
213      sock = TCPSocket.new("127.0.0.1", port)
214      ctx = OpenSSL::SSL::SSLContext.new
215      ctx.set_params(
216        :verify_callback => Proc.new do |preverify_ok, store_ctx|
217          store_ctx.error = OpenSSL::X509::V_OK
218          true
219        end
220      )
221      ssl = OpenSSL::SSL::SSLSocket.new(sock, ctx)
222      ssl.connect
223      assert_equal(OpenSSL::X509::V_OK, ssl.verify_result)
224
225      sock = TCPSocket.new("127.0.0.1", port)
226      ctx = OpenSSL::SSL::SSLContext.new
227      ctx.set_params(
228        :verify_callback => Proc.new do |preverify_ok, store_ctx|
229          store_ctx.error = OpenSSL::X509::V_ERR_APPLICATION_VERIFICATION
230          false
231        end
232      )
233      ssl = OpenSSL::SSL::SSLSocket.new(sock, ctx)
234      assert_raise(OpenSSL::SSL::SSLError){ ssl.connect }
235      assert_equal(OpenSSL::X509::V_ERR_APPLICATION_VERIFICATION, ssl.verify_result)
236    }
237  end
238
239  def test_exception_in_verify_callback_is_ignored
240    start_server(PORT, OpenSSL::SSL::VERIFY_NONE, true){|server, port|
241      sock = TCPSocket.new("127.0.0.1", port)
242      ctx = OpenSSL::SSL::SSLContext.new
243      ctx.set_params(
244        :verify_callback => Proc.new do |preverify_ok, store_ctx|
245          store_ctx.error = OpenSSL::X509::V_OK
246          raise RuntimeError
247        end
248      )
249      ssl = OpenSSL::SSL::SSLSocket.new(sock, ctx)
250      OpenSSL::TestUtils.silent do
251        # SSLError, not RuntimeError
252        assert_raise(OpenSSL::SSL::SSLError) { ssl.connect }
253      end
254      assert_equal(OpenSSL::X509::V_ERR_CERT_REJECTED, ssl.verify_result)
255      ssl.close
256    }
257  end
258
259  def test_sslctx_set_params
260    start_server(PORT, OpenSSL::SSL::VERIFY_NONE, true){|server, port|
261      sock = TCPSocket.new("127.0.0.1", port)
262      ctx = OpenSSL::SSL::SSLContext.new
263      ctx.set_params
264      assert_equal(OpenSSL::SSL::VERIFY_PEER, ctx.verify_mode)
265      assert_equal(TLS_DEFAULT_OPS, ctx.options)
266      ciphers = ctx.ciphers
267      ciphers_versions = ciphers.collect{|_, v, _, _| v }
268      ciphers_names = ciphers.collect{|v, _, _, _| v }
269      assert(ciphers_names.all?{|v| /ADH/ !~ v })
270      assert(ciphers_versions.all?{|v| /SSLv2/ !~ v })
271      ssl = OpenSSL::SSL::SSLSocket.new(sock, ctx)
272      assert_raise(OpenSSL::SSL::SSLError){ ssl.connect }
273      assert_equal(OpenSSL::X509::V_ERR_SELF_SIGNED_CERT_IN_CHAIN, ssl.verify_result)
274    }
275  end
276
277  def test_post_connection_check
278    sslerr = OpenSSL::SSL::SSLError
279
280    start_server(PORT, OpenSSL::SSL::VERIFY_NONE, true){|server, port|
281      server_connect(port) { |ssl|
282        assert_raise(sslerr){ssl.post_connection_check("localhost.localdomain")}
283        assert_raise(sslerr){ssl.post_connection_check("127.0.0.1")}
284        assert(ssl.post_connection_check("localhost"))
285        assert_raise(sslerr){ssl.post_connection_check("foo.example.com")}
286
287        cert = ssl.peer_cert
288        assert(!OpenSSL::SSL.verify_certificate_identity(cert, "localhost.localdomain"))
289        assert(!OpenSSL::SSL.verify_certificate_identity(cert, "127.0.0.1"))
290        assert(OpenSSL::SSL.verify_certificate_identity(cert, "localhost"))
291        assert(!OpenSSL::SSL.verify_certificate_identity(cert, "foo.example.com"))
292      }
293    }
294
295    now = Time.now
296    exts = [
297      ["keyUsage","keyEncipherment,digitalSignature",true],
298      ["subjectAltName","DNS:localhost.localdomain",false],
299      ["subjectAltName","IP:127.0.0.1",false],
300    ]
301    @svr_cert = issue_cert(@svr, @svr_key, 4, now, now+1800, exts,
302                           @ca_cert, @ca_key, OpenSSL::Digest::SHA1.new)
303    start_server(PORT, OpenSSL::SSL::VERIFY_NONE, true){|server, port|
304      server_connect(port) { |ssl|
305        assert(ssl.post_connection_check("localhost.localdomain"))
306        assert(ssl.post_connection_check("127.0.0.1"))
307        assert_raise(sslerr){ssl.post_connection_check("localhost")}
308        assert_raise(sslerr){ssl.post_connection_check("foo.example.com")}
309
310        cert = ssl.peer_cert
311        assert(OpenSSL::SSL.verify_certificate_identity(cert, "localhost.localdomain"))
312        assert(OpenSSL::SSL.verify_certificate_identity(cert, "127.0.0.1"))
313        assert(!OpenSSL::SSL.verify_certificate_identity(cert, "localhost"))
314        assert(!OpenSSL::SSL.verify_certificate_identity(cert, "foo.example.com"))
315      }
316    }
317
318    now = Time.now
319    exts = [
320      ["keyUsage","keyEncipherment,digitalSignature",true],
321      ["subjectAltName","DNS:*.localdomain",false],
322    ]
323    @svr_cert = issue_cert(@svr, @svr_key, 5, now, now+1800, exts,
324                           @ca_cert, @ca_key, OpenSSL::Digest::SHA1.new)
325    start_server(PORT, OpenSSL::SSL::VERIFY_NONE, true){|server, port|
326      server_connect(port) { |ssl|
327        assert(ssl.post_connection_check("localhost.localdomain"))
328        assert_raise(sslerr){ssl.post_connection_check("127.0.0.1")}
329        assert_raise(sslerr){ssl.post_connection_check("localhost")}
330        assert_raise(sslerr){ssl.post_connection_check("foo.example.com")}
331        cert = ssl.peer_cert
332        assert(OpenSSL::SSL.verify_certificate_identity(cert, "localhost.localdomain"))
333        assert(!OpenSSL::SSL.verify_certificate_identity(cert, "127.0.0.1"))
334        assert(!OpenSSL::SSL.verify_certificate_identity(cert, "localhost"))
335        assert(!OpenSSL::SSL.verify_certificate_identity(cert, "foo.example.com"))
336      }
337    }
338  end
339
340  def test_verify_certificate_identity
341    [true, false].each do |criticality|
342      cert = create_null_byte_SAN_certificate(criticality)
343      assert_equal(false, OpenSSL::SSL.verify_certificate_identity(cert, 'www.example.com'))
344      assert_equal(true,  OpenSSL::SSL.verify_certificate_identity(cert, "www.example.com\0.evil.com"))
345      assert_equal(false, OpenSSL::SSL.verify_certificate_identity(cert, '192.168.7.255'))
346      assert_equal(true,  OpenSSL::SSL.verify_certificate_identity(cert, '192.168.7.1'))
347      assert_equal(false, OpenSSL::SSL.verify_certificate_identity(cert, '13::17'))
348      assert_equal(true,  OpenSSL::SSL.verify_certificate_identity(cert, '13:0:0:0:0:0:0:17'))
349    end
350  end
351
352  # Create NULL byte SAN certificate
353  def create_null_byte_SAN_certificate(critical = false)
354    ef = OpenSSL::X509::ExtensionFactory.new
355    cert = OpenSSL::X509::Certificate.new
356    cert.subject = OpenSSL::X509::Name.parse "/DC=some/DC=site/CN=Some Site"
357    ext = ef.create_ext('subjectAltName', 'DNS:placeholder,IP:192.168.7.1,IP:13::17', critical)
358    ext_asn1 = OpenSSL::ASN1.decode(ext.to_der)
359    san_list_der = ext_asn1.value.reduce(nil) { |memo,val| val.tag == 4 ? val.value : memo }
360    san_list_asn1 = OpenSSL::ASN1.decode(san_list_der)
361    san_list_asn1.value[0].value = "www.example.com\0.evil.com"
362    pos = critical ? 2 : 1
363    ext_asn1.value[pos].value = san_list_asn1.to_der
364    real_ext = OpenSSL::X509::Extension.new ext_asn1
365    cert.add_extension(real_ext)
366    cert
367  end
368
369  def test_tlsext_hostname
370    return unless OpenSSL::SSL::SSLSocket.instance_methods.include?(:hostname)
371
372    ctx_proc = Proc.new do |ctx, ssl|
373      foo_ctx = ctx.dup
374
375      ctx.servername_cb = Proc.new do |ssl2, hostname|
376        case hostname
377        when 'foo.example.com'
378          foo_ctx
379        when 'bar.example.com'
380          nil
381        else
382          raise "unknown hostname #{hostname.inspect}"
383        end
384      end
385    end
386
387    server_proc = Proc.new do |ctx, ssl|
388      readwrite_loop(ctx, ssl)
389    end
390
391    start_server(PORT, OpenSSL::SSL::VERIFY_NONE, true, :ctx_proc => ctx_proc, :server_proc => server_proc) do |server, port|
392      2.times do |i|
393        ctx = OpenSSL::SSL::SSLContext.new
394        if defined?(OpenSSL::SSL::OP_NO_TICKET)
395          # disable RFC4507 support
396          ctx.options = OpenSSL::SSL::OP_NO_TICKET
397        end
398        server_connect(port, ctx) { |ssl|
399          ssl.hostname = (i & 1 == 0) ? 'foo.example.com' : 'bar.example.com'
400          str = "x" * 100 + "\n"
401          ssl.puts(str)
402          assert_equal(str, ssl.gets)
403        }
404      end
405    end
406  end
407
408  def test_multibyte_read_write
409    #German a umlaut
410    auml = [%w{ C3 A4 }.join('')].pack('H*')
411    auml.force_encoding(Encoding::UTF_8)
412
413    [10, 1000, 100000].each {|i|
414      str = nil
415      num_written = nil
416      server_proc = Proc.new {|ctx, ssl|
417        cmp = ssl.read
418        raw_size = cmp.size
419        cmp.force_encoding(Encoding::UTF_8)
420        assert_equal(str, cmp)
421        assert_equal(num_written, raw_size)
422        ssl.close
423      }
424      start_server(PORT, OpenSSL::SSL::VERIFY_NONE, true, :server_proc => server_proc){|server, port|
425        server_connect(port) { |ssl|
426          str = auml * i
427          num_written = ssl.write(str)
428        }
429      }
430    }
431  end
432
433  def test_unset_OP_ALL
434    ctx_proc = Proc.new { |ctx|
435      # If OP_DONT_INSERT_EMPTY_FRAGMENTS is not defined, this test is
436      # redundant because the default options already are equal to OP_ALL.
437      # But it also degrades gracefully, so keep it
438      ctx.options = OpenSSL::SSL::OP_ALL
439    }
440    start_server(PORT, OpenSSL::SSL::VERIFY_NONE, true, :ctx_proc => ctx_proc){|server, port|
441      server_connect(port) { |ssl|
442        ssl.puts('hello')
443        assert_equal("hello\n", ssl.gets)
444      }
445    }
446  end
447
448  # different OpenSSL versions react differently when facing a SSL/TLS version
449  # that has been marked as forbidden, therefore either of these may be raised
450  HANDSHAKE_ERRORS = [OpenSSL::SSL::SSLError, Errno::ECONNRESET]
451
452if OpenSSL::SSL::SSLContext::METHODS.include? :TLSv1
453
454  def test_forbid_ssl_v3_for_client
455    ctx_proc = Proc.new { |ctx| ctx.options = OpenSSL::SSL::OP_ALL | OpenSSL::SSL::OP_NO_SSLv3 }
456    start_server_version(:SSLv23, ctx_proc) { |server, port|
457      ctx = OpenSSL::SSL::SSLContext.new
458      ctx.ssl_version = :SSLv3
459      assert_raise(*HANDSHAKE_ERRORS) { server_connect(port, ctx) }
460    }
461  end
462
463  def test_forbid_ssl_v3_from_server
464    start_server_version(:SSLv3) { |server, port|
465      ctx = OpenSSL::SSL::SSLContext.new
466      ctx.options = OpenSSL::SSL::OP_ALL | OpenSSL::SSL::OP_NO_SSLv3
467      assert_raise(*HANDSHAKE_ERRORS) { server_connect(port, ctx) }
468    }
469  end
470
471end
472
473if OpenSSL::SSL::SSLContext::METHODS.include? :TLSv1_1
474
475  def test_tls_v1_1
476    start_server_version(:TLSv1_1) { |server, port|
477      server_connect(port) { |ssl| assert_equal("TLSv1.1", ssl.ssl_version) }
478    }
479  end
480
481  def test_forbid_tls_v1_for_client
482    ctx_proc = Proc.new { |ctx| ctx.options = OpenSSL::SSL::OP_ALL | OpenSSL::SSL::OP_NO_TLSv1 }
483    start_server_version(:SSLv23, ctx_proc) { |server, port|
484      ctx = OpenSSL::SSL::SSLContext.new
485      ctx.ssl_version = :TLSv1
486      assert_raise(*HANDSHAKE_ERRORS) { server_connect(port, ctx) }
487    }
488  end
489
490  def test_forbid_tls_v1_from_server
491    start_server_version(:TLSv1) { |server, port|
492      ctx = OpenSSL::SSL::SSLContext.new
493      ctx.options = OpenSSL::SSL::OP_ALL | OpenSSL::SSL::OP_NO_TLSv1
494      assert_raise(*HANDSHAKE_ERRORS) { server_connect(port, ctx) }
495    }
496  end
497
498end
499
500if OpenSSL::SSL::SSLContext::METHODS.include? :TLSv1_2
501
502  def test_tls_v1_2
503    start_server_version(:TLSv1_2) { |server, port|
504      ctx = OpenSSL::SSL::SSLContext.new
505      ctx.ssl_version = :TLSv1_2_client
506      server_connect(port, ctx) { |ssl| assert_equal("TLSv1.2", ssl.ssl_version) }
507    }
508  end if OpenSSL::OPENSSL_VERSION_NUMBER > 0x10001000
509
510  def test_forbid_tls_v1_1_for_client
511    ctx_proc = Proc.new { |ctx| ctx.options = OpenSSL::SSL::OP_ALL | OpenSSL::SSL::OP_NO_TLSv1_1 }
512    start_server_version(:SSLv23, ctx_proc) { |server, port|
513      ctx = OpenSSL::SSL::SSLContext.new
514      ctx.ssl_version = :TLSv1_1
515      assert_raise(*HANDSHAKE_ERRORS) { server_connect(port, ctx) }
516    }
517  end if defined?(OpenSSL::SSL::OP_NO_TLSv1_1)
518
519  def test_forbid_tls_v1_1_from_server
520    start_server_version(:TLSv1_1) { |server, port|
521      ctx = OpenSSL::SSL::SSLContext.new
522      ctx.options = OpenSSL::SSL::OP_ALL | OpenSSL::SSL::OP_NO_TLSv1_1
523      assert_raise(*HANDSHAKE_ERRORS) { server_connect(port, ctx) }
524    }
525  end if defined?(OpenSSL::SSL::OP_NO_TLSv1_1)
526
527  def test_forbid_tls_v1_2_for_client
528    ctx_proc = Proc.new { |ctx| ctx.options = OpenSSL::SSL::OP_ALL | OpenSSL::SSL::OP_NO_TLSv1_2 }
529    start_server_version(:SSLv23, ctx_proc) { |server, port|
530      ctx = OpenSSL::SSL::SSLContext.new
531      ctx.ssl_version = :TLSv1_2
532      assert_raise(*HANDSHAKE_ERRORS) { server_connect(port, ctx) }
533    }
534  end if defined?(OpenSSL::SSL::OP_NO_TLSv1_2)
535
536  def test_forbid_tls_v1_2_from_server
537    start_server_version(:TLSv1_2) { |server, port|
538      ctx = OpenSSL::SSL::SSLContext.new
539      ctx.options = OpenSSL::SSL::OP_ALL | OpenSSL::SSL::OP_NO_TLSv1_2
540      assert_raise(*HANDSHAKE_ERRORS) { server_connect(port, ctx) }
541    }
542  end if defined?(OpenSSL::SSL::OP_NO_TLSv1_2)
543
544end
545
546  def test_renegotiation_cb
547    num_handshakes = 0
548    renegotiation_cb = Proc.new { |ssl| num_handshakes += 1 }
549    ctx_proc = Proc.new { |ctx| ctx.renegotiation_cb = renegotiation_cb }
550    start_server_version(:SSLv23, ctx_proc) { |server, port|
551      server_connect(port) { |ssl|
552        assert_equal(1, num_handshakes)
553      }
554    }
555  end
556
557if OpenSSL::OPENSSL_VERSION_NUMBER > 0x10001000
558
559  def test_npn_protocol_selection_ary
560    advertised = ["http/1.1", "spdy/2"]
561    ctx_proc = Proc.new { |ctx| ctx.npn_protocols = advertised }
562    start_server_version(:SSLv23, ctx_proc) { |server, port|
563      selector = lambda { |which|
564        ctx = OpenSSL::SSL::SSLContext.new
565        ctx.npn_select_cb = -> (protocols) { protocols.send(which) }
566        server_connect(port, ctx) { |ssl|
567          assert_equal(advertised.send(which), ssl.npn_protocol)
568        }
569      }
570      selector.call(:first)
571      selector.call(:last)
572    }
573  end
574
575  def test_npn_protocol_selection_enum
576    advertised = Object.new
577    def advertised.each
578      yield "http/1.1"
579      yield "spdy/2"
580    end
581    ctx_proc = Proc.new { |ctx| ctx.npn_protocols = advertised }
582    start_server_version(:SSLv23, ctx_proc) { |server, port|
583      selector = lambda { |selected, which|
584        ctx = OpenSSL::SSL::SSLContext.new
585        ctx.npn_select_cb = -> (protocols) { protocols.to_a.send(which) }
586        server_connect(port, ctx) { |ssl|
587          assert_equal(selected, ssl.npn_protocol)
588        }
589      }
590      selector.call("http/1.1", :first)
591      selector.call("spdy/2", :last)
592    }
593  end
594
595  def test_npn_protocol_selection_cancel
596    ctx_proc = Proc.new { |ctx| ctx.npn_protocols = ["http/1.1"] }
597    start_server_version(:SSLv23, ctx_proc) { |server, port|
598      ctx = OpenSSL::SSL::SSLContext.new
599      ctx.npn_select_cb = -> (protocols) { raise RuntimeError.new }
600      assert_raise(RuntimeError) { server_connect(port, ctx) }
601    }
602  end
603
604  def test_npn_advertised_protocol_too_long
605    ctx_proc = Proc.new { |ctx| ctx.npn_protocols = ["a" * 256] }
606    start_server_version(:SSLv23, ctx_proc) { |server, port|
607      ctx = OpenSSL::SSL::SSLContext.new
608      ctx.npn_select_cb = -> (protocols) { protocols.first }
609      assert_raise(*HANDSHAKE_ERRORS) { server_connect(port, ctx) }
610    }
611  end
612
613  def test_npn_selected_protocol_too_long
614    ctx_proc = Proc.new { |ctx| ctx.npn_protocols = ["http/1.1"] }
615    start_server_version(:SSLv23, ctx_proc) { |server, port|
616      ctx = OpenSSL::SSL::SSLContext.new
617      ctx.npn_select_cb = -> (protocols) { "a" * 256 }
618      assert_raise(*HANDSHAKE_ERRORS) { server_connect(port, ctx) }
619    }
620  end
621
622end
623
624  def test_invalid_shutdown_by_gc
625    assert_nothing_raised {
626      start_server(PORT, OpenSSL::SSL::VERIFY_NONE, true){|server, port|
627        10.times {
628          sock = TCPSocket.new("127.0.0.1", port)
629          ssl = OpenSSL::SSL::SSLSocket.new(sock)
630          GC.start
631          ssl.connect
632          sock.close
633        }
634      }
635    }
636  end
637
638  def test_close_after_socket_close
639    start_server(PORT, OpenSSL::SSL::VERIFY_NONE, true){|server, port|
640      sock = TCPSocket.new("127.0.0.1", port)
641      ssl = OpenSSL::SSL::SSLSocket.new(sock)
642      ssl.sync_close = true
643      ssl.connect
644      sock.close
645      assert_nothing_raised do
646        ssl.close
647      end
648    }
649  end
650
651  private
652
653  def start_server_version(version, ctx_proc=nil, server_proc=nil, &blk)
654    ctx_wrap = Proc.new { |ctx|
655      ctx.ssl_version = version
656      ctx_proc.call(ctx) if ctx_proc
657    }
658    start_server(
659      PORT,
660      OpenSSL::SSL::VERIFY_NONE,
661      true,
662      :ctx_proc => ctx_wrap,
663      :server_proc => server_proc,
664      &blk
665    )
666  end
667
668  def server_connect(port, ctx=nil)
669    sock = TCPSocket.new("127.0.0.1", port)
670    ssl = ctx ? OpenSSL::SSL::SSLSocket.new(sock, ctx) : OpenSSL::SSL::SSLSocket.new(sock)
671    ssl.sync_close = true
672    ssl.connect
673    yield ssl
674  ensure
675    ssl.close
676  end
677end
678
679end
680