diff --git a/sdk-api-kotlin/build.gradle.kts b/sdk-api-kotlin/build.gradle.kts index 705abf2b..24485614 100644 --- a/sdk-api-kotlin/build.gradle.kts +++ b/sdk-api-kotlin/build.gradle.kts @@ -22,6 +22,7 @@ dependencies { testImplementation(testingLibs.junit.jupiter) testImplementation(testingLibs.assertj) testImplementation(coreLibs.log4j.core) + testImplementation(coreLibs.protobuf.java) testImplementation(project(":sdk-core", "testArchive")) } diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/Awaitables.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/Awaitables.kt index a174afd4..315aefb1 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/Awaitables.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/Awaitables.kt @@ -8,11 +8,11 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.kotlin -import com.google.protobuf.ByteString import dev.restate.sdk.common.Serde import dev.restate.sdk.common.syscalls.Deferred import dev.restate.sdk.common.syscalls.Result import dev.restate.sdk.common.syscalls.Syscalls +import java.nio.ByteBuffer import kotlinx.coroutines.CancellableContinuation import kotlinx.coroutines.suspendCancellableCoroutine @@ -79,14 +79,14 @@ internal abstract class BaseSingleMappedAwaitableImpl( internal open class SingleSerdeAwaitableImpl internal constructor( syscalls: Syscalls, - deferred: Deferred, + deferred: Deferred, private val serde: Serde, ) : - BaseSingleMappedAwaitableImpl( + BaseSingleMappedAwaitableImpl( SingleAwaitableImpl(syscalls, deferred), ) { @Suppress("UNCHECKED_CAST") - override suspend fun map(res: Result): Result { + override suspend fun map(res: Result): Result { return if (res.isSuccess) { // This propagates exceptions as non-terminal Result.success(serde.deserializeWrappingException(syscalls, res.value!!)) @@ -151,7 +151,7 @@ internal fun wrapAnyAwaitable(awaitables: List>): AnyAwaitable { internal class AwakeableImpl internal constructor( syscalls: Syscalls, - deferred: Deferred, + deferred: Deferred, serde: Serde, override val id: String ) : SingleSerdeAwaitableImpl(syscalls, deferred, serde), Awakeable {} diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt index 5268e249..b5a5524f 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt @@ -8,13 +8,13 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.kotlin -import com.google.protobuf.ByteString import dev.restate.sdk.common.* import dev.restate.sdk.common.Target import dev.restate.sdk.common.syscalls.Deferred import dev.restate.sdk.common.syscalls.EnterSideEffectSyscallCallback import dev.restate.sdk.common.syscalls.ExitSideEffectSyscallCallback import dev.restate.sdk.common.syscalls.Syscalls +import java.nio.ByteBuffer import kotlin.coroutines.resume import kotlin.time.Duration import kotlin.time.toJavaDuration @@ -33,8 +33,8 @@ internal class ContextImpl internal constructor(private val syscalls: Syscalls) } override suspend fun get(key: StateKey): T? { - val deferred: Deferred = - suspendCancellableCoroutine { cont: CancellableContinuation> -> + val deferred: Deferred = + suspendCancellableCoroutine { cont: CancellableContinuation> -> syscalls.get(key.name(), completingContinuation(cont)) } @@ -109,8 +109,8 @@ internal class ContextImpl internal constructor(private val syscalls: Syscalls) ): Awaitable { val input = inputSerde.serializeWrappingException(syscalls, parameter) - val deferred: Deferred = - suspendCancellableCoroutine { cont: CancellableContinuation> -> + val deferred: Deferred = + suspendCancellableCoroutine { cont: CancellableContinuation> -> syscalls.call(target, input, completingContinuation(cont)) } @@ -136,19 +136,19 @@ internal class ContextImpl internal constructor(private val syscalls: Syscalls) block: suspend () -> T ): T { val exitResult = - suspendCancellableCoroutine { cont: CancellableContinuation> + suspendCancellableCoroutine { cont: CancellableContinuation> -> syscalls.enterSideEffectBlock( name, object : EnterSideEffectSyscallCallback { - override fun onSuccess(t: ByteString?) { - val deferred: CompletableDeferred = CompletableDeferred() + override fun onSuccess(t: ByteBuffer?) { + val deferred: CompletableDeferred = CompletableDeferred() deferred.complete(t!!) cont.resume(deferred) } override fun onFailure(t: TerminalException) { - val deferred: CompletableDeferred = CompletableDeferred() + val deferred: CompletableDeferred = CompletableDeferred() deferred.completeExceptionally(t) cont.resume(deferred) } @@ -182,7 +182,7 @@ internal class ContextImpl internal constructor(private val syscalls: Syscalls) val exitCallback = object : ExitSideEffectSyscallCallback { - override fun onSuccess(t: ByteString?) { + override fun onSuccess(t: ByteBuffer?) { exitResult.complete(t!!) } @@ -208,7 +208,7 @@ internal class ContextImpl internal constructor(private val syscalls: Syscalls) override suspend fun awakeable(serde: Serde): Awakeable { val (aid, deferredResult) = suspendCancellableCoroutine { - cont: CancellableContinuation>> -> + cont: CancellableContinuation>> -> syscalls.awakeable(completingContinuation(cont)) } @@ -234,8 +234,8 @@ internal class ContextImpl internal constructor(private val syscalls: Syscalls) inner class DurablePromiseImpl(private val key: DurablePromiseKey) : DurablePromise { override suspend fun awaitable(): Awaitable { - val deferred: Deferred = - suspendCancellableCoroutine { cont: CancellableContinuation> -> + val deferred: Deferred = + suspendCancellableCoroutine { cont: CancellableContinuation> -> syscalls.promise(key.name(), completingContinuation(cont)) } @@ -243,8 +243,8 @@ internal class ContextImpl internal constructor(private val syscalls: Syscalls) } override suspend fun peek(): T? { - val deferred: Deferred = - suspendCancellableCoroutine { cont: CancellableContinuation> -> + val deferred: Deferred = + suspendCancellableCoroutine { cont: CancellableContinuation> -> syscalls.peekPromise(key.name(), completingContinuation(cont)) } @@ -265,8 +265,8 @@ internal class ContextImpl internal constructor(private val syscalls: Syscalls) } override suspend fun isCompleted(): Boolean { - val deferred: Deferred = - suspendCancellableCoroutine { cont: CancellableContinuation> -> + val deferred: Deferred = + suspendCancellableCoroutine { cont: CancellableContinuation> -> syscalls.peekPromise(key.name(), completingContinuation(cont)) } diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/HandlerRunner.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/HandlerRunner.kt index 9c7ff1b5..a43aed08 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/HandlerRunner.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/HandlerRunner.kt @@ -8,12 +8,12 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.kotlin -import com.google.protobuf.ByteString import dev.restate.sdk.common.TerminalException import dev.restate.sdk.common.syscalls.HandlerSpecification import dev.restate.sdk.common.syscalls.SyscallCallback import dev.restate.sdk.common.syscalls.Syscalls import io.opentelemetry.extension.kotlin.asContextElement +import java.nio.ByteBuffer import kotlin.coroutines.CoroutineContext import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers @@ -45,7 +45,7 @@ internal constructor( handlerSpecification: HandlerSpecification, syscalls: Syscalls, options: Options?, - callback: SyscallCallback + callback: SyscallCallback ) { val ctx: Context = ContextImpl(syscalls) @@ -57,7 +57,7 @@ internal constructor( .asContextElement(syscalls) + syscalls.request().otelContext()!!.asContextElement()) scope.launch { - val serializedResult: ByteString + val serializedResult: ByteBuffer try { // Parse input @@ -77,7 +77,7 @@ internal constructor( // Serialize output try { - serializedResult = handlerSpecification.responseSerde.serializeToByteString(res) + serializedResult = handlerSpecification.responseSerde.serializeToByteBuffer(res) } catch (e: Error) { throw e } catch (e: Throwable) { diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/KtSerdes.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/KtSerdes.kt index 5a001c65..2a157a66 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/KtSerdes.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/KtSerdes.kt @@ -8,10 +8,10 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.kotlin -import com.google.protobuf.ByteString import dev.restate.sdk.common.DurablePromiseKey import dev.restate.sdk.common.Serde import dev.restate.sdk.common.StateKey +import java.nio.ByteBuffer import java.nio.charset.StandardCharsets import kotlin.reflect.typeOf import kotlinx.serialization.KSerializer @@ -51,15 +51,15 @@ object KtSerdes { return ByteArray(0) } - override fun serializeToByteString(value: Unit?): ByteString { - return ByteString.EMPTY + override fun serializeToByteBuffer(value: Unit?): ByteBuffer { + return ByteBuffer.allocate(0) } override fun deserialize(value: ByteArray) { return } - override fun deserialize(byteString: ByteString) { + override fun deserialize(byteBuffer: ByteBuffer) { return } @@ -71,12 +71,12 @@ object KtSerdes { /** Creates a [Serde] implementation using the `kotlinx.serialization` json module. */ fun json(serializer: KSerializer): Serde { return object : Serde { - override fun serialize(value: T?): ByteArray { + override fun serialize(value: T): ByteArray { return Json.encodeToString(serializer, value!!).encodeToByteArray() } - override fun deserialize(value: ByteArray?): T { - return Json.decodeFromString(serializer, String(value!!, StandardCharsets.UTF_8)) + override fun deserialize(value: ByteArray): T { + return Json.decodeFromString(serializer, String(value, StandardCharsets.UTF_8)) } override fun contentType(): String { diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/Util.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/Util.kt index 148e32f8..e9a6e360 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/Util.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/Util.kt @@ -8,10 +8,10 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.kotlin -import com.google.protobuf.ByteString import dev.restate.sdk.common.Serde import dev.restate.sdk.common.syscalls.SyscallCallback import dev.restate.sdk.common.syscalls.Syscalls +import java.nio.ByteBuffer import kotlin.coroutines.resume import kotlinx.coroutines.CancellableContinuation import kotlinx.coroutines.CancellationException @@ -32,9 +32,9 @@ internal fun completingUnitContinuation( internal fun Serde.serializeWrappingException( syscalls: Syscalls, value: T? -): ByteString? { +): ByteBuffer { return try { - this.serializeToByteString(value) + this.serializeToByteBuffer(value) } catch (e: Exception) { syscalls.fail(e) throw CancellationException("Failed serialization", e) @@ -43,10 +43,10 @@ internal fun Serde.serializeWrappingException( internal fun Serde.deserializeWrappingException( syscalls: Syscalls, - byteString: ByteString + ByteBuffer: ByteBuffer ): T { return try { - this.deserialize(byteString) + this.deserialize(ByteBuffer) } catch (e: Exception) { syscalls.fail(e) throw CancellationException("Failed deserialization", e) diff --git a/sdk-api/src/main/java/dev/restate/sdk/Awakeable.java b/sdk-api/src/main/java/dev/restate/sdk/Awakeable.java index b36bf5e2..2000628e 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/Awakeable.java +++ b/sdk-api/src/main/java/dev/restate/sdk/Awakeable.java @@ -8,11 +8,11 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk; -import com.google.protobuf.ByteString; import dev.restate.sdk.common.Serde; import dev.restate.sdk.common.syscalls.Deferred; import dev.restate.sdk.common.syscalls.Result; import dev.restate.sdk.common.syscalls.Syscalls; +import java.nio.ByteBuffer; /** * An {@link Awakeable} is a special type of {@link Awaitable} which can be arbitrarily completed by @@ -28,11 +28,11 @@ *

NOTE: This interface MUST NOT be accessed concurrently since it can lead to different * orderings of user actions, corrupting the execution of the invocation. */ -public final class Awakeable extends Awaitable.MappedAwaitable { +public final class Awakeable extends Awaitable.MappedAwaitable { private final String identifier; - Awakeable(Syscalls syscalls, Deferred deferred, Serde serde, String identifier) { + Awakeable(Syscalls syscalls, Deferred deferred, Serde serde, String identifier) { super( Awaitable.single(syscalls, deferred), res -> { diff --git a/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java b/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java index 68995e3a..72e40577 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java +++ b/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java @@ -8,13 +8,13 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk; -import com.google.protobuf.ByteString; import dev.restate.sdk.common.*; import dev.restate.sdk.common.function.ThrowingSupplier; import dev.restate.sdk.common.syscalls.Deferred; import dev.restate.sdk.common.syscalls.EnterSideEffectSyscallCallback; import dev.restate.sdk.common.syscalls.ExitSideEffectSyscallCallback; import dev.restate.sdk.common.syscalls.Syscalls; +import java.nio.ByteBuffer; import java.time.Duration; import java.util.Collection; import java.util.Map; @@ -43,7 +43,7 @@ public Request request() { @Override public Optional get(StateKey key) { - Deferred deferred = Util.blockOnSyscall(cb -> syscalls.get(key.name(), cb)); + Deferred deferred = Util.blockOnSyscall(cb -> syscalls.get(key.name(), cb)); if (!deferred.isCompleted()) { Util.blockOnSyscall(cb -> syscalls.resolveDeferred(deferred, cb)); @@ -91,27 +91,27 @@ public Awaitable timer(Duration duration) { @Override public Awaitable call( Target target, Serde inputSerde, Serde outputSerde, T parameter) { - ByteString input = Util.serializeWrappingException(syscalls, inputSerde, parameter); - Deferred result = Util.blockOnSyscall(cb -> syscalls.call(target, input, cb)); + ByteBuffer input = Util.serializeWrappingException(syscalls, inputSerde, parameter); + Deferred result = Util.blockOnSyscall(cb -> syscalls.call(target, input, cb)); return Awaitable.single(syscalls, result) .map(bs -> Util.deserializeWrappingException(syscalls, outputSerde, bs)); } @Override public void send(Target target, Serde inputSerde, T parameter) { - ByteString input = Util.serializeWrappingException(syscalls, inputSerde, parameter); + ByteBuffer input = Util.serializeWrappingException(syscalls, inputSerde, parameter); Util.blockOnSyscall(cb -> syscalls.send(target, input, null, cb)); } @Override public void send(Target target, Serde inputSerde, T parameter, Duration delay) { - ByteString input = Util.serializeWrappingException(syscalls, inputSerde, parameter); + ByteBuffer input = Util.serializeWrappingException(syscalls, inputSerde, parameter); Util.blockOnSyscall(cb -> syscalls.send(target, input, delay, cb)); } @Override public T run(String name, Serde serde, ThrowingSupplier action) { - CompletableFuture> enterFut = new CompletableFuture<>(); + CompletableFuture> enterFut = new CompletableFuture<>(); syscalls.enterSideEffectBlock( name, new EnterSideEffectSyscallCallback() { @@ -121,7 +121,7 @@ public void onNotExecuted() { } @Override - public void onSuccess(ByteString result) { + public void onSuccess(ByteBuffer result) { enterFut.complete(CompletableFuture.completedFuture(result)); } @@ -137,7 +137,7 @@ public void onCancel(Throwable t) { }); // If a failure was stored, it's simply thrown here - CompletableFuture exitFut = Util.awaitCompletableFuture(enterFut); + CompletableFuture exitFut = Util.awaitCompletableFuture(enterFut); if (exitFut.isDone()) { // We already have a result, we don't need to execute the action return Util.deserializeWrappingException( @@ -147,7 +147,7 @@ public void onCancel(Throwable t) { ExitSideEffectSyscallCallback exitCallback = new ExitSideEffectSyscallCallback() { @Override - public void onSuccess(ByteString result) { + public void onSuccess(ByteBuffer result) { exitFut.complete(result); } @@ -188,7 +188,7 @@ public void onCancel(@Nullable Throwable t) { @Override public Awakeable awakeable(Serde serde) throws TerminalException { // Retrieve the awakeable - Map.Entry> awakeable = Util.blockOnSyscall(syscalls::awakeable); + Map.Entry> awakeable = Util.blockOnSyscall(syscalls::awakeable); return new Awakeable<>(syscalls, awakeable.getValue(), serde, awakeable.getKey()); } @@ -221,14 +221,14 @@ public DurablePromise durablePromise(DurablePromiseKey key) { return new DurablePromise<>() { @Override public Awaitable awaitable() { - Deferred result = Util.blockOnSyscall(cb -> syscalls.promise(key.name(), cb)); + Deferred result = Util.blockOnSyscall(cb -> syscalls.promise(key.name(), cb)); return Awaitable.single(syscalls, result) .map(bs -> Util.deserializeWrappingException(syscalls, key.serde(), bs)); } @Override public Optional peek() { - Deferred deferred = + Deferred deferred = Util.blockOnSyscall(cb -> syscalls.peekPromise(key.name(), cb)); if (!deferred.isCompleted()) { @@ -241,7 +241,7 @@ public Optional peek() { @Override public boolean isCompleted() { - Deferred deferred = + Deferred deferred = Util.blockOnSyscall(cb -> syscalls.peekPromise(key.name(), cb)); if (!deferred.isCompleted()) { diff --git a/sdk-api/src/main/java/dev/restate/sdk/HandlerRunner.java b/sdk-api/src/main/java/dev/restate/sdk/HandlerRunner.java index cd0f00f9..5f6462b5 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/HandlerRunner.java +++ b/sdk-api/src/main/java/dev/restate/sdk/HandlerRunner.java @@ -8,12 +8,12 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk; -import com.google.protobuf.ByteString; import dev.restate.sdk.common.TerminalException; import dev.restate.sdk.common.syscalls.HandlerSpecification; import dev.restate.sdk.common.syscalls.SyscallCallback; import dev.restate.sdk.common.syscalls.Syscalls; import io.opentelemetry.context.Scope; +import java.nio.ByteBuffer; import java.util.concurrent.Executor; import java.util.concurrent.Executors; import java.util.function.BiConsumer; @@ -41,7 +41,7 @@ public void run( HandlerSpecification handlerSpecification, Syscalls syscalls, @Nullable Options options, - SyscallCallback callback) { + SyscallCallback callback) { if (options == null) { options = Options.DEFAULT; } @@ -92,9 +92,9 @@ public void run( } // Serialize output - ByteString serializedResult; + ByteBuffer serializedResult; try { - serializedResult = handlerSpecification.getResponseSerde().serializeToByteString(res); + serializedResult = handlerSpecification.getResponseSerde().serializeToByteBuffer(res); } catch (Error e) { throw e; } catch (Throwable e) { diff --git a/sdk-api/src/main/java/dev/restate/sdk/Util.java b/sdk-api/src/main/java/dev/restate/sdk/Util.java index fe055895..1851170e 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/Util.java +++ b/sdk-api/src/main/java/dev/restate/sdk/Util.java @@ -8,7 +8,6 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk; -import com.google.protobuf.ByteString; import dev.restate.sdk.common.AbortedExecutionException; import dev.restate.sdk.common.Serde; import dev.restate.sdk.common.function.ThrowingFunction; @@ -16,6 +15,7 @@ import dev.restate.sdk.common.syscalls.Result; import dev.restate.sdk.common.syscalls.SyscallCallback; import dev.restate.sdk.common.syscalls.Syscalls; +import java.nio.ByteBuffer; import java.util.Optional; import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; @@ -78,12 +78,12 @@ static R executeMappingException(Syscalls syscalls, ThrowingFunction ByteString serializeWrappingException(Syscalls syscalls, Serde serde, T value) { - return executeMappingException(syscalls, serde::serializeToByteString, value); + static ByteBuffer serializeWrappingException(Syscalls syscalls, Serde serde, T value) { + return executeMappingException(syscalls, serde::serializeToByteBuffer, value); } static T deserializeWrappingException( - Syscalls syscalls, Serde serde, ByteString byteString) { + Syscalls syscalls, Serde serde, ByteBuffer byteString) { return executeMappingException(syscalls, serde::deserialize, byteString); } } diff --git a/sdk-common/build.gradle.kts b/sdk-common/build.gradle.kts index 7eb3e0b5..e1d8e5a7 100644 --- a/sdk-common/build.gradle.kts +++ b/sdk-common/build.gradle.kts @@ -8,7 +8,6 @@ description = "Common interfaces of the Restate SDK" dependencies { compileOnly(coreLibs.jspecify) - api(coreLibs.protobuf.java) api(platform(coreLibs.opentelemetry.bom)) api(coreLibs.opentelemetry.api) diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/CoreSerdes.java b/sdk-common/src/main/java/dev/restate/sdk/common/CoreSerdes.java index 1f84c94b..b6c7741a 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/CoreSerdes.java +++ b/sdk-common/src/main/java/dev/restate/sdk/common/CoreSerdes.java @@ -12,12 +12,12 @@ import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.core.JsonToken; -import com.google.protobuf.ByteString; import dev.restate.sdk.common.function.ThrowingBiConsumer; import dev.restate.sdk.common.function.ThrowingFunction; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.Objects; import org.jspecify.annotations.Nullable; @@ -39,8 +39,8 @@ public byte[] serialize(Void value) { } @Override - public ByteString serializeToByteString(@Nullable Void value) { - return ByteString.EMPTY; + public ByteBuffer serializeToByteBuffer(@Nullable Void value) { + return ByteBuffer.allocate(0); } @Override @@ -49,7 +49,7 @@ public Void deserialize(byte[] value) { } @Override - public Void deserialize(ByteString byteString) { + public Void deserialize(ByteBuffer byteBuffer) { return null; } @@ -73,6 +73,39 @@ public byte[] deserialize(byte[] value) { } }; + /** Pass through {@link Serde} for {@link ByteBuffer}. */ + public static Serde BYTE_BUFFER = + new Serde<>() { + + @Override + public byte[] serialize(@Nullable ByteBuffer byteBuffer) { + if (byteBuffer == null) { + return new byte[] {}; + } + if (byteBuffer.hasArray()) { + return byteBuffer.array(); + } + byte[] bytes = new byte[byteBuffer.remaining()]; + byteBuffer.get(bytes); + return bytes; + } + + @Override + public ByteBuffer serializeToByteBuffer(@Nullable ByteBuffer value) { + return value; + } + + @Override + public ByteBuffer deserialize(byte[] value) { + return ByteBuffer.wrap(value); + } + + @Override + public ByteBuffer deserialize(ByteBuffer byteBuffer) { + return byteBuffer; + } + }; + /** {@link Serde} for {@link String}. This writes and reads {@link String} as JSON value. */ public static Serde JSON_STRING = usingJackson( diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/Request.java b/sdk-common/src/main/java/dev/restate/sdk/common/Request.java index 763d934a..1b220963 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/Request.java +++ b/sdk-common/src/main/java/dev/restate/sdk/common/Request.java @@ -8,8 +8,8 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.common; -import com.google.protobuf.ByteString; import io.opentelemetry.context.Context; +import java.nio.ByteBuffer; import java.util.Map; import java.util.Objects; @@ -17,13 +17,13 @@ public final class Request { private final InvocationId invocationId; private final Context otelContext; - private final ByteString body; + private final ByteBuffer body; private final Map headers; public Request( InvocationId invocationId, Context otelContext, - ByteString body, + ByteBuffer body, Map headers) { this.invocationId = invocationId; this.otelContext = otelContext; @@ -40,11 +40,11 @@ public Context otelContext() { } public byte[] body() { - return body.toByteArray(); + return CoreSerdes.BYTE_BUFFER.serialize(body); } - public ByteString bodyBuffer() { - return body; + public ByteBuffer bodyBuffer() { + return body.asReadOnlyBuffer(); } public Map headers() { diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/Serde.java b/sdk-common/src/main/java/dev/restate/sdk/common/Serde.java index 43bf2680..5e5d67c6 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/Serde.java +++ b/sdk-common/src/main/java/dev/restate/sdk/common/Serde.java @@ -8,9 +8,8 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.common; -import com.google.protobuf.ByteString; -import com.google.protobuf.UnsafeByteOperations; import dev.restate.sdk.common.function.ThrowingFunction; +import java.nio.ByteBuffer; import java.util.Objects; import org.jspecify.annotations.Nullable; @@ -23,15 +22,20 @@ public interface Serde { byte[] serialize(@Nullable T value); - default ByteString serializeToByteString(@Nullable T value) { + default ByteBuffer serializeToByteBuffer(@Nullable T value) { // This is safe because we don't mutate the generated byte[] afterward. - return UnsafeByteOperations.unsafeWrap(serialize(value)); + return ByteBuffer.wrap(serialize(value)); } T deserialize(byte[] value); - default T deserialize(ByteString byteString) { - return deserialize(byteString.toByteArray()); + default T deserialize(ByteBuffer byteBuffer) { + if (byteBuffer.hasArray()) { + return deserialize(byteBuffer.array()); + } + byte[] bytes = new byte[byteBuffer.remaining()]; + byteBuffer.get(bytes); + return deserialize(bytes); } // --- Metadata about the serialized/deserialized content @@ -98,13 +102,13 @@ public byte[] serialize(@Nullable T value) { } @Override - public ByteString serializeToByteString(@Nullable T value) { - return inner.serializeToByteString(value); + public ByteBuffer serializeToByteBuffer(@Nullable T value) { + return inner.serializeToByteBuffer(value); } @Override - public T deserialize(ByteString byteString) { - return inner.deserialize(byteString); + public T deserialize(ByteBuffer byteBuffer) { + return inner.deserialize(byteBuffer); } @Override diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/ExitSideEffectSyscallCallback.java b/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/ExitSideEffectSyscallCallback.java index 1f590a37..43d0ed89 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/ExitSideEffectSyscallCallback.java +++ b/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/ExitSideEffectSyscallCallback.java @@ -8,10 +8,10 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.common.syscalls; -import com.google.protobuf.ByteString; import dev.restate.sdk.common.TerminalException; +import java.nio.ByteBuffer; -public interface ExitSideEffectSyscallCallback extends SyscallCallback { +public interface ExitSideEffectSyscallCallback extends SyscallCallback { /** This is user failure. */ void onFailure(TerminalException t); diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/HandlerRunner.java b/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/HandlerRunner.java index 7ca3863e..95915aa6 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/HandlerRunner.java +++ b/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/HandlerRunner.java @@ -8,7 +8,7 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.common.syscalls; -import com.google.protobuf.ByteString; +import java.nio.ByteBuffer; import org.jspecify.annotations.Nullable; public interface HandlerRunner { @@ -26,5 +26,5 @@ void run( HandlerSpecification handlerSpecification, Syscalls syscalls, @Nullable O options, - SyscallCallback callback); + SyscallCallback callback); } diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Syscalls.java b/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Syscalls.java index 07ed0ab5..f7128785 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Syscalls.java +++ b/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Syscalls.java @@ -8,10 +8,10 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.common.syscalls; -import com.google.protobuf.ByteString; import dev.restate.sdk.common.Request; import dev.restate.sdk.common.Target; import dev.restate.sdk.common.TerminalException; +import java.nio.ByteBuffer; import java.time.Duration; import java.util.*; import java.util.List; @@ -40,13 +40,13 @@ public interface Syscalls { // Note: These are not supposed to be exposed to RestateContext, but they should be used through // gRPC APIs. - void writeOutput(ByteString value, SyscallCallback callback); + void writeOutput(ByteBuffer value, SyscallCallback callback); void writeOutput(TerminalException exception, SyscallCallback callback); // ----- State - void get(String name, SyscallCallback> callback); + void get(String name, SyscallCallback> callback); void getKeys(SyscallCallback>> callback); @@ -54,38 +54,38 @@ public interface Syscalls { void clearAll(SyscallCallback callback); - void set(String name, ByteString value, SyscallCallback callback); + void set(String name, ByteBuffer value, SyscallCallback callback); // ----- Syscalls void sleep(Duration duration, SyscallCallback> callback); - void call(Target target, ByteString parameter, SyscallCallback> callback); + void call(Target target, ByteBuffer parameter, SyscallCallback> callback); void send( Target target, - ByteString parameter, + ByteBuffer parameter, @Nullable Duration delay, SyscallCallback requestCallback); void enterSideEffectBlock(@Nullable String name, EnterSideEffectSyscallCallback callback); - void exitSideEffectBlock(ByteString toWrite, ExitSideEffectSyscallCallback callback); + void exitSideEffectBlock(ByteBuffer toWrite, ExitSideEffectSyscallCallback callback); void exitSideEffectBlockWithTerminalException( TerminalException toWrite, ExitSideEffectSyscallCallback callback); - void awakeable(SyscallCallback>> callback); + void awakeable(SyscallCallback>> callback); - void resolveAwakeable(String id, ByteString payload, SyscallCallback requestCallback); + void resolveAwakeable(String id, ByteBuffer payload, SyscallCallback requestCallback); void rejectAwakeable(String id, String reason, SyscallCallback requestCallback); - void promise(String key, SyscallCallback> callback); + void promise(String key, SyscallCallback> callback); - void peekPromise(String key, SyscallCallback> callback); + void peekPromise(String key, SyscallCallback> callback); - void resolvePromise(String key, ByteString payload, SyscallCallback> callback); + void resolvePromise(String key, ByteBuffer payload, SyscallCallback> callback); void rejectPromise(String key, String reason, SyscallCallback> callback); diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/Entries.java b/sdk-core/src/main/java/dev/restate/sdk/core/Entries.java index ecf0ffd7..c7502460 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/Entries.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/Entries.java @@ -11,11 +11,13 @@ import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.MessageLite; +import com.google.protobuf.UnsafeByteOperations; import dev.restate.generated.service.protocol.Protocol; import dev.restate.generated.service.protocol.Protocol.*; import dev.restate.sdk.common.syscalls.Result; import io.opentelemetry.api.common.Attributes; import io.opentelemetry.api.trace.Span; +import java.nio.ByteBuffer; import java.util.Collection; import java.util.function.Function; import java.util.stream.Collectors; @@ -77,7 +79,7 @@ public void trace(OutputEntryMessage expected, Span span) { } static final class GetStateEntry - extends CompletableJournalEntry { + extends CompletableJournalEntry { static final GetStateEntry INSTANCE = new GetStateEntry(); @@ -111,9 +113,9 @@ void checkEntryHeader(GetStateEntryMessage expected, MessageLite actual) } @Override - public Result parseEntryResult(GetStateEntryMessage actual) { + public Result parseEntryResult(GetStateEntryMessage actual) { if (actual.getResultCase() == GetStateEntryMessage.ResultCase.VALUE) { - return Result.success(actual.getValue()); + return Result.success(actual.getValue().asReadOnlyByteBuffer()); } else if (actual.getResultCase() == GetStateEntryMessage.ResultCase.FAILURE) { return Result.failure(Util.toRestateException(actual.getFailure())); } else if (actual.getResultCase() == GetStateEntryMessage.ResultCase.EMPTY) { @@ -124,9 +126,9 @@ public Result parseEntryResult(GetStateEntryMessage actual) { } @Override - public Result parseCompletionResult(CompletionMessage actual) { + public Result parseCompletionResult(CompletionMessage actual) { if (actual.getResultCase() == CompletionMessage.ResultCase.VALUE) { - return Result.success(actual.getValue()); + return Result.success(actual.getValue().asReadOnlyByteBuffer()); } else if (actual.getResultCase() == CompletionMessage.ResultCase.EMPTY) { return Result.empty(); } else if (actual.getResultCase() == CompletionMessage.ResultCase.FAILURE) { @@ -138,7 +140,7 @@ public Result parseCompletionResult(CompletionMessage actual) { @Override void updateUserStateStoreWithEntry( GetStateEntryMessage expected, UserStateStore userStateStore) { - userStateStore.set(expected.getKey(), expected.getValue()); + userStateStore.set(expected.getKey(), expected.getValue().asReadOnlyByteBuffer()); } @Override @@ -146,7 +148,9 @@ GetStateEntryMessage tryCompleteWithUserStateStorage( GetStateEntryMessage expected, UserStateStore userStateStore) { UserStateStore.State value = userStateStore.get(expected.getKey()); if (value instanceof UserStateStore.Value) { - return expected.toBuilder().setValue(((UserStateStore.Value) value).getValue()).build(); + return expected.toBuilder() + .setValue(UnsafeByteOperations.unsafeWrap(((UserStateStore.Value) value).getValue())) + .build(); } else if (value instanceof UserStateStore.Empty) { return expected.toBuilder().setEmpty(Empty.getDefaultInstance()).build(); } @@ -159,7 +163,7 @@ void updateUserStateStorageWithCompletion( if (actual.hasEmpty()) { userStateStore.clear(expected.getKey()); } else { - userStateStore.set(expected.getKey(), actual.getValue()); + userStateStore.set(expected.getKey(), actual.getValue().asReadOnlyByteBuffer()); } } } @@ -333,7 +337,7 @@ void checkEntryHeader(SetStateEntryMessage expected, MessageLite actual) @Override void updateUserStateStoreWithEntry( SetStateEntryMessage expected, UserStateStore userStateStore) { - userStateStore.set(expected.getKey(), expected.getValue()); + userStateStore.set(expected.getKey(), expected.getValue().asReadOnlyByteBuffer()); } } @@ -383,9 +387,9 @@ public Result parseCompletionResult(CompletionMessage actual) { static final class InvokeEntry extends CompletableJournalEntry { - private final Function> valueParser; + private final Function> valueParser; - InvokeEntry(Function> valueParser) { + InvokeEntry(Function> valueParser) { this.valueParser = valueParser; } @@ -427,7 +431,7 @@ void checkEntryHeader(CallEntryMessage expected, MessageLite actual) throws Prot @Override public Result parseEntryResult(CallEntryMessage actual) { if (actual.hasValue()) { - return valueParser.apply(actual.getValue()); + return valueParser.apply(actual.getValue().asReadOnlyByteBuffer()); } return Result.failure(Util.toRestateException(actual.getFailure())); } @@ -435,7 +439,7 @@ public Result parseEntryResult(CallEntryMessage actual) { @Override public Result parseCompletionResult(CompletionMessage actual) { if (actual.hasValue()) { - return valueParser.apply(actual.getValue()); + return valueParser.apply(actual.getValue().asReadOnlyByteBuffer()); } if (actual.hasFailure()) { return Result.failure(Util.toRestateException(actual.getFailure())); @@ -474,7 +478,7 @@ void checkEntryHeader(OneWayCallEntryMessage expected, MessageLite actual) } static final class AwakeableEntry - extends CompletableJournalEntry { + extends CompletableJournalEntry { static final AwakeableEntry INSTANCE = new AwakeableEntry(); private AwakeableEntry() {} @@ -495,17 +499,17 @@ public boolean hasResult(AwakeableEntryMessage actual) { } @Override - public Result parseEntryResult(AwakeableEntryMessage actual) { + public Result parseEntryResult(AwakeableEntryMessage actual) { if (actual.hasValue()) { - return Result.success(actual.getValue()); + return Result.success(actual.getValue().asReadOnlyByteBuffer()); } return Result.failure(Util.toRestateException(actual.getFailure())); } @Override - public Result parseCompletionResult(CompletionMessage actual) { + public Result parseCompletionResult(CompletionMessage actual) { if (actual.hasValue()) { - return Result.success(actual.getValue()); + return Result.success(actual.getValue().asReadOnlyByteBuffer()); } if (actual.hasFailure()) { return Result.failure(Util.toRestateException(actual.getFailure())); @@ -515,7 +519,7 @@ public Result parseCompletionResult(CompletionMessage actual) { } static final class GetPromiseEntry - extends CompletableJournalEntry { + extends CompletableJournalEntry { static final GetPromiseEntry INSTANCE = new GetPromiseEntry(); private GetPromiseEntry() {} @@ -547,17 +551,17 @@ public boolean hasResult(GetPromiseEntryMessage actual) { } @Override - public Result parseEntryResult(GetPromiseEntryMessage actual) { + public Result parseEntryResult(GetPromiseEntryMessage actual) { if (actual.hasValue()) { - return Result.success(actual.getValue()); + return Result.success(actual.getValue().asReadOnlyByteBuffer()); } return Result.failure(Util.toRestateException(actual.getFailure())); } @Override - public Result parseCompletionResult(CompletionMessage actual) { + public Result parseCompletionResult(CompletionMessage actual) { if (actual.hasValue()) { - return Result.success(actual.getValue()); + return Result.success(actual.getValue().asReadOnlyByteBuffer()); } if (actual.hasFailure()) { return Result.failure(Util.toRestateException(actual.getFailure())); @@ -567,7 +571,7 @@ public Result parseCompletionResult(CompletionMessage actual) { } static final class PeekPromiseEntry - extends CompletableJournalEntry { + extends CompletableJournalEntry { static final PeekPromiseEntry INSTANCE = new PeekPromiseEntry(); private PeekPromiseEntry() {} @@ -599,9 +603,9 @@ public boolean hasResult(PeekPromiseEntryMessage actual) { } @Override - public Result parseEntryResult(PeekPromiseEntryMessage actual) { + public Result parseEntryResult(PeekPromiseEntryMessage actual) { if (actual.getResultCase() == PeekPromiseEntryMessage.ResultCase.VALUE) { - return Result.success(actual.getValue()); + return Result.success(actual.getValue().asReadOnlyByteBuffer()); } else if (actual.getResultCase() == PeekPromiseEntryMessage.ResultCase.FAILURE) { return Result.failure(Util.toRestateException(actual.getFailure())); } else if (actual.getResultCase() == PeekPromiseEntryMessage.ResultCase.EMPTY) { @@ -612,9 +616,9 @@ public Result parseEntryResult(PeekPromiseEntryMessage actual) { } @Override - public Result parseCompletionResult(CompletionMessage actual) { + public Result parseCompletionResult(CompletionMessage actual) { if (actual.getResultCase() == CompletionMessage.ResultCase.VALUE) { - return Result.success(actual.getValue()); + return Result.success(actual.getValue().asReadOnlyByteBuffer()); } else if (actual.getResultCase() == CompletionMessage.ResultCase.EMPTY) { return Result.empty(); } else if (actual.getResultCase() == CompletionMessage.ResultCase.FAILURE) { diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/ExecutorSwitchingSyscalls.java b/sdk-core/src/main/java/dev/restate/sdk/core/ExecutorSwitchingSyscalls.java index ebe47d15..fb0e7c22 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/ExecutorSwitchingSyscalls.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/ExecutorSwitchingSyscalls.java @@ -8,7 +8,6 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.core; -import com.google.protobuf.ByteString; import dev.restate.sdk.common.Request; import dev.restate.sdk.common.Target; import dev.restate.sdk.common.TerminalException; @@ -16,6 +15,7 @@ import dev.restate.sdk.common.syscalls.EnterSideEffectSyscallCallback; import dev.restate.sdk.common.syscalls.ExitSideEffectSyscallCallback; import dev.restate.sdk.common.syscalls.SyscallCallback; +import java.nio.ByteBuffer; import java.time.Duration; import java.util.Collection; import java.util.Map; @@ -33,7 +33,7 @@ class ExecutorSwitchingSyscalls implements SyscallsInternal { } @Override - public void writeOutput(ByteString value, SyscallCallback callback) { + public void writeOutput(ByteBuffer value, SyscallCallback callback) { syscallsExecutor.execute(() -> syscalls.writeOutput(value, callback)); } @@ -43,7 +43,7 @@ public void writeOutput(TerminalException throwable, SyscallCallback callb } @Override - public void get(String name, SyscallCallback> callback) { + public void get(String name, SyscallCallback> callback) { syscallsExecutor.execute(() -> syscalls.get(name, callback)); } @@ -63,7 +63,7 @@ public void clearAll(SyscallCallback callback) { } @Override - public void set(String name, ByteString value, SyscallCallback callback) { + public void set(String name, ByteBuffer value, SyscallCallback callback) { syscallsExecutor.execute(() -> syscalls.set(name, value, callback)); } @@ -74,14 +74,14 @@ public void sleep(Duration duration, SyscallCallback> callback) { @Override public void call( - Target target, ByteString parameter, SyscallCallback> callback) { + Target target, ByteBuffer parameter, SyscallCallback> callback) { syscallsExecutor.execute(() -> syscalls.call(target, parameter, callback)); } @Override public void send( Target target, - ByteString parameter, + ByteBuffer parameter, @Nullable Duration delay, SyscallCallback requestCallback) { syscallsExecutor.execute(() -> syscalls.send(target, parameter, delay, requestCallback)); @@ -93,7 +93,7 @@ public void enterSideEffectBlock(String name, EnterSideEffectSyscallCallback cal } @Override - public void exitSideEffectBlock(ByteString toWrite, ExitSideEffectSyscallCallback callback) { + public void exitSideEffectBlock(ByteBuffer toWrite, ExitSideEffectSyscallCallback callback) { syscallsExecutor.execute(() -> syscalls.exitSideEffectBlock(toWrite, callback)); } @@ -105,13 +105,13 @@ public void exitSideEffectBlockWithTerminalException( } @Override - public void awakeable(SyscallCallback>> callback) { + public void awakeable(SyscallCallback>> callback) { syscallsExecutor.execute(() -> syscalls.awakeable(callback)); } @Override public void resolveAwakeable( - String id, ByteString payload, SyscallCallback requestCallback) { + String id, ByteBuffer payload, SyscallCallback requestCallback) { syscallsExecutor.execute(() -> syscalls.resolveAwakeable(id, payload, requestCallback)); } @@ -121,18 +121,18 @@ public void rejectAwakeable(String id, String reason, SyscallCallback requ } @Override - public void promise(String key, SyscallCallback> callback) { + public void promise(String key, SyscallCallback> callback) { syscallsExecutor.execute(() -> syscalls.promise(key, callback)); } @Override - public void peekPromise(String key, SyscallCallback> callback) { + public void peekPromise(String key, SyscallCallback> callback) { syscallsExecutor.execute(() -> syscalls.peekPromise(key, callback)); } @Override public void resolvePromise( - String key, ByteString payload, SyscallCallback> callback) { + String key, ByteBuffer payload, SyscallCallback> callback) { syscallsExecutor.execute(() -> syscalls.resolvePromise(key, payload, callback)); } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/InvocationStateMachine.java b/sdk-core/src/main/java/dev/restate/sdk/core/InvocationStateMachine.java index c5e5ec8e..7cbf56ff 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/InvocationStateMachine.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/InvocationStateMachine.java @@ -228,7 +228,7 @@ void onStartMessage(MessageLite msg) { new Request( invocationId, Context.root().with(span), - inputEntry.getValue(), + inputEntry.getValue().asReadOnlyByteBuffer(), inputEntry.getHeadersList().stream() .collect( Collectors.toUnmodifiableMap( @@ -490,7 +490,7 @@ void completeSideEffectCallbackWithEntry( if (sideEffectEntry.hasFailure()) { callback.onFailure(Util.toRestateException(sideEffectEntry.getFailure())); } else { - callback.onSuccess(sideEffectEntry.getValue()); + callback.onSuccess(sideEffectEntry.getValue().asReadOnlyByteBuffer()); } } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/ResolvedEndpointHandlerImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/ResolvedEndpointHandlerImpl.java index 7b5a1dcd..62191097 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/ResolvedEndpointHandlerImpl.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/ResolvedEndpointHandlerImpl.java @@ -8,9 +8,9 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.core; -import com.google.protobuf.ByteString; import dev.restate.sdk.common.TerminalException; import dev.restate.sdk.common.syscalls.*; +import java.nio.ByteBuffer; import java.util.concurrent.Executor; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -74,12 +74,12 @@ public void start() { t -> {})); } - private void writeOutputAndEnd(SyscallsInternal syscalls, ByteString output) { + private void writeOutputAndEnd(SyscallsInternal syscalls, ByteBuffer output) { syscalls.writeOutput( output, SyscallCallback.ofVoid( () -> { - LOG.trace("Wrote output message:\n{}", output); + LOG.trace("Wrote output message"); this.end(syscalls, null); }, syscalls::fail)); @@ -118,7 +118,7 @@ public void run( HandlerSpecification spec, Syscalls syscalls, @Nullable O options, - SyscallCallback callback) { + SyscallCallback callback) { try { this.handler.run(spec, syscalls, options, callback); } catch (Throwable e) { diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/SyscallsImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/SyscallsImpl.java index f01c0be3..a0e8e5d5 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/SyscallsImpl.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/SyscallsImpl.java @@ -8,6 +8,8 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.core; +import static dev.restate.sdk.core.Util.nioBufferToProtobufBuffer; + import com.google.protobuf.ByteString; import dev.restate.generated.service.protocol.Protocol; import dev.restate.sdk.common.Request; @@ -50,12 +52,15 @@ public Request request() { } @Override - public void writeOutput(ByteString value, SyscallCallback callback) { + public void writeOutput(ByteBuffer value, SyscallCallback callback) { wrapAndPropagateExceptions( () -> { LOG.trace("writeOutput success"); this.writeOutput( - Protocol.OutputEntryMessage.newBuilder().setValue(value).build(), callback); + Protocol.OutputEntryMessage.newBuilder() + .setValue(nioBufferToProtobufBuffer(value)) + .build(), + callback); }, callback); } @@ -81,7 +86,7 @@ private void writeOutput(Protocol.OutputEntryMessage entry, SyscallCallback> callback) { + public void get(String name, SyscallCallback> callback) { wrapAndPropagateExceptions( () -> { LOG.trace("get {}", name); @@ -137,14 +142,14 @@ public void clearAll(SyscallCallback callback) { } @Override - public void set(String name, ByteString value, SyscallCallback callback) { + public void set(String name, ByteBuffer value, SyscallCallback callback) { wrapAndPropagateExceptions( () -> { LOG.trace("set {}", name); this.stateMachine.processJournalEntry( Protocol.SetStateEntryMessage.newBuilder() .setKey(ByteString.copyFromUtf8(name)) - .setValue(value) + .setValue(nioBufferToProtobufBuffer(value)) .build(), SetStateEntry.INSTANCE, callback); @@ -169,7 +174,7 @@ public void sleep(Duration duration, SyscallCallback> callback) { @Override public void call( - Target target, ByteString parameter, SyscallCallback> callback) { + Target target, ByteBuffer parameter, SyscallCallback> callback) { wrapAndPropagateExceptions( () -> { LOG.trace("call {}", target); @@ -178,7 +183,7 @@ public void call( Protocol.CallEntryMessage.newBuilder() .setServiceName(target.getService()) .setHandlerName(target.getHandler()) - .setParameter(parameter); + .setParameter(nioBufferToProtobufBuffer(parameter)); if (target.getKey() != null) { builder.setKey(target.getKey()); } @@ -192,7 +197,7 @@ public void call( @Override public void send( Target target, - ByteString parameter, + ByteBuffer parameter, @Nullable Duration delay, SyscallCallback callback) { wrapAndPropagateExceptions( @@ -203,7 +208,7 @@ public void send( Protocol.OneWayCallEntryMessage.newBuilder() .setServiceName(target.getService()) .setHandlerName(target.getHandler()) - .setParameter(parameter); + .setParameter(nioBufferToProtobufBuffer(parameter)); if (target.getKey() != null) { builder.setKey(target.getKey()); } @@ -228,12 +233,15 @@ public void enterSideEffectBlock(String name, EnterSideEffectSyscallCallback cal } @Override - public void exitSideEffectBlock(ByteString toWrite, ExitSideEffectSyscallCallback callback) { + public void exitSideEffectBlock(ByteBuffer toWrite, ExitSideEffectSyscallCallback callback) { wrapAndPropagateExceptions( () -> { LOG.trace("exitSideEffectBlock with success"); this.stateMachine.exitSideEffectBlock( - Protocol.RunEntryMessage.newBuilder().setValue(toWrite).build(), callback); + Protocol.RunEntryMessage.newBuilder() + .setValue(nioBufferToProtobufBuffer(toWrite)) + .build(), + callback); }, callback); } @@ -254,7 +262,7 @@ public void exitSideEffectBlockWithTerminalException( } @Override - public void awakeable(SyscallCallback>> callback) { + public void awakeable(SyscallCallback>> callback) { wrapAndPropagateExceptions( () -> { LOG.trace("awakeable"); @@ -271,7 +279,7 @@ public void awakeable(SyscallCallback>> c ByteString.copyFrom( ByteBuffer.allocate(4) .putInt( - ((SingleDeferredInternal) deferredResult) + ((SingleDeferredInternal) deferredResult) .entryIndex()) .rewind())); @@ -287,13 +295,14 @@ public void awakeable(SyscallCallback>> c @Override public void resolveAwakeable( - String serializedId, ByteString payload, SyscallCallback callback) { + String serializedId, ByteBuffer payload, SyscallCallback callback) { wrapAndPropagateExceptions( () -> { LOG.trace("resolveAwakeable"); completeAwakeable( serializedId, - Protocol.CompleteAwakeableEntryMessage.newBuilder().setValue(payload), + Protocol.CompleteAwakeableEntryMessage.newBuilder() + .setValue(nioBufferToProtobufBuffer(payload)), callback); }, callback); @@ -325,7 +334,7 @@ private void completeAwakeable( } @Override - public void promise(String key, SyscallCallback> callback) { + public void promise(String key, SyscallCallback> callback) { wrapAndPropagateExceptions( () -> { LOG.trace("promise"); @@ -338,7 +347,7 @@ public void promise(String key, SyscallCallback> callback) } @Override - public void peekPromise(String key, SyscallCallback> callback) { + public void peekPromise(String key, SyscallCallback> callback) { wrapAndPropagateExceptions( () -> { LOG.trace("peekPromise"); @@ -352,14 +361,14 @@ public void peekPromise(String key, SyscallCallback> callba @Override public void resolvePromise( - String key, ByteString payload, SyscallCallback> callback) { + String key, ByteBuffer payload, SyscallCallback> callback) { wrapAndPropagateExceptions( () -> { LOG.trace("resolvePromise"); this.stateMachine.processCompletableJournalEntry( Protocol.CompletePromiseEntryMessage.newBuilder() .setKey(key) - .setCompletionValue(payload) + .setCompletionValue(nioBufferToProtobufBuffer(payload)) .build(), CompletePromiseEntry.INSTANCE, callback); diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/UserStateStore.java b/sdk-core/src/main/java/dev/restate/sdk/core/UserStateStore.java index e8bf7b63..e66136cb 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/UserStateStore.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/UserStateStore.java @@ -9,6 +9,7 @@ package dev.restate.sdk.core; import com.google.protobuf.ByteString; +import java.nio.ByteBuffer; import java.util.HashMap; import java.util.Map; import java.util.Set; @@ -31,13 +32,13 @@ private Empty() {} } static final class Value implements State { - private final ByteString value; + private final ByteBuffer value; - private Value(ByteString value) { + private Value(ByteBuffer value) { this.value = value; } - public ByteString getValue() { + public ByteBuffer getValue() { return value; } } @@ -50,14 +51,16 @@ public ByteString getValue() { this.map = new HashMap<>( map.entrySet().stream() - .collect(Collectors.toMap(Map.Entry::getKey, e -> new Value(e.getValue())))); + .collect( + Collectors.toMap( + Map.Entry::getKey, e -> new Value(e.getValue().asReadOnlyByteBuffer())))); } public State get(ByteString key) { return this.map.getOrDefault(key, isPartial ? Unknown.INSTANCE : Empty.INSTANCE); } - public void set(ByteString key, ByteString value) { + public void set(ByteString key, ByteBuffer value) { this.map.put(key, new Value(value)); } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/Util.java b/sdk-core/src/main/java/dev/restate/sdk/core/Util.java index 81dc8339..02479022 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/Util.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/Util.java @@ -8,13 +8,16 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.core; +import com.google.protobuf.ByteString; import com.google.protobuf.MessageLite; +import com.google.protobuf.UnsafeByteOperations; import dev.restate.generated.sdk.java.Java; import dev.restate.generated.service.protocol.Protocol; import dev.restate.sdk.common.AbortedExecutionException; import dev.restate.sdk.common.TerminalException; import java.io.PrintWriter; import java.io.StringWriter; +import java.nio.ByteBuffer; import java.util.Objects; import java.util.Optional; import java.util.function.Predicate; @@ -159,4 +162,9 @@ static boolean isEntry(MessageLite msg) { || msg instanceof Java.CombinatorAwaitableEntryMessage || msg instanceof Protocol.RunEntryMessage; } + + /** NOTE! This method rewinds the buffer!!! */ + static ByteString nioBufferToProtobufBuffer(ByteBuffer nioBuffer) { + return UnsafeByteOperations.unsafeWrap(nioBuffer.rewind()); + } } diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/AwakeableIdTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/AwakeableIdTestSuite.java index a1cf62df..5899bd19 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/AwakeableIdTestSuite.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/AwakeableIdTestSuite.java @@ -56,7 +56,10 @@ public Stream definitions() { assertThat(messages) .element(1) .asInstanceOf(type(Protocol.OutputEntryMessage.class)) - .extracting(out -> CoreSerdes.JSON_STRING.deserialize(out.getValue())) + .extracting( + out -> + CoreSerdes.JSON_STRING.deserialize( + out.getValue().asReadOnlyByteBuffer())) .isEqualTo(base64ExpectedAwakeableId); })); } diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/ProtoUtils.java b/sdk-core/src/test/java/dev/restate/sdk/core/ProtoUtils.java index 1ff12852..446461e7 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/ProtoUtils.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/ProtoUtils.java @@ -64,7 +64,8 @@ public static Protocol.StartMessage.Builder startMessage( e -> StateEntry.newBuilder() .setKey(ByteString.copyFromUtf8(e.getKey())) - .setValue(CoreSerdes.JSON_STRING.serializeToByteString(e.getValue())) + .setValue( + ByteString.copyFrom(CoreSerdes.JSON_STRING.serialize(e.getValue()))) .build()) .collect(Collectors.toList())); } @@ -75,7 +76,7 @@ public static Protocol.CompletionMessage.Builder completionMessage(int index) { public static Protocol.CompletionMessage completionMessage( int index, Serde serde, T value) { - return completionMessage(index).setValue(serde.serializeToByteString(value)).build(); + return completionMessage(index).setValue(ByteString.copyFrom(serde.serialize(value))).build(); } public static Protocol.CompletionMessage completionMessage(int index, String value) { @@ -109,7 +110,7 @@ public static Protocol.InputEntryMessage inputMessage(byte[] value) { public static Protocol.InputEntryMessage inputMessage(Serde serde, T value) { return Protocol.InputEntryMessage.newBuilder() - .setValue(serde.serializeToByteString(value)) + .setValue(ByteString.copyFrom(serde.serialize(value))) .build(); } @@ -123,7 +124,7 @@ public static Protocol.InputEntryMessage inputMessage(int value) { public static Protocol.OutputEntryMessage outputMessage(Serde serde, T value) { return Protocol.OutputEntryMessage.newBuilder() - .setValue(serde.serializeToByteString(value)) + .setValue(ByteString.copyFrom(serde.serialize(value))) .build(); } @@ -170,7 +171,7 @@ public static Protocol.GetStateEntryMessage getStateEmptyMessage(String key) { public static Protocol.GetStateEntryMessage getStateMessage( String key, Serde serde, T value) { - return getStateMessage(key).setValue(serde.serializeToByteString(value)).build(); + return getStateMessage(key).setValue(ByteString.copyFrom(serde.serialize(value))).build(); } public static Protocol.GetStateEntryMessage getStateMessage(String key, String value) { @@ -181,7 +182,7 @@ public static Protocol.SetStateEntryMessage setStateMessage( String key, Serde serde, T value) { return Protocol.SetStateEntryMessage.newBuilder() .setKey(ByteString.copyFromUtf8(key)) - .setValue(serde.serializeToByteString(value)) + .setValue(ByteString.copyFrom(serde.serialize(value))) .build(); } @@ -213,13 +214,13 @@ public static Protocol.CallEntryMessage.Builder invokeMessage(Target target, byt public static Protocol.CallEntryMessage.Builder invokeMessage( Target target, Serde reqSerde, T parameter) { - return invokeMessage(target).setParameter(reqSerde.serializeToByteString(parameter)); + return invokeMessage(target).setParameter(ByteString.copyFrom(reqSerde.serialize(parameter))); } public static Protocol.CallEntryMessage invokeMessage( Target target, Serde reqSerde, T parameter, Serde resSerde, R result) { return invokeMessage(target, reqSerde, parameter) - .setValue(resSerde.serializeToByteString(result)) + .setValue(ByteString.copyFrom(resSerde.serialize(result))) .build(); } @@ -237,7 +238,9 @@ public static Protocol.AwakeableEntryMessage.Builder awakeable() { } public static Protocol.AwakeableEntryMessage awakeable(String value) { - return awakeable().setValue(CoreSerdes.JSON_STRING.serializeToByteString(value)).build(); + return awakeable() + .setValue(ByteString.copyFrom(CoreSerdes.JSON_STRING.serialize(value))) + .build(); } public static Protocol.GetPromiseEntryMessage.Builder getPromise(String key) { @@ -252,7 +255,7 @@ public static Protocol.CompletePromiseEntryMessage.Builder completePromise( String key, String value) { return Protocol.CompletePromiseEntryMessage.newBuilder() .setKey(key) - .setCompletionValue(CoreSerdes.JSON_STRING.serializeToByteString(value)); + .setCompletionValue(ByteString.copyFrom(CoreSerdes.JSON_STRING.serialize(value))); } public static Protocol.CompletePromiseEntryMessage.Builder completePromise( diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/SideEffectTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/SideEffectTestSuite.java index 756188e3..b607175b 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/SideEffectTestSuite.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/SideEffectTestSuite.java @@ -15,6 +15,7 @@ import static org.assertj.core.api.InstanceOfAssertFactories.STRING; import static org.assertj.core.api.InstanceOfAssertFactories.type; +import com.google.protobuf.ByteString; import dev.restate.generated.service.protocol.Protocol; import dev.restate.sdk.common.CoreSerdes; import dev.restate.sdk.common.TerminalException; @@ -41,14 +42,14 @@ public Stream definitions() { .withInput(startMessage(1), inputMessage("Till")) .expectingOutput( Protocol.RunEntryMessage.newBuilder() - .setValue(CoreSerdes.JSON_STRING.serializeToByteString("Francesco")), + .setValue(ByteString.copyFrom(CoreSerdes.JSON_STRING.serialize("Francesco"))), suspensionMessage(1)) .named("Without optimization suspends"), this.sideEffect("Francesco") .withInput(startMessage(1), inputMessage("Till"), ackMessage(1)) .expectingOutput( Protocol.RunEntryMessage.newBuilder() - .setValue(CoreSerdes.JSON_STRING.serializeToByteString("Francesco")), + .setValue(ByteString.copyFrom(CoreSerdes.JSON_STRING.serialize("Francesco"))), outputMessage("Hello Francesco"), END_MESSAGE) .named("Without optimization and with acks returns"), @@ -57,13 +58,13 @@ public Stream definitions() { .expectingOutput( Protocol.RunEntryMessage.newBuilder() .setName("get-my-name") - .setValue(CoreSerdes.JSON_STRING.serializeToByteString("Francesco")), + .setValue(ByteString.copyFrom(CoreSerdes.JSON_STRING.serialize("Francesco"))), suspensionMessage(1)), this.consecutiveSideEffect("Francesco") .withInput(startMessage(1), inputMessage("Till")) .expectingOutput( Protocol.RunEntryMessage.newBuilder() - .setValue(CoreSerdes.JSON_STRING.serializeToByteString("Francesco")), + .setValue(ByteString.copyFrom(CoreSerdes.JSON_STRING.serialize("Francesco"))), suspensionMessage(1)) .named("With optimization and without ack on first side effect will suspend"), this.consecutiveSideEffect("Francesco") @@ -71,9 +72,9 @@ public Stream definitions() { .onlyUnbuffered() .expectingOutput( Protocol.RunEntryMessage.newBuilder() - .setValue(CoreSerdes.JSON_STRING.serializeToByteString("Francesco")), + .setValue(ByteString.copyFrom(CoreSerdes.JSON_STRING.serialize("Francesco"))), Protocol.RunEntryMessage.newBuilder() - .setValue(CoreSerdes.JSON_STRING.serializeToByteString("FRANCESCO")), + .setValue(ByteString.copyFrom(CoreSerdes.JSON_STRING.serialize("FRANCESCO"))), suspensionMessage(2)) .named("With optimization and ack on first side effect will suspend"), this.consecutiveSideEffect("Francesco") @@ -81,9 +82,9 @@ public Stream definitions() { .onlyUnbuffered() .expectingOutput( Protocol.RunEntryMessage.newBuilder() - .setValue(CoreSerdes.JSON_STRING.serializeToByteString("Francesco")), + .setValue(ByteString.copyFrom(CoreSerdes.JSON_STRING.serialize("Francesco"))), Protocol.RunEntryMessage.newBuilder() - .setValue(CoreSerdes.JSON_STRING.serializeToByteString("FRANCESCO")), + .setValue(ByteString.copyFrom(CoreSerdes.JSON_STRING.serialize("FRANCESCO"))), outputMessage("Hello FRANCESCO"), END_MESSAGE) .named("With optimization and ack on first and second side effect will resume"), diff --git a/sdk-http-vertx/build.gradle.kts b/sdk-http-vertx/build.gradle.kts index 2af8bb5c..b58ad322 100644 --- a/sdk-http-vertx/build.gradle.kts +++ b/sdk-http-vertx/build.gradle.kts @@ -12,6 +12,8 @@ dependencies { api(project(":sdk-common")) implementation(project(":sdk-core")) + implementation(coreLibs.protobuf.java) + // Vert.x implementation(platform(vertxLibs.vertx.bom)) implementation(vertxLibs.vertx.core) diff --git a/sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/RestateHttpEndpointTest.kt b/sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/RestateHttpEndpointTest.kt index d9f33812..37ea9860 100644 --- a/sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/RestateHttpEndpointTest.kt +++ b/sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/RestateHttpEndpointTest.kt @@ -119,7 +119,7 @@ internal class RestateHttpEndpointTest { request.write( encode( completionMessage(1) - .setValue(CoreSerdes.JSON_LONG.serializeToByteString(2)) + .setValue(ByteString.copyFrom(CoreSerdes.JSON_LONG.serialize(2))) .build())) // Wait for Set State Entry diff --git a/sdk-lambda/build.gradle.kts b/sdk-lambda/build.gradle.kts index e9a1c3da..e68d9fd0 100644 --- a/sdk-lambda/build.gradle.kts +++ b/sdk-lambda/build.gradle.kts @@ -13,6 +13,8 @@ dependencies { api(lambdaLibs.core) api(lambdaLibs.events) + implementation(coreLibs.protobuf.java) + implementation(platform(coreLibs.opentelemetry.bom)) implementation(coreLibs.opentelemetry.api) diff --git a/sdk-serde-protobuf/src/main/java/dev/restate/sdk/serde/protobuf/ProtobufSerdes.java b/sdk-serde-protobuf/src/main/java/dev/restate/sdk/serde/protobuf/ProtobufSerdes.java index 0b890293..65189e77 100644 --- a/sdk-serde-protobuf/src/main/java/dev/restate/sdk/serde/protobuf/ProtobufSerdes.java +++ b/sdk-serde-protobuf/src/main/java/dev/restate/sdk/serde/protobuf/ProtobufSerdes.java @@ -8,11 +8,9 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.serde.protobuf; -import com.google.protobuf.ByteString; -import com.google.protobuf.InvalidProtocolBufferException; -import com.google.protobuf.MessageLite; -import com.google.protobuf.Parser; +import com.google.protobuf.*; import dev.restate.sdk.common.Serde; +import java.nio.ByteBuffer; import java.util.Objects; import org.jspecify.annotations.Nullable; @@ -37,16 +35,17 @@ public T deserialize(byte[] value) { } } - // -- We reimplement the ByteString variants here as it might be more efficient to use them. + // -- We reimplement the ByteBuffer variants here as it might be more efficient to use them. + @Override - public ByteString serializeToByteString(@Nullable T value) { - return Objects.requireNonNull(value).toByteString(); + public ByteBuffer serializeToByteBuffer(@Nullable T value) { + return Objects.requireNonNull(value).toByteString().asReadOnlyByteBuffer(); } @Override - public T deserialize(ByteString byteString) { + public T deserialize(ByteBuffer byteBuffer) { try { - return parser.parseFrom(byteString); + return parser.parseFrom(UnsafeByteOperations.unsafeWrap(byteBuffer.rewind())); } catch (InvalidProtocolBufferException e) { throw new RuntimeException("Cannot deserialize Protobuf object", e); }