Skip to content

Commit

Permalink
Support both async fn and functions that return futures
Browse files Browse the repository at this point in the history
  • Loading branch information
kjvalencik committed Jul 16, 2024
1 parent a0fe39e commit 3cfbd7a
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 29 deletions.
9 changes: 5 additions & 4 deletions crates/neon-macros/src/export/function/meta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub(crate) struct Meta {
#[derive(Default)]
pub(super) enum Kind {
Async,
AsyncFn,
#[default]
Normal,
Task,
Expand All @@ -29,8 +30,8 @@ impl Meta {

fn force_context(&mut self, meta: syn::meta::ParseNestedMeta) -> syn::Result<()> {
match self.kind {
Kind::Normal | Kind::AsyncFn => {}
Kind::Async => return Err(meta.error(super::ASYNC_CX_ERROR)),
Kind::Normal => {}
Kind::Task => return Err(meta.error(super::TASK_CX_ERROR)),
}

Expand All @@ -40,8 +41,8 @@ impl Meta {
}

fn make_async(&mut self, meta: syn::meta::ParseNestedMeta) -> syn::Result<()> {
if self.context {
return Err(meta.error(super::ASYNC_CX_ERROR));
if matches!(self.kind, Kind::AsyncFn) {
return Err(meta.error(super::ASYNC_FN_ERROR));
}

self.kind = Kind::Async;
Expand Down Expand Up @@ -76,7 +77,7 @@ impl syn::parse::Parser for Parser {
let mut attr = Meta::default();

if item.sig.asyncness.is_some() {
attr.kind = Kind::Async;
attr.kind = Kind::AsyncFn;
}

let parser = syn::meta::parser(|meta| {
Expand Down
12 changes: 9 additions & 3 deletions crates/neon-macros/src/export/function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::export::function::meta::Kind;
pub(crate) mod meta;

static ASYNC_CX_ERROR: &str = "`FunctionContext` is not allowed in async functions";
static ASYNC_FN_ERROR: &str = "`async` attribute should not be used with an `async fn`";
static TASK_CX_ERROR: &str = "`FunctionContext` is not allowed with `task` attribute";

pub(super) fn export(meta: meta::Meta, input: syn::ItemFn) -> proc_macro::TokenStream {
Expand Down Expand Up @@ -64,9 +65,14 @@ pub(super) fn export(meta: meta::Meta, input: syn::ItemFn) -> proc_macro::TokenS

// Generate the call to the original function
let call_body = match meta.kind {
Kind::Async => quote::quote!(
Kind::Async | Kind::AsyncFn => quote::quote!(
let (#(#tuple_fields,)*) = cx.args()?;
let fut = #name(#context_arg #(#args),*);
let fut = {
use neon::macro_internal::ToNeonFutureMarker;

(&fut).to_neon_future_marker().make_result(&mut cx, fut)?
};

neon::macro_internal::spawn(&mut cx, fut, |mut cx, res| #result_extract)
),
Expand Down Expand Up @@ -167,8 +173,8 @@ fn has_context_arg(meta: &meta::Meta, sig: &syn::Signature) -> syn::Result<bool>

// Context is only allowed for normal functions
match meta.kind {
Kind::Async => return Err(syn::Error::new(first.span(), ASYNC_CX_ERROR)),
Kind::Normal => {}
Kind::Normal | Kind::Async => {}
Kind::AsyncFn => return Err(syn::Error::new(first.span(), ASYNC_CX_ERROR)),
Kind::Task => return Err(syn::Error::new(first.span(), TASK_CX_ERROR)),
}

Expand Down
51 changes: 49 additions & 2 deletions crates/neon/src/macro_internal/futures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use std::future::Future;

use crate::{
context::{Context, TaskContext},
result::JsResult,
types::JsValue,
result::{JsResult, NeonResult},
types::{extract::TryIntoJs, JsValue},
};

pub fn spawn<'cx, C, F, S>(cx: &mut C, fut: F, settle: S) -> JsResult<'cx, JsValue>
Expand All @@ -28,3 +28,50 @@ where

Ok(promise.upcast())
}

pub trait ToNeonFutureMarker {
type Marker;

fn to_neon_future_marker(&self) -> Self::Marker;
}

impl<T, E> ToNeonFutureMarker for Result<T, E> {
type Marker = NeonFutureMarkerResult;

fn to_neon_future_marker(&self) -> Self::Marker {
NeonFutureMarkerResult
}
}

impl<T> ToNeonFutureMarker for &T {
type Marker = NeonFutureMarkerValue;

fn to_neon_future_marker(&self) -> Self::Marker {
NeonFutureMarkerValue
}
}

pub struct NeonFutureMarkerResult;
pub struct NeonFutureMarkerValue;

impl NeonFutureMarkerResult {
pub fn make_result<'cx, C, T, E>(self, cx: &mut C, res: Result<T, E>) -> NeonResult<T>
where
C: Context<'cx>,
E: TryIntoJs<'cx>,
{
res.or_else(|err| {
let err = err.try_into_js(cx)?;
cx.throw(err)
})
}
}

impl NeonFutureMarkerValue {
pub fn make_result<'cx, C, T>(self, _cx: &mut C, res: T) -> NeonResult<T>
where
C: Context<'cx>,
{
Ok(res)
}
}
2 changes: 1 addition & 1 deletion crates/neon/src/macro_internal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
};

#[cfg(all(feature = "napi-6", feature = "futures"))]
pub use self::futures::spawn;
pub use self::futures::*;

#[cfg(all(feature = "napi-6", feature = "futures"))]
mod futures;
Expand Down
54 changes: 38 additions & 16 deletions test/napi/src/js/futures.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,11 @@
use {
neon::{prelude::*, types::buffer::TypedArray},
once_cell::sync::OnceCell,
tokio::runtime::Runtime,
};

fn runtime<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> {
static RUNTIME: OnceCell<Runtime> = OnceCell::new();
use std::future::Future;

RUNTIME.get_or_try_init(|| {
let runtime = Runtime::new().or_else(|err| cx.throw_error(&err.to_string()))?;
let handle = runtime.handle().clone();
use neon::{
prelude::*,
types::{buffer::TypedArray, extract::Error},
};

neon::RUNTIME.get_or_init(cx, || Box::new(handle));

Ok(runtime)
})
}
use crate::runtime;

// Accepts two functions that take no parameters and return numbers.
// Resolves with the sum of the two numbers.
Expand Down Expand Up @@ -85,3 +75,35 @@ pub fn lazy_async_sum(mut cx: FunctionContext) -> JsResult<JsPromise> {

Ok(promise)
}

#[neon::export]
async fn async_fn_add(a: f64, b: f64) -> f64 {
a + b
}

#[neon::export(async)]
fn async_add(a: f64, b: f64) -> impl Future<Output = f64> {
async move { a + b }
}

#[neon::export]
async fn async_fn_div(a: f64, b: f64) -> Result<f64, Error> {
if b == 0.0 {
return Err(Error::from("Divide by zero"));
}

Ok(a / b)
}

#[neon::export(async)]
fn async_div(cx: &mut FunctionContext) -> NeonResult<impl Future<Output = Result<f64, Error>>> {
let (a, b): (f64, f64) = cx.args()?;

Ok(async move {
if b == 0.0 {
return Err(Error::from("Divide by zero"));
}

Ok(a / b)
})
}
10 changes: 7 additions & 3 deletions test/napi/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use neon::prelude::*;
use once_cell::sync::OnceCell;
use tokio::runtime::Runtime;

use crate::js::{
arrays::*, boxed::*, coercions::*, date::*, errors::*, functions::*, numbers::*, objects::*,
Expand Down Expand Up @@ -27,6 +29,7 @@ mod js {

#[neon::main]
fn main(mut cx: ModuleContext) -> NeonResult<()> {
neon::RUNTIME.get_or_try_init(&mut cx, |cx| Ok(Box::new(runtime(cx)?.handle().clone())))?;
neon::registered().export(&mut cx)?;

assert!(neon::registered().into_iter().next().is_some());
Expand Down Expand Up @@ -416,7 +419,8 @@ fn main(mut cx: ModuleContext) -> NeonResult<()> {
Ok(())
}

#[neon::export]
async fn async_add(a: f64, b: f64) -> f64 {
a + b
fn runtime<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> {
static RUNTIME: OnceCell<Runtime> = OnceCell::new();

RUNTIME.get_or_try_init(|| Runtime::new().or_else(|err| cx.throw_error(&err.to_string())))
}

0 comments on commit 3cfbd7a

Please sign in to comment.