1(*
2    Title:      Standard Basis Library: Internet Sockets
3    Author:     David Matthews
4    Copyright   David Matthews 2000, 2016, 2019
5
6    This library is free software; you can redistribute it and/or
7    modify it under the terms of the GNU Lesser General Public
8    License version 2.1 as published by the Free Software Foundation.
9    
10    This library is distributed in the hope that it will be useful,
11    but WITHOUT ANY WARRANTY; without even the implied warranty of
12    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13    Lesser General Public License for more details.
14    
15    You should have received a copy of the GNU Lesser General Public
16    License along with this library; if not, write to the Free Software
17    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
18*)
19
20signature NET_HOST_DB =
21sig
22    eqtype in_addr
23    eqtype addr_family
24    type entry
25    val name : entry -> string
26    val aliases : entry -> string list
27    val addrType : entry -> addr_family
28    val addr : entry -> in_addr
29    val addrs : entry -> in_addr list
30
31    val getByName : string -> entry option
32    val getByAddr : in_addr -> entry option
33    val getHostName : unit -> string
34    val scan : (char, 'a) StringCvt.reader
35              -> (in_addr, 'a) StringCvt.reader
36    val fromString : string -> in_addr option
37    val toString : in_addr -> string
38end;
39
40local
41    fun power2 0 = 1: LargeInt.int
42    |   power2 n = 2 * power2(n-1)
43    val p32 = power2 32
44    val p24 = power2 24
45
46    fun scan getc src =
47    let
48        (* Read a number as either decimal, hex or octal up to the
49           given limit. Stops when it reaches the limit or finds a
50           character it doesn't recognise. *)
51        fun readNum base acc limit src =
52        let
53            fun addDigit d src =
54            let
55                val n = case acc of SOME(n, _) => n | NONE => 0
56                val next = n * LargeInt.fromInt base + LargeInt.fromInt d
57            in
58                (* If we are below the limit we can continue. *)
59                if next < limit
60                then readNum base (SOME(next, src)) limit src
61                else acc
62            end
63        in
64            case getc src of
65                NONE => acc
66            |   SOME(ch, src') =>
67                    if Char.isDigit ch andalso
68                        ch < Char.chr(Char.ord #"0" + base)
69                    then addDigit (Char.ord ch - Char.ord #"0") src'
70                    else if base = 16 andalso (ch >= #"A" andalso ch <= #"F")
71                    then addDigit (Char.ord ch - Char.ord #"A" + 10) src'
72                    else if base = 16 andalso (ch >= #"a" andalso ch <= #"f")
73                    then addDigit (Char.ord ch - Char.ord #"a" + 10) src'
74                    else acc
75        end
76
77        (* Read a number.  If it starts with 0x or 0X treat it
78           as hex, otherwise if it starts with 0 treat as octal
79           otherwise decimal. *)
80        fun scanNum limit src =
81            case getc src of
82                NONE => NONE
83            |   SOME (#"0", src') =>
84                (
85                    case getc src' of
86                        SOME(ch, src'') =>
87                            if ch = #"x" orelse ch = #"X"
88                            then
89                                (
90                                (* If it is invalid we have still read a
91                                   zero so return that. *)
92                                case readNum 16 NONE limit src'' of
93                                    NONE => SOME(0, src')
94                                |   res => res
95                                )
96                            else (* Octal - include the zero. *)
97                                readNum 8 NONE limit src
98                    |   NONE => SOME(0, src') (* Just the zero. *)
99                )
100            |   SOME (_, _) => (* Treat it as a decimal number. *)
101                    readNum 10 NONE limit src
102
103        fun scanAddr src limit i acc =
104            case scanNum limit src of
105                NONE => NONE
106            |   SOME(n, src') =>
107                let
108                    val res = acc*256 + n (* This is the accumulated result. *)
109                in
110                    (* If the result is more than 24 bits or we've read
111                       all the sections we're finished. *)
112                    if res >= p24 orelse i = 1 then SOME(res, src')
113                    else
114                        case getc src' of
115                            SOME (#".", src'') =>
116                            (
117                                (* The limit for sections other than the
118                                   first is 256. *)
119                                case scanAddr src'' 256 (i-1) res of
120                                    NONE => SOME(res, src') (* Return what we had. *)
121                                |   r => r
122                            )
123                        |   _ => SOME(res, src') (* Return what we've got. *)
124                end
125    in
126        scanAddr src p32 4 (* Four sections in all. *) 0
127    end (* scan *)
128    
129    structure INet4Addr :>
130    sig
131        eqtype in_addr
132        type inet
133        type sock_addr = inet Socket.sock_addr
134        val inetAF : Socket.AF.addr_family
135 
136        val scan : (char, 'a) StringCvt.reader
137                  -> (in_addr, 'a) StringCvt.reader
138        val fromString : string -> in_addr option
139        val toString : in_addr -> string
140
141        val toAddr : in_addr * int -> sock_addr
142        val fromAddr : sock_addr -> in_addr * int
143        val any : int -> sock_addr
144    end
145    =
146    struct
147        type in_addr = LargeInt.int
148       
149        abstype inet = ABSTRACT with end;
150
151        type sock_addr = inet Socket.sock_addr
152
153        val inetAF =
154            case Socket.AF.fromString "INET" of
155                NONE => raise OS.SysErr("Missing address family", NONE)
156            |   SOME s => s
157       
158        val scan = scan
159        and fromString = StringCvt.scanString scan
160    
161        fun toString (n: in_addr) =
162        let
163            fun pr n i =
164                (if i > 0 then pr (n div 256) (i-1) ^ "." else "") ^
165                    LargeInt.toString (n mod 256)
166                
167        in
168            pr n 3 (* Always generate 4 numbers. *)
169        end
170
171        val toAddr: in_addr * int -> sock_addr = RunCall.rtsCallFull2 "PolyNetworkCreateIP4Address"
172        and fromAddr: sock_addr -> in_addr * int = RunCall.rtsCallFull1 "PolyNetworkGetAddressAndPortFromIP4"
173
174        local
175            val getAddrAny: unit -> in_addr = RunCall.rtsCallFull0 "PolyNetworkReturnIP4AddressAny"
176            val iAddrAny: in_addr = getAddrAny()
177        in
178            fun any (p: int) : sock_addr = toAddr(iAddrAny, p)
179        end
180
181    end
182
183in
184    structure NetHostDB :> NET_HOST_DB where type in_addr = INet4Addr.in_addr where type addr_family = Socket.AF.addr_family =
185    struct
186        open INet4Addr
187        type addr_family = Socket.AF.addr_family
188        type entry = string * string list * addr_family * in_addr list
189        val name: entry -> string = #1
190        (* aliases now always returns the empty list. *)
191        val aliases : entry -> string list = #2
192        val addrType : entry -> addr_family = #3
193        val addrs : entry -> in_addr list = #4
194    
195        (* Addr returns the first address in the list. There should always
196           be at least one entry. *)
197        fun addr e =
198            case addrs e of
199                a :: _ => a
200             |  [] => raise OS.SysErr("No address returned", NONE)
201    
202        val getHostName: unit -> string = RunCall.rtsCallFull0 "PolyNetworkGetHostName"
203
204        local
205            type addrInfo = int * Socket.AF.addr_family * int * int * sock_addr * string
206            val getAddrInfo: string * addr_family -> addrInfo list =
207                RunCall.rtsCallFull2 "PolyNetworkGetAddrInfo"
208        in
209            fun getByName s =
210            (
211                case getAddrInfo(s, inetAF) of
212                    [] => NONE
213                |   l as ((_, family, _, _, _, name) :: _) =>
214                        SOME (name, [], family, map (#1 o fromAddr o #5) l)
215            ) handle OS.SysErr _ => NONE
216        end
217
218        local
219            (* This does a reverse lookup of the address to return the name. *)
220            val doCall: sock_addr -> string = RunCall.rtsCallFull1 "PolyNetworkGetNameInfo"
221        in
222            fun getByAddr n =
223            (
224                (* Create an entry out of this.  We could do a forward look-up
225                   of the resulting address but there doesn't seem to be any point. *)
226                SOME(doCall(toAddr(n, 0)), [], inetAF, [n])
227            ) handle OS.SysErr _ => NONE
228        end
229     end
230
231    and INetSock =
232    struct
233        open INet4Addr
234 
235        type 'sock_type sock = (inet, 'sock_type) Socket.sock
236        type 'mode stream_sock = 'mode Socket.stream sock
237
238        type dgram_sock = Socket.dgram sock
239
240        local
241            val doSetOpt: int * OS.IO.iodesc * int -> unit =
242                RunCall.rtsCallFull3 "PolyNetworkSetOption"
243            val doGetOpt: int * OS.IO.iodesc -> int =
244                RunCall.rtsCallFull2 "PolyNetworkGetOption"
245        in
246            structure UDP =
247            struct
248                fun socket () = GenericSock.socket(inetAF, Socket.SOCK.dgram)
249                fun socket' p = GenericSock.socket'(inetAF, Socket.SOCK.dgram, p)
250            end
251
252            structure TCP =
253            struct
254                fun socket () = GenericSock.socket(inetAF, Socket.SOCK.stream)
255                fun socket' p = GenericSock.socket'(inetAF, Socket.SOCK.stream, p)
256
257                fun getNODELAY(LibraryIOSupport.SOCK s: 'mode stream_sock): bool = doGetOpt(16, s) <> 0
258
259                fun setNODELAY (LibraryIOSupport.SOCK s: 'mode stream_sock, b: bool): unit =
260                    doSetOpt(15, s, if b then 1 else 0)
261            end
262        end 
263
264    end;
265end;
266
267
268(* These use NetHostDB in the signature which is a bit of a mess. *)
269
270(* Apply type realisation. *)
271signature SOCKET = sig include SOCKET end where type AF.addr_family = NetHostDB.addr_family;
272
273signature INET_SOCK =
274sig
275    type inet
276
277    type 'sock_type sock = (inet, 'sock_type) Socket.sock
278    type 'mode stream_sock = 'mode Socket.stream sock
279
280    type dgram_sock = Socket.dgram sock
281    type sock_addr = inet Socket.sock_addr
282
283    val inetAF : Socket.AF.addr_family
284    val toAddr : NetHostDB.in_addr * int -> sock_addr
285    val fromAddr : sock_addr -> NetHostDB.in_addr * int
286    val any : int -> sock_addr
287    structure UDP :
288    sig
289        val socket : unit -> dgram_sock
290        val socket' : int -> dgram_sock
291    end
292    structure TCP :
293    sig
294        val socket : unit -> 'mode stream_sock
295        val socket' : int -> 'mode stream_sock
296        val getNODELAY : 'mode stream_sock -> bool
297        val setNODELAY : 'mode stream_sock * bool -> unit
298    end
299end;
300
301structure INetSock :> INET_SOCK = INetSock;
302
303local
304    (* Install the pretty printer for NetHostDB.in_addr.
305       This must be done outside
306       the structure if we use opaque matching. *)
307    fun printAddr _ _ x = PolyML.PrettyString(NetHostDB.toString x)
308in
309    val () = PolyML.addPrettyPrinter printAddr
310end;
311