Skip to content

Commit 880c8d3

Browse files
committed
Make struct cast implementation pluggable
Signed-off-by: Robert Kruszewski <github@robertk.io> fmt Signed-off-by: Robert Kruszewski <github@robertk.io> less Signed-off-by: Robert Kruszewski <github@robertk.io> fixes Signed-off-by: Robert Kruszewski <github@robertk.io> refactor Signed-off-by: Robert Kruszewski <github@robertk.io>
1 parent 44dd940 commit 880c8d3

9 files changed

Lines changed: 547 additions & 214 deletions

File tree

vortex-array/public-api.lock

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13586,8 +13586,12 @@ impl vortex_array::optimizer::kernels::ArrayKernels
1358613586

1358713587
pub fn vortex_array::optimizer::kernels::ArrayKernels::empty() -> Self
1358813588

13589+
pub fn vortex_array::optimizer::kernels::ArrayKernels::find_execute_parent(&self, parent: vortex_session::registry::Id, child: vortex_session::registry::Id) -> core::option::Option<alloc::sync::Arc<[vortex_array::optimizer::kernels::ExecuteParentFn]>>
13590+
1358913591
pub fn vortex_array::optimizer::kernels::ArrayKernels::find_reduce_parent(&self, parent: vortex_session::registry::Id, child: vortex_session::registry::Id) -> core::option::Option<alloc::sync::Arc<[vortex_array::optimizer::kernels::ReduceParentFn]>>
1359013592

13593+
pub fn vortex_array::optimizer::kernels::ArrayKernels::register_execute_parent<I: core::iter::traits::collect::IntoIterator<Item = vortex_array::optimizer::kernels::ExecuteParentFn>>(&self, parent: vortex_session::registry::Id, child: vortex_session::registry::Id, fns: I)
13594+
1359113595
pub fn vortex_array::optimizer::kernels::ArrayKernels::register_reduce_parent<I: core::iter::traits::collect::IntoIterator<Item = vortex_array::optimizer::kernels::ReduceParentFn>>(&self, parent: vortex_session::registry::Id, child: vortex_session::registry::Id, fns: I)
1359213596

1359313597
impl core::default::Default for vortex_array::optimizer::kernels::ArrayKernels
@@ -13612,6 +13616,8 @@ impl<S: vortex_session::SessionExt> vortex_array::optimizer::kernels::ArrayKerne
1361213616

1361313617
pub fn S::kernels(&self) -> vortex_session::Ref<'_, vortex_array::optimizer::kernels::ArrayKernels>
1361413618

13619+
pub type vortex_array::optimizer::kernels::ExecuteParentFn = fn(child: &vortex_array::ArrayRef, parent: &vortex_array::ArrayRef, child_idx: usize, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<core::option::Option<vortex_array::ArrayRef>>
13620+
1361513621
pub type vortex_array::optimizer::kernels::ReduceParentFn = fn(child: &vortex_array::ArrayRef, parent: &vortex_array::ArrayRef, child_idx: usize) -> vortex_error::VortexResult<core::option::Option<vortex_array::ArrayRef>>
1361613622

1361713623
pub mod vortex_array::optimizer::rules

vortex-array/src/arrays/struct_/compute/cast.rs

Lines changed: 177 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -6,106 +6,149 @@ use vortex_error::VortexResult;
66
use vortex_error::vortex_ensure;
77

88
use crate::ArrayRef;
9+
use crate::ArrayView;
910
use crate::ExecutionCtx;
1011
use crate::IntoArray;
11-
use crate::array::ArrayView;
1212
use crate::arrays::ConstantArray;
1313
use crate::arrays::Struct;
1414
use crate::arrays::StructArray;
15+
use crate::arrays::scalar_fn::ExactScalarFn;
1516
use crate::arrays::struct_::StructArrayExt;
1617
use crate::builtins::ArrayBuiltins;
1718
use crate::dtype::DType;
19+
use crate::matcher::Matcher;
1820
use crate::scalar::Scalar;
19-
use crate::scalar_fn::fns::cast::CastKernel;
20-
21-
impl CastKernel for Struct {
22-
fn cast(
23-
array: ArrayView<'_, Struct>,
24-
dtype: &DType,
25-
ctx: &mut ExecutionCtx,
26-
) -> VortexResult<Option<ArrayRef>> {
27-
let Some(target_sdtype) = dtype.as_struct_fields_opt() else {
28-
return Ok(None);
29-
};
30-
31-
let source_sdtype = array.struct_fields();
21+
use crate::scalar_fn::fns::cast::Cast;
22+
23+
pub(crate) fn struct_cast_execute_parent(
24+
child: &ArrayRef,
25+
parent: &ArrayRef,
26+
_child_idx: usize,
27+
ctx: &mut ExecutionCtx,
28+
) -> VortexResult<Option<ArrayRef>> {
29+
let Some(array) = child.as_opt::<Struct>() else {
30+
return Ok(None);
31+
};
32+
let Some(parent) = ExactScalarFn::<Cast>::try_match(parent) else {
33+
return Ok(None);
34+
};
35+
36+
let dtype = parent.options;
37+
if array.dtype() == parent.options {
38+
return Ok(Some(array.array().clone()));
39+
}
3240

33-
let fields_match_order = target_sdtype.nfields() == source_sdtype.nfields()
34-
&& target_sdtype
35-
.names()
36-
.iter()
37-
.zip(source_sdtype.names().iter())
38-
.all(|(f1, f2)| f1 == f2);
41+
struct_cast(array, dtype, ctx)
42+
}
3943

40-
let mut cast_fields = Vec::with_capacity(target_sdtype.nfields());
41-
if fields_match_order {
42-
for (field, target_type) in array.iter_unmasked_fields().zip_eq(target_sdtype.fields())
43-
{
44-
let cast_field = field.cast(target_type)?;
45-
cast_fields.push(cast_field);
44+
pub(crate) fn struct_cast(
45+
array: ArrayView<Struct>,
46+
dtype: &DType,
47+
ctx: &mut ExecutionCtx,
48+
) -> VortexResult<Option<ArrayRef>> {
49+
let Some(target_sdtype) = dtype.as_struct_fields_opt() else {
50+
return Ok(None);
51+
};
52+
53+
let source_sdtype = array.struct_fields();
54+
55+
let mut cast_fields = Vec::with_capacity(target_sdtype.nfields());
56+
// Re-order, handle fields by value instead.
57+
for (target_name, target_type) in target_sdtype.names().iter().zip_eq(target_sdtype.fields()) {
58+
match source_sdtype.find(target_name) {
59+
None => {
60+
// No source field with this name => evolve the schema compatibly.
61+
// If the field is nullable, we add a new ConstantArray field with the type.
62+
vortex_ensure!(
63+
target_type.is_nullable(),
64+
"CAST for struct only supports added nullable fields"
65+
);
66+
67+
cast_fields
68+
.push(ConstantArray::new(Scalar::null(target_type), array.len()).into_array());
4669
}
47-
} else {
48-
// Re-order, handle fields by value instead.
49-
for (target_name, target_type) in
50-
target_sdtype.names().iter().zip_eq(target_sdtype.fields())
51-
{
52-
match source_sdtype.find(target_name) {
53-
None => {
54-
// No source field with this name => evolve the schema compatibly.
55-
// If the field is nullable, we add a new ConstantArray field with the type.
56-
vortex_ensure!(
57-
target_type.is_nullable(),
58-
"CAST for struct only supports added nullable fields"
59-
);
60-
61-
cast_fields.push(
62-
ConstantArray::new(Scalar::null(target_type), array.len()).into_array(),
63-
);
64-
}
65-
Some(src_field_idx) => {
66-
// Field exists in source field. Cast it to the target type.
67-
let cast_field = array.unmasked_field(src_field_idx).cast(target_type)?;
68-
cast_fields.push(cast_field);
69-
}
70-
}
70+
Some(src_field_idx) => {
71+
// Field exists in source field. Cast it to the target type.
72+
let cast_field = array.unmasked_field(src_field_idx).cast(target_type)?;
73+
cast_fields.push(cast_field);
7174
}
7275
}
76+
}
7377

74-
let validity = array
75-
.validity()?
76-
.cast_nullability(dtype.nullability(), array.len(), ctx)?;
78+
let validity = array
79+
.validity()?
80+
.cast_nullability(dtype.nullability(), array.len(), ctx)?;
7781

78-
StructArray::try_new(
79-
target_sdtype.names().clone(),
80-
cast_fields,
81-
array.len(),
82-
validity,
83-
)
84-
.map(|a| Some(a.into_array()))
85-
}
82+
Ok(Some(
83+
unsafe {
84+
StructArray::new_unchecked(cast_fields, target_sdtype.clone(), array.len(), validity)
85+
}
86+
.into_array(),
87+
))
8688
}
8789

8890
#[cfg(test)]
8991
mod tests {
92+
use std::sync::LazyLock;
93+
9094
use rstest::rstest;
9195
use vortex_buffer::buffer;
96+
use vortex_error::VortexResult;
97+
use vortex_session::VortexSession;
9298

99+
use crate::ArrayRef;
100+
use crate::ExecutionCtx;
93101
use crate::IntoArray;
94-
#[expect(deprecated)]
95-
use crate::ToCanonical as _;
102+
use crate::VortexSessionExecute;
103+
use crate::arrays::ConstantArray;
96104
use crate::arrays::PrimitiveArray;
97105
use crate::arrays::StructArray;
98106
use crate::arrays::VarBinArray;
107+
use crate::arrays::scalar_fn::ScalarFnFactoryExt;
99108
use crate::arrays::struct_::StructArrayExt;
109+
use crate::assert_arrays_eq;
100110
use crate::builtins::ArrayBuiltins;
101111
use crate::compute::conformance::cast::test_cast_conformance;
102112
use crate::dtype::DType;
103113
use crate::dtype::DecimalDType;
104114
use crate::dtype::FieldNames;
105115
use crate::dtype::Nullability;
106116
use crate::dtype::PType;
117+
use crate::dtype::StructFields;
118+
use crate::optimizer::kernels::ArrayKernels;
119+
use crate::optimizer::kernels::ArrayKernelsExt;
120+
use crate::optimizer::kernels::ExecuteParentFn;
121+
use crate::scalar::Scalar;
122+
use crate::scalar_fn::fns::cast::Cast;
123+
use crate::session::ArraySession;
107124
use crate::validity::Validity;
108125

126+
static SESSION: LazyLock<VortexSession> =
127+
LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
128+
129+
fn null_struct_cast_execute_parent(
130+
child: &ArrayRef,
131+
parent: &ArrayRef,
132+
_child_idx: usize,
133+
_ctx: &mut ExecutionCtx,
134+
) -> VortexResult<Option<ArrayRef>> {
135+
let Some(target_fields) = parent.dtype().as_struct_fields_opt() else {
136+
return Ok(None);
137+
};
138+
let fields: Vec<ArrayRef> = target_fields
139+
.fields()
140+
.map(|dtype| ConstantArray::new(Scalar::null(dtype), child.len()).into_array())
141+
.collect();
142+
143+
StructArray::try_new(
144+
target_fields.names().clone(),
145+
fields,
146+
child.len(),
147+
Validity::from(parent.dtype().nullability()),
148+
)
149+
.map(|array| Some(array.into_array()))
150+
}
151+
109152
#[rstest]
110153
#[case(create_test_struct(false))]
111154
#[case(create_test_struct(true))]
@@ -115,6 +158,64 @@ mod tests {
115158
test_cast_conformance(&array.into_array());
116159
}
117160

161+
#[test]
162+
fn struct_cast_execute_parent_is_not_static_kernel() {
163+
let source = create_simple_struct().into_array();
164+
let target = DType::struct_(
165+
[(
166+
"value",
167+
DType::Primitive(PType::I64, Nullability::NonNullable),
168+
)],
169+
Nullability::NonNullable,
170+
);
171+
172+
let cast = Cast
173+
.try_new_array(source.len(), target, [source.clone()])
174+
.unwrap();
175+
let mut ctx = ExecutionCtx::new(VortexSession::empty());
176+
177+
assert!(source.execute_parent(&cast, 0, &mut ctx).unwrap().is_none());
178+
}
179+
180+
#[test]
181+
fn struct_cast_execute_parent_uses_session_plugin() {
182+
let source = StructArray::try_new(
183+
FieldNames::from(["a"]),
184+
vec![VarBinArray::from_vec(vec!["A"], DType::Utf8(Nullability::Nullable)).into_array()],
185+
1,
186+
Validity::NonNullable,
187+
)
188+
.unwrap()
189+
.into_array();
190+
let child_id = source.encoding_id();
191+
192+
let utf8_null = DType::Utf8(Nullability::Nullable);
193+
let target = DType::Struct(
194+
StructFields::new(FieldNames::from(["b"]), vec![utf8_null.clone()]),
195+
Nullability::NonNullable,
196+
);
197+
198+
let cast = Cast
199+
.try_new_array(source.len(), target.clone(), [source])
200+
.unwrap();
201+
let parent_id = cast.encoding_id();
202+
let session = VortexSession::empty().with::<ArrayKernels>();
203+
session.kernels().register_execute_parent(
204+
parent_id,
205+
child_id,
206+
&[null_struct_cast_execute_parent as ExecuteParentFn],
207+
);
208+
let mut ctx = session.create_execution_ctx();
209+
210+
let result = cast.execute::<StructArray>(&mut ctx).unwrap();
211+
212+
assert_eq!(result.dtype(), &target);
213+
assert_arrays_eq!(
214+
result.unmasked_field_by_name("b").unwrap(),
215+
ConstantArray::new(Scalar::null(utf8_null), 1)
216+
);
217+
}
218+
118219
fn create_test_struct(nullable: bool) -> StructArray {
119220
let names = FieldNames::from(["a", "b"]);
120221

@@ -204,14 +305,17 @@ mod tests {
204305

205306
let target_dtype = struct_array.dtype().as_nullable();
206307

207-
let result = struct_array
308+
let cast = struct_array
208309
.into_array()
209310
.cast(target_dtype.clone())
210311
.unwrap();
211-
assert_eq!(result.dtype(), &target_dtype);
212-
assert_eq!(result.len(), 3);
213-
#[expect(deprecated)]
214-
let nfields = result.to_struct().struct_fields().nfields();
312+
assert_eq!(cast.dtype(), &target_dtype);
313+
assert_eq!(cast.len(), 3);
314+
let nfields = cast
315+
.execute::<StructArray>(&mut SESSION.create_execution_ctx())
316+
.unwrap()
317+
.struct_fields()
318+
.nfields();
215319
assert_eq!(nfields, 2);
216320
}
217321

@@ -241,8 +345,11 @@ mod tests {
241345
.unwrap();
242346
assert_eq!(result.dtype(), &target_dtype);
243347
assert_eq!(result.len(), 3);
244-
#[expect(deprecated)]
245-
let nfields = result.to_struct().struct_fields().nfields();
348+
let nfields = result
349+
.execute::<StructArray>(&mut SESSION.create_execution_ctx())
350+
.unwrap()
351+
.struct_fields()
352+
.nfields();
246353
assert_eq!(nfields, 3);
247354
}
248355
}

vortex-array/src/arrays/struct_/compute/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
mod cast;
4+
pub(crate) mod cast;
55
mod mask;
66
pub(crate) mod rules;
77
mod slice;

0 commit comments

Comments
 (0)