Skip to content

Commit

Permalink
Integrate cache validation
Browse files Browse the repository at this point in the history
  • Loading branch information
DJMcNab committed Mar 18, 2024
1 parent b93f4fd commit e973ce8
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 17 deletions.
20 changes: 19 additions & 1 deletion wgpu-core/src/device/global.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2303,6 +2303,7 @@ impl Global {
}

pub fn pipeline_cache_get_data<A: HalApi>(&self, id: id::PipelineCacheId) -> Option<Vec<u8>> {
use crate::pipeline_cache;
api_log!("PipelineCache::get_data");
let hub = A::hub(self);

Expand All @@ -2312,7 +2313,24 @@ impl Global {
return None;
}
if let Some(raw_cache) = cache.raw.as_ref() {
return unsafe { cache.device.raw().pipeline_cache_get_data(raw_cache) };
let vec = unsafe { cache.device.raw().pipeline_cache_get_data(raw_cache) };
let Some(mut vec) = vec else { return None };
let Some(validation_key) = cache.device.raw().pipeline_cache_validation_key()
else {
return None;
};
let mut header_contents = [0; pipeline_cache::HEADER_LENGTH];
pipeline_cache::add_cache_header(
&mut header_contents,
&vec,
&cache.device.adapter.raw.info,
validation_key,
);

let deleted = vec.splice(..1, header_contents).collect::<Vec<_>>();
debug_assert!(deleted.is_empty());

return Some(vec);
}
}
None
Expand Down
28 changes: 22 additions & 6 deletions wgpu-core/src/device/resource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3442,17 +3442,33 @@ impl<A: HalApi> Device<A> {
self: &Arc<Self>,
desc: &pipeline::PipelineCacheDescriptor,
) -> Result<pipeline::PipelineCache<A>, pipeline::CreatePipelineCacheError> {
use crate::pipeline_cache;
self.require_features(wgt::Features::PIPELINE_CACHE)?;
let mut cache_desc = hal::PipelineCacheDescriptor {
data: desc.data.as_deref(),
let data = if let Some((data, validation_key)) = desc
.data
.as_ref()
.zip(self.raw().pipeline_cache_validation_key())
{
let data = pipeline_cache::validate_pipeline_cache(
&data,
&self.adapter.raw.info,
validation_key,
);
match data {
Ok(data) => Some(data),
Err(e) if e.was_avoidable() || !desc.fallback => return Err(e.into()),
// If the error was unavoidable and we are asked to fallback, do so
Err(_) => None,
}
} else {
None
};
let cache_desc = hal::PipelineCacheDescriptor {
data,
label: desc.label.to_hal(self.instance_flags),
};
let raw = match unsafe { self.raw().create_pipeline_cache(&cache_desc) } {
Ok(raw) => raw,
Err(hal::PipelineCacheError::Validation) if desc.fallback => {
debug_assert!(cache_desc.data.take().is_some());
unsafe { self.raw().create_pipeline_cache(&cache_desc)? }
}
Err(e) => return Err(e.into()),
};
let cache = pipeline::PipelineCache {
Expand Down
3 changes: 1 addition & 2 deletions wgpu-core/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ pub enum CreatePipelineCacheError {
#[error(transparent)]
Device(#[from] DeviceError),
#[error("Pipeline cache validation failed")]
Validation(PipelineCacheValidationError),
Validation(#[from] PipelineCacheValidationError),
#[error(transparent)]
MissingFeatures(#[from] MissingFeatures),
#[error("Internal error: {0}")]
Expand All @@ -349,7 +349,6 @@ impl From<hal::PipelineCacheError> for CreatePipelineCacheError {
hal::PipelineCacheError::Device(device) => {
CreatePipelineCacheError::Device(device.into())
}
hal::PipelineCacheError::Validation => CreatePipelineCacheError::Validation,
}
}
}
Expand Down
12 changes: 6 additions & 6 deletions wgpu-core/src/pipeline_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,9 @@ impl<'a> Reader<'a> {
Some(start.try_into().expect("off-by-one-error in array size"))
}

fn read_u16(&mut self) -> Option<u16> {
self.read_array().map(u16::from_be_bytes)
}
// fn read_u16(&mut self) -> Option<u16> {
// self.read_array().map(u16::from_be_bytes)
// }
fn read_u32(&mut self) -> Option<u32> {
self.read_array().map(u32::from_be_bytes)
}
Expand Down Expand Up @@ -281,9 +281,9 @@ impl<'a> Writer<'a> {
Some(())
}

fn write_u16(&mut self, value: u16) -> Option<()> {
self.write_array(&value.to_be_bytes())
}
// fn write_u16(&mut self, value: u16) -> Option<()> {
// self.write_array(&value.to_be_bytes())
// }
fn write_u32(&mut self, value: u32) -> Option<()> {
self.write_array(&value.to_be_bytes())
}
Expand Down
5 changes: 3 additions & 2 deletions wgpu-hal/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,6 @@ pub enum PipelineError {
pub enum PipelineCacheError {
#[error(transparent)]
Device(#[from] DeviceError),
#[error("Pipeline cache had a validation error")]
Validation,
}

#[derive(Clone, Debug, Eq, PartialEq, Error)]
Expand Down Expand Up @@ -386,6 +384,9 @@ pub trait Device<A: Api>: WasmNotSendSync {
&self,
desc: &PipelineCacheDescriptor<'_>,
) -> Result<A::PipelineCache, PipelineCacheError>;
fn pipeline_cache_validation_key(&self) -> Option<[u8; 16]> {
None
}
unsafe fn destroy_pipeline_cache(&self, cache: A::PipelineCache);

unsafe fn create_query_set(
Expand Down
14 changes: 14 additions & 0 deletions wgpu-hal/src/vulkan/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1514,6 +1514,19 @@ impl super::Adapter {
unsafe { raw_device.get_device_queue(family_index, queue_index) }
};

let driver_version = self
.phd_capabilities
.properties
.driver_version
.to_be_bytes();
#[rustfmt::skip]
let pipeline_cache_validation_key = [
driver_version[0], driver_version[1], driver_version[2], driver_version[3],
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
];

let shared = Arc::new(super::DeviceShared {
raw: raw_device,
family_index,
Expand All @@ -1528,6 +1541,7 @@ impl super::Adapter {
timeline_semaphore: timeline_semaphore_fn,
ray_tracing: ray_tracing_fns,
},
pipeline_cache_validation_key,
vendor_id: self.phd_capabilities.properties.vendor_id,
timestamp_period: self.phd_capabilities.properties.limits.timestamp_period,
private_caps: self.private_caps.clone(),
Expand Down
3 changes: 3 additions & 0 deletions wgpu-hal/src/vulkan/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1950,6 +1950,9 @@ impl crate::Device<super::Api> for super::Device {

Ok(PipelineCache { raw })
}
fn pipeline_cache_validation_key(&self) -> Option<[u8; 16]> {
Some(self.shared.pipeline_cache_validation_key)
}
unsafe fn destroy_pipeline_cache(&self, cache: PipelineCache) {
unsafe { self.shared.raw.destroy_pipeline_cache(cache.raw, None) }
}
Expand Down
1 change: 1 addition & 0 deletions wgpu-hal/src/vulkan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ struct DeviceShared {
enabled_extensions: Vec<&'static CStr>,
extension_fns: DeviceExtensionFunctions,
vendor_id: u32,
pipeline_cache_validation_key: [u8; 16],
timestamp_period: f32,
private_caps: PrivateCapabilities,
workarounds: Workarounds,
Expand Down

0 comments on commit e973ce8

Please sign in to comment.