Skip to content

Commit

Permalink
Add endpoints to UWABI extensions (project-oak#2506)
Browse files Browse the repository at this point in the history
Add endpoints to UWABI extensions, add getter/setter for the endpoint in an UWABI extension with TODOs to remove them being exposed when the design allows, extend testing helpers with drop_extension, and finding extension by channel handle.
  • Loading branch information
mariaschett authored Jan 27, 2022
1 parent 4f7e1bf commit 3dca7f2
Showing 1 changed file with 79 additions and 32 deletions.
111 changes: 79 additions & 32 deletions oak_functions/loader/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,28 @@ pub trait ExtensionFactory {
/// `Uwabi` extension called by listening to a channel.
pub enum BoxedExtension {
Native(Box<dyn OakApiNativeExtension + Send + Sync>),
Uwabi(Box<dyn UwabiExtension + Send + Sync>),
Uwabi(BoxedUwabiExtension),
}

pub type BoxedUwabiExtension = Box<dyn UwabiExtension + Send + Sync>;

pub type BoxedExtensionFactory = Box<dyn ExtensionFactory + Send + Sync>;

/// Trait for implementing an extension which relies on UWABI.
pub trait UwabiExtension {
/// Get the channel handle to address this extension.
fn get_channel_handle(&self) -> ChannelHandle;

/// Get the endpoint.
// TODO(#2508): Stop exposing the endpoint for an extension as soon as we have a way the
// extension handles how it reads/writes into the endpoint.
fn get_endpoint_mut(&mut self) -> Option<&mut Endpoint>;

/// Set the endpoint if it has not been set before.
// TODO(#2510) We cannot set the endpoint when we `create` the extension, as this would require
// to change the `BoxedExtensionFactory` trait. This helps to keep the changes to the
// (existing) Native extensions minimal.
fn set_endpoint(&mut self, endpoint: Endpoint);
}

/// `WasmState` holds runtime values for a particular execution instance of Wasm, handling a
Expand All @@ -181,9 +194,8 @@ pub struct WasmState {
extensions_metadata: HashMap<String, (usize, wasmi::Signature)>,
/// A mapping from channel handles to the hosted endpoints of channels.
channel_switchboard: ChannelSwitchboard,
/// A mapping from channel handles to the endpoints for the Oak functions runtime.
// TODO(#2502) Add extension endpoints to the corresponding endpoint.
extensions_endpoints: HashMap<ChannelHandle, Endpoint>,
/// A list of UWABI extensions.
uwabi_extensions: Vec<BoxedUwabiExtension>,
}

impl WasmState {
Expand Down Expand Up @@ -496,7 +508,7 @@ impl WasmState {
extensions_indices: HashMap<usize, BoxedExtension>,
extensions_metadata: HashMap<String, (usize, wasmi::Signature)>,
channel_switchboard: ChannelSwitchboard,
extensions_endpoints: HashMap<ChannelHandle, Endpoint>,
uwabi_extensions: Vec<BoxedUwabiExtension>,
) -> anyhow::Result<WasmState> {
let mut abi = WasmState {
request_bytes,
Expand All @@ -507,7 +519,7 @@ impl WasmState {
extensions_indices: Some(extensions_indices),
extensions_metadata,
channel_switchboard,
extensions_endpoints,
uwabi_extensions,
};

let instance = wasmi::ModuleInstance::new(
Expand Down Expand Up @@ -682,10 +694,7 @@ impl WasmHandler {
fn init(&self, request_bytes: Vec<u8>) -> anyhow::Result<WasmState> {
let mut extensions_indices = HashMap::new();
let mut extensions_metadata = HashMap::new();
// Remove the extension_endpoints map as soon as we extend extensions by functionality for
// channels. Until then we use the ChannelHandle as keys.
// TODO(#2502).
let mut extensions_endpoints = HashMap::new();
let mut uwabi_extensions: Vec<BoxedUwabiExtension> = vec![];

let mut channel_switchboard = ChannelSwitchboard::new();

Expand All @@ -697,10 +706,11 @@ impl WasmHandler {
extensions_indices.insert(ind + EXTENSION_INDEX_OFFSET, extension);
extensions_metadata.insert(name, (ind + EXTENSION_INDEX_OFFSET, signature));
}
BoxedExtension::Uwabi(uwabi_extension) => {
BoxedExtension::Uwabi(mut uwabi_extension) => {
let channel_handle = uwabi_extension.get_channel_handle();
let endpoint = channel_switchboard.register(channel_handle);
extensions_endpoints.insert(channel_handle, endpoint);
uwabi_extension.set_endpoint(endpoint);
uwabi_extensions.push(uwabi_extension);
}
}
}
Expand All @@ -712,7 +722,7 @@ impl WasmHandler {
extensions_indices,
extensions_metadata,
channel_switchboard,
extensions_endpoints,
uwabi_extensions,
)
}

Expand Down Expand Up @@ -825,7 +835,7 @@ pub fn format_bytes(v: &[u8]) -> String {

// The Endpoint of a bidirectional channel.
#[derive(Debug)]
struct Endpoint {
pub struct Endpoint {
sender: Sender<AbiMessage>,
receiver: Receiver<AbiMessage>,
}
Expand Down Expand Up @@ -895,6 +905,7 @@ mod tests {
fn create(&self) -> anyhow::Result<BoxedExtension> {
let extension = TestingExtension {
logger: self.logger.clone(),
endpoint: None,
};
Ok(BoxedExtension::Uwabi(Box::new(extension)))
}
Expand All @@ -903,12 +914,26 @@ mod tests {
#[allow(dead_code)]
pub struct TestingExtension {
logger: Logger,
endpoint: Option<Endpoint>,
}

impl UwabiExtension for TestingExtension {
fn get_channel_handle(&self) -> oak_functions_abi::proto::ChannelHandle {
ChannelHandle::Testing
}

fn get_endpoint_mut(&mut self) -> Option<&mut Endpoint> {
match &mut self.endpoint {
Some(endpoint) => Some(endpoint),
None => None,
}
}

fn set_endpoint(&mut self, endpoint: Endpoint) {
if self.endpoint.is_none() {
self.endpoint = Some(endpoint);
}
}
}

#[test]
Expand Down Expand Up @@ -985,12 +1010,11 @@ mod tests {

#[tokio::test]
async fn test_hosted_channel_read_channel_closed() {
let channel_handle = ChannelHandle::Testing as i32;
let channel_handle = ChannelHandle::Testing;
let mut wasm_state = create_test_wasm_state();
let extension_endpoints = &mut wasm_state.extensions_endpoints;
// Remove the endpoint of the runtime closes one endpoint of the channel.
extension_endpoints.remove(&ChannelHandle::Testing);
let result = wasm_state.channel_read(channel_handle, 0, 0);
// Remove the extension to close one endpoint of the channel.
drop_extension(&mut wasm_state, channel_handle);
let result = wasm_state.channel_read(channel_handle as i32, 0, 0);
assert!(result.is_err());
assert_eq!(
ChannelStatus::ChannelEndpointDisconnected,
Expand Down Expand Up @@ -1087,9 +1111,8 @@ mod tests {
async fn test_hosted_channel_write_channel_closed() {
let channel_handle = ChannelHandle::Testing;
let mut wasm_state = create_test_wasm_state();
let extension_endpoints = &mut wasm_state.extensions_endpoints;
// Remove the endpoint of the runtime closes one endpoint of the channel.
extension_endpoints.remove(&channel_handle);
// Remove the extension to close one endpoint of the channel.
drop_extension(&mut wasm_state, channel_handle);
let result = wasm_state.channel_write(channel_handle as i32, 0, 0);
assert!(result.is_err());
assert_eq!(ChannelStatus::ChannelEndpointClosed, result.unwrap_err());
Expand All @@ -1114,29 +1137,53 @@ mod tests {
.expect("could not create wasm_state")
}

// Helper function to read from Endpoint associated to ChannelHandle extension in the runtime.
// Helper function for testing to drop the UWABI extension for the given ChannelHandle.
fn drop_extension(wasm_state: &mut WasmState, channel_handle: ChannelHandle) {
wasm_state
.uwabi_extensions
.retain(|uwabi_extension| uwabi_extension.get_channel_handle() != channel_handle);
}

// Helper function for testing to read from Endpoint associated to ChannelHandle extension in
// the runtime.
async fn read_from_runtime_endpoint(
wasm_state: &mut WasmState,
channel_handle: ChannelHandle,
) -> Vec<u8> {
let endpoint = wasm_state
.extensions_endpoints
.get_mut(&channel_handle)
.unwrap();
let endpoint = runtime_endpoint_for_channel_handle(wasm_state, channel_handle);
endpoint.receiver.try_recv().unwrap()
}

// Helper function to write to Endpoint associated to ChannelHandle extension in the runtime.
// Helper function for testing to write to Endpoint associated to ChannelHandle extension in the
// runtime.
async fn write_to_runtime_endpoint(
wasm_state: &mut WasmState,
channel_handle: ChannelHandle,
message: AbiMessage,
) {
let endpoint = wasm_state
.extensions_endpoints
.get_mut(&channel_handle)
.unwrap();
let endpoint = runtime_endpoint_for_channel_handle(wasm_state, channel_handle);
let result = endpoint.sender.send(message.to_vec().clone()).await;
assert!(result.is_ok());
}

// Helper function for testing to find the Endpoint associated to ChannelHandle in the runtime.
fn runtime_endpoint_for_channel_handle(
wasm_state: &mut WasmState,
channel_handle: ChannelHandle,
) -> &mut Endpoint {
// Find extension associated to ChannelHandle in WasmState.
let extension = wasm_state
.uwabi_extensions
.iter_mut()
.find(|uwabi_extension| {
let channel_handle_of_extension = uwabi_extension.get_channel_handle();
channel_handle_of_extension == channel_handle
})
.expect("No extension for channel handle.");

// Get endpoint from the extension.
extension
.get_endpoint_mut()
.expect("No endpoint set for extension.")
}
}

0 comments on commit 3dca7f2

Please sign in to comment.