Skip to content

Commit

Permalink
Compute Pipeline Specialization (#3979)
Browse files Browse the repository at this point in the history
# Objective

- Fixes #3970
- To support Bevy's shader abstraction(shader defs, shader imports and hot shader reloading) for compute shaders, I have followed carts advice and change the `PipelinenCache` to accommodate both compute and render pipelines.

## Solution

- renamed `RenderPipelineCache` to `PipelineCache`
- Cached Pipelines are now represented by an enum (render, compute)
- split the `SpecializedPipelines` into `SpecializedRenderPipelines` and `SpecializedComputePipelines`
- updated the game of life example

## Open Questions

- should `SpecializedRenderPipelines` and `SpecializedComputePipelines` be merged and how would we do that?
- should the `get_render_pipeline` and `get_compute_pipeline` methods be merged?
- is pipeline specialization for different entry points a good pattern




Co-authored-by: Kurt Kühnert <[email protected]>
Co-authored-by: Carter Anderson <[email protected]>
  • Loading branch information
3 people committed Mar 23, 2022
1 parent 0a4136d commit 9e450f2
Show file tree
Hide file tree
Showing 21 changed files with 516 additions and 318 deletions.
26 changes: 13 additions & 13 deletions crates/bevy_core_pipeline/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use bevy_render::{
color::Color,
render_graph::{EmptyNode, RenderGraph, SlotInfo, SlotType},
render_phase::{
batch_phase_system, sort_phase_system, BatchedPhaseItem, CachedPipelinePhaseItem,
batch_phase_system, sort_phase_system, BatchedPhaseItem, CachedRenderPipelinePhaseItem,
DrawFunctionId, DrawFunctions, EntityPhaseItem, PhaseItem, RenderPhase,
},
render_resource::*,
Expand Down Expand Up @@ -198,7 +198,7 @@ impl Plugin for CorePipelinePlugin {
pub struct Transparent2d {
pub sort_key: FloatOrd,
pub entity: Entity,
pub pipeline: CachedPipelineId,
pub pipeline: CachedRenderPipelineId,
pub draw_function: DrawFunctionId,
/// Range in the vertex buffer of this item
pub batch_range: Option<Range<u32>>,
Expand All @@ -225,9 +225,9 @@ impl EntityPhaseItem for Transparent2d {
}
}

impl CachedPipelinePhaseItem for Transparent2d {
impl CachedRenderPipelinePhaseItem for Transparent2d {
#[inline]
fn cached_pipeline(&self) -> CachedPipelineId {
fn cached_pipeline(&self) -> CachedRenderPipelineId {
self.pipeline
}
}
Expand All @@ -244,7 +244,7 @@ impl BatchedPhaseItem for Transparent2d {

pub struct Opaque3d {
pub distance: f32,
pub pipeline: CachedPipelineId,
pub pipeline: CachedRenderPipelineId,
pub entity: Entity,
pub draw_function: DrawFunctionId,
}
Expand All @@ -270,16 +270,16 @@ impl EntityPhaseItem for Opaque3d {
}
}

impl CachedPipelinePhaseItem for Opaque3d {
impl CachedRenderPipelinePhaseItem for Opaque3d {
#[inline]
fn cached_pipeline(&self) -> CachedPipelineId {
fn cached_pipeline(&self) -> CachedRenderPipelineId {
self.pipeline
}
}

pub struct AlphaMask3d {
pub distance: f32,
pub pipeline: CachedPipelineId,
pub pipeline: CachedRenderPipelineId,
pub entity: Entity,
pub draw_function: DrawFunctionId,
}
Expand All @@ -305,16 +305,16 @@ impl EntityPhaseItem for AlphaMask3d {
}
}

impl CachedPipelinePhaseItem for AlphaMask3d {
impl CachedRenderPipelinePhaseItem for AlphaMask3d {
#[inline]
fn cached_pipeline(&self) -> CachedPipelineId {
fn cached_pipeline(&self) -> CachedRenderPipelineId {
self.pipeline
}
}

pub struct Transparent3d {
pub distance: f32,
pub pipeline: CachedPipelineId,
pub pipeline: CachedRenderPipelineId,
pub entity: Entity,
pub draw_function: DrawFunctionId,
}
Expand All @@ -340,9 +340,9 @@ impl EntityPhaseItem for Transparent3d {
}
}

impl CachedPipelinePhaseItem for Transparent3d {
impl CachedRenderPipelinePhaseItem for Transparent3d {
#[inline]
fn cached_pipeline(&self) -> CachedPipelineId {
fn cached_pipeline(&self) -> CachedRenderPipelineId {
self.pipeline
}
}
Expand Down
4 changes: 2 additions & 2 deletions crates/bevy_pbr/src/material.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use bevy_render::{
SetItemPipeline, TrackedRenderPass,
},
render_resource::{
BindGroup, BindGroupLayout, RenderPipelineCache, RenderPipelineDescriptor, Shader,
BindGroup, BindGroupLayout, PipelineCache, RenderPipelineDescriptor, Shader,
SpecializedMeshPipeline, SpecializedMeshPipelineError, SpecializedMeshPipelines,
},
renderer::RenderDevice,
Expand Down Expand Up @@ -307,7 +307,7 @@ pub fn queue_material_meshes<M: SpecializedMaterial>(
transparent_draw_functions: Res<DrawFunctions<Transparent3d>>,
material_pipeline: Res<MaterialPipeline<M>>,
mut pipelines: ResMut<SpecializedMeshPipelines<MaterialPipeline<M>>>,
mut pipeline_cache: ResMut<RenderPipelineCache>,
mut pipeline_cache: ResMut<PipelineCache>,
msaa: Res<Msaa>,
render_meshes: Res<RenderAssets<Mesh>>,
render_materials: Res<RenderAssets<M>>,
Expand Down
10 changes: 5 additions & 5 deletions crates/bevy_pbr/src/render/light.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use bevy_render::{
render_asset::RenderAssets,
render_graph::{Node, NodeRunError, RenderGraphContext, SlotInfo, SlotType},
render_phase::{
CachedPipelinePhaseItem, DrawFunctionId, DrawFunctions, EntityPhaseItem,
CachedRenderPipelinePhaseItem, DrawFunctionId, DrawFunctions, EntityPhaseItem,
EntityRenderCommand, PhaseItem, RenderCommandResult, RenderPhase, SetItemPipeline,
TrackedRenderPass,
},
Expand Down Expand Up @@ -1055,7 +1055,7 @@ pub fn queue_shadows(
casting_meshes: Query<&Handle<Mesh>, Without<NotShadowCaster>>,
render_meshes: Res<RenderAssets<Mesh>>,
mut pipelines: ResMut<SpecializedMeshPipelines<ShadowPipeline>>,
mut pipeline_cache: ResMut<RenderPipelineCache>,
mut pipeline_cache: ResMut<PipelineCache>,
view_lights: Query<&ViewLightEntities>,
mut view_light_shadow_phases: Query<(&LightEntity, &mut RenderPhase<Shadow>)>,
point_light_entities: Query<&CubemapVisibleEntities, With<ExtractedPointLight>>,
Expand Down Expand Up @@ -1119,7 +1119,7 @@ pub fn queue_shadows(
pub struct Shadow {
pub distance: f32,
pub entity: Entity,
pub pipeline: CachedPipelineId,
pub pipeline: CachedRenderPipelineId,
pub draw_function: DrawFunctionId,
}

Expand All @@ -1143,9 +1143,9 @@ impl EntityPhaseItem for Shadow {
}
}

impl CachedPipelinePhaseItem for Shadow {
impl CachedRenderPipelinePhaseItem for Shadow {
#[inline]
fn cached_pipeline(&self) -> CachedPipelineId {
fn cached_pipeline(&self) -> CachedRenderPipelineId {
self.pipeline
}
}
Expand Down
18 changes: 8 additions & 10 deletions crates/bevy_pbr/src/wireframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@ use bevy_asset::{load_internal_asset, Handle, HandleUntyped};
use bevy_core_pipeline::Opaque3d;
use bevy_ecs::{prelude::*, reflect::ReflectComponent};
use bevy_reflect::{Reflect, TypeUuid};
use bevy_render::mesh::MeshVertexBufferLayout;
use bevy_render::render_resource::{
PolygonMode, RenderPipelineDescriptor, SpecializedMeshPipeline, SpecializedMeshPipelineError,
SpecializedMeshPipelines,
};
use bevy_render::{
mesh::Mesh,
mesh::{Mesh, MeshVertexBufferLayout},
render_asset::RenderAssets,
render_phase::{AddRenderCommand, DrawFunctions, RenderPhase, SetItemPipeline},
render_resource::{RenderPipelineCache, Shader},
render_resource::{
PipelineCache, PolygonMode, RenderPipelineDescriptor, Shader, SpecializedMeshPipeline,
SpecializedMeshPipelineError, SpecializedMeshPipelines,
},
view::{ExtractedView, Msaa},
RenderApp, RenderStage,
};
Expand Down Expand Up @@ -109,8 +107,8 @@ fn queue_wireframes(
render_meshes: Res<RenderAssets<Mesh>>,
wireframe_config: Res<WireframeConfig>,
wireframe_pipeline: Res<WireframePipeline>,
mut pipeline_cache: ResMut<RenderPipelineCache>,
mut specialized_pipelines: ResMut<SpecializedMeshPipelines<WireframePipeline>>,
mut pipelines: ResMut<SpecializedMeshPipelines<WireframePipeline>>,
mut pipeline_cache: ResMut<PipelineCache>,
msaa: Res<Msaa>,
mut material_meshes: QuerySet<(
QueryState<(Entity, &Handle<Mesh>, &MeshUniform)>,
Expand All @@ -132,7 +130,7 @@ fn queue_wireframes(
if let Some(mesh) = render_meshes.get(mesh_handle) {
let key = msaa_key
| MeshPipelineKey::from_primitive_topology(mesh.primitive_topology);
let pipeline_id = specialized_pipelines.specialize(
let pipeline_id = pipelines.specialize(
&mut pipeline_cache,
&wireframe_pipeline,
key,
Expand Down
13 changes: 8 additions & 5 deletions crates/bevy_render/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
extern crate core;

pub mod camera;
pub mod color;
pub mod mesh;
Expand Down Expand Up @@ -36,7 +38,7 @@ use crate::{
mesh::MeshPlugin,
primitives::{CubemapFrusta, Frustum},
render_graph::RenderGraph,
render_resource::{RenderPipelineCache, Shader, ShaderLoader},
render_resource::{PipelineCache, Shader, ShaderLoader},
renderer::render_system,
texture::ImagePlugin,
view::{ViewPlugin, WindowRenderPlugin},
Expand Down Expand Up @@ -146,12 +148,13 @@ impl Plugin for RenderPlugin {
.init_resource::<ScratchRenderWorld>()
.register_type::<Frustum>()
.register_type::<CubemapFrusta>();
let render_pipeline_cache = RenderPipelineCache::new(device.clone());

let pipeline_cache = PipelineCache::new(device.clone());
let asset_server = app.world.resource::<AssetServer>().clone();

let mut render_app = App::empty();
let mut extract_stage =
SystemStage::parallel().with_system(RenderPipelineCache::extract_shaders);
SystemStage::parallel().with_system(PipelineCache::extract_shaders);
// don't apply buffers when the stage finishes running
// extract stage runs on the app world, but the buffers are applied to the render world
extract_stage.set_apply_buffers(false);
Expand All @@ -163,15 +166,15 @@ impl Plugin for RenderPlugin {
.add_stage(
RenderStage::Render,
SystemStage::parallel()
.with_system(RenderPipelineCache::process_pipeline_queue_system)
.with_system(PipelineCache::process_pipeline_queue_system)
.with_system(render_system.exclusive_system().at_end()),
)
.add_stage(RenderStage::Cleanup, SystemStage::parallel())
.insert_resource(instance)
.insert_resource(device)
.insert_resource(queue)
.insert_resource(adapter_info)
.insert_resource(render_pipeline_cache)
.insert_resource(pipeline_cache)
.insert_resource(asset_server)
.init_resource::<RenderGraph>();

Expand Down
15 changes: 9 additions & 6 deletions crates/bevy_render/src/render_phase/draw.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
render_phase::TrackedRenderPass,
render_resource::{CachedPipelineId, RenderPipelineCache},
render_resource::{CachedRenderPipelineId, PipelineCache},
};
use bevy_app::App;
use bevy_ecs::{
Expand Down Expand Up @@ -162,8 +162,8 @@ pub trait EntityPhaseItem: PhaseItem {
fn entity(&self) -> Entity;
}

pub trait CachedPipelinePhaseItem: PhaseItem {
fn cached_pipeline(&self) -> CachedPipelineId;
pub trait CachedRenderPipelinePhaseItem: PhaseItem {
fn cached_pipeline(&self) -> CachedRenderPipelineId;
}

/// A [`PhaseItem`] that can be batched dynamically.
Expand Down Expand Up @@ -224,16 +224,19 @@ impl<P: EntityPhaseItem, E: EntityRenderCommand> RenderCommand<P> for E {
}

pub struct SetItemPipeline;
impl<P: CachedPipelinePhaseItem> RenderCommand<P> for SetItemPipeline {
type Param = SRes<RenderPipelineCache>;
impl<P: CachedRenderPipelinePhaseItem> RenderCommand<P> for SetItemPipeline {
type Param = SRes<PipelineCache>;
#[inline]
fn render<'w>(
_view: Entity,
item: &P,
pipeline_cache: SystemParamItem<'w, '_, Self::Param>,
pass: &mut TrackedRenderPass<'w>,
) -> RenderCommandResult {
if let Some(pipeline) = pipeline_cache.into_inner().get(item.cached_pipeline()) {
if let Some(pipeline) = pipeline_cache
.into_inner()
.get_render_pipeline(item.cached_pipeline())
{
pass.set_render_pipeline(pipeline);
RenderCommandResult::Success
} else {
Expand Down
14 changes: 7 additions & 7 deletions crates/bevy_render/src/render_resource/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ pub use wgpu::{
BindGroupEntry, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingResource, BindingType,
BlendComponent, BlendFactor, BlendOperation, BlendState, BufferAddress, BufferBinding,
BufferBindingType, BufferDescriptor, BufferSize, BufferUsages, ColorTargetState, ColorWrites,
CommandEncoder, CommandEncoderDescriptor, CompareFunction, ComputePassDescriptor,
ComputePipelineDescriptor, DepthBiasState, DepthStencilState, Extent3d, Face,
Features as WgpuFeatures, FilterMode, FragmentState as RawFragmentState, FrontFace,
ImageCopyBuffer, ImageCopyBufferBase, ImageCopyTexture, ImageCopyTextureBase, ImageDataLayout,
ImageSubresourceRange, IndexFormat, Limits as WgpuLimits, LoadOp, MapMode, MultisampleState,
Operations, Origin3d, PipelineLayout, PipelineLayoutDescriptor, PolygonMode, PrimitiveState,
PrimitiveTopology, RenderPassColorAttachment, RenderPassDepthStencilAttachment,
CommandEncoder, CommandEncoderDescriptor, CompareFunction, ComputePass, ComputePassDescriptor,
ComputePipelineDescriptor as RawComputePipelineDescriptor, DepthBiasState, DepthStencilState,
Extent3d, Face, Features as WgpuFeatures, FilterMode, FragmentState as RawFragmentState,
FrontFace, ImageCopyBuffer, ImageCopyBufferBase, ImageCopyTexture, ImageCopyTextureBase,
ImageDataLayout, ImageSubresourceRange, IndexFormat, Limits as WgpuLimits, LoadOp, MapMode,
MultisampleState, Operations, Origin3d, PipelineLayout, PipelineLayoutDescriptor, PolygonMode,
PrimitiveState, PrimitiveTopology, RenderPassColorAttachment, RenderPassDepthStencilAttachment,
RenderPassDescriptor, RenderPipelineDescriptor as RawRenderPipelineDescriptor,
SamplerBindingType, SamplerDescriptor, ShaderModule, ShaderModuleDescriptor, ShaderSource,
ShaderStages, StencilFaceState, StencilOperation, StencilState, StorageTextureAccess,
Expand Down
13 changes: 13 additions & 0 deletions crates/bevy_render/src/render_resource/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,16 @@ pub struct FragmentState {
/// The color state of the render targets.
pub targets: Vec<ColorTargetState>,
}

/// Describes a compute pipeline.
#[derive(Clone, Debug)]
pub struct ComputePipelineDescriptor {
pub label: Option<Cow<'static, str>>,
pub layout: Option<Vec<BindGroupLayout>>,
/// The compiled shader module for this stage.
pub shader: Handle<Shader>,
pub shader_defs: Vec<String>,
/// The name of the entry point in the compiled shader. There must be a
/// function with this name in the shader.
pub entry_point: Cow<'static, str>,
}
Loading

0 comments on commit 9e450f2

Please sign in to comment.