Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Now a standalone Result can be returned from handlers #22

Merged
merged 1 commit into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions examples/counter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,33 @@ use restate_sdk::prelude::*;
#[restate_sdk::object]
trait Counter {
#[shared]
async fn get() -> HandlerResult<u64>;
async fn add(val: u64) -> HandlerResult<u64>;
async fn increment() -> HandlerResult<u64>;
async fn reset() -> HandlerResult<()>;
async fn get() -> Result<u64, TerminalError>;
async fn add(val: u64) -> Result<u64, TerminalError>;
async fn increment() -> Result<u64, TerminalError>;
async fn reset() -> Result<(), TerminalError>;
}

struct CounterImpl;

const COUNT: &str = "count";

impl Counter for CounterImpl {
async fn get(&self, ctx: SharedObjectContext<'_>) -> HandlerResult<u64> {
async fn get(&self, ctx: SharedObjectContext<'_>) -> Result<u64, TerminalError> {
Ok(ctx.get::<u64>(COUNT).await?.unwrap_or(0))
}

async fn add(&self, ctx: ObjectContext<'_>, val: u64) -> HandlerResult<u64> {
async fn add(&self, ctx: ObjectContext<'_>, val: u64) -> Result<u64, TerminalError> {
let current = ctx.get::<u64>(COUNT).await?.unwrap_or(0);
let new = current + val;
ctx.set(COUNT, new);
Ok(new)
}

async fn increment(&self, ctx: ObjectContext<'_>) -> HandlerResult<u64> {
async fn increment(&self, ctx: ObjectContext<'_>) -> Result<u64, TerminalError> {
self.add(ctx, 1).await
}

async fn reset(&self, ctx: ObjectContext<'_>) -> HandlerResult<()> {
async fn reset(&self, ctx: ObjectContext<'_>) -> Result<(), TerminalError> {
ctx.clear(COUNT);
Ok(())
}
Expand Down
8 changes: 4 additions & 4 deletions examples/failures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use restate_sdk::prelude::*;
#[restate_sdk::service]
trait FailureExample {
#[name = "doRun"]
async fn do_run() -> HandlerResult<()>;
async fn do_run() -> Result<(), TerminalError>;
}

struct FailureExampleImpl;
Expand All @@ -14,14 +14,14 @@ struct FailureExampleImpl;
struct MyError;

impl FailureExample for FailureExampleImpl {
async fn do_run(&self, context: Context<'_>) -> HandlerResult<()> {
async fn do_run(&self, context: Context<'_>) -> Result<(), TerminalError> {
context
.run(|| async move {
if rand::thread_rng().next_u32() % 4 == 0 {
return Err(TerminalError::new("Failed!!!").into());
Err(TerminalError::new("Failed!!!"))?
}

Err(MyError.into())
Err(MyError)?
})
.await?;

Expand Down
5 changes: 3 additions & 2 deletions examples/greeter.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
use restate_sdk::prelude::*;
use std::convert::Infallible;

#[restate_sdk::service]
trait Greeter {
async fn greet(name: String) -> HandlerResult<String>;
async fn greet(name: String) -> Result<String, Infallible>;
}

struct GreeterImpl;

impl Greeter for GreeterImpl {
async fn greet(&self, _: Context<'_>, name: String) -> HandlerResult<String> {
async fn greet(&self, _: Context<'_>, name: String) -> Result<String, Infallible> {
Ok(format!("Greetings {name}"))
}
}
Expand Down
7 changes: 5 additions & 2 deletions examples/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@ use std::collections::HashMap;

#[restate_sdk::service]
trait RunExample {
async fn do_run() -> HandlerResult<Json<HashMap<String, String>>>;
async fn do_run() -> Result<Json<HashMap<String, String>>, HandlerError>;
}

struct RunExampleImpl(reqwest::Client);

impl RunExample for RunExampleImpl {
async fn do_run(&self, context: Context<'_>) -> HandlerResult<Json<HashMap<String, String>>> {
async fn do_run(
&self,
context: Context<'_>,
) -> Result<Json<HashMap<String, String>>, HandlerError> {
let res = context
.run(|| async move {
let req = self.0.get("https://httpbin.org/ip").build()?;
Expand Down
51 changes: 33 additions & 18 deletions macros/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ pub(crate) struct Handler {
pub(crate) restate_name: String,
pub(crate) ident: Ident,
pub(crate) arg: Option<PatType>,
pub(crate) output: Type,
pub(crate) output_ok: Type,
pub(crate) output_err: Type,
}

impl Parse for Handler {
Expand Down Expand Up @@ -192,17 +193,18 @@ impl Parse for Handler {
let return_type: ReturnType = input.parse()?;
input.parse::<Token![;]>()?;

let output: Type = match &return_type {
ReturnType::Default => {
parse_quote!(())
}
let (ok_ty, err_ty) = match &return_type {
ReturnType::Default => return Err(Error::new(
return_type.span(),
"The return type cannot be empty, only Result or restate_sdk::prelude::HandlerResult is supported as return type",
)),
ReturnType::Type(_, ty) => {
if let Some(ty) = extract_handler_result_parameter(ty) {
ty
if let Some((ok_ty, err_ty)) = extract_handler_result_parameter(ty) {
(ok_ty, err_ty)
} else {
return Err(Error::new(
return_type.span(),
"Only restate_sdk::prelude::HandlerResult is supported as return type",
"Only Result or restate_sdk::prelude::HandlerResult is supported as return type",
));
}
}
Expand All @@ -229,7 +231,8 @@ impl Parse for Handler {
restate_name,
ident,
arg: args.pop(),
output,
output_ok: ok_ty,
output_err: err_ty,
})
}
}
Expand Down Expand Up @@ -263,14 +266,16 @@ fn read_literal_attribute_name(attr: &Attribute) -> Result<Option<String>> {
.transpose()
}

fn extract_handler_result_parameter(ty: &Type) -> Option<Type> {
fn extract_handler_result_parameter(ty: &Type) -> Option<(Type, Type)> {
let path = match ty {
Type::Path(ty) => &ty.path,
_ => return None,
};

let last = path.segments.last().unwrap();
if last.ident != "HandlerResult" {
let is_result = last.ident == "Result";
let is_handler_result = last.ident == "HandlerResult";
if !is_result && !is_handler_result {
return None;
}

Expand All @@ -279,12 +284,22 @@ fn extract_handler_result_parameter(ty: &Type) -> Option<Type> {
_ => return None,
};

if bracketed.args.len() != 1 {
return None;
}

match &bracketed.args[0] {
GenericArgument::Type(arg) => Some(arg.clone()),
_ => None,
if is_handler_result && bracketed.args.len() == 1 {
match &bracketed.args[0] {
GenericArgument::Type(arg) => Some((
arg.clone(),
parse_quote!(::restate_sdk::prelude::HandlerError),
)),
_ => None,
}
} else if is_result && bracketed.args.len() == 2 {
match (&bracketed.args[0], &bracketed.args[1]) {
(GenericArgument::Type(ok_arg), GenericArgument::Type(err_arg)) => {
Some((ok_arg.clone(), err_arg.clone()))
}
_ => None,
}
} else {
None
}
}
8 changes: 4 additions & 4 deletions macros/src/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl<'a> ServiceGenerator<'a> {
let handler_fns = handlers
.iter()
.map(
|Handler { attrs, ident, arg, is_shared, output, .. }| {
|Handler { attrs, ident, arg, is_shared, output_ok, output_err, .. }| {
let args = arg.iter();

let ctx = match (&service_ty, is_shared) {
Expand All @@ -68,7 +68,7 @@ impl<'a> ServiceGenerator<'a> {

quote! {
#( #attrs )*
fn #ident(&self, context: #ctx, #( #args ),*) -> impl std::future::Future<Output=::restate_sdk::prelude::HandlerResult<#output>> + ::core::marker::Send;
fn #ident(&self, context: #ctx, #( #args ),*) -> impl std::future::Future<Output=Result<#output_ok, #output_err>> + ::core::marker::Send;
}
},
);
Expand Down Expand Up @@ -130,7 +130,7 @@ impl<'a> ServiceGenerator<'a> {
quote! {
#handler_literal => {
#get_input_and_call
let res = fut.await;
let res = fut.await.map_err(::restate_sdk::errors::HandlerError::from);
ctx.handle_handler_result(res);
ctx.end();
Ok(())
Expand Down Expand Up @@ -302,7 +302,7 @@ impl<'a> ServiceGenerator<'a> {
ty, ..
}) => quote! { #ty }
};
let res_ty = &handler.output;
let res_ty = &handler.output_ok;
let input = match &handler.arg {
None => quote! { () },
Some(_) => quote! { req }
Expand Down
36 changes: 18 additions & 18 deletions src/endpoint/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -527,14 +527,14 @@ impl ContextInternal {
.sys_complete_promise(id.to_owned(), NonEmptyValue::Failure(failure.into()));
}

pub fn run<'a, Run, Fut, Res>(
pub fn run<'a, Run, Fut, Out>(
&'a self,
run_closure: Run,
) -> impl crate::context::RunFuture<Result<Res, TerminalError>> + Send + Sync + 'a
) -> impl crate::context::RunFuture<Result<Out, TerminalError>> + Send + Sync + 'a
where
Run: RunClosure<Fut = Fut, Output = Res> + Send + Sync + 'a,
Fut: Future<Output = HandlerResult<Res>> + Send + Sync + 'a,
Res: Serialize + Deserialize + 'static,
Run: RunClosure<Fut = Fut, Output = Out> + Send + Sync + 'a,
Fut: Future<Output = HandlerResult<Out>> + Send + Sync + 'a,
Out: Serialize + Deserialize + 'static,
{
let this = Arc::clone(&self.inner);

Expand Down Expand Up @@ -631,12 +631,12 @@ impl<Run, Fut, Ret> RunFuture<Run, Fut, Ret> {
}
}

impl<Run, Fut, Ret> crate::context::RunFuture<Result<Result<Ret, TerminalError>, Error>>
for RunFuture<Run, Fut, Ret>
impl<Run, Fut, Out> crate::context::RunFuture<Result<Result<Out, TerminalError>, Error>>
for RunFuture<Run, Fut, Out>
where
Run: RunClosure<Fut = Fut, Output = Ret> + Send + Sync,
Fut: Future<Output = HandlerResult<Ret>> + Send + Sync,
Ret: Serialize + Deserialize,
Run: RunClosure<Fut = Fut, Output = Out> + Send + Sync,
Fut: Future<Output = HandlerResult<Out>> + Send + Sync,
Out: Serialize + Deserialize,
{
fn with_retry_policy(mut self, retry_policy: RunRetryPolicy) -> Self {
self.retry_policy = RetryPolicy::Exponential {
Expand All @@ -655,13 +655,13 @@ where
}
}

impl<Run, Fut, Res> Future for RunFuture<Run, Fut, Res>
impl<Run, Fut, Out> Future for RunFuture<Run, Fut, Out>
where
Run: RunClosure<Fut = Fut, Output = Res> + Send + Sync,
Res: Serialize + Deserialize,
Fut: Future<Output = HandlerResult<Res>> + Send + Sync,
Run: RunClosure<Fut = Fut, Output = Out> + Send + Sync,
Out: Serialize + Deserialize,
Fut: Future<Output = HandlerResult<Out>> + Send + Sync,
{
type Output = Result<Result<Res, TerminalError>, Error>;
type Output = Result<Result<Out, TerminalError>, Error>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
Expand All @@ -681,7 +681,7 @@ where
// Enter the side effect
match enter_result.map_err(ErrorInner::VM)? {
RunEnterResult::Executed(NonEmptyValue::Success(mut v)) => {
let t = Res::deserialize(&mut v).map_err(|e| {
let t = Out::deserialize(&mut v).map_err(|e| {
ErrorInner::Deserialization {
syscall: "run",
err: Box::new(e),
Expand All @@ -707,7 +707,7 @@ where
}
RunStateProj::ClosureRunning { start_time, fut } => {
let res = match ready!(fut.poll(cx)) {
Ok(t) => RunExitResult::Success(Res::serialize(&t).map_err(|e| {
Ok(t) => RunExitResult::Success(Out::serialize(&t).map_err(|e| {
ErrorInner::Serialization {
syscall: "run",
err: Box::new(e),
Expand Down Expand Up @@ -752,7 +752,7 @@ where
}
.into()),
Value::Success(mut s) => {
let t = Res::deserialize(&mut s).map_err(|e| {
let t = Out::deserialize(&mut s).map_err(|e| {
ErrorInner::Deserialization {
syscall: "run",
err: Box::new(e),
Expand Down
2 changes: 0 additions & 2 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,4 @@ impl From<TerminalError> for Failure {
}

/// Result type for a Restate handler.
///
/// All Restate handlers *MUST* use this type as return type for their handlers.
pub type HandlerResult<T> = Result<T, HandlerError>;
4 changes: 2 additions & 2 deletions test-services/src/failing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ impl Failing for FailingImpl {
error_message: String,
) -> HandlerResult<()> {
context
.run(|| async move { Err::<(), _>(TerminalError::new(error_message).into()) })
.run(|| async move { Err(TerminalError::new(error_message))? })
.await?;

unreachable!("This should be unreachable")
Expand All @@ -92,7 +92,7 @@ impl Failing for FailingImpl {
cloned_counter.store(0, Ordering::SeqCst);
Ok(current_attempt)
} else {
Err(anyhow!("Failed at attempt {current_attempt}").into())
Err(anyhow!("Failed at attempt {current_attempt}"))?
}
})
.with_retry_policy(
Expand Down
6 changes: 6 additions & 0 deletions tests/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ trait MyService {
async fn no_output() -> HandlerResult<()>;

async fn no_input_no_output() -> HandlerResult<()>;

async fn std_result() -> Result<(), std::io::Error>;

async fn std_result_with_terminal_error() -> Result<(), TerminalError>;

async fn std_result_with_handler_error() -> Result<(), HandlerError>;
}

#[restate_sdk::object]
Expand Down