Skip to content

Commit 78e6308

Browse files
committed
refactor
Signed-off-by: Robert Kruszewski <github@robertk.io>
1 parent eb71f69 commit 78e6308

4 files changed

Lines changed: 101 additions & 109 deletions

File tree

vortex-array/src/arrays/struct_/compute/cast.rs

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ use vortex_error::VortexResult;
66
use vortex_error::vortex_ensure;
77

88
use crate::ArrayRef;
9+
use crate::ArrayView;
910
use crate::ExecutionCtx;
1011
use crate::IntoArray;
11-
use crate::array::ArrayView;
1212
use crate::arrays::ConstantArray;
1313
use crate::arrays::Struct;
1414
use crate::arrays::StructArray;
@@ -19,23 +19,12 @@ use crate::dtype::DType;
1919
use crate::matcher::Matcher;
2020
use crate::scalar::Scalar;
2121
use crate::scalar_fn::fns::cast::Cast;
22-
use crate::scalar_fn::fns::cast::CastKernel;
23-
24-
impl CastKernel for Struct {
25-
fn cast(
26-
array: ArrayView<'_, Struct>,
27-
dtype: &DType,
28-
ctx: &mut ExecutionCtx,
29-
) -> VortexResult<Option<ArrayRef>> {
30-
cast_struct(array, dtype)
31-
}
32-
}
3322

3423
pub(crate) fn struct_cast_execute_parent(
3524
child: &ArrayRef,
3625
parent: &ArrayRef,
3726
_child_idx: usize,
38-
_ctx: &mut ExecutionCtx,
27+
ctx: &mut ExecutionCtx,
3928
) -> VortexResult<Option<ArrayRef>> {
4029
let Some(array) = child.as_opt::<Struct>() else {
4130
return Ok(None);
@@ -45,14 +34,18 @@ pub(crate) fn struct_cast_execute_parent(
4534
};
4635

4736
let dtype = parent.options;
48-
if array.dtype() == dtype {
37+
if array.dtype() == parent.options {
4938
return Ok(Some(array.array().clone()));
5039
}
5140

52-
cast_struct(array, dtype)
41+
struct_cast(array, dtype, ctx)
5342
}
5443

55-
fn cast_struct(array: ArrayView<'_, Struct>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
44+
pub(crate) fn struct_cast(
45+
array: ArrayView<Struct>,
46+
dtype: &DType,
47+
ctx: &mut ExecutionCtx,
48+
) -> VortexResult<Option<ArrayRef>> {
5649
let Some(target_sdtype) = dtype.as_struct_fields_opt() else {
5750
return Ok(None);
5851
};
@@ -84,19 +77,20 @@ fn cast_struct(array: ArrayView<'_, Struct>, dtype: &DType) -> VortexResult<Opti
8477

8578
let validity = array
8679
.validity()?
87-
.cast_nullability(dtype.nullability(), array.len())?;
88-
89-
StructArray::try_new(
90-
target_sdtype.names().clone(),
91-
cast_fields,
92-
array.len(),
93-
validity,
94-
)
95-
.map(|a| Some(a.into_array()))
80+
.cast_nullability(dtype.nullability(), array.len(), ctx)?;
81+
82+
Ok(Some(
83+
unsafe {
84+
StructArray::new_unchecked(cast_fields, target_sdtype.clone(), array.len(), validity)
85+
}
86+
.into_array(),
87+
))
9688
}
9789

9890
#[cfg(test)]
9991
mod tests {
92+
use std::sync::LazyLock;
93+
10094
use rstest::rstest;
10195
use vortex_buffer::buffer;
10296
use vortex_error::VortexResult;
@@ -105,8 +99,7 @@ mod tests {
10599
use crate::ArrayRef;
106100
use crate::ExecutionCtx;
107101
use crate::IntoArray;
108-
#[expect(deprecated)]
109-
use crate::ToCanonical as _;
102+
use crate::VortexSessionExecute;
110103
use crate::arrays::ConstantArray;
111104
use crate::arrays::PrimitiveArray;
112105
use crate::arrays::StructArray;
@@ -127,8 +120,12 @@ mod tests {
127120
use crate::optimizer::kernels::ExecuteParentFn;
128121
use crate::scalar::Scalar;
129122
use crate::scalar_fn::fns::cast::Cast;
123+
use crate::session::ArraySession;
130124
use crate::validity::Validity;
131125

126+
static SESSION: LazyLock<VortexSession> =
127+
LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
128+
132129
fn null_struct_cast_execute_parent(
133130
child: &ArrayRef,
134131
parent: &ArrayRef,
@@ -206,12 +203,11 @@ mod tests {
206203
session.kernels().register_execute_parent(
207204
parent_id,
208205
child_id,
209-
[null_struct_cast_execute_parent as ExecuteParentFn],
206+
&[null_struct_cast_execute_parent as ExecuteParentFn],
210207
);
211-
let mut ctx = ExecutionCtx::new(session);
208+
let mut ctx = session.create_execution_ctx();
212209

213-
#[expect(deprecated)]
214-
let result = cast.execute::<ArrayRef>(&mut ctx).unwrap().to_struct();
210+
let result = cast.execute::<StructArray>(&mut ctx).unwrap();
215211

216212
assert_eq!(result.dtype(), &target);
217213
assert_arrays_eq!(
@@ -309,14 +305,17 @@ mod tests {
309305

310306
let target_dtype = struct_array.dtype().as_nullable();
311307

312-
let result = struct_array
308+
let cast = struct_array
313309
.into_array()
314310
.cast(target_dtype.clone())
315311
.unwrap();
316-
assert_eq!(result.dtype(), &target_dtype);
317-
assert_eq!(result.len(), 3);
318-
#[expect(deprecated)]
319-
let nfields = result.to_struct().struct_fields().nfields();
312+
assert_eq!(cast.dtype(), &target_dtype);
313+
assert_eq!(cast.len(), 3);
314+
let nfields = cast
315+
.execute::<StructArray>(&mut SESSION.create_execution_ctx())
316+
.unwrap()
317+
.struct_fields()
318+
.nfields();
320319
assert_eq!(nfields, 2);
321320
}
322321

@@ -346,8 +345,11 @@ mod tests {
346345
.unwrap();
347346
assert_eq!(result.dtype(), &target_dtype);
348347
assert_eq!(result.len(), 3);
349-
#[expect(deprecated)]
350-
let nfields = result.to_struct().struct_fields().nfields();
348+
let nfields = result
349+
.execute::<StructArray>(&mut SESSION.create_execution_ctx())
350+
.unwrap()
351+
.struct_fields()
352+
.nfields();
351353
assert_eq!(nfields, 3);
352354
}
353355
}

0 commit comments

Comments
 (0)