From 155905ca7db15949801c30774042ec0ad573bfe9 Mon Sep 17 00:00:00 2001 From: Anton Bachin Date: Sun, 27 Nov 2016 18:15:25 -0600 Subject: [PATCH 1/2] Concurrent version of Lwt_unix.bind The new function is called Lwt_unix.Versioned.bind_2. It will replace the current Lwt_unix.bind in a major release. --- doc/examples/unix/relay.ml | 2 +- src/unix/lwt_io.ml | 2 +- src/unix/lwt_unix.h | 1 + src/unix/lwt_unix.ml | 14 ++++++++ src/unix/lwt_unix.mli | 32 +++++++++++++++++- src/unix/lwt_unix_unix.c | 43 ++++++++++++++++++------ src/unix/lwt_unix_windows.c | 1 + tests/test.ml | 25 ++++++++------ tests/test.mli | 4 +++ tests/unix/test_lwt_unix.ml | 67 ++++++++++++++++++++++++++++++++++++- tests/unix/test_mcast.ml | 3 +- 11 files changed, 168 insertions(+), 26 deletions(-) diff --git a/doc/examples/unix/relay.ml b/doc/examples/unix/relay.ml index d7f6ac111b..539be32199 100644 --- a/doc/examples/unix/relay.ml +++ b/doc/examples/unix/relay.ml @@ -125,7 +125,7 @@ let%lwt () = (* Initialize the listening address. *) let sock = Lwt_unix.socket Unix.PF_INET Unix.SOCK_STREAM 0 in Lwt_unix.setsockopt sock Unix.SO_REUSEADDR true; - Lwt_unix.bind sock src_addr; + let%lwt () = Lwt_unix.Versioned.bind_2 sock src_addr in Lwt_unix.listen sock 1024; ignore (Lwt_log.notice "waiting for connection"); diff --git a/src/unix/lwt_io.ml b/src/unix/lwt_io.ml index bc522e785d..44d55feae9 100644 --- a/src/unix/lwt_io.ml +++ b/src/unix/lwt_io.ml @@ -1425,7 +1425,7 @@ let establish_server ?fd ?(buffer_size = !default_buffer_size) ?(backlog=5) sock | Some fd -> fd in Lwt_unix.setsockopt sock Unix.SO_REUSEADDR true; - Lwt_unix.bind sock sockaddr; + (Lwt_unix.bind sock sockaddr) [@ocaml.warning "-3"]; Lwt_unix.listen sock backlog; let abort_waiter, abort_wakener = Lwt.wait () in let abort_waiter = abort_waiter >>= fun () -> Lwt.return `Shutdown in diff --git a/src/unix/lwt_unix.h b/src/unix/lwt_unix.h index 1adf013177..c165b981b8 100644 --- a/src/unix/lwt_unix.h +++ b/src/unix/lwt_unix.h @@ -28,6 +28,7 @@ #include #include #include +#include /* The macro to get the file-descriptor from a value. */ #if defined(LWT_ON_WINDOWS) diff --git a/src/unix/lwt_unix.ml b/src/unix/lwt_unix.ml index e88febe5c8..e21e92af33 100644 --- a/src/unix/lwt_unix.ml +++ b/src/unix/lwt_unix.ml @@ -1579,6 +1579,14 @@ let bind ch addr = check_descriptor ch; Unix.bind ch.fd addr +external bind_job : Unix.file_descr -> Unix.sockaddr -> unit job = + "lwt_unix_bind_job" + +let bind' fd addr = + match Sys.win32, addr with + | true, _ | false, Unix.ADDR_INET _ -> Lwt.return (Unix.bind fd.fd addr) + | false, Unix.ADDR_UNIX _ -> run_job (bind_job fd.fd addr) + let listen ch cnt = check_descriptor ch; Unix.listen ch.fd cnt @@ -2328,3 +2336,9 @@ let () = Some(Printf.sprintf "Unix.Unix_error(Unix.%s, %S, %S)" error func arg) | _ -> None) + +module Versioned = +struct + let bind_1 = bind + let bind_2 = bind' +end diff --git a/src/unix/lwt_unix.mli b/src/unix/lwt_unix.mli index 6ff77ffca2..cb15d873df 100644 --- a/src/unix/lwt_unix.mli +++ b/src/unix/lwt_unix.mli @@ -835,7 +835,19 @@ val socketpair : socket_domain -> socket_type -> int -> file_descr * file_descr (** Wrapper for [Unix.socketpair] *) val bind : file_descr -> sockaddr -> unit - (** Wrapper for [Unix.bind] *) + [@@ocaml.deprecated +"This function will soon return threads (-> unit Lwt.t), because the bind system +call can block for Unix domain sockets. See + https://github.com/ocsigen/lwt/issues/230 +To keep using the current signature, use Lwt_unix.Versioned.bind_1 +To use the new non-blocking version immediately, use Lwt_unix.Versioned.bind_2"] +(** Binds an address to the given socket. This is the cooperative analog of + {{:http://caml.inria.fr/pub/docs/manual-ocaml/libref/Unix.html#VALbind} + [Unix.bind]}. See also + {{:http://man7.org/linux/man-pages/man3/bind.3p.html} [bind(3p)]}. + + @deprecated Will be replaced by {!Versioned.bind_2}, whose result type is + [unit Lwt.t] instead of [unit]. *) val listen : file_descr -> int -> unit (** Wrapper for [Unix.listen] *) @@ -1355,6 +1367,24 @@ val set_affinity : ?pid : int -> int list -> unit (** [set_affinity ?pid cpus] sets the list of CPUs the given process is allowed to run on. *) +(** {2 Versioned interfaces} *) + +(** Versioned variants of APIs undergoing breaking changes. *) +module Versioned : +sig + val bind_1 : file_descr -> sockaddr -> unit + [@@ocaml.deprecated +"Deprecated in favor of Lwt_unix.Versioned.bind_2. See + https://github.com/ocsigen/lwt/issues/230"] + (** Alias for the current {!Lwt_unix.bind}. + + @deprecated Use {!bind_2}. *) + + val bind_2 : file_descr -> sockaddr -> unit Lwt.t + (** Like {!Lwt_unix.bind}, but evaluates to an Lwt thread, in order to avoid + blocking the process in case the given socket is a Unix domain socket. *) +end + (**/**) val run : 'a Lwt.t -> 'a diff --git a/src/unix/lwt_unix_unix.c b/src/unix/lwt_unix_unix.c index dca1fe453d..a40e56231e 100644 --- a/src/unix/lwt_unix_unix.c +++ b/src/unix/lwt_unix_unix.c @@ -341,16 +341,6 @@ value lwt_unix_bytes_send(value fd, value buf, value ofs, value len, value flags extern int socket_domain_table[]; extern int socket_type_table[]; -union sock_addr_union { - struct sockaddr s_gen; - struct sockaddr_un s_unix; - struct sockaddr_in s_inet; - struct sockaddr_in6 s_inet6; -}; - -CAMLexport value alloc_sockaddr (union sock_addr_union * addr /*in*/, - socklen_t addr_len, int close_on_error); - value lwt_unix_recvfrom(value fd, value buf, value ofs, value len, value flags) { CAMLparam5(fd, buf, ofs, len, flags); @@ -2834,6 +2824,39 @@ CAMLprim value lwt_unix_getnameinfo_job(value sockaddr, value opts) return lwt_unix_alloc_job(&job->job); } +/* bind */ + +struct job_bind { + struct lwt_unix_job job; + int fd; + union sock_addr_union addr; + socklen_param_type addr_len; + int result; + int error_code; +}; + +static void worker_bind(struct job_bind *job) +{ + job->result = bind(job->fd, &job->addr.s_gen, job->addr_len); + job->error_code = errno; +} + +static value result_bind(struct job_bind *job) +{ + LWT_UNIX_CHECK_JOB(job, job->result != 0, "bind"); + lwt_unix_free_job(&job->job); + return Val_unit; +} + +CAMLprim value lwt_unix_bind_job(value fd, value address) +{ + LWT_UNIX_INIT_JOB(job, bind, 0); + job->fd = Int_val(fd); + get_sockaddr(address, &job->addr, &job->addr_len); + + return lwt_unix_alloc_job(&job->job); +} + /* +-----------------------------------------------------------------+ | Termios conversion | +-----------------------------------------------------------------+ */ diff --git a/src/unix/lwt_unix_windows.c b/src/unix/lwt_unix_windows.c index b158e691cf..4921e4b6a7 100644 --- a/src/unix/lwt_unix_windows.c +++ b/src/unix/lwt_unix_windows.c @@ -575,3 +575,4 @@ LWT_NOT_AVAILABLE1(unix_invalidate_dir) LWT_NOT_AVAILABLE3(unix_writev) LWT_NOT_AVAILABLE3(unix_writev_job) LWT_NOT_AVAILABLE1(unix_iov_max) +LWT_NOT_AVAILABLE2(unix_bind_job) diff --git a/tests/test.ml b/tests/test.ml index 35442939ad..bb16abe638 100644 --- a/tests/test.ml +++ b/tests/test.ml @@ -96,18 +96,21 @@ let run name suites = in loop_suites 0 0 1 suites +let temp_name = + let rng = Random.State.make_self_init () in + fun () -> + let number = Random.State.int rng 10000 in + Printf.sprintf "_build/lwt-testing-%04d" number + let temp_file () = Filename.temp_file ~temp_dir:"_build" "lwt-testing-" "" -let temp_directory = - let rng = Random.State.make_self_init () in - fun () -> - let rec attempt () = - let number = Random.State.int rng 10000 in - let path = Printf.sprintf "_build/lwt-testing-%04d" number in - try - Unix.mkdir path 0o755; - path - with Not_found -> attempt () - in +let temp_directory () = + let rec attempt () = + let path = temp_name () in + try + Unix.mkdir path 0o755; + path + with Unix.Unix_error (Unix.EEXIST, "mkdir", _) -> attempt () + in attempt () diff --git a/tests/test.mli b/tests/test.mli index d7b3173f8f..32051202ef 100644 --- a/tests/test.mli +++ b/tests/test.mli @@ -43,6 +43,10 @@ val run : string -> suite list -> unit (** Run all the given tests and exit the program with an exit code of [0] if all tests succeeded and with [1] otherwise. *) +val temp_name : unit -> string +(** Generates the name of a temporary file (or directory) in [_build/]. Note + that a file at the path may already exist. *) + val temp_file : unit -> string (** Creates a temporary file in [_build/] and evaluates to its path. *) diff --git a/tests/unix/test_lwt_unix.ml b/tests/unix/test_lwt_unix.ml index d143ee0920..2866848d44 100644 --- a/tests/unix/test_lwt_unix.ml +++ b/tests/unix/test_lwt_unix.ml @@ -417,8 +417,73 @@ let writev_tests = not (Lwt_unix.IO_vectors.is_empty io_vectors))); ] +let bind_tests = [ + test "bind: basic" + (fun () -> + let address = Unix.(ADDR_INET (inet_addr_loopback, 56100)) in + let socket = Lwt_unix.(socket PF_INET SOCK_STREAM 0) in + + Lwt.finalize + (fun () -> + Lwt_unix.Versioned.bind_2 socket address >>= fun () -> + Lwt.return (Unix.getsockname (Lwt_unix.unix_file_descr socket))) + (fun () -> + Lwt_unix.close socket) + + >>= fun address' -> + + Lwt.return (address' = address)); + + test "bind: Unix domain" ~only_if:(fun () -> not Sys.win32) + (fun () -> + let socket = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in + + let rec bind_loop attempts = + if attempts <= 0 then + Lwt.fail (Unix.Unix_error (Unix.EADDRINUSE, "bind", "")) + else + let path = temp_name () in + let address = Unix.(ADDR_UNIX path) in + Lwt.catch + (fun () -> + Lwt_unix.Versioned.bind_2 socket address >>= fun () -> + Lwt.return_true) + (function + | Unix.Unix_error (Unix.EADDRINUSE, "bind", _) -> Lwt.return_false + | e -> Lwt.fail e) [@ocaml.warning "-4"] >>= fun bound -> + if bound then + Lwt.return path + else + bind_loop (attempts - 1) + in + + Lwt.finalize + (fun () -> + bind_loop 5 >>= fun chosen_path -> + let actual_path = + Unix.getsockname (Lwt_unix.unix_file_descr socket) in + Lwt.return (chosen_path, actual_path)) + (fun () -> + Lwt_unix.close socket) + >>= fun (chosen_path, actual_path) -> + + let actual_path = + match actual_path with + | Unix.ADDR_UNIX path -> path + | Unix.ADDR_INET _ -> assert false + in + + (try Unix.unlink chosen_path + with _ -> ()); + (try Unix.unlink actual_path + with _ -> ()); + + Lwt.return (chosen_path = actual_path)); +] + let suite = suite "lwt_unix" (utimes_tests @ readdir_tests @ - writev_tests) + writev_tests @ + bind_tests) diff --git a/tests/unix/test_mcast.ml b/tests/unix/test_mcast.ml index db046a48f7..4b707a5555 100644 --- a/tests/unix/test_mcast.ml +++ b/tests/unix/test_mcast.ml @@ -30,7 +30,6 @@ let mcast_port = 4321 let child join fd = (* Lwt_unix.setsockopt fd Lwt_unix.SO_REUSEADDR true; *) - Lwt_unix.(bind fd (ADDR_INET (Unix.inet_addr_any, mcast_port))); if join then Lwt_unix.mcast_add_membership fd (Unix.inet_addr_of_string mcast_addr); let buf = Bytes.create 50 in Lwt_unix.with_timeout 0.1 (fun () -> Lwt_unix.read fd buf 0 (Bytes.length buf)) >>= fun n -> @@ -58,6 +57,8 @@ let test_mcast name join set_loop = let t () = Lwt.catch (fun () -> + Lwt_unix.(Versioned.bind_2 + fd1 (ADDR_INET (Unix.inet_addr_any, mcast_port))) >>= fun () -> let t1 = child join fd1 in let t2 = parent set_loop fd2 in Lwt.join [t1; t2] >>= fun () -> Lwt.return true From 2527f03bc32a892d346b86be32cee450331120ff Mon Sep 17 00:00:00 2001 From: Anton Bachin Date: Mon, 28 Nov 2016 10:52:37 -0600 Subject: [PATCH 2/2] Lwt_io.establish_server should return a thread --- src/unix/lwt_io.ml | 34 ++++++++++++++++++++++++++++------ src/unix/lwt_io.mli | 6 ++++-- tests/unix/test_lwt_io.ml | 34 ++++++++++++++++------------------ 3 files changed, 48 insertions(+), 26 deletions(-) diff --git a/src/unix/lwt_io.ml b/src/unix/lwt_io.ml index 44d55feae9..913a9e983b 100644 --- a/src/unix/lwt_io.ml +++ b/src/unix/lwt_io.ml @@ -1419,14 +1419,14 @@ let shutdown_server_2 server = Lazy.force server.shutdown let shutdown_server server = Lwt.async (fun () -> shutdown_server_2 server) -let establish_server ?fd ?(buffer_size = !default_buffer_size) ?(backlog=5) sockaddr f = +let _establish_server_base + bind ?fd ?(buffer_size = !default_buffer_size) ?(backlog=5) sockaddr f = let sock = match fd with | None -> Lwt_unix.socket (Unix.domain_of_sockaddr sockaddr) Unix.SOCK_STREAM 0 | Some fd -> fd in Lwt_unix.setsockopt sock Unix.SO_REUSEADDR true; - (Lwt_unix.bind sock sockaddr) [@ocaml.warning "-3"]; - Lwt_unix.listen sock backlog; + let abort_waiter, abort_wakener = Lwt.wait () in let abort_waiter = abort_waiter >>= fun () -> Lwt.return `Shutdown in (* Signals that the listening socket has been closed. *) @@ -1452,8 +1452,25 @@ let establish_server ?fd ?(buffer_size = !default_buffer_size) ?(backlog=5) sock Lwt.wakeup closed_wakener (); Lwt.return_unit in - ignore (loop ()); - { shutdown = lazy (Lwt.wakeup abort_wakener (); closed_waiter) } + + let started, signal_started = Lwt.wait () in + Lwt.ignore_result begin + bind sock sockaddr >>= fun () -> + Lwt_unix.listen sock backlog; + Lwt.wakeup signal_started (); + loop () + end; + + let server = {shutdown = lazy (Lwt.wakeup abort_wakener (); closed_waiter)} in + + server, started + +let establish_server ?fd ?buffer_size ?backlog sockaddr f = + let blocking_bind fd addr = + Lwt.return (Lwt_unix.Versioned.bind_1 fd addr) [@ocaml.warning "-3"] + in + _establish_server_base blocking_bind ?fd ?buffer_size ?backlog sockaddr f + |> fst let establish_server_safe ?fd ?buffer_size ?backlog sockaddr f = let best_effort_close channel = @@ -1488,7 +1505,12 @@ let establish_server_safe ?fd ?buffer_size ?backlog sockaddr f = >>= fun () -> best_effort_close output_channel) in - establish_server ?fd ?buffer_size ?backlog sockaddr handler + let server, started = + _establish_server_base + Lwt_unix.Versioned.bind_2 + ?fd ?buffer_size ?backlog sockaddr handler in + started >>= fun () -> + Lwt.return server let ignore_close ch = ignore (close ch) diff --git a/src/unix/lwt_io.mli b/src/unix/lwt_io.mli index e8e39e5754..60b231b54f 100644 --- a/src/unix/lwt_io.mli +++ b/src/unix/lwt_io.mli @@ -422,7 +422,8 @@ val establish_server : "The signature and semantics of this function will soon change: - the callback parameter f will evaluate to a thread (-> unit Lwt.t), - channels will be closed automatically when that thread completes, to avoid - leaking file descriptors. + leaking file descriptors, and +- the result will be a thread (-> server Lwt.t). This will be breaking change. See https://github.com/ocsigen/lwt/pull/258 To keep the current functionality, use Lwt_io.Versioned.establish_server_1 @@ -599,7 +600,8 @@ sig ?fd : Lwt_unix.file_descr -> ?buffer_size : int -> ?backlog : int -> - Unix.sockaddr -> (input_channel * output_channel -> unit Lwt.t) -> server + Unix.sockaddr -> (input_channel * output_channel -> unit Lwt.t) -> + server Lwt.t (** [establish_server_safe ?fd ?buffer_size ?backlog sockaddr f] creates a server which listens for incoming connections. New connections are passed to [f]. When threads returned by [f] complete, the connections are closed diff --git a/tests/unix/test_lwt_io.ml b/tests/unix/test_lwt_io.ml index 3bb599f621..ba59994089 100644 --- a/tests/unix/test_lwt_io.ml +++ b/tests/unix/test_lwt_io.ml @@ -38,16 +38,16 @@ struct let with_client f = let handler_finished, notify_handler_finished = Lwt.wait () in - let server = - Lwt_io.Versioned.establish_server_2 - local - (fun channels -> - Lwt.finalize - (fun () -> f channels) - (fun () -> - Lwt.wakeup notify_handler_finished (); - Lwt.return_unit)) - in + Lwt_io.Versioned.establish_server_2 + local + (fun channels -> + Lwt.finalize + (fun () -> f channels) + (fun () -> + Lwt.wakeup notify_handler_finished (); + Lwt.return_unit)) + + >>= fun server -> let client_finished = Lwt_io.with_connection @@ -367,8 +367,8 @@ let suite = suite "lwt_io" [ let in_channel' = ref Lwt_io.stdin in let out_channel' = ref Lwt_io.stdout in - let server = - Lwt_io.Versioned.establish_server_2 local (fun _ -> Lwt.return_unit) in + Lwt_io.Versioned.establish_server_2 local (fun _ -> Lwt.return_unit) + >>= fun server -> Lwt_io.with_connection local (fun (in_channel, out_channel) -> in_channel' := in_channel; @@ -400,12 +400,10 @@ let suite = suite "lwt_io" [ let handler_started, notify_handler_started = Lwt.wait () in let finish_server, resume_server = Lwt.wait () in - let server = - Lwt_io.Versioned.establish_server_2 local - (fun _ -> - Lwt.wakeup notify_handler_started (); - finish_server) - in + Lwt_io.Versioned.establish_server_2 local + (fun _ -> + Lwt.wakeup notify_handler_started (); + finish_server) >>= fun server -> expecting_ebadf (fun () -> Lwt_io.with_connection local (fun (in_channel, out_channel) ->