@@ -6,106 +6,149 @@ use vortex_error::VortexResult;
66use vortex_error:: vortex_ensure;
77
88use crate :: ArrayRef ;
9+ use crate :: ArrayView ;
910use crate :: ExecutionCtx ;
1011use crate :: IntoArray ;
11- use crate :: array:: ArrayView ;
1212use crate :: arrays:: ConstantArray ;
1313use crate :: arrays:: Struct ;
1414use crate :: arrays:: StructArray ;
15+ use crate :: arrays:: scalar_fn:: ExactScalarFn ;
1516use crate :: arrays:: struct_:: StructArrayExt ;
1617use crate :: builtins:: ArrayBuiltins ;
1718use crate :: dtype:: DType ;
19+ use crate :: matcher:: Matcher ;
1820use crate :: scalar:: Scalar ;
19- use crate :: scalar_fn:: fns:: cast:: CastKernel ;
20-
21- impl CastKernel for Struct {
22- fn cast (
23- array : ArrayView < ' _ , Struct > ,
24- dtype : & DType ,
25- ctx : & mut ExecutionCtx ,
26- ) -> VortexResult < Option < ArrayRef > > {
27- let Some ( target_sdtype) = dtype. as_struct_fields_opt ( ) else {
28- return Ok ( None ) ;
29- } ;
30-
31- let source_sdtype = array. struct_fields ( ) ;
21+ use crate :: scalar_fn:: fns:: cast:: Cast ;
22+
23+ pub ( crate ) fn struct_cast_execute_parent (
24+ child : & ArrayRef ,
25+ parent : & ArrayRef ,
26+ _child_idx : usize ,
27+ ctx : & mut ExecutionCtx ,
28+ ) -> VortexResult < Option < ArrayRef > > {
29+ let Some ( array) = child. as_opt :: < Struct > ( ) else {
30+ return Ok ( None ) ;
31+ } ;
32+ let Some ( parent) = ExactScalarFn :: < Cast > :: try_match ( parent) else {
33+ return Ok ( None ) ;
34+ } ;
35+
36+ let dtype = parent. options ;
37+ if array. dtype ( ) == parent. options {
38+ return Ok ( Some ( array. array ( ) . clone ( ) ) ) ;
39+ }
3240
33- let fields_match_order = target_sdtype. nfields ( ) == source_sdtype. nfields ( )
34- && target_sdtype
35- . names ( )
36- . iter ( )
37- . zip ( source_sdtype. names ( ) . iter ( ) )
38- . all ( |( f1, f2) | f1 == f2) ;
41+ struct_cast ( array, dtype, ctx)
42+ }
3943
40- let mut cast_fields = Vec :: with_capacity ( target_sdtype. nfields ( ) ) ;
41- if fields_match_order {
42- for ( field, target_type) in array. iter_unmasked_fields ( ) . zip_eq ( target_sdtype. fields ( ) )
43- {
44- let cast_field = field. cast ( target_type) ?;
45- cast_fields. push ( cast_field) ;
44+ pub ( crate ) fn struct_cast (
45+ array : ArrayView < Struct > ,
46+ dtype : & DType ,
47+ ctx : & mut ExecutionCtx ,
48+ ) -> VortexResult < Option < ArrayRef > > {
49+ let Some ( target_sdtype) = dtype. as_struct_fields_opt ( ) else {
50+ return Ok ( None ) ;
51+ } ;
52+
53+ let source_sdtype = array. struct_fields ( ) ;
54+
55+ let mut cast_fields = Vec :: with_capacity ( target_sdtype. nfields ( ) ) ;
56+ // Re-order, handle fields by value instead.
57+ for ( target_name, target_type) in target_sdtype. names ( ) . iter ( ) . zip_eq ( target_sdtype. fields ( ) ) {
58+ match source_sdtype. find ( target_name) {
59+ None => {
60+ // No source field with this name => evolve the schema compatibly.
61+ // If the field is nullable, we add a new ConstantArray field with the type.
62+ vortex_ensure ! (
63+ target_type. is_nullable( ) ,
64+ "CAST for struct only supports added nullable fields"
65+ ) ;
66+
67+ cast_fields
68+ . push ( ConstantArray :: new ( Scalar :: null ( target_type) , array. len ( ) ) . into_array ( ) ) ;
4669 }
47- } else {
48- // Re-order, handle fields by value instead.
49- for ( target_name, target_type) in
50- target_sdtype. names ( ) . iter ( ) . zip_eq ( target_sdtype. fields ( ) )
51- {
52- match source_sdtype. find ( target_name) {
53- None => {
54- // No source field with this name => evolve the schema compatibly.
55- // If the field is nullable, we add a new ConstantArray field with the type.
56- vortex_ensure ! (
57- target_type. is_nullable( ) ,
58- "CAST for struct only supports added nullable fields"
59- ) ;
60-
61- cast_fields. push (
62- ConstantArray :: new ( Scalar :: null ( target_type) , array. len ( ) ) . into_array ( ) ,
63- ) ;
64- }
65- Some ( src_field_idx) => {
66- // Field exists in source field. Cast it to the target type.
67- let cast_field = array. unmasked_field ( src_field_idx) . cast ( target_type) ?;
68- cast_fields. push ( cast_field) ;
69- }
70- }
70+ Some ( src_field_idx) => {
71+ // Field exists in source field. Cast it to the target type.
72+ let cast_field = array. unmasked_field ( src_field_idx) . cast ( target_type) ?;
73+ cast_fields. push ( cast_field) ;
7174 }
7275 }
76+ }
7377
74- let validity = array
75- . validity ( ) ?
76- . cast_nullability ( dtype. nullability ( ) , array. len ( ) , ctx) ?;
78+ let validity = array
79+ . validity ( ) ?
80+ . cast_nullability ( dtype. nullability ( ) , array. len ( ) , ctx) ?;
7781
78- StructArray :: try_new (
79- target_sdtype. names ( ) . clone ( ) ,
80- cast_fields,
81- array. len ( ) ,
82- validity,
83- )
84- . map ( |a| Some ( a. into_array ( ) ) )
85- }
82+ Ok ( Some (
83+ unsafe {
84+ StructArray :: new_unchecked ( cast_fields, target_sdtype. clone ( ) , array. len ( ) , validity)
85+ }
86+ . into_array ( ) ,
87+ ) )
8688}
8789
8890#[ cfg( test) ]
8991mod tests {
92+ use std:: sync:: LazyLock ;
93+
9094 use rstest:: rstest;
9195 use vortex_buffer:: buffer;
96+ use vortex_error:: VortexResult ;
97+ use vortex_session:: VortexSession ;
9298
99+ use crate :: ArrayRef ;
100+ use crate :: ExecutionCtx ;
93101 use crate :: IntoArray ;
94- # [ expect ( deprecated ) ]
95- use crate :: ToCanonical as _ ;
102+ use crate :: VortexSessionExecute ;
103+ use crate :: arrays :: ConstantArray ;
96104 use crate :: arrays:: PrimitiveArray ;
97105 use crate :: arrays:: StructArray ;
98106 use crate :: arrays:: VarBinArray ;
107+ use crate :: arrays:: scalar_fn:: ScalarFnFactoryExt ;
99108 use crate :: arrays:: struct_:: StructArrayExt ;
109+ use crate :: assert_arrays_eq;
100110 use crate :: builtins:: ArrayBuiltins ;
101111 use crate :: compute:: conformance:: cast:: test_cast_conformance;
102112 use crate :: dtype:: DType ;
103113 use crate :: dtype:: DecimalDType ;
104114 use crate :: dtype:: FieldNames ;
105115 use crate :: dtype:: Nullability ;
106116 use crate :: dtype:: PType ;
117+ use crate :: dtype:: StructFields ;
118+ use crate :: optimizer:: kernels:: ArrayKernels ;
119+ use crate :: optimizer:: kernels:: ArrayKernelsExt ;
120+ use crate :: optimizer:: kernels:: ExecuteParentFn ;
121+ use crate :: scalar:: Scalar ;
122+ use crate :: scalar_fn:: fns:: cast:: Cast ;
123+ use crate :: session:: ArraySession ;
107124 use crate :: validity:: Validity ;
108125
126+ static SESSION : LazyLock < VortexSession > =
127+ LazyLock :: new ( || VortexSession :: empty ( ) . with :: < ArraySession > ( ) ) ;
128+
129+ fn null_struct_cast_execute_parent (
130+ child : & ArrayRef ,
131+ parent : & ArrayRef ,
132+ _child_idx : usize ,
133+ _ctx : & mut ExecutionCtx ,
134+ ) -> VortexResult < Option < ArrayRef > > {
135+ let Some ( target_fields) = parent. dtype ( ) . as_struct_fields_opt ( ) else {
136+ return Ok ( None ) ;
137+ } ;
138+ let fields: Vec < ArrayRef > = target_fields
139+ . fields ( )
140+ . map ( |dtype| ConstantArray :: new ( Scalar :: null ( dtype) , child. len ( ) ) . into_array ( ) )
141+ . collect ( ) ;
142+
143+ StructArray :: try_new (
144+ target_fields. names ( ) . clone ( ) ,
145+ fields,
146+ child. len ( ) ,
147+ Validity :: from ( parent. dtype ( ) . nullability ( ) ) ,
148+ )
149+ . map ( |array| Some ( array. into_array ( ) ) )
150+ }
151+
109152 #[ rstest]
110153 #[ case( create_test_struct( false ) ) ]
111154 #[ case( create_test_struct( true ) ) ]
@@ -115,6 +158,64 @@ mod tests {
115158 test_cast_conformance ( & array. into_array ( ) ) ;
116159 }
117160
161+ #[ test]
162+ fn struct_cast_execute_parent_is_not_static_kernel ( ) {
163+ let source = create_simple_struct ( ) . into_array ( ) ;
164+ let target = DType :: struct_ (
165+ [ (
166+ "value" ,
167+ DType :: Primitive ( PType :: I64 , Nullability :: NonNullable ) ,
168+ ) ] ,
169+ Nullability :: NonNullable ,
170+ ) ;
171+
172+ let cast = Cast
173+ . try_new_array ( source. len ( ) , target, [ source. clone ( ) ] )
174+ . unwrap ( ) ;
175+ let mut ctx = ExecutionCtx :: new ( VortexSession :: empty ( ) ) ;
176+
177+ assert ! ( source. execute_parent( & cast, 0 , & mut ctx) . unwrap( ) . is_none( ) ) ;
178+ }
179+
180+ #[ test]
181+ fn struct_cast_execute_parent_uses_session_plugin ( ) {
182+ let source = StructArray :: try_new (
183+ FieldNames :: from ( [ "a" ] ) ,
184+ vec ! [ VarBinArray :: from_vec( vec![ "A" ] , DType :: Utf8 ( Nullability :: Nullable ) ) . into_array( ) ] ,
185+ 1 ,
186+ Validity :: NonNullable ,
187+ )
188+ . unwrap ( )
189+ . into_array ( ) ;
190+ let child_id = source. encoding_id ( ) ;
191+
192+ let utf8_null = DType :: Utf8 ( Nullability :: Nullable ) ;
193+ let target = DType :: Struct (
194+ StructFields :: new ( FieldNames :: from ( [ "b" ] ) , vec ! [ utf8_null. clone( ) ] ) ,
195+ Nullability :: NonNullable ,
196+ ) ;
197+
198+ let cast = Cast
199+ . try_new_array ( source. len ( ) , target. clone ( ) , [ source] )
200+ . unwrap ( ) ;
201+ let parent_id = cast. encoding_id ( ) ;
202+ let session = VortexSession :: empty ( ) . with :: < ArrayKernels > ( ) ;
203+ session. kernels ( ) . register_execute_parent (
204+ parent_id,
205+ child_id,
206+ & [ null_struct_cast_execute_parent as ExecuteParentFn ] ,
207+ ) ;
208+ let mut ctx = session. create_execution_ctx ( ) ;
209+
210+ let result = cast. execute :: < StructArray > ( & mut ctx) . unwrap ( ) ;
211+
212+ assert_eq ! ( result. dtype( ) , & target) ;
213+ assert_arrays_eq ! (
214+ result. unmasked_field_by_name( "b" ) . unwrap( ) ,
215+ ConstantArray :: new( Scalar :: null( utf8_null) , 1 )
216+ ) ;
217+ }
218+
118219 fn create_test_struct ( nullable : bool ) -> StructArray {
119220 let names = FieldNames :: from ( [ "a" , "b" ] ) ;
120221
@@ -204,14 +305,17 @@ mod tests {
204305
205306 let target_dtype = struct_array. dtype ( ) . as_nullable ( ) ;
206307
207- let result = struct_array
308+ let cast = struct_array
208309 . into_array ( )
209310 . cast ( target_dtype. clone ( ) )
210311 . unwrap ( ) ;
211- assert_eq ! ( result. dtype( ) , & target_dtype) ;
212- assert_eq ! ( result. len( ) , 3 ) ;
213- #[ expect( deprecated) ]
214- let nfields = result. to_struct ( ) . struct_fields ( ) . nfields ( ) ;
312+ assert_eq ! ( cast. dtype( ) , & target_dtype) ;
313+ assert_eq ! ( cast. len( ) , 3 ) ;
314+ let nfields = cast
315+ . execute :: < StructArray > ( & mut SESSION . create_execution_ctx ( ) )
316+ . unwrap ( )
317+ . struct_fields ( )
318+ . nfields ( ) ;
215319 assert_eq ! ( nfields, 2 ) ;
216320 }
217321
@@ -241,8 +345,11 @@ mod tests {
241345 . unwrap ( ) ;
242346 assert_eq ! ( result. dtype( ) , & target_dtype) ;
243347 assert_eq ! ( result. len( ) , 3 ) ;
244- #[ expect( deprecated) ]
245- let nfields = result. to_struct ( ) . struct_fields ( ) . nfields ( ) ;
348+ let nfields = result
349+ . execute :: < StructArray > ( & mut SESSION . create_execution_ctx ( ) )
350+ . unwrap ( )
351+ . struct_fields ( )
352+ . nfields ( ) ;
246353 assert_eq ! ( nfields, 3 ) ;
247354 }
248355}
0 commit comments