Skip to content

Commit

Permalink
Add Atomic.compare_exchange (#3368)
Browse files Browse the repository at this point in the history
* cmpxchg

* dedup impl

* val_bool
  • Loading branch information
TheNumbat authored Dec 12, 2024
1 parent 6a49953 commit 17d57bd
Show file tree
Hide file tree
Showing 22 changed files with 121 additions and 27 deletions.
15 changes: 15 additions & 0 deletions backend/cmm_helpers.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4207,6 +4207,21 @@ let atomic_compare_and_set ~dbg atomic ~old_value ~new_value =
[atomic; old_value; new_value],
dbg )

let atomic_compare_exchange ~dbg atomic ~old_value ~new_value =
Cop
( Cextcall
{ func = "caml_atomic_compare_exchange";
builtin = false;
returns = true;
effects = Arbitrary_effects;
coeffects = Has_coeffects;
ty = typ_val;
ty_args = [];
alloc = false
},
[atomic; old_value; new_value],
dbg )

type even_or_odd =
| Even
| Odd
Expand Down
7 changes: 7 additions & 0 deletions backend/cmm_helpers.mli
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,13 @@ val atomic_compare_and_set :
new_value:expression ->
expression

val atomic_compare_exchange :
dbg:Debuginfo.t ->
expression ->
old_value:expression ->
new_value:expression ->
expression

val emit_gc_roots_table : symbols:symbol list -> phrase list -> phrase list

val perform : dbg:Debuginfo.t -> expression -> expression
Expand Down
4 changes: 3 additions & 1 deletion bytecomp/bytegen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ let preserve_tailcall_for_prim = function
| Pbigstring_set_64 _ | Pbigstring_set_128 _
| Pprobe_is_enabled _ | Pobj_dup
| Pctconst _ | Pbswap16 | Pbbswap _ | Pint_as_pointer _
| Patomic_exchange | Patomic_cas | Patomic_fetch_add | Patomic_load _
| Patomic_exchange | Patomic_compare_exchange
| Patomic_cas | Patomic_fetch_add | Patomic_load _
| Pdls_get | Preinterpret_tagged_int63_as_unboxed_int64
| Preinterpret_unboxed_int64_as_tagged_int63 | Ppoll ->
false
Expand Down Expand Up @@ -653,6 +654,7 @@ let comp_primitive stack_info p sz args =
| Pobj_dup -> Kccall("caml_obj_dup", 1)
| Patomic_load _ -> Kccall("caml_atomic_load", 1)
| Patomic_exchange -> Kccall("caml_atomic_exchange", 2)
| Patomic_compare_exchange -> Kccall("caml_atomic_compare_exchange", 3)
| Patomic_cas -> Kccall("caml_atomic_cas", 3)
| Patomic_fetch_add -> Kccall("caml_atomic_fetch_add", 2)
| Pdls_get -> Kccall("caml_domain_dls_get", 1)
Expand Down
6 changes: 5 additions & 1 deletion lambda/lambda.ml
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ type primitive =
(* Atomic operations *)
| Patomic_load of {immediate_or_pointer : immediate_or_pointer}
| Patomic_exchange
| Patomic_compare_exchange
| Patomic_cas
| Patomic_fetch_add
(* Inhibition of optimisation *)
Expand Down Expand Up @@ -1923,6 +1924,7 @@ let primitive_may_allocate : primitive -> locality_mode option = function
Some alloc_heap
| Patomic_load _
| Patomic_exchange
| Patomic_compare_exchange
| Patomic_cas
| Patomic_fetch_add
| Pdls_get
Expand Down Expand Up @@ -2087,7 +2089,8 @@ let primitive_can_raise prim =
| Punbox_vector _ | Punbox_int _ | Pbox_int _ | Pmake_unboxed_product _
| Punboxed_product_field _ | Pget_header _ ->
false
| Patomic_exchange | Patomic_cas | Patomic_fetch_add | Patomic_load _ -> false
| Patomic_exchange | Patomic_compare_exchange
| Patomic_cas | Patomic_fetch_add | Patomic_load _ -> false
| Prunstack | Pperform | Presume | Preperform -> true (* XXX! *)
| Pdls_get | Ppoll | Preinterpret_tagged_int63_as_unboxed_int64
| Preinterpret_unboxed_int64_as_tagged_int63 ->
Expand Down Expand Up @@ -2317,6 +2320,7 @@ let primitive_result_layout (p : primitive) =
| Patomic_load { immediate_or_pointer = Immediate } -> layout_int
| Patomic_load { immediate_or_pointer = Pointer } -> layout_any_value
| Patomic_exchange
| Patomic_compare_exchange
| Patomic_cas
| Patomic_fetch_add
| Pdls_get -> layout_any_value
Expand Down
1 change: 1 addition & 0 deletions lambda/lambda.mli
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ type primitive =
(* Atomic operations *)
| Patomic_load of {immediate_or_pointer : immediate_or_pointer}
| Patomic_exchange
| Patomic_compare_exchange
| Patomic_cas
| Patomic_fetch_add
(* Inhibition of optimisation *)
Expand Down
2 changes: 2 additions & 0 deletions lambda/printlambda.ml
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,7 @@ let primitive ppf = function
| Immediate -> fprintf ppf "atomic_load_imm"
| Pointer -> fprintf ppf "atomic_load_ptr")
| Patomic_exchange -> fprintf ppf "atomic_exchange"
| Patomic_compare_exchange -> fprintf ppf "atomic_compare_exchange"
| Patomic_cas -> fprintf ppf "atomic_cas"
| Patomic_fetch_add -> fprintf ppf "atomic_fetch_add"
| Popaque _ -> fprintf ppf "opaque"
Expand Down Expand Up @@ -1071,6 +1072,7 @@ let name_of_primitive = function
| Immediate -> "atomic_load_imm"
| Pointer -> "atomic_load_ptr")
| Patomic_exchange -> "Patomic_exchange"
| Patomic_compare_exchange -> "Patomic_compare_exchange"
| Patomic_cas -> "Patomic_cas"
| Patomic_fetch_add -> "Patomic_fetch_add"
| Popaque _ -> "Popaque"
Expand Down
3 changes: 2 additions & 1 deletion lambda/tmc.ml
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,8 @@ let rec choice ctx t =
| Prunstack | Pperform | Presume | Preperform | Pdls_get

(* we don't handle atomic primitives *)
| Patomic_exchange | Patomic_cas | Patomic_fetch_add | Patomic_load _
| Patomic_exchange | Patomic_compare_exchange
| Patomic_cas | Patomic_fetch_add | Patomic_load _
| Punbox_float _ | Pbox_float (_, _)
| Punbox_int _ | Pbox_int _
| Punbox_vector _ | Pbox_vector (_, _)
Expand Down
4 changes: 3 additions & 1 deletion lambda/translprim.ml
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,7 @@ let lookup_primitive loc ~poly_mode ~poly_sort pos p =
| "%atomic_load" ->
Primitive ((Patomic_load {immediate_or_pointer=Pointer}), 1)
| "%atomic_exchange" -> Primitive (Patomic_exchange, 2)
| "%atomic_compare_exchange" -> Primitive (Patomic_compare_exchange, 3)
| "%atomic_cas" -> Primitive (Patomic_cas, 3)
| "%atomic_fetch_add" -> Primitive (Patomic_fetch_add, 2)
| "%runstack" ->
Expand Down Expand Up @@ -1792,7 +1793,8 @@ let lambda_primitive_needs_event_after = function
| Parrayblit _
| Parraylength _ | Parrayrefu _ | Parraysetu _ | Pisint _ | Pisnull | Pisout
| Pprobe_is_enabled _
| Patomic_exchange | Patomic_cas | Patomic_fetch_add | Patomic_load _
| Patomic_exchange | Patomic_compare_exchange
| Patomic_cas | Patomic_fetch_add | Patomic_load _
| Pintofbint _ | Pctconst _ | Pbswap16 | Pint_as_pointer _ | Popaque _
| Pdls_get
| Pobj_magic _ | Punbox_float _ | Punbox_int _ | Punbox_vector _
Expand Down
1 change: 1 addition & 0 deletions lambda/value_rec_compiler.ml
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ let compute_static_size lam =
| Pint_as_pointer _
| Patomic_load _
| Patomic_exchange
| Patomic_compare_exchange
| Patomic_cas
| Patomic_fetch_add
| Popaque _
Expand Down
4 changes: 2 additions & 2 deletions middle_end/flambda2/from_lambda/closure_conversion.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1046,8 +1046,8 @@ let close_primitive acc env ~let_bound_ids_with_kinds named
| Pbox_vector (_, _)
| Punbox_int _ | Pbox_int _ | Pmake_unboxed_product _
| Punboxed_product_field _ | Pget_header _ | Prunstack | Pperform
| Presume | Preperform | Patomic_exchange | Patomic_cas
| Patomic_fetch_add | Pdls_get | Ppoll | Patomic_load _
| Presume | Preperform | Patomic_exchange | Patomic_compare_exchange
| Patomic_cas | Patomic_fetch_add | Pdls_get | Ppoll | Patomic_load _
| Preinterpret_tagged_int63_as_unboxed_int64
| Preinterpret_unboxed_int64_as_tagged_int63 ->
(* Inconsistent with outer match *)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2319,6 +2319,8 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list)
atomic ) ]
| Patomic_exchange, [[atomic]; [new_value]] ->
[Binary (Atomic_exchange, atomic, new_value)]
| Patomic_compare_exchange, [[atomic]; [old_value]; [new_value]] ->
[Ternary (Atomic_compare_exchange, atomic, old_value, new_value)]
| Patomic_cas, [[atomic]; [old_value]; [new_value]] ->
[Ternary (Atomic_compare_and_set, atomic, old_value, new_value)]
| Patomic_fetch_add, [[atomic]; [i]] ->
Expand Down Expand Up @@ -2443,7 +2445,8 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list)
| Pfloatarray_set_128 _ | Pfloat_array_set_128 _ | Pint_array_set_128 _
| Punboxed_float_array_set_128 _ | Punboxed_float32_array_set_128 _
| Punboxed_int32_array_set_128 _ | Punboxed_int64_array_set_128 _
| Punboxed_nativeint_array_set_128 _ | Patomic_cas ),
| Punboxed_nativeint_array_set_128 _ | Patomic_cas
| Patomic_compare_exchange ),
( []
| [_]
| [_; _]
Expand Down
2 changes: 1 addition & 1 deletion middle_end/flambda2/parser/flambda_to_fexpr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ let ternop env (op : Flambda_primitive.ternary_primitive) : Fexpr.ternop =
let ask = fexpr_of_array_set_kind env ask in
Array_set (ak, ask)
| Bytes_or_bigstring_set (blv, saw) -> Bytes_or_bigstring_set (blv, saw)
| Bigarray_set _ | Atomic_compare_and_set ->
| Bigarray_set _ | Atomic_compare_and_set | Atomic_compare_exchange ->
Misc.fatal_errorf "TODO: Ternary primitive: %a"
Flambda_primitive.Without_args.print
(Flambda_primitive.Without_args.Ternary op)
Expand Down
7 changes: 7 additions & 0 deletions middle_end/flambda2/simplify/simplify_ternary_primitive.ml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ let simplify_atomic_compare_and_set ~original_prim dacc ~original_term _dbg
(P.result_kind' original_prim)
~original_term

let simplify_atomic_compare_exchange ~original_prim dacc ~original_term _dbg
~arg1:_ ~arg1_ty:_ ~arg2:_ ~arg2_ty:_ ~arg3:_ ~arg3_ty:_ ~result_var =
SPR.create_unknown dacc ~result_var
(P.result_kind' original_prim)
~original_term

let simplify_ternary_primitive dacc original_prim (prim : P.ternary_primitive)
~arg1 ~arg1_ty ~arg2 ~arg2_ty ~arg3 ~arg3_ty dbg ~result_var =
let original_term = Named.create_prim original_prim dbg in
Expand All @@ -84,6 +90,7 @@ let simplify_ternary_primitive dacc original_prim (prim : P.ternary_primitive)
| Bigarray_set (num_dimensions, bigarray_kind, bigarray_layout) ->
simplify_bigarray_set ~num_dimensions bigarray_kind bigarray_layout
| Atomic_compare_and_set -> simplify_atomic_compare_and_set ~original_prim
| Atomic_compare_exchange -> simplify_atomic_compare_exchange ~original_prim
in
simplifier dacc ~original_term dbg ~arg1 ~arg1_ty ~arg2 ~arg2_ty ~arg3
~arg3_ty ~result_var
3 changes: 2 additions & 1 deletion middle_end/flambda2/terms/code_size.ml
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,8 @@ let ternary_prim_size prim =
5 (* ~ 3 block_load + 2 block_set *)
| Bigarray_set (_dims, _kind, _layout) -> 2
(* ~ 1 block_load + 1 block_set *)
| Atomic_compare_and_set -> does_not_need_caml_c_call_extcall_size
| Atomic_compare_and_set | Atomic_compare_exchange ->
does_not_need_caml_c_call_extcall_size

let variadic_prim_size prim args =
match (prim : Flambda_primitive.variadic_primitive) with
Expand Down
23 changes: 14 additions & 9 deletions middle_end/flambda2/terms/flambda_primitive.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1856,11 +1856,12 @@ type ternary_primitive =
| Bytes_or_bigstring_set of bytes_like_value * string_accessor_width
| Bigarray_set of num_dimensions * Bigarray_kind.t * Bigarray_layout.t
| Atomic_compare_and_set
| Atomic_compare_exchange

let ternary_primitive_eligible_for_cse p =
match p with
| Array_set _ | Bytes_or_bigstring_set _ | Bigarray_set _
| Atomic_compare_and_set ->
| Atomic_compare_and_set | Atomic_compare_exchange ->
false

let compare_ternary_primitive p1 p2 =
Expand All @@ -1870,6 +1871,7 @@ let compare_ternary_primitive p1 p2 =
| Bytes_or_bigstring_set _ -> 1
| Bigarray_set _ -> 2
| Atomic_compare_and_set -> 3
| Atomic_compare_exchange -> 4
in
match p1, p2 with
| Array_set (kind1, set_kind1), Array_set (kind2, set_kind2) ->
Expand All @@ -1888,7 +1890,7 @@ let compare_ternary_primitive p1 p2 =
let c = Stdlib.compare kind1 kind2 in
if c <> 0 then c else Stdlib.compare layout1 layout2
| ( ( Array_set _ | Bytes_or_bigstring_set _ | Bigarray_set _
| Atomic_compare_and_set ),
| Atomic_compare_and_set | Atomic_compare_exchange ),
_ ) ->
Stdlib.compare
(ternary_primitive_numbering p1)
Expand All @@ -1910,6 +1912,7 @@ let print_ternary_primitive ppf p =
"@[(Bigarray_set (num_dimensions@ %d)@ (kind@ %a)@ (layout@ %a))@]"
num_dimensions Bigarray_kind.print kind Bigarray_layout.print layout
| Atomic_compare_and_set -> fprintf ppf "Atomic_compare_and_set"
| Atomic_compare_exchange -> fprintf ppf "Atomic_compare_exchange"

let args_kind_of_ternary_primitive p =
match p with
Expand Down Expand Up @@ -1939,43 +1942,45 @@ let args_kind_of_ternary_primitive p =
bigstring_kind, bytes_or_bigstring_index_kind, K.naked_vec128
| Bigarray_set (_, kind, _) ->
bigarray_kind, bigarray_index_kind, Bigarray_kind.element_kind kind
| Atomic_compare_and_set -> K.value, K.value, K.value
| Atomic_compare_and_set | Atomic_compare_exchange ->
K.value, K.value, K.value

let result_kind_of_ternary_primitive p : result_kind =
match p with
| Array_set _ | Bytes_or_bigstring_set _ | Bigarray_set _ -> Unit
| Atomic_compare_and_set -> Singleton K.value
| Atomic_compare_and_set | Atomic_compare_exchange -> Singleton K.value

let effects_and_coeffects_of_ternary_primitive p :
Effects.t * Coeffects.t * Placement.t =
match p with
| Array_set _ -> writing_to_an_array
| Bytes_or_bigstring_set _ -> writing_to_bytes_or_bigstring
| Bigarray_set (_, kind, _) -> writing_to_a_bigarray kind
| Atomic_compare_and_set -> Arbitrary_effects, Has_coeffects, Strict
| Atomic_compare_and_set | Atomic_compare_exchange ->
Arbitrary_effects, Has_coeffects, Strict

let ternary_classify_for_printing p =
match p with
| Array_set _ | Bytes_or_bigstring_set _ | Bigarray_set _
| Atomic_compare_and_set ->
| Atomic_compare_and_set | Atomic_compare_exchange ->
Neither

let free_names_ternary_primitive p =
match p with
| Array_set _ | Bytes_or_bigstring_set _ | Bigarray_set _
| Atomic_compare_and_set ->
| Atomic_compare_and_set | Atomic_compare_exchange ->
Name_occurrences.empty

let apply_renaming_ternary_primitive p _ =
match p with
| Array_set _ | Bytes_or_bigstring_set _ | Bigarray_set _
| Atomic_compare_and_set ->
| Atomic_compare_and_set | Atomic_compare_exchange ->
p

let ids_for_export_ternary_primitive p =
match p with
| Array_set _ | Bytes_or_bigstring_set _ | Bigarray_set _
| Atomic_compare_and_set ->
| Atomic_compare_and_set | Atomic_compare_exchange ->
Ids_for_export.empty

type variadic_primitive =
Expand Down
1 change: 1 addition & 0 deletions middle_end/flambda2/terms/flambda_primitive.mli
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ type ternary_primitive =
| Bytes_or_bigstring_set of bytes_like_value * string_accessor_width
| Bigarray_set of num_dimensions * Bigarray_kind.t * Bigarray_layout.t
| Atomic_compare_and_set
| Atomic_compare_exchange

(** Primitives taking zero or more arguments. *)
type variadic_primitive =
Expand Down
2 changes: 2 additions & 0 deletions middle_end/flambda2/to_cmm/to_cmm_primitive.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,8 @@ let ternary_primitive _env dbg f x y z =
bigarray_store ~dbg kind ~bigarray:x ~index:y ~new_value:z
| Atomic_compare_and_set ->
C.atomic_compare_and_set ~dbg x ~old_value:y ~new_value:z
| Atomic_compare_exchange ->
C.atomic_compare_exchange ~dbg x ~old_value:y ~new_value:z

let variadic_primitive _env dbg f args =
match (f : P.variadic_primitive) with
Expand Down
19 changes: 13 additions & 6 deletions runtime/memory.c
Original file line number Diff line number Diff line change
Expand Up @@ -376,27 +376,34 @@ CAMLprim value caml_atomic_exchange (value ref, value v)
return ret;
}

CAMLprim value caml_atomic_cas (value ref, value oldv, value newv)
CAMLprim value caml_atomic_compare_exchange (value ref, value oldv, value newv)
{
if (caml_domain_alone()) {
value* p = Op_val(ref);
if (*p == oldv) {
*p = newv;
write_barrier(ref, 0, oldv, newv);
return Val_int(1);
return oldv;
} else {
return Val_int(0);
return *p;
}
} else {
atomic_value* p = &Op_atomic_val(ref)[0];
int cas_ret = atomic_compare_exchange_strong(p, &oldv, newv);
atomic_thread_fence(memory_order_release); /* generates `dmb ish` on Arm64*/
if (cas_ret) {
write_barrier(ref, 0, oldv, newv);
return Val_int(1);
} else {
return Val_int(0);
}
return oldv;
}
}

CAMLprim value caml_atomic_cas (value ref, value oldv, value newv)
{
if (caml_atomic_compare_exchange(ref, oldv, newv) == oldv) {
return Val_true;
} else {
return Val_false;
}
}

Expand Down
15 changes: 12 additions & 3 deletions runtime4/misc.c
Original file line number Diff line number Diff line change
Expand Up @@ -252,14 +252,23 @@ CAMLprim value caml_atomic_load(value ref)
return Field(ref, 0);
}

CAMLprim value caml_atomic_cas(value ref, value oldv, value newv)
CAMLprim value caml_atomic_compare_exchange(value ref, value oldv, value newv)
{
value* p = Op_val(ref);
if (*p == oldv) {
caml_modify(p, newv);
return Val_int(1);
return oldv;
} else {
return *p;
}
}

CAMLprim value caml_atomic_cas(value ref, value oldv, value newv)
{
if (caml_atomic_compare_exchange(ref, oldv, newv) == oldv) {
return Val_true;
} else {
return Val_int(0);
return Val_false;
}
}

Expand Down
Loading

0 comments on commit 17d57bd

Please sign in to comment.