From 17d57bd96c560985fde899eafe53a95d4ac1b18f Mon Sep 17 00:00:00 2001 From: Max Slater Date: Thu, 12 Dec 2024 16:42:39 -0500 Subject: [PATCH] Add Atomic.compare_exchange (#3368) * cmpxchg * dedup impl * val_bool --- backend/cmm_helpers.ml | 15 ++++++++++++ backend/cmm_helpers.mli | 7 ++++++ bytecomp/bytegen.ml | 4 +++- lambda/lambda.ml | 6 ++++- lambda/lambda.mli | 1 + lambda/printlambda.ml | 2 ++ lambda/tmc.ml | 3 ++- lambda/translprim.ml | 4 +++- lambda/value_rec_compiler.ml | 1 + .../from_lambda/closure_conversion.ml | 4 ++-- .../lambda_to_flambda_primitives.ml | 5 +++- .../flambda2/parser/flambda_to_fexpr.ml | 2 +- .../simplify/simplify_ternary_primitive.ml | 7 ++++++ middle_end/flambda2/terms/code_size.ml | 3 ++- .../flambda2/terms/flambda_primitive.ml | 23 +++++++++++-------- .../flambda2/terms/flambda_primitive.mli | 1 + .../flambda2/to_cmm/to_cmm_primitive.ml | 2 ++ runtime/memory.c | 19 ++++++++++----- runtime4/misc.c | 15 +++++++++--- stdlib/atomic.ml | 1 + stdlib/atomic.mli | 5 ++++ .../tests/lib-atomic/test_atomic_cmpxchg.ml | 18 +++++++++++++++ 22 files changed, 121 insertions(+), 27 deletions(-) create mode 100644 testsuite/tests/lib-atomic/test_atomic_cmpxchg.ml diff --git a/backend/cmm_helpers.ml b/backend/cmm_helpers.ml index 83e82438d84..3931e0eb7c6 100644 --- a/backend/cmm_helpers.ml +++ b/backend/cmm_helpers.ml @@ -4207,6 +4207,21 @@ let atomic_compare_and_set ~dbg atomic ~old_value ~new_value = [atomic; old_value; new_value], dbg ) +let atomic_compare_exchange ~dbg atomic ~old_value ~new_value = + Cop + ( Cextcall + { func = "caml_atomic_compare_exchange"; + builtin = false; + returns = true; + effects = Arbitrary_effects; + coeffects = Has_coeffects; + ty = typ_val; + ty_args = []; + alloc = false + }, + [atomic; old_value; new_value], + dbg ) + type even_or_odd = | Even | Odd diff --git a/backend/cmm_helpers.mli b/backend/cmm_helpers.mli index b075082c4db..f8eada76861 100644 --- a/backend/cmm_helpers.mli +++ b/backend/cmm_helpers.mli @@ -1011,6 +1011,13 @@ val atomic_compare_and_set : new_value:expression -> expression +val atomic_compare_exchange : + dbg:Debuginfo.t -> + expression -> + old_value:expression -> + new_value:expression -> + expression + val emit_gc_roots_table : symbols:symbol list -> phrase list -> phrase list val perform : dbg:Debuginfo.t -> expression -> expression diff --git a/bytecomp/bytegen.ml b/bytecomp/bytegen.ml index 6fa48e5e285..95e5137d184 100644 --- a/bytecomp/bytegen.ml +++ b/bytecomp/bytegen.ml @@ -195,7 +195,8 @@ let preserve_tailcall_for_prim = function | Pbigstring_set_64 _ | Pbigstring_set_128 _ | Pprobe_is_enabled _ | Pobj_dup | Pctconst _ | Pbswap16 | Pbbswap _ | Pint_as_pointer _ - | Patomic_exchange | Patomic_cas | Patomic_fetch_add | Patomic_load _ + | Patomic_exchange | Patomic_compare_exchange + | Patomic_cas | Patomic_fetch_add | Patomic_load _ | Pdls_get | Preinterpret_tagged_int63_as_unboxed_int64 | Preinterpret_unboxed_int64_as_tagged_int63 | Ppoll -> false @@ -653,6 +654,7 @@ let comp_primitive stack_info p sz args = | Pobj_dup -> Kccall("caml_obj_dup", 1) | Patomic_load _ -> Kccall("caml_atomic_load", 1) | Patomic_exchange -> Kccall("caml_atomic_exchange", 2) + | Patomic_compare_exchange -> Kccall("caml_atomic_compare_exchange", 3) | Patomic_cas -> Kccall("caml_atomic_cas", 3) | Patomic_fetch_add -> Kccall("caml_atomic_fetch_add", 2) | Pdls_get -> Kccall("caml_domain_dls_get", 1) diff --git a/lambda/lambda.ml b/lambda/lambda.ml index 3844ee97a79..a151ee51b1d 100644 --- a/lambda/lambda.ml +++ b/lambda/lambda.ml @@ -304,6 +304,7 @@ type primitive = (* Atomic operations *) | Patomic_load of {immediate_or_pointer : immediate_or_pointer} | Patomic_exchange + | Patomic_compare_exchange | Patomic_cas | Patomic_fetch_add (* Inhibition of optimisation *) @@ -1923,6 +1924,7 @@ let primitive_may_allocate : primitive -> locality_mode option = function Some alloc_heap | Patomic_load _ | Patomic_exchange + | Patomic_compare_exchange | Patomic_cas | Patomic_fetch_add | Pdls_get @@ -2087,7 +2089,8 @@ let primitive_can_raise prim = | Punbox_vector _ | Punbox_int _ | Pbox_int _ | Pmake_unboxed_product _ | Punboxed_product_field _ | Pget_header _ -> false - | Patomic_exchange | Patomic_cas | Patomic_fetch_add | Patomic_load _ -> false + | Patomic_exchange | Patomic_compare_exchange + | Patomic_cas | Patomic_fetch_add | Patomic_load _ -> false | Prunstack | Pperform | Presume | Preperform -> true (* XXX! *) | Pdls_get | Ppoll | Preinterpret_tagged_int63_as_unboxed_int64 | Preinterpret_unboxed_int64_as_tagged_int63 -> @@ -2317,6 +2320,7 @@ let primitive_result_layout (p : primitive) = | Patomic_load { immediate_or_pointer = Immediate } -> layout_int | Patomic_load { immediate_or_pointer = Pointer } -> layout_any_value | Patomic_exchange + | Patomic_compare_exchange | Patomic_cas | Patomic_fetch_add | Pdls_get -> layout_any_value diff --git a/lambda/lambda.mli b/lambda/lambda.mli index c0ca0cebcdc..f72780dda1c 100644 --- a/lambda/lambda.mli +++ b/lambda/lambda.mli @@ -299,6 +299,7 @@ type primitive = (* Atomic operations *) | Patomic_load of {immediate_or_pointer : immediate_or_pointer} | Patomic_exchange + | Patomic_compare_exchange | Patomic_cas | Patomic_fetch_add (* Inhibition of optimisation *) diff --git a/lambda/printlambda.ml b/lambda/printlambda.ml index 6819f6935a5..ed0e1326695 100644 --- a/lambda/printlambda.ml +++ b/lambda/printlambda.ml @@ -897,6 +897,7 @@ let primitive ppf = function | Immediate -> fprintf ppf "atomic_load_imm" | Pointer -> fprintf ppf "atomic_load_ptr") | Patomic_exchange -> fprintf ppf "atomic_exchange" + | Patomic_compare_exchange -> fprintf ppf "atomic_compare_exchange" | Patomic_cas -> fprintf ppf "atomic_cas" | Patomic_fetch_add -> fprintf ppf "atomic_fetch_add" | Popaque _ -> fprintf ppf "opaque" @@ -1071,6 +1072,7 @@ let name_of_primitive = function | Immediate -> "atomic_load_imm" | Pointer -> "atomic_load_ptr") | Patomic_exchange -> "Patomic_exchange" + | Patomic_compare_exchange -> "Patomic_compare_exchange" | Patomic_cas -> "Patomic_cas" | Patomic_fetch_add -> "Patomic_fetch_add" | Popaque _ -> "Popaque" diff --git a/lambda/tmc.ml b/lambda/tmc.ml index 389418f2091..4a65037e891 100644 --- a/lambda/tmc.ml +++ b/lambda/tmc.ml @@ -902,7 +902,8 @@ let rec choice ctx t = | Prunstack | Pperform | Presume | Preperform | Pdls_get (* we don't handle atomic primitives *) - | Patomic_exchange | Patomic_cas | Patomic_fetch_add | Patomic_load _ + | Patomic_exchange | Patomic_compare_exchange + | Patomic_cas | Patomic_fetch_add | Patomic_load _ | Punbox_float _ | Pbox_float (_, _) | Punbox_int _ | Pbox_int _ | Punbox_vector _ | Pbox_vector (_, _) diff --git a/lambda/translprim.ml b/lambda/translprim.ml index dd9716be0d7..97a4c54c707 100644 --- a/lambda/translprim.ml +++ b/lambda/translprim.ml @@ -865,6 +865,7 @@ let lookup_primitive loc ~poly_mode ~poly_sort pos p = | "%atomic_load" -> Primitive ((Patomic_load {immediate_or_pointer=Pointer}), 1) | "%atomic_exchange" -> Primitive (Patomic_exchange, 2) + | "%atomic_compare_exchange" -> Primitive (Patomic_compare_exchange, 3) | "%atomic_cas" -> Primitive (Patomic_cas, 3) | "%atomic_fetch_add" -> Primitive (Patomic_fetch_add, 2) | "%runstack" -> @@ -1792,7 +1793,8 @@ let lambda_primitive_needs_event_after = function | Parrayblit _ | Parraylength _ | Parrayrefu _ | Parraysetu _ | Pisint _ | Pisnull | Pisout | Pprobe_is_enabled _ - | Patomic_exchange | Patomic_cas | Patomic_fetch_add | Patomic_load _ + | Patomic_exchange | Patomic_compare_exchange + | Patomic_cas | Patomic_fetch_add | Patomic_load _ | Pintofbint _ | Pctconst _ | Pbswap16 | Pint_as_pointer _ | Popaque _ | Pdls_get | Pobj_magic _ | Punbox_float _ | Punbox_int _ | Punbox_vector _ diff --git a/lambda/value_rec_compiler.ml b/lambda/value_rec_compiler.ml index b6a7bf4c78b..51568107cf6 100644 --- a/lambda/value_rec_compiler.ml +++ b/lambda/value_rec_compiler.ml @@ -345,6 +345,7 @@ let compute_static_size lam = | Pint_as_pointer _ | Patomic_load _ | Patomic_exchange + | Patomic_compare_exchange | Patomic_cas | Patomic_fetch_add | Popaque _ diff --git a/middle_end/flambda2/from_lambda/closure_conversion.ml b/middle_end/flambda2/from_lambda/closure_conversion.ml index 2f387a1aa28..d3c15be9aa3 100644 --- a/middle_end/flambda2/from_lambda/closure_conversion.ml +++ b/middle_end/flambda2/from_lambda/closure_conversion.ml @@ -1046,8 +1046,8 @@ let close_primitive acc env ~let_bound_ids_with_kinds named | Pbox_vector (_, _) | Punbox_int _ | Pbox_int _ | Pmake_unboxed_product _ | Punboxed_product_field _ | Pget_header _ | Prunstack | Pperform - | Presume | Preperform | Patomic_exchange | Patomic_cas - | Patomic_fetch_add | Pdls_get | Ppoll | Patomic_load _ + | Presume | Preperform | Patomic_exchange | Patomic_compare_exchange + | Patomic_cas | Patomic_fetch_add | Pdls_get | Ppoll | Patomic_load _ | Preinterpret_tagged_int63_as_unboxed_int64 | Preinterpret_unboxed_int64_as_tagged_int63 -> (* Inconsistent with outer match *) diff --git a/middle_end/flambda2/from_lambda/lambda_to_flambda_primitives.ml b/middle_end/flambda2/from_lambda/lambda_to_flambda_primitives.ml index ae05e49a5e9..c91ccd3663b 100644 --- a/middle_end/flambda2/from_lambda/lambda_to_flambda_primitives.ml +++ b/middle_end/flambda2/from_lambda/lambda_to_flambda_primitives.ml @@ -2319,6 +2319,8 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list) atomic ) ] | Patomic_exchange, [[atomic]; [new_value]] -> [Binary (Atomic_exchange, atomic, new_value)] + | Patomic_compare_exchange, [[atomic]; [old_value]; [new_value]] -> + [Ternary (Atomic_compare_exchange, atomic, old_value, new_value)] | Patomic_cas, [[atomic]; [old_value]; [new_value]] -> [Ternary (Atomic_compare_and_set, atomic, old_value, new_value)] | Patomic_fetch_add, [[atomic]; [i]] -> @@ -2443,7 +2445,8 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list) | Pfloatarray_set_128 _ | Pfloat_array_set_128 _ | Pint_array_set_128 _ | Punboxed_float_array_set_128 _ | Punboxed_float32_array_set_128 _ | Punboxed_int32_array_set_128 _ | Punboxed_int64_array_set_128 _ - | Punboxed_nativeint_array_set_128 _ | Patomic_cas ), + | Punboxed_nativeint_array_set_128 _ | Patomic_cas + | Patomic_compare_exchange ), ( [] | [_] | [_; _] diff --git a/middle_end/flambda2/parser/flambda_to_fexpr.ml b/middle_end/flambda2/parser/flambda_to_fexpr.ml index 31fb631a928..d0e0c5e0863 100644 --- a/middle_end/flambda2/parser/flambda_to_fexpr.ml +++ b/middle_end/flambda2/parser/flambda_to_fexpr.ml @@ -645,7 +645,7 @@ let ternop env (op : Flambda_primitive.ternary_primitive) : Fexpr.ternop = let ask = fexpr_of_array_set_kind env ask in Array_set (ak, ask) | Bytes_or_bigstring_set (blv, saw) -> Bytes_or_bigstring_set (blv, saw) - | Bigarray_set _ | Atomic_compare_and_set -> + | Bigarray_set _ | Atomic_compare_and_set | Atomic_compare_exchange -> Misc.fatal_errorf "TODO: Ternary primitive: %a" Flambda_primitive.Without_args.print (Flambda_primitive.Without_args.Ternary op) diff --git a/middle_end/flambda2/simplify/simplify_ternary_primitive.ml b/middle_end/flambda2/simplify/simplify_ternary_primitive.ml index 112dea064cc..444886f1560 100644 --- a/middle_end/flambda2/simplify/simplify_ternary_primitive.ml +++ b/middle_end/flambda2/simplify/simplify_ternary_primitive.ml @@ -73,6 +73,12 @@ let simplify_atomic_compare_and_set ~original_prim dacc ~original_term _dbg (P.result_kind' original_prim) ~original_term +let simplify_atomic_compare_exchange ~original_prim dacc ~original_term _dbg + ~arg1:_ ~arg1_ty:_ ~arg2:_ ~arg2_ty:_ ~arg3:_ ~arg3_ty:_ ~result_var = + SPR.create_unknown dacc ~result_var + (P.result_kind' original_prim) + ~original_term + let simplify_ternary_primitive dacc original_prim (prim : P.ternary_primitive) ~arg1 ~arg1_ty ~arg2 ~arg2_ty ~arg3 ~arg3_ty dbg ~result_var = let original_term = Named.create_prim original_prim dbg in @@ -84,6 +90,7 @@ let simplify_ternary_primitive dacc original_prim (prim : P.ternary_primitive) | Bigarray_set (num_dimensions, bigarray_kind, bigarray_layout) -> simplify_bigarray_set ~num_dimensions bigarray_kind bigarray_layout | Atomic_compare_and_set -> simplify_atomic_compare_and_set ~original_prim + | Atomic_compare_exchange -> simplify_atomic_compare_exchange ~original_prim in simplifier dacc ~original_term dbg ~arg1 ~arg1_ty ~arg2 ~arg2_ty ~arg3 ~arg3_ty ~result_var diff --git a/middle_end/flambda2/terms/code_size.ml b/middle_end/flambda2/terms/code_size.ml index 00363d8dcd8..05c3f1814e4 100644 --- a/middle_end/flambda2/terms/code_size.ml +++ b/middle_end/flambda2/terms/code_size.ml @@ -414,7 +414,8 @@ let ternary_prim_size prim = 5 (* ~ 3 block_load + 2 block_set *) | Bigarray_set (_dims, _kind, _layout) -> 2 (* ~ 1 block_load + 1 block_set *) - | Atomic_compare_and_set -> does_not_need_caml_c_call_extcall_size + | Atomic_compare_and_set | Atomic_compare_exchange -> + does_not_need_caml_c_call_extcall_size let variadic_prim_size prim args = match (prim : Flambda_primitive.variadic_primitive) with diff --git a/middle_end/flambda2/terms/flambda_primitive.ml b/middle_end/flambda2/terms/flambda_primitive.ml index 96a85b41c66..e34a17beedb 100644 --- a/middle_end/flambda2/terms/flambda_primitive.ml +++ b/middle_end/flambda2/terms/flambda_primitive.ml @@ -1856,11 +1856,12 @@ type ternary_primitive = | Bytes_or_bigstring_set of bytes_like_value * string_accessor_width | Bigarray_set of num_dimensions * Bigarray_kind.t * Bigarray_layout.t | Atomic_compare_and_set + | Atomic_compare_exchange let ternary_primitive_eligible_for_cse p = match p with | Array_set _ | Bytes_or_bigstring_set _ | Bigarray_set _ - | Atomic_compare_and_set -> + | Atomic_compare_and_set | Atomic_compare_exchange -> false let compare_ternary_primitive p1 p2 = @@ -1870,6 +1871,7 @@ let compare_ternary_primitive p1 p2 = | Bytes_or_bigstring_set _ -> 1 | Bigarray_set _ -> 2 | Atomic_compare_and_set -> 3 + | Atomic_compare_exchange -> 4 in match p1, p2 with | Array_set (kind1, set_kind1), Array_set (kind2, set_kind2) -> @@ -1888,7 +1890,7 @@ let compare_ternary_primitive p1 p2 = let c = Stdlib.compare kind1 kind2 in if c <> 0 then c else Stdlib.compare layout1 layout2 | ( ( Array_set _ | Bytes_or_bigstring_set _ | Bigarray_set _ - | Atomic_compare_and_set ), + | Atomic_compare_and_set | Atomic_compare_exchange ), _ ) -> Stdlib.compare (ternary_primitive_numbering p1) @@ -1910,6 +1912,7 @@ let print_ternary_primitive ppf p = "@[(Bigarray_set (num_dimensions@ %d)@ (kind@ %a)@ (layout@ %a))@]" num_dimensions Bigarray_kind.print kind Bigarray_layout.print layout | Atomic_compare_and_set -> fprintf ppf "Atomic_compare_and_set" + | Atomic_compare_exchange -> fprintf ppf "Atomic_compare_exchange" let args_kind_of_ternary_primitive p = match p with @@ -1939,12 +1942,13 @@ let args_kind_of_ternary_primitive p = bigstring_kind, bytes_or_bigstring_index_kind, K.naked_vec128 | Bigarray_set (_, kind, _) -> bigarray_kind, bigarray_index_kind, Bigarray_kind.element_kind kind - | Atomic_compare_and_set -> K.value, K.value, K.value + | Atomic_compare_and_set | Atomic_compare_exchange -> + K.value, K.value, K.value let result_kind_of_ternary_primitive p : result_kind = match p with | Array_set _ | Bytes_or_bigstring_set _ | Bigarray_set _ -> Unit - | Atomic_compare_and_set -> Singleton K.value + | Atomic_compare_and_set | Atomic_compare_exchange -> Singleton K.value let effects_and_coeffects_of_ternary_primitive p : Effects.t * Coeffects.t * Placement.t = @@ -1952,30 +1956,31 @@ let effects_and_coeffects_of_ternary_primitive p : | Array_set _ -> writing_to_an_array | Bytes_or_bigstring_set _ -> writing_to_bytes_or_bigstring | Bigarray_set (_, kind, _) -> writing_to_a_bigarray kind - | Atomic_compare_and_set -> Arbitrary_effects, Has_coeffects, Strict + | Atomic_compare_and_set | Atomic_compare_exchange -> + Arbitrary_effects, Has_coeffects, Strict let ternary_classify_for_printing p = match p with | Array_set _ | Bytes_or_bigstring_set _ | Bigarray_set _ - | Atomic_compare_and_set -> + | Atomic_compare_and_set | Atomic_compare_exchange -> Neither let free_names_ternary_primitive p = match p with | Array_set _ | Bytes_or_bigstring_set _ | Bigarray_set _ - | Atomic_compare_and_set -> + | Atomic_compare_and_set | Atomic_compare_exchange -> Name_occurrences.empty let apply_renaming_ternary_primitive p _ = match p with | Array_set _ | Bytes_or_bigstring_set _ | Bigarray_set _ - | Atomic_compare_and_set -> + | Atomic_compare_and_set | Atomic_compare_exchange -> p let ids_for_export_ternary_primitive p = match p with | Array_set _ | Bytes_or_bigstring_set _ | Bigarray_set _ - | Atomic_compare_and_set -> + | Atomic_compare_and_set | Atomic_compare_exchange -> Ids_for_export.empty type variadic_primitive = diff --git a/middle_end/flambda2/terms/flambda_primitive.mli b/middle_end/flambda2/terms/flambda_primitive.mli index 9087da2caa2..d59e2f17fae 100644 --- a/middle_end/flambda2/terms/flambda_primitive.mli +++ b/middle_end/flambda2/terms/flambda_primitive.mli @@ -510,6 +510,7 @@ type ternary_primitive = | Bytes_or_bigstring_set of bytes_like_value * string_accessor_width | Bigarray_set of num_dimensions * Bigarray_kind.t * Bigarray_layout.t | Atomic_compare_and_set + | Atomic_compare_exchange (** Primitives taking zero or more arguments. *) type variadic_primitive = diff --git a/middle_end/flambda2/to_cmm/to_cmm_primitive.ml b/middle_end/flambda2/to_cmm/to_cmm_primitive.ml index c8b39c8ec92..5b9262728a0 100644 --- a/middle_end/flambda2/to_cmm/to_cmm_primitive.ml +++ b/middle_end/flambda2/to_cmm/to_cmm_primitive.ml @@ -1018,6 +1018,8 @@ let ternary_primitive _env dbg f x y z = bigarray_store ~dbg kind ~bigarray:x ~index:y ~new_value:z | Atomic_compare_and_set -> C.atomic_compare_and_set ~dbg x ~old_value:y ~new_value:z + | Atomic_compare_exchange -> + C.atomic_compare_exchange ~dbg x ~old_value:y ~new_value:z let variadic_primitive _env dbg f args = match (f : P.variadic_primitive) with diff --git a/runtime/memory.c b/runtime/memory.c index 6cf49f0a921..fade3ec6c55 100644 --- a/runtime/memory.c +++ b/runtime/memory.c @@ -376,16 +376,16 @@ CAMLprim value caml_atomic_exchange (value ref, value v) return ret; } -CAMLprim value caml_atomic_cas (value ref, value oldv, value newv) +CAMLprim value caml_atomic_compare_exchange (value ref, value oldv, value newv) { if (caml_domain_alone()) { value* p = Op_val(ref); if (*p == oldv) { *p = newv; write_barrier(ref, 0, oldv, newv); - return Val_int(1); + return oldv; } else { - return Val_int(0); + return *p; } } else { atomic_value* p = &Op_atomic_val(ref)[0]; @@ -393,10 +393,17 @@ CAMLprim value caml_atomic_cas (value ref, value oldv, value newv) atomic_thread_fence(memory_order_release); /* generates `dmb ish` on Arm64*/ if (cas_ret) { write_barrier(ref, 0, oldv, newv); - return Val_int(1); - } else { - return Val_int(0); } + return oldv; + } +} + +CAMLprim value caml_atomic_cas (value ref, value oldv, value newv) +{ + if (caml_atomic_compare_exchange(ref, oldv, newv) == oldv) { + return Val_true; + } else { + return Val_false; } } diff --git a/runtime4/misc.c b/runtime4/misc.c index 2e0a60fe255..8a60aa2b4f1 100644 --- a/runtime4/misc.c +++ b/runtime4/misc.c @@ -252,14 +252,23 @@ CAMLprim value caml_atomic_load(value ref) return Field(ref, 0); } -CAMLprim value caml_atomic_cas(value ref, value oldv, value newv) +CAMLprim value caml_atomic_compare_exchange(value ref, value oldv, value newv) { value* p = Op_val(ref); if (*p == oldv) { caml_modify(p, newv); - return Val_int(1); + return oldv; + } else { + return *p; + } +} + +CAMLprim value caml_atomic_cas(value ref, value oldv, value newv) +{ + if (caml_atomic_compare_exchange(ref, oldv, newv) == oldv) { + return Val_true; } else { - return Val_int(0); + return Val_false; } } diff --git a/stdlib/atomic.ml b/stdlib/atomic.ml index 652a33be39d..9a60e2c0b44 100644 --- a/stdlib/atomic.ml +++ b/stdlib/atomic.ml @@ -19,6 +19,7 @@ external make_contended : 'a -> 'a t = "caml_atomic_make_contended" external get : 'a t -> 'a = "%atomic_load" external exchange : 'a t -> 'a -> 'a = "%atomic_exchange" external compare_and_set : 'a t -> 'a -> 'a -> bool = "%atomic_cas" +external compare_exchange : 'a t -> 'a -> 'a -> 'a = "%atomic_compare_exchange" external fetch_and_add : int t -> int -> int = "%atomic_fetch_add" external ignore : 'a -> unit = "%ignore" diff --git a/stdlib/atomic.mli b/stdlib/atomic.mli index 2d6db297e7c..ef6cab8f7df 100644 --- a/stdlib/atomic.mli +++ b/stdlib/atomic.mli @@ -61,6 +61,11 @@ val exchange : 'a t -> 'a -> 'a otherwise. *) val compare_and_set : 'a t -> 'a -> 'a -> bool +(** [compare_exchange r seen v] sets the new value of [r] to [v] only + if its current value is physically equal to [seen] -- the comparison + and the set occur atomically. Returns the previous value. *) +val compare_exchange : 'a t -> 'a -> 'a -> 'a + (** [fetch_and_add r n] atomically increments the value of [r] by [n], and returns the current value (before the increment). *) val fetch_and_add : int t -> int -> int diff --git a/testsuite/tests/lib-atomic/test_atomic_cmpxchg.ml b/testsuite/tests/lib-atomic/test_atomic_cmpxchg.ml new file mode 100644 index 00000000000..d14896d57d5 --- /dev/null +++ b/testsuite/tests/lib-atomic/test_atomic_cmpxchg.ml @@ -0,0 +1,18 @@ +(* TEST *) + +let r = Atomic.make 1 +let () = assert (Atomic.get r = 1) + +let () = Atomic.set r 2 +let () = assert (Atomic.get r = 2) + +let () = assert (Atomic.exchange r 3 = 2) + +let () = assert (Atomic.compare_exchange r 3 4 = 3) +let () = assert (Atomic.get r = 4) + +let () = assert (Atomic.compare_exchange r 3 (-4) = 4) +let () = assert (Atomic.get r = 4) + +let () = assert (Atomic.compare_exchange r 3 4 = 4) +let () = assert (Atomic.get r = 4)