Skip to content

Commit

Permalink
Duplicate signal strength to all outputs (MCHPR#129)
Browse files Browse the repository at this point in the history
* decrease size of DirectLink to 4 bytes

* duplicate signal strength to all outputs
This allows input states to be checked with simd
because the input nodes don't need to be fetched

* cleanup old code

* remove erroring line

* make index of NodeId private again

* Make TickScheduler::NUM_QUEUES constant

* remove from_ne_bytes from last_index_positive

* only get needed repeater inputs

* replace magic number with NUM_QUEUES

* don't crash the program when ss is larger than 15
when creating a forward link

* fix minor typo

* rename INPUT_MASK to BOOL_INPUT_MASK

* Run cargo fmt

* address nits

* Add check for maximum inputs.
This isn't really neccesary right now but if we ever add optimizations which combine nodes it might become likely that a node will have more than 255 inputs.

* make poper use of input counters in from_compile_node

* Make clamp_weights pass mandatory

* Use get_unchecked_mut() for incrementing signal strength counters

* skip updating if output power did not change

* wrap input arrays into a struct to ensure they are aligned

* run cargo fmt
  • Loading branch information
BramOtte committed Jan 3, 2024
1 parent 728a848 commit 7b88d53
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 24 deletions.
67 changes: 43 additions & 24 deletions crates/core/src/redpiler/backend/direct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,10 @@ struct ForwardLink {
}

impl ForwardLink {
pub fn new(id: NodeId, side: bool, mut ss: u8) -> Self {
pub fn new(id: NodeId, side: bool, ss: u8) -> Self {
assert!(id.index() < (1 << 27));
if ss >= 16 {
ss = 15;
}
// the clamp_weights compile pass should ensure ss < 16
assert!(ss < 16);
Self {
data: (id.index() as u32) << 5 | if side { 1 << 4 } else { 0 } | ss as u32,
}
Expand Down Expand Up @@ -154,14 +153,20 @@ impl NodeType {
}
}

#[repr(align(16))]
#[derive(Debug, Clone, Default)]
struct NodeInput {
ss_counts: [u8; 16],
}

// struct is 128 bytes to fit nicely into cachelines
// which are usualy 64 bytes, it can vary but is almost always a power of 2
#[derive(Debug, Clone)]
#[repr(align(128))]
pub struct Node {
ty: NodeType,
default_inputs: [u8; 16],
side_inputs: [u8; 16],
default_inputs: NodeInput,
side_inputs: NodeInput,
updates: SmallVec<[ForwardLink; 18]>,

facing_diode: bool,
Expand Down Expand Up @@ -189,14 +194,14 @@ impl Node {
stats: &mut FinalGraphStats,
) -> Self {
let node = &graph[node_idx];

const MAX_INPUTS: usize = 255;

let mut default_input_count = 0;
let mut side_input_count = 0;

let mut default_inputs = [0; 16];
let mut side_inputs = [0; 16];
let mut default_inputs = NodeInput { ss_counts: [0; 16] };
let mut side_inputs = NodeInput { ss_counts: [0; 16] };
for edge in graph.edges_directed(node_idx, Direction::Incoming) {
let weight = edge.weight();
let distance = weight.ss;
Expand All @@ -205,17 +210,20 @@ impl Node {
match weight.ty {
LinkType::Default => {
if default_input_count >= MAX_INPUTS {
panic!("Exceeded the maximum number of default inputs {}", MAX_INPUTS);
panic!(
"Exceeded the maximum number of default inputs {}",
MAX_INPUTS
);
}
default_input_count += 1;
default_inputs[ss as usize] += 1;
},
default_inputs.ss_counts[ss as usize] += 1;
}
LinkType::Side => {
if side_input_count >= MAX_INPUTS {
panic!("Exceeded the maximum number of side inputs {}", MAX_INPUTS);
}
side_input_count += 1;
side_inputs[ss as usize] += 1;
side_inputs.ss_counts[ss as usize] += 1;
}
}
}
Expand Down Expand Up @@ -405,8 +413,19 @@ impl DirectBackend {
} else {
&mut update_ref.default_inputs
};
inputs[old_power.saturating_sub(distance) as usize] -= 1;
inputs[new_power.saturating_sub(distance) as usize] += 1;

let old_power = old_power.saturating_sub(distance);
let new_power = new_power.saturating_sub(distance);

if old_power == new_power {
continue;
}

// Safety: signal strength is never larger than 15
unsafe {
*inputs.ss_counts.get_unchecked_mut(old_power as usize) -= 1;
*inputs.ss_counts.get_unchecked_mut(new_power as usize) += 1;
}

update_node(&mut self.scheduler, &mut self.nodes, update);
}
Expand Down Expand Up @@ -672,11 +691,11 @@ const BOOL_INPUT_MASK: u128 = u128::from_ne_bytes([
]);

fn get_bool_input(node: &Node) -> bool {
u128::from_ne_bytes(node.default_inputs) & BOOL_INPUT_MASK != 0
u128::from_le_bytes(node.default_inputs.ss_counts) & BOOL_INPUT_MASK != 0
}

fn get_bool_side(node: &Node) -> bool {
u128::from_ne_bytes(node.side_inputs) & BOOL_INPUT_MASK != 0
u128::from_le_bytes(node.side_inputs.ss_counts) & BOOL_INPUT_MASK != 0
}

fn last_index_positive(array: &[u8; 16]) -> u32 {
Expand All @@ -690,17 +709,17 @@ fn last_index_positive(array: &[u8; 16]) -> u32 {
}

fn get_all_input(node: &Node) -> (u8, u8) {
let input_power = last_index_positive(&node.default_inputs) as u8;
let input_power = last_index_positive(&node.default_inputs.ss_counts) as u8;

let side_input_power = last_index_positive(&node.side_inputs) as u8;
let side_input_power = last_index_positive(&node.side_inputs.ss_counts) as u8;

(input_power, side_input_power)
}

fn get_decoder_state(node: &Node) -> u32 {
let mut new_state = 0;
for i in 1..node.default_inputs.len() {
let input = node.default_inputs[i];
for i in 1..node.default_inputs.ss_counts.len() {
let input = node.default_inputs.ss_counts[i];
new_state = (new_state << 1) | if input > 0 {1} else {0};
}
new_state
Expand Down Expand Up @@ -821,8 +840,8 @@ fn update_node(scheduler: &mut TickScheduler, nodes: &mut Nodes, node_id: NodeId
let old = if node.will_be_powered {15 - distance} else {0};
{
let update = &mut nodes[update];
update.default_inputs[old as usize] -= 1;
update.default_inputs[new as usize] += 1;
update.default_inputs.ss_counts[old as usize] -= 1;
update.default_inputs.ss_counts[new as usize] += 1;
}
update_node(scheduler, nodes, update);
}
Expand Down
5 changes: 5 additions & 0 deletions crates/core/src/redpiler/passes/clamp_weights.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,9 @@ impl<W: World> Pass<W> for ClampWeights {
fn run_pass(&self, graph: &mut CompileGraph, _: &CompilerOptions, _: &CompilerInput<'_, W>) {
graph.retain_edges(|g, edge| g[edge].ss < 15);
}

fn should_run(&self, _: &CompilerOptions) -> bool {
// Mandatory
true
}
}

0 comments on commit 7b88d53

Please sign in to comment.