1require_relative 'utils'
2
3if defined?(OpenSSL)
4
5require 'socket'
6require_relative '../ruby/ut_eof'
7
8module SSLPair
9  def server
10    host = "127.0.0.1"
11    port = 0
12    ctx = OpenSSL::SSL::SSLContext.new()
13    ctx.ciphers = "ADH"
14    ctx.tmp_dh_callback = proc { OpenSSL::TestUtils::TEST_KEY_DH1024 }
15    tcps = TCPServer.new(host, port)
16    ssls = OpenSSL::SSL::SSLServer.new(tcps, ctx)
17    return ssls
18  end
19
20  def client(port)
21    host = "127.0.0.1"
22    ctx = OpenSSL::SSL::SSLContext.new()
23    ctx.ciphers = "ADH"
24    s = TCPSocket.new(host, port)
25    ssl = OpenSSL::SSL::SSLSocket.new(s, ctx)
26    ssl.connect
27    ssl.sync_close = true
28    ssl
29  end
30
31  def ssl_pair
32    ssls = server
33    th = Thread.new {
34      ns = ssls.accept
35      ssls.close
36      ns
37    }
38    port = ssls.to_io.addr[1]
39    c = client(port)
40    s = th.value
41    if block_given?
42      begin
43        yield c, s
44      ensure
45        c.close unless c.closed?
46        s.close unless s.closed?
47      end
48    else
49      return c, s
50    end
51  ensure
52    if th && th.alive?
53      th.kill
54      th.join
55    end
56  end
57end
58
59class OpenSSL::TestEOF1 < Test::Unit::TestCase
60  include TestEOF
61  include SSLPair
62
63  def open_file(content)
64    s1, s2 = ssl_pair
65    Thread.new { s2 << content; s2.close }
66    yield s1
67  end
68end
69
70class OpenSSL::TestEOF2 < Test::Unit::TestCase
71  include TestEOF
72  include SSLPair
73
74  def open_file(content)
75    s1, s2 = ssl_pair
76    Thread.new { s1 << content; s1.close }
77    yield s2
78  end
79end
80
81class OpenSSL::TestPair < Test::Unit::TestCase
82  include SSLPair
83
84  def test_getc
85    ssl_pair {|s1, s2|
86      s1 << "a"
87      assert_equal(?a, s2.getc)
88    }
89  end
90
91  def test_readpartial
92    ssl_pair {|s1, s2|
93      s2.write "a\nbcd"
94      assert_equal("a\n", s1.gets)
95      assert_equal("bcd", s1.readpartial(10))
96      s2.write "efg"
97      assert_equal("efg", s1.readpartial(10))
98      s2.close
99      assert_raise(EOFError) { s1.readpartial(10) }
100      assert_raise(EOFError) { s1.readpartial(10) }
101      assert_equal("", s1.readpartial(0))
102    }
103  end
104
105  def test_readall
106    ssl_pair {|s1, s2|
107      s2.close
108      assert_equal("", s1.read)
109    }
110  end
111
112  def test_readline
113    ssl_pair {|s1, s2|
114      s2.close
115      assert_raise(EOFError) { s1.readline }
116    }
117  end
118
119  def test_puts_meta
120    ssl_pair {|s1, s2|
121      begin
122        old = $/
123        $/ = '*'
124        s1.puts 'a'
125      ensure
126        $/ = old
127      end
128      s1.close
129      assert_equal("a\n", s2.read)
130    }
131  end
132
133  def test_puts_empty
134    ssl_pair {|s1, s2|
135      s1.puts
136      s1.close
137      assert_equal("\n", s2.read)
138    }
139  end
140
141  def test_read_nonblock
142    ssl_pair {|s1, s2|
143      err = nil
144      assert_raise(OpenSSL::SSL::SSLError) {
145        begin
146          s2.read_nonblock(10)
147        ensure
148          err = $!
149        end
150      }
151      assert_kind_of(IO::WaitReadable, err)
152      s1.write "abc\ndef\n"
153      IO.select([s2])
154      assert_equal("ab", s2.read_nonblock(2))
155      assert_equal("c\n", s2.gets)
156      ret = nil
157      assert_nothing_raised("[ruby-core:20298]") { ret = s2.read_nonblock(10) }
158      assert_equal("def\n", ret)
159    }
160  end
161
162  def test_write_nonblock
163    ssl_pair {|s1, s2|
164      n = 0
165      begin
166        n += s1.write_nonblock("a" * 100000)
167        n += s1.write_nonblock("b" * 100000)
168        n += s1.write_nonblock("c" * 100000)
169        n += s1.write_nonblock("d" * 100000)
170        n += s1.write_nonblock("e" * 100000)
171        n += s1.write_nonblock("f" * 100000)
172      rescue IO::WaitWritable
173      end
174      s1.close
175      assert_equal(n, s2.read.length)
176    }
177  end
178
179  def test_write_nonblock_with_buffered_data
180    ssl_pair {|s1, s2|
181      s1.write "foo"
182      s1.write_nonblock("bar")
183      s1.write "baz"
184      s1.close
185      assert_equal("foobarbaz", s2.read)
186    }
187  end
188
189  def test_connect_accept_nonblock
190    host = "127.0.0.1"
191    port = 0
192    ctx = OpenSSL::SSL::SSLContext.new()
193    ctx.ciphers = "ADH"
194    ctx.tmp_dh_callback = proc { OpenSSL::TestUtils::TEST_KEY_DH1024 }
195    serv = TCPServer.new(host, port)
196
197    port = serv.connect_address.ip_port
198
199    sock1 = TCPSocket.new(host, port)
200    sock2 = serv.accept
201    serv.close
202
203    th = Thread.new {
204      s2 = OpenSSL::SSL::SSLSocket.new(sock2, ctx)
205      s2.sync_close = true
206      begin
207        sleep 0.2
208        s2.accept_nonblock
209      rescue IO::WaitReadable
210        IO.select([s2])
211        retry
212      rescue IO::WaitWritable
213        IO.select(nil, [s2])
214        retry
215      end
216      s2
217    }
218
219    sleep 0.1
220    ctx = OpenSSL::SSL::SSLContext.new()
221    ctx.ciphers = "ADH"
222    s1 = OpenSSL::SSL::SSLSocket.new(sock1, ctx)
223    begin
224      sleep 0.2
225      s1.connect_nonblock
226    rescue IO::WaitReadable
227      IO.select([s1])
228      retry
229    rescue IO::WaitWritable
230      IO.select(nil, [s1])
231      retry
232    end
233    s1.sync_close = true
234
235    s2 = th.value
236
237    s1.print "a\ndef"
238    assert_equal("a\n", s2.gets)
239  ensure
240    s1.close if s1 && !s1.closed?
241    s2.close if s2 && !s2.closed?
242    serv.close if serv && !serv.closed?
243    sock1.close if sock1 && !sock1.closed?
244    sock2.close if sock2 && !sock2.closed?
245  end
246
247end
248
249end
250