Skip to content

Commit

Permalink
Merge pull request #47 from 4atj/limit-uses
Browse files Browse the repository at this point in the history
Limit input uses
  • Loading branch information
4atj authored Apr 10, 2024
2 parents 702da33 + cdcbeed commit 6b15010
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 29 deletions.
22 changes: 12 additions & 10 deletions src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ use crate::{
vec::Vector,
};

pub type Mask = u8;
pub type VarCount = [u8; INPUTS.len()];

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Expr {
pub left: Option<NonNull<Expr>>,
pub right: Option<NonNull<Expr>>,
pub op_idx: OpIndex,
pub var_mask: Mask,
pub var_count: VarCount,
pub output: Vector,
}
unsafe impl Send for Expr {}
Expand All @@ -26,11 +26,13 @@ impl Expr {
}

pub fn variable(index: usize, output: Vector) -> Self {
let mut var_count = [0; INPUTS.len()];
var_count[index] = 1;
Self {
left: None,
right: None,
op_idx: OP_INDEX_VARIABLE,
var_mask: 1 << index,
var_count,
output,
}
}
Expand All @@ -40,27 +42,27 @@ impl Expr {
left: None,
right: None,
op_idx: OP_INDEX_LITERAL,
var_mask: 0,
var_count: [0; INPUTS.len()],
output: Vector::constant(value),
}
}

pub fn is_literal(&self) -> bool {
self.var_mask == 0
self.var_count.iter().all(|&var_count| var_count == 0)
}

pub fn bin(
el: NonNull<Expr>,
er: NonNull<Expr>,
op_idx: OpIndex,
var_mask: Mask,
var_count: VarCount,
output: Vector,
) -> Self {
Self {
left: Some(el),
right: Some(er),
op_idx,
var_mask,
var_count,
output,
}
}
Expand All @@ -70,7 +72,7 @@ impl Expr {
left: None,
right: Some(er.into()),
op_idx,
var_mask: er.var_mask,
var_count: er.var_count,
output,
}
}
Expand All @@ -80,7 +82,7 @@ impl Expr {
left: None,
right: Some(er.into()),
op_idx: OP_INDEX_PARENS,
var_mask: er.var_mask,
var_count: er.var_count,
output: er.output.clone(),
}
}
Expand All @@ -101,7 +103,7 @@ impl Display for Expr {
write!(
f,
"{}",
INPUTS[self.var_mask.trailing_zeros() as usize].name
INPUTS[self.var_count.iter().position(|&c| c == 1).unwrap()].name
)?;
} else {
write!(f, "{}", self.output[0])?;
Expand Down
67 changes: 54 additions & 13 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub mod params;
#[cfg_attr(not(feature = "simd"), path = "vec.rs")]
pub mod vec;

use expr::{Expr, Mask, NonNullExpr};
use expr::{Expr, NonNullExpr, VarCount};
use operator::*;
use params::*;

Expand All @@ -35,19 +35,39 @@ fn positive_integer_length(mut k: Num) -> usize {
l
}

fn can_use_required_vars(mask: u8, length: usize) -> bool {
!USE_ALL_VARS || length + (INPUTS.len() - (mask.count_ones() as usize)) * 2 <= MAX_LENGTH
fn can_use_required_vars(var_count: VarCount, length: usize) -> bool {
let missing_uses: u8 = var_count
.iter()
.zip(INPUTS.iter())
.map(|(&c, i)| i.min_uses - std::cmp::min(c, i.min_uses))
.sum();
length + missing_uses as usize * 2 <= MAX_LENGTH
}

fn is_leaf_expr(op_idx: OpIndex, length: usize) -> bool {
length == MAX_LENGTH
|| length == MAX_LENGTH - 1 && (UNARY_OPERATORS.len() == 0 || op_idx.prec() < UnaryOp::PREC)
}

const fn has_unlimited_var() -> bool {
let mut i = 0;
while i < INPUTS.len() {
if INPUTS[i].max_uses == u8::MAX {
return true;
}
i += 1;
}
false
}

fn save(level: &mut CacheLevel, expr: Expr, n: usize, cache: &Cache, hashset_cache: &HashSetCache) {
const ALL_MASK: Mask = (1 << INPUTS.len()) - 1;
let uses_required_vars = expr
.var_count
.iter()
.zip(INPUTS.iter())
.all(|(&c, i)| c >= i.min_uses);

if (!USE_ALL_VARS || expr.var_mask == ALL_MASK) && Matcher::match_all(&expr) {
if uses_required_vars && Matcher::match_all(&expr) {
println!("{expr}");
return;
}
Expand All @@ -56,7 +76,14 @@ fn save(level: &mut CacheLevel, expr: Expr, n: usize, cache: &Cache, hashset_cac
return;
}

if !REUSE_VARS && expr.var_mask == ALL_MASK {
let cant_use_more_vars = !has_unlimited_var()
&& expr
.var_count
.iter()
.zip(INPUTS.iter())
.all(|(&c, i)| c == i.max_uses);

if cant_use_more_vars {
let mut mp: HashMap<Num, Num> = HashMap::new();
for i in 0..GOAL.len() {
if let Some(old) = mp.insert(expr.output[i], GOAL[i]) {
Expand Down Expand Up @@ -112,11 +139,18 @@ fn find_binary_operators(
if er.is_literal() && el.is_literal() {
return;
}
if !REUSE_VARS && (el.var_mask & er.var_mask != 0) {
return;
let mut var_count = el.var_count;
for ((l, &r), i) in var_count
.iter_mut()
.zip(er.var_count.iter())
.zip(INPUTS.iter())
{
*l += r;
if *l > i.max_uses {
return;
}
}
let mask = el.var_mask | er.var_mask;
if !can_use_required_vars(mask, n) {
if !can_use_required_vars(var_count, n) {
return;
}
seq!(idx in 0..100 {
Expand All @@ -140,7 +174,7 @@ fn find_binary_operators(
} else if let Some(output) = op.vec_apply(el.output.clone(), &er.output) {
save(
cn,
Expr::bin(el.into(), er.into(), op_idx, mask, output),
Expr::bin(el.into(), er.into(), op_idx, var_count, output),
n,
cache,
hashset_cache,
Expand Down Expand Up @@ -194,7 +228,7 @@ fn find_unary_operators(
n: usize,
er: &Expr,
) {
if !can_use_required_vars(er.var_mask, n) {
if !can_use_required_vars(er.var_count, n) {
return;
}
seq!(idx in 0..10 {
Expand Down Expand Up @@ -249,7 +283,7 @@ fn find_parens_expressions(
return;
}
for er in &cache[n - 2] {
if !can_use_required_vars(er.var_mask, n) {
if !can_use_required_vars(er.var_count, n) {
continue;
}
if er.op_idx < OP_INDEX_PARENS {
Expand Down Expand Up @@ -376,8 +410,15 @@ fn validate_input() {
GOAL.len(),
"INPUTS and GOAL must have equal length"
);

assert_ne!(i.max_uses, 0, "INPUTS maximum uses must be non-zero");
}

assert!(
INPUTS.iter().map(|i| i.min_uses as usize).sum::<usize>() * 2 <= MAX_LENGTH + 1,
"The minimum uses requirement will never be met"
);

let mut input_set = HashSet::new();
for i in 0..INPUTS[0].vec.len() {
let mut input = [0; INPUTS.len()];
Expand Down
10 changes: 4 additions & 6 deletions src/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@ pub type Num = i32;
pub struct Input {
pub name: &'static str,
pub vec: &'static [Num],
pub min_uses: u8,
pub max_uses: u8,
}

pub const INPUTS: &[Input] = &[Input {
name: "n",
vec: &['E' as i32, 'W' as i32, 'N' as i32, 'S' as i32],
min_uses: 1,
max_uses: 255,
}];

pub struct Matcher {}
Expand Down Expand Up @@ -101,9 +105,3 @@ pub const UNARY_OPERATORS: &[UnaryOp] = &[

/// Match leaf expressions 1 output at a time to avoid unnecessary precalculations
pub const MATCH_1BY1: bool = true;

/// Search expressions that use the same variable twice (like `x*x`).
pub const REUSE_VARS: bool = true;

/// Controls whether all declared variables should be always used.
pub const USE_ALL_VARS: bool = true;

0 comments on commit 6b15010

Please sign in to comment.