1begin
2  require "socket"
3rescue LoadError
4end
5
6require "test/unit"
7require "tempfile"
8require "timeout"
9require "tmpdir"
10require "thread"
11require "io/nonblock"
12
13class TestSocket_UNIXSocket < Test::Unit::TestCase
14  def test_fd_passing
15    r1, w = IO.pipe
16    s1, s2 = UNIXSocket.pair
17    begin
18      s1.send_io(nil)
19    rescue NotImplementedError
20      assert_raise(NotImplementedError) { s2.recv_io }
21    rescue TypeError
22      s1.send_io(r1)
23      r2 = s2.recv_io
24      assert_equal(r1.stat.ino, r2.stat.ino)
25      assert_not_equal(r1.fileno, r2.fileno)
26      assert(r2.close_on_exec?)
27      w.syswrite "a"
28      assert_equal("a", r2.sysread(10))
29    ensure
30      s1.close
31      s2.close
32      w.close
33      r1.close
34      r2.close if r2 && !r2.closed?
35    end
36  end
37
38  def test_fd_passing_n
39    io_ary = []
40    return if !defined?(Socket::SCM_RIGHTS)
41    io_ary.concat IO.pipe
42    io_ary.concat IO.pipe
43    io_ary.concat IO.pipe
44    send_io_ary = []
45    io_ary.each {|io|
46      send_io_ary << io
47      UNIXSocket.pair {|s1, s2|
48        begin
49          ret = s1.sendmsg("\0", 0, nil, [Socket::SOL_SOCKET, Socket::SCM_RIGHTS,
50                                          send_io_ary.map {|io2| io2.fileno }.pack("i!*")])
51        rescue NotImplementedError
52          return
53        end
54        assert_equal(1, ret)
55        ret = s2.recvmsg(:scm_rights=>true)
56        data, srcaddr, flags, *ctls = ret
57        recv_io_ary = []
58        ctls.each {|ctl|
59          next if ctl.level != Socket::SOL_SOCKET || ctl.type != Socket::SCM_RIGHTS
60          recv_io_ary.concat ctl.unix_rights
61        }
62        assert_equal(send_io_ary.length, recv_io_ary.length)
63        send_io_ary.length.times {|i|
64          assert_not_equal(send_io_ary[i].fileno, recv_io_ary[i].fileno)
65          assert(File.identical?(send_io_ary[i], recv_io_ary[i]))
66          assert(recv_io_ary[i].close_on_exec?)
67        }
68      }
69    }
70  ensure
71    io_ary.each {|io| io.close if !io.closed? }
72  end
73
74  def test_fd_passing_n2
75    io_ary = []
76    return if !defined?(Socket::SCM_RIGHTS)
77    return if !defined?(Socket::AncillaryData)
78    io_ary.concat IO.pipe
79    io_ary.concat IO.pipe
80    io_ary.concat IO.pipe
81    send_io_ary = []
82    io_ary.each {|io|
83      send_io_ary << io
84      UNIXSocket.pair {|s1, s2|
85        begin
86          ancdata = Socket::AncillaryData.unix_rights(*send_io_ary)
87          ret = s1.sendmsg("\0", 0, nil, ancdata)
88        rescue NotImplementedError
89          return
90        end
91        assert_equal(1, ret)
92        ret = s2.recvmsg(:scm_rights=>true)
93        data, srcaddr, flags, *ctls = ret
94        recv_io_ary = []
95        ctls.each {|ctl|
96          next if ctl.level != Socket::SOL_SOCKET || ctl.type != Socket::SCM_RIGHTS
97          recv_io_ary.concat ctl.unix_rights
98        }
99        assert_equal(send_io_ary.length, recv_io_ary.length)
100        send_io_ary.length.times {|i|
101          assert_not_equal(send_io_ary[i].fileno, recv_io_ary[i].fileno)
102          assert(File.identical?(send_io_ary[i], recv_io_ary[i]))
103          assert(recv_io_ary[i].close_on_exec?)
104        }
105      }
106    }
107  ensure
108    io_ary.each {|io| io.close if !io.closed? }
109  end
110
111  def test_fd_passing_race_condition
112    r1, w = IO.pipe
113    s1, s2 = UNIXSocket.pair
114    s1.nonblock = s2.nonblock = true
115    lock = Mutex.new
116    nr = 0
117    x = 2
118    y = 1000
119    begin
120      s1.send_io(nil)
121    rescue NotImplementedError
122      assert_raise(NotImplementedError) { s2.recv_io }
123    rescue TypeError
124      thrs = x.times.map do
125        Thread.new do
126          y.times do
127            s2.recv_io.close
128            lock.synchronize { nr += 1 }
129          end
130          true
131        end
132      end
133      (x * y).times { s1.send_io r1 }
134      assert_equal([true]*x, thrs.map { |t| t.value })
135      assert_equal x * y, nr
136    ensure
137      s1.close
138      s2.close
139      w.close
140      r1.close
141    end
142  end
143
144  def test_sendmsg
145    return if !defined?(Socket::SCM_RIGHTS)
146    IO.pipe {|r1, w|
147      UNIXSocket.pair {|s1, s2|
148        begin
149          ret = s1.sendmsg("\0", 0, nil, [Socket::SOL_SOCKET, Socket::SCM_RIGHTS, [r1.fileno].pack("i!")])
150        rescue NotImplementedError
151          return
152        end
153        assert_equal(1, ret)
154        r2 = s2.recv_io
155        begin
156          assert(File.identical?(r1, r2))
157          assert(r2.close_on_exec?)
158        ensure
159          r2.close
160        end
161      }
162    }
163  end
164
165  def test_sendmsg_ancillarydata_int
166    return if !defined?(Socket::SCM_RIGHTS)
167    return if !defined?(Socket::AncillaryData)
168    IO.pipe {|r1, w|
169      UNIXSocket.pair {|s1, s2|
170        begin
171          ad = Socket::AncillaryData.int(:UNIX, :SOCKET, :RIGHTS, r1.fileno)
172          ret = s1.sendmsg("\0", 0, nil, ad)
173        rescue NotImplementedError
174          return
175        end
176        assert_equal(1, ret)
177        r2 = s2.recv_io
178        begin
179          assert(File.identical?(r1, r2))
180        ensure
181          r2.close
182        end
183      }
184    }
185  end
186
187  def test_sendmsg_ancillarydata_unix_rights
188    return if !defined?(Socket::SCM_RIGHTS)
189    return if !defined?(Socket::AncillaryData)
190    IO.pipe {|r1, w|
191      UNIXSocket.pair {|s1, s2|
192        begin
193          ad = Socket::AncillaryData.unix_rights(r1)
194          ret = s1.sendmsg("\0", 0, nil, ad)
195        rescue NotImplementedError
196          return
197        end
198        assert_equal(1, ret)
199        r2 = s2.recv_io
200        begin
201          assert(File.identical?(r1, r2))
202        ensure
203          r2.close
204        end
205      }
206    }
207  end
208
209  def test_recvmsg
210    return if !defined?(Socket::SCM_RIGHTS)
211    return if !defined?(Socket::AncillaryData)
212    IO.pipe {|r1, w|
213      UNIXSocket.pair {|s1, s2|
214        s1.send_io(r1)
215        ret = s2.recvmsg(:scm_rights=>true)
216        data, srcaddr, flags, *ctls = ret
217        assert_equal("\0", data)
218	if flags == nil
219	  # struct msghdr is 4.3BSD style (msg_accrights field).
220	  assert_instance_of(Array, ctls)
221	  assert_equal(0, ctls.length)
222	else
223	  # struct msghdr is POSIX/4.4BSD style (msg_control field).
224	  assert_equal(0, flags & (Socket::MSG_TRUNC|Socket::MSG_CTRUNC))
225	  assert_instance_of(Addrinfo, srcaddr)
226	  assert_instance_of(Array, ctls)
227	  assert_equal(1, ctls.length)
228          ctl = ctls[0]
229	  assert_instance_of(Socket::AncillaryData, ctl)
230	  assert_equal(Socket::SOL_SOCKET, ctl.level)
231	  assert_equal(Socket::SCM_RIGHTS, ctl.type)
232	  assert_instance_of(String, ctl.data)
233          ios = ctl.unix_rights
234          assert_equal(1, ios.length)
235	  r2 = ios[0]
236	  begin
237	    assert(File.identical?(r1, r2))
238            assert(r2.close_on_exec?)
239	  ensure
240	    r2.close
241	  end
242	end
243      }
244    }
245  end
246
247  def bound_unix_socket(klass)
248    tmpfile = Tempfile.new("s")
249    path = tmpfile.path
250    tmpfile.close(true)
251    yield klass.new(path), path
252  ensure
253    File.unlink path if path && File.socket?(path)
254  end
255
256  def test_addr
257    bound_unix_socket(UNIXServer) {|serv, path|
258      c = UNIXSocket.new(path)
259      s = serv.accept
260      assert_equal(["AF_UNIX", path], c.peeraddr)
261      assert_equal(["AF_UNIX", ""], c.addr)
262      assert_equal(["AF_UNIX", ""], s.peeraddr)
263      assert_equal(["AF_UNIX", path], s.addr)
264      assert_equal(path, s.path)
265      assert_equal("", c.path)
266    }
267  end
268
269  def test_cloexec
270    bound_unix_socket(UNIXServer) {|serv, path|
271      c = UNIXSocket.new(path)
272      s = serv.accept
273      assert(serv.close_on_exec?)
274      assert(c.close_on_exec?)
275      assert(s.close_on_exec?)
276    }
277  end
278
279  def test_noname_path
280    s1, s2 = UNIXSocket.pair
281    assert_equal("", s1.path)
282    assert_equal("", s2.path)
283  ensure
284    s1.close
285    s2.close
286  end
287
288  def test_noname_addr
289    s1, s2 = UNIXSocket.pair
290    assert_equal(["AF_UNIX", ""], s1.addr)
291    assert_equal(["AF_UNIX", ""], s2.addr)
292  ensure
293    s1.close
294    s2.close
295  end
296
297  def test_noname_peeraddr
298    s1, s2 = UNIXSocket.pair
299    assert_equal(["AF_UNIX", ""], s1.peeraddr)
300    assert_equal(["AF_UNIX", ""], s2.peeraddr)
301  ensure
302    s1.close
303    s2.close
304  end
305
306  def test_noname_unpack_sockaddr_un
307    s1, s2 = UNIXSocket.pair
308    n = nil
309    assert_equal("", Socket.unpack_sockaddr_un(n)) if (n = s1.getsockname) != ""
310    assert_equal("", Socket.unpack_sockaddr_un(n)) if (n = s1.getsockname) != ""
311    assert_equal("", Socket.unpack_sockaddr_un(n)) if (n = s2.getsockname) != ""
312    assert_equal("", Socket.unpack_sockaddr_un(n)) if (n = s1.getpeername) != ""
313    assert_equal("", Socket.unpack_sockaddr_un(n)) if (n = s2.getpeername) != ""
314  ensure
315    s1.close
316    s2.close
317  end
318
319  def test_noname_recvfrom
320    s1, s2 = UNIXSocket.pair
321    s2.write("a")
322    assert_equal(["a", ["AF_UNIX", ""]], s1.recvfrom(10))
323  ensure
324    s1.close
325    s2.close
326  end
327
328  def test_noname_recv_nonblock
329    s1, s2 = UNIXSocket.pair
330    s2.write("a")
331    IO.select [s1]
332    assert_equal("a", s1.recv_nonblock(10))
333  ensure
334    s1.close
335    s2.close
336  end
337
338  def test_too_long_path
339    assert_raise(ArgumentError) { Socket.sockaddr_un("a" * 300) }
340    assert_raise(ArgumentError) { UNIXServer.new("a" * 300) }
341  end
342
343  def test_abstract_namespace
344    return if /linux/ !~ RUBY_PLATFORM
345    addr = Socket.pack_sockaddr_un("\0foo")
346    assert_match(/\0foo\z/, addr)
347    assert_equal("\0foo", Socket.unpack_sockaddr_un(addr))
348  end
349
350  def test_dgram_pair
351    s1, s2 = UNIXSocket.pair(Socket::SOCK_DGRAM)
352    assert_raise(Errno::EAGAIN) { s1.recv_nonblock(10) }
353    s2.send("", 0)
354    s2.send("haha", 0)
355    s2.send("", 0)
356    s2.send("", 0)
357    assert_equal("", s1.recv(10))
358    assert_equal("haha", s1.recv(10))
359    assert_equal("", s1.recv(10))
360    assert_equal("", s1.recv(10))
361    assert_raise(Errno::EAGAIN) { s1.recv_nonblock(10) }
362  ensure
363    s1.close if s1
364    s2.close if s2
365  end
366
367  def test_dgram_pair_sendrecvmsg_errno_set
368    s1, s2 = to_close = UNIXSocket.pair(Socket::SOCK_DGRAM)
369    pipe = IO.pipe
370    to_close.concat(pipe)
371    set_errno = lambda do
372      begin
373        pipe[0].read_nonblock(1)
374        fail
375      rescue => e
376        assert(IO::WaitReadable === e)
377      end
378    end
379    Timeout.timeout(10) do
380      set_errno.call
381      assert_equal(2, s1.sendmsg("HI"))
382      set_errno.call
383      assert_equal("HI", s2.recvmsg[0])
384    end
385  ensure
386    to_close.each(&:close) if to_close
387  end
388
389  def test_epipe # [ruby-dev:34619]
390    s1, s2 = UNIXSocket.pair
391    s1.shutdown(Socket::SHUT_WR)
392    assert_raise(Errno::EPIPE) { s1.write "a" }
393    assert_equal(nil, s2.read(1))
394    s2.write "a"
395    assert_equal("a", s1.read(1))
396  end
397
398  def test_socket_pair_with_block
399    pair = nil
400    ret = Socket.pair(Socket::AF_UNIX, Socket::SOCK_STREAM, 0) {|s1, s2|
401      pair = [s1, s2]
402      :return_value
403    }
404    assert_equal(:return_value, ret)
405    assert_kind_of(Socket, pair[0])
406    assert_kind_of(Socket, pair[1])
407  end
408
409  def test_unix_socket_pair_with_block
410    pair = nil
411    UNIXSocket.pair {|s1, s2|
412      pair = [s1, s2]
413    }
414    assert_kind_of(UNIXSocket, pair[0])
415    assert_kind_of(UNIXSocket, pair[1])
416  end
417
418  def test_unix_socket_pair_close_on_exec
419    pair = nil
420    UNIXSocket.pair {|s1, s2|
421      assert(s1.close_on_exec?)
422      assert(s2.close_on_exec?)
423    }
424  end
425
426  def test_initialize
427    Dir.mktmpdir {|d|
428      Socket.open(Socket::AF_UNIX, Socket::SOCK_STREAM, 0) {|s|
429	s.bind(Socket.pack_sockaddr_un("#{d}/s1"))
430	addr = s.getsockname
431	assert_nothing_raised { Socket.unpack_sockaddr_un(addr) }
432	assert_raise(ArgumentError) { Socket.unpack_sockaddr_in(addr) }
433      }
434      Socket.open("AF_UNIX", "SOCK_STREAM", 0) {|s|
435	s.bind(Socket.pack_sockaddr_un("#{d}/s2"))
436	addr = s.getsockname
437	assert_nothing_raised { Socket.unpack_sockaddr_un(addr) }
438	assert_raise(ArgumentError) { Socket.unpack_sockaddr_in(addr) }
439      }
440    }
441  end
442
443  def test_unix_server_socket
444    Dir.mktmpdir {|d|
445      path = "#{d}/sock"
446      s0 = nil
447      Socket.unix_server_socket(path) {|s|
448        assert_equal(path, s.local_address.unix_path)
449        assert(File.socket?(path))
450        s0 = s
451      }
452      assert(s0.closed?)
453      assert_raise(Errno::ENOENT) { File.stat path }
454    }
455  end
456
457  def test_getcred_ucred
458    return if /linux/ !~ RUBY_PLATFORM
459    Dir.mktmpdir {|d|
460      sockpath = "#{d}/sock"
461      serv = Socket.unix_server_socket(sockpath)
462      c = Socket.unix(sockpath)
463      s, = serv.accept
464      cred = s.getsockopt(:SOCKET, :PEERCRED)
465      inspect = cred.inspect
466      assert_match(/ pid=#{$$} /, inspect)
467      assert_match(/ euid=#{Process.euid} /, inspect)
468      assert_match(/ egid=#{Process.egid} /, inspect)
469      assert_match(/ \(ucred\)/, inspect)
470    }
471  end
472
473  def test_getcred_xucred
474    return if /freebsd|darwin/ !~ RUBY_PLATFORM
475    Dir.mktmpdir {|d|
476      sockpath = "#{d}/sock"
477      serv = Socket.unix_server_socket(sockpath)
478      c = Socket.unix(sockpath)
479      s, = serv.accept
480      cred = s.getsockopt(0, Socket::LOCAL_PEERCRED)
481      inspect = cred.inspect
482      assert_match(/ euid=#{Process.euid} /, inspect)
483      assert_match(/ \(xucred\)/, inspect)
484    }
485  end
486
487  def test_sendcred_ucred
488    return if /linux/ !~ RUBY_PLATFORM
489    Dir.mktmpdir {|d|
490      sockpath = "#{d}/sock"
491      serv = Socket.unix_server_socket(sockpath)
492      c = Socket.unix(sockpath)
493      s, = serv.accept
494      s.setsockopt(:SOCKET, :PASSCRED, 1)
495      c.print "a"
496      msg, cliend_ai, rflags, cred = s.recvmsg
497      inspect = cred.inspect
498      assert_equal("a", msg)
499      assert_match(/ pid=#{$$} /, inspect)
500      assert_match(/ uid=#{Process.uid} /, inspect)
501      assert_match(/ gid=#{Process.gid} /, inspect)
502      assert_match(/ \(ucred\)/, inspect)
503    }
504  end
505
506  def test_sendcred_sockcred
507    return if /netbsd|freebsd/ !~ RUBY_PLATFORM
508    Dir.mktmpdir {|d|
509      sockpath = "#{d}/sock"
510      serv = Socket.unix_server_socket(sockpath)
511      c = Socket.unix(sockpath)
512      s, = serv.accept
513      s.setsockopt(0, Socket::LOCAL_CREDS, 1)
514      c.print "a"
515      msg, cliend_ai, rflags, cred = s.recvmsg
516      assert_equal("a", msg)
517      inspect = cred.inspect
518      assert_match(/ uid=#{Process.uid} /, inspect)
519      assert_match(/ euid=#{Process.euid} /, inspect)
520      assert_match(/ gid=#{Process.gid} /, inspect)
521      assert_match(/ egid=#{Process.egid} /, inspect)
522      assert_match(/ \(sockcred\)/, inspect)
523    }
524  end
525
526  def test_sendcred_cmsgcred
527    return if /freebsd/ !~ RUBY_PLATFORM
528    Dir.mktmpdir {|d|
529      sockpath = "#{d}/sock"
530      serv = Socket.unix_server_socket(sockpath)
531      c = Socket.unix(sockpath)
532      s, = serv.accept
533      c.sendmsg("a", 0, nil, [:SOCKET, Socket::SCM_CREDS, ""])
534      msg, cliend_ai, rflags, cred = s.recvmsg
535      assert_equal("a", msg)
536      inspect = cred.inspect
537      assert_match(/ pid=#{$$} /, inspect)
538      assert_match(/ uid=#{Process.uid} /, inspect)
539      assert_match(/ euid=#{Process.euid} /, inspect)
540      assert_match(/ gid=#{Process.gid} /, inspect)
541      assert_match(/ \(cmsgcred\)/, inspect)
542    }
543  end
544
545  def test_getpeereid
546    Dir.mktmpdir {|d|
547      path = "#{d}/sock"
548      serv = Socket.unix_server_socket(path)
549      c = Socket.unix(path)
550      s, = serv.accept
551      begin
552        assert_equal([Process.euid, Process.egid], c.getpeereid)
553        assert_equal([Process.euid, Process.egid], s.getpeereid)
554      rescue NotImplementedError
555      end
556    }
557  end
558
559  def test_abstract_unix_server
560    return if /linux/ !~ RUBY_PLATFORM
561    name = "\0ruby-test_unix"
562    s0 = nil
563    UNIXServer.open(name) {|s|
564      assert_equal(name, s.local_address.unix_path)
565      s0 = s
566      UNIXSocket.open(name) {|c|
567        sock = s.accept
568        begin
569          assert_equal(name, c.remote_address.unix_path)
570        ensure
571          sock.close
572        end
573      }
574    }
575    assert(s0.closed?)
576  end
577
578  def test_abstract_unix_socket_econnrefused
579    return if /linux/ !~ RUBY_PLATFORM
580    name = "\0ruby-test_unix"
581    assert_raise(Errno::ECONNREFUSED) do
582      UNIXSocket.open(name) {}
583    end
584  end
585
586  def test_abstract_unix_server_socket
587    return if /linux/ !~ RUBY_PLATFORM
588    name = "\0ruby-test_unix"
589    s0 = nil
590    Socket.unix_server_socket(name) {|s|
591      assert_equal(name, s.local_address.unix_path)
592      s0 = s
593      Socket.unix(name) {|c|
594        sock, = s.accept
595        begin
596          assert_equal(name, c.remote_address.unix_path)
597        ensure
598          sock.close
599        end
600      }
601    }
602    assert(s0.closed?)
603  end
604
605  def test_autobind
606    return if /linux/ !~ RUBY_PLATFORM
607    s0 = nil
608    Socket.unix_server_socket("") {|s|
609      name = s.local_address.unix_path
610      assert_match(/\A\0[0-9a-f]{5}\z/, name)
611      s0 = s
612      Socket.unix(name) {|c|
613        sock, = s.accept
614        begin
615          assert_equal(name, c.remote_address.unix_path)
616        ensure
617          sock.close
618        end
619      }
620    }
621    assert(s0.closed?)
622  end
623
624end if defined?(UNIXSocket) && /cygwin/ !~ RUBY_PLATFORM
625