diff --git a/CHANGES b/CHANGES index f57e5ca13..6d6810c06 100644 --- a/CHANGES +++ b/CHANGES @@ -8,6 +8,8 @@ * Fix win32_spawn leaking dev_null fd in the parent process. (#906, Antonin Décimo) * Prefer SetHandleInformation to DuplicateHandle in set_close_on_exec for sockets. DuplicateHandle mustn't be used on sockets. (#907, Antonin Décimo) + * Lwt.pick and Lwt.choose select preferentially failed promises as per + documentation (#856, #874, Raman Varabets) ===== 5.5.0 ===== diff --git a/src/core/lwt.ml b/src/core/lwt.ml index 9dd731bf7..ec5858246 100644 --- a/src/core/lwt.ml +++ b/src/core/lwt.ml @@ -2642,15 +2642,28 @@ struct [choose]/[pick] implementation, which may actually be optimal anyway with Flambda. *) - let count_resolved_promises_in (ps : _ t list) = - let accumulate total p = - let Internal p = to_internal_promise p in - match (underlying p).state with - | Fulfilled _ -> total + 1 - | Rejected _ -> total + 1 - | Pending _ -> total + let count_resolved_promises_in (ps : 'a t list) = + let rec count_and_gather_rejected total rejected ps = + match ps with + | [] -> Result.Error (total, rejected) + | p :: ps -> + let Internal q = to_internal_promise p in + match (underlying q).state with + | Fulfilled _ -> count_and_gather_rejected total rejected ps + | Rejected _ -> count_and_gather_rejected (total + 1) (p :: rejected) ps + | Pending _ -> count_and_gather_rejected total rejected ps + in + let rec count_fulfilled total ps = + match ps with + | [] -> Result.Ok total + | p :: ps -> + let Internal q = to_internal_promise p in + match (underlying q).state with + | Fulfilled _ -> count_fulfilled (total + 1) ps + | Rejected _ -> count_and_gather_rejected 1 [p] ps + | Pending _ -> count_fulfilled total ps in - List.fold_left accumulate 0 ps + count_fulfilled 0 ps (* Evaluates to the [n]th promise in [ps], among only those promises in [ps] that are resolved. The caller is expected to ensure that there are at @@ -2704,7 +2717,7 @@ struct invalid_arg "Lwt.choose [] would return a promise that is pending forever"; match count_resolved_promises_in ps with - | 0 -> + | Result.Ok 0 -> let p = new_pending ~how_to_cancel:(propagate_cancel_to_several ps) in let callback result = @@ -2718,17 +2731,20 @@ struct to_public_promise p - | 1 -> + | Result.Ok 1 -> nth_resolved ps 0 - | n -> + | Result.Ok n -> + nth_resolved ps (Random.State.int (Lazy.force prng) n) + + | Result.Error (n, ps) -> nth_resolved ps (Random.State.int (Lazy.force prng) n) let pick ps = if ps = [] then invalid_arg "Lwt.pick [] would return a promise that is pending forever"; match count_resolved_promises_in ps with - | 0 -> + | Ok 0 -> let p = new_pending ~how_to_cancel:(propagate_cancel_to_several ps) in let callback result = @@ -2743,13 +2759,17 @@ struct to_public_promise p - | 1 -> + | Ok 1 -> nth_resolved_and_cancel_pending ps 0 - | n -> + | Ok n -> nth_resolved_and_cancel_pending ps (Random.State.int (Lazy.force prng) n) + | Error (n, qs) -> + List.iter cancel ps; + nth_resolved qs (Random.State.int (Lazy.force prng) n) + (* If [nchoose ps] or [npick ps] found all promises in [ps] pending, the diff --git a/test/core/test_lwt.ml b/test/core/test_lwt.ml index 7c14f5ecd..803b02364 100644 --- a/test/core/test_lwt.ml +++ b/test/core/test_lwt.ml @@ -2274,11 +2274,10 @@ let choose_tests = suite "choose" [ end; test "multiple resolved" begin fun () -> - (* This is run in a loop to exercise the internal PRNG. *) - let outcomes = Array.make 3 0 in + (* This is run in a loop to check that it consistently returns the failed + result as per documentation. *) let rec repeat n = - if n <= 0 then () - else + n <= 0 || begin let p = Lwt.choose [fst (Lwt.wait ()); @@ -2286,19 +2285,15 @@ let choose_tests = suite "choose" [ Lwt.fail Exception; Lwt.return "bar"] in - begin match Lwt.state p with - | Lwt.Return "foo" -> outcomes.(0) <- outcomes.(0) + 1 - | Lwt.Fail Exception -> outcomes.(1) <- outcomes.(1) + 1 - | Lwt.Return "bar" -> outcomes.(2) <- outcomes.(2) + 1 + match Lwt.state p with + | Lwt.Return "foo" -> false + | Lwt.Fail Exception -> repeat (n - 1) + | Lwt.Return "bar" -> false | _ -> assert false end [@ocaml.warning "-4"]; - repeat (n - 1) in - let count = 1000 in - repeat count; - Lwt.return - (outcomes.(0) > 0 && outcomes.(1) > 0 && outcomes.(2) > 0 && - outcomes.(0) + outcomes.(1) + outcomes.(2) = count) + let count = 100 in + Lwt.return (repeat count) end; test "pending" begin fun () -> @@ -2982,31 +2977,29 @@ let pick_tests = suite "pick" [ end; test "multiple resolved" begin fun () -> - (* This is run in a loop to exercise the internal PRNG. *) - let outcomes = Array.make 3 0 in + (* This is run in a loop to check that it consistently returns the failed + result as per documentation. *) let rec repeat n = - if n <= 0 then () - else + n <= 0 || begin + let (waiter, _) = Lwt.task () in let p = Lwt.pick - [fst (Lwt.wait ()); + [waiter; Lwt.return "foo"; Lwt.fail Exception; Lwt.return "bar"] in - begin match Lwt.state p with - | Lwt.Return "foo" -> outcomes.(0) <- outcomes.(0) + 1 - | Lwt.Fail Exception -> outcomes.(1) <- outcomes.(1) + 1 - | Lwt.Return "bar" -> outcomes.(2) <- outcomes.(2) + 1 + match Lwt.state p with + | Lwt.Return "foo" -> false + | Lwt.Fail Exception -> + Lwt.state waiter = Lwt.Fail Lwt.Canceled + && repeat (n - 1) + | Lwt.Return "bar" -> false | _ -> assert false end [@ocaml.warning "-4"]; - repeat (n - 1) in - let count = 1000 in - repeat count; - Lwt.return - (outcomes.(0) > 0 && outcomes.(1) > 0 && outcomes.(2) > 0 && - outcomes.(0) + outcomes.(1) + outcomes.(2) = count) + let count = 100 in + Lwt.return (repeat count) end; test "pending" begin fun () ->