Skip to content

Commit

Permalink
Separate the type of threads (covariant) from the type of thread wake…
Browse files Browse the repository at this point in the history
…ners (contravariant)

Ignore-this: fe7165d0365a398cdb7f485b328dfbde

darcs-hash:20090610131130-29d5e-0b8a377b418498c25bc2a68dcb49de19f31683b5
  • Loading branch information
vouillon at pps.jussieu.fr committed Jun 10, 2009
1 parent 2993247 commit 363f67a
Show file tree
Hide file tree
Showing 10 changed files with 84 additions and 67 deletions.
43 changes: 27 additions & 16 deletions src/lwt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@

exception Canceled

type +'a t
type -'a u

(* Reason for a thread to be a sleeping thread: *)
type sleep_reason =
| Task
Expand All @@ -49,27 +52,33 @@ and 'a thread_state =
- [sleep_reason] is the reason why the thread is sleeping
- [waiters] is the list of waiters, which are thunk functions
*)
| Repr of 'a t
| Repr of 'a thread_repr
(* [Repr t] a thread which behaves the same as [t] *)

and 'a t = 'a thread_state ref
and 'a thread_repr = 'a thread_state ref

external thread_repr : 'a t -> 'a thread_repr = "%identity"
external thread : 'a thread_repr -> 'a t = "%identity"
external wakener : 'a thread_repr -> 'a u = "%identity"
external wakener_repr : 'a u -> 'a thread_repr = "%identity"

(* Returns the represent of a thread, updating non-direct references: *)
let rec repr t =
let rec repr_rec t =
match !t with
| Repr t' -> let t'' = repr t' in if t'' != t' then t := Repr t''; t''
| Repr t' -> let t'' = repr_rec t' in if t'' != t' then t := Repr t''; t''
| _ -> t
let repr t = repr_rec (thread_repr t)

let run_waiters waiters t =
Lwt_sequence.iter_l (fun f -> f t) waiters
Lwt_sequence.iter_l (fun f -> f (thread t)) waiters

(* Restarts a sleeping thread [t]:
- run all its waiters
- set his state to the terminated state [state]
*)
let restart t state caller =
let t = repr t in
let t = repr_rec (wakener_repr t) in
match !t with
| Sleep((Wait | Task), waiters) ->
t := state;
Expand All @@ -86,7 +95,7 @@ let wakeup_exn t e = restart t (Fail e) "wakeup_exn"
let cancel t =
match !(repr t) with
| Sleep(Task, _) ->
wakeup_exn t Canceled
wakeup_exn (wakener (thread_repr t)) Canceled
| Sleep(Temp f, _) ->
f ()
| _ ->
Expand Down Expand Up @@ -122,18 +131,20 @@ let rec connect t1 t2 =
(* [t1] is not asleep: *)
invalid_arg "connect"

let return v = thread (ref (Return v))
let fail e = thread (ref (Fail e))
let temp f = thread (ref (Sleep(Temp f, Lwt_sequence.create ())))
let wait _ =
let t = ref (Sleep(Wait, Lwt_sequence.create ())) in (thread t, wakener t)
let task _ =
let t = ref (Sleep(Task, Lwt_sequence.create ())) in (thread t, wakener t)

(* apply function, reifying explicit exceptions into the thread type
apply: ('a -(exn)-> 'b t) -> ('a -(n)-> 'b t)
semantically a natural transformation TE -> T, where T is the thread
monad, which is layered over exception monad E.
*)
let apply f x = try f x with e -> ref (Fail e)

let return v = ref (Return v)
let fail e = ref (Fail e)
let temp f = ref (Sleep(Temp f, Lwt_sequence.create ()))
let wait _ = ref (Sleep(Wait, Lwt_sequence.create ()))
let task _ = ref (Sleep(Task, Lwt_sequence.create ()))
let apply f x = try f x with e -> fail e

let new_waiter waiters f =
Lwt_sequence.add_r f waiters
Expand Down Expand Up @@ -235,7 +246,7 @@ let rec nth_ready l n =
| _ when n > 0 ->
nth_ready rem (n - 1)
| _ ->
x
thread x

let ready_count l =
List.fold_left (fun acc x -> match !(repr x) with Sleep _ -> acc | _ -> acc + 1) 0 l
Expand Down Expand Up @@ -309,7 +320,7 @@ let join l =
let rec bind_sleepers = function
| [] ->
(* If no thread is sleeping, returns now: *)
if !sleeping = 0 then res := Return ();
if !sleeping = 0 then thread_repr res := Return ();
res
| t :: l -> match !(repr t) with
| Fail exn ->
Expand Down
18 changes: 11 additions & 7 deletions src/lwt.mli
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
You should use [catch] instead. *)


type 'a t
type +'a t
(** The type of threads returning a result of type ['a]. *)

val return : 'a -> 'a t
Expand Down Expand Up @@ -112,17 +112,21 @@ val ignore_result : 'a t -> unit
Note that if the thread [t] yields and later fails, the
exception will not be raised at this point in the program. *)

val wait : unit -> 'a t
(** [wait ()] is a thread which sleeps forever (unless it is
resumed by one of the functions [wakeup], [wakeup_exn] below).
type 'a u
(** The type of thread wakeners. *)

val wait : unit -> 'a t * 'a u
(** [wait ()] is a pair of a thread which sleeps forever (unless
it is resumed by one of the functions [wakeup], [wakeup_exn]
below) and the corresponding wakener.
This thread does not block the execution of the remainder of
the program (except of course, if another thread tries to
wait for its termination). *)

val wakeup : 'a t -> 'a -> unit
val wakeup : 'a u -> 'a -> unit
(** [wakeup t e] makes the sleeping thread [t] terminate and
return the value of the expression [e]. *)
val wakeup_exn : 'a t -> exn -> unit
val wakeup_exn : 'a u -> exn -> unit
(** [wakeup_exn t e] makes the sleeping thread [t] fail with the
exception [e]. *)

Expand All @@ -147,7 +151,7 @@ val state : 'a t -> 'a state
exception Canceled
(** Canceled threads fails with this exception *)

val task : unit -> 'a t
val task : unit -> 'a t * 'a u
(** [task ()] creates a sleeping thread that can be canceled using
{!cancel} *)

Expand Down
20 changes: 10 additions & 10 deletions src/lwt_io.ml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ and 'mode channel = {
channel : 'mode _channel;
(* The real channel *)

mutable queued : unit Lwt.t Lwt_sequence.t;
mutable queued : unit Lwt.u Lwt_sequence.t;
(* Queued operations *)
}

Expand All @@ -110,7 +110,7 @@ and 'mode _channel = {
(* Position of the end of data int the buffer. It is equal to
[length] for output channels. *)

abort : int Lwt.t;
abort : int Lwt.t * int Lwt.u;
(* Thread which is wakeup with an exception when the channel is
closed. *)

Expand Down Expand Up @@ -183,7 +183,7 @@ let perform_io ch = match ch.main.state with
(size, ch.length - size)
| Output ->
(0, ch.ptr) in
lwt n = choose [ch.abort; lwt_unix_call (fun _ -> ch.perform_io ch.buffer ptr len)] in
lwt n = choose [fst ch.abort; lwt_unix_call (fun _ -> ch.perform_io ch.buffer ptr len)] in
(* Never trust user functions... *)
if n < 0 || n > len then
fail (Failure (Printf.sprintf "Lwt_io: invalid result of the [%s] function(request=%d,result=%d)"
Expand Down Expand Up @@ -296,10 +296,10 @@ let primitive f wrapper = match wrapper.state with
return ()

| Busy_primitive | Busy_atomic _ ->
let w = task () in
let (res, w) = task () in
let node = Lwt_sequence.add_r w wrapper.queued in
Lwt.on_cancel w (fun _ -> Lwt_sequence.remove node);
lwt () = w in
Lwt.on_cancel res (fun _ -> Lwt_sequence.remove node);
lwt () = res in
begin match wrapper.state with
| Closed ->
(* The channel has been closed while we were waiting *)
Expand Down Expand Up @@ -343,10 +343,10 @@ let atomic f wrapper = match wrapper.state with
return ()

| Busy_primitive | Busy_atomic _ ->
let w = task () in
let (res, w) = task () in
let node = Lwt_sequence.add_r w wrapper.queued in
Lwt.on_cancel w (fun _ -> Lwt_sequence.remove node);
lwt () = w in
Lwt.on_cancel res (fun _ -> Lwt_sequence.remove node);
lwt () = res in
begin match wrapper.state with
| Closed ->
(* The channel has been closed while we were waiting *)
Expand Down Expand Up @@ -391,7 +391,7 @@ let rec abort wrapper = match wrapper.state with
wrapper.state <- Closed;
(* Abort any current real reading/writing operation on the
channel: *)
wakeup_exn wrapper.channel.abort (closed_channel wrapper.channel);
wakeup_exn (snd wrapper.channel.abort) (closed_channel wrapper.channel);
Lazy.force wrapper.channel.close

let close wrapper = match wrapper.channel.mode with
Expand Down
6 changes: 3 additions & 3 deletions src/lwt_monitor.ml
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ open Lwt

module Condition =
struct
type 'a t = 'a Lwt.t Queue.t
type 'a t = 'a Lwt.u Queue.t

let wait cvar =
let thread = Lwt.wait () in
Queue.add thread cvar;
let (thread, w) = Lwt.wait () in
Queue.add w cvar;
thread

let notify cvar arg =
Expand Down
6 changes: 3 additions & 3 deletions src/lwt_mutex.ml
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@

open Lwt

type t = { mutable locked : bool; mutable waiters : unit Lwt.t Lwt_sequence.t }
type t = { mutable locked : bool; mutable waiters : unit Lwt.u Lwt_sequence.t }

let create () = { locked = false; waiters = Lwt_sequence.create () }

let rec lock m =
if m.locked then begin
let res = Lwt.task () in
let node = Lwt_sequence.add_r res m.waiters in
let (res, w) = Lwt.task () in
let node = Lwt_sequence.add_r w m.waiters in
Lwt.on_cancel res (fun _ -> Lwt_sequence.remove node);
res
end else begin
Expand Down
16 changes: 8 additions & 8 deletions src/lwt_mvar.ml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ type 'a t = {
mutable contents : 'a option;
(* Current contents *)

mutable writers : ('a * unit Lwt.t) Lwt_sequence.t;
mutable writers : ('a * unit Lwt.u) Lwt_sequence.t;
(* Threads waiting to put a value *)

mutable readers : 'a Lwt.t Lwt_sequence.t;
mutable readers : 'a Lwt.u Lwt_sequence.t;
(* Threads waiting for a value *)
}

Expand All @@ -64,10 +64,10 @@ let put mvar v =
end;
return_unit
| Some _ ->
let w = Lwt.task () in
let (res, w) = Lwt.task () in
let node = Lwt_sequence.add_r (v, w) mvar.writers in
Lwt.on_cancel w (fun _ -> Lwt_sequence.remove node);
w
Lwt.on_cancel res (fun _ -> Lwt_sequence.remove node);
res

let take mvar =
match mvar.contents with
Expand All @@ -81,7 +81,7 @@ let take mvar =
end;
Lwt.return v
| None ->
let w = Lwt.task () in
let (res, w) = Lwt.task () in
let node = Lwt_sequence.add_r w mvar.readers in
Lwt.on_cancel w (fun _ -> Lwt_sequence.remove node);
w
Lwt.on_cancel res (fun _ -> Lwt_sequence.remove node);
res
6 changes: 3 additions & 3 deletions src/lwt_pool.ml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type 'a t =
max : int;
mutable count : int;
list : 'a Queue.t;
waiters : 'a Lwt.t Queue.t }
waiters : 'a Lwt.u Queue.t }

let create m ?(check = fun _ f -> f true) create =
{ max = m;
Expand All @@ -59,8 +59,8 @@ let acquire p =
if p.count < p.max then
create_member p
else begin
let r = wait () in
Queue.push r p.waiters;
let (r, w) = wait () in
Queue.push w p.waiters;
r
end
Expand Down
15 changes: 8 additions & 7 deletions src/lwt_preemptive.ml
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ type thread = {
let workers : thread Queue.t = Queue.create ()

(* Queue of clients waiting for a worker to be available: *)
let waiters : thread Lwt.t Lwt_sequence.t = Lwt_sequence.create ()
let waiters : thread Lwt.u Lwt_sequence.t = Lwt_sequence.create ()

(* Mapping from thread ids to client lwt-thread: *)
let clients : (int, unit Lwt.t) Hashtbl.t = Hashtbl.create 16
let clients : (int, unit Lwt.u) Hashtbl.t = Hashtbl.create 16

(* Code executed by a worker: *)
let rec worker_loop worker =
Expand Down Expand Up @@ -129,10 +129,10 @@ let rec get_worker _ =
else if !threads_count < !max_threads then
return (make_worker ())
else begin
let w = Lwt.task () in
let (res, w) = Lwt.task () in
let node = Lwt_sequence.add_r w waiters in
Lwt.on_cancel w (fun _ -> Lwt_sequence.remove node);
w
Lwt.on_cancel res (fun _ -> Lwt_sequence.remove node);
res
end

(* +-----------------------------------------------------------------+
Expand Down Expand Up @@ -223,13 +223,14 @@ let detach f args =
result := `Failure exn
in
lwt worker = get_worker () in
let w = Lwt.wait () in
let (res, w) = Lwt.wait () in
Hashtbl.add clients (Thread.id worker.thread) w;
(* Send the task to the worker: *)
Event.sync (Event.send worker.task_channel task);
try_lwt
(* Wait for notification of the dispatcher: *)
w >> match !result with
res >>
match !result with
| `Nothing ->
fail (Failure "Lwt_preemptive.detach")
| `Success v ->
Expand Down
15 changes: 8 additions & 7 deletions src/lwt_unix.ml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ open Lwt
type sleep = {
time : float;
mutable canceled : bool;
thread : unit Lwt.t;
thread : unit Lwt.u;
}

module SleepQueue =
Expand All @@ -51,9 +51,9 @@ let sleep_queue = ref SleepQueue.empty
let new_sleeps = ref []

let sleep d =
let res = Lwt.task () in
let (res, w) = Lwt.task () in
let t = if d <= 0. then 0. else Unix.gettimeofday () +. d in
let sleeper = { time = t; canceled = false; thread = res } in
let sleeper = { time = t; canceled = false; thread = w } in
new_sleeps := sleeper :: !new_sleeps;
Lwt.on_cancel res (fun _ -> sleeper.canceled <- true);
res
Expand Down Expand Up @@ -190,9 +190,10 @@ let rec retry_syscall set ch cont action =
ignore (Lwt_sequence.add_r (fun _ -> retry_syscall set ch cont action) (get_actions ch set))

let register_action set ch action =
let res = Lwt.task () in
let (res, w) = Lwt.task () in
let actions = get_actions ch set in
let node = Lwt_sequence.add_r (fun _ -> retry_syscall set ch res action) actions in
let node =
Lwt_sequence.add_r (fun _ -> retry_syscall set ch w action) actions in
(* Unregister the action on cancel: *)
Lwt.on_cancel res begin fun _ ->
Lwt_sequence.remove node;
Expand Down Expand Up @@ -548,8 +549,8 @@ let waitpid flags pid =
return res
else begin
ignore (Lazy.force init_wait_pid);
let res = Lwt.task () in
let node = Lwt_sequence.add_l (res, flags, pid) wait_children in
let (res, w) = Lwt.task () in
let node = Lwt_sequence.add_l (w, flags, pid) wait_children in
Lwt.on_cancel res (fun _ -> Lwt_sequence.remove node);
res
end
Expand Down
Loading

0 comments on commit 363f67a

Please sign in to comment.