Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tsparser: improve auth handler errors #1511

Merged
merged 5 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions tsparser/src/legacymeta/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,12 @@ impl<'a> MetaBuilder<'a> {

Dependent::Gateway((_b, gw)) => {
let auth_handler = if let Some(auth_handler) = &gw.auth_handler {
let ah = auth_handlers
.get(&auth_handler.id)
.ok_or(anyhow::anyhow!("auth handler not found"))?;
let Some(ah) = auth_handlers.get(&auth_handler.id) else {
HANDLER.with(|handler| {
handler.span_err(gw.range, "auth handler not found")
});
fredr marked this conversation as resolved.
Show resolved Hide resolved
continue;
};

let service_name = self
.service_for_range(&ah.range)
Expand Down
10 changes: 8 additions & 2 deletions tsparser/src/parser/resources/apis/authhandler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,13 @@ pub const AUTHHANDLER_PARSER: ResourceParser = ResourceParser {
let request = pass.type_checker.resolve_type(module.clone(), &r.request);
let response = pass.type_checker.resolve_type(module.clone(), &r.response);

let fields = iface_fields(pass.type_checker, &request)?;
let fields = match iface_fields(pass.type_checker, &request) {
Ok(fields) => fields,
Err(e) => {
e.report();
continue;
}
};

for (_, v) in fields {
if !v.is_custom() {
Expand All @@ -66,7 +72,7 @@ pub const AUTHHANDLER_PARSER: ResourceParser = ResourceParser {
.type_checker
.resolve_obj(pass.module.clone(), &ast::Expr::Ident(r.bind_name.clone()));

let encoding = describe_auth_handler(pass.type_checker.state(), request, response)?;
let encoding = describe_auth_handler(pass.type_checker.state(), request, response);

let resource = Resource::AuthHandler(Lrc::new(AuthHandler {
range: r.range,
Expand Down
115 changes: 76 additions & 39 deletions tsparser/src/parser/resources/apis/encoding.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::collections::HashMap;

use anyhow::{bail, Context, Result};
use anyhow::{bail, Context};
use litparser::Sp;
use thiserror::Error;

use crate::parser::resources::apis::api::{Method, Methods};
use crate::parser::respath::Path;
Expand All @@ -10,6 +12,7 @@ use crate::parser::types::{
Type, TypeChecker,
};
use crate::parser::Range;
use crate::span_err::{ErrorWithSpanExt, SpErr};

/// Describes how an API endpoint can be encoded on the wire.
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -84,8 +87,8 @@ pub struct ResponseEncoding {

#[derive(Debug, Clone)]
pub struct AuthHandlerEncoding {
pub auth_param: Type,
pub auth_data: Type,
pub auth_param: Sp<Type>,
pub auth_data: Sp<Type>,
}

pub struct RequestParamsByLoc<'a> {
Expand Down Expand Up @@ -172,13 +175,16 @@ pub fn describe_stream_endpoint(
tc: &TypeChecker,
methods: Methods,
path: Path,
req: Option<Type>,
resp: Option<Type>,
handshake: Option<Type>,
) -> Result<EndpointEncoding> {
let resp = resp
.map(|t| unwrap_promise(tc.state(), &t).clone())
.and_then(drop_empty_or_void);
req: Option<Sp<Type>>,
resp: Option<Sp<Type>>,
handshake: Option<Sp<Type>>,
) -> anyhow::Result<EndpointEncoding> {
let resp = if let Some(resp) = resp {
let (span, resp) = resp.split();
drop_empty_or_void(unwrap_promise(tc.state(), &resp).clone()).map(|t| Sp::new(span, t))
} else {
None
};

let default_method = default_method(&methods);

Expand Down Expand Up @@ -210,29 +216,36 @@ pub fn describe_stream_endpoint(
path
};

let raw_handshake_schema = handshake.map(|sp| sp.take());
let raw_req_schema = req.map(|sp| sp.take());
let raw_resp_schema = resp.map(|sp| sp.take());

Ok(EndpointEncoding {
path,
methods,
default_method,
req: req_enc,
resp: resp_enc,
handshake: handshake_enc,
raw_handshake_schema: handshake,
raw_req_schema: req,
raw_resp_schema: resp,
raw_handshake_schema,
raw_req_schema,
raw_resp_schema,
})
}
pub fn describe_endpoint(
tc: &TypeChecker,
methods: Methods,
path: Path,
req: Option<Type>,
resp: Option<Type>,
req: Option<Sp<Type>>,
resp: Option<Sp<Type>>,
raw: bool,
) -> Result<EndpointEncoding> {
let resp = resp
.map(|t| unwrap_promise(tc.state(), &t).clone())
.and_then(drop_empty_or_void);
) -> anyhow::Result<EndpointEncoding> {
let resp = if let Some(resp) = resp {
let (span, resp) = resp.split();
drop_empty_or_void(unwrap_promise(tc.state(), &resp).clone()).map(|t| Sp::new(span, t))
} else {
None
};

let default_method = default_method(&methods);

Expand All @@ -241,6 +254,9 @@ pub fn describe_endpoint(

let path = rewrite_path_types(&req_enc[0], path, raw).context("parse path param types")?;

let raw_req_schema = req.map(|sp| sp.take());
let raw_resp_schema = resp.map(|sp| sp.take());

Ok(EndpointEncoding {
path,
methods,
Expand All @@ -249,8 +265,8 @@ pub fn describe_endpoint(
resp: resp_enc,
handshake: None,
raw_handshake_schema: None,
raw_req_schema: req,
raw_resp_schema: resp,
raw_req_schema,
raw_resp_schema,
})
}

Expand All @@ -275,9 +291,9 @@ fn describe_req(
tc: &TypeChecker,
methods: &Methods,
path: Option<&Path>,
req_schema: &Option<Type>,
req_schema: &Option<Sp<Type>>,
raw: bool,
) -> Result<(Vec<RequestEncoding>, Option<FieldMap>)> {
) -> anyhow::Result<(Vec<RequestEncoding>, Option<FieldMap>)> {
let Some(req_schema) = req_schema else {
// We don't have any request schema. This is valid if and only if
// we have no path parameters or it's a raw endpoint.
Expand Down Expand Up @@ -330,8 +346,8 @@ fn describe_req(
fn describe_resp(
tc: &TypeChecker,
_methods: &Methods,
resp_schema: &Option<Type>,
) -> Result<(ResponseEncoding, Option<FieldMap>)> {
resp_schema: &Option<Sp<Type>>,
) -> anyhow::Result<(ResponseEncoding, Option<FieldMap>)> {
let Some(resp_schema) = resp_schema else {
return Ok((ResponseEncoding { params: vec![] }, None));
};
Expand All @@ -350,15 +366,16 @@ fn describe_resp(

pub fn describe_auth_handler(
ctx: &ResolveState,
params: Type,
response: Type,
) -> Result<AuthHandlerEncoding> {
params: Sp<Type>,
response: Sp<Type>,
) -> AuthHandlerEncoding {
let (span, response) = response.split();
let response = unwrap_promise(ctx, &response).clone();

Ok(AuthHandlerEncoding {
AuthHandlerEncoding {
auth_param: params,
auth_data: response,
})
auth_data: Sp::new(span, response),
}
}

fn default_method(methods: &Methods) -> Method {
Expand Down Expand Up @@ -415,30 +432,50 @@ impl Field {
}
}

pub(crate) fn iface_fields<'a>(tc: &'a TypeChecker, typ: &'a Type) -> Result<FieldMap> {
fn to_fields(state: &ResolveState, iface: &Interface) -> Result<FieldMap> {
#[derive(Error, Debug)]
pub enum Error {
#[error("expected named interface type, found {0}")]
ExpectedNamedInterfaceType(String),
#[error("invalid custom type field")]
InvalidCustomType(#[source] anyhow::Error),
}

pub(crate) fn iface_fields<'a>(
tc: &'a TypeChecker,
typ: &'a Sp<Type>,
) -> Result<FieldMap, SpErr<Error>> {
fn to_fields<'a>(
state: &'a ResolveState,
iface: &'a Interface,
) -> Result<FieldMap, SpErr<Error>> {
let mut map = HashMap::new();
for f in &iface.fields {
if let FieldName::String(name) = &f.name {
map.insert(name.clone(), rewrite_custom_type_field(state, f, name)?);
map.insert(
name.clone(),
rewrite_custom_type_field(state, f, name)
.map_err(Error::InvalidCustomType)
.map_err(|e| e.with_span(f.range.into()))?,
);
}
}
Ok(map)
}

let span = typ.span();
let typ = unwrap_promise(tc.state(), typ);
match typ {
Type::Basic(Basic::Void) => Ok(HashMap::new()),
Type::Interface(iface) => to_fields(tc.state(), iface),
Type::Named(named) => {
let underlying = tc.underlying(named.obj.module_id, typ);
let underlying = Sp::new(span, tc.underlying(named.obj.module_id, typ));
iface_fields(tc, &underlying)
}
_ => anyhow::bail!("expected named interface type, found {:?}", typ),
_ => Err(Error::ExpectedNamedInterfaceType(format!("{typ:?}")).with_span(span)),
}
}

fn extract_path_params(path: &Path, fields: &mut FieldMap) -> Result<Vec<Param>> {
fn extract_path_params(path: &Path, fields: &mut FieldMap) -> anyhow::Result<Vec<Param>> {
let mut params = Vec::new();
for (index, seg) in path.dynamic_segments().enumerate() {
let name = seg.lit_or_name();
Expand Down Expand Up @@ -491,7 +528,7 @@ fn extract_loc_params(fields: &FieldMap, default_loc: ParamLocation) -> Vec<Para
params
}

fn rewrite_path_types(req: &RequestEncoding, path: Path, raw: bool) -> Result<Path> {
fn rewrite_path_types(req: &RequestEncoding, path: Path, raw: bool) -> anyhow::Result<Path> {
use crate::parser::respath::{Segment, ValueType};
// Get the path params into a map, keyed by name.
let path_params = req
Expand Down Expand Up @@ -539,7 +576,7 @@ fn rewrite_custom_type_field(
ctx: &ResolveState,
field: &InterfaceField,
field_name: &str,
) -> Result<Field> {
) -> anyhow::Result<Field> {
let standard_field = Field {
name: field_name.to_string(),
typ: field.typ.clone(),
Expand Down
10 changes: 4 additions & 6 deletions tsparser/src/parser/resources/apis/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,10 @@ pub static SERVICE_PARSER: ResourceParser = ResourceParser {

if i > 0 {
HANDLER.with(|h| {
h.struct_span_err(
h.span_err(
r.range,
"cannot have multiple service declarations in the same module",
)
.emit();
);
});
continue;
}
Expand All @@ -52,11 +51,10 @@ pub static SERVICE_PARSER: ResourceParser = ResourceParser {
FilePath::Real(buf) if buf.ends_with("encore.service.ts") => {}
_ => {
HANDLER.with(|h| {
h.struct_span_err(
h.span_err(
r.range,
"service declarations are only allowed in encore.service.ts",
)
.emit();
);
});
continue;
}
Expand Down
16 changes: 10 additions & 6 deletions tsparser/src/parser/types/type_resolve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::fmt::Debug;
use std::ops::Deref;
use std::rc::Rc;

use litparser::Sp;
use swc_common::errors::HANDLER;
use swc_common::sync::Lrc;
use swc_common::{Span, Spanned};
Expand Down Expand Up @@ -33,18 +34,21 @@ impl TypeChecker {
&self.ctx
}

pub fn resolve_type(&self, module: Lrc<module_loader::Module>, expr: &ast::TsType) -> Type {
pub fn resolve_type(&self, module: Lrc<module_loader::Module>, expr: &ast::TsType) -> Sp<Type> {
// Ensure the module is initialized.
let module_id = module.id;
_ = self.ctx.get_or_init_module(module);

let ctx = Ctx::new(&self.ctx, module_id);
let typ = ctx.typ(expr);
match ctx.concrete(&typ) {
New(typ) => typ,
Changed(typ) => typ.clone(),
Same(_) => typ,
}
Sp::new(
expr.span(),
match ctx.concrete(&typ) {
New(typ) => typ,
Changed(typ) => typ.clone(),
Same(_) => typ,
},
)
}

pub fn concrete(&self, module_id: ModuleId, typ: &Type) -> Type {
Expand Down
47 changes: 46 additions & 1 deletion tsparser/src/span_err.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use swc_common::{errors::HANDLER, Spanned};
use swc_common::{errors::HANDLER, Span, Spanned};

pub trait ErrReporter {
fn err(&self, msg: &str);
Expand All @@ -12,3 +12,48 @@ where
HANDLER.with(|h| h.span_err(self.span(), msg));
}
}

#[derive(Debug)]
pub struct SpErr<E> {
span: Span,
error: E,
}

impl<E> SpErr<E>
where
E: std::error::Error,
{
pub fn new(span: Span, error: E) -> Self {
SpErr { span, error }
}

pub fn report(&self) {
HANDLER.with(|handler| handler.span_err(self.span, &self.error.to_string()))
}
}

impl<E> std::error::Error for SpErr<E>
where
E: std::error::Error,
{
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.error.source()
}
}

impl<E> std::fmt::Display for SpErr<E>
where
E: std::fmt::Display,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(&self.error, f)
}
}

pub trait ErrorWithSpanExt: std::error::Error + Sized {
fn with_span(self, span: Span) -> SpErr<Self> {
SpErr::new(span, self)
}
}

impl<T> ErrorWithSpanExt for T where T: std::error::Error {}
Loading