Skip to content

Commit abc6902

Browse files
committed
refactor
Signed-off-by: Robert Kruszewski <github@robertk.io>
1 parent c263689 commit abc6902

26 files changed

Lines changed: 339 additions & 335 deletions

File tree

encodings/alp/src/alp/compute/cast.rs

Lines changed: 23 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ use vortex_array::ArrayView;
66
use vortex_array::IntoArray;
77
use vortex_array::builtins::ArrayBuiltins;
88
use vortex_array::dtype::DType;
9-
use vortex_array::patches::Patches;
109
use vortex_array::scalar_fn::fns::cast::CastReduce;
1110
use vortex_error::VortexResult;
1211

@@ -17,41 +16,29 @@ use crate::alp::ALP;
1716
impl CastReduce for ALP {
1817
fn cast(array: ArrayView<'_, Self>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
1918
// Check if this is just a nullability change
20-
if array.dtype().eq_ignore_nullability(dtype) {
21-
// For nullability-only changes, we can avoid decoding
22-
// Cast the encoded array (integers) to handle nullability
23-
let new_encoded = array.encoded().cast(
24-
array
25-
.encoded()
26-
.dtype()
27-
.with_nullability(dtype.nullability()),
28-
)?;
29-
30-
let new_patches = array
31-
.patches()
32-
.map(|p| {
33-
if p.values().dtype() == dtype {
34-
Ok(p)
35-
} else {
36-
Patches::new(
37-
p.array_len(),
38-
p.offset(),
39-
p.indices().clone(),
40-
p.values().cast(dtype.clone())?,
41-
p.chunk_offsets().clone(),
42-
)
43-
}
44-
})
45-
.transpose()?;
46-
47-
// SAFETY: casting nullability doesn't alter the invariants
48-
unsafe {
49-
Ok(Some(
50-
ALP::new_unchecked(new_encoded, array.exponents(), new_patches).into_array(),
51-
))
52-
}
53-
} else {
54-
Ok(None)
19+
if !array.dtype().eq_ignore_nullability(dtype) {
20+
return Ok(None);
21+
}
22+
23+
// For nullability-only changes, we can avoid decoding
24+
// Cast the encoded array (integers) to handle nullability
25+
let new_encoded = array.encoded().cast(
26+
array
27+
.encoded()
28+
.dtype()
29+
.with_nullability(dtype.nullability()),
30+
)?;
31+
32+
let new_patches = array
33+
.patches()
34+
.map(|p| p.map_values(|v| v.cast(dtype.clone())))
35+
.transpose()?;
36+
37+
// SAFETY: casting nullability doesn't alter the invariants
38+
unsafe {
39+
Ok(Some(
40+
ALP::new_unchecked(new_encoded, array.exponents(), new_patches).into_array(),
41+
))
5542
}
5643
}
5744
}

encodings/alp/src/alp_rd/compute/cast.rs

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,38 +16,35 @@ use crate::alp_rd::ALPRD;
1616

1717
impl CastReduce for ALPRD {
1818
fn cast(array: ArrayView<'_, Self>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
19-
// ALPRDArray stores floating-point values, so only cast between float types
20-
// or if just changing nullability
21-
2219
// Check if this is just a nullability change
23-
if array.dtype().eq_ignore_nullability(dtype) {
24-
// For nullability-only changes, we need to cast the left_parts array
25-
// since it carries the validity information
26-
let new_left_parts = array.left_parts().cast(
27-
array
28-
.left_parts()
29-
.dtype()
30-
.with_nullability(dtype.nullability()),
31-
)?;
20+
if !array.dtype().eq_ignore_nullability(dtype) {
21+
return Ok(None);
22+
}
23+
24+
// For nullability-only changes, we need to cast the left_parts array
25+
// since it carries the validity information
26+
let new_left_parts = array.left_parts().cast(
27+
array
28+
.left_parts()
29+
.dtype()
30+
.with_nullability(dtype.nullability()),
31+
)?;
3232

33-
// NOTE: `CastReduce::cast` has a fixed trait signature without `ExecutionCtx`, so we
34-
// construct a legacy ctx locally at this trait boundary.
35-
return Ok(Some(
36-
ALPRD::try_new(
33+
// NOTE: `CastReduce::cast` has a fixed trait signature without `ExecutionCtx`, so we
34+
// construct a legacy ctx locally at this trait boundary.
35+
Ok(Some(
36+
unsafe {
37+
ALPRD::new_unchecked(
3738
dtype.clone(),
3839
new_left_parts,
3940
array.left_parts_dictionary().clone(),
4041
array.right_parts().clone(),
4142
array.right_bit_width(),
4243
array.left_parts_patches(),
43-
&mut LEGACY_SESSION.create_execution_ctx(),
44-
)?
45-
.into_array(),
46-
));
47-
}
48-
49-
// For other casts (e.g., f32 to f64), decode to canonical and let PrimitiveArray handle it
50-
Ok(None)
44+
)
45+
}
46+
.into_array(),
47+
))
5148
}
5249
}
5350

encodings/bytebool/src/compute.rs

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,21 @@ impl CastReduce for ByteBool {
2525
// ByteBool is essentially a bool array stored as bytes
2626
// The main difference from BoolArray is the storage format
2727
// For casting, we can decode to canonical (BoolArray) and let it handle the cast
28-
2928
// If just changing nullability, we can optimize
30-
if array.dtype().eq_ignore_nullability(dtype) {
31-
let Some(new_validity) = array
32-
.validity()?
33-
.try_cast_nullability(dtype.nullability(), array.len())?
34-
else {
35-
return Ok(None);
36-
};
37-
38-
return Ok(Some(
39-
ByteBool::new(array.buffer().clone(), new_validity).into_array(),
40-
));
29+
if !dtype.is_boolean() {
30+
return Ok(None);
4131
}
4232

43-
// For other casts, decode to canonical and let BoolArray handle it
44-
Ok(None)
33+
let Some(new_validity) = array
34+
.validity()?
35+
.try_cast_nullability(dtype.nullability(), array.len())?
36+
else {
37+
return Ok(None);
38+
};
39+
40+
Ok(Some(
41+
ByteBool::new(array.buffer().clone(), new_validity).into_array(),
42+
))
4543
}
4644
}
4745

@@ -52,7 +50,7 @@ impl CastKernel for ByteBool {
5250
ctx: &mut ExecutionCtx,
5351
) -> VortexResult<Option<ArrayRef>> {
5452
// Only handle nullability changes here; non-bool targets fall through to canonicalization.
55-
if !array.dtype().eq_ignore_nullability(dtype) {
53+
if !dtype.is_boolean() {
5654
return Ok(None);
5755
}
5856

encodings/decimal-byte-parts/src/decimal_byte_parts/compute/cast.rs

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,39 +7,31 @@ use vortex_array::IntoArray;
77
use vortex_array::builtins::ArrayBuiltins;
88
use vortex_array::dtype::DType;
99
use vortex_array::scalar_fn::fns::cast::CastReduce;
10-
use vortex_error::VortexExpect;
1110
use vortex_error::VortexResult;
1211

1312
use crate::DecimalByteParts;
1413
use crate::decimal_byte_parts::DecimalBytePartsArrayExt;
14+
1515
impl CastReduce for DecimalByteParts {
1616
fn cast(array: ArrayView<'_, Self>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
17+
// Check if this is just a nullability change
18+
if !dtype.eq_ignore_nullability(array.dtype()) {
19+
return Ok(None);
20+
}
1721
// DecimalBytePartsArray can only have Decimal dtype, so we only handle decimal-to-decimal casts
1822
let DType::Decimal(target_decimal, target_nullability) = dtype else {
1923
// Cannot cast decimal to non-decimal types - delegate to canonical form
2024
return Ok(None);
2125
};
2226

23-
// Check if this is just a nullability change
24-
if array
25-
.dtype()
26-
.as_decimal_opt()
27-
.vortex_expect("must be a decimal dtype")
28-
== target_decimal
29-
&& array.dtype().nullability() != *target_nullability
30-
{
31-
// Cast the msp array to handle nullability change
32-
let new_msp = array
33-
.msp()
34-
.cast(array.msp().dtype().with_nullability(*target_nullability))?;
35-
36-
return Ok(Some(
37-
DecimalByteParts::try_new(new_msp, *target_decimal)?.into_array(),
38-
));
39-
}
27+
// Cast the msp array to handle nullability change
28+
let new_msp = array
29+
.msp()
30+
.cast(array.msp().dtype().with_nullability(*target_nullability))?;
4031

41-
// For precision/scale changes, decode to canonical and let DecimalArray handle it
42-
Ok(None)
32+
Ok(Some(
33+
DecimalByteParts::try_new(new_msp, *target_decimal)?.into_array(),
34+
))
4335
}
4436
}
4537

encodings/fastlanes/src/bitpacking/array/mod.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -337,19 +337,25 @@ impl<T: TypedArrayRef<crate::BitPacked>> BitPackedArrayExt for T {}
337337

338338
#[cfg(test)]
339339
mod test {
340+
use std::sync::LazyLock;
341+
340342
use vortex_array::IntoArray;
341-
use vortex_array::LEGACY_SESSION;
342343
use vortex_array::VortexSessionExecute;
343344
use vortex_array::arrays::PrimitiveArray;
344345
use vortex_array::assert_arrays_eq;
346+
use vortex_array::session::ArraySession;
345347
use vortex_buffer::Buffer;
348+
use vortex_session::VortexSession;
346349

347350
use crate::BitPackedData;
348351
use crate::bitpacking::array::BitPackedArrayExt;
349352

353+
static SESSION: LazyLock<VortexSession> =
354+
LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
355+
350356
#[test]
351357
fn test_encode() {
352-
let mut ctx = LEGACY_SESSION.create_execution_ctx();
358+
let mut ctx = SESSION.create_execution_ctx();
353359
let values = [
354360
Some(1u64),
355361
None,
@@ -372,7 +378,7 @@ mod test {
372378

373379
#[test]
374380
fn test_encode_too_wide() {
375-
let mut ctx = LEGACY_SESSION.create_execution_ctx();
381+
let mut ctx = SESSION.create_execution_ctx();
376382
let values = [Some(1u8), None, Some(1), None, Some(1), None];
377383
let uncompressed = PrimitiveArray::from_option_iter(values);
378384
let _packed = BitPackedData::encode(&uncompressed.clone().into_array(), 8, &mut ctx)
@@ -383,7 +389,7 @@ mod test {
383389

384390
#[test]
385391
fn signed_with_patches() {
386-
let mut ctx = LEGACY_SESSION.create_execution_ctx();
392+
let mut ctx = SESSION.create_execution_ctx();
387393
let values: Buffer<i32> = (0i32..=512).collect();
388394
let parray = values.clone().into_array();
389395

encodings/fastlanes/src/bitpacking/compute/cast.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ impl CastReduce for BitPacked {
4646
else {
4747
return Ok(None);
4848
};
49-
Ok(Some(build_with_validity(array, dtype, new_validity)?))
49+
build_with_validity(array, dtype, new_validity).map(Some)
5050
}
5151
}
5252

@@ -63,7 +63,7 @@ impl CastKernel for BitPacked {
6363
array
6464
.validity()?
6565
.cast_nullability(dtype.nullability(), array.len(), ctx)?;
66-
Ok(Some(build_with_validity(array, dtype, new_validity)?))
66+
build_with_validity(array, dtype, new_validity).map(Some)
6767
}
6868
}
6969

encodings/fastlanes/src/bitpacking/vtable/mod.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -153,18 +153,6 @@ impl VTable for BitPacked {
153153
}
154154
}
155155

156-
fn reduce_parent(
157-
array: ArrayView<'_, Self>,
158-
parent: &ArrayRef,
159-
child_idx: usize,
160-
) -> VortexResult<Option<ArrayRef>> {
161-
RULES.evaluate(array, parent, child_idx)
162-
}
163-
164-
fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
165-
BitPackedSlots::NAMES[idx].to_string()
166-
}
167-
168156
fn serialize(
169157
array: ArrayView<'_, Self>,
170158
_session: &VortexSession,
@@ -283,6 +271,10 @@ impl VTable for BitPacked {
283271
})
284272
}
285273

274+
fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
275+
BitPackedSlots::NAMES[idx].to_string()
276+
}
277+
286278
fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
287279
require_patches!(
288280
array,
@@ -305,6 +297,14 @@ impl VTable for BitPacked {
305297
) -> VortexResult<Option<ArrayRef>> {
306298
PARENT_KERNELS.execute(array, parent, child_idx, ctx)
307299
}
300+
301+
fn reduce_parent(
302+
array: ArrayView<'_, Self>,
303+
parent: &ArrayRef,
304+
child_idx: usize,
305+
) -> VortexResult<Option<ArrayRef>> {
306+
RULES.evaluate(array, parent, child_idx)
307+
}
308308
}
309309

310310
#[derive(Clone, Debug)]

encodings/fastlanes/src/delta/compute/cast.rs

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,17 @@ use vortex_array::dtype::DType;
99
use vortex_array::dtype::Nullability::NonNullable;
1010
use vortex_array::scalar_fn::fns::cast::CastReduce;
1111
use vortex_error::VortexResult;
12-
use vortex_error::vortex_panic;
1312

1413
use crate::delta::Delta;
1514
use crate::delta::array::DeltaArrayExt;
15+
1616
impl CastReduce for Delta {
1717
fn cast(array: ArrayView<'_, Self>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
18-
// Delta encoding stores differences between consecutive values, which requires
19-
// unsigned integers to avoid overflow issues. Signed integers could produce
20-
// negative deltas that wouldn't fit in the unsigned delta representation.
21-
// This encoding is optimized for monotonically increasing sequences.
2218
let DType::Primitive(target_ptype, _) = dtype else {
2319
return Ok(None);
2420
};
2521

26-
let DType::Primitive(source_ptype, _) = array.dtype() else {
27-
vortex_panic!("delta should be primitive typed");
28-
};
29-
22+
let source_ptype = array.dtype().as_ptype();
3023
// TODO(DK): narrows can be safe but we must decompress to compute the maximum value.
3124
if target_ptype.is_signed_int() || source_ptype.bit_width() > target_ptype.bit_width() {
3225
return Ok(None);

encodings/fastlanes/src/delta/vtable/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ use crate::delta::array::DeltaArrayExt;
3939
use crate::delta::array::SLOT_NAMES;
4040
use crate::delta::array::delta_decompress::delta_decompress;
4141
use crate::delta::array::lane_count;
42+
use crate::delta_compress;
4243

4344
mod operations;
4445
mod rules;
@@ -200,7 +201,7 @@ impl Delta {
200201
ctx: &mut ExecutionCtx,
201202
) -> VortexResult<DeltaArray> {
202203
let logical_len = array.len();
203-
let (bases, deltas) = crate::delta::array::delta_compress::delta_compress(array, ctx)?;
204+
let (bases, deltas) = delta_compress(array, ctx)?;
204205
Self::try_new(bases.into_array(), deltas.into_array(), 0, logical_len)
205206
}
206207
}

0 commit comments

Comments
 (0)