diff --git a/Cargo.toml b/Cargo.toml index 7dd6e0ce9e61a..0805c2890aeae 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -274,6 +274,10 @@ path = "examples/reflection/trait_reflection.rs" name = "scene" path = "examples/scene/scene.rs" +[[example]] +name = "hot_shader_reloading" +path = "examples/shader/hot_shader_reloading.rs" + [[example]] name = "mesh_custom_attribute" path = "examples/shader/mesh_custom_attribute.rs" diff --git a/assets/shaders/hot.frag b/assets/shaders/hot.frag new file mode 100644 index 0000000000000..18c41c8cd5c04 --- /dev/null +++ b/assets/shaders/hot.frag @@ -0,0 +1,11 @@ +#version 450 + +layout(location = 0) out vec4 o_Target; + +layout(set = 2, binding = 0) uniform MyMaterial_color { + vec4 color; +}; + +void main() { + o_Target = color * 0.5; +} diff --git a/assets/shaders/hot.vert b/assets/shaders/hot.vert new file mode 100644 index 0000000000000..71a610e6e8463 --- /dev/null +++ b/assets/shaders/hot.vert @@ -0,0 +1,15 @@ +#version 450 + +layout(location = 0) in vec3 Vertex_Position; + +layout(set = 0, binding = 0) uniform Camera { + mat4 ViewProj; +}; + +layout(set = 1, binding = 0) uniform Transform { + mat4 Model; +}; + +void main() { + gl_Position = ViewProj * Model * vec4(Vertex_Position, 1.0); +} diff --git a/crates/bevy_render/src/lib.rs b/crates/bevy_render/src/lib.rs index ab2c20a25c86a..a2699ad8d7a96 100644 --- a/crates/bevy_render/src/lib.rs +++ b/crates/bevy_render/src/lib.rs @@ -44,6 +44,7 @@ use render_graph::{ RenderGraph, }; use renderer::{AssetRenderResourceBindings, RenderResourceBindings}; +use shader::ShaderLoader; #[cfg(feature = "hdr")] use texture::HdrTextureLoader; #[cfg(feature = "png")] @@ -87,6 +88,8 @@ impl Plugin for RenderPlugin { app.init_asset_loader::(); } + app.init_asset_loader::(); + if app.resources().get::().is_none() { app.resources_mut().insert(ClearColor::default()); } @@ -134,6 +137,7 @@ impl Plugin for RenderPlugin { camera::visible_entities_system, ) // TODO: turn these "resource systems" into graph nodes and remove the RENDER_RESOURCE stage + .add_system_to_stage(stage::RENDER_RESOURCE, shader::shader_update_system) .add_system_to_stage(stage::RENDER_RESOURCE, mesh::mesh_resource_provider_system) .add_system_to_stage(stage::RENDER_RESOURCE, Texture::texture_resource_system) .add_system_to_stage( diff --git a/crates/bevy_render/src/pipeline/pipeline_compiler.rs b/crates/bevy_render/src/pipeline/pipeline_compiler.rs index 563ebeb651e21..1bb10a519a9b2 100644 --- a/crates/bevy_render/src/pipeline/pipeline_compiler.rs +++ b/crates/bevy_render/src/pipeline/pipeline_compiler.rs @@ -2,7 +2,7 @@ use super::{state_descriptors::PrimitiveTopology, IndexFormat, PipelineDescripto use crate::{ pipeline::{BindType, InputStepMode, VertexBufferDescriptor}, renderer::RenderResourceContext, - shader::{Shader, ShaderSource}, + shader::{Shader, ShaderError, ShaderSource}, }; use bevy_asset::{Assets, Handle}; use bevy_reflect::Reflect; @@ -60,6 +60,7 @@ struct SpecializedPipeline { #[derive(Debug, Default)] pub struct PipelineCompiler { specialized_shaders: HashMap, Vec>, + specialized_shader_pipelines: HashMap, Vec>>, specialized_pipelines: HashMap, Vec>, } @@ -70,7 +71,7 @@ impl PipelineCompiler { shaders: &mut Assets, shader_handle: &Handle, shader_specialization: &ShaderSpecialization, - ) -> Handle { + ) -> Result, ShaderError> { let specialized_shaders = self .specialized_shaders .entry(shader_handle.clone_weak()) @@ -80,7 +81,7 @@ impl PipelineCompiler { // don't produce new shader if the input source is already spirv if let ShaderSource::Spirv(_) = shader.source { - return shader_handle.clone_weak(); + return Ok(shader_handle.clone_weak()); } if let Some(specialized_shader) = @@ -91,7 +92,7 @@ impl PipelineCompiler { }) { // if shader has already been compiled with current configuration, use existing shader - specialized_shader.shader.clone_weak() + Ok(specialized_shader.shader.clone_weak()) } else { // if no shader exists with the current configuration, create new shader and compile let shader_def_vec = shader_specialization @@ -100,14 +101,14 @@ impl PipelineCompiler { .cloned() .collect::>(); let compiled_shader = - render_resource_context.get_specialized_shader(shader, Some(&shader_def_vec)); + render_resource_context.get_specialized_shader(shader, Some(&shader_def_vec))?; let specialized_handle = shaders.add(compiled_shader); let weak_specialized_handle = specialized_handle.clone_weak(); specialized_shaders.push(SpecializedShader { shader: specialized_handle, specialization: shader_specialization.clone(), }); - weak_specialized_handle + Ok(weak_specialized_handle) } } @@ -138,23 +139,31 @@ impl PipelineCompiler { ) -> Handle { let source_descriptor = pipelines.get(source_pipeline).unwrap(); let mut specialized_descriptor = source_descriptor.clone(); - specialized_descriptor.shader_stages.vertex = self.compile_shader( - render_resource_context, - shaders, - &specialized_descriptor.shader_stages.vertex, - &pipeline_specialization.shader_specialization, - ); + let specialized_vertex_shader = self + .compile_shader( + render_resource_context, + shaders, + &specialized_descriptor.shader_stages.vertex, + &pipeline_specialization.shader_specialization, + ) + .unwrap(); + specialized_descriptor.shader_stages.vertex = specialized_vertex_shader.clone_weak(); + let mut specialized_fragment_shader = None; specialized_descriptor.shader_stages.fragment = specialized_descriptor .shader_stages .fragment .as_ref() .map(|fragment| { - self.compile_shader( - render_resource_context, - shaders, - fragment, - &pipeline_specialization.shader_specialization, - ) + let shader = self + .compile_shader( + render_resource_context, + shaders, + fragment, + &pipeline_specialization.shader_specialization, + ) + .unwrap(); + specialized_fragment_shader = Some(shader.clone_weak()); + shader }); let mut layout = render_resource_context.reflect_pipeline_layout( @@ -244,6 +253,18 @@ impl PipelineCompiler { &shaders, ); + // track specialized shader pipelines + self.specialized_shader_pipelines + .entry(specialized_vertex_shader) + .or_insert_with(Default::default) + .push(source_pipeline.clone_weak()); + if let Some(specialized_fragment_shader) = specialized_fragment_shader { + self.specialized_shader_pipelines + .entry(specialized_fragment_shader) + .or_insert_with(Default::default) + .push(source_pipeline.clone_weak()); + } + let specialized_pipelines = self .specialized_pipelines .entry(source_pipeline.clone_weak()) @@ -282,4 +303,56 @@ impl PipelineCompiler { }) .flatten() } + + /// Update specialized shaders and remove any related specialized + /// pipelines and assets. + pub fn update_shader( + &mut self, + shader: &Handle, + pipelines: &mut Assets, + shaders: &mut Assets, + render_resource_context: &dyn RenderResourceContext, + ) -> Result<(), ShaderError> { + if let Some(specialized_shaders) = self.specialized_shaders.get_mut(shader) { + for specialized_shader in specialized_shaders { + // Recompile specialized shader. If it fails, we bail immediately. + let shader_def_vec = specialized_shader + .specialization + .shader_defs + .iter() + .cloned() + .collect::>(); + let new_handle = + shaders.add(render_resource_context.get_specialized_shader( + shaders.get(shader).unwrap(), + Some(&shader_def_vec), + )?); + + // Replace handle and remove old from assets. + let old_handle = std::mem::replace(&mut specialized_shader.shader, new_handle); + shaders.remove(&old_handle); + + // Find source pipelines that use the old specialized + // shader, and remove from tracking. + if let Some(source_pipelines) = + self.specialized_shader_pipelines.remove(&old_handle) + { + // Remove all specialized pipelines from tracking + // and asset storage. They will be rebuilt on next + // draw. + for source_pipeline in source_pipelines { + if let Some(specialized_pipelines) = + self.specialized_pipelines.remove(&source_pipeline) + { + for p in specialized_pipelines { + pipelines.remove(p.pipeline); + } + } + } + } + } + } + + Ok(()) + } } diff --git a/crates/bevy_render/src/renderer/headless_render_resource_context.rs b/crates/bevy_render/src/renderer/headless_render_resource_context.rs index f1b0a08cf4a2c..ed182a51bb1d3 100644 --- a/crates/bevy_render/src/renderer/headless_render_resource_context.rs +++ b/crates/bevy_render/src/renderer/headless_render_resource_context.rs @@ -2,7 +2,7 @@ use super::RenderResourceContext; use crate::{ pipeline::{BindGroupDescriptorId, PipelineDescriptor}, renderer::{BindGroup, BufferId, BufferInfo, RenderResourceId, SamplerId, TextureId}, - shader::Shader, + shader::{Shader, ShaderError}, texture::{SamplerDescriptor, TextureDescriptor}, }; use bevy_asset::{Assets, Handle, HandleUntyped}; @@ -149,8 +149,12 @@ impl RenderResourceContext for HeadlessRenderResourceContext { size } - fn get_specialized_shader(&self, shader: &Shader, _macros: Option<&[String]>) -> Shader { - shader.clone() + fn get_specialized_shader( + &self, + shader: &Shader, + _macros: Option<&[String]>, + ) -> Result { + Ok(shader.clone()) } fn remove_stale_bind_groups(&self) {} diff --git a/crates/bevy_render/src/renderer/render_resource_context.rs b/crates/bevy_render/src/renderer/render_resource_context.rs index f398c64a6e133..365595fac7816 100644 --- a/crates/bevy_render/src/renderer/render_resource_context.rs +++ b/crates/bevy_render/src/renderer/render_resource_context.rs @@ -1,7 +1,7 @@ use crate::{ pipeline::{BindGroupDescriptorId, PipelineDescriptor, PipelineLayout}, renderer::{BindGroup, BufferId, BufferInfo, RenderResourceId, SamplerId, TextureId}, - shader::{Shader, ShaderLayout, ShaderStages}, + shader::{Shader, ShaderError, ShaderLayout, ShaderStages}, texture::{SamplerDescriptor, TextureDescriptor}, }; use bevy_asset::{Asset, Assets, Handle, HandleUntyped}; @@ -29,7 +29,11 @@ pub trait RenderResourceContext: Downcast + Send + Sync + 'static { fn create_buffer_with_data(&self, buffer_info: BufferInfo, data: &[u8]) -> BufferId; fn create_shader_module(&self, shader_handle: &Handle, shaders: &Assets); fn create_shader_module_from_source(&self, shader_handle: &Handle, shader: &Shader); - fn get_specialized_shader(&self, shader: &Shader, macros: Option<&[String]>) -> Shader; + fn get_specialized_shader( + &self, + shader: &Shader, + macros: Option<&[String]>, + ) -> Result; fn remove_buffer(&self, buffer: BufferId); fn remove_texture(&self, texture: TextureId); fn remove_sampler(&self, sampler: SamplerId); diff --git a/crates/bevy_render/src/shader/shader.rs b/crates/bevy_render/src/shader/shader.rs index 48fd56854443e..6fffa1e0ca7e8 100644 --- a/crates/bevy_render/src/shader/shader.rs +++ b/crates/bevy_render/src/shader/shader.rs @@ -1,7 +1,16 @@ +use crate::{ + pipeline::{PipelineCompiler, PipelineDescriptor}, + renderer::RenderResourceContext, +}; + use super::ShaderLayout; -use bevy_asset::Handle; +use bevy_app::{EventReader, Events}; +use bevy_asset::{AssetEvent, AssetLoader, Assets, Handle, LoadContext, LoadedAsset}; +use bevy_ecs::{Local, Res, ResMut}; use bevy_reflect::TypeUuid; +use bevy_utils::{tracing::error, BoxedFuture}; use std::marker::Copy; +use thiserror::Error; /// The stage of a shader #[derive(Hash, Eq, PartialEq, Copy, Clone, Debug)] @@ -11,6 +20,18 @@ pub enum ShaderStage { Compute, } +/// An error that occurs during shader handling. +#[derive(Error, Debug)] +pub enum ShaderError { + /// Shader compilation error. + #[error("Shader compilation error: {0}")] + Compilation(String), + #[cfg(target_os = "ios")] + /// shaderc error. + #[error("shaderc error")] + ShaderC(#[from] shaderc::Error), +} + #[cfg(all(not(target_os = "ios"), not(target_arch = "wasm32")))] impl Into for ShaderStage { fn into(self) -> bevy_glsl_to_spirv::ShaderType { @@ -27,8 +48,9 @@ pub fn glsl_to_spirv( glsl_source: &str, stage: ShaderStage, shader_defs: Option<&[String]>, -) -> Vec { - bevy_glsl_to_spirv::compile(glsl_source, stage.into(), shader_defs).unwrap() +) -> Result, ShaderError> { + bevy_glsl_to_spirv::compile(glsl_source, stage.into(), shader_defs) + .map_err(ShaderError::Compilation) } #[cfg(target_os = "ios")] @@ -47,26 +69,24 @@ pub fn glsl_to_spirv( glsl_source: &str, stage: ShaderStage, shader_defs: Option<&[String]>, -) -> Vec { - let mut compiler = shaderc::Compiler::new().unwrap(); - let mut options = shaderc::CompileOptions::new().unwrap(); +) -> Result, ShaderError> { + let mut compiler = shaderc::Compiler::new()?; + let mut options = shaderc::CompileOptions::new()?; if let Some(shader_defs) = shader_defs { for def in shader_defs.iter() { options.add_macro_definition(def, None); } } - let binary_result = compiler - .compile_into_spirv( - glsl_source, - stage.into(), - "shader.glsl", - "main", - Some(&options), - ) - .unwrap(); + let binary_result = compiler.compile_into_spirv( + glsl_source, + stage.into(), + "shader.glsl", + "main", + Some(&options), + )?; - binary_result.as_binary().to_vec() + Ok(binary_result.as_binary().to_vec()) } fn bytes_to_words(bytes: &[u8]) -> Vec { @@ -114,19 +134,19 @@ impl Shader { } #[cfg(not(target_arch = "wasm32"))] - pub fn get_spirv(&self, macros: Option<&[String]>) -> Vec { + pub fn get_spirv(&self, macros: Option<&[String]>) -> Result, ShaderError> { match self.source { - ShaderSource::Spirv(ref bytes) => bytes.clone(), + ShaderSource::Spirv(ref bytes) => Ok(bytes.clone()), ShaderSource::Glsl(ref source) => glsl_to_spirv(&source, self.stage, macros), } } #[cfg(not(target_arch = "wasm32"))] - pub fn get_spirv_shader(&self, macros: Option<&[String]>) -> Shader { - Shader { - source: ShaderSource::Spirv(self.get_spirv(macros)), + pub fn get_spirv_shader(&self, macros: Option<&[String]>) -> Result { + Ok(Shader { + source: ShaderSource::Spirv(self.get_spirv(macros)?), stage: self.stage, - } + }) } #[cfg(not(target_arch = "wasm32"))] @@ -188,3 +208,60 @@ impl ShaderStages { } } } + +#[derive(Default)] +pub struct ShaderLoader; + +impl AssetLoader for ShaderLoader { + fn load<'a>( + &'a self, + bytes: &'a [u8], + load_context: &'a mut LoadContext, + ) -> BoxedFuture<'a, Result<(), anyhow::Error>> { + Box::pin(async move { + let ext = load_context.path().extension().unwrap().to_str().unwrap(); + + let shader = match ext { + "vert" => Shader::from_glsl(ShaderStage::Vertex, std::str::from_utf8(bytes)?), + "frag" => Shader::from_glsl(ShaderStage::Fragment, std::str::from_utf8(bytes)?), + _ => panic!("unhandled extension: {}", ext), + }; + + load_context.set_default_asset(LoadedAsset::new(shader)); + Ok(()) + }) + } + + fn extensions(&self) -> &[&str] { + &["vert", "frag"] + } +} + +pub fn shader_update_system( + mut shaders: ResMut>, + mut pipelines: ResMut>, + shader_events: Res>>, + mut shader_event_reader: Local>>, + mut pipeline_compiler: ResMut, + render_resource_context: Res>, +) { + for event in shader_event_reader.iter(&shader_events) { + match event { + AssetEvent::Modified { handle } => { + if let Err(e) = pipeline_compiler.update_shader( + handle, + &mut pipelines, + &mut shaders, + &**render_resource_context, + ) { + error!("Failed to update shader: {}", e); + } + } + // Creating shaders on the fly is unhandled since they + // have to exist already when assigned to a pipeline. If a + // shader is removed the pipeline keeps using its + // specialized version. Maybe this should be a warning? + AssetEvent::Created { .. } | AssetEvent::Removed { .. } => (), + } + } +} diff --git a/crates/bevy_render/src/shader/shader_reflect.rs b/crates/bevy_render/src/shader/shader_reflect.rs index 77577ede7b9a8..13bfbeae38386 100644 --- a/crates/bevy_render/src/shader/shader_reflect.rs +++ b/crates/bevy_render/src/shader/shader_reflect.rs @@ -328,7 +328,8 @@ mod tests { } "#, ) - .get_spirv_shader(None); + .get_spirv_shader(None) + .unwrap(); let layout = vertex_shader.reflect_layout(true).unwrap(); assert_eq!( diff --git a/crates/bevy_wgpu/src/renderer/wgpu_render_resource_context.rs b/crates/bevy_wgpu/src/renderer/wgpu_render_resource_context.rs index cc428011206f3..dadb66792a4a4 100644 --- a/crates/bevy_wgpu/src/renderer/wgpu_render_resource_context.rs +++ b/crates/bevy_wgpu/src/renderer/wgpu_render_resource_context.rs @@ -12,7 +12,7 @@ use bevy_render::{ BindGroup, BufferId, BufferInfo, RenderResourceBinding, RenderResourceContext, RenderResourceId, SamplerId, TextureId, }, - shader::{glsl_to_spirv, Shader, ShaderSource}, + shader::{glsl_to_spirv, Shader, ShaderError, ShaderSource}, texture::{Extent3d, SamplerDescriptor, TextureDescriptor}, }; use bevy_utils::tracing::trace; @@ -251,7 +251,7 @@ impl RenderResourceContext for WgpuRenderResourceContext { fn create_shader_module_from_source(&self, shader_handle: &Handle, shader: &Shader) { let mut shader_modules = self.resources.shader_modules.write(); - let spirv: Cow<[u32]> = shader.get_spirv(None).into(); + let spirv: Cow<[u32]> = shader.get_spirv(None).unwrap().into(); let shader_module = self .device .create_shader_module(wgpu::ShaderModuleSource::SpirV(spirv)); @@ -574,14 +574,18 @@ impl RenderResourceContext for WgpuRenderResourceContext { } } - fn get_specialized_shader(&self, shader: &Shader, macros: Option<&[String]>) -> Shader { + fn get_specialized_shader( + &self, + shader: &Shader, + macros: Option<&[String]>, + ) -> Result { let spirv_data = match shader.source { ShaderSource::Spirv(ref bytes) => bytes.clone(), - ShaderSource::Glsl(ref source) => glsl_to_spirv(&source, shader.stage, macros), + ShaderSource::Glsl(ref source) => glsl_to_spirv(&source, shader.stage, macros)?, }; - Shader { + Ok(Shader { source: ShaderSource::Spirv(spirv_data), ..*shader - } + }) } } diff --git a/examples/shader/hot_shader_reloading.rs b/examples/shader/hot_shader_reloading.rs new file mode 100644 index 0000000000000..2c1ba667bd217 --- /dev/null +++ b/examples/shader/hot_shader_reloading.rs @@ -0,0 +1,80 @@ +use bevy::{ + prelude::*, + reflect::TypeUuid, + render::{ + mesh::shape, + pipeline::{PipelineDescriptor, RenderPipeline}, + render_graph::{base, AssetRenderResourcesNode, RenderGraph}, + renderer::RenderResources, + shader::ShaderStages, + }, +}; + +/// This example illustrates how to load shaders such that they can be +/// edited while the example is still running. +fn main() { + App::build() + .add_plugins(DefaultPlugins) + .add_asset::() + .add_startup_system(setup) + .run(); +} + +#[derive(RenderResources, Default, TypeUuid)] +#[uuid = "3bf9e364-f29d-4d6c-92cf-93298466c620"] +struct MyMaterial { + pub color: Color, +} + +fn setup( + commands: &mut Commands, + asset_server: ResMut, + mut pipelines: ResMut>, + mut meshes: ResMut>, + mut materials: ResMut>, + mut render_graph: ResMut, +) { + // Watch for changes + asset_server.watch_for_changes().unwrap(); + + // Create a new shader pipeline with shaders loaded from the asset directory + let pipeline_handle = pipelines.add(PipelineDescriptor::default_config(ShaderStages { + vertex: asset_server.load::("shaders/hot.vert"), + fragment: Some(asset_server.load::("shaders/hot.frag")), + })); + + // Add an AssetRenderResourcesNode to our Render Graph. This will bind MyMaterial resources to our shader + render_graph.add_system_node( + "my_material", + AssetRenderResourcesNode::::new(true), + ); + + // Add a Render Graph edge connecting our new "my_material" node to the main pass node. This ensures "my_material" runs before the main pass + render_graph + .add_node_edge("my_material", base::node::MAIN_PASS) + .unwrap(); + + // Create a new material + let material = materials.add(MyMaterial { + color: Color::rgb(0.0, 0.8, 0.0), + }); + + // Setup our world + commands + // cube + .spawn(MeshBundle { + mesh: meshes.add(Mesh::from(shape::Cube { size: 2.0 })), + render_pipelines: RenderPipelines::from_pipelines(vec![RenderPipeline::new( + pipeline_handle, + )]), + transform: Transform::from_translation(Vec3::new(0.0, 0.0, 0.0)), + ..Default::default() + }) + .with(material) + // camera + .spawn(Camera3dBundle { + transform: Transform::from_translation(Vec3::new(3.0, 5.0, -8.0)) + .looking_at(Vec3::default(), Vec3::unit_y()), + ..Default::default() + }); +}