From 81e8c2504f55a109ec3da649ca7541a64af32e17 Mon Sep 17 00:00:00 2001 From: Amatsugu Date: Mon, 18 Aug 2025 21:41:49 -0400 Subject: [PATCH] re-create compute from example --- Cargo.lock | 88 +++++++++++++++++++ Cargo.toml | 2 +- assets/trace.wgsl | 74 ++++++++++++++++ src/app.rs | 27 +++--- src/main.rs | 5 +- src/render/mod.rs | 2 + src/render/node.rs | 101 ++++++++++++++++++++++ src/render/pipeline.rs | 188 +++++++++++++++++++++++++++++++++++++++++ 8 files changed, 469 insertions(+), 18 deletions(-) create mode 100644 src/render/mod.rs create mode 100644 src/render/node.rs create mode 100644 src/render/pipeline.rs diff --git a/Cargo.lock b/Cargo.lock index 8238c62..0afef2e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -549,6 +549,7 @@ dependencies = [ "futures-io", "futures-lite", "js-sys", + "notify-debouncer-full", "parking_lot", "ron", "serde", @@ -2343,6 +2344,15 @@ dependencies = [ "simd-adler32", ] +[[package]] +name = "file-id" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bc904b9bbefcadbd8e3a9fb0d464a9b979de6324c03b3c663e8994f46a5be36" +dependencies = [ + "windows-sys 0.52.0", +] + [[package]] name = "fixedbitset" version = "0.5.7" @@ -2430,6 +2440,15 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aa9a19cbb55df58761df49b23516a86d432839add4af60fc256da840f66ed35b" +[[package]] +name = "fsevent-sys" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76ee7a02da4d231650c7cea31349b889be2f45ddb3ef3032d2ec8185f6313fd2" +dependencies = [ + "libc", +] + [[package]] name = "futures-channel" version = "0.3.31" @@ -2979,6 +2998,26 @@ version = "3.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2db585e1d738fc771bf08a151420d3ed193d9d895a36df7f6f8a9456b911ddc" +[[package]] +name = "kqueue" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eac30106d7dce88daf4a3fcb4879ea939476d5074a9b7ddd0fb97fa4bed5596a" +dependencies = [ + "kqueue-sys", + "libc", +] + +[[package]] +name = "kqueue-sys" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed9625ffda8729b85e45cf04090035ac368927b8cebc34898e7c120f52e4838b" +dependencies = [ + "bitflags 1.3.2", + "libc", +] + [[package]] name = "ktx2" version = "0.3.0" @@ -3190,6 +3229,18 @@ dependencies = [ "simd-adler32", ] +[[package]] +name = "mio" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78bed444cc8a2160f01cbcf811ef18cac863ad68ae8ca62092e8db51d51c761c" +dependencies = [ + "libc", + "log", + "wasi 0.11.1+wasi-snapshot-preview1", + "windows-sys 0.59.0", +] + [[package]] name = "naga" version = "24.0.0" @@ -3344,6 +3395,43 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0676bb32a98c1a483ce53e500a81ad9c3d5b3f7c920c28c24e9cb0980d0b5bc8" +[[package]] +name = "notify" +version = "8.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d3d07927151ff8575b7087f245456e549fea62edf0ec4e565a5ee50c8402bc3" +dependencies = [ + "bitflags 2.9.1", + "fsevent-sys", + "inotify", + "kqueue", + "libc", + "log", + "mio", + "notify-types", + "walkdir", + "windows-sys 0.60.2", +] + +[[package]] +name = "notify-debouncer-full" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2d88b1a7538054351c8258338df7c931a590513fb3745e8c15eb9ff4199b8d1" +dependencies = [ + "file-id", + "log", + "notify", + "notify-types", + "walkdir", +] + +[[package]] +name = "notify-types" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e0826a989adedc2a244799e823aece04662b66609d96af8dff7ac6df9a8925d" + [[package]] name = "ntapi" version = "0.4.1" diff --git a/Cargo.toml b/Cargo.toml index b848df7..d68038a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,5 +4,5 @@ version = "0.1.0" edition = "2024" [dependencies] -bevy = {version = "0.16.1", features = ["bevy_image"]} +bevy = {version = "0.16.1", features = ["bevy_image", "file_watcher"]} bevy-inspector-egui = "0.31.0" diff --git a/assets/trace.wgsl b/assets/trace.wgsl index e69de29..4bfc9e1 100644 --- a/assets/trace.wgsl +++ b/assets/trace.wgsl @@ -0,0 +1,74 @@ + +@group(0) @binding(0) var input: texture_storage_2d; + +@group(0) @binding(1) var output: texture_storage_2d; + +@group(0) @binding(2) var config: TracerUniforms; + +struct TracerUniforms { + sky_color: vec4, +} + +fn hash(value: u32) -> u32 { + var state = value; + state = state ^ 2747636419u; + state = state * 2654435769u; + state = state ^ (state >> 16u); + state = state * 2654435769u; + state = state ^ (state >> 16u); + state = state * 2654435769u; + return state; +} + +fn randomFloat(value: u32) -> f32 { + return f32(hash(value)) / 4294967295.0; +} + +@compute @workgroup_size(8, 8, 1) +fn init(@builtin(global_invocation_id) invocation_id: vec3, @builtin(num_workgroups) num_workgroups: vec3) { + let location = vec2(i32(invocation_id.x), i32(invocation_id.y)); + + let randomNumber = randomFloat((invocation_id.y << 16u) | invocation_id.x); + let alive = randomNumber > 0.9; + // Use alpha channel to keep track of cell's state + let color = vec4(config.sky_color.rgb, f32(alive)); + + textureStore(output, location, color); +} + +fn is_alive(location: vec2, offset_x: i32, offset_y: i32) -> i32 { + let value: vec4 = textureLoad(input, location + vec2(offset_x, offset_y)); + return i32(value.a); +} + +fn count_alive(location: vec2) -> i32 { + return is_alive(location, -1, -1) + + is_alive(location, -1, 0) + + is_alive(location, -1, 1) + + is_alive(location, 0, -1) + + is_alive(location, 0, 1) + + is_alive(location, 1, -1) + + is_alive(location, 1, 0) + + is_alive(location, 1, 1); +} + +@compute @workgroup_size(8, 8, 1) +fn update(@builtin(global_invocation_id) invocation_id: vec3) { +let location = vec2(i32(invocation_id.x), i32(invocation_id.y)); + + let n_alive = count_alive(location); + + var alive: bool; + + if (n_alive == 3) { + alive = true; + } else if (n_alive == 2) { + let currently_alive = is_alive(location, 0, 0); + alive = bool(currently_alive); + } else { + alive = false; + } + let color = vec4(config.sky_color.rgb , f32(alive)); + + textureStore(output, location, color); +} diff --git a/src/app.rs b/src/app.rs index 9300156..fdd4718 100644 --- a/src/app.rs +++ b/src/app.rs @@ -1,30 +1,27 @@ use bevy::{ asset::RenderAssetUsages, prelude::*, - render::{ - RenderApp, - render_resource::{Extent3d, TextureDimension, TextureFormat, TextureUsages}, - }, + render::render_resource::{Extent3d, TextureDimension, TextureFormat, TextureUsages}, window::PrimaryWindow, }; + +use crate::render::pipeline::{TracerPipelinePlugin, TracerRenderTextures, TracerUniforms}; + pub struct Blackhole; impl Plugin for Blackhole { fn build(&self, app: &mut App) { - app.register_type::(); + app.register_type::(); app.add_systems(Startup, setup); - let render_app = app.sub_app_mut(RenderApp); - - render_app.add_systems(Startup, init_pipeline); + app.add_plugins(TracerPipelinePlugin); + app.insert_resource(TracerUniforms { + sky_color: LinearRgba::BLUE, + }); } } -#[derive(Resource, Reflect)] -#[reflect(Resource)] -struct RenderTextures(pub Handle, pub Handle); - fn setup(mut commands: Commands, mut images: ResMut>, window: Single<&Window, With>) { let size = window.physical_size(); @@ -39,7 +36,7 @@ fn setup(mut commands: Commands, mut images: ResMut>, window: Sing let mut image = Image::new_fill( extent, TextureDimension::D2, - &[0; PIXEL_SIZE], + &[255; PIXEL_SIZE], PIXEL_FORMAT, RenderAssetUsages::RENDER_WORLD, ); @@ -63,7 +60,5 @@ fn setup(mut commands: Commands, mut images: ResMut>, window: Sing commands.spawn(Camera2d); - commands.insert_resource(RenderTextures(img0, img1)); + commands.insert_resource(TracerRenderTextures(img0, img1)); } - -fn init_pipeline(mut commands: Commands) {} diff --git a/src/main.rs b/src/main.rs index ebeff33..7271ea9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,13 @@ use app::Blackhole; use bevy::prelude::*; -use bevy::{prelude::*, window::PresentMode}; +use bevy::window::PresentMode; use bevy_inspector_egui::{bevy_egui::EguiPlugin, quick::WorldInspectorPlugin}; mod app; +mod render; +pub const SHADER_ASSET_PATH: &str = "trace.wgsl"; +pub const WORKGROUP_SIZE: u32 = 8; const NAME: &str = "Black Hole"; fn main() { diff --git a/src/render/mod.rs b/src/render/mod.rs new file mode 100644 index 0000000..07caabe --- /dev/null +++ b/src/render/mod.rs @@ -0,0 +1,2 @@ +pub mod node; +pub mod pipeline; diff --git a/src/render/node.rs b/src/render/node.rs new file mode 100644 index 0000000..1a16d3e --- /dev/null +++ b/src/render/node.rs @@ -0,0 +1,101 @@ +use bevy::{ + prelude::*, + render::{ + render_graph::{self}, + render_resource::{CachedPipelineState, ComputePassDescriptor, PipelineCache, PipelineCacheError}, + renderer::RenderContext, + }, +}; + +use crate::render::pipeline::{TracerImageBindGroups, TracerPipeline}; +use crate::{SHADER_ASSET_PATH, WORKGROUP_SIZE}; + +pub enum TracerState { + Loading, + Init, + Update(usize), +} + +pub struct TracerNode { + state: TracerState, +} + +impl Default for TracerNode { + fn default() -> Self { + Self { + state: TracerState::Loading, + } + } +} + +impl render_graph::Node for TracerNode { + fn update(&mut self, world: &mut World) { + let pipeline = world.resource::(); + let pipeline_cache = world.resource::(); + + // if the corresponding pipeline has loaded, transition to the next stage + match self.state { + TracerState::Loading => { + match pipeline_cache.get_compute_pipeline_state(pipeline.init_pipeline) { + CachedPipelineState::Ok(_) => { + self.state = TracerState::Init; + } + // If the shader hasn't loaded yet, just wait. + CachedPipelineState::Err(PipelineCacheError::ShaderNotLoaded(_)) => {} + CachedPipelineState::Err(err) => { + panic!("Initializing assets/{SHADER_ASSET_PATH}:\n{err}") + } + _ => {} + } + } + TracerState::Init => { + if let CachedPipelineState::Ok(_) = pipeline_cache.get_compute_pipeline_state(pipeline.update_pipeline) + { + self.state = TracerState::Update(1); + } + } + TracerState::Update(0) => { + self.state = TracerState::Update(1); + } + TracerState::Update(1) => { + self.state = TracerState::Update(0); + } + TracerState::Update(_) => unreachable!(), + } + } + + fn run( + &self, + _graph: &mut render_graph::RenderGraphContext, + render_context: &mut RenderContext, + world: &World, + ) -> Result<(), render_graph::NodeRunError> { + let bind_groups = &world.resource::().0; + let pipeline_cache = world.resource::(); + let pipeline = world.resource::(); + + let mut pass = render_context + .command_encoder() + .begin_compute_pass(&ComputePassDescriptor::default()); + + // select the pipeline based on the current state + match self.state { + TracerState::Loading => {} + TracerState::Init => { + let init_pipeline = pipeline_cache.get_compute_pipeline(pipeline.init_pipeline).unwrap(); + pass.set_bind_group(0, &bind_groups[0], &[]); + pass.set_pipeline(init_pipeline); + pass.dispatch_workgroups(1920 / WORKGROUP_SIZE, 1080 / WORKGROUP_SIZE, 1); + } + TracerState::Update(index) => { + if let Some(update_pipeline) = pipeline_cache.get_compute_pipeline(pipeline.update_pipeline) { + pass.set_bind_group(0, &bind_groups[index], &[]); + pass.set_pipeline(update_pipeline); + pass.dispatch_workgroups(1920 / WORKGROUP_SIZE, 1080 / WORKGROUP_SIZE, 1); + } + } + } + + Ok(()) + } +} diff --git a/src/render/pipeline.rs b/src/render/pipeline.rs new file mode 100644 index 0000000..a3e90b4 --- /dev/null +++ b/src/render/pipeline.rs @@ -0,0 +1,188 @@ +use std::borrow::Cow; + +use bevy::{ + prelude::*, + render::{ + Render, RenderApp, RenderSet, + extract_resource::{ExtractResource, ExtractResourcePlugin}, + render_asset::RenderAssets, + render_graph::{RenderGraph, RenderLabel}, + render_resource::{ + BindGroup, BindGroupEntries, BindGroupLayout, BindGroupLayoutEntries, CachedComputePipelineId, + ComputePipelineDescriptor, PipelineCache, ShaderStages, ShaderType, StorageTextureAccess, TextureFormat, + UniformBuffer, + binding_types::{texture_storage_2d, uniform_buffer}, + }, + renderer::{RenderDevice, RenderQueue}, + texture::GpuImage, + }, +}; + +use crate::{SHADER_ASSET_PATH, render::node::TracerNode}; + +#[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)] +pub struct TracerLabel; + +#[derive(Resource, Reflect, ExtractResource, Clone)] +#[reflect(Resource)] +pub struct TracerRenderTextures(pub Handle, pub Handle); + +#[derive(Resource)] +pub struct TracerPipeline { + pub texture_bind_group_layout: BindGroupLayout, + pub init_pipeline: CachedComputePipelineId, + pub update_pipeline: CachedComputePipelineId, +} + +#[derive(Resource, Clone, ExtractResource, ShaderType, Default)] +pub struct TracerUniforms { + pub sky_color: LinearRgba, +} + +pub struct TracerPipelinePlugin; + +impl Plugin for TracerPipelinePlugin { + fn build(&self, app: &mut App) { + app.add_plugins(( + ExtractResourcePlugin::::default(), + ExtractResourcePlugin::::default(), + )); + app.init_resource::(); + let render_app = app.sub_app_mut(RenderApp); + + // render_app.add_systems(Startup, init_pipeline); + render_app.add_systems(Render, prepare_bind_groups.in_set(RenderSet::PrepareBindGroups)); + + let mut render_graph = render_app.world_mut().resource_mut::(); + render_graph.add_node(TracerLabel, TracerNode::default()); + render_graph.add_node_edge(TracerLabel, bevy::render::graph::CameraDriverLabel); + } + + fn finish(&self, app: &mut App) { + let render_app = app.sub_app_mut(RenderApp); + render_app.init_resource::(); + } +} + +impl FromWorld for TracerPipeline { + fn from_world(world: &mut World) -> Self { + let render_device = world.resource::(); + + let texture_bind_group_layout = render_device.create_bind_group_layout( + "TracerImages", + &BindGroupLayoutEntries::sequential( + ShaderStages::COMPUTE, + ( + texture_storage_2d(TextureFormat::Rgba32Float, StorageTextureAccess::ReadOnly), + texture_storage_2d(TextureFormat::Rgba32Float, StorageTextureAccess::WriteOnly), + uniform_buffer::(false), + ), + ), + ); + let shader = world.load_asset(SHADER_ASSET_PATH); + let pipeline_cache = world.resource::(); + let init_pipeline = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { + layout: vec![texture_bind_group_layout.clone()], + shader: shader.clone(), + entry_point: Cow::from("init"), + label: None, + zero_initialize_workgroup_memory: false, + push_constant_ranges: Default::default(), + shader_defs: Default::default(), + }); + + let update_pipeline = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { + layout: vec![texture_bind_group_layout.clone()], + shader, + entry_point: Cow::from("update"), + label: None, + zero_initialize_workgroup_memory: false, + push_constant_ranges: Default::default(), + shader_defs: Default::default(), + }); + + return TracerPipeline { + texture_bind_group_layout, + init_pipeline, + update_pipeline, + }; + } +} + +fn init_pipeline( + mut commands: Commands, + render_device: Res, + asset_server: Res, + pipeline_cache: Res, +) { + let texture_bind_group_layout = render_device.create_bind_group_layout( + "TracerImages", + &BindGroupLayoutEntries::sequential( + ShaderStages::COMPUTE, + ( + texture_storage_2d(TextureFormat::Rgba32Float, StorageTextureAccess::ReadOnly), + texture_storage_2d(TextureFormat::Rgba32Float, StorageTextureAccess::WriteOnly), + uniform_buffer::(false), + ), + ), + ); + let shader = asset_server.load(SHADER_ASSET_PATH); + let init_pipeline = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { + layout: vec![texture_bind_group_layout.clone()], + shader: shader.clone(), + entry_point: Cow::from("init"), + label: None, + zero_initialize_workgroup_memory: false, + push_constant_ranges: Default::default(), + shader_defs: Default::default(), + }); + + let update_pipeline = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { + layout: vec![texture_bind_group_layout.clone()], + shader, + entry_point: Cow::from("update"), + label: None, + zero_initialize_workgroup_memory: false, + push_constant_ranges: Default::default(), + shader_defs: Default::default(), + }); + + commands.insert_resource(TracerPipeline { + texture_bind_group_layout, + init_pipeline, + update_pipeline, + }); +} + +#[derive(Resource)] +pub struct TracerImageBindGroups(pub [BindGroup; 2]); + +fn prepare_bind_groups( + mut commands: Commands, + pipeline: Res, + gpu_images: Res>, + tracer_images: Res, + tracer_uniforms: Res, + render_device: Res, + queue: Res, +) { + let view_a = gpu_images.get(&tracer_images.0).unwrap(); + let view_b = gpu_images.get(&tracer_images.1).unwrap(); + + // Uniform buffer is used here to demonstrate how to set up a uniform in a compute shader + // Alternatives such as storage buffers or push constants may be more suitable for your use case + let mut uniform_buffer = UniformBuffer::from(tracer_uniforms.into_inner()); + uniform_buffer.write_buffer(&render_device, &queue); + + let bind_group_0 = render_device.create_bind_group( + None, + &pipeline.texture_bind_group_layout, + &BindGroupEntries::sequential((&view_a.texture_view, &view_b.texture_view, &uniform_buffer)), + ); + let bind_group_1 = render_device.create_bind_group( + None, + &pipeline.texture_bind_group_layout, + &BindGroupEntries::sequential((&view_b.texture_view, &view_a.texture_view, &uniform_buffer)), + ); + commands.insert_resource(TracerImageBindGroups([bind_group_0, bind_group_1])); +}