Skip to content

Commit

Permalink
Unify temporal and spatial taps
Browse files Browse the repository at this point in the history
  • Loading branch information
kvark committed Oct 15, 2024
1 parent 819c63c commit 715368d
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 87 deletions.
14 changes: 5 additions & 9 deletions blade-helpers/src/hud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,14 @@ impl ExposeHud for blade_render::RayConfig {
&mut self.environment_importance_sampling,
"Env importance sampling",
);
ui.checkbox(&mut self.temporal_tap, "Temporal tap");
ui.add(egui::widgets::Slider::new(&mut self.tap_count, 0..=10).text("Tap count"));
ui.add(egui::widgets::Slider::new(&mut self.tap_radius, 1..=50).text("Tap radius (px)"));
ui.add(
egui::widgets::Slider::new(&mut self.temporal_history, 0..=50).text("Temporal history"),
egui::widgets::Slider::new(&mut self.tap_confidence_near, 1..=50)
.text("Max confidence"),
);
ui.add(egui::widgets::Slider::new(&mut self.spatial_taps, 0..=10).text("Spatial taps"));
ui.add(
egui::widgets::Slider::new(&mut self.spatial_tap_history, 0..=50)
.text("Spatial tap history"),
);
ui.add(
egui::widgets::Slider::new(&mut self.spatial_radius, 1..=50)
.text("Spatial radius (px)"),
egui::widgets::Slider::new(&mut self.tap_confidence_far, 1..=50).text("Min confidence"),
);
ui.add(
egui::widgets::Slider::new(&mut self.t_start, 0.001..=0.5)
Expand Down
9 changes: 4 additions & 5 deletions blade-helpers/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@ pub fn default_ray_config() -> blade_render::RayConfig {
blade_render::RayConfig {
num_environment_samples: 1,
environment_importance_sampling: false,
temporal_tap: true,
temporal_history: 10,
spatial_taps: 1,
spatial_tap_history: 10,
spatial_radius: 20,
tap_count: 2,
tap_radius: 20,
tap_confidence_near: 15,
tap_confidence_far: 10,
t_start: 0.01,
pairwise_mis: true,
defensive_mis: 0.1,
Expand Down
94 changes: 36 additions & 58 deletions blade-render/code/ray-trace.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,22 @@
const RAY_FLAG_CULL_NO_OPAQUE: u32 = 0x80u;

const PI: f32 = 3.1415926;
const MAX_RESERVOIRS: u32 = 2u;
const MAX_RESERVOIRS: u32 = 4u;
// See "DECOUPLING SHADING AND REUSE" in
// "Rearchitecting Spatiotemporal Resampling for Production"
const DECOUPLED_SHADING: bool = false;

// We are considering 2x2 grid, so must be <= 4
const FACTOR_TEMPORAL_CANDIDATES: u32 = 1u;
// How many more candidates to consder than the taps we need
const FACTOR_SPATIAL_CANDIDATES: u32 = 3u;
// Has to be at least discarding the 2x2 block
const MIN_SPATIAL_REUSE_DISTANCE: i32 = 7;
const FACTOR_CANDIDATES: u32 = 3u;

struct MainParams {
frame_index: u32,
num_environment_samples: u32,
environment_importance_sampling: u32,
temporal_tap: u32,
temporal_history: u32,
spatial_taps: u32,
spatial_tap_history: u32,
spatial_radius: i32,
tap_count: u32,
tap_radius: f32,
tap_confidence_near: f32,
tap_confidence_far: f32,
t_start: f32,
use_pairwise_mis: u32,
defensive_mis: f32,
Expand Down Expand Up @@ -124,13 +119,13 @@ fn normalize_reservoir(r: ptr<function, LiveReservoir>, history: f32) {
(*r).history = history;
}
}
fn unpack_reservoir(f: StoredReservoir, max_history: u32, radiance: vec3<f32>) -> LiveReservoir {
fn unpack_reservoir(f: StoredReservoir, max_confidence: f32, radiance: vec3<f32>) -> LiveReservoir {
var r: LiveReservoir;
r.selected_light_index = f.light_index;
r.selected_uv = f.light_uv;
r.selected_target_score = f.target_score;
r.selected_radiance = radiance;
let history = min(f.confidence, f32(max_history));
let history = min(f.confidence, max_confidence);
r.weight_sum = f.contribution_weight * f.target_score * history;
r.history = history;
return r;
Expand Down Expand Up @@ -234,7 +229,9 @@ fn evaluate_brdf(surface: Surface, dir: vec3<f32>) -> f32 {
return lambert_brdf * max(0.0, lambert_term);
}

fn check_ray_occluded(acs: acceleration_structure, position: vec3<f32>, direction: vec3<f32>, debug_len: f32) -> bool {
var<private> debug_len: f32;

fn check_ray_occluded(acs: acceleration_structure, position: vec3<f32>, direction: vec3<f32>, debug_color: u32) -> bool {
var rq: ray_query;
let flags = RAY_FLAG_TERMINATE_ON_FIRST_HIT | RAY_FLAG_CULL_NO_OPAQUE;
rayQueryInitialize(&rq, acs,
Expand All @@ -244,8 +241,8 @@ fn check_ray_occluded(acs: acceleration_structure, position: vec3<f32>, directio
let intersection = rayQueryGetCommittedIntersection(&rq);

let occluded = intersection.kind != RAY_QUERY_INTERSECTION_NONE;
if (debug_len != 0.0) {
let color = select(0xFFFFFFu, 0x0000FFu, occluded);
if (DEBUG_MODE && debug_len > 0.0) {
let color = select(0xFFFFFFu, 0x808080u, occluded) & debug_color;
debug_line(position, position + debug_len * direction, color);
}
return occluded;
Expand Down Expand Up @@ -284,7 +281,8 @@ fn make_target_score(color: vec3<f32>) -> TargetScore {
}

fn estimate_target_score_with_occlusion(
surface: Surface, position: vec3<f32>, light_index: u32, light_uv: vec2<f32>, acs: acceleration_structure, debug_len: f32
surface: Surface, position: vec3<f32>, light_index: u32, light_uv: vec2<f32>, acs: acceleration_structure,
debug_color: u32,
) -> TargetScore {
if (light_index != 0u) {
return TargetScore();
Expand All @@ -298,7 +296,7 @@ fn estimate_target_score_with_occlusion(
return TargetScore();
}

if (check_ray_occluded(acs, position, direction, debug_len)) {
if (check_ray_occluded(acs, position, direction, debug_color)) {
return TargetScore();
} else {
//Note: same as `evaluate_reflected_light`
Expand All @@ -307,7 +305,7 @@ fn estimate_target_score_with_occlusion(
}
}

fn evaluate_sample(ls: LightSample, surface: Surface, start_pos: vec3<f32>, debug_len: f32) -> f32 {
fn evaluate_sample(ls: LightSample, surface: Surface, start_pos: vec3<f32>, debug_color: u32) -> f32 {
let dir = map_equirect_uv_to_dir(ls.uv);
if (dot(dir, surface.flat_normal) <= 0.0) {
return 0.0;
Expand All @@ -323,7 +321,7 @@ fn evaluate_sample(ls: LightSample, surface: Surface, start_pos: vec3<f32>, debu
return 0.0;
}

if (check_ray_occluded(acc_struct, start_pos, dir, debug_len)) {
if (check_ray_occluded(acc_struct, start_pos, dir, debug_color)) {
return 0.0;
}

Expand All @@ -338,7 +336,7 @@ struct RestirOutput {
radiance: vec3<f32>,
}

fn compute_restir(surface: Surface, pixel: vec2<i32>, rng: ptr<function, RandomState>, enable_debug: bool) -> RestirOutput {
fn compute_restir(surface: Surface, pixel: vec2<i32>, rng: ptr<function, RandomState>) -> RestirOutput {
let ray_dir = get_ray_direction(camera, pixel);
let pixel_index = get_reservoir_index(pixel, camera);
if (surface.depth == 0.0) {
Expand All @@ -350,7 +348,6 @@ fn compute_restir(surface: Surface, pixel: vec2<i32>, rng: ptr<function, RandomS
if (WRITE_DEBUG_IMAGE && debug.view_mode == DebugMode_Depth) {
textureStore(out_debug, pixel, vec4<f32>(1.0 / surface.depth));
}
let debug_len = select(0.0, surface.depth * 0.2, enable_debug);
let position = camera.position + surface.depth * ray_dir;
let normal = qrot(surface.basis, vec3<f32>(0.0, 0.0, 1.0));

Expand All @@ -363,7 +360,7 @@ fn compute_restir(surface: Surface, pixel: vec2<i32>, rng: ptr<function, RandomS
ls = sample_light_from_sphere(rng);
}

let brdf = evaluate_sample(ls, surface, position, debug_len);
let brdf = evaluate_sample(ls, surface, position, 0x00FF00u);
if (brdf > 0.0) {
let other = make_reservoir(ls, 0u, vec3<f32>(brdf));
merge_reservoir(&canonical, other, random_gen(rng));
Expand All @@ -373,36 +370,17 @@ fn compute_restir(surface: Surface, pixel: vec2<i32>, rng: ptr<function, RandomS
}

let center_coord = get_prev_pixel(pixel, position);
let center_pixel = vec2<i32>(center_coord);
// Trick to start with closer pixels: we derive the "further"
// pixel in 2x2 grid by considering the sum.
let further_pixel = vec2<i32>(center_coord - 0.5) + vec2<i32>(center_coord + 0.5) - center_pixel;

// First, gather the list of reservoirs to merge with
var accepted_reservoir_indices = array<i32, MAX_RESERVOIRS>();
var accepted_count = 0u;
var temporal_index = ~0u;
let num_temporal_candidates = parameters.temporal_tap * FACTOR_TEMPORAL_CANDIDATES;
let num_candidates = num_temporal_candidates + parameters.spatial_taps * FACTOR_SPATIAL_CANDIDATES;
let max_samples = min(MAX_RESERVOIRS, 1u + parameters.spatial_taps);
let max_samples = min(MAX_RESERVOIRS, parameters.tap_count);
let num_candidates = max_samples * FACTOR_CANDIDATES;

for (var tap = 0u; tap < num_candidates && accepted_count < max_samples; tap += 1u) {
var other_pixel = center_pixel;
if (tap < num_temporal_candidates) {
if (temporal_index < tap) {
continue;
}
let mask = vec2<u32>(tap) & vec2<u32>(1u, 2u);
other_pixel = select(center_pixel, further_pixel, mask != vec2<u32>(0u));
} else {
let r0 = max(center_pixel - vec2<i32>(parameters.spatial_radius), vec2<i32>(0));
let r1 = min(center_pixel + vec2<i32>(parameters.spatial_radius + 1), vec2<i32>(prev_camera.target_size));
other_pixel = vec2<i32>(mix(vec2<f32>(r0), vec2<f32>(r1), vec2<f32>(random_gen(rng), random_gen(rng))));
let diff = other_pixel - center_pixel;
if (dot(diff, diff) < MIN_SPATIAL_REUSE_DISTANCE) {
continue;
}
}
let radius = parameters.tap_radius * random_gen(rng);
let offset = radius * sample_circle(random_gen(rng));
let other_pixel = vec2<i32>(center_coord + offset);

let other_index = get_reservoir_index(other_pixel, prev_camera);
if (other_index < 0) {
Expand All @@ -419,9 +397,6 @@ fn compute_restir(surface: Surface, pixel: vec2<i32>, rng: ptr<function, RandomS
continue;
}

if (tap < num_temporal_candidates) {
temporal_index = accepted_count;
}
accepted_reservoir_indices[accepted_count] = other_index;
accepted_count += 1u;
}
Expand All @@ -444,25 +419,26 @@ fn compute_restir(surface: Surface, pixel: vec2<i32>, rng: ptr<function, RandomS
for (var rid = 0u; rid < accepted_count; rid += 1u) {
let neighbor_index = accepted_reservoir_indices[rid];
let neighbor = prev_reservoirs[neighbor_index];
let neighbor_pixel = get_pixel_from_reservoir_index(neighbor_index, prev_camera);

let max_history = select(parameters.spatial_tap_history, parameters.temporal_history, rid == temporal_index);
let offset = vec2<f32>(neighbor_pixel) - center_coord;
let max_confidence = mix(parameters.tap_confidence_near, parameters.tap_confidence_far, length(offset) / parameters.tap_radius);
var other: LiveReservoir;
if (parameters.use_pairwise_mis != 0u) {
let neighbor_pixel = get_pixel_from_reservoir_index(neighbor_index, prev_camera);
let neighbor_history = min(neighbor.confidence, f32(max_history));
let neighbor_history = min(neighbor.confidence, max_confidence);
{ // scoping this to hint the register allocation
let neighbor_surface = read_prev_surface(neighbor_pixel);
let neighbor_dir = get_ray_direction(prev_camera, neighbor_pixel);
let neighbor_position = prev_camera.position + neighbor_surface.depth * neighbor_dir;

let t_canonical_at_neighbor = estimate_target_score_with_occlusion(
neighbor_surface, neighbor_position, canonical.selected_light_index, canonical.selected_uv, prev_acc_struct, debug_len);
neighbor_surface, neighbor_position, canonical.selected_light_index, canonical.selected_uv, prev_acc_struct, 0xFF0000u);
let r_canonical = ratio(canonical.history * canonical.selected_target_score * inv_count, neighbor_history * t_canonical_at_neighbor.score);
mis_canonical += mis_scale * r_canonical;
}

let t_neighbor_at_canonical = estimate_target_score_with_occlusion(
surface, position, neighbor.light_index, neighbor.light_uv, acc_struct, debug_len);
surface, position, neighbor.light_index, neighbor.light_uv, acc_struct, 0x0000FFu);
let r_neighbor = ratio(neighbor_history * neighbor.target_score, canonical.history * t_neighbor_at_canonical.score * inv_count);
let mis_neighbor = mis_scale * r_neighbor;

Expand All @@ -473,8 +449,8 @@ fn compute_restir(surface: Surface, pixel: vec2<i32>, rng: ptr<function, RandomS
other.selected_radiance = t_neighbor_at_canonical.color;
other.weight_sum = t_neighbor_at_canonical.score * neighbor.contribution_weight * mis_neighbor;
} else {
let radiance = evaluate_reflected_light(surface, other.selected_light_index, other.selected_uv);
other = unpack_reservoir(neighbor, max_history, radiance);
let radiance = evaluate_reflected_light(surface, neighbor.light_index, neighbor.light_uv);
other = unpack_reservoir(neighbor, max_confidence, radiance);
}

if (DECOUPLED_SHADING) {
Expand Down Expand Up @@ -522,7 +498,9 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let surface = read_surface(vec2<i32>(global_id.xy));
let enable_debug = DEBUG_MODE && all(global_id.xy == debug.mouse_pos);
let enable_restir_debug = (debug.draw_flags & DebugDrawFlags_RESTIR) != 0u && enable_debug;
let ro = compute_restir(surface, vec2<i32>(global_id.xy), &rng, enable_restir_debug);
debug_len = select(0.0, surface.depth * 0.2, enable_restir_debug);

let ro = compute_restir(surface, vec2<i32>(global_id.xy), &rng);

let color = ro.radiance;
if (enable_debug) {
Expand Down
27 changes: 12 additions & 15 deletions blade-render/src/render/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,10 @@ pub struct DebugConfig {
pub struct RayConfig {
pub num_environment_samples: u32,
pub environment_importance_sampling: bool,
pub temporal_tap: bool,
pub temporal_history: u32,
pub spatial_taps: u32,
pub spatial_tap_history: u32,
pub spatial_radius: u32,
pub tap_count: u32,
pub tap_radius: u32,
pub tap_confidence_near: u32,
pub tap_confidence_far: u32,
pub t_start: f32,
/// Evaluate MIS factor for ReSTIR in a pair-wise fashion.
/// Adds 2 extra visibility rays per reused sample.
Expand Down Expand Up @@ -372,11 +371,10 @@ struct MainParams {
frame_index: u32,
num_environment_samples: u32,
environment_importance_sampling: u32,
temporal_tap: u32,
temporal_history: u32,
spatial_taps: u32,
spatial_tap_history: u32,
spatial_radius: u32,
tap_count: u32,
tap_radius: f32,
tap_confidence_near: f32,
tap_confidence_far: f32,
t_start: f32,
use_pairwise_mis: u32,
defensive_mis: f32,
Expand Down Expand Up @@ -1172,11 +1170,10 @@ impl Renderer {
num_environment_samples: ray_config.num_environment_samples,
environment_importance_sampling: ray_config.environment_importance_sampling
as u32,
temporal_tap: ray_config.temporal_tap as u32,
temporal_history: ray_config.temporal_history,
spatial_taps: ray_config.spatial_taps,
spatial_tap_history: ray_config.spatial_tap_history,
spatial_radius: ray_config.spatial_radius,
tap_count: ray_config.tap_count,
tap_radius: ray_config.tap_radius as f32,
tap_confidence_near: ray_config.tap_confidence_near as f32,
tap_confidence_far: ray_config.tap_confidence_far as f32,
t_start: ray_config.t_start,
use_pairwise_mis: ray_config.pairwise_mis as u32,
defensive_mis: ray_config.defensive_mis,
Expand Down

0 comments on commit 715368d

Please sign in to comment.