diff --git a/src/unix/lwt_io.ml b/src/unix/lwt_io.ml index d8ef3cb6a8..c93a56af4e 100644 --- a/src/unix/lwt_io.ml +++ b/src/unix/lwt_io.ml @@ -1434,24 +1434,25 @@ let establish_server_base let closed_waiter, closed_wakener = Lwt.wait () in let rec loop () = Lwt.pick [Lwt_unix.accept sock >|= (fun x -> `Accept x); abort_waiter] >>= function - | `Accept(fd, _addr) -> - (try Lwt_unix.set_close_on_exec fd with Invalid_argument _ -> ()); - let close = lazy (close_socket fd) in - f (of_fd ~buffer:(Lwt_bytes.create buffer_size) ~mode:input - ~close:(fun () -> Lazy.force close) fd, - of_fd ~buffer:(Lwt_bytes.create buffer_size) ~mode:output - ~close:(fun () -> Lazy.force close) fd); - loop () - | `Shutdown -> - Lwt_unix.close sock >>= fun () -> - (match sockaddr with - | Unix.ADDR_UNIX path when path <> "" && path.[0] <> '\x00' -> - Unix.unlink path; - Lwt.return_unit - | _ -> - Lwt.return_unit) [@ocaml.warning "-4"] >>= fun () -> - Lwt.wakeup closed_wakener (); - Lwt.return_unit + | `Accept(fd, addr) -> + (try Lwt_unix.set_close_on_exec fd with Invalid_argument _ -> ()); + let close = lazy (close_socket fd) in + f addr + (of_fd ~buffer:(Lwt_bytes.create buffer_size) ~mode:input + ~close:(fun () -> Lazy.force close) fd, + of_fd ~buffer:(Lwt_bytes.create buffer_size) ~mode:output + ~close:(fun () -> Lazy.force close) fd); + loop () + | `Shutdown -> + Lwt_unix.close sock >>= fun () -> + (match sockaddr with + | Unix.ADDR_UNIX path when path <> "" && path.[0] <> '\x00' -> + Unix.unlink path; + Lwt.return_unit + | _ -> + Lwt.return_unit) [@ocaml.warning "-4"] >>= fun () -> + Lwt.wakeup closed_wakener (); + Lwt.return_unit in let started, signal_started = Lwt.wait () in @@ -1473,10 +1474,11 @@ let establish_server_deprecated ?fd ?buffer_size ?backlog sockaddr f = let blocking_bind fd addr = Lwt.return (Lwt_unix.Versioned.bind_1 fd addr) [@ocaml.warning "-3"] in + let f _addr c = f c in establish_server_base blocking_bind ?fd ?buffer_size ?backlog sockaddr f |> fst -let establish_server +let establish_server' ?fd ?buffer_size ?backlog ?(no_close = false) sockaddr f = let best_effort_close channel = (* First, check whether the channel is closed. f may have already tried to @@ -1495,13 +1497,13 @@ let establish_server Lwt.return_unit) in - let handler ((input_channel, output_channel) as channels) = + let handler addr ((input_channel, output_channel) as channels) = Lwt.async (fun () -> (* Not using Lwt.finalize here, to make sure that exceptions from [f] reach !Lwt.async_exception_hook before exceptions from closing the channels. *) Lwt.catch - (fun () -> f channels) + (fun () -> f addr channels) (fun exn -> !Lwt.async_exception_hook exn; Lwt.return_unit) @@ -1520,6 +1522,10 @@ let establish_server started >>= fun () -> Lwt.return server +let establish_server ?fd ?buffer_size ?backlog ?no_close sockaddr f = + let f _addr c = f c in + establish_server' ?fd ?buffer_size ?backlog ?no_close sockaddr f + let ignore_close ch = ignore (close ch) diff --git a/src/unix/lwt_io.mli b/src/unix/lwt_io.mli index 59474c12a1..abf30d5419 100644 --- a/src/unix/lwt_io.mli +++ b/src/unix/lwt_io.mli @@ -446,6 +446,17 @@ val establish_server : @since 3.0.0 *) +val establish_server' : + ?fd : Lwt_unix.file_descr -> + ?buffer_size : int -> + ?backlog : int -> + ?no_close : bool -> + Unix.sockaddr -> + (Lwt_unix.sockaddr -> input_channel * output_channel -> unit Lwt.t) -> + server Lwt.t +(** Like establish_server but allows you to access the client's socket + in the callback. *) + val shutdown_server : server -> unit Lwt.t (** Closes the given server's listening socket. The returned promise resolves when the [close(2)] system call completes. This function does not affect the