Skip to content

Commit 1a3cb42

Browse files
committed
ValidityKernel wrapper
Signed-off-by: Andrew Duffy <[email protected]>
1 parent 76aeed6 commit 1a3cb42

File tree

20 files changed

+146
-124
lines changed

20 files changed

+146
-124
lines changed

encodings/fastlanes/src/bitpacking/array/bitpack_pipeline.rs

Lines changed: 33 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33

44
use fastlanes::{BitPacking, FastLanes};
55
use static_assertions::const_assert_eq;
6+
use vortex_array::pipeline::validity::ValidityKernel;
67
use vortex_array::pipeline::{
78
BindContext, BitView, Kernel, KernelCtx, N, PipelineInputs, PipelinedNode,
89
};
10+
use vortex_array::validity::Validity;
911
use vortex_buffer::Buffer;
1012
use vortex_dtype::{PTypeDowncastExt, PhysicalPType, match_each_integer_ptype};
1113
use vortex_error::VortexResult;
12-
use vortex_mask::Mask;
1314
use vortex_vector::primitive::PVectorMut;
1415
use vortex_vector::{VectorMut, VectorMutOps};
1516

@@ -53,11 +54,21 @@ impl PipelinedNode for BitPackedArray {
5354
)
5455
}
5556

56-
Ok(Box::new(AlignedBitPackedKernel::<T>::new(
57-
packed_bit_width,
58-
packed_buffer,
59-
self.validity.to_mask(self.len()),
60-
)) as Box<dyn Kernel>)
57+
match self.validity {
58+
Validity::NonNullable | Validity::AllValid => Ok(Box::new(
59+
AlignedBitPackedKernel::<T>::new(packed_bit_width, packed_buffer),
60+
)
61+
as Box<dyn Kernel>),
62+
Validity::AllInvalid => {
63+
todo!("make a kernel that returns constant null");
64+
}
65+
Validity::Array(_) => {
66+
let inner = AlignedBitPackedKernel::<T>::new(packed_bit_width, packed_buffer);
67+
let mask = self.validity_mask();
68+
69+
Ok(Box::new(ValidityKernel::new(inner, mask)) as Box<dyn Kernel>)
70+
}
71+
}
6172
})
6273
}
6374
}
@@ -84,19 +95,12 @@ pub struct AlignedBitPackedKernel<BP: PhysicalPType<Physical: BitPacking>> {
8495
/// The buffer containing the bitpacked values.
8596
packed_buffer: Buffer<BP::Physical>,
8697

87-
/// The validity mask for the bitpacked array.
88-
validity: Mask,
89-
9098
/// The total number of bitpacked chunks we have unpacked.
9199
num_chunks_unpacked: usize,
92100
}
93101

94102
impl<BP: PhysicalPType<Physical: BitPacking>> AlignedBitPackedKernel<BP> {
95-
pub fn new(
96-
packed_bit_width: usize,
97-
packed_buffer: Buffer<BP::Physical>,
98-
validity: Mask,
99-
) -> Self {
103+
pub fn new(packed_bit_width: usize, packed_buffer: Buffer<BP::Physical>) -> Self {
100104
let packed_stride =
101105
packed_bit_width * <<BP as PhysicalPType>::Physical as FastLanes>::LANES;
102106

@@ -110,7 +114,6 @@ impl<BP: PhysicalPType<Physical: BitPacking>> AlignedBitPackedKernel<BP> {
110114
packed_bit_width,
111115
packed_stride,
112116
packed_buffer,
113-
validity,
114117
num_chunks_unpacked: 0,
115118
}
116119
}
@@ -130,9 +133,6 @@ impl<BP: PhysicalPType<Physical: BitPacking>> Kernel for AlignedBitPackedKernel<
130133
let not_yet_unpacked_values = &self.packed_buffer.as_slice()[packed_offset..];
131134

132135
let true_count = selection.true_count();
133-
let chunk_offset = self.num_chunks_unpacked * N;
134-
let array_len = self.validity.len();
135-
debug_assert!(chunk_offset < array_len);
136136

137137
// If the true count is very small (the selection is sparse), we can unpack individual
138138
// elements directly into the output vector.
@@ -141,26 +141,21 @@ impl<BP: PhysicalPType<Physical: BitPacking>> Kernel for AlignedBitPackedKernel<
141141
debug_assert!(true_count <= output_vector.capacity());
142142

143143
selection.iter_ones(|idx| {
144-
let absolute_idx = chunk_offset + idx;
145-
if self.validity.value(absolute_idx) {
146-
// SAFETY:
147-
// - The documentation for `packed_bit_width` explains that the size is valid.
148-
// - We know that the size of the `next_packed_chunk` we provide is equal to
149-
// `self.packed_stride`, and we explain why this is correct in its
150-
// documentation.
151-
let unpacked_value = unsafe {
152-
BitPacking::unchecked_unpack_single(
153-
self.packed_bit_width,
154-
not_yet_unpacked_values,
155-
idx,
156-
)
157-
};
158-
159-
// SAFETY: We just reserved enough capacity to push these values.
160-
unsafe { output_vector.push_unchecked(unpacked_value) };
161-
} else {
162-
output_vector.append_nulls(1);
163-
}
144+
// SAFETY:
145+
// - The documentation for `packed_bit_width` explains that the size is valid.
146+
// - We know that the size of the `next_packed_chunk` we provide is equal to
147+
// `self.packed_stride`, and we explain why this is correct in its
148+
// documentation.
149+
let unpacked_value = unsafe {
150+
BitPacking::unchecked_unpack_single(
151+
self.packed_bit_width,
152+
not_yet_unpacked_values,
153+
idx,
154+
)
155+
};
156+
157+
// SAFETY: We just reserved enough capacity to push these values.
158+
unsafe { output_vector.elements_mut().push_unchecked(unpacked_value) };
164159
});
165160
} else {
166161
// Otherwise if the mask is dense, it is faster to fully unpack the entire 1024
@@ -192,26 +187,6 @@ impl<BP: PhysicalPType<Physical: BitPacking>> Kernel for AlignedBitPackedKernel<
192187
output_vector.as_mut(),
193188
);
194189
}
195-
196-
if array_len < chunk_offset + N {
197-
let vector_len = array_len - chunk_offset;
198-
debug_assert!(vector_len < N, "math is broken");
199-
200-
// SAFETY: This must be less than `N` so this is just a truncate.
201-
unsafe { output_vector.elements_mut().set_len(vector_len) };
202-
203-
let chunk_mask = self.validity.slice(chunk_offset..array_len);
204-
205-
// SAFETY: We have just set the elements length to N, and the validity buffer has
206-
// capacity for N elements.
207-
unsafe { output_vector.validity_mut() }.append_mask(&chunk_mask);
208-
} else {
209-
let chunk_mask = self.validity.slice(chunk_offset..chunk_offset + N);
210-
211-
// SAFETY: We have just set the elements length to N, and the validity buffer has
212-
// capacity for N elements.
213-
unsafe { output_vector.validity_mut() }.append_mask(&chunk_mask);
214-
}
215190
}
216191

217192
self.num_chunks_unpacked += 1;

vortex-array/src/pipeline/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
pub mod driver;
5+
pub mod validity;
56

67
use vortex_error::{VortexExpect, VortexResult};
78
use vortex_vector::{Vector, VectorMut};
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
use vortex_error::VortexResult;
2+
use vortex_mask::Mask;
3+
use vortex_vector::{VectorMut, VectorMutOps};
4+
5+
use crate::pipeline::{BitView, Kernel, KernelCtx};
6+
7+
/// `ValidityKernel` wraps a child kernel, passing a validity mask through to the output.
8+
pub struct ValidityKernel<K> {
9+
child: K,
10+
validity: Mask,
11+
position: usize,
12+
}
13+
14+
impl<K> ValidityKernel<K> {
15+
pub fn new(inner: K, validity: Mask) -> Self {
16+
Self {
17+
child: inner,
18+
validity,
19+
position: 0,
20+
}
21+
}
22+
}
23+
24+
impl<K: Kernel> Kernel for ValidityKernel<K> {
25+
fn step(
26+
&mut self,
27+
ctx: &KernelCtx,
28+
selection: &BitView,
29+
out: &mut VectorMut,
30+
) -> VortexResult<()> {
31+
// execute the child kernel
32+
self.child.step(ctx, selection, out)?;
33+
34+
debug_assert_eq!(
35+
out.validity().len(),
36+
self.position,
37+
"child kernel should not step validity when wrapped with ValidityKernel"
38+
);
39+
40+
let new_position = self.position + out.len();
41+
42+
let slice = self.validity.slice(self.position..new_position);
43+
44+
// SAFETY: the child kernel must extend elements in its step function.
45+
unsafe { out.validity_mut().append_mask(&slice) };
46+
47+
// Advance the position in the kernel here.
48+
self.position = new_position;
49+
50+
Ok(())
51+
}
52+
}

vortex-compute/src/filter/vector/binaryview.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33

44
use vortex_buffer::{Buffer, BufferMut};
55
use vortex_mask::{Mask, MaskMut};
6-
use vortex_vector::VectorOps;
76
use vortex_vector::binaryview::{
87
BinaryView, BinaryViewType, BinaryViewVector, BinaryViewVectorMut,
98
};
9+
use vortex_vector::{VectorMutOps, VectorOps};
1010

1111
use crate::filter::Filter;
1212

vortex-compute/src/filter/vector/bool.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
use vortex_buffer::BitBuffer;
55
use vortex_mask::{Mask, MaskMut};
6-
use vortex_vector::VectorOps;
76
use vortex_vector::bool::{BoolVector, BoolVectorMut};
7+
use vortex_vector::{VectorMutOps, VectorOps};
88

99
use crate::filter::Filter;
1010

vortex-compute/src/filter/vector/dvector.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
use vortex_buffer::{Buffer, BufferMut};
55
use vortex_dtype::NativeDecimalType;
66
use vortex_mask::{Mask, MaskMut};
7-
use vortex_vector::VectorOps;
87
use vortex_vector::decimal::{DVector, DVectorMut};
8+
use vortex_vector::{VectorMutOps, VectorOps};
99

1010
use crate::filter::Filter;
1111

vortex-compute/src/filter/vector/list.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
use std::sync::Arc;
55

66
use vortex_mask::{Mask, MaskMut};
7-
use vortex_vector::VectorOps;
87
use vortex_vector::listview::{ListViewVector, ListViewVectorMut};
98
use vortex_vector::primitive::{PrimitiveVector, PrimitiveVectorMut};
9+
use vortex_vector::{VectorMutOps, VectorOps};
1010

1111
use crate::filter::Filter;
1212

vortex-compute/src/filter/vector/struct_.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use std::sync::Arc;
55

66
use vortex_mask::{Mask, MaskMut};
77
use vortex_vector::struct_::{StructVector, StructVectorMut};
8-
use vortex_vector::{Vector, VectorMut, VectorOps};
8+
use vortex_vector::{Vector, VectorMut, VectorMutOps, VectorOps};
99

1010
use crate::filter::Filter;
1111

vortex-vector/src/binaryview/vector_mut.rs

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,16 +122,6 @@ impl<T: BinaryViewType> BinaryViewVectorMut<T> {
122122
&mut self.views
123123
}
124124

125-
/// Get a mutable handle to the validity mask of the vector.
126-
///
127-
/// # Safety
128-
///
129-
/// Caller must make sure that the length of the validity mask
130-
/// always matches the length of the views
131-
pub unsafe fn validity_mut(&mut self) -> &mut MaskMut {
132-
&mut self.validity
133-
}
134-
135125
/// Get a mutable handle to the vector of buffers backing the string data of the vector.
136126
pub fn buffers(&mut self) -> &mut Vec<ByteBuffer> {
137127
&mut self.buffers
@@ -216,6 +206,10 @@ impl<T: BinaryViewType> VectorMutOps for BinaryViewVectorMut<T> {
216206
&self.validity
217207
}
218208

209+
unsafe fn validity_mut(&mut self) -> &mut MaskMut {
210+
&mut self.validity
211+
}
212+
219213
fn capacity(&self) -> usize {
220214
self.views.capacity()
221215
}

vortex-vector/src/bool/vector_mut.rs

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -93,16 +93,6 @@ impl BoolVectorMut {
9393
pub unsafe fn bits_mut(&mut self) -> &mut BitBufferMut {
9494
&mut self.bits
9595
}
96-
97-
/// Get a mutable handle to the validity mask of the vector.
98-
///
99-
/// # Safety
100-
///
101-
/// Caller must ensure that length of the validity always matches
102-
/// length of the bits.
103-
pub unsafe fn validity_mut(&mut self) -> &mut MaskMut {
104-
&mut self.validity
105-
}
10696
}
10797

10898
impl VectorMutOps for BoolVectorMut {
@@ -118,6 +108,10 @@ impl VectorMutOps for BoolVectorMut {
118108
&self.validity
119109
}
120110

111+
unsafe fn validity_mut(&mut self) -> &mut MaskMut {
112+
&mut self.validity
113+
}
114+
121115
fn capacity(&self) -> usize {
122116
self.bits.capacity()
123117
}

0 commit comments

Comments
 (0)