Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion hugr-core/src/std_extensions/collections/borrow_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pub const BORROW_ARRAY_VALUENAME: TypeName = TypeName::new_inline("borrow_array"
/// Reported unique name of the extension
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("collections.borrow_arr");
/// Extension version.
pub const VERSION: semver::Version = semver::Version::new(0, 1, 1);
pub const VERSION: semver::Version = semver::Version::new(0, 1, 2);

/// A linear, unsafe, fixed-length collection of values.
///
Expand Down Expand Up @@ -123,6 +123,8 @@ pub enum BArrayUnsafeOpDef {
discard_all_borrowed,
/// `new_all_borrowed<size, elem_ty>: () -> borrow_array<size, elem_ty>`
new_all_borrowed,
/// is_borrowed<N, T>: borrow_array<N, T>, usize -> bool, borrow_array<N, T>
is_borrowed,
}

impl BArrayUnsafeOpDef {
Expand Down Expand Up @@ -166,6 +168,13 @@ impl BArrayUnsafeOpDef {
Self::new_all_borrowed => {
PolyFuncTypeRV::new(params, FuncValueType::new(type_row![], vec![array_ty]))
}
Self::is_borrowed => PolyFuncTypeRV::new(
params,
FuncValueType::new(
vec![array_ty.clone(), usize_t],
vec![crate::extension::prelude::bool_t(), array_ty],
Comment on lines +174 to +175
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type order is a bit weird, with the array in first place for the input but 2nd in the output.
But I see that we're doing the same with borrow so 🤷

),
),
}
.into()
}
Expand Down Expand Up @@ -210,6 +219,7 @@ impl MakeOpDef for BArrayUnsafeOpDef {
"Discard a borrow array where all elements have been borrowed"
}
Self::new_all_borrowed => "Create a new borrow array that contains no elements",
Self::is_borrowed => "Test whether an element in a borrow array has been borrowed",
}
.into()
}
Expand Down Expand Up @@ -719,6 +729,38 @@ pub trait BArrayOpBuilder: GenericArrayOpBuilder {
.outputs_arr();
Ok(arr)
}

/// Adds an operation to test whether an element in a borrow array has been borrowed.
///
/// # Arguments
///
/// * `elem_ty` - The type of the elements in the array.
/// * `size` - The size of the array.
/// * `input` - The wire representing the array.
/// * `index` - The wire representing the index to test.
///
/// # Errors
///
/// Returns an error if building the operation fails.
///
/// # Returns
///
/// A tuple containing:
/// * The wire representing the boolean result (true if borrowed).
/// * The wire representing the updated array.
fn add_is_borrowed(
&mut self,
elem_ty: Type,
size: u64,
input: Wire,
index: Wire,
) -> Result<(Wire, Wire), BuildError> {
let op = BArrayUnsafeOpDef::is_borrowed.instantiate(&[size.into(), elem_ty.into()])?;
let [is_borrowed, arr] = self
.add_dataflow_op(op.to_extension_op().unwrap(), vec![input, index])?
.outputs_arr();
Ok((is_borrowed, arr))
}
}

impl<D: Dataflow> BArrayOpBuilder for D {}
Expand Down Expand Up @@ -804,4 +846,25 @@ mod test {
builder.finish_hugr_with_outputs([arr_with_put]).unwrap()
};
}
#[test]
fn test_is_borrowed() {
let size = 4;
let elem_ty = qb_t();
let arr_ty = borrow_array_type(size, elem_ty.clone());

let mut builder =
DFGBuilder::new(Signature::new(vec![arr_ty.clone()], vec![qb_t(), arr_ty])).unwrap();
let idx = builder.add_load_value(ConstUsize::new(2));
let [arr] = builder.input_wires_arr();
// Borrow the element at index 2
let (qb, arr_with_borrowed) = builder
.add_borrow_array_borrow(elem_ty.clone(), size, arr, idx)
.unwrap();
let (_is_borrowed, arr_after_check) = builder
.add_is_borrowed(elem_ty.clone(), size, arr_with_borrowed, idx)
.unwrap();
builder
.finish_hugr_with_outputs([qb, arr_after_check])
.unwrap();
}
}
159 changes: 145 additions & 14 deletions hugr-llvm/src/extension/collections/borrow_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -599,20 +599,16 @@ impl MaskCheck {
None,
|ctx, [mask_ptr, idx]| {
// Compute mask bitarray block index via `idx // BLOCK_SIZE`
let mask_ptr = mask_ptr.into_pointer_value();
let idx = idx.into_int_value();
let usize_t = usize_ty(&ctx.typing_session());
let block_size = usize_t.const_int(usize_t.get_bit_width() as u64, false);
let builder = ctx.builder();
let block_idx = builder.build_int_unsigned_div(idx, block_size, "")?;
let block_ptr = unsafe { builder.build_in_bounds_gep(mask_ptr, &[block_idx], "")? };
let block = builder.build_load(block_ptr, "")?.into_int_value();

// Extract bit from the block at position `idx % BLOCK_SIZE`
let idx_in_block = builder.build_int_unsigned_rem(idx, block_size, "")?;
let block_shifted = builder.build_right_shift(block, idx_in_block, false, "")?;
let bit =
builder.build_int_truncate(block_shifted, ctx.iw_context().bool_type(), "")?;
let (
BlockData {
block_ptr,
block,
idx_in_block,
},
bit,
) = inspect_mask_idx_bit(ctx, mask_ptr, idx)?;

let panic_bb = ctx.build_positioned_new_block("panic", None, |ctx, panic_bb| {
let err: &ConstError = match self {
MaskCheck::CheckNotBorrowed | MaskCheck::Borrow => &ERR_ALREADY_BORROWED,
Expand Down Expand Up @@ -651,6 +647,38 @@ impl MaskCheck {
}
}

struct BlockData<'c> {
block_ptr: PointerValue<'c>,
block: IntValue<'c>,
idx_in_block: IntValue<'c>,
}

fn inspect_mask_idx_bit<'c, H: HugrView<Node = Node>>(
ctx: &mut EmitFuncContext<'c, '_, H>,
mask_ptr: BasicValueEnum<'c>,
idx: BasicValueEnum<'c>,
) -> Result<(BlockData<'c>, IntValue<'c>)> {
let usize_t = usize_ty(&ctx.typing_session());
let mask_ptr = mask_ptr.into_pointer_value();
let idx = idx.into_int_value();
let block_size = usize_t.const_int(usize_t.get_bit_width() as u64, false);
let builder = ctx.builder();
let block_idx = builder.build_int_unsigned_div(idx, block_size, "")?;
let block_ptr = unsafe { builder.build_in_bounds_gep(mask_ptr, &[block_idx], "")? };
let block = builder.build_load(block_ptr, "")?.into_int_value();
let idx_in_block = builder.build_int_unsigned_rem(idx, block_size, "")?;
let block_shifted = builder.build_right_shift(block, idx_in_block, false, "")?;
let bit = builder.build_int_truncate(block_shifted, ctx.iw_context().bool_type(), "")?;
Ok((
BlockData {
block_ptr,
block,
idx_in_block,
},
bit,
))
}

struct MaskInfo<'a> {
mask_ptr: PointerValue<'a>,
offset: IntValue<'a>,
Expand Down Expand Up @@ -787,6 +815,27 @@ fn build_mask_padding1d<'c, H: HugrView<Node = Node>>(
Ok(())
}

/// Emits a check that returns whether a specific array element is borrowed (true) or not (false).
pub fn build_is_borrowed_bit<'c, H: HugrView<Node = Node>>(
ctx: &mut EmitFuncContext<'c, '_, H>,
mask_ptr: PointerValue<'c>,
idx: IntValue<'c>,
) -> Result<inkwell::values::IntValue<'c>> {
// Wrap the check into a function instead of inlining
const FUNC_NAME: &str = "__barray_is_borrowed";
get_or_make_function(
ctx,
FUNC_NAME,
[mask_ptr.into(), idx.into()],
Some(ctx.iw_context().bool_type().into()),
|ctx, [mask_ptr, idx]| {
let (_, bit) = inspect_mask_idx_bit(ctx, mask_ptr, idx)?;
Ok(Some(bit.into()))
},
)
.map(|v| v.expect("i1 return value").into_int_value())
}

/// Emits a check that no array elements have been borrowed.
pub fn build_none_borrowed_check<'c, H: HugrView<Node = Node>>(
ccg: &impl BorrowArrayCodegen,
Expand Down Expand Up @@ -1570,6 +1619,20 @@ pub fn emit_barray_unsafe_op<'c, H: HugrView<Node = Node>>(
let (_, array_v) = build_barray_alloc(ctx, ccg, elem_ty, size, true)?;
outputs.finish(ctx.builder(), [array_v.into()])
}
BArrayUnsafeOpDef::is_borrowed => {
let [array_v, index_v] = inputs
.try_into()
.map_err(|_| anyhow!("BArrayUnsafeOpDef::is_borrowed expects two arguments"))?;
let BArrayFatPtrComponents {
mask_ptr, offset, ..
} = decompose_barray_fat_pointer(builder, array_v)?;
let index_v = index_v.into_int_value();
build_bounds_check(ccg, ctx, size, index_v)?;
let offset_index_v = ctx.builder().build_int_add(index_v, offset, "")?;
// let bit = build_is_borrowed_check(ctx, mask_ptr, offset_index_v)?;
let bit = build_is_borrowed_bit(ctx, mask_ptr, offset_index_v)?;
outputs.finish(ctx.builder(), [bit.into(), array_v])
}
_ => todo!(),
}
}
Expand Down Expand Up @@ -1627,6 +1690,8 @@ mod test {
use hugr_core::extension::prelude::either_type;
use hugr_core::ops::Tag;
use hugr_core::std_extensions::STD_REG;
use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef;
use hugr_core::std_extensions::arithmetic::int_ops::IntOpDef;
use hugr_core::std_extensions::collections::array::ArrayOpBuilder;
use hugr_core::std_extensions::collections::array::op_builder::build_all_borrow_array_ops;
use hugr_core::std_extensions::collections::borrow_array::{
Expand Down Expand Up @@ -2634,7 +2699,7 @@ mod test {
// - Pops specified numbers from the left to introduce an offset
// - Converts it into a regular array
// - Converts it back into a borrow array
// - Borrows alls elements, sums them up, and returns the sum
// - Borrows all elements, sums them up, and returns the sum

let int_ty = int_type(6);
let hugr = SimpleHugrConfig::new()
Expand Down Expand Up @@ -2908,4 +2973,70 @@ mod test {
let msg = "Some array elements have been borrowed";
assert_eq!(&exec_ctx.exec_hugr_panicking(hugr, "main"), msg);
}

#[rstest]
fn exec_is_borrowed_basic(mut exec_ctx: TestContext) {
// We build a HUGR that:
// - Creates a borrow array [1,2,3]
// - Borrows index 1
// - Checks is_borrowed for indices 0, 1
// - Returns 1 if [false, true], else 0
let int_ty = int_type(6);
let size = 3;
let hugr = SimpleHugrConfig::new()
.with_outs(int_ty.clone())
.with_extensions(exec_registry())
.finish(|mut builder| {
let barray = borrow_array::BArrayValue::new(
int_ty.clone(),
(1..=3)
.map(|i| ConstInt::new_u(6, i).unwrap().into())
.collect_vec(),
);
let barray = builder.add_load_value(barray);
let idx1 = builder.add_load_value(ConstUsize::new(1));
let (_, barray) = builder
.add_borrow_array_borrow(int_ty.clone(), size, barray, idx1)
.unwrap();

let idx0 = builder.add_load_value(ConstUsize::new(0));
let (arr, b0_bools) =
[idx0, idx1]
.iter()
.fold((barray, Vec::new()), |(arr, mut bools), idx| {
let (b, arr) = builder
.add_is_borrowed(int_ty.clone(), size, arr, *idx)
.unwrap();
bools.push(b);
(arr, bools)
});
let [b0, b1] = b0_bools.try_into().unwrap();

let b0 = builder.add_not(b0).unwrap(); // flip b0 to true
let and01 = builder.add_and(b0, b1).unwrap();
// convert bool to i1
let i1 = builder
.add_dataflow_op(ConvertOpDef::ifrombool.without_log_width(), [and01])
.unwrap()
.out_wire(0);
// widen i1 to i64
let i_64 = builder
.add_dataflow_op(IntOpDef::iwiden_u.with_two_log_widths(0, 6), [i1])
.unwrap()
.out_wire(0);
builder
.add_borrow_array_discard(int_ty.clone(), size, arr)
.unwrap();
builder.finish_hugr_with_outputs([i_64]).unwrap()
});

exec_ctx.add_extensions(|cge| {
cge.add_default_prelude_extensions()
.add_logic_extensions()
.add_conversion_extensions()
.add_default_borrow_array_extensions(DefaultPreludeCodegen)
.add_default_int_extensions()
});
assert_eq!(1, exec_ctx.exec_hugr_u64(hugr, "main"));
}
}
Loading
Loading