diff --git a/third_party/move/move-vm/integration-tests/src/tests/loader_tests.rs b/third_party/move/move-vm/integration-tests/src/tests/loader_tests.rs index b3cdf389e1142..a9aaa77560317 100644 --- a/third_party/move/move-vm/integration-tests/src/tests/loader_tests.rs +++ b/third_party/move/move-vm/integration-tests/src/tests/loader_tests.rs @@ -5,7 +5,8 @@ use crate::compiler::compile_modules_in_file; use move_binary_format::{ file_format::{ - empty_module, AddressIdentifierIndex, IdentifierIndex, ModuleHandle, TableIndex, + empty_module, AbilitySet, AddressIdentifierIndex, IdentifierIndex, ModuleHandle, + ModuleHandleIndex, StructHandle, StructTypeParameter, TableIndex, }, CompiledModule, }; @@ -196,6 +197,37 @@ fn load_concurrent_many() { adapter.call_functions_async(30); } +#[test] +fn load_phantom_module() { + let data_store = InMemoryStorage::new(); + let mut adapter = Adapter::new(data_store); + let modules = get_modules(); + adapter.publish_modules(modules); + + let mut module = empty_module(); + module.address_identifiers[0] = WORKING_ACCOUNT; + module.identifiers[0] = Identifier::new("I").unwrap(); + module.identifiers.push(Identifier::new("H").unwrap()); + module.module_handles.push(ModuleHandle { + address: AddressIdentifierIndex(0), + name: IdentifierIndex((module.identifiers.len() - 1) as TableIndex), + }); + module.identifiers.push(Identifier::new("S").unwrap()); + module.struct_handles.push(StructHandle { + module: ModuleHandleIndex((module.module_handles.len() - 1) as TableIndex), + name: IdentifierIndex((module.identifiers.len() - 1) as TableIndex), + abilities: AbilitySet::EMPTY, + type_parameters: vec![StructTypeParameter { + constraints: AbilitySet::EMPTY, + is_phantom: false, + }], + }); + + let module_id = module.self_id(); + adapter.publish_modules(vec![module]); + adapter.vm.load_module(&module_id, &adapter.store).unwrap(); +} + #[test] fn deep_dependency_list_err_0() { let data_store = InMemoryStorage::new(); diff --git a/third_party/move/move-vm/integration-tests/src/tests/loader_tests_modules.move b/third_party/move/move-vm/integration-tests/src/tests/loader_tests_modules.move index dcdc5a4764034..ea26b7a2c55bc 100644 --- a/third_party/move/move-vm/integration-tests/src/tests/loader_tests_modules.move +++ b/third_party/move/move-vm/integration-tests/src/tests/loader_tests_modules.move @@ -170,4 +170,9 @@ address 0x2 { another_b } } + module H { + struct S { + f1: u64, + } + } } diff --git a/third_party/move/move-vm/runtime/src/loader/mod.rs b/third_party/move/move-vm/runtime/src/loader/mod.rs index 53d3906f7fdbe..cad0f0d4127cb 100644 --- a/third_party/move/move-vm/runtime/src/loader/mod.rs +++ b/third_party/move/move-vm/runtime/src/loader/mod.rs @@ -1542,22 +1542,12 @@ impl Script { let struct_name = script.identifier_at(struct_handle.name); let module_handle = script.module_handle_at(struct_handle.module); let module_id = script.module_id_for_handle(module_handle); - let struct_ = cache + cache .resolve_struct_by_name(struct_name, &module_id) + .map_err(|err| err.finish(Location::Script))? + .check_compatibility(struct_handle) .map_err(|err| err.finish(Location::Script))?; - if !struct_handle.abilities.is_subset(struct_.abilities) - || !struct_handle - .type_parameters - .iter() - .map(|ty| ty.is_phantom) - .eq(struct_.phantom_ty_args_mask.iter().cloned()) - { - return Err( - PartialVMError::new(StatusCode::UNKNOWN_INVARIANT_VIOLATION_ERROR) - .with_message("Ability definition of module mismatch".to_string()) - .finish(Location::Script), - ); - } + struct_names.push(Arc::new(StructIdentifier { module: module_id, name: struct_name.to_owned(), diff --git a/third_party/move/move-vm/runtime/src/loader/modules.rs b/third_party/move/move-vm/runtime/src/loader/modules.rs index 1979923d8dfc9..232721b570880 100644 --- a/third_party/move/move-vm/runtime/src/loader/modules.rs +++ b/third_party/move/move-vm/runtime/src/loader/modules.rs @@ -273,19 +273,9 @@ impl Module { let module_id = module.module_id_for_handle(module_handle); if module_handle != module.self_handle() { - let struct_ = cache.resolve_struct_by_name(struct_name, &module_id)?; - if !struct_handle.abilities.is_subset(struct_.abilities) - || !struct_handle - .type_parameters - .iter() - .map(|ty| ty.is_phantom) - .eq(struct_.phantom_ty_args_mask.iter().cloned()) - { - return Err(PartialVMError::new( - StatusCode::UNKNOWN_INVARIANT_VIOLATION_ERROR, - ) - .with_message("Ability definition of module mismatch".to_string())); - } + cache + .resolve_struct_by_name(struct_name, &module_id)? + .check_compatibility(struct_handle)?; } struct_names.push(Arc::new(StructIdentifier { module: module_id, diff --git a/third_party/move/move-vm/types/src/loaded_data/runtime_types.rs b/third_party/move/move-vm/types/src/loaded_data/runtime_types.rs index db100cb542bf0..fe27a728d5ed0 100644 --- a/third_party/move/move-vm/types/src/loaded_data/runtime_types.rs +++ b/third_party/move/move-vm/types/src/loaded_data/runtime_types.rs @@ -5,7 +5,9 @@ use derivative::Derivative; use move_binary_format::{ errors::{PartialVMError, PartialVMResult}, - file_format::{AbilitySet, SignatureToken, StructTypeParameter, TypeParameterIndex}, + file_format::{ + AbilitySet, SignatureToken, StructHandle, StructTypeParameter, TypeParameterIndex, + }, }; use move_core_types::{ gas_algebra::AbstractMemorySize, identifier::Identifier, language_storage::ModuleId, @@ -122,6 +124,34 @@ impl StructType { pub fn type_param_constraints(&self) -> impl ExactSizeIterator { self.type_parameters.iter().map(|param| ¶m.constraints) } + + // Check if the local struct handle is compatible with the defined struct type. + pub fn check_compatibility(&self, struct_handle: &StructHandle) -> PartialVMResult<()> { + if !self.abilities.is_subset(struct_handle.abilities) { + return Err( + PartialVMError::new(StatusCode::UNKNOWN_INVARIANT_VIOLATION_ERROR) + .with_message("Ability definition of module mismatch".to_string()), + ); + } + + if self.phantom_ty_args_mask.len() != struct_handle.type_parameters.len() + || !self + .phantom_ty_args_mask + .iter() + .zip(struct_handle.type_parameters.iter()) + .all(|(defined_is_phantom, local_type_parameter)| { + !local_type_parameter.is_phantom || *defined_is_phantom + }) + { + return Err( + PartialVMError::new(StatusCode::UNKNOWN_INVARIANT_VIOLATION_ERROR).with_message( + "Phantom type parameter definition of module mismatch".to_string(), + ), + ); + } + + Ok(()) + } } #[derive(Debug, Copy, Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]