diff --git a/crates/wasmtime/src/component/linker.rs b/crates/wasmtime/src/component/linker.rs index cf2bbb302cca..c97308fa0685 100644 --- a/crates/wasmtime/src/component/linker.rs +++ b/crates/wasmtime/src/component/linker.rs @@ -4,9 +4,11 @@ use crate::component::matching::TypeChecker; use crate::component::{Component, ComponentNamedList, Instance, InstancePre, Lift, Lower, Val}; use crate::{AsContextMut, Engine, Module, StoreContextMut}; use anyhow::{anyhow, bail, Context, Result}; +use indexmap::IndexMap; use std::collections::hash_map::{Entry, HashMap}; use std::future::Future; use std::marker; +use std::ops::Deref; use std::pin::Pin; use std::sync::Arc; use wasmtime_environ::component::TypeDef; @@ -22,6 +24,7 @@ pub struct Linker { engine: Engine, strings: Strings, map: NameMap, + path: Vec, allow_shadowing: bool, _marker: marker::PhantomData T>, } @@ -38,7 +41,9 @@ pub struct Strings { /// a "bag of named items", so each [`LinkerInstance`] can further define items /// internally. pub struct LinkerInstance<'a, T> { - engine: Engine, + engine: &'a Engine, + path: &'a mut Vec, + path_len: usize, strings: &'a mut Strings, map: &'a mut NameMap, allow_shadowing: bool, @@ -63,6 +68,7 @@ impl Linker { strings: Strings::default(), map: NameMap::default(), allow_shadowing: false, + path: Vec::new(), _marker: marker::PhantomData, } } @@ -85,7 +91,9 @@ impl Linker { /// the root namespace. pub fn root(&mut self) -> LinkerInstance<'_, T> { LinkerInstance { - engine: self.engine.clone(), + engine: &self.engine, + path: &mut self.path, + path_len: 0, strings: &mut self.strings, map: &mut self.map, allow_shadowing: self.allow_shadowing, @@ -230,7 +238,9 @@ impl Linker { impl LinkerInstance<'_, T> { fn as_mut(&mut self) -> LinkerInstance<'_, T> { LinkerInstance { - engine: self.engine.clone(), + engine: self.engine, + path: self.path, + path_len: self.path_len, strings: self.strings, map: self.map, allow_shadowing: self.allow_shadowing, @@ -310,21 +320,39 @@ impl LinkerInstance<'_, T> { name: &str, func: F, ) -> Result<()> { - for (import_name, ty) in component.env_component().import_types.values() { - if name == import_name { - if let TypeDef::ComponentFunc(index) = ty { - let name = self.strings.intern(name); - return self.insert( - name, - Definition::Func(HostFunc::new_dynamic(func, *index, component.types())), - ); + let mut map = &component + .env_component() + .import_types + .values() + .map(|(k, v)| (k.clone(), *v)) + .collect::>(); + + for name in self.path.iter().copied().take(self.path_len) { + let name = self.strings.strings[name].deref(); + if let Some(ty) = map.get(name) { + if let TypeDef::ComponentInstance(index) = ty { + map = &component.types()[*index].exports; } else { - bail!("import `{name}` has the wrong type (expected a function)"); + bail!("import `{name}` has the wrong type (expected a component instance)"); } + } else { + bail!("import `{name}` not found"); } } - Err(anyhow!("import `{name}` not found")) + if let Some(ty) = map.get(name) { + if let TypeDef::ComponentFunc(index) = ty { + let name = self.strings.intern(name); + return self.insert( + name, + Definition::Func(HostFunc::new_dynamic(func, *index, component.types())), + ); + } else { + bail!("import `{name}` has the wrong type (expected a function)"); + } + } else { + Err(anyhow!("import `{name}` not found")) + } } // TODO: define func_new_async @@ -367,6 +395,9 @@ impl LinkerInstance<'_, T> { Definition::Instance(map) => map, _ => unreachable!(), }; + self.path.truncate(self.path_len); + self.path.push(name); + self.path_len += 1; Ok(self) } diff --git a/tests/all/component_model/import.rs b/tests/all/component_model/import.rs index be5ba39bedf5..705348e9dfa7 100644 --- a/tests/all/component_model/import.rs +++ b/tests/all/component_model/import.rs @@ -169,6 +169,111 @@ fn simple() -> Result<()> { Ok(()) } +#[test] +fn functions_in_instances() -> Result<()> { + let component = r#" + (component + (type $import-type (instance + (export "a" (func (param "a" string))) + )) + (import (interface "test:test/foo") (instance $import (type $import-type))) + (alias export $import "a" (func $log)) + + (core module $libc + (memory (export "memory") 1) + + (func (export "realloc") (param i32 i32 i32 i32) (result i32) + unreachable) + ) + (core instance $libc (instantiate $libc)) + (core func $log_lower + (canon lower (func $log) (memory $libc "memory") (realloc (func $libc "realloc"))) + ) + (core module $m + (import "libc" "memory" (memory 1)) + (import "host" "log" (func $log (param i32 i32))) + + (func (export "call") + i32.const 5 + i32.const 11 + call $log) + + (data (i32.const 5) "hello world") + ) + (core instance $i (instantiate $m + (with "libc" (instance $libc)) + (with "host" (instance (export "log" (func $log_lower)))) + )) + (func $call + (canon lift (core func $i "call")) + ) + (component $c + (import "import-call" (func $f)) + (export "call" (func $f)) + ) + (instance $export (instantiate $c + (with "import-call" (func $call)) + )) + (export (interface "test:test/foo") (instance $export)) + ) + "#; + + let engine = super::engine(); + let component = Component::new(&engine, component)?; + let mut store = Store::new(&engine, None); + assert!(store.data().is_none()); + + // First, test the static API + + let mut linker = Linker::new(&engine); + linker.instance("test:test/foo")?.func_wrap( + "a", + |mut store: StoreContextMut<'_, Option>, (arg,): (WasmStr,)| -> Result<_> { + let s = arg.to_str(&store)?.to_string(); + assert!(store.data().is_none()); + *store.data_mut() = Some(s); + Ok(()) + }, + )?; + let instance = linker.instantiate(&mut store, &component)?; + let func = instance + .exports(&mut store) + .instance("test:test/foo") + .unwrap() + .typed_func::<(), ()>("call")?; + func.call(&mut store, ())?; + assert_eq!(store.data().as_ref().unwrap(), "hello world"); + + // Next, test the dynamic API + + *store.data_mut() = None; + let mut linker = Linker::new(&engine); + linker.instance("test:test/foo")?.func_new( + &component, + "a", + |mut store: StoreContextMut<'_, Option>, args, _results| { + if let Val::String(s) = &args[0] { + assert!(store.data().is_none()); + *store.data_mut() = Some(s.to_string()); + Ok(()) + } else { + panic!() + } + }, + )?; + let instance = linker.instantiate(&mut store, &component)?; + let func = instance + .exports(&mut store) + .instance("test:test/foo") + .unwrap() + .func("call") + .unwrap(); + func.call(&mut store, &[], &mut [])?; + assert_eq!(store.data().as_ref().unwrap(), "hello world"); + + Ok(()) +} + #[test] fn attempt_to_leave_during_malloc() -> Result<()> { let component = r#"