Skip to content

Commit

Permalink
Review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Sushisource committed Dec 4, 2024
1 parent 478324f commit 4e356fc
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 90 deletions.
119 changes: 65 additions & 54 deletions src/Temporalio/Bridge/CustomSlotSupplier.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Temporalio.Bridge.Interop;

namespace Temporalio.Bridge
{
Expand Down Expand Up @@ -37,36 +37,19 @@ internal unsafe CustomSlotSupplier(
try_reserve = FunctionPointer<Interop.CustomTryReserveSlotCallback>(TryReserve),
mark_used = FunctionPointer<Interop.CustomMarkSlotUsedCallback>(MarkUsed),
release = FunctionPointer<Interop.CustomReleaseSlotCallback>(Release),
free = FunctionPointer<Interop.CustomSlotImplFreeCallback>(Free),
};

PinCallbackHolder(interopCallbacks);
}

private static void SetCancelTokenOnCtx(ref SlotReserveCtx ctx, CancellationTokenSource cancelTokenSrc)
private unsafe void Reserve(Interop.SlotReserveCtx* ctx, void* sender)
{
unsafe
{
try
{
var handle = GCHandle.Alloc(cancelTokenSrc);
fixed (Interop.SlotReserveCtx* p = &ctx)
{
Interop.Methods.set_reserve_cancel_target(p, GCHandle.ToIntPtr(handle).ToPointer());
}
}
catch (Exception e)
{
Console.WriteLine($"Error setting cancel token on ctx: {e}");
throw;
}
}
}

private unsafe void Reserve(Interop.SlotReserveCtx ctx, void* sender)
{
SafeReserve(ctx, new IntPtr(sender));
SafeReserve(new IntPtr(ctx), new IntPtr(sender));
}

// Note that this is always called by Rust, either because the call is cancelled or because
// it completed. Therefore the GCHandle is always freed.
private unsafe void CancelReserve(void* tokenSrc)
{
var handle = GCHandle.FromIntPtr(new IntPtr(tokenSrc));
Expand All @@ -75,44 +58,56 @@ private unsafe void CancelReserve(void* tokenSrc)
handle.Free();
}

private void SafeReserve(Interop.SlotReserveCtx ctx, IntPtr sender)
private void SafeReserve(IntPtr ctx, IntPtr sender)
{
var reserveTask = Task.Run(async () =>
_ = Task.Run(async () =>
{
var cancelTokenSrc = new System.Threading.CancellationTokenSource();
SetCancelTokenOnCtx(ref ctx, cancelTokenSrc);
while (true)
using (var cancelTokenSrc = new System.Threading.CancellationTokenSource())
{
try
unsafe
{
var permit = await userSupplier.ReserveSlotAsync(
new(ctx), cancelTokenSrc.Token).ConfigureAwait(false);
var usedPermitId = AddPermitToMap(permit);
unsafe
{
Interop.Methods.complete_async_reserve(sender.ToPointer(), new(usedPermitId));
}
cancelTokenSrc.Dispose();
return;
var srcHandle = GCHandle.Alloc(cancelTokenSrc);
Interop.Methods.set_reserve_cancel_target(
(Interop.SlotReserveCtx*)ctx.ToPointer(),
GCHandle.ToIntPtr(srcHandle).ToPointer());
}
catch (OperationCanceledException)
while (true)
{
cancelTokenSrc.Dispose();
return;
}
try
{
ConfiguredTaskAwaitable<Temporalio.Worker.Tuning.ISlotPermit> reserveTask;
unsafe
{
reserveTask = userSupplier.ReserveSlotAsync(
new((Interop.SlotReserveCtx*)ctx.ToPointer()),
cancelTokenSrc.Token).ConfigureAwait(false);
}
var permit = await reserveTask;
unsafe
{
var usedPermitId = AddPermitToMap(permit);
Interop.Methods.complete_async_reserve(sender.ToPointer(), new(usedPermitId));
}
return;
}
catch (OperationCanceledException)
{
return;
}
#pragma warning disable CA1031 // We are ok catching all exceptions here
catch (Exception e)
{
catch (Exception e)
{
#pragma warning restore CA1031
logger.LogError(e, "Error reserving slot");
logger.LogError(e, "Error reserving slot");
}
// Wait for a bit to avoid spamming errors
await Task.Delay(1000, cancelTokenSrc.Token).ConfigureAwait(false);
}
// Wait for a bit to avoid spamming errors
await Task.Delay(1000, cancelTokenSrc.Token).ConfigureAwait(false);
}
});
}

private unsafe UIntPtr TryReserve(Interop.SlotReserveCtx ctx)
private unsafe UIntPtr TryReserve(Interop.SlotReserveCtx* ctx)
{
Temporalio.Worker.Tuning.ISlotPermit? maybePermit;
try
Expand All @@ -135,11 +130,16 @@ private unsafe UIntPtr TryReserve(Interop.SlotReserveCtx ctx)
return new(usedPermitId);
}

private void MarkUsed(Interop.SlotMarkUsedCtx ctx)
private unsafe void MarkUsed(Interop.SlotMarkUsedCtx* ctx)
{
try
{
userSupplier.MarkSlotUsed(new(ctx, permits[ctx.slot_permit.ToUInt32()]));
Temporalio.Worker.Tuning.ISlotPermit permit;
lock (permits)
{
permit = permits[(*ctx).slot_permit.ToUInt32()];
}
userSupplier.MarkSlotUsed(new(ctx, permit));
}
#pragma warning disable CA1031 // We are ok catching all exceptions here
catch (Exception e)
Expand All @@ -149,20 +149,31 @@ private void MarkUsed(Interop.SlotMarkUsedCtx ctx)
}
}

private void Release(Interop.SlotReleaseCtx ctx)
private unsafe void Release(Interop.SlotReleaseCtx* ctx)
{
var permitId = ctx.slot_permit.ToUInt32();
var permitId = (*ctx).slot_permit.ToUInt32();
Temporalio.Worker.Tuning.ISlotPermit permit;
lock (permits)
{
permit = permits[permitId];
}
try
{
userSupplier.ReleaseSlot(new(ctx, permits[permitId]));
userSupplier.ReleaseSlot(new(ctx, permit));
}
#pragma warning disable CA1031 // We are ok catching all exceptions here
catch (Exception e)
{
#pragma warning restore CA1031
logger.LogError(e, "Error releasing slot");
}
permits.Remove(permitId);
finally
{
lock (permits)
{
permits.Remove(permitId);
}
}
}

private uint AddPermitToMap(Temporalio.Worker.Tuning.ISlotPermit permit)
Expand Down
22 changes: 14 additions & 8 deletions src/Temporalio/Bridge/Interop/Interop.cs
Original file line number Diff line number Diff line change
Expand Up @@ -579,14 +579,14 @@ internal unsafe partial struct SlotReserveCtx
}

[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
internal unsafe delegate void CustomReserveSlotCallback([NativeTypeName("struct SlotReserveCtx")] SlotReserveCtx ctx, void* sender);
internal unsafe delegate void CustomReserveSlotCallback([NativeTypeName("const struct SlotReserveCtx *")] SlotReserveCtx* ctx, void* sender);

[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
internal unsafe delegate void CustomCancelReserveCallback(void* token_source);

[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
[return: NativeTypeName("uintptr_t")]
internal delegate UIntPtr CustomTryReserveSlotCallback([NativeTypeName("struct SlotReserveCtx")] SlotReserveCtx ctx);
internal unsafe delegate UIntPtr CustomTryReserveSlotCallback([NativeTypeName("const struct SlotReserveCtx *")] SlotReserveCtx* ctx);

internal enum SlotInfo_Tag
{
Expand Down Expand Up @@ -680,7 +680,7 @@ internal partial struct SlotMarkUsedCtx
}

[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
internal delegate void CustomMarkSlotUsedCallback([NativeTypeName("struct SlotMarkUsedCtx")] SlotMarkUsedCtx ctx);
internal unsafe delegate void CustomMarkSlotUsedCallback([NativeTypeName("const struct SlotMarkUsedCtx *")] SlotMarkUsedCtx* ctx);

internal unsafe partial struct SlotReleaseCtx
{
Expand All @@ -692,7 +692,10 @@ internal unsafe partial struct SlotReleaseCtx
}

[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
internal delegate void CustomReleaseSlotCallback([NativeTypeName("struct SlotReleaseCtx")] SlotReleaseCtx ctx);
internal unsafe delegate void CustomReleaseSlotCallback([NativeTypeName("const struct SlotReleaseCtx *")] SlotReleaseCtx* ctx);

[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
internal unsafe delegate void CustomSlotImplFreeCallback([NativeTypeName("const struct CustomSlotSupplierCallbacks *")] CustomSlotSupplierCallbacks* userimpl);

internal partial struct CustomSlotSupplierCallbacks
{
Expand All @@ -710,6 +713,9 @@ internal partial struct CustomSlotSupplierCallbacks

[NativeTypeName("CustomReleaseSlotCallback")]
public IntPtr release;

[NativeTypeName("CustomSlotImplFreeCallback")]
public IntPtr free;
}

internal unsafe partial struct CustomSlotSupplierCallbacksImpl
Expand All @@ -729,7 +735,7 @@ internal unsafe partial struct SlotSupplier
{
public SlotSupplier_Tag tag;

[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L472_C3")]
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L475_C3")]
public _Anonymous_e__Union Anonymous;

internal ref FixedSizeSlotSupplier fixed_size
Expand Down Expand Up @@ -769,15 +775,15 @@ internal ref CustomSlotSupplierCallbacksImpl custom
internal unsafe partial struct _Anonymous_e__Union
{
[FieldOffset(0)]
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L473_C5")]
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L476_C5")]
public _Anonymous1_e__Struct Anonymous1;

[FieldOffset(0)]
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L476_C5")]
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L479_C5")]
public _Anonymous2_e__Struct Anonymous2;

[FieldOffset(0)]
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L479_C5")]
[NativeTypeName("__AnonymousRecord_temporal-sdk-bridge_L482_C5")]
public _Anonymous3_e__Struct Anonymous3;

internal partial struct _Anonymous1_e__Struct
Expand Down
10 changes: 5 additions & 5 deletions src/Temporalio/Bridge/NativeInvokeableClass.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace Temporalio.Bridge
/// Extend this class to help with making a class that has callbacks which are invoked by Rust.
/// </summary>
/// <typeparam name="T">The native type that holds the function ptrs for callbacks to C#.</typeparam>
internal class NativeInvokeableClass<T>
internal abstract class NativeInvokeableClass<T>
where T : unmanaged
{
private readonly List<GCHandle> handles = new();
Expand All @@ -23,7 +23,7 @@ internal class NativeInvokeableClass<T>
/// the callbacks via <see cref="FunctionPointer"/>. Also adds `this` to the handle list.
/// </summary>
/// <param name="value">The native type to pin.</param>
internal void PinCallbackHolder(T value)
private protected void PinCallbackHolder(T value)
{
// Pin the callback holder & set it as the first handle
var holderHandle = GCHandle.Alloc(value, GCHandleType.Pinned);
Expand All @@ -42,7 +42,7 @@ internal void PinCallbackHolder(T value)
/// <typeparam name="TF">The native type of the function pointer.</typeparam>
/// <param name="func">The C# method to use for the callback.</param>
/// <returns>The function pointer to the C# method.</returns>
internal IntPtr FunctionPointer<TF>(TF func)
private protected IntPtr FunctionPointer<TF>(TF func)
where TF : Delegate
{
var handle = GCHandle.Alloc(func);
Expand All @@ -53,8 +53,8 @@ internal IntPtr FunctionPointer<TF>(TF func)
/// <summary>
/// Free the memory of the native type and all the function pointers.
/// </summary>
/// <param name="meter">The native type to free.</param>
internal unsafe void Free(T* meter)
/// <param name="ptr">The native type to free.</param>
private protected unsafe void Free(T* ptr)
{
// Free in order which frees function pointers first then object handles
foreach (var handle in handles)
Expand Down
11 changes: 7 additions & 4 deletions src/Temporalio/Bridge/include/temporal-sdk-bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -392,14 +392,14 @@ typedef struct SlotReserveCtx {
void *token_src;
} SlotReserveCtx;

typedef void (*CustomReserveSlotCallback)(struct SlotReserveCtx ctx, void *sender);
typedef void (*CustomReserveSlotCallback)(const struct SlotReserveCtx *ctx, void *sender);

typedef void (*CustomCancelReserveCallback)(void *token_source);

/**
* Must return C#-tracked id for the permit. A zero value means no permit was reserved.
*/
typedef uintptr_t (*CustomTryReserveSlotCallback)(struct SlotReserveCtx ctx);
typedef uintptr_t (*CustomTryReserveSlotCallback)(const struct SlotReserveCtx *ctx);

typedef enum SlotInfo_Tag {
WorkflowSlotInfo,
Expand Down Expand Up @@ -437,7 +437,7 @@ typedef struct SlotMarkUsedCtx {
uintptr_t slot_permit;
} SlotMarkUsedCtx;

typedef void (*CustomMarkSlotUsedCallback)(struct SlotMarkUsedCtx ctx);
typedef void (*CustomMarkSlotUsedCallback)(const struct SlotMarkUsedCtx *ctx);

typedef struct SlotReleaseCtx {
const struct SlotInfo *slot_info;
Expand All @@ -447,14 +447,17 @@ typedef struct SlotReleaseCtx {
uintptr_t slot_permit;
} SlotReleaseCtx;

typedef void (*CustomReleaseSlotCallback)(struct SlotReleaseCtx ctx);
typedef void (*CustomReleaseSlotCallback)(const struct SlotReleaseCtx *ctx);

typedef void (*CustomSlotImplFreeCallback)(const struct CustomSlotSupplierCallbacks *userimpl);

typedef struct CustomSlotSupplierCallbacks {
CustomReserveSlotCallback reserve;
CustomCancelReserveCallback cancel_reserve;
CustomTryReserveSlotCallback try_reserve;
CustomMarkSlotUsedCallback mark_used;
CustomReleaseSlotCallback release;
CustomSlotImplFreeCallback free;
} CustomSlotSupplierCallbacks;

typedef struct CustomSlotSupplierCallbacksImpl {
Expand Down
Loading

0 comments on commit 4e356fc

Please sign in to comment.