re-create compute from example

This commit is contained in:
2025-08-18 21:41:49 -04:00
parent b1f7524856
commit 81e8c2504f
8 changed files with 469 additions and 18 deletions

View File

@@ -1,30 +1,27 @@
use bevy::{
asset::RenderAssetUsages,
prelude::*,
render::{
RenderApp,
render_resource::{Extent3d, TextureDimension, TextureFormat, TextureUsages},
},
render::render_resource::{Extent3d, TextureDimension, TextureFormat, TextureUsages},
window::PrimaryWindow,
};
use crate::render::pipeline::{TracerPipelinePlugin, TracerRenderTextures, TracerUniforms};
pub struct Blackhole;
impl Plugin for Blackhole {
fn build(&self, app: &mut App) {
app.register_type::<RenderTextures>();
app.register_type::<TracerRenderTextures>();
app.add_systems(Startup, setup);
let render_app = app.sub_app_mut(RenderApp);
render_app.add_systems(Startup, init_pipeline);
app.add_plugins(TracerPipelinePlugin);
app.insert_resource(TracerUniforms {
sky_color: LinearRgba::BLUE,
});
}
}
#[derive(Resource, Reflect)]
#[reflect(Resource)]
struct RenderTextures(pub Handle<Image>, pub Handle<Image>);
fn setup(mut commands: Commands, mut images: ResMut<Assets<Image>>, window: Single<&Window, With<PrimaryWindow>>) {
let size = window.physical_size();
@@ -39,7 +36,7 @@ fn setup(mut commands: Commands, mut images: ResMut<Assets<Image>>, window: Sing
let mut image = Image::new_fill(
extent,
TextureDimension::D2,
&[0; PIXEL_SIZE],
&[255; PIXEL_SIZE],
PIXEL_FORMAT,
RenderAssetUsages::RENDER_WORLD,
);
@@ -63,7 +60,5 @@ fn setup(mut commands: Commands, mut images: ResMut<Assets<Image>>, window: Sing
commands.spawn(Camera2d);
commands.insert_resource(RenderTextures(img0, img1));
commands.insert_resource(TracerRenderTextures(img0, img1));
}
fn init_pipeline(mut commands: Commands) {}

View File

@@ -1,10 +1,13 @@
use app::Blackhole;
use bevy::prelude::*;
use bevy::{prelude::*, window::PresentMode};
use bevy::window::PresentMode;
use bevy_inspector_egui::{bevy_egui::EguiPlugin, quick::WorldInspectorPlugin};
mod app;
mod render;
pub const SHADER_ASSET_PATH: &str = "trace.wgsl";
pub const WORKGROUP_SIZE: u32 = 8;
const NAME: &str = "Black Hole";
fn main() {

2
src/render/mod.rs Normal file
View File

@@ -0,0 +1,2 @@
pub mod node;
pub mod pipeline;

101
src/render/node.rs Normal file
View File

@@ -0,0 +1,101 @@
use bevy::{
prelude::*,
render::{
render_graph::{self},
render_resource::{CachedPipelineState, ComputePassDescriptor, PipelineCache, PipelineCacheError},
renderer::RenderContext,
},
};
use crate::render::pipeline::{TracerImageBindGroups, TracerPipeline};
use crate::{SHADER_ASSET_PATH, WORKGROUP_SIZE};
pub enum TracerState {
Loading,
Init,
Update(usize),
}
pub struct TracerNode {
state: TracerState,
}
impl Default for TracerNode {
fn default() -> Self {
Self {
state: TracerState::Loading,
}
}
}
impl render_graph::Node for TracerNode {
fn update(&mut self, world: &mut World) {
let pipeline = world.resource::<TracerPipeline>();
let pipeline_cache = world.resource::<PipelineCache>();
// if the corresponding pipeline has loaded, transition to the next stage
match self.state {
TracerState::Loading => {
match pipeline_cache.get_compute_pipeline_state(pipeline.init_pipeline) {
CachedPipelineState::Ok(_) => {
self.state = TracerState::Init;
}
// If the shader hasn't loaded yet, just wait.
CachedPipelineState::Err(PipelineCacheError::ShaderNotLoaded(_)) => {}
CachedPipelineState::Err(err) => {
panic!("Initializing assets/{SHADER_ASSET_PATH}:\n{err}")
}
_ => {}
}
}
TracerState::Init => {
if let CachedPipelineState::Ok(_) = pipeline_cache.get_compute_pipeline_state(pipeline.update_pipeline)
{
self.state = TracerState::Update(1);
}
}
TracerState::Update(0) => {
self.state = TracerState::Update(1);
}
TracerState::Update(1) => {
self.state = TracerState::Update(0);
}
TracerState::Update(_) => unreachable!(),
}
}
fn run(
&self,
_graph: &mut render_graph::RenderGraphContext,
render_context: &mut RenderContext,
world: &World,
) -> Result<(), render_graph::NodeRunError> {
let bind_groups = &world.resource::<TracerImageBindGroups>().0;
let pipeline_cache = world.resource::<PipelineCache>();
let pipeline = world.resource::<TracerPipeline>();
let mut pass = render_context
.command_encoder()
.begin_compute_pass(&ComputePassDescriptor::default());
// select the pipeline based on the current state
match self.state {
TracerState::Loading => {}
TracerState::Init => {
let init_pipeline = pipeline_cache.get_compute_pipeline(pipeline.init_pipeline).unwrap();
pass.set_bind_group(0, &bind_groups[0], &[]);
pass.set_pipeline(init_pipeline);
pass.dispatch_workgroups(1920 / WORKGROUP_SIZE, 1080 / WORKGROUP_SIZE, 1);
}
TracerState::Update(index) => {
if let Some(update_pipeline) = pipeline_cache.get_compute_pipeline(pipeline.update_pipeline) {
pass.set_bind_group(0, &bind_groups[index], &[]);
pass.set_pipeline(update_pipeline);
pass.dispatch_workgroups(1920 / WORKGROUP_SIZE, 1080 / WORKGROUP_SIZE, 1);
}
}
}
Ok(())
}
}

188
src/render/pipeline.rs Normal file
View File

@@ -0,0 +1,188 @@
use std::borrow::Cow;
use bevy::{
prelude::*,
render::{
Render, RenderApp, RenderSet,
extract_resource::{ExtractResource, ExtractResourcePlugin},
render_asset::RenderAssets,
render_graph::{RenderGraph, RenderLabel},
render_resource::{
BindGroup, BindGroupEntries, BindGroupLayout, BindGroupLayoutEntries, CachedComputePipelineId,
ComputePipelineDescriptor, PipelineCache, ShaderStages, ShaderType, StorageTextureAccess, TextureFormat,
UniformBuffer,
binding_types::{texture_storage_2d, uniform_buffer},
},
renderer::{RenderDevice, RenderQueue},
texture::GpuImage,
},
};
use crate::{SHADER_ASSET_PATH, render::node::TracerNode};
#[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)]
pub struct TracerLabel;
#[derive(Resource, Reflect, ExtractResource, Clone)]
#[reflect(Resource)]
pub struct TracerRenderTextures(pub Handle<Image>, pub Handle<Image>);
#[derive(Resource)]
pub struct TracerPipeline {
pub texture_bind_group_layout: BindGroupLayout,
pub init_pipeline: CachedComputePipelineId,
pub update_pipeline: CachedComputePipelineId,
}
#[derive(Resource, Clone, ExtractResource, ShaderType, Default)]
pub struct TracerUniforms {
pub sky_color: LinearRgba,
}
pub struct TracerPipelinePlugin;
impl Plugin for TracerPipelinePlugin {
fn build(&self, app: &mut App) {
app.add_plugins((
ExtractResourcePlugin::<TracerRenderTextures>::default(),
ExtractResourcePlugin::<TracerUniforms>::default(),
));
app.init_resource::<TracerUniforms>();
let render_app = app.sub_app_mut(RenderApp);
// render_app.add_systems(Startup, init_pipeline);
render_app.add_systems(Render, prepare_bind_groups.in_set(RenderSet::PrepareBindGroups));
let mut render_graph = render_app.world_mut().resource_mut::<RenderGraph>();
render_graph.add_node(TracerLabel, TracerNode::default());
render_graph.add_node_edge(TracerLabel, bevy::render::graph::CameraDriverLabel);
}
fn finish(&self, app: &mut App) {
let render_app = app.sub_app_mut(RenderApp);
render_app.init_resource::<TracerPipeline>();
}
}
impl FromWorld for TracerPipeline {
fn from_world(world: &mut World) -> Self {
let render_device = world.resource::<RenderDevice>();
let texture_bind_group_layout = render_device.create_bind_group_layout(
"TracerImages",
&BindGroupLayoutEntries::sequential(
ShaderStages::COMPUTE,
(
texture_storage_2d(TextureFormat::Rgba32Float, StorageTextureAccess::ReadOnly),
texture_storage_2d(TextureFormat::Rgba32Float, StorageTextureAccess::WriteOnly),
uniform_buffer::<TracerUniforms>(false),
),
),
);
let shader = world.load_asset(SHADER_ASSET_PATH);
let pipeline_cache = world.resource::<PipelineCache>();
let init_pipeline = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
layout: vec![texture_bind_group_layout.clone()],
shader: shader.clone(),
entry_point: Cow::from("init"),
label: None,
zero_initialize_workgroup_memory: false,
push_constant_ranges: Default::default(),
shader_defs: Default::default(),
});
let update_pipeline = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
layout: vec![texture_bind_group_layout.clone()],
shader,
entry_point: Cow::from("update"),
label: None,
zero_initialize_workgroup_memory: false,
push_constant_ranges: Default::default(),
shader_defs: Default::default(),
});
return TracerPipeline {
texture_bind_group_layout,
init_pipeline,
update_pipeline,
};
}
}
fn init_pipeline(
mut commands: Commands,
render_device: Res<RenderDevice>,
asset_server: Res<AssetServer>,
pipeline_cache: Res<PipelineCache>,
) {
let texture_bind_group_layout = render_device.create_bind_group_layout(
"TracerImages",
&BindGroupLayoutEntries::sequential(
ShaderStages::COMPUTE,
(
texture_storage_2d(TextureFormat::Rgba32Float, StorageTextureAccess::ReadOnly),
texture_storage_2d(TextureFormat::Rgba32Float, StorageTextureAccess::WriteOnly),
uniform_buffer::<TracerUniforms>(false),
),
),
);
let shader = asset_server.load(SHADER_ASSET_PATH);
let init_pipeline = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
layout: vec![texture_bind_group_layout.clone()],
shader: shader.clone(),
entry_point: Cow::from("init"),
label: None,
zero_initialize_workgroup_memory: false,
push_constant_ranges: Default::default(),
shader_defs: Default::default(),
});
let update_pipeline = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
layout: vec![texture_bind_group_layout.clone()],
shader,
entry_point: Cow::from("update"),
label: None,
zero_initialize_workgroup_memory: false,
push_constant_ranges: Default::default(),
shader_defs: Default::default(),
});
commands.insert_resource(TracerPipeline {
texture_bind_group_layout,
init_pipeline,
update_pipeline,
});
}
#[derive(Resource)]
pub struct TracerImageBindGroups(pub [BindGroup; 2]);
fn prepare_bind_groups(
mut commands: Commands,
pipeline: Res<TracerPipeline>,
gpu_images: Res<RenderAssets<GpuImage>>,
tracer_images: Res<TracerRenderTextures>,
tracer_uniforms: Res<TracerUniforms>,
render_device: Res<RenderDevice>,
queue: Res<RenderQueue>,
) {
let view_a = gpu_images.get(&tracer_images.0).unwrap();
let view_b = gpu_images.get(&tracer_images.1).unwrap();
// Uniform buffer is used here to demonstrate how to set up a uniform in a compute shader
// Alternatives such as storage buffers or push constants may be more suitable for your use case
let mut uniform_buffer = UniformBuffer::from(tracer_uniforms.into_inner());
uniform_buffer.write_buffer(&render_device, &queue);
let bind_group_0 = render_device.create_bind_group(
None,
&pipeline.texture_bind_group_layout,
&BindGroupEntries::sequential((&view_a.texture_view, &view_b.texture_view, &uniform_buffer)),
);
let bind_group_1 = render_device.create_bind_group(
None,
&pipeline.texture_bind_group_layout,
&BindGroupEntries::sequential((&view_b.texture_view, &view_a.texture_view, &uniform_buffer)),
);
commands.insert_resource(TracerImageBindGroups([bind_group_0, bind_group_1]));
}