Skip to content

Commit

Permalink
Implement request identity (#297)
Browse files Browse the repository at this point in the history
* Add RequestIdentityVerifier interface, to be used to implement request identity.

* Add RestateRequestIdentityVerifier implementation.

* Fix case with unsigned signature scheme.

* Fix description of module

* Better name for the factory method

* Add test and remove prefixing with ASN1, this seems not needed.
  • Loading branch information
slinkydeveloper authored Apr 29, 2024
1 parent 9718851 commit 7b4dd3f
Show file tree
Hide file tree
Showing 16 changed files with 315 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH
//
// This file is part of the Restate Java SDK,
// which is released under the MIT license.
//
// You can find a copy of the license in file LICENSE in the root
// directory of this repository or package, or at
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
package dev.restate.sdk.auth;

import org.jspecify.annotations.Nullable;

/** Interface to verify requests. */
public interface RequestIdentityVerifier {

/** Abstraction for headers map. */
@FunctionalInterface
interface Headers {
@Nullable String get(String key);
}

/**
* @throws Exception if the request cannot be verified
*/
void verifyRequest(Headers headers) throws Exception;
}
4 changes: 2 additions & 2 deletions sdk-core/src/main/java/dev/restate/sdk/core/Entries.java
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ public Result<Collection<String>> parseCompletionResult(CompletionMessage actual
} catch (InvalidProtocolBufferException e) {
throw new ProtocolException(
"Cannot parse get state keys completion",
e,
ProtocolException.PROTOCOL_VIOLATION_CODE);
ProtocolException.PROTOCOL_VIOLATION_CODE,
e);
}
return Result.success(
stateKeys.getKeysList().stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ void onStartMessage(MessageLite msg) {
this.fail(
new ProtocolException(
"Expected at least one entry with Input, got " + this.entriesToReplay + " entries",
null,
TerminalException.INTERNAL_SERVER_ERROR_CODE));
TerminalException.INTERNAL_SERVER_ERROR_CODE,
null));
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

public class ProtocolException extends RuntimeException {

static final int UNAUTHORIZED_CODE = 401;
static final int NOT_FOUND_CODE = 404;
static final int JOURNAL_MISMATCH_CODE = 570;
static final int PROTOCOL_VIOLATION_CODE = 571;
Expand All @@ -28,10 +29,10 @@ private ProtocolException(String message) {
}

private ProtocolException(String message, int code) {
this(message, null, code);
this(message, code, null);
}

public ProtocolException(String message, Throwable cause, int code) {
public ProtocolException(String message, int code, Throwable cause) {
super(message, cause);
this.code = code;
}
Expand Down Expand Up @@ -77,7 +78,11 @@ static ProtocolException methodNotFound(String serviceName, String handlerName)
static ProtocolException invalidSideEffectCall() {
return new ProtocolException(
"A syscall was invoked from within a side effect closure.",
null,
TerminalException.INTERNAL_SERVER_ERROR_CODE);
TerminalException.INTERNAL_SERVER_ERROR_CODE,
null);
}

static ProtocolException unauthorized(Throwable e) {
return new ProtocolException("Unauthorized", UNAUTHORIZED_CODE, e);
}
}
25 changes: 23 additions & 2 deletions sdk-core/src/main/java/dev/restate/sdk/core/RestateEndpoint.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
package dev.restate.sdk.core;

import dev.restate.sdk.auth.RequestIdentityVerifier;
import dev.restate.sdk.common.BindableServiceFactory;
import dev.restate.sdk.common.syscalls.HandlerDefinition;
import dev.restate.sdk.common.syscalls.ServiceDefinition;
Expand All @@ -32,14 +33,17 @@ public class RestateEndpoint {

private final Map<String, ServiceAndOptions<?>> services;
private final Tracer tracer;
private final RequestIdentityVerifier requestIdentityVerifier;
private final DeploymentManifest deploymentManifest;

private RestateEndpoint(
DeploymentManifestSchema.ProtocolMode protocolMode,
Map<String, ServiceAndOptions<?>> services,
Tracer tracer) {
Tracer tracer,
RequestIdentityVerifier requestIdentityVerifier) {
this.services = services;
this.tracer = tracer;
this.requestIdentityVerifier = requestIdentityVerifier;
this.deploymentManifest =
new DeploymentManifest(protocolMode, services.values().stream().map(c -> c.service));

Expand All @@ -49,6 +53,7 @@ private RestateEndpoint(
public ResolvedEndpointHandler resolve(
String componentName,
String handlerName,
RequestIdentityVerifier.Headers headers,
io.opentelemetry.context.Context otelContext,
LoggingContextSetter loggingContextSetter,
@Nullable Executor syscallExecutor)
Expand All @@ -65,6 +70,15 @@ public ResolvedEndpointHandler resolve(
throw ProtocolException.methodNotFound(componentName, handlerName);
}

// Verify request
if (requestIdentityVerifier != null) {
try {
requestIdentityVerifier.verifyRequest(headers);
} catch (Exception e) {
throw ProtocolException.unauthorized(e);
}
}

// Generate the span
Span span =
tracer
Expand Down Expand Up @@ -108,6 +122,7 @@ public static class Builder {

private final List<ServiceAndOptions<?>> services = new ArrayList<>();
private final DeploymentManifestSchema.ProtocolMode protocolMode;
private RequestIdentityVerifier requestIdentityVerifier;
private Tracer tracer = OpenTelemetry.noop().getTracer("NOOP");

public Builder(DeploymentManifestSchema.ProtocolMode protocolMode) {
Expand All @@ -124,12 +139,18 @@ public Builder withTracer(Tracer tracer) {
return this;
}

public Builder withRequestIdentityVerifier(RequestIdentityVerifier requestIdentityVerifier) {
this.requestIdentityVerifier = requestIdentityVerifier;
return this;
}

public RestateEndpoint build() {
return new RestateEndpoint(
this.protocolMode,
this.services.stream()
.collect(Collectors.toMap(c -> c.service.getServiceName(), Function.identity())),
tracer);
tracer,
requestIdentityVerifier);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public void executeTest(TestDefinitions.TestDefinition definition) {
server.resolve(
serviceDefinition.get(0).getServiceName(),
definition.getMethod(),
k -> null,
io.opentelemetry.context.Context.current(),
RestateEndpoint.LoggingContextSetter.THREAD_LOCAL_INSTANCE,
syscallsExecutor);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public void executeTest(TestDefinition definition) {
server.resolve(
serviceDefinition.get(0).getServiceName(),
definition.getMethod(),
k -> null,
io.opentelemetry.context.Context.current(),
RestateEndpoint.LoggingContextSetter.THREAD_LOCAL_INSTANCE,
null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ public void handle(HttpServerRequest request) {
restateEndpoint.resolve(
serviceName,
handlerName,
request::getHeader,
otelContext,
ContextualData::put,
currentContextExecutor(vertxCurrentContext));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
package dev.restate.sdk.http.vertx;

import dev.restate.sdk.auth.RequestIdentityVerifier;
import dev.restate.sdk.common.BindableService;
import dev.restate.sdk.common.syscalls.ServiceDefinition;
import dev.restate.sdk.core.RestateEndpoint;
Expand Down Expand Up @@ -103,7 +104,7 @@ public <O> RestateHttpEndpointBuilder bind(BindableService<O> service, O options
}

/**
* Add a {@link OpenTelemetry} implementation for tracing and metrics.
* Set the {@link OpenTelemetry} implementation for tracing and metrics.
*
* @see OpenTelemetry
*/
Expand All @@ -112,6 +113,18 @@ public RestateHttpEndpointBuilder withOpenTelemetry(OpenTelemetry openTelemetry)
return this;
}

/**
* Set the request identity verifier for this endpoint.
*
* <p>For the Restate implementation to use with Restate Cloud, check the module {@code
* sdk-request-identity}.
*/
public RestateHttpEndpointBuilder withRequestIdentityVerifier(
RequestIdentityVerifier requestIdentityVerifier) {
this.endpointBuilder.withRequestIdentityVerifier(requestIdentityVerifier);
return this;
}

/** Build and listen on the specified port. */
public void buildAndListen(int port) {
build().listen(port).onComplete(RestateHttpEndpointBuilder::handleStart);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ private APIGatewayProxyResponseEvent handleInvoke(APIGatewayProxyRequestEvent in
this.restateEndpoint.resolve(
serviceName,
handlerName,
input.getHeaders()::get,
otelContext,
RestateEndpoint.LoggingContextSetter.THREAD_LOCAL_INSTANCE,
null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
package dev.restate.sdk.lambda;

import dev.restate.sdk.auth.RequestIdentityVerifier;
import dev.restate.sdk.common.BindableService;
import dev.restate.sdk.common.syscalls.ServiceDefinition;
import dev.restate.sdk.core.RestateEndpoint;
Expand Down Expand Up @@ -49,6 +50,18 @@ public RestateLambdaEndpointBuilder withOpenTelemetry(OpenTelemetry openTelemetr
return this;
}

/**
* Set the request identity verifier for this endpoint.
*
* <p>For the Restate implementation to use with Restate Cloud, check the module {@code
* sdk-request-identity}.
*/
public RestateLambdaEndpointBuilder withRequestIdentityVerifier(
RequestIdentityVerifier requestIdentityVerifier) {
this.restateEndpoint.withRequestIdentityVerifier(requestIdentityVerifier);
return this;
}

/** Build the {@link RestateLambdaEndpoint} serving the Restate service endpoint. */
public RestateLambdaEndpoint build() {
return new RestateLambdaEndpoint(this.restateEndpoint.build(), this.openTelemetry);
Expand Down
19 changes: 19 additions & 0 deletions sdk-request-identity/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
plugins {
`java-library`
`library-publishing-conventions`
}

description = "Restate SDK request identity implementation"

dependencies {
compileOnly(coreLibs.jspecify)

implementation(project(":sdk-common"))

// Dependencies for signing request tokens
implementation(coreLibs.jwt)
implementation(coreLibs.tink)

testImplementation(testingLibs.junit.jupiter)
testImplementation(testingLibs.assertj)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH
//
// This file is part of the Restate Java SDK,
// which is released under the MIT license.
//
// You can find a copy of the license in file LICENSE in the root
// directory of this repository or package, or at
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
package dev.restate.sdk.auth.signing;

import java.util.Arrays;

// Copied and adapted from
// https://github.com/bitcoinj/bitcoinj/blob/7df957e4c6817036c096283c5f0dcb7e4d60c982/core/src/main/java/org/bitcoinj/base/Base58.java#L50
// License Apache 2.0
// Copyright 2011 Google Inc.
// Copyright 2018 Andreas Schildbach

class Base58 {
public static final char[] ALPHABET =
"123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz".toCharArray();
private static final int[] INDEXES = new int[128];

static {
Arrays.fill(INDEXES, -1);
for (int i = 0; i < ALPHABET.length; i++) {
INDEXES[ALPHABET[i]] = i;
}
}

/**
* Decodes the given base58 string into the original data bytes.
*
* @param input the base58-encoded string to decode
* @return the decoded data bytes
*/
public static byte[] decode(String input) {
if (input.isEmpty()) {
return new byte[0];
}
// Convert the base58-encoded ASCII chars to a base58 byte sequence (base58 digits).
byte[] input58 = new byte[input.length()];
for (int i = 0; i < input.length(); ++i) {
char c = input.charAt(i);
int digit = c < 128 ? INDEXES[c] : -1;
if (digit < 0) {
throw new IllegalArgumentException(
String.format("Invalid character in Base58: 0x%04x", (int) c));
}
input58[i] = (byte) digit;
}
// Count leading zeros.
int zeros = 0;
while (zeros < input58.length && input58[zeros] == 0) {
++zeros;
}
// Convert base-58 digits to base-256 digits.
byte[] decoded = new byte[input.length()];
int outputStart = decoded.length;
for (int inputStart = zeros; inputStart < input58.length; ) {
decoded[--outputStart] = divmod(input58, inputStart, 58, 256);
if (input58[inputStart] == 0) {
++inputStart; // optimization - skip leading zeros
}
}
// Ignore extra leading zeroes that were added during the calculation.
while (outputStart < decoded.length && decoded[outputStart] == 0) {
++outputStart;
}
// Return decoded data (including original number of leading zeros).
return Arrays.copyOfRange(decoded, outputStart - zeros, decoded.length);
}

/**
* Divides a number, represented as an array of bytes each containing a single digit in the
* specified base, by the given divisor. The given number is modified in-place to contain the
* quotient, and the return value is the remainder.
*
* @param number the number to divide
* @param firstDigit the index within the array of the first non-zero digit (this is used for
* optimization by skipping the leading zeros)
* @param base the base in which the number's digits are represented (up to 256)
* @param divisor the number to divide by (up to 256)
* @return the remainder of the division operation
*/
private static byte divmod(byte[] number, int firstDigit, int base, int divisor) {
// this is just long division which accounts for the base of the input digits
int remainder = 0;
for (int i = firstDigit; i < number.length; i++) {
int digit = (int) number[i] & 0xFF;
int temp = remainder * base + digit;
number[i] = (byte) (temp / divisor);
remainder = temp % divisor;
}
return (byte) remainder;
}
}
Loading

0 comments on commit 7b4dd3f

Please sign in to comment.