@@ -6,9 +6,9 @@ use vortex_error::VortexResult;
66use vortex_error:: vortex_ensure;
77
88use crate :: ArrayRef ;
9+ use crate :: ArrayView ;
910use crate :: ExecutionCtx ;
1011use crate :: IntoArray ;
11- use crate :: array:: ArrayView ;
1212use crate :: arrays:: ConstantArray ;
1313use crate :: arrays:: Struct ;
1414use crate :: arrays:: StructArray ;
@@ -19,23 +19,12 @@ use crate::dtype::DType;
1919use crate :: matcher:: Matcher ;
2020use crate :: scalar:: Scalar ;
2121use crate :: scalar_fn:: fns:: cast:: Cast ;
22- use crate :: scalar_fn:: fns:: cast:: CastKernel ;
23-
24- impl CastKernel for Struct {
25- fn cast (
26- array : ArrayView < ' _ , Struct > ,
27- dtype : & DType ,
28- ctx : & mut ExecutionCtx ,
29- ) -> VortexResult < Option < ArrayRef > > {
30- cast_struct ( array, dtype)
31- }
32- }
3322
3423pub ( crate ) fn struct_cast_execute_parent (
3524 child : & ArrayRef ,
3625 parent : & ArrayRef ,
3726 _child_idx : usize ,
38- _ctx : & mut ExecutionCtx ,
27+ ctx : & mut ExecutionCtx ,
3928) -> VortexResult < Option < ArrayRef > > {
4029 let Some ( array) = child. as_opt :: < Struct > ( ) else {
4130 return Ok ( None ) ;
@@ -45,14 +34,18 @@ pub(crate) fn struct_cast_execute_parent(
4534 } ;
4635
4736 let dtype = parent. options ;
48- if array. dtype ( ) == dtype {
37+ if array. dtype ( ) == parent . options {
4938 return Ok ( Some ( array. array ( ) . clone ( ) ) ) ;
5039 }
5140
52- cast_struct ( array, dtype)
41+ struct_cast ( array, dtype, ctx )
5342}
5443
55- fn cast_struct ( array : ArrayView < ' _ , Struct > , dtype : & DType ) -> VortexResult < Option < ArrayRef > > {
44+ pub ( crate ) fn struct_cast (
45+ array : ArrayView < Struct > ,
46+ dtype : & DType ,
47+ ctx : & mut ExecutionCtx ,
48+ ) -> VortexResult < Option < ArrayRef > > {
5649 let Some ( target_sdtype) = dtype. as_struct_fields_opt ( ) else {
5750 return Ok ( None ) ;
5851 } ;
@@ -84,19 +77,20 @@ fn cast_struct(array: ArrayView<'_, Struct>, dtype: &DType) -> VortexResult<Opti
8477
8578 let validity = array
8679 . validity ( ) ?
87- . cast_nullability ( dtype. nullability ( ) , array. len ( ) ) ?;
88-
89- StructArray :: try_new (
90- target_sdtype. names ( ) . clone ( ) ,
91- cast_fields,
92- array. len ( ) ,
93- validity,
94- )
95- . map ( |a| Some ( a. into_array ( ) ) )
80+ . cast_nullability ( dtype. nullability ( ) , array. len ( ) , ctx) ?;
81+
82+ Ok ( Some (
83+ unsafe {
84+ StructArray :: new_unchecked ( cast_fields, target_sdtype. clone ( ) , array. len ( ) , validity)
85+ }
86+ . into_array ( ) ,
87+ ) )
9688}
9789
9890#[ cfg( test) ]
9991mod tests {
92+ use std:: sync:: LazyLock ;
93+
10094 use rstest:: rstest;
10195 use vortex_buffer:: buffer;
10296 use vortex_error:: VortexResult ;
@@ -105,8 +99,7 @@ mod tests {
10599 use crate :: ArrayRef ;
106100 use crate :: ExecutionCtx ;
107101 use crate :: IntoArray ;
108- #[ expect( deprecated) ]
109- use crate :: ToCanonical as _;
102+ use crate :: VortexSessionExecute ;
110103 use crate :: arrays:: ConstantArray ;
111104 use crate :: arrays:: PrimitiveArray ;
112105 use crate :: arrays:: StructArray ;
@@ -127,8 +120,12 @@ mod tests {
127120 use crate :: optimizer:: kernels:: ExecuteParentFn ;
128121 use crate :: scalar:: Scalar ;
129122 use crate :: scalar_fn:: fns:: cast:: Cast ;
123+ use crate :: session:: ArraySession ;
130124 use crate :: validity:: Validity ;
131125
126+ static SESSION : LazyLock < VortexSession > =
127+ LazyLock :: new ( || VortexSession :: empty ( ) . with :: < ArraySession > ( ) ) ;
128+
132129 fn null_struct_cast_execute_parent (
133130 child : & ArrayRef ,
134131 parent : & ArrayRef ,
@@ -206,12 +203,11 @@ mod tests {
206203 session. kernels ( ) . register_execute_parent (
207204 parent_id,
208205 child_id,
209- [ null_struct_cast_execute_parent as ExecuteParentFn ] ,
206+ & [ null_struct_cast_execute_parent as ExecuteParentFn ] ,
210207 ) ;
211- let mut ctx = ExecutionCtx :: new ( session) ;
208+ let mut ctx = session. create_execution_ctx ( ) ;
212209
213- #[ expect( deprecated) ]
214- let result = cast. execute :: < ArrayRef > ( & mut ctx) . unwrap ( ) . to_struct ( ) ;
210+ let result = cast. execute :: < StructArray > ( & mut ctx) . unwrap ( ) ;
215211
216212 assert_eq ! ( result. dtype( ) , & target) ;
217213 assert_arrays_eq ! (
@@ -309,14 +305,17 @@ mod tests {
309305
310306 let target_dtype = struct_array. dtype ( ) . as_nullable ( ) ;
311307
312- let result = struct_array
308+ let cast = struct_array
313309 . into_array ( )
314310 . cast ( target_dtype. clone ( ) )
315311 . unwrap ( ) ;
316- assert_eq ! ( result. dtype( ) , & target_dtype) ;
317- assert_eq ! ( result. len( ) , 3 ) ;
318- #[ expect( deprecated) ]
319- let nfields = result. to_struct ( ) . struct_fields ( ) . nfields ( ) ;
312+ assert_eq ! ( cast. dtype( ) , & target_dtype) ;
313+ assert_eq ! ( cast. len( ) , 3 ) ;
314+ let nfields = cast
315+ . execute :: < StructArray > ( & mut SESSION . create_execution_ctx ( ) )
316+ . unwrap ( )
317+ . struct_fields ( )
318+ . nfields ( ) ;
320319 assert_eq ! ( nfields, 2 ) ;
321320 }
322321
@@ -346,8 +345,11 @@ mod tests {
346345 . unwrap ( ) ;
347346 assert_eq ! ( result. dtype( ) , & target_dtype) ;
348347 assert_eq ! ( result. len( ) , 3 ) ;
349- #[ expect( deprecated) ]
350- let nfields = result. to_struct ( ) . struct_fields ( ) . nfields ( ) ;
348+ let nfields = result
349+ . execute :: < StructArray > ( & mut SESSION . create_execution_ctx ( ) )
350+ . unwrap ( )
351+ . struct_fields ( )
352+ . nfields ( ) ;
351353 assert_eq ! ( nfields, 3 ) ;
352354 }
353355}
0 commit comments