diff --git a/hugr-core/src/std_extensions/collections/borrow_array.rs b/hugr-core/src/std_extensions/collections/borrow_array.rs index ddd2148e2c..463f8a4de7 100644 --- a/hugr-core/src/std_extensions/collections/borrow_array.rs +++ b/hugr-core/src/std_extensions/collections/borrow_array.rs @@ -45,7 +45,7 @@ pub const BORROW_ARRAY_VALUENAME: TypeName = TypeName::new_inline("borrow_array" /// Reported unique name of the extension pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("collections.borrow_arr"); /// Extension version. -pub const VERSION: semver::Version = semver::Version::new(0, 1, 1); +pub const VERSION: semver::Version = semver::Version::new(0, 1, 2); /// A linear, unsafe, fixed-length collection of values. /// @@ -123,6 +123,8 @@ pub enum BArrayUnsafeOpDef { discard_all_borrowed, /// `new_all_borrowed: () -> borrow_array` new_all_borrowed, + /// is_borrowed: borrow_array, usize -> bool, borrow_array + is_borrowed, } impl BArrayUnsafeOpDef { @@ -166,6 +168,13 @@ impl BArrayUnsafeOpDef { Self::new_all_borrowed => { PolyFuncTypeRV::new(params, FuncValueType::new(type_row![], vec![array_ty])) } + Self::is_borrowed => PolyFuncTypeRV::new( + params, + FuncValueType::new( + vec![array_ty.clone(), usize_t], + vec![crate::extension::prelude::bool_t(), array_ty], + ), + ), } .into() } @@ -210,6 +219,7 @@ impl MakeOpDef for BArrayUnsafeOpDef { "Discard a borrow array where all elements have been borrowed" } Self::new_all_borrowed => "Create a new borrow array that contains no elements", + Self::is_borrowed => "Test whether an element in a borrow array has been borrowed", } .into() } @@ -719,6 +729,38 @@ pub trait BArrayOpBuilder: GenericArrayOpBuilder { .outputs_arr(); Ok(arr) } + + /// Adds an operation to test whether an element in a borrow array has been borrowed. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// * `index` - The wire representing the index to test. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + /// + /// # Returns + /// + /// A tuple containing: + /// * The wire representing the boolean result (true if borrowed). + /// * The wire representing the updated array. + fn add_is_borrowed( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + index: Wire, + ) -> Result<(Wire, Wire), BuildError> { + let op = BArrayUnsafeOpDef::is_borrowed.instantiate(&[size.into(), elem_ty.into()])?; + let [is_borrowed, arr] = self + .add_dataflow_op(op.to_extension_op().unwrap(), vec![input, index])? + .outputs_arr(); + Ok((is_borrowed, arr)) + } } impl BArrayOpBuilder for D {} @@ -804,4 +846,25 @@ mod test { builder.finish_hugr_with_outputs([arr_with_put]).unwrap() }; } + #[test] + fn test_is_borrowed() { + let size = 4; + let elem_ty = qb_t(); + let arr_ty = borrow_array_type(size, elem_ty.clone()); + + let mut builder = + DFGBuilder::new(Signature::new(vec![arr_ty.clone()], vec![qb_t(), arr_ty])).unwrap(); + let idx = builder.add_load_value(ConstUsize::new(2)); + let [arr] = builder.input_wires_arr(); + // Borrow the element at index 2 + let (qb, arr_with_borrowed) = builder + .add_borrow_array_borrow(elem_ty.clone(), size, arr, idx) + .unwrap(); + let (_is_borrowed, arr_after_check) = builder + .add_is_borrowed(elem_ty.clone(), size, arr_with_borrowed, idx) + .unwrap(); + builder + .finish_hugr_with_outputs([qb, arr_after_check]) + .unwrap(); + } } diff --git a/hugr-llvm/src/extension/collections/borrow_array.rs b/hugr-llvm/src/extension/collections/borrow_array.rs index 99e0884b87..b0a73f468e 100644 --- a/hugr-llvm/src/extension/collections/borrow_array.rs +++ b/hugr-llvm/src/extension/collections/borrow_array.rs @@ -599,20 +599,16 @@ impl MaskCheck { None, |ctx, [mask_ptr, idx]| { // Compute mask bitarray block index via `idx // BLOCK_SIZE` - let mask_ptr = mask_ptr.into_pointer_value(); - let idx = idx.into_int_value(); let usize_t = usize_ty(&ctx.typing_session()); - let block_size = usize_t.const_int(usize_t.get_bit_width() as u64, false); - let builder = ctx.builder(); - let block_idx = builder.build_int_unsigned_div(idx, block_size, "")?; - let block_ptr = unsafe { builder.build_in_bounds_gep(mask_ptr, &[block_idx], "")? }; - let block = builder.build_load(block_ptr, "")?.into_int_value(); - - // Extract bit from the block at position `idx % BLOCK_SIZE` - let idx_in_block = builder.build_int_unsigned_rem(idx, block_size, "")?; - let block_shifted = builder.build_right_shift(block, idx_in_block, false, "")?; - let bit = - builder.build_int_truncate(block_shifted, ctx.iw_context().bool_type(), "")?; + let ( + BlockData { + block_ptr, + block, + idx_in_block, + }, + bit, + ) = inspect_mask_idx_bit(ctx, mask_ptr, idx)?; + let panic_bb = ctx.build_positioned_new_block("panic", None, |ctx, panic_bb| { let err: &ConstError = match self { MaskCheck::CheckNotBorrowed | MaskCheck::Borrow => &ERR_ALREADY_BORROWED, @@ -651,6 +647,38 @@ impl MaskCheck { } } +struct BlockData<'c> { + block_ptr: PointerValue<'c>, + block: IntValue<'c>, + idx_in_block: IntValue<'c>, +} + +fn inspect_mask_idx_bit<'c, H: HugrView>( + ctx: &mut EmitFuncContext<'c, '_, H>, + mask_ptr: BasicValueEnum<'c>, + idx: BasicValueEnum<'c>, +) -> Result<(BlockData<'c>, IntValue<'c>)> { + let usize_t = usize_ty(&ctx.typing_session()); + let mask_ptr = mask_ptr.into_pointer_value(); + let idx = idx.into_int_value(); + let block_size = usize_t.const_int(usize_t.get_bit_width() as u64, false); + let builder = ctx.builder(); + let block_idx = builder.build_int_unsigned_div(idx, block_size, "")?; + let block_ptr = unsafe { builder.build_in_bounds_gep(mask_ptr, &[block_idx], "")? }; + let block = builder.build_load(block_ptr, "")?.into_int_value(); + let idx_in_block = builder.build_int_unsigned_rem(idx, block_size, "")?; + let block_shifted = builder.build_right_shift(block, idx_in_block, false, "")?; + let bit = builder.build_int_truncate(block_shifted, ctx.iw_context().bool_type(), "")?; + Ok(( + BlockData { + block_ptr, + block, + idx_in_block, + }, + bit, + )) +} + struct MaskInfo<'a> { mask_ptr: PointerValue<'a>, offset: IntValue<'a>, @@ -787,6 +815,27 @@ fn build_mask_padding1d<'c, H: HugrView>( Ok(()) } +/// Emits a check that returns whether a specific array element is borrowed (true) or not (false). +pub fn build_is_borrowed_bit<'c, H: HugrView>( + ctx: &mut EmitFuncContext<'c, '_, H>, + mask_ptr: PointerValue<'c>, + idx: IntValue<'c>, +) -> Result> { + // Wrap the check into a function instead of inlining + const FUNC_NAME: &str = "__barray_is_borrowed"; + get_or_make_function( + ctx, + FUNC_NAME, + [mask_ptr.into(), idx.into()], + Some(ctx.iw_context().bool_type().into()), + |ctx, [mask_ptr, idx]| { + let (_, bit) = inspect_mask_idx_bit(ctx, mask_ptr, idx)?; + Ok(Some(bit.into())) + }, + ) + .map(|v| v.expect("i1 return value").into_int_value()) +} + /// Emits a check that no array elements have been borrowed. pub fn build_none_borrowed_check<'c, H: HugrView>( ccg: &impl BorrowArrayCodegen, @@ -1570,6 +1619,20 @@ pub fn emit_barray_unsafe_op<'c, H: HugrView>( let (_, array_v) = build_barray_alloc(ctx, ccg, elem_ty, size, true)?; outputs.finish(ctx.builder(), [array_v.into()]) } + BArrayUnsafeOpDef::is_borrowed => { + let [array_v, index_v] = inputs + .try_into() + .map_err(|_| anyhow!("BArrayUnsafeOpDef::is_borrowed expects two arguments"))?; + let BArrayFatPtrComponents { + mask_ptr, offset, .. + } = decompose_barray_fat_pointer(builder, array_v)?; + let index_v = index_v.into_int_value(); + build_bounds_check(ccg, ctx, size, index_v)?; + let offset_index_v = ctx.builder().build_int_add(index_v, offset, "")?; + // let bit = build_is_borrowed_check(ctx, mask_ptr, offset_index_v)?; + let bit = build_is_borrowed_bit(ctx, mask_ptr, offset_index_v)?; + outputs.finish(ctx.builder(), [bit.into(), array_v]) + } _ => todo!(), } } @@ -1627,6 +1690,8 @@ mod test { use hugr_core::extension::prelude::either_type; use hugr_core::ops::Tag; use hugr_core::std_extensions::STD_REG; + use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef; + use hugr_core::std_extensions::arithmetic::int_ops::IntOpDef; use hugr_core::std_extensions::collections::array::ArrayOpBuilder; use hugr_core::std_extensions::collections::array::op_builder::build_all_borrow_array_ops; use hugr_core::std_extensions::collections::borrow_array::{ @@ -2634,7 +2699,7 @@ mod test { // - Pops specified numbers from the left to introduce an offset // - Converts it into a regular array // - Converts it back into a borrow array - // - Borrows alls elements, sums them up, and returns the sum + // - Borrows all elements, sums them up, and returns the sum let int_ty = int_type(6); let hugr = SimpleHugrConfig::new() @@ -2908,4 +2973,70 @@ mod test { let msg = "Some array elements have been borrowed"; assert_eq!(&exec_ctx.exec_hugr_panicking(hugr, "main"), msg); } + + #[rstest] + fn exec_is_borrowed_basic(mut exec_ctx: TestContext) { + // We build a HUGR that: + // - Creates a borrow array [1,2,3] + // - Borrows index 1 + // - Checks is_borrowed for indices 0, 1 + // - Returns 1 if [false, true], else 0 + let int_ty = int_type(6); + let size = 3; + let hugr = SimpleHugrConfig::new() + .with_outs(int_ty.clone()) + .with_extensions(exec_registry()) + .finish(|mut builder| { + let barray = borrow_array::BArrayValue::new( + int_ty.clone(), + (1..=3) + .map(|i| ConstInt::new_u(6, i).unwrap().into()) + .collect_vec(), + ); + let barray = builder.add_load_value(barray); + let idx1 = builder.add_load_value(ConstUsize::new(1)); + let (_, barray) = builder + .add_borrow_array_borrow(int_ty.clone(), size, barray, idx1) + .unwrap(); + + let idx0 = builder.add_load_value(ConstUsize::new(0)); + let (arr, b0_bools) = + [idx0, idx1] + .iter() + .fold((barray, Vec::new()), |(arr, mut bools), idx| { + let (b, arr) = builder + .add_is_borrowed(int_ty.clone(), size, arr, *idx) + .unwrap(); + bools.push(b); + (arr, bools) + }); + let [b0, b1] = b0_bools.try_into().unwrap(); + + let b0 = builder.add_not(b0).unwrap(); // flip b0 to true + let and01 = builder.add_and(b0, b1).unwrap(); + // convert bool to i1 + let i1 = builder + .add_dataflow_op(ConvertOpDef::ifrombool.without_log_width(), [and01]) + .unwrap() + .out_wire(0); + // widen i1 to i64 + let i_64 = builder + .add_dataflow_op(IntOpDef::iwiden_u.with_two_log_widths(0, 6), [i1]) + .unwrap() + .out_wire(0); + builder + .add_borrow_array_discard(int_ty.clone(), size, arr) + .unwrap(); + builder.finish_hugr_with_outputs([i_64]).unwrap() + }); + + exec_ctx.add_extensions(|cge| { + cge.add_default_prelude_extensions() + .add_logic_extensions() + .add_conversion_extensions() + .add_default_borrow_array_extensions(DefaultPreludeCodegen) + .add_default_int_extensions() + }); + assert_eq!(1, exec_ctx.exec_hugr_u64(hugr, "main")); + } } diff --git a/hugr-py/src/hugr/std/_json_defs/collections/borrow_arr.json b/hugr-py/src/hugr/std/_json_defs/collections/borrow_arr.json index 1774b4aea6..a22457393b 100644 --- a/hugr-py/src/hugr/std/_json_defs/collections/borrow_arr.json +++ b/hugr-py/src/hugr/std/_json_defs/collections/borrow_arr.json @@ -1,5 +1,5 @@ { - "version": "0.1.1", + "version": "0.1.2", "name": "collections.borrow_arr", "types": { "borrow_array": { @@ -493,6 +493,86 @@ }, "binary": false }, + "is_borrowed": { + "extension": "collections.borrow_arr", + "name": "is_borrowed", + "description": "Test whether an element in a borrow array has been borrowed", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + }, + { + "t": "I" + } + ], + "output": [ + { + "t": "Sum", + "s": "Unit", + "size": 2 + }, + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ] + } + }, + "binary": false + }, "new_all_borrowed": { "extension": "collections.borrow_arr", "name": "new_all_borrowed", diff --git a/specification/std_extensions/collections/borrow_arr.json b/specification/std_extensions/collections/borrow_arr.json index 1774b4aea6..a22457393b 100644 --- a/specification/std_extensions/collections/borrow_arr.json +++ b/specification/std_extensions/collections/borrow_arr.json @@ -1,5 +1,5 @@ { - "version": "0.1.1", + "version": "0.1.2", "name": "collections.borrow_arr", "types": { "borrow_array": { @@ -493,6 +493,86 @@ }, "binary": false }, + "is_borrowed": { + "extension": "collections.borrow_arr", + "name": "is_borrowed", + "description": "Test whether an element in a borrow array has been borrowed", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + }, + { + "t": "I" + } + ], + "output": [ + { + "t": "Sum", + "s": "Unit", + "size": 2 + }, + { + "t": "Opaque", + "extension": "collections.borrow_arr", + "id": "borrow_array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ] + } + }, + "binary": false + }, "new_all_borrowed": { "extension": "collections.borrow_arr", "name": "new_all_borrowed",