diff --git a/Cargo.lock b/Cargo.lock index f9fd5e6e960..939f3885108 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10084,6 +10084,7 @@ dependencies = [ "arrow-select 58.1.0", "arrow-string 58.1.0", "async-lock", + "base64", "bytes", "cfg-if", "codspeed-divan-compat", @@ -10918,6 +10919,7 @@ dependencies = [ name = "vortex-tensor" version = "0.1.0" dependencies = [ + "arrow-schema 58.1.0", "codspeed-divan-compat", "half", "itertools 0.14.0", @@ -10927,6 +10929,8 @@ dependencies = [ "rand 0.10.1", "rand_distr 0.6.0", "rstest", + "serde", + "serde_json", "vortex-array", "vortex-btrblocks", "vortex-buffer", diff --git a/Cargo.toml b/Cargo.toml index 7770a59d03d..917cd4d3f6f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -106,6 +106,7 @@ async-lock = "3.4" async-stream = "0.3.6" async-trait = "0.1.89" base16ct = "1.0.0" +base64 = "0.22" bigdecimal = "0.4.8" bindgen = "0.72.0" bit-vec = "0.9.0" diff --git a/vortex-array/Cargo.toml b/vortex-array/Cargo.toml index f8676d76ef0..a56f4317333 100644 --- a/vortex-array/Cargo.toml +++ b/vortex-array/Cargo.toml @@ -33,6 +33,7 @@ arrow-schema = { workspace = true } arrow-select = { workspace = true } arrow-string = { workspace = true } async-lock = { workspace = true } +base64 = { workspace = true } bytes = { workspace = true } cfg-if = { workspace = true } cudarc = { workspace = true, optional = true } diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 0675fa157ad..90667c25f37 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -7072,158 +7072,234 @@ pub fn vortex_array::ArrayRef::execute_record_batch(self, schema: &arrow_schema: pub fn vortex_array::ArrayRef::execute_record_batches(self, schema: &arrow_schema::schema::Schema, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> -pub trait vortex_array::arrow::FromArrowArray +pub trait vortex_array::arrow::FromArrowArray: core::marker::Sized -pub fn vortex_array::arrow::FromArrowArray::from_arrow(array: A, nullable: bool) -> vortex_error::VortexResult where Self: core::marker::Sized +pub fn vortex_array::arrow::FromArrowArray::from_arrow(array: A, nullable: bool) -> vortex_error::VortexResult + +pub fn vortex_array::arrow::FromArrowArray::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult impl vortex_array::arrow::FromArrowArray<&arrow_array::array::boolean_array::BooleanArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::boolean_array::BooleanArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::fixed_size_list_array::FixedSizeListArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &arrow_array::array::fixed_size_list_array::FixedSizeListArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: &arrow_array::array::fixed_size_list_array::FixedSizeListArray, nullable: bool, session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::null_array::NullArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::null_array::NullArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::struct_array::StructArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::struct_array::StructArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(value: &arrow_array::array::struct_array::StructArray, nullable: bool, session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::record_batch::RecordBatch> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &arrow_array::record_batch::RecordBatch, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: &arrow_array::record_batch::RecordBatch, nullable: bool, session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&dyn arrow_array::array::Array> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &dyn arrow_array::array::Array, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: &dyn arrow_array::array::Array, nullable: bool, session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: arrow_array::record_batch::RecordBatch, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: arrow_array::record_batch::RecordBatch, nullable: bool, session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::dictionary_array::DictionaryArray> for vortex_array::arrays::dict::DictArray pub fn vortex_array::arrays::dict::DictArray::from_arrow(array: &arrow_array::array::dictionary_array::DictionaryArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::arrays::dict::DictArray::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::list_view_array::GenericListViewArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &arrow_array::array::list_view_array::GenericListViewArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::list_array::GenericListArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::list_array::GenericListArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(value: &arrow_array::array::list_array::GenericListArray, nullable: bool, session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::byte_array::GenericByteArray> for vortex_array::ArrayRef where ::Offset: vortex_array::dtype::IntegerPType pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::byte_array::GenericByteArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::byte_view_array::GenericByteViewArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::byte_view_array::GenericByteViewArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + pub trait vortex_array::arrow::IntoArrowArray pub fn vortex_array::arrow::IntoArrowArray::into_arrow(self, data_type: &arrow_schema::datatype::DataType) -> vortex_error::VortexResult @@ -8744,26 +8820,38 @@ pub trait vortex_array::dtype::arrow::FromArrowType: core::marker::Sized pub fn vortex_array::dtype::arrow::FromArrowType::from_arrow(value: T) -> Self +pub fn vortex_array::dtype::arrow::FromArrowType::from_arrow_with_session(value: T, _session: &vortex_session::VortexSession) -> Self + impl vortex_array::dtype::arrow::FromArrowType<&arrow_schema::field::Field> for vortex_array::dtype::DType pub fn vortex_array::dtype::DType::from_arrow(field: &arrow_schema::field::Field) -> Self +pub fn vortex_array::dtype::DType::from_arrow_with_session(field: &arrow_schema::field::Field, session: &vortex_session::VortexSession) -> Self + impl vortex_array::dtype::arrow::FromArrowType<&arrow_schema::fields::Fields> for vortex_array::dtype::StructFields pub fn vortex_array::dtype::StructFields::from_arrow(value: &arrow_schema::fields::Fields) -> Self +pub fn vortex_array::dtype::StructFields::from_arrow_with_session(value: &arrow_schema::fields::Fields, session: &vortex_session::VortexSession) -> Self + impl vortex_array::dtype::arrow::FromArrowType<&arrow_schema::schema::Schema> for vortex_array::dtype::DType pub fn vortex_array::dtype::DType::from_arrow(value: &arrow_schema::schema::Schema) -> Self +pub fn vortex_array::dtype::DType::from_arrow_with_session(value: &arrow_schema::schema::Schema, session: &vortex_session::VortexSession) -> Self + impl vortex_array::dtype::arrow::FromArrowType<(&arrow_schema::datatype::DataType, vortex_array::dtype::Nullability)> for vortex_array::dtype::DType pub fn vortex_array::dtype::DType::from_arrow((data_type, nullability): (&arrow_schema::datatype::DataType, vortex_array::dtype::Nullability)) -> Self +pub fn vortex_array::dtype::DType::from_arrow_with_session(value: T, _session: &vortex_session::VortexSession) -> Self + impl vortex_array::dtype::arrow::FromArrowType> for vortex_array::dtype::DType pub fn vortex_array::dtype::DType::from_arrow(value: arrow_schema::schema::SchemaRef) -> Self +pub fn vortex_array::dtype::DType::from_arrow_with_session(value: arrow_schema::schema::SchemaRef, session: &vortex_session::VortexSession) -> Self + pub trait vortex_array::dtype::arrow::TryFromArrowType: core::marker::Sized pub fn vortex_array::dtype::arrow::TryFromArrowType::try_from_arrow(value: T) -> vortex_error::VortexResult @@ -9062,14 +9150,36 @@ pub mod vortex_array::dtype::serde pub mod vortex_array::dtype::session +pub struct vortex_array::dtype::session::ArrowCanonicalCodec + +pub vortex_array::dtype::session::ArrowCanonicalCodec::from_json: fn(&str) -> vortex_error::VortexResult> + +pub vortex_array::dtype::session::ArrowCanonicalCodec::to_json: fn(&[u8]) -> vortex_error::VortexResult + +impl core::clone::Clone for vortex_array::dtype::session::ArrowCanonicalCodec + +pub fn vortex_array::dtype::session::ArrowCanonicalCodec::clone(&self) -> vortex_array::dtype::session::ArrowCanonicalCodec + +impl core::fmt::Debug for vortex_array::dtype::session::ArrowCanonicalCodec + +pub fn vortex_array::dtype::session::ArrowCanonicalCodec::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::marker::Copy for vortex_array::dtype::session::ArrowCanonicalCodec + pub struct vortex_array::dtype::session::DTypeSession impl vortex_array::dtype::session::DTypeSession +pub fn vortex_array::dtype::session::DTypeSession::arrow_alias_for(&self, vortex_id: &vortex_array::dtype::extension::ExtId) -> core::option::Option<(vortex_array::dtype::extension::ExtId, vortex_array::dtype::session::ArrowCanonicalCodec)> + pub fn vortex_array::dtype::session::DTypeSession::register(&self, vtable: V) +pub fn vortex_array::dtype::session::DTypeSession::register_arrow_canonical(&self, vortex_id: vortex_array::dtype::extension::ExtId, arrow_id: vortex_array::dtype::extension::ExtId, codec: vortex_array::dtype::session::ArrowCanonicalCodec) + pub fn vortex_array::dtype::session::DTypeSession::registry(&self) -> &vortex_array::dtype::session::ExtDTypeRegistry +pub fn vortex_array::dtype::session::DTypeSession::vortex_alias_for(&self, arrow_id: &vortex_array::dtype::extension::ExtId) -> core::option::Option<(vortex_array::dtype::extension::ExtId, vortex_array::dtype::session::ArrowCanonicalCodec)> + impl core::default::Default for vortex_array::dtype::session::DTypeSession pub fn vortex_array::dtype::session::DTypeSession::default() -> Self @@ -9242,6 +9352,8 @@ pub fn vortex_array::dtype::DType::to_arrow_dtype(&self) -> vortex_error::Vortex pub fn vortex_array::dtype::DType::to_arrow_schema(&self) -> vortex_error::VortexResult +pub fn vortex_array::dtype::DType::to_arrow_schema_with_session(&self, session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl core::clone::Clone for vortex_array::dtype::DType pub fn vortex_array::dtype::DType::clone(&self) -> vortex_array::dtype::DType @@ -9304,18 +9416,26 @@ impl vortex_array::dtype::arrow::FromArrowType<&arrow_schema::field::Field> for pub fn vortex_array::dtype::DType::from_arrow(field: &arrow_schema::field::Field) -> Self +pub fn vortex_array::dtype::DType::from_arrow_with_session(field: &arrow_schema::field::Field, session: &vortex_session::VortexSession) -> Self + impl vortex_array::dtype::arrow::FromArrowType<&arrow_schema::schema::Schema> for vortex_array::dtype::DType pub fn vortex_array::dtype::DType::from_arrow(value: &arrow_schema::schema::Schema) -> Self +pub fn vortex_array::dtype::DType::from_arrow_with_session(value: &arrow_schema::schema::Schema, session: &vortex_session::VortexSession) -> Self + impl vortex_array::dtype::arrow::FromArrowType<(&arrow_schema::datatype::DataType, vortex_array::dtype::Nullability)> for vortex_array::dtype::DType pub fn vortex_array::dtype::DType::from_arrow((data_type, nullability): (&arrow_schema::datatype::DataType, vortex_array::dtype::Nullability)) -> Self +pub fn vortex_array::dtype::DType::from_arrow_with_session(value: T, _session: &vortex_session::VortexSession) -> Self + impl vortex_array::dtype::arrow::FromArrowType> for vortex_array::dtype::DType pub fn vortex_array::dtype::DType::from_arrow(value: arrow_schema::schema::SchemaRef) -> Self +pub fn vortex_array::dtype::DType::from_arrow_with_session(value: arrow_schema::schema::SchemaRef, session: &vortex_session::VortexSession) -> Self + impl vortex_flatbuffers::FlatBufferRoot for vortex_array::dtype::DType impl vortex_flatbuffers::WriteFlatBuffer for vortex_array::dtype::DType @@ -10270,6 +10390,8 @@ impl vortex_array::dtype::arrow::FromArrowType<&arrow_schema::fields::Fields> fo pub fn vortex_array::dtype::StructFields::from_arrow(value: &arrow_schema::fields::Fields) -> Self +pub fn vortex_array::dtype::StructFields::from_arrow_with_session(value: &arrow_schema::fields::Fields, session: &vortex_session::VortexSession) -> Self + impl core::iter::traits::collect::FromIterator<(T, V)> for vortex_array::dtype::StructFields where T: core::convert::Into, V: core::convert::Into pub fn vortex_array::dtype::StructFields::from_iter>(iter: I) -> Self @@ -22612,130 +22734,194 @@ impl vortex_array::arrow::FromArrowArray<&arrow_array::array::boolean_array::Boo pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::boolean_array::BooleanArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::fixed_size_list_array::FixedSizeListArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &arrow_array::array::fixed_size_list_array::FixedSizeListArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: &arrow_array::array::fixed_size_list_array::FixedSizeListArray, nullable: bool, session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::null_array::NullArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::null_array::NullArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::primitive_array::PrimitiveArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::primitive_array::PrimitiveArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::struct_array::StructArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::struct_array::StructArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(value: &arrow_array::array::struct_array::StructArray, nullable: bool, session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::record_batch::RecordBatch> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &arrow_array::record_batch::RecordBatch, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: &arrow_array::record_batch::RecordBatch, nullable: bool, session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&dyn arrow_array::array::Array> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: &dyn arrow_array::array::Array, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: &dyn arrow_array::array::Array, nullable: bool, session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(array: arrow_array::record_batch::RecordBatch, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: arrow_array::record_batch::RecordBatch, nullable: bool, session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::IntoArrowArray for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::into_arrow(self, data_type: &arrow_schema::datatype::DataType) -> vortex_error::VortexResult @@ -22804,18 +22990,26 @@ impl, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::list_array::GenericListArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::list_array::GenericListArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(value: &arrow_array::array::list_array::GenericListArray, nullable: bool, session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::byte_array::GenericByteArray> for vortex_array::ArrayRef where ::Offset: vortex_array::dtype::IntegerPType pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::byte_array::GenericByteArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl vortex_array::arrow::FromArrowArray<&arrow_array::array::byte_view_array::GenericByteViewArray> for vortex_array::ArrayRef pub fn vortex_array::ArrayRef::from_arrow(value: &arrow_array::array::byte_view_array::GenericByteViewArray, nullable: bool) -> vortex_error::VortexResult +pub fn vortex_array::ArrayRef::from_arrow_with_session(array: A, nullable: bool, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + impl core::convert::AsRef for vortex_array::Array pub fn vortex_array::Array::as_ref(&self) -> &vortex_array::ArrayRef diff --git a/vortex-array/src/arrow/convert.rs b/vortex-array/src/arrow/convert.rs index 6eafa6033c5..556df3a725a 100644 --- a/vortex-array/src/arrow/convert.rs +++ b/vortex-array/src/arrow/convert.rs @@ -56,6 +56,7 @@ use arrow_buffer::ScalarBuffer; use arrow_buffer::buffer::NullBuffer; use arrow_buffer::buffer::OffsetBuffer; use arrow_schema::DataType; +use arrow_schema::Field; use arrow_schema::TimeUnit as ArrowTimeUnit; use itertools::Itertools; use vortex_buffer::Alignment; @@ -66,12 +67,15 @@ use vortex_error::VortexExpect as _; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_panic; +use vortex_session::VortexSession; use crate::ArrayRef; use crate::IntoArray; +use crate::LEGACY_SESSION; use crate::arrays::BoolArray; use crate::arrays::DecimalArray; use crate::arrays::DictArray; +use crate::arrays::ExtensionArray; use crate::arrays::FixedSizeListArray; use crate::arrays::ListArray; use crate::arrays::ListViewArray; @@ -87,7 +91,9 @@ use crate::dtype::DecimalDType; use crate::dtype::IntegerPType; use crate::dtype::NativePType; use crate::dtype::PType; +use crate::dtype::arrow::resolve_extension_dtype; use crate::dtype::i256; +use crate::dtype::session::DTypeSessionExt; use crate::extension::datetime::TimeUnit; use crate::validity::Validity; @@ -380,23 +386,33 @@ fn remove_nulls(data: arrow_data::ArrayData) -> arrow_data::ArrayData { impl FromArrowArray<&ArrowStructArray> for ArrayRef { fn from_arrow(value: &ArrowStructArray, nullable: bool) -> VortexResult { + Self::from_arrow_with_session(value, nullable, &LEGACY_SESSION) + } + + fn from_arrow_with_session( + value: &ArrowStructArray, + nullable: bool, + session: &VortexSession, + ) -> VortexResult { + let columns = value + .columns() + .iter() + .zip(value.fields()) + .map(|(c, field)| { + // Arrow pushes down nulls, even into non-nullable fields. So we strip them + // out here because Vortex is a little more strict. + let storage = if c.null_count() > 0 && !field.is_nullable() { + let stripped = make_array(remove_nulls(c.into_data())); + Self::from_arrow_with_session(stripped.as_ref(), false, session)? + } else { + Self::from_arrow_with_session(c.as_ref(), field.is_nullable(), session)? + }; + wrap_extension_if_field_has_metadata(storage, field.as_ref(), session) + }) + .collect::>>()?; Ok(StructArray::try_new( value.column_names().iter().copied().collect(), - value - .columns() - .iter() - .zip(value.fields()) - .map(|(c, field)| { - // Arrow pushes down nulls, even into non-nullable fields. So we strip them - // out here because Vortex is a little more strict. - if c.null_count() > 0 && !field.is_nullable() { - let stripped = make_array(remove_nulls(c.into_data())); - Self::from_arrow(stripped.as_ref(), false) - } else { - Self::from_arrow(c.as_ref(), field.is_nullable()) - } - }) - .collect::>>()?, + columns, value.len(), nulls(value.nulls(), nullable), )? @@ -406,14 +422,27 @@ impl FromArrowArray<&ArrowStructArray> for ArrayRef { impl FromArrowArray<&GenericListArray> for ArrayRef { fn from_arrow(value: &GenericListArray, nullable: bool) -> VortexResult { - // Extract the validity of the underlying element array. - let elements_are_nullable = match value.data_type() { - DataType::List(field) => field.is_nullable(), - DataType::LargeList(field) => field.is_nullable(), + Self::from_arrow_with_session(value, nullable, &LEGACY_SESSION) + } + + fn from_arrow_with_session( + value: &GenericListArray, + nullable: bool, + session: &VortexSession, + ) -> VortexResult { + let elements_field: &Field = match value.data_type() { + DataType::List(field) => field.as_ref(), + DataType::LargeList(field) => field.as_ref(), dt => vortex_panic!("Invalid data type for ListArray: {dt}"), }; - let elements = Self::from_arrow(value.values().as_ref(), elements_are_nullable)?; + let elements_storage = Self::from_arrow_with_session( + value.values().as_ref(), + elements_field.is_nullable(), + session, + )?; + let elements = + wrap_extension_if_field_has_metadata(elements_storage, elements_field, session)?; // `offsets` are always non-nullable. let offsets = value.offsets().clone().into_array(); @@ -445,12 +474,25 @@ impl FromArrowArray<&GenericListViewArray> impl FromArrowArray<&ArrowFixedSizeListArray> for ArrayRef { fn from_arrow(array: &ArrowFixedSizeListArray, nullable: bool) -> VortexResult { + Self::from_arrow_with_session(array, nullable, &LEGACY_SESSION) + } + + fn from_arrow_with_session( + array: &ArrowFixedSizeListArray, + nullable: bool, + session: &VortexSession, + ) -> VortexResult { let DataType::FixedSizeList(field, list_size) = array.data_type() else { vortex_panic!("Invalid data type for ListArray: {}", array.data_type()); }; + let elements_storage = + Self::from_arrow_with_session(array.values().as_ref(), field.is_nullable(), session)?; + let elements = + wrap_extension_if_field_has_metadata(elements_storage, field.as_ref(), session)?; + Ok(FixedSizeListArray::try_new( - Self::from_arrow(array.values().as_ref(), field.is_nullable())?, + elements, *list_size as u32, nulls(array.nulls(), nullable), array.len(), @@ -494,6 +536,30 @@ fn nulls(nulls: Option<&NullBuffer>, nullable: bool) -> Validity { } impl FromArrowArray<&dyn ArrowArray> for ArrayRef { + fn from_arrow_with_session( + array: &dyn ArrowArray, + nullable: bool, + session: &VortexSession, + ) -> VortexResult { + match array.data_type() { + DataType::Struct(_) => { + Self::from_arrow_with_session(array.as_struct(), nullable, session) + } + DataType::List(_) => { + Self::from_arrow_with_session(array.as_list::(), nullable, session) + } + DataType::LargeList(_) => { + Self::from_arrow_with_session(array.as_list::(), nullable, session) + } + DataType::FixedSizeList(..) => { + Self::from_arrow_with_session(array.as_fixed_size_list(), nullable, session) + } + // Other arrays don't carry child Fields, so session-aware dispatch is identical to + // the legacy path; fall through to `from_arrow`. + _ => Self::from_arrow(array, nullable), + } + } + fn from_arrow(array: &dyn ArrowArray, nullable: bool) -> VortexResult { match array.data_type() { DataType::Boolean => Self::from_arrow(array.as_boolean(), nullable), @@ -617,13 +683,45 @@ impl FromArrowArray<&dyn ArrowArray> for ArrayRef { impl FromArrowArray for ArrayRef { fn from_arrow(array: RecordBatch, nullable: bool) -> VortexResult { - ArrayRef::from_arrow(&arrow_array::StructArray::from(array), nullable) + Self::from_arrow_with_session(array, nullable, &LEGACY_SESSION) + } + + fn from_arrow_with_session( + array: RecordBatch, + nullable: bool, + session: &VortexSession, + ) -> VortexResult { + Self::from_arrow_with_session(&arrow_array::StructArray::from(array), nullable, session) } } impl FromArrowArray<&RecordBatch> for ArrayRef { fn from_arrow(array: &RecordBatch, nullable: bool) -> VortexResult { - Self::from_arrow(array.clone(), nullable) + Self::from_arrow_with_session(array, nullable, &LEGACY_SESSION) + } + + fn from_arrow_with_session( + array: &RecordBatch, + nullable: bool, + session: &VortexSession, + ) -> VortexResult { + Self::from_arrow_with_session(array.clone(), nullable, session) + } +} + +/// Inverse of `field_from_dtype` (in `dtype/arrow.rs`): if `field` carries +/// `ARROW:extension:name` metadata for a registered extension, rewrap `storage` as an +/// `ExtensionArray`; otherwise fall through to `storage`. Diagnostic warnings live in +/// [`resolve_extension_dtype`]. +fn wrap_extension_if_field_has_metadata( + storage: ArrayRef, + field: &Field, + session: &VortexSession, +) -> VortexResult { + let dtypes = session.dtypes(); + match resolve_extension_dtype(field, &dtypes, storage.dtype()) { + Some(ext_dtype) => Ok(ExtensionArray::try_new(ext_dtype, storage)?.into_array()), + None => Ok(storage), } } diff --git a/vortex-array/src/arrow/executor/mod.rs b/vortex-array/src/arrow/executor/mod.rs index 890e7f8a46a..66e7afd7543 100644 --- a/vortex-array/src/arrow/executor/mod.rs +++ b/vortex-array/src/arrow/executor/mod.rs @@ -30,8 +30,10 @@ use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use crate::ArrayRef; +use crate::arrays::ExtensionArray; use crate::arrays::List; use crate::arrays::VarBin; +use crate::arrays::extension::ExtensionArrayExt; use crate::arrays::list::ListArrayExt; use crate::arrays::varbin::VarBinArrayExt; use crate::arrow::executor::bool::to_arrow_bool; @@ -87,6 +89,12 @@ impl ArrowArrayExecutor for ArrayRef { data_type: Option<&DataType>, ctx: &mut ExecutionCtx, ) -> VortexResult { + // Extension identity lives on Field metadata; dispatch on the storage array. + if matches!(self.dtype(), DType::Extension(_)) { + let ext = self.execute::(ctx)?; + return ext.storage_array().clone().execute_arrow(data_type, ctx); + } + let len = self.len(); // Resolve the DataType if it is a leaf type @@ -228,3 +236,37 @@ fn preferred_arrow_type(array: &ArrayRef) -> VortexResult { // Everything else: use canonical dtype conversion array.dtype().to_arrow_dtype() } + +#[cfg(test)] +mod tests { + use arrow_array::cast::AsArray; + use arrow_array::types::UInt64Type; + use arrow_schema::DataType; + + use super::*; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; + use crate::array::IntoArray; + use crate::arrays::ExtensionArray; + use crate::arrays::PrimitiveArray; + use crate::extension::tests::divisible_int::DivisibleInt; + use crate::extension::tests::divisible_int::Divisor; + + #[test] + fn execute_arrow_unwraps_extension_to_storage() { + let storage = PrimitiveArray::from_iter(0u64..6).into_array(); + let ext = ExtensionArray::try_new_from_vtable(DivisibleInt, Divisor(1), storage) + .unwrap() + .into_array(); + + let arrow = ext + .execute_arrow( + Some(&DataType::UInt64), + &mut LEGACY_SESSION.create_execution_ctx(), + ) + .unwrap(); + + let primitives = arrow.as_primitive::(); + assert_eq!(primitives.values(), &[0, 1, 2, 3, 4, 5]); + } +} diff --git a/vortex-array/src/arrow/mod.rs b/vortex-array/src/arrow/mod.rs index efc83aa6af6..52e32bcf3d7 100644 --- a/vortex-array/src/arrow/mod.rs +++ b/vortex-array/src/arrow/mod.rs @@ -6,6 +6,7 @@ use arrow_array::ArrayRef as ArrowArrayRef; use arrow_schema::DataType; use vortex_error::VortexResult; +use vortex_session::VortexSession; mod convert; mod datum; @@ -24,10 +25,19 @@ use crate::ArrayRef; use crate::LEGACY_SESSION; use crate::VortexSessionExecute; -pub trait FromArrowArray { - fn from_arrow(array: A, nullable: bool) -> VortexResult - where - Self: Sized; +pub trait FromArrowArray: Sized { + fn from_arrow(array: A, nullable: bool) -> VortexResult; + + /// Same conversion, with session for resolving `ARROW:extension:name` field metadata to + /// registered extension dtypes. The default ignores the session — override on impls that + /// see Arrow `Field`s (RecordBatch, Struct, List, FSL). + fn from_arrow_with_session( + array: A, + nullable: bool, + _session: &VortexSession, + ) -> VortexResult { + Self::from_arrow(array, nullable) + } } #[deprecated(note = "Use `execute_arrow(None, ctx)` or `execute_arrow(Some(dt), ctx)` instead")] diff --git a/vortex-array/src/dtype/arrow.rs b/vortex-array/src/dtype/arrow.rs index 17af749cfc0..f55954ff639 100644 --- a/vortex-array/src/dtype/arrow.rs +++ b/vortex-array/src/dtype/arrow.rs @@ -23,19 +23,30 @@ use arrow_schema::Schema; use arrow_schema::SchemaBuilder; use arrow_schema::SchemaRef; use arrow_schema::TimeUnit as ArrowTimeUnit; +use arrow_schema::extension::EXTENSION_TYPE_METADATA_KEY; +use arrow_schema::extension::EXTENSION_TYPE_NAME_KEY; +use base64::Engine; +use base64::prelude::BASE64_STANDARD; use vortex_error::VortexError; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_err; use vortex_error::vortex_panic; +use vortex_session::VortexSession; +use crate::LEGACY_SESSION; use crate::dtype::DType; use crate::dtype::DecimalDType; use crate::dtype::FieldName; use crate::dtype::Nullability; use crate::dtype::PType; use crate::dtype::StructFields; +use crate::dtype::extension::ExtDTypeRef; +use crate::dtype::extension::ExtId; +use crate::dtype::session::ArrowCanonicalCodec; +use crate::dtype::session::DTypeSession; +use crate::dtype::session::DTypeSessionExt; use crate::extension::datetime::AnyTemporal; use crate::extension::datetime::Date; use crate::extension::datetime::TemporalMetadata; @@ -43,10 +54,19 @@ use crate::extension::datetime::Time; use crate::extension::datetime::TimeUnit; use crate::extension::datetime::Timestamp; +const ARROW_EXT_NAME_VARIANT: &str = "arrow.parquet.variant"; + /// Trait for converting Arrow types to Vortex types. pub trait FromArrowType: Sized { /// Convert the Arrow type to a Vortex type. fn from_arrow(value: T) -> Self; + + /// Convert the Arrow type to a Vortex type, consulting `session` for extension lookup. + /// + /// Unregistered or malformed extension metadata falls back to the storage dtype. + fn from_arrow_with_session(value: T, _session: &VortexSession) -> Self { + Self::from_arrow(value) + } } /// Trait for converting Vortex types to Arrow types. @@ -126,14 +146,22 @@ impl TryFrom for ArrowTimeUnit { impl FromArrowType for DType { fn from_arrow(value: SchemaRef) -> Self { - Self::from_arrow(value.as_ref()) + Self::from_arrow_with_session(value, &LEGACY_SESSION) + } + + fn from_arrow_with_session(value: SchemaRef, session: &VortexSession) -> Self { + Self::from_arrow_with_session(value.as_ref(), session) } } impl FromArrowType<&Schema> for DType { fn from_arrow(value: &Schema) -> Self { + Self::from_arrow_with_session(value, &LEGACY_SESSION) + } + + fn from_arrow_with_session(value: &Schema, session: &VortexSession) -> Self { Self::Struct( - StructFields::from_arrow(value.fields()), + StructFields::from_arrow_with_session(value.fields(), session), Nullability::NonNullable, // Must match From for Array ) } @@ -141,10 +169,15 @@ impl FromArrowType<&Schema> for DType { impl FromArrowType<&Fields> for StructFields { fn from_arrow(value: &Fields) -> Self { + Self::from_arrow_with_session(value, &LEGACY_SESSION) + } + + fn from_arrow_with_session(value: &Fields, session: &VortexSession) -> Self { + let dtypes = session.dtypes(); StructFields::from_iter(value.into_iter().map(|f| { ( FieldName::from(f.name().as_str()), - DType::from_arrow(f.as_ref()), + dtype_from_field(f.as_ref(), &dtypes), ) })) } @@ -210,21 +243,142 @@ impl FromArrowType<(&DataType, Nullability)> for DType { impl FromArrowType<&Field> for DType { fn from_arrow(field: &Field) -> Self { - if field - .metadata() - .get("ARROW:extension:name") - .map(|s| s.as_str()) - == Some("arrow.parquet.variant") - { - return DType::Variant(field.is_nullable().into()); + Self::from_arrow_with_session(field, &LEGACY_SESSION) + } + + fn from_arrow_with_session(field: &Field, session: &VortexSession) -> Self { + dtype_from_field(field, &session.dtypes()) + } +} + +/// Convert an Arrow Field to a [`DType`] with `dtypes` already borrowed from the session, +/// so the handle is acquired once per schema rather than once per field. +fn dtype_from_field(field: &Field, dtypes: &DTypeSession) -> DType { + if field + .extension_type_name() + .is_some_and(|s| s == ARROW_EXT_NAME_VARIANT) + { + return DType::Variant(field.is_nullable().into()); + } + + let storage_dtype = storage_dtype_from_field(field, dtypes); + match resolve_extension_dtype(field, dtypes, &storage_dtype) { + Some(ext_ref) => DType::Extension(ext_ref), + None => storage_dtype, + } +} + +/// Resolve the [`ExtDTypeRef`] for an Arrow Field whose `ARROW:extension:name` metadata names +/// a registered Vortex extension. Returns `None` for unregistered extensions, malformed +/// metadata, or fields with no extension name; `tracing::warn!` reports the anomaly so callers +/// can simply fall back to the storage representation. +/// +/// Used on both the dtype side ([`dtype_from_field`]) and the array side +/// (`wrap_extension_if_field_has_metadata`); only the final wrap differs. +pub(crate) fn resolve_extension_dtype( + field: &Field, + dtypes: &DTypeSession, + storage_dtype: &DType, +) -> Option { + let ext_name = field.extension_type_name()?; + if ext_name == ARROW_EXT_NAME_VARIANT { + return None; + } + + let arrow_id = ExtId::new(ext_name); + let (ext_id, codec) = match dtypes.vortex_alias_for(&arrow_id) { + Some((vortex_id, codec)) => (vortex_id, Some(codec)), + None => (arrow_id, None), + }; + + let Some(plugin) = dtypes.registry().find(&ext_id) else { + tracing::warn!( + "Arrow field {:?} extension id {ext_name:?} not registered; using storage dtype", + field.name(), + ); + return None; + }; + + let metadata_bytes = match decode_extension_metadata(field, codec) { + Ok(bytes) => bytes, + Err(e) => { + tracing::warn!( + "Arrow field {:?} extension id {ext_name:?} has malformed metadata ({e}); \ + using storage dtype", + field.name(), + ); + return None; + } + }; + + match plugin.deserialize(&metadata_bytes, storage_dtype.clone()) { + Ok(ext_ref) => Some(ext_ref), + Err(e) => { + tracing::warn!( + "Arrow field {:?} extension id {ext_name:?} failed to deserialize ({e}); \ + using storage dtype", + field.name(), + ); + None + } + } +} + +/// Non-canonical extensions base64-encode arbitrary binary metadata to survive Arrow's +/// String-typed metadata channel; canonical extensions go through the registered codec. +fn decode_extension_metadata( + field: &Field, + codec: Option, +) -> VortexResult> { + match field.extension_type_metadata() { + None | Some("") => Ok(Vec::new()), + Some(s) => match codec { + Some(codec) => (codec.from_json)(s), + None => BASE64_STANDARD.decode(s).map_err(|e| { + vortex_err!("failed to base64-decode {EXTENSION_TYPE_METADATA_KEY}: {e}") + }), + }, + } +} + +/// Recursively build the storage [`DType`] for an Arrow Field, threading `dtypes` through +/// nested child fields so nested extensions are also resolved. +fn storage_dtype_from_field(field: &Field, dtypes: &DTypeSession) -> DType { + let nullability: Nullability = field.is_nullable().into(); + match field.data_type() { + DataType::Struct(f) => DType::Struct( + StructFields::from_iter(f.into_iter().map(|child| { + ( + FieldName::from(child.name().as_str()), + dtype_from_field(child.as_ref(), dtypes), + ) + })), + nullability, + ), + DataType::List(e) + | DataType::LargeList(e) + | DataType::ListView(e) + | DataType::LargeListView(e) => { + DType::List(Arc::new(dtype_from_field(e.as_ref(), dtypes)), nullability) } - Self::from_arrow((field.data_type(), field.is_nullable().into())) + DataType::FixedSizeList(e, size) => DType::FixedSizeList( + Arc::new(dtype_from_field(e.as_ref(), dtypes)), + *size as u32, + nullability, + ), + other => DType::from_arrow((other, nullability)), } } impl DType { /// Convert a Vortex [`DType`] into an Arrow [`Schema`]. pub fn to_arrow_schema(&self) -> VortexResult { + self.to_arrow_schema_with_session(&LEGACY_SESSION) + } + + /// Convert a Vortex [`DType`] into an Arrow [`Schema`], consulting `session` for Arrow + /// canonical extension aliases registered via [`DTypeSession::register_arrow_canonical`]. + pub fn to_arrow_schema_with_session(&self, session: &VortexSession) -> VortexResult { let DType::Struct(struct_dtype, nullable) = self else { vortex_bail!("only DType::Struct can be converted to arrow schema"); }; @@ -233,25 +387,14 @@ impl DType { vortex_bail!("top-level struct in Schema must be NonNullable"); } + let dtypes = session.dtypes(); let mut builder = SchemaBuilder::with_capacity(struct_dtype.names().len()); for (field_name, field_dtype) in struct_dtype.names().iter().zip(struct_dtype.fields()) { - let field = if field_dtype.is_variant() { - let storage = DataType::Struct(variant_storage_fields_minimal()); - Field::new(field_name.as_ref(), storage, field_dtype.is_nullable()).with_metadata( - [( - "ARROW:extension:name".to_owned(), - "arrow.parquet.variant".to_owned(), - )] - .into(), - ) - } else { - Field::new( - field_name.as_ref(), - field_dtype.to_arrow_dtype()?, - field_dtype.is_nullable(), - ) - }; - builder.push(field); + builder.push(field_from_dtype( + field_name.as_ref(), + &field_dtype, + &dtypes, + )?); } Ok(builder.finish()) @@ -259,100 +402,163 @@ impl DType { /// Returns the Arrow [`DataType`] that best corresponds to this Vortex [`DType`]. pub fn to_arrow_dtype(&self) -> VortexResult { - Ok(match self { - DType::Null => DataType::Null, - DType::Bool(_) => DataType::Boolean, - DType::Primitive(ptype, _) => match ptype { - PType::U8 => DataType::UInt8, - PType::U16 => DataType::UInt16, - PType::U32 => DataType::UInt32, - PType::U64 => DataType::UInt64, - PType::I8 => DataType::Int8, - PType::I16 => DataType::Int16, - PType::I32 => DataType::Int32, - PType::I64 => DataType::Int64, - PType::F16 => DataType::Float16, - PType::F32 => DataType::Float32, - PType::F64 => DataType::Float64, - }, - DType::Decimal(dt, _) => { - let precision = dt.precision(); - let scale = dt.scale(); - - match precision { - // This code is commented out until DataFusion improves its support for smaller decimals. - // // DECIMAL32_MAX_PRECISION - // 0..=9 => DataType::Decimal32(precision, scale), - // // DECIMAL64_MAX_PRECISION - // 10..=18 => DataType::Decimal64(precision, scale), - // DECIMAL128_MAX_PRECISION - 0..=38 => DataType::Decimal128(precision, scale), - // DECIMAL256_MAX_PRECISION - 39.. => DataType::Decimal256(precision, scale), - } + arrow_dtype_from_dtype(self, &LEGACY_SESSION.dtypes()) + } +} + +fn arrow_dtype_from_dtype(dtype: &DType, dtypes: &DTypeSession) -> VortexResult { + Ok(match dtype { + DType::Null => DataType::Null, + DType::Bool(_) => DataType::Boolean, + DType::Primitive(ptype, _) => match ptype { + PType::U8 => DataType::UInt8, + PType::U16 => DataType::UInt16, + PType::U32 => DataType::UInt32, + PType::U64 => DataType::UInt64, + PType::I8 => DataType::Int8, + PType::I16 => DataType::Int16, + PType::I32 => DataType::Int32, + PType::I64 => DataType::Int64, + PType::F16 => DataType::Float16, + PType::F32 => DataType::Float32, + PType::F64 => DataType::Float64, + }, + DType::Decimal(dt, _) => { + let precision = dt.precision(); + let scale = dt.scale(); + + match precision { + // This code is commented out until DataFusion improves its support for smaller decimals. + // // DECIMAL32_MAX_PRECISION + // 0..=9 => DataType::Decimal32(precision, scale), + // // DECIMAL64_MAX_PRECISION + // 10..=18 => DataType::Decimal64(precision, scale), + // DECIMAL128_MAX_PRECISION + 0..=38 => DataType::Decimal128(precision, scale), + // DECIMAL256_MAX_PRECISION + 39.. => DataType::Decimal256(precision, scale), } - DType::Utf8(_) => DataType::Utf8View, - DType::Binary(_) => DataType::BinaryView, - // There are four kinds of lists: List (32-bit offsets), Large List (64-bit), List View - // (32-bit), Large List View (64-bit). We cannot both guarantee zero-copy and commit to an - // Arrow dtype because we do not how large our offsets are. - DType::List(elem_dtype, _) => DataType::List(FieldRef::new(Field::new_list_field( - elem_dtype.to_arrow_dtype()?, - elem_dtype.nullability().into(), - ))), - DType::FixedSizeList(elem_dtype, size, _) => DataType::FixedSizeList( - FieldRef::new(Field::new_list_field( - elem_dtype.to_arrow_dtype()?, - elem_dtype.nullability().into(), - )), - *size as i32, - ), - DType::Struct(struct_dtype, _) => { - let mut fields = Vec::with_capacity(struct_dtype.names().len()); - for (field_name, field_dt) in struct_dtype.names().iter().zip(struct_dtype.fields()) - { - fields.push(FieldRef::from(Field::new( - field_name.as_ref(), - field_dt.to_arrow_dtype()?, - field_dt.is_nullable(), - ))); - } - - DataType::Struct(Fields::from(fields)) + } + DType::Utf8(_) => DataType::Utf8View, + DType::Binary(_) => DataType::BinaryView, + // There are four kinds of lists: List (32-bit offsets), Large List (64-bit), List View + // (32-bit), Large List View (64-bit). We cannot both guarantee zero-copy and commit to an + // Arrow dtype because we do not how large our offsets are. + DType::List(elem_dtype, _) => DataType::List(FieldRef::new(field_from_dtype( + Field::LIST_FIELD_DEFAULT_NAME, + elem_dtype, + dtypes, + )?)), + DType::FixedSizeList(elem_dtype, size, _) => DataType::FixedSizeList( + FieldRef::new(field_from_dtype( + Field::LIST_FIELD_DEFAULT_NAME, + elem_dtype, + dtypes, + )?), + *size as i32, + ), + DType::Struct(struct_dtype, _) => { + let mut fields = Vec::with_capacity(struct_dtype.names().len()); + for (field_name, field_dt) in struct_dtype.names().iter().zip(struct_dtype.fields()) { + fields.push(FieldRef::from(field_from_dtype( + field_name.as_ref(), + &field_dt, + dtypes, + )?)); } - DType::Variant(_) => vortex_bail!( - "DType::Variant requires Arrow Field metadata; use to_arrow_schema or a Field helper" - ), - DType::Extension(ext_dtype) => { - // Try and match against the known extension DTypes. - if let Some(temporal) = ext_dtype.metadata_opt::() { - return Ok(match temporal { - TemporalMetadata::Timestamp(unit, tz) => { - DataType::Timestamp(ArrowTimeUnit::try_from(*unit)?, tz.clone()) - } - TemporalMetadata::Date(unit) => match unit { - TimeUnit::Days => DataType::Date32, - TimeUnit::Milliseconds => DataType::Date64, - TimeUnit::Nanoseconds | TimeUnit::Microseconds | TimeUnit::Seconds => { - vortex_panic!(InvalidArgument: "Invalid TimeUnit {} for {}", unit, ext_dtype.id()) - } - }, - TemporalMetadata::Time(unit) => match unit { - TimeUnit::Seconds => DataType::Time32(ArrowTimeUnit::Second), - TimeUnit::Milliseconds => DataType::Time32(ArrowTimeUnit::Millisecond), - TimeUnit::Microseconds => DataType::Time64(ArrowTimeUnit::Microsecond), - TimeUnit::Nanoseconds => DataType::Time64(ArrowTimeUnit::Nanosecond), - TimeUnit::Days => { - vortex_panic!(InvalidArgument: "Invalid TimeUnit {} for {}", unit, ext_dtype.id()) - } - }, - }); - }; - - vortex_bail!("Unsupported extension type \"{}\"", ext_dtype.id()) + + DataType::Struct(Fields::from(fields)) + } + DType::Variant(_) => vortex_bail!( + "DType::Variant requires Arrow Field metadata; use to_arrow_schema or a Field helper" + ), + DType::Extension(ext_dtype) => { + if let Some(native) = native_arrow_dtype_for_extension(ext_dtype) { + return Ok(native); } - }) + // Extension identity lives on the Field (see `field_from_dtype`), not on + // DataType, so here we only encode the storage type. + arrow_dtype_from_dtype(ext_dtype.storage_dtype(), dtypes)? + } + }) +} + +/// Build an Arrow [`Field`], attaching `ARROW:extension:name` and, when present, +/// `ARROW:extension:metadata` for extensions and Variant that have no native Arrow mapping. +fn field_from_dtype(name: &str, dtype: &DType, dtypes: &DTypeSession) -> VortexResult { + if dtype.is_variant() { + let storage = DataType::Struct(variant_storage_fields_minimal()); + return Ok( + Field::new(name, storage, dtype.is_nullable()).with_metadata( + [( + EXTENSION_TYPE_NAME_KEY.to_owned(), + ARROW_EXT_NAME_VARIANT.to_owned(), + )] + .into(), + ), + ); + } + + if let DType::Extension(ext) = dtype { + // Native Arrow mapping carries the semantics in DataType; emitting extension metadata + // on top would break consumers that only understand native Arrow types. + if let Some(native) = native_arrow_dtype_for_extension(ext) { + return Ok(Field::new(name, native, dtype.is_nullable())); + } + + let storage_arrow = arrow_dtype_from_dtype(ext.storage_dtype(), dtypes)?; + let ext_meta_bytes = ext.serialize_metadata()?; + let (ext_name, meta_str) = match dtypes.arrow_alias_for(&ext.id()) { + Some((canonical, codec)) => ( + canonical.as_str().to_owned(), + (codec.to_json)(&ext_meta_bytes)?, + ), + None => ( + ext.id().as_str().to_owned(), + BASE64_STANDARD.encode(&ext_meta_bytes), + ), + }; + + let mut metadata = vec![(EXTENSION_TYPE_NAME_KEY.to_owned(), ext_name)]; + if !meta_str.is_empty() { + metadata.push((EXTENSION_TYPE_METADATA_KEY.to_owned(), meta_str)); + } + return Ok(Field::new(name, storage_arrow, dtype.is_nullable()) + .with_metadata(metadata.into_iter().collect())); } + + Ok(Field::new( + name, + arrow_dtype_from_dtype(dtype, dtypes)?, + dtype.is_nullable(), + )) +} + +/// Returns the native Arrow [`DataType`] for extensions Arrow models directly (e.g. temporal). +/// `None` means the extension should round-trip via storage + Field metadata. +fn native_arrow_dtype_for_extension(ext_dtype: &ExtDTypeRef) -> Option { + let temporal = ext_dtype.metadata_opt::()?; + Some(match temporal { + TemporalMetadata::Timestamp(unit, tz) => { + DataType::Timestamp(ArrowTimeUnit::try_from(*unit).ok()?, tz.clone()) + } + TemporalMetadata::Date(unit) => match unit { + TimeUnit::Days => DataType::Date32, + TimeUnit::Milliseconds => DataType::Date64, + TimeUnit::Nanoseconds | TimeUnit::Microseconds | TimeUnit::Seconds => { + vortex_panic!(InvalidArgument: "Invalid TimeUnit {} for {}", unit, ext_dtype.id()) + } + }, + TemporalMetadata::Time(unit) => match unit { + TimeUnit::Seconds => DataType::Time32(ArrowTimeUnit::Second), + TimeUnit::Milliseconds => DataType::Time32(ArrowTimeUnit::Millisecond), + TimeUnit::Microseconds => DataType::Time64(ArrowTimeUnit::Microsecond), + TimeUnit::Nanoseconds => DataType::Time64(ArrowTimeUnit::Nanosecond), + TimeUnit::Days => { + vortex_panic!(InvalidArgument: "Invalid TimeUnit {} for {}", unit, ext_dtype.id()) + } + }, + }) } fn variant_storage_fields_minimal() -> Fields { @@ -561,4 +767,167 @@ mod test { assert_eq!(original_dtype, roundtripped_dtype); } + + mod extension_roundtrip { + use vortex_session::VortexSession; + + use super::*; + use crate::dtype::extension::ExtDType; + use crate::dtype::session::DTypeSession; + use crate::dtype::session::DTypeSessionExt; + use crate::extension::tests::divisible_int::DivisibleInt; + use crate::extension::tests::divisible_int::Divisor; + + fn session_with_divisible_int() -> VortexSession { + let session = VortexSession::empty().with::(); + session.dtypes().register(DivisibleInt); + session + } + + fn divisible_ext(divisor: u64) -> DType { + let ext = ExtDType::::try_new( + Divisor(divisor), + DType::Primitive(PType::U64, Nullability::NonNullable), + ) + .unwrap(); + DType::Extension(ext.erased()) + } + + #[test] + fn forward_emits_name_and_base64_metadata() { + let dtype = DType::struct_([("div", divisible_ext(7))], Nullability::NonNullable); + + let schema = dtype.to_arrow_schema().unwrap(); + let field = schema.field(0); + + assert_eq!(field.data_type(), &DataType::UInt64); + assert_eq!( + field + .metadata() + .get(EXTENSION_TYPE_NAME_KEY) + .map(String::as_str), + Some("test.divisible_int"), + ); + + let meta_b64 = field.metadata().get(EXTENSION_TYPE_METADATA_KEY).unwrap(); + let decoded = BASE64_STANDARD.decode(meta_b64).unwrap(); + assert_eq!(decoded, 7u64.to_le_bytes()); + } + + #[test] + fn reverse_with_session_recovers_extension() { + let original = DType::struct_([("div", divisible_ext(42))], Nullability::NonNullable); + + let schema = original.to_arrow_schema().unwrap(); + let session = session_with_divisible_int(); + let recovered = DType::from_arrow_with_session(&schema, &session); + + assert_eq!(recovered, original); + } + + #[test] + fn reverse_without_registration_falls_back_to_storage() { + let original = DType::struct_([("div", divisible_ext(13))], Nullability::NonNullable); + + let schema = original.to_arrow_schema().unwrap(); + // DivisibleInt is not in the default DTypeSession. + let session = VortexSession::empty().with::(); + let recovered = DType::from_arrow_with_session(&schema, &session); + + let expected = DType::struct_( + [( + "div", + DType::Primitive(PType::U64, Nullability::NonNullable), + )], + Nullability::NonNullable, + ); + assert_eq!(recovered, expected); + } + + #[test] + fn nested_struct_roundtrip() { + let inner = DType::struct_([("div", divisible_ext(3))], Nullability::Nullable); + let original = DType::struct_([("inner", inner)], Nullability::NonNullable); + + let schema = original.to_arrow_schema().unwrap(); + let session = session_with_divisible_int(); + let recovered = DType::from_arrow_with_session(&schema, &session); + + assert_eq!(recovered, original); + } + + #[test] + fn list_element_roundtrip() { + let list_dtype = DType::List(Arc::new(divisible_ext(5)), Nullability::Nullable); + let original = DType::struct_([("xs", list_dtype)], Nullability::NonNullable); + + let schema = original.to_arrow_schema().unwrap(); + let session = session_with_divisible_int(); + let recovered = DType::from_arrow_with_session(&schema, &session); + + assert_eq!(recovered, original); + } + + #[test] + fn temporal_native_path_emits_no_extension_metadata() { + let ts = Timestamp::new_with_tz(TimeUnit::Microseconds, None, Nullability::Nullable); + let original = DType::struct_( + [("t", DType::Extension(ts.erased()))], + Nullability::NonNullable, + ); + + let schema = original.to_arrow_schema().unwrap(); + let field = schema.field(0); + + assert!(matches!( + field.data_type(), + DataType::Timestamp(ArrowTimeUnit::Microsecond, None) + )); + assert!(field.metadata().get(EXTENSION_TYPE_NAME_KEY).is_none()); + + let recovered = DType::from_arrow(&schema); + assert_eq!(recovered, original); + } + + #[test] + fn variant_still_roundtrips() { + let original = DType::struct_( + [("v", DType::Variant(Nullability::NonNullable))], + Nullability::NonNullable, + ); + let schema = original.to_arrow_schema().unwrap(); + let recovered = DType::from_arrow(&schema); + assert_eq!(recovered, original); + } + + #[test] + fn malformed_metadata_falls_back_to_storage() { + let field = Field::new("div", DataType::UInt64, false).with_metadata( + [ + ( + EXTENSION_TYPE_NAME_KEY.to_owned(), + "test.divisible_int".to_owned(), + ), + ( + EXTENSION_TYPE_METADATA_KEY.to_owned(), + "not_base64!!!".to_owned(), + ), + ] + .into(), + ); + let schema = Schema::new(Fields::from(vec![field])); + + let session = session_with_divisible_int(); + let recovered = DType::from_arrow_with_session(&schema, &session); + + let expected = DType::struct_( + [( + "div", + DType::Primitive(PType::U64, Nullability::NonNullable), + )], + Nullability::NonNullable, + ); + assert_eq!(recovered, expected); + } + } } diff --git a/vortex-array/src/dtype/session.rs b/vortex-array/src/dtype/session.rs index 2314658869f..0f0db8236c0 100644 --- a/vortex-array/src/dtype/session.rs +++ b/vortex-array/src/dtype/session.rs @@ -6,12 +6,16 @@ use std::any::Any; use std::sync::Arc; +use arc_swap::ArcSwap; +use vortex_error::VortexResult; use vortex_session::Ref; use vortex_session::SessionExt; use vortex_session::SessionVar; use vortex_session::registry::Registry; +use vortex_utils::aliases::hash_map::HashMap; use crate::dtype::extension::ExtDTypePluginRef; +use crate::dtype::extension::ExtId; use crate::dtype::extension::ExtVTable; use crate::extension::datetime::Date; use crate::extension::datetime::Time; @@ -20,19 +24,79 @@ use crate::extension::datetime::Timestamp; /// Registry for extension dtypes. pub type ExtDTypeRegistry = Registry; +/// Converters between an extension's on-disk metadata bytes and the Arrow canonical JSON wire. +/// +/// Bundled with the alias at registration time so [`ExtVTable`] stays Arrow-unaware. +#[derive(Copy, Clone, Debug)] +pub struct ArrowCanonicalCodec { + pub to_json: fn(&[u8]) -> VortexResult, + pub from_json: fn(&str) -> VortexResult>, +} + +/// Forward map is the canonical source: each Vortex extension owns its codec and points at the +/// Arrow canonical name it serializes as. Reverse map is a lookup index for the read path, +/// taking an Arrow name back to the Vortex id whose codec applies. +#[derive(Default, Clone)] +struct AliasState { + forward: HashMap, + reverse: HashMap, +} + +#[derive(Debug, Default)] +struct ArrowCanonicalAliases(ArcSwap); + +impl ArrowCanonicalAliases { + /// Re-registering evicts any prior alias touching either id so both directions agree. + fn register(&self, vortex_id: ExtId, arrow_id: ExtId, codec: ArrowCanonicalCodec) { + self.0.rcu(|prev| { + let mut next = (**prev).clone(); + if let Some((stale_arrow, _)) = next.forward.remove(&vortex_id) { + next.reverse.remove(&stale_arrow); + } + if let Some(stale_vortex) = next.reverse.remove(&arrow_id) { + next.forward.remove(&stale_vortex); + } + next.forward.insert(vortex_id, (arrow_id, codec)); + next.reverse.insert(arrow_id, vortex_id); + Arc::new(next) + }); + } + + fn arrow_alias_for(&self, vortex_id: &ExtId) -> Option<(ExtId, ArrowCanonicalCodec)> { + self.0.load().forward.get(vortex_id).copied() + } + + fn vortex_alias_for(&self, arrow_id: &ExtId) -> Option<(ExtId, ArrowCanonicalCodec)> { + let state = self.0.load(); + let vortex_id = *state.reverse.get(arrow_id)?; + let (_, codec) = *state.forward.get(&vortex_id)?; + Some((vortex_id, codec)) + } +} + +impl std::fmt::Debug for AliasState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AliasState") + .field("forward", &self.forward) + .field("reverse", &self.reverse) + .finish() + } +} + /// Session for managing extension dtypes. #[derive(Debug)] pub struct DTypeSession { registry: ExtDTypeRegistry, + arrow_canonical: ArrowCanonicalAliases, } impl Default for DTypeSession { fn default() -> Self { let this = Self { registry: Registry::default(), + arrow_canonical: ArrowCanonicalAliases::default(), }; - // Register built-in temporal extension dtypes this.register(Date); this.register(Time); this.register(Timestamp); @@ -62,6 +126,27 @@ impl DTypeSession { pub fn registry(&self) -> &ExtDTypeRegistry { &self.registry } + + /// Alias `arrow_id` to `vortex_id` with the codec used at the Arrow boundary. + /// Re-registering evicts the previous mapping for either side. + pub fn register_arrow_canonical( + &self, + vortex_id: ExtId, + arrow_id: ExtId, + codec: ArrowCanonicalCodec, + ) { + self.arrow_canonical.register(vortex_id, arrow_id, codec); + } + + /// Returns the Arrow canonical id and codec aliased to `vortex_id`, if any. + pub fn arrow_alias_for(&self, vortex_id: &ExtId) -> Option<(ExtId, ArrowCanonicalCodec)> { + self.arrow_canonical.arrow_alias_for(vortex_id) + } + + /// Returns the Vortex id and codec aliased to `arrow_id`, if any. + pub fn vortex_alias_for(&self, arrow_id: &ExtId) -> Option<(ExtId, ArrowCanonicalCodec)> { + self.arrow_canonical.vortex_alias_for(arrow_id) + } } /// Extension trait for accessing the DType session. @@ -75,3 +160,86 @@ impl DTypeSessionExt for S { self.get::() } } + +#[cfg(test)] +mod tests { + use vortex_error::vortex_err; + + use super::*; + + const TEST_CODEC: ArrowCanonicalCodec = ArrowCanonicalCodec { + to_json: |bytes| { + String::from_utf8(bytes.to_vec()).map_err(|e| vortex_err!("non-utf8 test bytes: {e}")) + }, + from_json: |s| Ok(s.as_bytes().to_vec()), + }; + + #[test] + fn arrow_canonical_re_registration_is_clean() { + let session = DTypeSession::default(); + let v = ExtId::new("vortex.test"); + let foo = ExtId::new("arrow.foo"); + let bar = ExtId::new("arrow.bar"); + + session.register_arrow_canonical(v, foo, TEST_CODEC); + assert_eq!(session.arrow_alias_for(&v).map(|(id, _)| id), Some(foo)); + assert_eq!(session.vortex_alias_for(&foo).map(|(id, _)| id), Some(v)); + + session.register_arrow_canonical(v, bar, TEST_CODEC); + assert_eq!(session.arrow_alias_for(&v).map(|(id, _)| id), Some(bar)); + assert_eq!(session.vortex_alias_for(&bar).map(|(id, _)| id), Some(v)); + assert!(session.vortex_alias_for(&foo).is_none()); + } + + /// `(vid → old, old → vid)` then `register(vid, new)` should leave `(vid → new, new → vid)`. + #[test] + fn rebind_vortex_id_to_new_arrow_name() { + let session = DTypeSession::default(); + let vid = ExtId::new("vortex.a"); + let old = ExtId::new("arrow.b"); + let new = ExtId::new("arrow.c"); + + session.register_arrow_canonical(vid, old, TEST_CODEC); + assert_eq!(session.arrow_alias_for(&vid).map(|(id, _)| id), Some(old)); + assert_eq!(session.vortex_alias_for(&old).map(|(id, _)| id), Some(vid)); + + session.register_arrow_canonical(vid, new, TEST_CODEC); + + assert_eq!(session.arrow_alias_for(&vid).map(|(id, _)| id), Some(new)); + assert_eq!(session.vortex_alias_for(&new).map(|(id, _)| id), Some(vid)); + assert!(session.vortex_alias_for(&old).is_none()); + } + + /// `(old → name, name → old)` then `register(new, name)` should leave `(new → name, name → new)`. + #[test] + fn steal_arrow_name_from_another_vortex_id() { + let session = DTypeSession::default(); + let old = ExtId::new("vortex.a"); + let name = ExtId::new("arrow.b"); + let new = ExtId::new("vortex.c"); + + session.register_arrow_canonical(old, name, TEST_CODEC); + assert_eq!(session.arrow_alias_for(&old).map(|(id, _)| id), Some(name)); + + session.register_arrow_canonical(new, name, TEST_CODEC); + + assert_eq!(session.arrow_alias_for(&new).map(|(id, _)| id), Some(name)); + assert_eq!(session.vortex_alias_for(&name).map(|(id, _)| id), Some(new)); + assert!(session.arrow_alias_for(&old).is_none()); + } + + #[test] + fn codec_round_trips_through_lookup() { + let session = DTypeSession::default(); + let vid = ExtId::new("vortex.x"); + let aid = ExtId::new("arrow.x"); + + session.register_arrow_canonical(vid, aid, TEST_CODEC); + + let (_, codec) = session.arrow_alias_for(&vid).unwrap(); + let json = (codec.to_json)(b"hello").unwrap(); + assert_eq!(json, "hello"); + let bytes = (codec.from_json)(&json).unwrap(); + assert_eq!(bytes, b"hello"); + } +} diff --git a/vortex-array/src/extension/mod.rs b/vortex-array/src/extension/mod.rs index 9f81e7fb310..077af4a8337 100644 --- a/vortex-array/src/extension/mod.rs +++ b/vortex-array/src/extension/mod.rs @@ -9,7 +9,7 @@ pub mod datetime; pub mod uuid; #[cfg(test)] -mod tests; +pub(crate) mod tests; /// An empty metadata struct for extension dtypes that do not require any metadata. #[derive(Debug, Clone, PartialEq, Eq, Hash)] diff --git a/vortex-array/src/extension/tests/mod.rs b/vortex-array/src/extension/tests/mod.rs index 31df677e61d..f4ab560fbf8 100644 --- a/vortex-array/src/extension/tests/mod.rs +++ b/vortex-array/src/extension/tests/mod.rs @@ -3,4 +3,4 @@ //! Test extension types for exercising the [`ExtVTable`] contract. -mod divisible_int; +pub(crate) mod divisible_int; diff --git a/vortex-tensor/Cargo.toml b/vortex-tensor/Cargo.toml index 2f92ce5a107..d3de27f5fb1 100644 --- a/vortex-tensor/Cargo.toml +++ b/vortex-tensor/Cargo.toml @@ -29,8 +29,11 @@ half = { workspace = true } itertools = { workspace = true } num-traits = { workspace = true } prost = { workspace = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } [dev-dependencies] +arrow-schema = { workspace = true } divan = { workspace = true } mimalloc = { workspace = true } rand = { workspace = true } diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index 7beadc02e93..28ee8795825 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -11,6 +11,8 @@ )] use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin; +use vortex_array::dtype::extension::ExtVTable; +use vortex_array::dtype::session::ArrowCanonicalCodec; use vortex_array::dtype::session::DTypeSessionExt; use vortex_array::scalar_fn::session::ScalarFnSessionExt; use vortex_array::session::ArraySessionExt; @@ -47,8 +49,19 @@ pub const SCALAR_FN_ARRAY_TENSOR_PLUGIN_ENV: &str = "VX_SCALAR_FN_ARRAY_TENSOR_P /// Initialize the Vortex tensor library with a Vortex session. pub fn initialize(session: &VortexSession) { - session.dtypes().register(Vector); - session.dtypes().register(FixedShapeTensor); + let dtypes = session.dtypes(); + dtypes.register(Vector); + dtypes.register(FixedShapeTensor); + dtypes.register_arrow_canonical( + FixedShapeTensor.id(), + FixedShapeTensor::arrow_ext_id(), + ArrowCanonicalCodec { + to_json: fixed_shape::proto_to_json, + from_json: fixed_shape::json_to_proto, + }, + ); + // Release the shard read before `scalar_fns` may take a write on the same shard. + drop(dtypes); let session_fns = session.scalar_fns(); @@ -85,4 +98,6 @@ mod tests { crate::initialize(&session); session }); + + mod arrow_roundtrip; } diff --git a/vortex-tensor/src/tests/arrow_roundtrip.rs b/vortex-tensor/src/tests/arrow_roundtrip.rs new file mode 100644 index 00000000000..baa17a480ca --- /dev/null +++ b/vortex-tensor/src/tests/arrow_roundtrip.rs @@ -0,0 +1,255 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Arrow ↔ DType round-trip tests for tensor extension types. + +use std::sync::Arc; + +use arrow_schema::DataType; +use arrow_schema::TimeUnit as ArrowTimeUnit; +use arrow_schema::extension::EXTENSION_TYPE_METADATA_KEY; +use arrow_schema::extension::EXTENSION_TYPE_NAME_KEY; +use vortex_array::ArrayRef; +use vortex_array::IntoArray; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::StructArray; +use vortex_array::arrow::FromArrowArray; +use vortex_array::dtype::DType; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_array::dtype::arrow::FromArrowType; +use vortex_array::dtype::extension::ExtDType; +use vortex_array::dtype::extension::ExtVTable; +use vortex_array::extension::EmptyMetadata; +use vortex_array::extension::datetime::TimeUnit; +use vortex_array::extension::datetime::Timestamp; +use vortex_array::validity::Validity; + +use crate::tests::SESSION; +use crate::types::fixed_shape::FixedShapeTensor; +use crate::types::fixed_shape::FixedShapeTensorMetadata; +use crate::types::vector::Vector; + +fn vector_dtype(len: u32) -> DType { + let storage = DType::FixedSizeList( + Arc::new(DType::Primitive(PType::F32, Nullability::NonNullable)), + len, + Nullability::NonNullable, + ); + let ext = ExtDType::::try_new(EmptyMetadata, storage).unwrap(); + DType::Extension(ext.erased()) +} + +fn fixed_shape_dtype(metadata: FixedShapeTensorMetadata, element_count: u32) -> DType { + let storage = DType::FixedSizeList( + Arc::new(DType::Primitive(PType::F32, Nullability::NonNullable)), + element_count, + Nullability::NonNullable, + ); + let ext = ExtDType::::try_new(metadata, storage).unwrap(); + DType::Extension(ext.erased()) +} + +#[test] +fn vector_forward_carries_extension_name() { + let original = DType::struct_([("embedding", vector_dtype(4))], Nullability::NonNullable); + + let schema = original.to_arrow_schema().unwrap(); + let field = schema.field(0); + + assert_eq!( + field + .metadata() + .get(EXTENSION_TYPE_NAME_KEY) + .map(String::as_str), + Some(Vector.id().as_str()), + ); + // EmptyMetadata → no metadata key. + assert!(field.metadata().get(EXTENSION_TYPE_METADATA_KEY).is_none()); + + let DataType::FixedSizeList(element, size) = field.data_type() else { + panic!("expected FixedSizeList, got {:?}", field.data_type()); + }; + assert_eq!(*size, 4); + assert_eq!(element.data_type(), &DataType::Float32); +} + +#[test] +fn vector_roundtrip_with_session() { + let original = DType::struct_([("embedding", vector_dtype(128))], Nullability::NonNullable); + + let schema = original.to_arrow_schema().unwrap(); + let recovered = DType::from_arrow_with_session(&schema, &SESSION); + + assert_eq!(recovered, original); +} + +#[test] +fn vector_without_registration_falls_back_to_fsl() { + let original = DType::struct_([("embedding", vector_dtype(16))], Nullability::NonNullable); + + let empty_session = vortex_session::VortexSession::empty(); + let schema = original.to_arrow_schema().unwrap(); + let recovered = DType::from_arrow_with_session(&schema, &empty_session); + + let expected = DType::struct_( + [( + "embedding", + DType::FixedSizeList( + Arc::new(DType::Primitive(PType::F32, Nullability::NonNullable)), + 16, + Nullability::NonNullable, + ), + )], + Nullability::NonNullable, + ); + assert_eq!(recovered, expected); +} + +#[test] +fn fixed_shape_tensor_metadata_roundtrip() { + let metadata = FixedShapeTensorMetadata::new(vec![2, 3, 4]) + .with_dim_names(vec!["x".into(), "y".into(), "z".into()]) + .unwrap() + .with_permutation(vec![2, 0, 1]) + .unwrap(); + + let original = DType::struct_( + [("tensor", fixed_shape_dtype(metadata, 24))], + Nullability::NonNullable, + ); + + let schema = original.to_arrow_schema_with_session(&SESSION).unwrap(); + let field = schema.field(0); + + assert_eq!( + field + .metadata() + .get(EXTENSION_TYPE_NAME_KEY) + .map(String::as_str), + Some(FixedShapeTensor::arrow_ext_id().as_str()), + ); + + // Canonical wire: raw JSON, not base64. + let meta_str = field.metadata().get(EXTENSION_TYPE_METADATA_KEY).unwrap(); + let parsed: serde_json::Value = serde_json::from_str(meta_str).unwrap(); + assert_eq!(parsed["shape"], serde_json::json!([2, 3, 4])); + assert_eq!(parsed["dim_names"], serde_json::json!(["x", "y", "z"])); + assert_eq!(parsed["permutation"], serde_json::json!([2, 0, 1])); + + let recovered = DType::from_arrow_with_session(&schema, &SESSION); + assert_eq!(recovered, original); +} + +#[test] +fn tensor_inside_nested_struct_roundtrips() { + let inner = DType::struct_([("embedding", vector_dtype(8))], Nullability::Nullable); + let original = DType::struct_( + [("inner", inner), ("id", DType::Utf8(Nullability::Nullable))], + Nullability::NonNullable, + ); + + let schema = original.to_arrow_schema().unwrap(); + let recovered = DType::from_arrow_with_session(&schema, &SESSION); + + assert_eq!(recovered, original); +} + +#[test] +fn vector_record_batch_round_trip_carries_field_metadata() { + let vector_array = Vector::constant_array(&[1.0f32, 2.0, 3.0, 4.0], 2).unwrap(); + let struct_array = StructArray::from_fields(&[("embedding", vector_array)]).unwrap(); + + let dtype = DType::struct_([("embedding", vector_dtype(4))], Nullability::NonNullable); + let schema = dtype.to_arrow_schema_with_session(&SESSION).unwrap(); + let rb = struct_array.into_record_batch_with_schema(&schema).unwrap(); + + let column = rb.column(0); + let DataType::FixedSizeList(_, size) = column.data_type() else { + panic!( + "expected storage FixedSizeList, got {:?}", + column.data_type() + ); + }; + assert_eq!(*size, 4); + + assert_eq!( + rb.schema() + .field(0) + .metadata() + .get(EXTENSION_TYPE_NAME_KEY) + .map(String::as_str), + Some(Vector.id().as_str()), + ); +} + +#[test] +fn temporal_extension_still_uses_native_arrow() { + let ts = Timestamp::new_with_tz(TimeUnit::Microseconds, None, Nullability::Nullable); + let original = DType::struct_( + [("ts", DType::Extension(ts.erased()))], + Nullability::NonNullable, + ); + + let schema = original.to_arrow_schema().unwrap(); + let field = schema.field(0); + + assert!(matches!( + field.data_type(), + DataType::Timestamp(ArrowTimeUnit::Microsecond, None) + )); + assert!(field.metadata().get(EXTENSION_TYPE_NAME_KEY).is_none()); + assert!(field.metadata().get(EXTENSION_TYPE_METADATA_KEY).is_none()); +} + +/// Build a storage FSL with `num_rows` rows, each of `elements_per_row` elements. +fn fsl_f32_storage(elements_per_row: u32, num_rows: usize) -> ArrayRef { + let total = elements_per_row as usize * num_rows; + let elements = PrimitiveArray::from_iter((0..total).map(|i| i as f32)); + FixedSizeListArray::try_new( + elements.into_array(), + elements_per_row, + Validity::NonNullable, + num_rows, + ) + .unwrap() + .into_array() +} + +#[test] +fn vector_record_batch_round_trip() { + let vector_array = + ExtensionArray::try_new_from_vtable(Vector, EmptyMetadata, fsl_f32_storage(4, 2)) + .unwrap() + .into_array(); + let original = StructArray::from_fields(&[("embedding", vector_array)]).unwrap(); + + let dtype = DType::struct_([("embedding", vector_dtype(4))], Nullability::NonNullable); + let schema = dtype.to_arrow_schema_with_session(&SESSION).unwrap(); + let rb = original.into_record_batch_with_schema(&schema).unwrap(); + + let recovered = ArrayRef::from_arrow_with_session(rb, false, &SESSION).unwrap(); + assert_eq!(recovered.dtype(), &dtype); +} + +#[test] +fn fixed_shape_tensor_record_batch_round_trip() { + let metadata = FixedShapeTensorMetadata::new(vec![2, 2]) + .with_dim_names(vec!["row".into(), "col".into()]) + .unwrap(); + let tensor_dtype = fixed_shape_dtype(metadata.clone(), 4); + let dtype = DType::struct_([("tensor", tensor_dtype)], Nullability::NonNullable); + let schema = dtype.to_arrow_schema_with_session(&SESSION).unwrap(); + + let tensor_array = + ExtensionArray::try_new_from_vtable(FixedShapeTensor, metadata, fsl_f32_storage(4, 3)) + .unwrap() + .into_array(); + let original = StructArray::from_fields(&[("tensor", tensor_array)]).unwrap(); + let rb = original.into_record_batch_with_schema(&schema).unwrap(); + + let recovered = ArrayRef::from_arrow_with_session(rb, false, &SESSION).unwrap(); + assert_eq!(recovered.dtype(), &dtype); +} diff --git a/vortex-tensor/src/types/fixed_shape/canonical.rs b/vortex-tensor/src/types/fixed_shape/canonical.rs new file mode 100644 index 00000000000..e5ff5389f72 --- /dev/null +++ b/vortex-tensor/src/types/fixed_shape/canonical.rs @@ -0,0 +1,148 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Arrow canonical [`arrow.fixed_shape_tensor`] JSON wire ⇄ on-disk proto adapters. +//! +//! Hand-rolled rather than reusing `arrow_schema::extension::FixedShapeTensor` because arrow-rs +//! 58 emits `"permutations"` (plural) while the spec and pyarrow use `"permutation"`. +//! +//! [`arrow.fixed_shape_tensor`]: https://arrow.apache.org/docs/format/CanonicalExtensions.html#fixed-shape-tensor + +use serde::Deserialize; +use serde::Serialize; +use vortex_error::VortexResult; +use vortex_error::vortex_err; + +use crate::types::fixed_shape::FixedShapeTensorMetadata; +use crate::types::fixed_shape::proto; + +#[derive(Serialize)] +struct WireRef<'a> { + shape: &'a [usize], + #[serde(skip_serializing_if = "Option::is_none")] + dim_names: Option<&'a [String]>, + #[serde(skip_serializing_if = "Option::is_none")] + permutation: Option<&'a [usize]>, +} + +#[derive(Deserialize)] +struct Wire { + shape: Vec, + #[serde(default)] + dim_names: Option>, + #[serde(default)] + permutation: Option>, +} + +fn metadata_to_json(metadata: &FixedShapeTensorMetadata) -> VortexResult { + let wire = WireRef { + shape: metadata.logical_shape(), + dim_names: metadata.dim_names(), + permutation: metadata.permutation(), + }; + serde_json::to_string(&wire) + .map_err(|e| vortex_err!("fixed_shape_tensor canonical serialize: {e}")) +} + +fn metadata_from_json(json: &str) -> VortexResult { + let wire: Wire = serde_json::from_str(json) + .map_err(|e| vortex_err!("fixed_shape_tensor canonical deserialize: {e}"))?; + + let mut m = FixedShapeTensorMetadata::new(wire.shape); + if let Some(names) = wire.dim_names { + m = m.with_dim_names(names)?; + } + if let Some(perm) = wire.permutation { + m = m.with_permutation(perm)?; + } + Ok(m) +} + +pub(crate) fn proto_to_json(proto_bytes: &[u8]) -> VortexResult { + let metadata = proto::deserialize(proto_bytes)?; + metadata_to_json(&metadata) +} + +pub(crate) fn json_to_proto(json: &str) -> VortexResult> { + let metadata = metadata_from_json(json)?; + Ok(proto::serialize(&metadata)) +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + + use super::*; + + #[rstest] + #[case::scalar_0d(FixedShapeTensorMetadata::new(vec![]))] + #[case::vector_1d(FixedShapeTensorMetadata::new(vec![5]))] + #[case::shape_only(FixedShapeTensorMetadata::new(vec![2, 3, 4]))] + #[case::with_dim_names( + FixedShapeTensorMetadata::new(vec![3, 4]) + .with_dim_names(vec!["rows".into(), "cols".into()]) + .unwrap() + )] + #[case::with_permutation( + FixedShapeTensorMetadata::new(vec![2, 3, 4]) + .with_permutation(vec![2, 0, 1]) + .unwrap() + )] + #[case::all_fields( + FixedShapeTensorMetadata::new(vec![2, 3, 4]) + .with_dim_names(vec!["x".into(), "y".into(), "z".into()]).unwrap() + .with_permutation(vec![1, 2, 0]).unwrap() + )] + fn json_roundtrip(#[case] metadata: FixedShapeTensorMetadata) -> VortexResult<()> { + let json = metadata_to_json(&metadata)?; + let decoded = metadata_from_json(&json)?; + assert_eq!(decoded, metadata); + Ok(()) + } + + #[rstest] + #[case::shape_only(FixedShapeTensorMetadata::new(vec![2, 3, 4]))] + #[case::all_fields( + FixedShapeTensorMetadata::new(vec![2, 3, 4]) + .with_dim_names(vec!["x".into(), "y".into(), "z".into()]).unwrap() + .with_permutation(vec![1, 2, 0]).unwrap() + )] + fn proto_to_json_to_proto_roundtrip( + #[case] metadata: FixedShapeTensorMetadata, + ) -> VortexResult<()> { + let proto_bytes = proto::serialize(&metadata); + let json = proto_to_json(&proto_bytes)?; + let proto_again = json_to_proto(&json)?; + let metadata_again = proto::deserialize(&proto_again)?; + assert_eq!(metadata_again, metadata); + Ok(()) + } + + #[test] + fn wire_format_matches_arrow_spec() -> VortexResult<()> { + let metadata = FixedShapeTensorMetadata::new(vec![2, 3, 4]) + .with_dim_names(vec!["x".into(), "y".into(), "z".into()])? + .with_permutation(vec![1, 2, 0])?; + + let json = metadata_to_json(&metadata)?; + let v: serde_json::Value = + serde_json::from_str(&json).map_err(|e| vortex_err!("parse wire: {e}"))?; + + assert_eq!(v["shape"], serde_json::json!([2, 3, 4])); + assert_eq!(v["dim_names"], serde_json::json!(["x", "y", "z"])); + // Arrow spec uses singular "permutation"; guard against regressions to arrow-rs's plural. + assert_eq!(v["permutation"], serde_json::json!([1, 2, 0])); + assert!(v.get("permutations").is_none()); + Ok(()) + } + + #[test] + fn omits_optional_fields_when_unset() -> VortexResult<()> { + let json = metadata_to_json(&FixedShapeTensorMetadata::new(vec![5]))?; + let v: serde_json::Value = + serde_json::from_str(&json).map_err(|e| vortex_err!("parse wire: {e}"))?; + assert!(v.get("dim_names").is_none()); + assert!(v.get("permutation").is_none()); + Ok(()) + } +} diff --git a/vortex-tensor/src/types/fixed_shape/mod.rs b/vortex-tensor/src/types/fixed_shape/mod.rs index 48f991517ec..94d91e74095 100644 --- a/vortex-tensor/src/types/fixed_shape/mod.rs +++ b/vortex-tensor/src/types/fixed_shape/mod.rs @@ -3,10 +3,20 @@ //! Fixed-shape Tensor extension type. +use vortex_array::dtype::extension::ExtId; +use vortex_session::registry::CachedId; + /// The VTable for the Tensor extension type. #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] pub struct FixedShapeTensor; +impl FixedShapeTensor { + pub(crate) fn arrow_ext_id() -> ExtId { + static ID: CachedId = CachedId::new("arrow.fixed_shape_tensor"); + *ID + } +} + mod matcher; pub use matcher::AnyFixedShapeTensor; pub use matcher::FixedShapeTensorMatcherMetadata; @@ -14,5 +24,8 @@ pub use matcher::FixedShapeTensorMatcherMetadata; mod metadata; pub use metadata::FixedShapeTensorMetadata; +mod canonical; mod proto; mod vtable; +pub(crate) use canonical::json_to_proto; +pub(crate) use canonical::proto_to_json; diff --git a/vortex-tensor/src/types/fixed_shape/vtable.rs b/vortex-tensor/src/types/fixed_shape/vtable.rs index d0a4b7842f1..97eadbb55fb 100644 --- a/vortex-tensor/src/types/fixed_shape/vtable.rs +++ b/vortex-tensor/src/types/fixed_shape/vtable.rs @@ -10,11 +10,15 @@ use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_error::vortex_ensure_eq; +use vortex_session::registry::CachedId; use crate::types::fixed_shape::FixedShapeTensor; use crate::types::fixed_shape::FixedShapeTensorMetadata; use crate::types::fixed_shape::proto; +/// Vortex extension id for [`FixedShapeTensor`]. +static ID: CachedId = CachedId::new("vortex.tensor.fixed_shape_tensor"); + impl ExtVTable for FixedShapeTensor { type Metadata = FixedShapeTensorMetadata; @@ -22,7 +26,7 @@ impl ExtVTable for FixedShapeTensor { type NativeValue<'a> = &'a ScalarValue; fn id(&self) -> ExtId { - ExtId::new("vortex.fixed_shape_tensor") + *ID } fn serialize_metadata(&self, metadata: &Self::Metadata) -> VortexResult> { @@ -141,69 +145,64 @@ mod tests { assert_roundtrip(&metadata?) } - /// Constructs a `FixedShapeTensor` ext dtype wrapped in `DType::Extension`. - fn tensor_dtype( - metadata: FixedShapeTensorMetadata, - element: PType, - list_size: u32, - ) -> VortexResult { + fn tensor_dtype(metadata: FixedShapeTensorMetadata, element: PType, list_size: u32) -> DType { let storage = DType::FixedSizeList( Arc::new(DType::Primitive(element, Nullability::NonNullable)), list_size, Nullability::NonNullable, ); - Ok(DType::Extension( - ExtDType::::try_new(metadata, storage)?.erased(), - )) - } - - #[test] - fn tensor_widens_element_when_metadata_matches() -> VortexResult<()> { - let metadata = FixedShapeTensorMetadata::new(vec![2, 3]); - let lhs = tensor_dtype(metadata.clone(), PType::F32, 6)?; - let rhs = tensor_dtype(metadata.clone(), PType::F64, 6)?; - let expected = tensor_dtype(metadata, PType::F64, 6)?; - assert_eq!(lhs.least_supertype(&rhs), Some(expected)); - Ok(()) - } - - #[test] - fn tensor_different_shape_returns_none() -> VortexResult<()> { - let lhs = tensor_dtype(FixedShapeTensorMetadata::new(vec![2, 3]), PType::F32, 6)?; - let rhs = tensor_dtype(FixedShapeTensorMetadata::new(vec![3, 2]), PType::F32, 6)?; - assert_eq!(lhs.least_supertype(&rhs), None); - Ok(()) - } - - #[test] - fn tensor_different_permutation_returns_none() -> VortexResult<()> { - let lhs_metadata = - FixedShapeTensorMetadata::new(vec![2, 3]).with_permutation(vec![0, 1])?; - let rhs_metadata = - FixedShapeTensorMetadata::new(vec![2, 3]).with_permutation(vec![1, 0])?; - let lhs = tensor_dtype(lhs_metadata, PType::F32, 6)?; - let rhs = tensor_dtype(rhs_metadata, PType::F32, 6)?; - assert_eq!(lhs.least_supertype(&rhs), None); - Ok(()) + DType::Extension( + ExtDType::::try_new(metadata, storage) + .unwrap() + .erased(), + ) } - #[test] - fn tensor_different_dim_names_returns_none() -> VortexResult<()> { - let lhs_metadata = FixedShapeTensorMetadata::new(vec![2, 3]) - .with_dim_names(vec!["x".into(), "y".into()])?; - let rhs_metadata = FixedShapeTensorMetadata::new(vec![2, 3]) - .with_dim_names(vec!["rows".into(), "cols".into()])?; - let lhs = tensor_dtype(lhs_metadata, PType::F32, 6)?; - let rhs = tensor_dtype(rhs_metadata, PType::F32, 6)?; - assert_eq!(lhs.least_supertype(&rhs), None); - Ok(()) - } - - #[test] - fn tensor_vs_non_extension_returns_none() -> VortexResult<()> { - let lhs = tensor_dtype(FixedShapeTensorMetadata::new(vec![2, 3]), PType::F32, 6)?; - let rhs = DType::Primitive(PType::F32, Nullability::NonNullable); - assert_eq!(lhs.least_supertype(&rhs), None); - Ok(()) + #[rstest] + #[case::widens_element_when_metadata_matches( + tensor_dtype(FixedShapeTensorMetadata::new(vec![2, 3]), PType::F32, 6), + tensor_dtype(FixedShapeTensorMetadata::new(vec![2, 3]), PType::F64, 6), + Some(tensor_dtype(FixedShapeTensorMetadata::new(vec![2, 3]), PType::F64, 6)), + )] + #[case::different_shape_returns_none( + tensor_dtype(FixedShapeTensorMetadata::new(vec![2, 3]), PType::F32, 6), + tensor_dtype(FixedShapeTensorMetadata::new(vec![3, 2]), PType::F32, 6), + None, + )] + #[case::different_permutation_returns_none( + tensor_dtype( + FixedShapeTensorMetadata::new(vec![2, 3]).with_permutation(vec![0, 1]).unwrap(), + PType::F32, 6, + ), + tensor_dtype( + FixedShapeTensorMetadata::new(vec![2, 3]).with_permutation(vec![1, 0]).unwrap(), + PType::F32, 6, + ), + None, + )] + #[case::different_dim_names_returns_none( + tensor_dtype( + FixedShapeTensorMetadata::new(vec![2, 3]) + .with_dim_names(vec!["x".into(), "y".into()]).unwrap(), + PType::F32, 6, + ), + tensor_dtype( + FixedShapeTensorMetadata::new(vec![2, 3]) + .with_dim_names(vec!["rows".into(), "cols".into()]).unwrap(), + PType::F32, 6, + ), + None, + )] + #[case::vs_non_extension_returns_none( + tensor_dtype(FixedShapeTensorMetadata::new(vec![2, 3]), PType::F32, 6), + DType::Primitive(PType::F32, Nullability::NonNullable), + None, + )] + fn tensor_least_supertype( + #[case] lhs: DType, + #[case] rhs: DType, + #[case] expected: Option, + ) { + assert_eq!(lhs.least_supertype(&rhs), expected); } } diff --git a/vortex-tensor/src/types/vector/vtable.rs b/vortex-tensor/src/types/vector/vtable.rs index c80f17665f2..83d807a2cd3 100644 --- a/vortex-tensor/src/types/vector/vtable.rs +++ b/vortex-tensor/src/types/vector/vtable.rs @@ -8,10 +8,14 @@ use vortex_array::dtype::extension::ExtVTable; use vortex_array::extension::EmptyMetadata; use vortex_array::scalar::ScalarValue; use vortex_error::VortexResult; +use vortex_session::registry::CachedId; use crate::types::vector::Vector; use crate::types::vector::validate_vector_storage_dtype; +/// Vortex extension id for [`Vector`]. +static ID: CachedId = CachedId::new("vortex.tensor.vector"); + impl ExtVTable for Vector { type Metadata = EmptyMetadata; @@ -19,7 +23,7 @@ impl ExtVTable for Vector { type NativeValue<'a> = &'a ScalarValue; fn id(&self) -> ExtId { - ExtId::new("vortex.tensor.vector") + *ID } fn serialize_metadata(&self, _metadata: &Self::Metadata) -> VortexResult> { @@ -138,60 +142,46 @@ mod tests { Ok(()) } - /// Constructs a `Vector` ext dtype wrapped in `DType::Extension`. - fn vector_dtype(ptype: PType, dims: u32) -> VortexResult { - vector_dtype_with_outer(ptype, dims, Nullability::NonNullable) - } - - /// Constructs a `Vector` ext dtype with the given outer `Nullability`, wrapped in - /// `DType::Extension`. - fn vector_dtype_with_outer(ptype: PType, dims: u32, outer: Nullability) -> VortexResult { + fn vector_dtype(ptype: PType, dims: u32, outer: Nullability) -> DType { let storage = vector_storage_dtype(ptype, dims, outer); - Ok(DType::Extension( - ExtDType::::try_new(EmptyMetadata, storage)?.erased(), - )) - } - - #[test] - fn vector_widens_float_precision() -> VortexResult<()> { - let lhs = vector_dtype(PType::F32, 768)?; - let rhs = vector_dtype(PType::F64, 768)?; - let expected = vector_dtype(PType::F64, 768)?; - assert_eq!(lhs.least_supertype(&rhs), Some(expected)); - Ok(()) - } - - #[test] - fn vector_dim_mismatch_returns_none() -> VortexResult<()> { - let lhs = vector_dtype(PType::F32, 768)?; - let rhs = vector_dtype(PType::F32, 1024)?; - assert_eq!(lhs.least_supertype(&rhs), None); - Ok(()) - } - - #[test] - fn vector_vs_non_extension_returns_none() -> VortexResult<()> { - let lhs = vector_dtype(PType::F32, 768)?; - let rhs = DType::Primitive(PType::F32, Nullability::NonNullable); - assert_eq!(lhs.least_supertype(&rhs), None); - Ok(()) - } - - #[test] - fn vector_unions_outer_nullability_with_float_widening() -> VortexResult<()> { - let lhs = vector_dtype_with_outer(PType::F32, 4, Nullability::NonNullable)?; - let rhs = vector_dtype_with_outer(PType::F64, 4, Nullability::Nullable)?; - let expected = vector_dtype_with_outer(PType::F64, 4, Nullability::Nullable)?; - assert_eq!(lhs.least_supertype(&rhs), Some(expected)); - Ok(()) + DType::Extension( + ExtDType::::try_new(EmptyMetadata, storage) + .unwrap() + .erased(), + ) } - #[test] - fn vector_same_ptype_unions_outer_nullability() -> VortexResult<()> { - let lhs = vector_dtype_with_outer(PType::F32, 4, Nullability::NonNullable)?; - let rhs = vector_dtype_with_outer(PType::F32, 4, Nullability::Nullable)?; - let expected = vector_dtype_with_outer(PType::F32, 4, Nullability::Nullable)?; - assert_eq!(lhs.least_supertype(&rhs), Some(expected)); - Ok(()) + #[rstest] + #[case::widens_float_precision( + vector_dtype(PType::F32, 768, Nullability::NonNullable), + vector_dtype(PType::F64, 768, Nullability::NonNullable), + Some(vector_dtype(PType::F64, 768, Nullability::NonNullable)) + )] + #[case::dim_mismatch_returns_none( + vector_dtype(PType::F32, 768, Nullability::NonNullable), + vector_dtype(PType::F32, 1024, Nullability::NonNullable), + None + )] + #[case::vs_non_extension_returns_none( + vector_dtype(PType::F32, 768, Nullability::NonNullable), + DType::Primitive(PType::F32, Nullability::NonNullable), + None + )] + #[case::unions_outer_nullability_with_float_widening( + vector_dtype(PType::F32, 4, Nullability::NonNullable), + vector_dtype(PType::F64, 4, Nullability::Nullable), + Some(vector_dtype(PType::F64, 4, Nullability::Nullable)) + )] + #[case::same_ptype_unions_outer_nullability( + vector_dtype(PType::F32, 4, Nullability::NonNullable), + vector_dtype(PType::F32, 4, Nullability::Nullable), + Some(vector_dtype(PType::F32, 4, Nullability::Nullable)) + )] + fn vector_least_supertype( + #[case] lhs: DType, + #[case] rhs: DType, + #[case] expected: Option, + ) { + assert_eq!(lhs.least_supertype(&rhs), expected); } }