Skip to content

Commit 6fa3cae

Browse files
authored
feat: teach Array<Struct> try_from(&[Array<Struct>]) (#7632)
Look ma, no execution context! Signed-off-by: Daniel King <dan@spiraldb.com>
1 parent e3a6b62 commit 6fa3cae

5 files changed

Lines changed: 139 additions & 81 deletions

File tree

vortex-array/public-api.lock

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19480,6 +19480,10 @@ pub fn vortex_array::validity::Validity::union_nullability(self, nullability: vo
1948019480

1948119481
impl vortex_array::validity::Validity
1948219482

19483+
pub fn vortex_array::validity::Validity::concat(validities: alloc::vec::Vec<(vortex_array::validity::Validity, usize)>) -> core::option::Option<Self>
19484+
19485+
impl vortex_array::validity::Validity
19486+
1948319487
pub fn vortex_array::validity::Validity::execute(self, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::validity::Validity>
1948419488

1948519489
impl vortex_array::validity::Validity
@@ -22066,6 +22070,10 @@ pub fn vortex_array::Array<vortex_array::arrays::Struct>::project(&self, project
2206622070

2206722071
pub fn vortex_array::Array<vortex_array::arrays::Struct>::remove_column(&self, name: impl core::convert::Into<vortex_array::dtype::FieldName>) -> core::option::Option<(Self, vortex_array::ArrayRef)>
2206822072

22073+
pub fn vortex_array::Array<vortex_array::arrays::Struct>::remove_column_owned(&self, name: impl core::convert::Into<vortex_array::dtype::FieldName>) -> core::option::Option<(Self, vortex_array::ArrayRef)>
22074+
22075+
pub fn vortex_array::Array<vortex_array::arrays::Struct>::try_concat<T>(chunks: impl core::iter::traits::collect::IntoIterator<Item = T>) -> vortex_error::VortexResult<Self> where T: core::borrow::Borrow<vortex_array::Array<vortex_array::arrays::Struct>>
22076+
2206922077
pub fn vortex_array::Array<vortex_array::arrays::Struct>::try_from_iter<N: core::convert::AsRef<str>, A: vortex_array::IntoArray, T: core::iter::traits::collect::IntoIterator<Item = (N, A)>>(iter: T) -> vortex_error::VortexResult<Self>
2207022078

2207122079
pub fn vortex_array::Array<vortex_array::arrays::Struct>::try_from_iter_with_validity<N: core::convert::AsRef<str>, A: vortex_array::IntoArray, T: core::iter::traits::collect::IntoIterator<Item = (N, A)>>(iter: T, validity: vortex_array::validity::Validity) -> vortex_error::VortexResult<Self>
@@ -22074,15 +22082,11 @@ pub fn vortex_array::Array<vortex_array::arrays::Struct>::try_new(names: vortex_
2207422082

2207522083
pub fn vortex_array::Array<vortex_array::arrays::Struct>::try_new_with_dtype(fields: impl core::convert::Into<alloc::sync::Arc<[vortex_array::ArrayRef]>>, dtype: vortex_array::dtype::StructFields, length: usize, validity: vortex_array::validity::Validity) -> vortex_error::VortexResult<Self>
2207622084

22077-
impl vortex_array::Array<vortex_array::arrays::Struct>
22078-
22079-
pub fn vortex_array::Array<vortex_array::arrays::Struct>::into_record_batch_with_schema(self, schema: impl core::convert::AsRef<arrow_schema::schema::Schema>) -> vortex_error::VortexResult<arrow_array::record_batch::RecordBatch>
22085+
pub fn vortex_array::Array<vortex_array::arrays::Struct>::with_column(&self, name: impl core::convert::Into<vortex_array::dtype::FieldName>, array: vortex_array::ArrayRef) -> vortex_error::VortexResult<Self>
2208022086

2208122087
impl vortex_array::Array<vortex_array::arrays::Struct>
2208222088

22083-
pub fn vortex_array::Array<vortex_array::arrays::Struct>::remove_column_owned(&self, name: impl core::convert::Into<vortex_array::dtype::FieldName>) -> core::option::Option<(Self, vortex_array::ArrayRef)>
22084-
22085-
pub fn vortex_array::Array<vortex_array::arrays::Struct>::with_column(&self, name: impl core::convert::Into<vortex_array::dtype::FieldName>, array: vortex_array::ArrayRef) -> vortex_error::VortexResult<Self>
22089+
pub fn vortex_array::Array<vortex_array::arrays::Struct>::into_record_batch_with_schema(self, schema: impl core::convert::AsRef<arrow_schema::schema::Schema>) -> vortex_error::VortexResult<arrow_array::record_batch::RecordBatch>
2208622090

2208722091
impl vortex_array::Array<vortex_array::arrays::VarBin>
2208822092

vortex-array/src/arrays/chunked/vtable/canonical.rs

Lines changed: 8 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use itertools::Itertools as _;
45
use vortex_buffer::Buffer;
56
use vortex_error::VortexExpect;
67
use vortex_error::VortexResult;
@@ -18,13 +19,11 @@ use crate::arrays::StructArray;
1819
use crate::arrays::chunked::ChunkedArrayExt;
1920
use crate::arrays::listview::ListViewArrayExt;
2021
use crate::arrays::listview::ListViewRebuildMode;
21-
use crate::arrays::struct_::StructArrayExt;
2222
use crate::builders::builder_with_capacity_in;
2323
use crate::builtins::ArrayBuiltins;
2424
use crate::dtype::DType;
2525
use crate::dtype::Nullability;
2626
use crate::dtype::PType;
27-
use crate::dtype::StructFields;
2827
use crate::memory::HostAllocatorExt;
2928
use crate::validity::Validity;
3029

@@ -41,9 +40,8 @@ pub(super) fn _canonicalize(
4140

4241
let owned_chunks: Vec<ArrayRef> = array.iter_chunks().cloned().collect();
4342
Ok(match array.dtype() {
44-
DType::Struct(struct_dtype, _) => {
45-
let struct_array =
46-
pack_struct_chunks(&owned_chunks, array.array().validity()?, struct_dtype, ctx)?;
43+
DType::Struct(..) => {
44+
let struct_array = pack_struct_chunks(owned_chunks, ctx)?;
4745
Canonical::Struct(struct_array)
4846
}
4947
DType::List(elem_dtype, _) => Canonical::List(swizzle_list_chunks(
@@ -64,36 +62,11 @@ pub(super) fn _canonicalize(
6462
/// field is a [`ChunkedArray`].
6563
///
6664
/// The caller guarantees there are at least 2 chunks.
67-
fn pack_struct_chunks(
68-
chunks: &[ArrayRef],
69-
validity: Validity,
70-
struct_dtype: &StructFields,
71-
ctx: &mut ExecutionCtx,
72-
) -> VortexResult<StructArray> {
73-
let len = chunks.iter().map(|chunk| chunk.len()).sum();
74-
let mut field_arrays = Vec::new();
75-
76-
let executed_chunks: Vec<StructArray> = chunks
77-
.iter()
78-
.map(|c| c.clone().execute::<StructArray>(ctx))
79-
.collect::<VortexResult<_>>()?;
80-
81-
for (field_idx, field_dtype) in struct_dtype.fields().enumerate() {
82-
let mut field_chunks = Vec::with_capacity(chunks.len());
83-
for struct_array in &executed_chunks {
84-
let field = struct_array.unmasked_field(field_idx).clone();
85-
field_chunks.push(field);
86-
}
87-
88-
// SAFETY: field_chunks are extracted from valid StructArrays with matching dtypes.
89-
// Each chunk's field array is guaranteed to be valid for field_dtype.
90-
let field_array = unsafe { ChunkedArray::new_unchecked(field_chunks, field_dtype.clone()) };
91-
field_arrays.push(field_array.into_array());
92-
}
93-
94-
// SAFETY: field_arrays are built from corresponding chunks of same length, dtypes match by
95-
// construction.
96-
Ok(unsafe { StructArray::new_unchecked(field_arrays, struct_dtype.clone(), len, validity) })
65+
fn pack_struct_chunks(chunks: Vec<ArrayRef>, ctx: &mut ExecutionCtx) -> VortexResult<StructArray> {
66+
chunks
67+
.into_iter()
68+
.map(|c| c.execute::<StructArray>(ctx))
69+
.process_results(|iter| StructArray::try_concat(iter))?
9770
}
9871

9972
/// Packs [`ListViewArray`]s together into a chunked `ListViewArray`.

vortex-array/src/arrays/chunked/vtable/validity.rs

Lines changed: 10 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,52 +4,24 @@
44
use itertools::Itertools;
55
use vortex_error::VortexResult;
66

7-
use crate::IntoArray;
87
use crate::array::ArrayView;
98
use crate::array::ValidityVTable;
109
use crate::arrays::Chunked;
11-
use crate::arrays::ChunkedArray;
1210
use crate::arrays::chunked::ChunkedArrayExt;
13-
use crate::dtype::DType;
14-
use crate::dtype::Nullability;
1511
use crate::validity::Validity;
1612

1713
impl ValidityVTable<Chunked> for Chunked {
1814
fn validity(array: ArrayView<'_, Chunked>) -> VortexResult<Validity> {
19-
let validities: Vec<Validity> =
20-
array.chunks().iter().map(|c| c.validity()).try_collect()?;
15+
let validities = array
16+
.chunks()
17+
.iter()
18+
.map(|chunk| chunk.validity().map(|v| (v, chunk.len())))
19+
.try_collect()?;
20+
let Some(validity) = Validity::concat(validities) else {
21+
// If there are no chunks:
22+
return Ok(array.dtype().nullability().into());
23+
};
2124

22-
match validities.first() {
23-
// If there are no chunks, return the array's dtype nullability
24-
None => return Ok(array.dtype().nullability().into()),
25-
// If all chunks have the same non-array validity, return that validity directly
26-
// We skip Validity::Array since equality is very expensive.
27-
Some(first) if !matches!(first, Validity::Array(_)) => {
28-
let target = std::mem::discriminant(first);
29-
if validities
30-
.iter()
31-
.all(|v| std::mem::discriminant(v) == target)
32-
{
33-
return Ok(first.clone());
34-
}
35-
}
36-
_ => {
37-
// Array validity or mixed validities, proceed to build the validity array
38-
}
39-
}
40-
41-
Ok(Validity::Array(
42-
unsafe {
43-
ChunkedArray::new_unchecked(
44-
validities
45-
.into_iter()
46-
.zip(array.iter_chunks())
47-
.map(|(v, chunk)| v.to_array(chunk.len()))
48-
.collect(),
49-
DType::Bool(Nullability::NonNullable),
50-
)
51-
}
52-
.into_array(),
53-
))
25+
Ok(validity)
5426
}
5527
}

vortex-array/src/arrays/struct_/array.rs

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use std::borrow::Borrow;
45
use std::iter::once;
56
use std::sync::Arc;
67

8+
use itertools::Itertools;
79
use vortex_error::VortexExpect;
810
use vortex_error::VortexResult;
11+
use vortex_error::vortex_bail;
912
use vortex_error::vortex_err;
1013

1114
use crate::ArrayRef;
@@ -16,6 +19,7 @@ use crate::array::EmptyArrayData;
1619
use crate::array::TypedArrayRef;
1720
use crate::array::child_to_validity;
1821
use crate::array::validity_to_child;
22+
use crate::arrays::ChunkedArray;
1923
use crate::arrays::Struct;
2024
use crate::dtype::DType;
2125
use crate::dtype::FieldName;
@@ -430,9 +434,7 @@ impl Array<Struct> {
430434
};
431435
Some((new_array, field))
432436
}
433-
}
434437

435-
impl Array<Struct> {
436438
pub fn with_column(&self, name: impl Into<FieldName>, array: ArrayRef) -> VortexResult<Self> {
437439
let name = name.into();
438440
let struct_dtype = self.struct_fields();
@@ -453,4 +455,70 @@ impl Array<Struct> {
453455
pub fn remove_column_owned(&self, name: impl Into<FieldName>) -> Option<(Self, ArrayRef)> {
454456
self.remove_column(name)
455457
}
458+
459+
pub fn try_concat<T>(chunks: impl IntoIterator<Item = T>) -> VortexResult<Self>
460+
where
461+
T: Borrow<Array<Struct>>,
462+
{
463+
let mut it = chunks.into_iter();
464+
let Some(first) = it.next() else {
465+
vortex_bail!("cannot concat empty iterator of arrays");
466+
};
467+
let first_dtype = first.borrow().dtype().clone();
468+
let struct_fields = first_dtype.as_struct_fields().clone();
469+
let names = struct_fields.names();
470+
471+
let it = [first].into_iter().chain(it);
472+
let (field_arrays_per_chunk, validities) = it
473+
.map(|chunk| {
474+
let chunk = chunk.borrow();
475+
if &first_dtype != chunk.dtype() {
476+
vortex_bail!(
477+
"cannot concatenate struct arrays with differing dtypes: {}, {}",
478+
first_dtype,
479+
chunk.dtype(),
480+
);
481+
}
482+
483+
let fields = names
484+
.iter()
485+
.map(|name| {
486+
chunk
487+
.unmasked_field_by_name(name)
488+
.vortex_expect("field exists because it is in dtype")
489+
.clone()
490+
})
491+
.collect::<Vec<_>>();
492+
let validity = chunk.validity()?;
493+
494+
Ok((fields, (validity, chunk.len())))
495+
})
496+
.process_results(|iter| iter.unzip::<_, _, Vec<_>, Vec<_>>())?;
497+
498+
let field_arrays = struct_fields
499+
.fields()
500+
.enumerate()
501+
.map(|(i, dtype)| {
502+
// SAFETY: We establish above that every array has the same type.
503+
let chunks = field_arrays_per_chunk
504+
.iter()
505+
.map(|x| x[i].clone())
506+
.collect();
507+
unsafe { ChunkedArray::new_unchecked(chunks, dtype) }.into_array()
508+
})
509+
.collect::<Vec<_>>();
510+
let len = validities.iter().map(|(_v, len)| len).sum();
511+
let validity = Validity::concat(validities).vortex_expect("verified non-empty above");
512+
513+
// SAFETY:
514+
//
515+
// 1. The field arrays, by construction, have the type specified in fields.
516+
//
517+
// 2. Each Array<Struct> has a valid len, therefore the sum of those lens should be valid
518+
// for the concatenation of each field.
519+
//
520+
// 3. Each Array<Struct> has a valid validity, so the concatenation of those validities has
521+
// the correct length and dtype harmony.
522+
Ok(unsafe { Array::<Struct>::new_unchecked(field_arrays, struct_fields, len, validity) })
523+
}
456524
}

vortex-array/src/validity.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
use std::fmt::Debug;
77
use std::ops::Range;
88

9+
use itertools::Itertools as _;
910
use vortex_buffer::BitBuffer;
1011
use vortex_error::VortexExpect as _;
1112
use vortex_error::VortexResult;
@@ -22,6 +23,7 @@ use crate::IntoArray;
2223
use crate::LEGACY_SESSION;
2324
use crate::VortexSessionExecute;
2425
use crate::arrays::BoolArray;
26+
use crate::arrays::ChunkedArray;
2527
use crate::arrays::ConstantArray;
2628
use crate::arrays::scalar_fn::ScalarFnFactoryExt;
2729
use crate::builtins::ArrayBuiltins;
@@ -447,6 +449,45 @@ impl From<&Nullability> for Validity {
447449
}
448450
}
449451

452+
impl Validity {
453+
/// Concatenate one or more validities together.
454+
///
455+
/// Returns None if the vector is empty.
456+
pub fn concat(validities: Vec<(Validity, usize)>) -> Option<Self> {
457+
let mut validity_kinds = validities
458+
.iter()
459+
.map(|(v, _)| std::mem::discriminant(v))
460+
.unique();
461+
let validity_kind = validity_kinds.next()?;
462+
if validity_kinds.next().is_none() {
463+
// If there is only one kind of validity and its not Validity::Array, avoid constructing
464+
// a Validity::Array.
465+
if validity_kind == std::mem::discriminant(&Validity::AllValid) {
466+
return Some(Validity::AllValid);
467+
}
468+
if validity_kind == std::mem::discriminant(&Validity::AllInvalid) {
469+
return Some(Validity::AllInvalid);
470+
}
471+
if validity_kind == std::mem::discriminant(&Validity::NonNullable) {
472+
return Some(Validity::NonNullable);
473+
}
474+
}
475+
476+
Some(Validity::Array(
477+
unsafe {
478+
ChunkedArray::new_unchecked(
479+
validities
480+
.into_iter()
481+
.map(|(v, len)| v.to_array(len))
482+
.collect(),
483+
DType::Bool(Nullability::NonNullable),
484+
)
485+
}
486+
.into_array(),
487+
))
488+
}
489+
}
490+
450491
impl Validity {
451492
pub fn from_bit_buffer(buffer: BitBuffer, nullability: Nullability) -> Self {
452493
if buffer.true_count() == buffer.len() {

0 commit comments

Comments
 (0)