Skip to content

Commit

Permalink
Fix handling for poll calls to non-generator futures (#161)
Browse files Browse the repository at this point in the history
## What Changed?

No longer attempts to use the "resolve original arguments" kind of async
handling on futures that have explicit `poll` implementations, as
opposed to being auto-implemented (e.g. `async fn` and async blocks).

Removes a loop in async arg resolution, as it was no longer exercised.

## Why Does It Need To?

Avoids spurious errors in trying to resolve async arguments.

Fixes #159 

## Checklist

- [x] Above description has been filled out so that upon quash merge we
have a
  good record of what changed.
- [x] New functions, methods, types are documented. Old documentation is
updated
  if necessary
- [ ] Documentation in Notion has been updated
- [x] Tests for new behaviors are provided
  - [ ] New test suites (if any) ave been added to the CI tests (in
`.github/workflows/rust.yml`) either as compiler test or integration
test.
*Or* justification for their omission from CI has been provided in this
PR
    description.
  • Loading branch information
JustusAdam authored Jul 23, 2024
1 parent 2d832e1 commit d53df7e
Show file tree
Hide file tree
Showing 4 changed files with 276 additions and 67 deletions.
141 changes: 90 additions & 51 deletions crates/flowistry_pdg_construction/src/async_support.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@ pub enum AsyncType {
Trait,
}

/// Context for a call to [`Future::poll`](std::future::Future::poll), when
/// called on a future created via an `async fn` or an async block.
pub struct AsyncFnPollEnv<'tcx> {
/// If the generator came from an `async fn`, then this is that function. If
/// it is from an async block, this is `None`.
pub async_fn_parent: Option<Instance<'tcx>>,
/// Where was the `async fn` called, or where was the async block created.
pub creation_loc: Location,
/// A place which carries the runtime value representing the generator in
/// the caller.
pub generator_data: Place<'tcx>,
}

/// Stores ids that are needed to construct projections around async functions.
pub(crate) struct AsyncInfo {
pub poll_ready_variant_idx: VariantIdx,
Expand Down Expand Up @@ -184,18 +197,34 @@ pub enum AsyncDeterminationResult<T> {
NotAsync,
}

/// Does this instance refer to an `async fn` or `async {}`.
fn is_async_fn_or_block(tcx: TyCtxt, instance: Instance) -> bool {
// It turns out that the `DefId` of the [`poll`](std::future::Future::poll)
// impl for an `async fn` or async block is the same as the `DefId` of the
// generator itself. That means after resolution (e.g. on the `Instance`) we
// only need to call `tcx.generator_is_async`.
tcx.generator_is_async(instance.def_id())
}

impl<'tcx, 'mir> LocalAnalysis<'tcx, 'mir> {
/// Checks whether the function call, described by the unresolved `def_id`
/// and the resolved instance `resolved_fn` is a call to [`<T as
/// Future>::poll`](std::future::Future::poll) where `T` is the type of an
/// `async fn` or `async {}` created generator.
///
/// Resolves the original arguments that constituted the generator.
pub(crate) fn try_poll_call_kind<'a>(
&'a self,
def_id: DefId,
resolved_fn: Instance<'tcx>,
original_args: &'a [Operand<'tcx>],
) -> AsyncDeterminationResult<CallKind<'tcx>> {
let lang_items = self.tcx().lang_items();
if lang_items.future_poll_fn() == Some(def_id) {
if lang_items.future_poll_fn() == Some(def_id)
&& is_async_fn_or_block(self.tcx(), resolved_fn)
{
match self.find_async_args(original_args) {
Ok((fun, loc, args)) => {
AsyncDeterminationResult::Resolved(CallKind::AsyncPoll(fun, loc, args))
}
Ok(poll) => AsyncDeterminationResult::Resolved(CallKind::AsyncPoll(poll)),
Err(str) => AsyncDeterminationResult::Unresolvable(str),
}
} else {
Expand All @@ -207,12 +236,18 @@ impl<'tcx, 'mir> LocalAnalysis<'tcx, 'mir> {
fn find_async_args<'a>(
&'a self,
args: &'a [Operand<'tcx>],
) -> Result<(Instance<'tcx>, Location, Place<'tcx>), String> {
) -> Result<AsyncFnPollEnv<'tcx>, String> {
macro_rules! let_assert {
($p:pat = $e:expr, $($arg:tt)*) => {
let $p = $e else {
let msg = format!($($arg)*);
return Err(format!("Abandoning attempt to handle async because pattern {} could not be matched to {:?}: {}", stringify!($p), $e, msg));
return Err(format!(
"Abandoning attempt to handle async because pattern {} (line {}) could not be matched to {:?}: {}",
stringify!($p),
line!(),
$e,
msg
));
};
}
}
Expand Down Expand Up @@ -267,56 +302,60 @@ impl<'tcx, 'mir> LocalAnalysis<'tcx, 'mir> {
"Assignment to alias of pin::new input is not a call"
);

let mut chase_target = Err(&into_future_args[0]);

while let Err(target) = chase_target {
let async_fn_call_loc = get_def_for_op(target)?;
let stmt = &self.mono_body.stmt_at(async_fn_call_loc);
chase_target = match stmt {
Either::Right(Terminator {
kind:
TerminatorKind::Call {
func, destination, ..
},
..
}) => {
let (op, generics) = self.operand_to_def_id(func).unwrap();
Ok((op, generics, *destination, async_fn_call_loc))
let target = &into_future_args[0];
let creation_loc = get_def_for_op(target)?;
let stmt = &self.mono_body.stmt_at(creation_loc);
let (op, generics, generator_data) = match stmt {
Either::Right(Terminator {
kind:
TerminatorKind::Call {
func, destination, ..
},
..
}) => {
let (op, generics) = self.operand_to_def_id(func).unwrap();
(Some(op), generics, *destination)
}
Either::Left(Statement { kind, .. }) => match kind {
StatementKind::Assign(box (
lhs,
Rvalue::Aggregate(box AggregateKind::Generator(def_id, generic_args, _), _args),
)) => {
assert!(self.tcx().generator_is_async(*def_id));
(None, *generic_args, *lhs)
}
StatementKind::Assign(box (_, Rvalue::Use(target))) => {
let generics = self
.operand_to_def_id(target)
.ok_or_else(|| "Nope".to_string())?
.1;
(None, generics, target.place().unwrap())
}
Either::Left(Statement { kind, .. }) => match kind {
StatementKind::Assign(box (
lhs,
Rvalue::Aggregate(
box AggregateKind::Generator(def_id, generic_args, _),
_args,
),
)) => Ok((*def_id, *generic_args, *lhs, async_fn_call_loc)),
StatementKind::Assign(box (_, Rvalue::Use(target))) => {
let (op, generics) = self
.operand_to_def_id(target)
.ok_or_else(|| "Nope".to_string())?;
Ok((op, generics, target.place().unwrap(), async_fn_call_loc))
}
_ => {
panic!("Assignment to into_future input is not a call: {stmt:?}");
}
},
_ => {
panic!("Assignment to into_future input is not a call: {stmt:?}");
}
};
}

let (op, generics, calling_convention, async_fn_call_loc) = chase_target.unwrap();
},
_ => {
panic!("Assignment to into_future input is not a call: {stmt:?}");
}
};

let resolution = utils::try_resolve_function(
self.tcx(),
op,
self.tcx().param_env_reveal_all_normalized(self.def_id),
generics,
)
.ok_or_else(|| "Instance resolution failed".to_string())?;
let async_fn_parent = op
.map(|def_id| {
utils::try_resolve_function(
self.tcx(),
def_id,
self.tcx().param_env_reveal_all_normalized(self.def_id),
generics,
)
.ok_or_else(|| "Instance resolution failed".to_string())
})
.transpose()?;

Ok((resolution, async_fn_call_loc, calling_convention))
Ok(AsyncFnPollEnv {
async_fn_parent,
creation_loc,
generator_data,
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ impl<'tcx, 'a> CallingConvention<'tcx, 'a> {
args: &'a [Operand<'tcx>],
) -> CallingConvention<'tcx, 'a> {
match kind {
CallKind::AsyncPoll(_, _, ctx) => CallingConvention::Async(*ctx),
CallKind::AsyncPoll(poll) => CallingConvention::Async(poll.generator_data),
CallKind::Direct => CallingConvention::Direct(args),
CallKind::Indirect => CallingConvention::Indirect {
closure_arg: &args[0],
Expand Down
52 changes: 40 additions & 12 deletions crates/flowistry_pdg_construction/src/local_analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,13 +387,41 @@ impl<'tcx, 'a> LocalAnalysis<'tcx, 'a> {
let Some(resolved_fn) =
utils::try_resolve_function(self.tcx(), called_def_id, param_env, generic_args)
else {
if let Some(d) = generic_args.iter().find(|arg| matches!(arg.unpack(), GenericArgKind::Type(t) if matches!(t.kind(), TyKind::Dynamic(..)))) {
self.tcx().sess.span_warn(self.tcx().def_span(called_def_id), format!("could not resolve instance due to dynamic argument: {d:?}"));
return None;
let dynamics = generic_args.iter()
.flat_map(|g| g.walk())
.filter(|arg| matches!(arg.unpack(), GenericArgKind::Type(t) if matches!(t.kind(), TyKind::Dynamic(..))))
.collect::<Box<[_]>>();
let mut msg = format!(
"instance resolution for call to function {} failed.",
tcx.def_path_str(called_def_id)
);
if !dynamics.is_empty() {
use std::fmt::Write;
write!(msg, " Dynamic arguments ").unwrap();
let mut first = true;
for dyn_ in dynamics.iter() {
if !first {
write!(msg, ", ").unwrap();
}
first = false;
write!(msg, "`{dyn_}`").unwrap();
}
write!(
msg,
" were found.\n\
These may have been injected by Paralegal to instantiate generics \n\
at the entrypoint (location of #[paralegal::analyze]).\n\
A likely reason why this may cause this resolution to fail is if the\n\
method or function this attempts to resolve has a `Sized` constraint.\n\
Such a constraint can be implicit if this is a type variable in a\n\
trait definition and no refutation (`?Sized` constraint) is present."
)
.unwrap();
self.tcx().sess.span_warn(span, msg);
} else {
tcx.sess.span_err(span, "instance resolution failed: too unspecific");
return None;
self.tcx().sess.span_err(span, msg);
}
return None;
};
let resolved_def_id = resolved_fn.def_id();
if log_enabled!(Level::Trace) && called_def_id != resolved_def_id {
Expand All @@ -410,7 +438,7 @@ impl<'tcx, 'a> LocalAnalysis<'tcx, 'a> {
return Some(CallHandling::ApproxAsyncSM(handler));
};

let call_kind = self.classify_call_kind(called_def_id, resolved_def_id, args, span);
let call_kind = self.classify_call_kind(called_def_id, resolved_fn, args, span);

let calling_convention = CallingConvention::from_call_kind(&call_kind, args);

Expand Down Expand Up @@ -439,7 +467,7 @@ impl<'tcx, 'a> LocalAnalysis<'tcx, 'a> {
callee: resolved_fn,
call_string: self.make_call_string(location),
is_cached,
async_parent: if let CallKind::AsyncPoll(resolution, _loc, _) = call_kind {
async_parent: if let CallKind::AsyncPoll(poll) = call_kind {
// Special case for async. We ask for skipping not on the closure, but
// on the "async" function that created it. This is needed for
// consistency in skipping. Normally, when "poll" is inlined, mutations
Expand All @@ -448,7 +476,7 @@ impl<'tcx, 'a> LocalAnalysis<'tcx, 'a> {
// those mutations to occur. To ensure this we always ask for the
// "CallChanges" on the creator so that both creator and closure have
// the same view of whether they are inlined or "Skip"ped.
Some(resolution)
poll.async_fn_parent
} else {
None
},
Expand Down Expand Up @@ -637,14 +665,14 @@ impl<'tcx, 'a> LocalAnalysis<'tcx, 'a> {
fn classify_call_kind<'b>(
&'b self,
def_id: DefId,
resolved_def_id: DefId,
resolved_fn: Instance<'tcx>,
original_args: &'b [Operand<'tcx>],
span: Span,
) -> CallKind<'tcx> {
match self.try_poll_call_kind(def_id, original_args) {
match self.try_poll_call_kind(def_id, resolved_fn, original_args) {
AsyncDeterminationResult::Resolved(r) => r,
AsyncDeterminationResult::NotAsync => self
.try_indirect_call_kind(resolved_def_id)
.try_indirect_call_kind(resolved_fn.def_id())
.unwrap_or(CallKind::Direct),
AsyncDeterminationResult::Unresolvable(reason) => {
self.tcx().sess.span_fatal(span, reason)
Expand Down Expand Up @@ -747,7 +775,7 @@ pub enum CallKind<'tcx> {
/// A call to a function variable, like `fn foo(f: impl Fn()) { f() }`
Indirect,
/// A poll to an async function, like `f.await`.
AsyncPoll(Instance<'tcx>, Location, Place<'tcx>),
AsyncPoll(AsyncFnPollEnv<'tcx>),
}

#[derive(strum::AsRefStr)]
Expand Down
Loading

0 comments on commit d53df7e

Please sign in to comment.